• Docs >
  • Module code >
  • libcom.kontext_blending_harmonization.kontext_blending_harmonization
Shortcuts

Source code for libcom.kontext_blending_harmonization.kontext_blending_harmonization

import os
import torch
from PIL import Image
import os
import numpy as np
from .source.pipeline import FluxKontextPipeline 
from diffusers import FluxTransformer2DModel
import cv2
from libcom.utils.model_download import download_pretrained_model

cur_dir   = os.path.dirname(os.path.abspath(__file__))
model_dir = os.environ.get('LIBCOM_MODEL_DIR', cur_dir)
model_set = ['Kontext_blend', 'Kontext_harm'] 

[docs]class KontextBlendingHarmonizationModel: """ Flux Kontext based image blending and harmonization model. Args: device (str | torch.device): gpu id model_type (str): predefined model type. "Kontext_blend" refers to the version fintuned on the image blending task. "Kontext_harm" refers to the version finetuned on the image harmonization task. default: "Kontext_blend" kwargs (dict): other parameters for building model Examples: >>> from libcom import KontextBlendingHarmonizationModel >>> from libcom.utils.process_image import make_image_grid, draw_bbox_on_image >>> import cv2 >>> net = KontextBlendingHarmonizationModel(device=0, model_type="Kontext_blend") >>> img_names = ["000000049931.png", "000000460450.png", "6c5601278dcb5e6d_m09728_f5cd2891_17.png"] >>> bboxes = [[168, 137, 488, 413], [134, 158, 399, 511], [130, 91, 392, 271]] >>> test_dir = 'tests/controllable_composition/' >>> for i in range(len(img_names)): >>> bg_img = test_dir + 'background/' + img_names[i] >>> fg_img = test_dir + 'foreground/' + img_names[i] >>> bbox = bboxes[i] >>> mask = test_dir + 'foreground_mask/' + img_names[i] >>> comp = net(bg_img, fg_img, bbox, mask) >>> bg_img = draw_bbox_on_image(bg_img, bbox) >>> grid_img = make_image_grid([bg_img, fg_img, comp[0]]) >>> cv2.imwrite('../docs/_static/image/kontext_result{}.jpg'.format(i+1), grid_img) Expected result: .. image:: _static/image/kontext_result1.jpg :scale: 21 % .. image:: _static/image/kontext_result2.jpg :scale: 21 % """ def __init__(self, device=0, model_type='Kontext_blend', **kwargs): assert model_type in model_set, f'Not implementation for {model_type}' self.model_type = model_type self.option = kwargs weight_path = "black-forest-labs/FLUX.1-Kontext-dev" blending_lora_path = os.path.join(cur_dir, 'pretrained_models', 'flux_kontext_blending.safetensors') harmonization_lora_path = os.path.join(cur_dir, 'pretrained_models', 'flux_kontext_harmonization.safetensors') download_pretrained_model(blending_lora_path) download_pretrained_model(harmonization_lora_path) if model_type == 'Kontext_blend': lora_path = blending_lora_path elif model_type == 'Kontext_harm': lora_path = harmonization_lora_path else: raise ValueError(f"Unsupported model type: {model_type}") self.device = f"cuda:{device}" self.build_pretrained_model(weight_path, lora_path) def build_pretrained_model(self, weight_path, lora_path): transformer = FluxTransformer2DModel.from_pretrained(weight_path, subfolder="transformer") transformer.requires_grad_(False) self.pipeline = FluxKontextPipeline.from_pretrained( weight_path, transformer=transformer, torch_dtype=torch.bfloat16 ) self.pipeline.load_lora_weights(lora_path) self.pipeline = self.pipeline.to(self.device, dtype=torch.bfloat16) def generate_initial_image(self, bg_path, fg_path, fg_mask_path, bbox): bg_img = Image.open(bg_path).convert('RGB') fg_img = Image.open(fg_path).convert('RGB') if fg_mask_path is not None: fg_mask = Image.open(fg_mask_path).convert('L') # 灰度 fg_mask = fg_mask.resize((fg_img.width, fg_img.height)) fg_img_np = np.array(fg_img) fg_mask_np = np.array(fg_mask) # mask>127认为是物体,其他是背景 -> 背景改成白色 white = np.ones_like(fg_img_np) * 255 fg_img_np = np.where(fg_mask_np[..., None] > 127, fg_img_np, white) fg_img = Image.fromarray(fg_img_np) # 按照bbox缩放前景 x1, y1, x2, y2 = [int(v) for v in bbox] fg_resized = fg_img.resize((x2 - x1, y2 - y1)) # 把前景贴到背景上 bg_img.paste(fg_resized, (x1, y1, x2, y2)) bg_img = bg_img.resize((1024, 1024)) return bg_img def inputs_preprocess(self, background_image, foreground_image, bbox, foreground_mask): image = self.generate_initial_image(background_image, foreground_image, foreground_mask, bbox) return image @torch.no_grad() def __call__(self, background_image, foreground_image, bbox, foreground_mask=None, prompt='put it here', num_samples=1, sample_steps=28, guidance_scale=2.5, seed=321): """ Kontext based image blending and harmonization. Args: background_image (str): The path to background image. foreground_image (str): The path to foreground image. bbox (list): The bounding box which indicates the foreground's location in the background. [x1, y1, x2, y2]. foreground_mask (None | str): Mask of foreground image which indicates the foreground object region in the foreground image. default: None. prompt (str): The text prompt to guide the image generation. default: 'put it here'. num_samples (int): Number of images to be generated for each task. default: 1. sample_steps (int): Number of denoising steps. The recommended setting is 28 for FlowMatchEulerDiscreteScheduler. default: 28. guidance_scale (int): Scale in classifier-free guidance (minimum: 1; maximum: 20). default: 2.5. seed (int): Random Seed is used to reproduce results and same seed will lead to same results. Returns: composite_images (numpy.ndarray): Generated images with a shape of 512x512x3 or Nx512x512x3, where N indicates the number of generated images. """ image = self.inputs_preprocess(background_image, foreground_image, bbox, foreground_mask) comp_image_list = [] for i in range(num_samples): comp_img = self.pipeline( image=image, prompt=prompt, height=1024, width=1024, guidance_scale=guidance_scale, num_inference_steps=sample_steps, max_sequence_length=512, ).images[0] comp_image_list.append(cv2.cvtColor(np.array(comp_img), cv2.COLOR_RGB2BGR)) composite_images = np.stack(comp_image_list, axis=0) return composite_images