Shortcuts

Source code for libcom.fos_score.fos_score_prediction

import torch
import torchvision
from libcom.utils.model_download import download_pretrained_model
from libcom.utils.process_image import *
from libcom.utils.environment import *
from libcom.fos_score.source.config import Config
from libcom.fos_score.source.networks import StudentModel, SingleScaleD
import torch 
import os
import math
import datetime
import torchvision.transforms as transforms
from torchvision.utils import save_image
from einops import rearrange
cur_dir   = os.path.dirname(os.path.abspath(__file__))
model_dir = os.environ.get('LIBCOM_MODEL_DIR',cur_dir)
model_set = ['FOS_D'] 

[docs]class FOSScoreModel: """ Foreground object search score prediction model. Args: device (str | torch.device): gpu id model_type (str): predefined model type kwargs (dict): other parameters for building model Examples: >>> from libcom.utils.process_image import make_image_grid >>> from libcom import FOSScoreModel >>> import cv2 >>> import torch >>> task_name = 'fos_score_prediction' >>> MODEL_TYPE = 'FOS_D' >>> background = '../tests/source/background/f80eda2459853824_m09g1w_b2413ec8_11.png' >>> fg_bbox = [175, 82, 309, 310] # x1,y1,x2,y2 >>> foreground = '../tests/source/foreground/f80eda2459853824_m09g1w_b2413ec8_11.png' >>> foreground_mask = '../tests/source/foreground_mask/f80eda2459853824_m09g1w_b2413ec8_11.png' >>> composite_image = '../tests/source/composite/f80eda2459853824_m09g1w_b2413ec8_11.png' >>> net = FOSScoreModel(device=0, model_type=MODEL_TYPE) >>> score = net(background, foreground, fg_bbox, foreground_mask=foreground_mask) >>> grid_img = make_image_grid([background, foreground, composite_image], text_list=[f'fos_score:{score:.2f}']) >>> cv2.imshow('fos_score_demo', grid_img) Expected result: .. image:: _static/image/fos_score_result3.jpg :scale: 50 % .. image:: _static/image/fos_score_result2.jpg :scale: 50 % """ def __init__(self, device=0, model_type='FOS_D', **kwargs): assert model_type in model_set, f'Not implementation for {model_type}' self.model_type = model_type self.option = kwargs self.IMAGE_NET_MEAN = [0.5, 0.5, 0.5] self.IMAGE_NET_STD = [0.5, 0.5, 0.5] config_file = os.path.join(cur_dir, 'source/config/config_rfosd.yaml') self.cfg = Config(config_file) weight_path = os.path.join(model_dir, 'pretrained_models', '{}.pth'.format(self.model_type)) download_pretrained_model(weight_path) self.device = check_gpu_device(device) self.build_pretrained_model(weight_path) self.build_data_transformer()
[docs] def build_pretrained_model(self, weight_path): """ Build pretrained model from path of weight. """ model = SingleScaleD(False) model.load_state_dict(torch.load(weight_path, map_location='cpu', weights_only=True), strict=True) self.model = model.to(self.device).eval()
def build_data_transformer(self): self.image_size = self.cfg.image_size self.transformer = transforms.Compose([ transforms.Resize(self.cfg.image_size), transforms.ToTensor(), transforms.Normalize(mean=self.IMAGE_NET_MEAN, std=self.IMAGE_NET_STD) ]) def inputs_preprocess(self, background_image, foreground_image, foreground_mask, bbox): bg = cv2.cvtColor(read_image_opencv(background_image), cv2.COLOR_BGR2RGB) bg_h, bg_w, _ = bg.shape fg = cv2.cvtColor(read_image_opencv(foreground_image), cv2.COLOR_BGR2RGB) if foreground_mask is not None: fg_mask = cv2.cvtColor(read_image_opencv(foreground_mask), cv2.COLOR_BGR2GRAY) assert fg.shape[:2] == fg_mask.shape fg[np.where(fg_mask != 255)] = 128 x1, y1, x2, y2 = bbox rs_fg = cv2.resize(fg.copy(), (x2 - x1, y2 - y1)) bg[y1:y2, x1:x2] = rs_fg x1, y1, x2, y2 = self.get_crop_bbox(bbox, bg_h, bg_w) scale_comp = bg.copy()[y1:y2, x1:x2] scale_comp_r = Image.fromarray(scale_comp, mode="RGB") return scale_comp_r def get_crop_bbox(self, bbox, bg_h, bg_w): x1, y1, x2, y2 = bbox ori_tar_w = x2 - x1 ori_tar_h = y2 - y1 add_w = int(ori_tar_w * (math.sqrt(2) - 1) / 2) add_h = int(ori_tar_h * (math.sqrt(2) - 1) / 2) new_y1 = max(0, y1 - add_h) new_y2 = min(bg_h, y2 + add_h) new_x1 = max(0, x1 - add_w) new_x2 = min(bg_w, x2 + add_w) return new_x1, new_y1, new_x2, new_y2 def prepare_input_encoders(self, background_image, foreground_image, bounding_box): background_image = fill_box_with_specified_pixel(background_image, bounding_box, self.cfg.fill_pixel) bg_t = self.transformer(background_image) bg_ori_w, bg_ori_h = background_image.size ori_x1, ori_y1 = bounding_box[0:2] ori_x2, ori_y2 = bounding_box[2:4] query_box = torch.tensor([ori_x1 / bg_ori_w, ori_y1 / bg_ori_h, ori_x2 / bg_ori_w, ori_y2 / bg_ori_h]) * self.cfg.image_size query_box = torch.round(query_box) query_box = query_box.float() ori_tar_w = ori_x2 - ori_x1 ori_tar_h = ori_y2 - ori_y1 add_w = int(ori_tar_w * (math.sqrt(2) - 1) / 2) add_h = int(ori_tar_h * (math.sqrt(2) - 1) / 2) new_y1 = max(0, ori_y1 - add_h) new_y2 = min(bg_ori_h, ori_y2 + add_h) new_x1 = max(0, ori_x1 - add_w) new_x2 = min(bg_ori_w, ori_x2 + add_w) new_box = torch.tensor([new_x1 / bg_ori_w, new_y1 / bg_ori_h, new_x2 / bg_ori_w, new_y2 / bg_ori_h]) * self.cfg.image_size new_box = torch.round(new_box) new_box = new_box.float() pad_fg = padding_to_square(foreground_image.copy(), self.cfg.pad_pixel) pad_fg = Image.fromarray(pad_fg) fg_t = self.transformer(pad_fg) sample = { 'bg': bg_t.unsqueeze(0), 'fg': fg_t.unsqueeze(0), 'query_box': query_box.unsqueeze(0), 'crop_box': new_box.unsqueeze(0), } return sample def prepare_input_disc(self, composite_image): comp = self.transformer(composite_image).unsqueeze(0) return comp def preprocess_image(self, image): image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] image = rearrange(image, "h w c -> 1 c h w") image = image.to(self.device) return image @torch.no_grad() def __call__( self, background_image, foreground_image, bounding_box, foreground_mask=None ): """ Predicting the compatibility score between the given background and the given foreground. Args: background_image (str | numpy.ndarray): The path to background image or the background image in ndarray form. foreground_image (str | numpy.ndarray): The path to foreground image or the background image in ndarray form. bounding_box (list): The bounding box which indicates the foreground's location in the background. [x1, y1, x2, y2]. foreground_mask (str | numpy.ndarray): Mask of foreground image which indicates the foreground object region in the foreground image. default: None. Returns: fos_score (float): Predicted compatibility score between the given background image and the given foreground image. """ composite_image = self.inputs_preprocess(background_image, foreground_image, foreground_mask, bounding_box) composite_image = self.prepare_input_disc(composite_image).to(self.device) _, score = self.model(composite_image) output = score[-1].item() return output
def fill_box_with_specified_pixel(bg_im, query_box, fill_value): x1, y1 = query_box[0:2] x2, y2 = query_box[2:4] bg_im = np.array(bg_im) bg_im[y1:y2, x1:x2] = fill_value bg_im = Image.fromarray(bg_im) return bg_im def padding_to_square(src_img, pad_pixel=255): src_h, src_w = src_img.shape[:2] if src_h == src_w: return src_img if src_w > src_h: pad_w = 0 pad_h = src_w - src_h else: pad_w = src_h - src_w pad_h = 0 pad_y1 = int(pad_h // 2) pad_y2 = int(pad_h - pad_y1) pad_x1 = int(pad_w // 2) pad_x2 = int(pad_w - pad_x1) if len(src_img.shape) == 3: pad_im = np.pad(src_img, ((pad_y1, pad_y2), (pad_x1, pad_x2), (0,0)), 'constant', constant_values=pad_pixel) else: pad_im = np.pad(src_img, ((pad_y1, pad_y2), (pad_x1, pad_x2)), 'constant', constant_values=pad_pixel) return pad_im