import torch
import torchvision
from libcom.utils.model_download import download_pretrained_model, download_entire_folder
from libcom.utils.process_image import *
from libcom.utils.environment import *
import os
import torchvision.transforms as transforms
import sys
import cv2
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor
cur_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.join(cur_dir, 'source/src')
if src_dir not in sys.path:
sys.path.insert(0, src_dir)
from libcom.image_harmonization.source.pct_net import *
from libcom.image_harmonization.source.src.lbm.inference import get_model
from diffusers import FlowMatchEulerDiscreteScheduler
# =======================================================
model_dir = os.environ.get('LIBCOM_MODEL_DIR',cur_dir)
model_set = ['PCTNet', 'LBM']
[docs]class ImageHarmonizationModel:
"""
Image harmonization model.
Args:
device (str | torch.device): gpu id
model_type (str): predefined model type, 'PCTNet' or 'LBM'
kwargs (dict): other parameters for building model.
For LBM, you can set 'ckpt_path' here.
Examples:
>>> from libcom import ImageHarmonizationModel
>>> import cv2
>>> import os
>>> import numpy as np
>>> from PIL import Image
>>> #Use PCTNet
>>> PCTNet = ImageHarmonizationModel(device=0, model_type='PCTNet')
>>> comp_img1 = '../tests/source/composite/comp1_PCTNet.jpg'
>>> comp_mask1 = '../tests/source/composite_mask/mask1_PCTNet.png'
>>> PCT_result1 = PCTNet(comp_img1, comp_mask1)
>>> cv2.imwrite('../docs/_static/image/image_harmonization_PCT_result1.jpg', np.concatenate([cv2.imread(comp_img1), cv2.imread(comp_mask1), PCT_result1],axis=1))
>>> #Use LBM
>>> LBM = ImageHarmonizationModel(device=0, model_type='LBM')
>>> comp_img = '../tests/source/composite/1.jpg'
>>> comp_mask = '../tests/source/composite_mask/1.png'
>>> LBM_result = LBM(comp_img, comp_mask, steps=4)
>>> cv2.imwrite('../docs/_static/image/image_harmonization_LBM_result.jpg', np.concatenate([cv2.imread(comp_img), cv2.imread(comp_mask), LBM_result],axis=1))
Expected result:
.. image:: _static/image/image_harmonization_PCT_result1.jpg
.. image:: _static/image/image_harmonization_LBM_result.jpg
"""
def __init__(self, device=0, model_type='PCTNet', **kwargs):
assert model_type in model_set, f'Not implementation for {model_type}'
self.model_type = model_type
self.option = kwargs
self.device = check_gpu_device(device)
if self.model_type == 'PCTNet':
weight_path = os.path.join(model_dir, 'pretrained_models', 'PCTNet.pth')
download_pretrained_model(weight_path)
lut_path = os.path.join(model_dir, 'pretrained_models', 'IdentityLUT33.txt')
download_pretrained_model(lut_path)
self.build_pretrained_model(weight_path)
self.build_data_transformer()
elif self.model_type == 'LBM':
lbm_dir = os.path.join(model_dir, 'pretrained_models', 'lbm_ckpt')
download_entire_folder(lbm_dir)
self.build_pretrained_model(lbm_dir)
def build_pretrained_model(self, weight_path):
if self.model_type == 'LBM':
self.model = get_model(weight_path, torch_dtype=torch.bfloat16, device=self.device)
self.model.bridge_noise_sigma = 0.005
if self.model.sampling_noise_scheduler is None:
self.model.sampling_noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler"
)
self.model.eval()
else:
model = PCTNet()
model.load_state_dict(torch.load(weight_path, map_location='cpu', weights_only=True))
self.model = model.to(self.device).eval()
def build_data_transformer(self):
self.transformer = transforms.Compose([
transforms.ToTensor(),
])
def inputs_preprocess(self, composite_image, composite_mask):
img = read_image_opencv(composite_image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = read_mask_opencv(composite_mask) / 255.0
img_lr = cv2.resize(img, (256, 256))
mask_lr = cv2.resize(mask, (256, 256))
#to tensor
img = self.transformer(img).float().to(self.device)
mask = self.transformer(mask).float().to(self.device)
img_lr = self.transformer(img_lr).float().to(self.device)
mask_lr = self.transformer(mask_lr).float().to(self.device)
return img, mask, img_lr, mask_lr
def outputs_postprocess(self, outputs):
if len(outputs.shape) == 4:
outputs = outputs.squeeze(0)
outputs = (torch.clamp(255.0 * outputs.permute(1, 2, 0), 0, 255)).cpu().numpy()
outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
return outputs
@torch.no_grad()
def __call__(self, composite_image, composite_mask, **kwargs):
"""
Given a composite image and a foreground mask, perform harmonization on the foreground.
Args:
composite_image (str | numpy.ndarray): The path to composite image or the compposite image in ndarray form.
composite_mask (str | numpy.ndarray): Mask of composite image which indicates the foreground object region in the composite image.
**kwargs: Extra parameters for inference (e.g., steps=4, resolution=1024 for LBM).
Returns:
harmonized_image (np.array): The harmonized result.
"""
if self.model_type == 'LBM':
return self._inference_lbm(composite_image, composite_mask, **kwargs)
img, mask, img_lr, mask_lr = self.inputs_preprocess(composite_image, composite_mask)
outputs = self.model(img_lr, img, mask_lr, mask)
preds = self.outputs_postprocess(outputs)
return preds
@torch.no_grad()
def _inference_lbm(self, composite_image, composite_mask, **kwargs):
steps = kwargs.get('steps', self.option.get('steps', 4))
inference_size = kwargs.get('resolution', self.option.get('resolution', 1024))
latent_size = inference_size // 8
if isinstance(composite_image, str):
src_raw = Image.open(composite_image).convert("RGB")
else:
src_raw = Image.fromarray(cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB))
if isinstance(composite_mask, str):
mask_raw = Image.open(composite_mask).convert("L")
else:
if len(composite_mask.shape) == 3:
composite_mask = composite_mask[:, :, 0]
mask_raw = Image.fromarray(composite_mask).convert("L")
ori_w, ori_h = src_raw.size
src_resized = src_raw.resize((inference_size, inference_size), Image.BILINEAR)
src_t = (ToTensor()(src_resized).unsqueeze(0) * 2 - 1).to(self.device, dtype=torch.bfloat16)
batch = {"source_image_paste": src_t}
mask_latent_img = mask_raw.resize((latent_size, latent_size), Image.BILINEAR)
mask_t = ToTensor()(mask_latent_img).unsqueeze(0).to(self.device, dtype=torch.bfloat16)
z_source = self.model.vae.encode(batch["source_image_paste"])
output_tensor = self.model.sample(
z=z_source,
num_steps=int(steps),
conditioner_inputs=batch,
max_samples=1,
mask=mask_t
).clamp(-1, 1)
res_tensor = (output_tensor[0].float().cpu() + 1) / 2
preds = (torch.clamp(255.0 * res_tensor.permute(1, 2, 0), 0, 255)).numpy().astype(np.uint8)
preds = cv2.cvtColor(preds, cv2.COLOR_RGB2BGR)
preds = cv2.resize(preds, (ori_w, ori_h))
return preds