import torch
import torchvision
from libcom.utils.model_download import download_pretrained_model, download_entire_folder
from libcom.utils.process_image import *
from libcom.utils.environment import *
import torch
import torch.nn.functional as F
import os
import torchvision.transforms as transforms
from tqdm.auto import tqdm
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from PIL import Image
import torchvision.transforms.functional as tf
from .source.PHDNet.phdnet import PHDNet
from .source.PHDiffusion.ldm.modules.encoders.adapter import Adapter,NoRes_Adapter
from .source.PHDiffusion.ldm.models.diffusion.scheduling_pndm import PNDMScheduler
from .source.PHDiffusion.ldm.util import instantiate_from_config
import logging
logging.getLogger('transformers').setLevel(logging.ERROR) # disable transformer lib warning
cur_dir = os.path.dirname(os.path.abspath(__file__))
model_dir = os.environ.get('LIBCOM_MODEL_DIR',cur_dir)
model_set = ['PHDNet', 'PHDiffusion']
[docs]class PainterlyHarmonizationModel:
"""
Painterly image harmonization prediction model.
Args:
device (str | torch.device): gpu id
model_type (str): predefined model type
kwargs (dict): use_residual (bool): whether to use adapter with residual or not for PHDiffusion
Examples:
>>> from libcom.utils.process_image import make_image_grid
>>> from libcom import PainterlyHarmonizationModel
>>> import cv2
>>> import torch
>>> task_name = 'painterly_image_harmonization'
>>> MODEL_TYPE = 'PHDNet' # choose from 'PHDNet', 'PHDiffusion'
>>> comp_img = '../tests/painterly_harmonization_source/composite/3.png'
>>> comp_mask = '../tests/painterly_harmonization_source/composite_mask/3.png'
>>> net = PainterlyHarmonizationModel(device=0, model_type=MODEL_TYPE)
>>> output_img = net(comp_img, comp_mask)
>>> grid_img = make_image_grid([comp_img, comp_mask, output_img])
>>> cv2.imshow('painterly_image_harmonization_demo', grid_img)
Expected result:
.. image:: _static/image/painterly_image_harmonization_result2.jpg
.. image:: _static/image/painterly_image_harmonization_result3.jpg
"""
def __init__(self, device=0, model_type='PHDNet', **kwargs):
assert model_type in model_set, f'Not implementation for {model_type}'
self.model_type = model_type
self.option = kwargs
if model_type == 'PHDNet':
weight_path = os.path.join(model_dir, 'pretrained_models', model_type + '.pth')
self.device = check_gpu_device(device)
download_pretrained_model(weight_path)
self.build_pretrained_model(weight_path)
elif model_type == 'PHDiffusion':
self.use_residual = self.option.get('use_residual', True)
if self.use_residual:
phdiff_weight_path = os.path.join(model_dir, 'pretrained_models', model_type+'WithRes.pth')
else:
phdiff_weight_path = os.path.join(model_dir, 'pretrained_models', model_type+'.pth')
sd_weight_path = os.path.join(model_dir, '../shared_pretrained_models', 'sd-v1-4.ckpt')
download_pretrained_model(sd_weight_path)
download_pretrained_model(phdiff_weight_path)
self.device = check_gpu_device(device)
self.build_pretrained_model(sd_weight_path, phdiff_weight_path)
self.build_data_transformer()
def build_pretrained_model(self, *weight_path):
if len(weight_path) == 1:
weight_path = weight_path[0]
# build PHDNet
assert self.model_type == 'PHDNet', self.model_type
model = PHDNet(self.device)
model.load_networks(weight_path)
self.model = model.to(self.device).eval()
elif len(weight_path) == 2:
# build PHDiffusion model
sd_weight_path, phdiff_weight_path = weight_path
assert self.model_type == 'PHDiffusion', self.model_type
self.config= OmegaConf.load(cur_dir+'/source/PHDiffusion/stable_diffusion.yaml')
self.config.model.params.cond_stage_config.params.model_path = "openai/clip-vit-large-patch14"
pl_sd = torch.load(sd_weight_path, map_location="cpu", weights_only=False)
sd = pl_sd["state_dict"]
model = instantiate_from_config(self.config.model)
model.load_state_dict(sd, strict=False)
if self.use_residual:
adapter=Adapter(cin=int(64 * 4), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
else:
adapter=NoRes_Adapter(cin=int(64 * 4), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
model_resume_state = torch.load(phdiff_weight_path, map_location='cpu')
adapter.load_state_dict(model_resume_state['ad'])
model.model.diffusion_model.interact_blocks.load_state_dict(model_resume_state['interact'])
self.model = model.to(self.device).eval()
self.adapter = adapter.to(self.device).eval()
self.scheduler = PNDMScheduler(
beta_end=0.012,
beta_schedule='scaled_linear',
beta_start=0.00085,
num_train_timesteps=1000,
set_alpha_to_one=False,
skip_prk_steps=True,
steps_offset=1,
)
def build_data_transformer(self):
self.image_size = 512
self.transformer = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.mask_transform = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
])
def inputs_preprocess(self, composite_image, composite_mask):
img = read_image_pil(composite_image)
img = self.transformer(img)
mask = read_mask_pil(composite_mask).convert('L')
mask = self.mask_transform(mask)
if self.model_type == 'PHDiffusion':
img = img.unsqueeze(0).to(self.device)
mask = mask.unsqueeze(0).to(self.device)
return img, mask
def outputs_postprocess(self, outputs):
if outputs.dim() == 4:
outputs = outputs.squeeze(0)
outputs = (torch.clamp((outputs.permute(1, 2, 0) + 1) / 2.0 * 255, 0, 255)).cpu().numpy()
outputs = outputs.astype(np.uint8)
outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
return outputs
@torch.no_grad()
def PHDNet_inference(self, composite_image, composite_mask):
comp, mask = self.inputs_preprocess(composite_image, composite_mask)
outputs = self.model(comp, mask)
preds = self.outputs_postprocess(outputs)
return preds
@torch.no_grad()
def __call__(self, composite_image, composite_mask, sample_steps=50, strength=0.7, random_seed=None):
"""
Generating the harmonized image for the given composite image and the corresponding composite mask.
Args:
composite_image (str | numpy.ndarray): The path to the composite image or the composite image in ndarray form.
composite_mask (str | numpy.ndarray): The path to the composite mask or the composite mask in ndarray form.
sample_steps (int): Default total step in the inference process of PHDiffusion.
strength (float): A hyper-parameter that decides the total step (strength * sample_steps) for PHDiffusion.
Returns:
preds (numpy.ndarray): Generated harmonized image for the given composite image and the corresponding composite mask, with BGR channel.
"""
if self.model_type == 'PHDNet':
return self.PHDNet_inference(composite_image, composite_mask)
if random_seed != None:
seed_everything(random_seed)
comp, mask = self.inputs_preprocess(composite_image, composite_mask)
self.scheduler.set_timesteps(sample_steps, device=self.device)
init_timestep = min(int(sample_steps * strength), sample_steps)
t_start = max(sample_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
c = self.model.get_learned_conditioning([''])
batch_size = mask.shape[0]
latent_timestep = timesteps[:1].repeat(batch_size)
x_0 = self.model.encode_first_stage(comp)
x_0 = self.model.get_first_stage_encoding(x_0)
mask_latents = F.interpolate(mask, size=x_0.shape[-2:])
noise = torch.randn(x_0.shape, device=x_0.device, dtype=x_0.dtype)
latents = self.scheduler.add_noise(x_0, noise, latent_timestep)
adapter_input = torch.cat((comp, mask), dim=1).to(dtype=comp.dtype)
features_adapter = self.adapter(adapter_input)
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
noise = torch.randn(x_0.shape, device=x_0.device, dtype=x_0.dtype)
t_latents = self.scheduler.add_noise(x_0, noise, t)
latents = latents * mask_latents + t_latents * (1 - mask_latents)
latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.model.model.diffusion_model(x=latent_model_input,
fg_mask=mask_latents,
timesteps=t.repeat(batch_size),
context=c,
features_adapter=features_adapter)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
x_samples_ddim = self.model.decode_first_stage(latents)
preds = self.outputs_postprocess(x_samples_ddim)
return preds