Shortcuts

Source code for libcom.fopa_heat_map.fopa_heat_map

import csv
import random
import shutil
import torch
from libcom.utils.model_download import download_pretrained_model
from libcom.utils.process_image import *
from libcom.utils.environment import *
import torch 
import os
import torchvision.transforms as transforms

import numpy as np
from PIL import Image
import libcom.fopa_heat_map.source.network.ObPlaNet_simple as network
from libcom.fopa_heat_map.source.data.OBdataset import make_composite_PIL
from libcom.fopa_heat_map.source.prepare_multi_fg_scales import prepare_multi_fg_scales
from libcom.fopa_heat_map.source.data.all_transforms import Compose, JointResize

cur_dir = os.path.dirname(os.path.abspath(__file__))
model_dir = os.environ.get('LIBCOM_MODEL_DIR',cur_dir)
model_set = ['fopa'] 


[docs]class FOPAHeatMapModel: """ Generate a heatmap for a pair of scaled foreground and background. Args: device (str | torch.device): gpu id model_type (str): predefined model type Examples: >>> test_set = get_test_list_fopa_heatmap() >>> result_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'results', task_name) >>> if os.path.exists(result_dir): >>> shutil.rmtree(result_dir) >>> os.makedirs(result_dir, exist_ok=True) >>> os.makedirs(os.path.join(result_dir, 'grid'), exist_ok=True) >>> print(f'begin testing {task_name}...') >>> net = FOPAHeatMapModel(device=0) >>> for pair in test_set[:1]: >>> fg_img, fg_mask, bg_img = pair['foreground'], pair['foreground_mask'], pair['background'] >>> bboxes, heatmaps = net(fg_img, fg_mask, bg_img, cache_dir=os.path.join(result_dir, 'cache'), heatmap_dir=os.path.join(result_dir, 'heatmap')) >>> img_name = os.path.basename(bg_img).replace('.png', '.jpg') >>> grid_img = make_image_grid([bg_img, fg_img, heatmaps[0]]) >>> res_path = os.path.join(result_dir, 'grid', img_name) >>> cv2.imwrite(res_path, grid_img) >>> print('save result to ', res_path) >>> print(f'end testing {task_name}!') Expected result: .. image:: _static/image/fopa_heatmap_FOPA_result1.png :scale: 72 % .. image:: _static/image/fopa_heatmap_FOPA_result3.png :scale: 52 % """ def __init__(self, device=0, model_type='fopa', **kwargs): assert model_type in model_set, f'Not implementation for {model_type}' self.model_type = model_type self.option = kwargs self.device = check_gpu_device(device) fopa_weight = os.path.join(model_dir, 'pretrained_models', 'FOPA.pth') download_pretrained_model(fopa_weight) sopa_weight = os.path.join(model_dir, 'pretrained_models', 'SOPA.pth') download_pretrained_model(sopa_weight) self.build_pretrained_model(sopa_weight, fopa_weight) self.build_data_transformer() def build_pretrained_model(self, sopa_weight, fopa_weight): model = getattr(network, "ObPlaNet_resnet18")(pretrained=False, weight_path=sopa_weight).to(self.device) model.load_state_dict(torch.load(fopa_weight, map_location='cpu')) self.model = model.to(self.device).eval() def build_data_transformer(self): self.triple_transform = Compose([JointResize(256)]) self.image_transform = transforms.Compose( [transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),] ) self.mask_transform = transforms.ToTensor() def generate_heatmap(self, background_image, csv_path, scaled_fg_dir, scaled_mask_dir, heatmap_dir): os.makedirs(heatmap_dir, exist_ok=True) heatmap_list = [] with open(csv_path, mode='r', newline='') as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: fg_name = '{}_{}_{}_{}.jpg'.format(row["fg_name"].split(".")[0],row["bg_name"].split(".")[0],int(row["newWidth"]),int(row["newHeight"])) mask_name = '{}_{}_{}_{}.jpg'.format(row["fg_name"].split(".")[0],row["bg_name"].split(".")[0],int(row["newWidth"]),int(row["newHeight"])) scale = row['scale'] save_name = fg_name.split(".")[0] + '_' + str(scale) + '.jpg' bg_img = read_image_pil(background_image) fg_img = read_image_pil(os.path.join(scaled_fg_dir, fg_name)) mask = read_mask_pil(os.path.join(scaled_mask_dir, mask_name)) bg_t, fg_t, mask_t = self.triple_transform(bg_img, fg_img, mask) mask_t = self.mask_transform(mask_t).to(self.device) fg_t = self.image_transform(fg_t).to(self.device) bg_t = self.image_transform(bg_t).to(self.device) outs, _ = self.model(bg_t.unsqueeze(0), fg_t.unsqueeze(0), mask_t.unsqueeze(0), 'test') outs = torch.softmax(outs, dim=1)[:,1,:,:] outs = transforms.ToPILImage()(outs) out_file = os.path.join(heatmap_dir, save_name) outs.save(out_file) heatmap_list.append(out_file) return heatmap_list def generate_bounding_box(self, foreground_image, foreground_mask, background_image, csv_path, cache_dir, heatmap_dir, fg_scale_num, composite_num_choose, composite_num): icount = 0 with open(csv_path, mode='r', newline='') as csv_file: csv_reader = csv.DictReader(csv_file) for row in csv_reader: scale = row['scale'] w = int(row['newWidth']) h = int(row['newHeight']) save_name = row['fg_name'].split(".")[0] + '_' + str(scale) + '.jpg' fg_name = '{}_{}_{}_{}.jpg'.format(row["fg_name"].split(".")[0],row["bg_name"].split(".")[0],int(row["newWidth"]),int(row["newHeight"])) save_name = fg_name.split(".")[0] + '_' + str(scale) + '.jpg' if icount == 0: bg_img = read_image_pil(background_image) fg_tocp = read_image_pil(foreground_image) mask_tocp = read_image_pil(foreground_mask) composite_dir = os.path.join(cache_dir, f'{fg_scale_num}scales_composite', f'{row["fg_name"].split(".")[0]}_{row["bg_name"].split(".")[0]}') os.makedirs(composite_dir, exist_ok=True) heatmap_center_list = [] fg_size_list = [] icount += 1 heatmap = Image.open(os.path.join(heatmap_dir, save_name)) heatmap = np.array(heatmap) heatmap_center = np.zeros_like(heatmap, dtype=np.float64) hb = int(h / bg_img.height * heatmap.shape[0] / 2) wb = int(w / bg_img.width * heatmap.shape[1] / 2) heatmap_center[hb:-hb, wb:-wb] = heatmap[hb:-hb, wb:-wb] heatmap_center_list.append(heatmap_center) fg_size_list.append((h,w)) if icount == fg_scale_num: icount = 0 heatmap_center_stack = np.stack(heatmap_center_list) sorted_indices = np.argsort(-heatmap_center_stack, axis=None) sorted_indices = np.unravel_index(sorted_indices, heatmap_center_stack.shape) for i in range(composite_num): iscale, y_, x_ = sorted_indices[0][i], sorted_indices[1][i], sorted_indices[2][i] h, w = fg_size_list[iscale] x_ = x_/heatmap.shape[1]*bg_img.width y_ = y_/heatmap.shape[0]*bg_img.height x = int(x_ - w / 2) y = int(y_ - h / 2) composite_img, composite_msk = make_composite_PIL(fg_tocp, mask_tocp, bg_img, [x, y, w, h], return_mask=True) save_img_path = os.path.join(composite_dir, f'{row["fg_name"].split(".")[0]}_{row["bg_name"].split(".")[0]}_{x}_{y}_{w}_{h}.jpg') save_msk_path = os.path.join(composite_dir, f'{row["fg_name"].split(".")[0]}_{row["bg_name"].split(".")[0]}_{x}_{y}_{w}_{h}.png') composite_img.save(save_img_path) composite_msk.save(save_msk_path) source_folder = composite_dir composite_dir_choose = os.path.join(cache_dir, f'{fg_scale_num}scales_composite_RandomSelect_{composite_num_choose}', f'{row["fg_name"].split(".")[0]}_{row["bg_name"].split(".")[0]}') os.makedirs(composite_dir_choose, exist_ok=True) bbox_list = [] image_files = [f for f in os.listdir(composite_dir) if f.endswith((".jpg"))] selected_images = random.sample(image_files, composite_num_choose) for image in selected_images: source_path = os.path.join(source_folder, image) destination_path = os.path.join(composite_dir_choose, image) png_source_path = source_path.replace(".jpg", ".png") destination_path_png = os.path.join(composite_dir_choose, image.replace(".jpg", ".png")) shutil.copyfile(source_path, destination_path) shutil.copyfile(png_source_path, destination_path_png) bbox_list_parts = os.path.splitext(os.path.split(source_path)[1])[0].split('_') x_box, y_box, w_box, h_box = int(bbox_list_parts[2]), int(bbox_list_parts[3]), int(bbox_list_parts[4]), int(bbox_list_parts[5]) last_four_numbers = [x_box, y_box, w_box, h_box] bbox_list.append(last_four_numbers) return bbox_list @torch.no_grad() def __call__(self, foreground_image, foreground_mask, background_image, cache_dir, heatmap_dir, fg_scale_num=16, composite_num_choose=3, composite_num=50): """ Generate a heatmap for a pair of scaled foreground and background. Args: foreground_image: foreground image path foreground_mask: foreground mask path background_image: background image path cache_dir: folder path where scaled foreground images, scaled mask images and composite images are stored heatmap_dir: folder path where heatmaps are stored fg_scale_num: number of scales of scaled foreground images and mask images composite_num_choose: the number of chosen composite images composite_num: the number of composite images with the highest score Returns: box_list: the path of concatenated background image, foreground image and corresponding heatmap heatmap_list: the path of heatmaps """ scaled_fg_dir, scaled_mask_dir, csv_file = prepare_multi_fg_scales(cache_dir, foreground_image, foreground_mask, background_image, fg_scale_num) heatmap_list = self.generate_heatmap(background_image, csv_file, scaled_fg_dir, scaled_mask_dir, heatmap_dir) box_list = self.generate_bounding_box(foreground_image, foreground_mask, background_image, csv_file, cache_dir, heatmap_dir, fg_scale_num, composite_num_choose, composite_num) return box_list, heatmap_list