Source code for libcom.harmony_score.harmony_score_prediction
import torch
import torchvision
from libcom.utils.model_download import download_pretrained_model
from libcom.utils.process_image import *
from libcom.utils.environment import *
from libcom.harmony_score.source.bargainnet import StyleEncoder
import torch
import os
import torchvision.transforms as transforms
import math
cur_dir = os.path.dirname(os.path.abspath(__file__))
model_dir = os.environ.get('LIBCOM_MODEL_DIR',cur_dir)
model_set = ['BargainNet']
[docs]class HarmonyScoreModel:
"""
Foreground object search score prediction model.
Args:
device (str | torch.device): gpu id
model_type (str): predefined model type.
kwargs (dict): other parameters for building model
Examples:
>>> from libcom import HarmonyScoreModel
>>> from libcom.utils.process_image import make_image_grid
>>> import cv2
>>> net = HarmonyScoreModel(device=0, model_type='BargainNet')
>>> test_dir = '../tests/harmony_score_prediction/'
>>> img_names = ['vaulted-cellar-247391_inharm.jpg', 'ameland-5651866_harm.jpg']
>>> vis_list,scores = [], []
>>> for img_name in img_names:
>>> comp_img = test_dir + 'composite/' + img_name
>>> comp_mask = test_dir + 'composite_mask/' + img_name
>>> score = net(comp_img, comp_mask)
>>> vis_list += [comp_img, comp_mask]
>>> scores.append(score)
>>> grid_img = make_image_grid(vis_list, text_list=[f'harmony_score:{scores[0]:.2f}', 'composite-mask', f'harmony_score:{scores[1]:.2f}', 'composite-mask'])
>>> cv2.imwrite('../docs/_static/image/harmonyscore_result1.jpg', grid_img)
Expected result:
.. image:: _static/image/harmonyscore_result1.jpg
:scale: 38 %
"""
def __init__(self, device=0, model_type='BargainNet', **kwargs):
assert model_type in model_set, f'Not implementation for {model_type}'
self.model_type = model_type
self.option = kwargs
weight_path = os.path.join(model_dir, 'pretrained_models', 'BargainNet.pth')
download_pretrained_model(weight_path)
self.device = check_gpu_device(device)
self.build_pretrained_model(weight_path)
self.build_data_transformer()
def build_pretrained_model(self, weight_path):
model = StyleEncoder(style_dim=16)
model.load_state_dict(torch.load(weight_path, map_location='cpu', weights_only=True))
self.model = model.to(self.device).eval()
def build_data_transformer(self):
self.image_size = 256
self.transform = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
self.mask_transform = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
])
def inputs_preprocess(self, composite_image, composite_mask):
image = read_image_opencv(composite_image)
mask = read_mask_opencv(composite_mask)
fg_mask = mask.astype(np.float32) / 255.
bg_mask = 1 - fg_mask
fg_mask = self.mask_transform(Image.fromarray(fg_mask))
fg_mask = fg_mask.unsqueeze(0).to(self.device)
bg_mask = self.mask_transform(Image.fromarray(bg_mask))
bg_mask = bg_mask.unsqueeze(0).to(self.device)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transform(Image.fromarray(image))
image = image.unsqueeze(0).to(self.device)
return image, bg_mask, fg_mask
def outputs_postprocess(self, bg_style, fg_style):
eucl_dist = self.Euclidean_distance(bg_style, fg_style)
# convert distance to harmony level which lies in 0 and 1
harm_level = math.exp(-0.04212 * eucl_dist)
return harm_level
def Euclidean_distance(self, vec1, vec2):
vec1 = vec1.cpu().numpy()
vec2 = vec2.cpu().numpy()
dist = np.sqrt(np.sum((vec1 - vec2)**2))
return dist
@torch.no_grad()
def __call__(self, composite_image, composite_mask):
"""
Predicting the compatibility score between background and foreground in the given composite image.
Args:
composite_image (str | numpy.ndarray): The path to composite image or the compposite image in ndarray form.
composite_mask (str | numpy.ndarray): Mask of composite image which indicates the foreground object region in the composite image.
Returns:
harmony_score (float): Predicted harmony score within [0,1] between background region and foreground region of the given composite image. Larger harmony score implies more harmonious composite image.
"""
im, bg_mask, fg_mask = self.inputs_preprocess(composite_image, composite_mask)
bg_style = self.model(im, bg_mask)
fg_style = self.model(im, fg_mask)
preds = self.outputs_postprocess(bg_style, fg_style)
return preds