Source code for libcom.inharmonious_region_localization.inharmonious_region_localization
import torch
import torchvision
from libcom.utils.model_download import download_pretrained_model
from libcom.utils.process_image import *
from libcom.utils.environment import *
import os
import torchvision.transforms as transforms
from .source.madis_net import *
cur_dir = os.path.dirname(os.path.abspath(__file__))
model_dir = os.environ.get('LIBCOM_MODEL_DIR',cur_dir)
model_set = ['IHDRNet']
[docs]class InharmoniousLocalizationModel:
"""
Inharmonious region localization model.
Args:
device (str | torch.device): gpu id
model_type (str): predefined model type
kwargs (dict): other parameters for building model
Examples:
>>> from libcom import InharmoniousLocalizationModel
>>> import cv2
>>> net = InharmoniousLocalizationModel(device=0)
>>> comp_img1 = '../tests/source/composite/comp1_MadisNet.png'
>>> inharmonious_localization1 = net(comp_img1)
>>> comp_img2 = '../tests/source/composite/comp2_MadisNet.png'
>>> inharmonious_localization2 = net(comp_img2)
>>> cv2.imwrite('../docs/_static/image/inharmonious_localization_result1.jpg', np.concatenate([cv2.resize(cv2.imread(comp_img1),(256,256)), inharmonious_localization1],axis=1))
>>> cv2.imwrite('../docs/_static/image/inharmonious_localization_result2.jpg', np.concatenate([cv2.resize(cv2.imread(comp_img2),(256,256)), inharmonious_localization2],axis=1))
Expected result:
.. image:: _static/image/inharmonious_localization_result3_4.jpg
"""
def __init__(self, device=0, model_type='IHDRNet', **kwargs):
assert model_type in model_set, f'Not implementation for {model_type}'
self.model_type = model_type
self.option = kwargs
weight_path_g = os.path.join(model_dir, 'pretrained_models', 'Inharmonious_G.pth')
download_pretrained_model(weight_path_g)
weight_path_ihdrnet = os.path.join(model_dir, 'pretrained_models', 'IHDRNet.pth')
download_pretrained_model(weight_path_ihdrnet)
self.device = check_gpu_device(device)
self.build_pretrained_model(weight_path_g, weight_path_ihdrnet)
self.build_data_transformer()
def build_pretrained_model(self, weight_path_g, weight_path_ihdrnet):
model = MadisNet()
model.g.load_state_dict(torch.load(weight_path_g, weights_only=True)['state_dict'])
model.ihdrnet.load_state_dict(torch.load(weight_path_ihdrnet, weights_only=True)['state_dict'])
self.MadisNet_model = model.to(self.device).eval()
def build_data_transformer(self):
self.transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)
)
])
def inputs_preprocess(self, composite_image):
img = read_image_opencv(composite_image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256,256))
img = self.transformer(img).float().to(self.device).unsqueeze(0)
return img
def outputs_postprocess(self, outputs):
if len(outputs.shape) == 4:
outputs = outputs.squeeze(0)
outputs = (torch.clamp(255.0 * outputs.permute(1, 2, 0), 0, 255)).cpu().numpy()
outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
return outputs
@torch.no_grad()
def __call__(self, composite_image):
"""
Given a composite image, predict the mask of the inharmonious region.
Args:
composite_image (str | numpy.ndarray): The path to composite image or the compposite image in ndarray form.
Returns:
inharmonious_mask (np.array): The inharmonious mask.
"""
img = self.inputs_preprocess(composite_image)
outputs = self.MadisNet_model(img)[0]
preds = self.outputs_postprocess(outputs)
return preds