Shortcuts

Source code for libcom.os_insert.os_insert

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

from typing import Literal

import os

import cv2
import numpy as np

from libcom.utils.model_download import download_pretrained_model, download_entire_folder
from .source.insertanything_infer import InsertAnythingModel, insertanything_infer, run_insertanything
from .source.objectstitch_infer import (
    ObjectStitchConfig,
    run_objectstitch_single_image,
    run_objectstitch_single_image_from_images,
    load_objectstitch_model_and_sampler,
    run_objectstitch_single_image_from_images_cached,
)
from .source.sam_on_objectstitch import (
    SamOnObjectStitchConfig,
    build_sam_predictor,
    run_sam_on_objectstitch,
    run_sam_on_objectstitch_with_predictor,
)
from .source.utils import load_bbox_txt, make_rect_mask_from_bbox


@dataclass
class OSInsertConfig:
    model_dir: Path
    device: str = "cuda:0"
    objectstitch_ckpt_path: Path | None = None
    objectstitch_config_path: Path | None = None
    objectstitch_clip_dir: Path | None = None
    sam_checkpoint: Path | None = None
    flux_fill_path: Path | None = None
    flux_redux_path: Path | None = None
    ia_lora_path: Path | None = None

[docs]class OSInsertModel: """ High-level OSInsert interface. This model provides a unified interface for object insertion with two modes: conservative and aggressive. It internally combines multiple sub-models such as InsertAnything, ObjectStitch, and SAM. Modes ----- - ``aggressive``: ObjectStitch + SAM + InsertAnything pipeline. Suitable for more complex and flexible compositions. - ``conservative``: Directly uses background + bbox to generate mask, then performs insertion via InsertAnything. Faster and more stable. Args: device (str): Device to run the model on (e.g., "cuda:0", "cpu"). model_dir (str | Path | None): Root directory of all model checkpoints. eager_aggressive_init (bool): If True, preload ObjectStitch and SAM models at initialization. Otherwise, they will be lazily loaded when first used. objectstitch_ckpt_path (str | Path | None): Path to ObjectStitch checkpoint. objectstitch_config_path (str | Path | None): Path to ObjectStitch config file. objectstitch_clip_dir (str | Path | None): Path to CLIP model directory used by ObjectStitch. sam_checkpoint (str | Path | None): Path to SAM (Segment Anything Model) checkpoint. flux_fill_path (str | Path | None): Path to Flux Fill model directory. flux_redux_path (str | Path | None): Path to Flux Redux model directory. ia_lora_path (str | Path | None): Path to LoRA weights for InsertAnything. Notes ----- - InsertAnything is initialized during class construction. - ObjectStitch and SAM are lazily initialized (unless ``eager_aggressive_init=True``), and then cached for reuse. - Conservative mode does not require ObjectStitch. Examples: >>> import cv2 >>> from libcom import OSInsertModel >>> model = OSInsertModel( >>> device="cuda:0" >>> ) >>> bg = cv2.imread("tests/osinsert/background/Demo_0.png") >>> fg = cv2.imread("tests/osinsert/foreground/Demo_0.png") >>> fg_mask = cv2.imread( >>> "tests/osinsert/foreground_mask/Demo_0.png", >>> cv2.IMREAD_GRAYSCALE >>> ) >>> bbox = (175, 184, 363, 372) >>> result = model.infer_images( >>> background=bg, >>> foreground=fg, >>> foreground_mask=fg_mask, >>> bbox_xyxy=bbox, >>> mode="conservative", # or "aggressive" >>> verbose=False, >>> seed=123, >>> strength=1.0, >>> split_ratio=0.33, >>> save_path="result_dir/conservative", >>> ) Expected result: The foreground object is inserted into the background image at the specified bounding box, with realistic blending. .. image:: _static/image/os_insert_result.jpg """ def __init__( self, device: str = "cuda:0", model_dir: str | Path | None = None, *, eager_aggressive_init: bool = False, objectstitch_ckpt_path: str | Path | None = None, objectstitch_config_path: str | Path | None = None, objectstitch_clip_dir: str | Path | None = None, sam_checkpoint: str | Path | None = None, flux_fill_path: str | Path | None = None, flux_redux_path: str | Path | None = None, ia_lora_path: str | Path | None = None, ) -> None: # 设置模型目录 cur_dir = os.path.dirname(os.path.abspath(__file__)) self.model_dir = os.environ.get('LIBCOM_MODEL_DIR', cur_dir) # 设置默认路径 if model_dir is None: model_dir = self.model_dir # 下载必要的模型权重 self._download_model_weights() self.config = OSInsertConfig( model_dir=Path(model_dir), device=device, objectstitch_ckpt_path=Path(objectstitch_ckpt_path) if objectstitch_ckpt_path is not None else None, objectstitch_config_path=Path(objectstitch_config_path) if objectstitch_config_path is not None else None, objectstitch_clip_dir=Path(objectstitch_clip_dir) if objectstitch_clip_dir is not None else None, sam_checkpoint=Path(sam_checkpoint) if sam_checkpoint is not None else None, flux_fill_path=Path(flux_fill_path) if flux_fill_path is not None else None, flux_redux_path=Path(flux_redux_path) if flux_redux_path is not None else None, ia_lora_path=Path(ia_lora_path) if ia_lora_path is not None else None, ) self._ia_net = InsertAnythingModel( model_dir=self.config.model_dir, flux_fill_path=self.config.flux_fill_path, flux_redux_path=self.config.flux_redux_path, ia_lora_path=self.config.ia_lora_path, device=self.config.device, ) self._os_model = None self._os_sampler = None self._sam_predictor = None if eager_aggressive_init: self._get_objectstitch_model_and_sampler() self._get_sam_predictor() def _download_model_weights(self): """下载所有必要的模型权重""" # ObjectStitch 模型权重 objectstitch_ckpt_path = os.path.join(self.model_dir, 'pretrained_models', 'ObjectStitch.pth') download_pretrained_model(objectstitch_ckpt_path) # ObjectStitch 配置文件 - 使用 source 目录中的本地文件 objectstitch_config_path = os.path.join(os.path.dirname(__file__), 'source', 'v1.yaml') if not os.path.exists(objectstitch_config_path): print(f'v1.yaml not found at {objectstitch_config_path}') raise Exception('v1.yaml not found') print('Using local v1.yaml configuration file from source directory') # CLIP 模型 clip_dir = os.path.join(self.model_dir, 'shared_pretrained_models', 'openai-clip-vit-large-patch14') download_entire_folder(clip_dir) # SAM 模型权重 sam_ckpt_path = os.path.join(self.model_dir, 'pretrained_models', 'sam_vit_h_4b8939.pth') download_pretrained_model(sam_ckpt_path) # InsertAnything LoRA 权重 ia_lora_path = os.path.join(self.model_dir, 'pretrained_models', 'insert_anything_lora.safetensors') download_pretrained_model(ia_lora_path) def _get_objectstitch_model_and_sampler(self): if self._os_model is not None and self._os_sampler is not None: return self._os_model, self._os_sampler objectstitch_ckpt = self.config.objectstitch_ckpt_path if objectstitch_ckpt is None: objectstitch_ckpt = os.path.join(self.model_dir, 'pretrained_models', 'ObjectStitch.pth') objectstitch_cfg = self.config.objectstitch_config_path if objectstitch_cfg is None: objectstitch_cfg = os.path.join(os.path.dirname(__file__), 'source', 'v1.yaml') clip_dir = self.config.objectstitch_clip_dir if clip_dir is None: clip_dir = os.path.join(self.model_dir, 'shared_pretrained_models', 'openai-clip-vit-large-patch14') os_cfg = ObjectStitchConfig( ckpt_path=objectstitch_ckpt, config_path=objectstitch_cfg, clip_dir=clip_dir, device=self.config.device, ) self._os_model, self._os_sampler = load_objectstitch_model_and_sampler(config=os_cfg) return self._os_model, self._os_sampler def _get_sam_predictor(self): if self._sam_predictor is not None: return self._sam_predictor sam_ckpt = self.config.sam_checkpoint if sam_ckpt is None: sam_ckpt = os.path.join(self.model_dir, 'pretrained_models', 'sam_vit_h_4b8939.pth') sam_cfg = SamOnObjectStitchConfig( sam_checkpoint=sam_ckpt, device=self.config.device, ) self._sam_predictor = build_sam_predictor(config=sam_cfg) return self._sam_predictor
[docs] def __call__( self, background_path: str | Path, foreground_path: str | Path, foreground_mask_path: str | Path, bbox: list[int], result_dir: str | Path, mode: Literal["aggressive", "conservative"] = "conservative", cleanup_intermediate: bool = True, verbose: bool = False, seed: int = 123, strength: float = 1.0, split_ratio: float = 0.5, ) -> np.ndarray | None: """Run a single OSInsert inference. Parameters ---------- background_path: Path to the background image. foreground_path: Path to the foreground image used as the InsertAnything reference image. foreground_mask_path: Binary mask for the foreground image. bbox: List containing ``[x1, y1, x2, y2]``, specifying the insertion region on the background image. result_dir: Directory where the final composed image will be written. mode: - ``"conservative"``: background + bbox -> mask -> InsertAnything. - ``"aggressive"``: ObjectStitch + SAM -> combined source/mask -> InsertAnything. cleanup_intermediate: Deprecated. Present for backward compatibility. verbose: If True, save intermediate artifacts into ``result_dir/intermediates``. Default False (do not save intermediates). seed: Random seed for InsertAnything. strength: InsertAnything strength parameter. Returns: Generated composited image (np.array): The inserted result. """ if mode not in {"aggressive", "conservative"}: raise ValueError(f"Unsupported mode: {mode}") # ------------------------------------------------------------------ # Path normalization and output directory. # ------------------------------------------------------------------ background_path = Path(background_path) foreground_path = Path(foreground_path) foreground_mask_path = Path(foreground_mask_path) result_dir = Path(result_dir) os.makedirs(result_dir, exist_ok=True) intermediates_dir = result_dir / "intermediates" if verbose: os.makedirs(intermediates_dir, exist_ok=True) # InsertAnything expects a list of seeds. seeds = [seed] # Load background once; used by both modes. bg = cv2.imread(str(background_path)) if bg is None: raise FileNotFoundError(background_path) h, w = bg.shape[:2] fg = cv2.imread(str(foreground_path)) if fg is None: raise FileNotFoundError(foreground_path) fg_mask = cv2.imread(str(foreground_mask_path)) if fg_mask is None: raise FileNotFoundError(foreground_mask_path) if fg_mask.ndim == 3: fg_mask = fg_mask[:, :, 0] # 直接使用传入的 bbox 列表 assert len(bbox) == 4, f"bbox should be a list of 4 integers: [x1, y1, x2, y2], got {bbox}" # ------------------------------------------------------------------ # Aggressive mode: ObjectStitch + SAM + InsertAnything. # ------------------------------------------------------------------ if mode == "aggressive": # 1) ObjectStitch coarse composite (cached weights). os_model, os_sampler = self._get_objectstitch_model_and_sampler() os_rgb = run_objectstitch_single_image_from_images_cached( background=cv2.cvtColor(bg, cv2.COLOR_BGR2RGB), foreground=cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), foreground_mask=fg_mask, bbox_xyxy=tuple(bbox), model=os_model, sampler=os_sampler, device=self.config.device, seed=seed, ) # 2) SAM mask on top of ObjectStitch composite (in-memory). predictor = self._get_sam_predictor() sam_mask = run_sam_on_objectstitch_with_predictor( predictor=predictor, os_image=cv2.cvtColor(os_rgb, cv2.COLOR_RGB2BGR), bg_shape_hw=(h, w), bbox_xyxy_bg=tuple(bbox), ) # 3) Construct InsertAnything source & mask following # exp/run_insertanything_strength_sweep_dispatch.py::make_source_and_mask bg_bgr = bg # already read above os_bgr = cv2.cvtColor(os_rgb, cv2.COLOR_RGB2BGR) hh, ww = bg_bgr.shape[:2] os_bgr = cv2.resize(os_bgr, (ww, hh), interpolation=cv2.INTER_AREA) if sam_mask.shape[:2] != (hh, ww): sam_mask = cv2.resize(sam_mask, (ww, hh), interpolation=cv2.INTER_NEAREST) m = (sam_mask > 127).astype(np.float32) m3 = np.stack([m, m, m], axis=-1) src_bgr = bg_bgr.astype(np.float32) * (1.0 - m3) + os_bgr.astype(np.float32) * m3 src_bgr = np.clip(src_bgr, 0, 255).astype(np.uint8) # BBox-based mask for InsertAnything second-stage (bbox mask) bbox_mask = make_rect_mask_from_bbox(h, w, bbox) if verbose: cv2.imwrite(str(intermediates_dir / "objectstitch_coarse_rgb.png"), cv2.cvtColor(os_rgb, cv2.COLOR_RGB2BGR)) cv2.imwrite(str(intermediates_dir / "sam_mask.png"), sam_mask) cv2.imwrite(str(intermediates_dir / "blended_source.png"), src_bgr) cv2.imwrite(str(intermediates_dir / "bbox_mask.png"), bbox_mask) # 4) InsertAnything refinement using blended source, bbox mask # (for the second half of denoising) and SAM mask # (for the first half, wired via sam_mask_path). result = insertanything_infer( source_image=cv2.cvtColor(src_bgr, cv2.COLOR_BGR2RGB), mask_image=bbox_mask, ref_image=cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), ref_mask=fg_mask, sam_mask=sam_mask, seeds=seeds, strength=strength, split_ratio=split_ratio, save_path=str(result_dir), filename_suffix="", net=self._ia_net, return_image=True, ) return result # ------------------------------------------------------------------ # Conservative mode: background + bbox -> mask -> InsertAnything. # ------------------------------------------------------------------ mask = make_rect_mask_from_bbox(h, w, bbox) if verbose: cv2.imwrite(str(intermediates_dir / "bbox_mask.png"), mask) result = insertanything_infer( source_image=cv2.cvtColor(bg, cv2.COLOR_BGR2RGB), mask_image=mask, ref_image=cv2.cvtColor(fg, cv2.COLOR_BGR2RGB), ref_mask=fg_mask, seeds=seeds, strength=strength, split_ratio=split_ratio, save_path=str(result_dir), filename_suffix="", net=self._ia_net, return_image=True, ) return result
def infer_images( self, *, background: np.ndarray, foreground: np.ndarray, foreground_mask: np.ndarray, bbox_xyxy: tuple[int, int, int, int], mode: Literal["aggressive", "conservative"] = "conservative", verbose: bool = False, seed: int = 123, strength: float = 1.0, split_ratio: float = 0.5, save_path: str | Path | None = None, filename_suffix: str = "", ) -> np.ndarray | None: if background.ndim != 3: raise ValueError("background must be HxWx3") h, w = background.shape[:2] out_dir = str(save_path) if save_path is not None else "./result" if verbose and save_path is not None: inter_dir = Path(save_path) / "intermediates" os.makedirs(inter_dir, exist_ok=True) # InsertAnything expects a list of seeds. seeds = [seed] if mode == "aggressive": os_model, os_sampler = self._get_objectstitch_model_and_sampler() os_rgb = run_objectstitch_single_image_from_images_cached( background=background[:, :, ::-1], foreground=foreground[:, :, ::-1], foreground_mask=foreground_mask, bbox_xyxy=tuple(bbox_xyxy), model=os_model, sampler=os_sampler, device=self.config.device, seed=seed, ) predictor = self._get_sam_predictor() sam_mask = run_sam_on_objectstitch_with_predictor( predictor=predictor, os_image=cv2.cvtColor(os_rgb, cv2.COLOR_RGB2BGR), bg_shape_hw=(h, w), bbox_xyxy_bg=tuple(bbox_xyxy), ) bg_bgr = background os_bgr = cv2.cvtColor(os_rgb, cv2.COLOR_RGB2BGR) hh, ww = bg_bgr.shape[:2] os_bgr = cv2.resize(os_bgr, (ww, hh), interpolation=cv2.INTER_AREA) if sam_mask.shape[:2] != (hh, ww): sam_mask = cv2.resize(sam_mask, (ww, hh), interpolation=cv2.INTER_NEAREST) m = (sam_mask > 127).astype(np.float32) m3 = np.stack([m, m, m], axis=-1) src_bgr = bg_bgr.astype(np.float32) * (1.0 - m3) + os_bgr.astype(np.float32) * m3 src_bgr = np.clip(src_bgr, 0, 255).astype(np.uint8) bbox_mask = make_rect_mask_from_bbox(h, w, list(bbox_xyxy)) if verbose and save_path is not None: inter_dir = Path(save_path) / "intermediates" prefix = f"{filename_suffix}_" if filename_suffix else "" cv2.imwrite(str(inter_dir / f"{prefix}objectstitch_coarse_rgb.png"), cv2.cvtColor(os_rgb, cv2.COLOR_RGB2BGR)) cv2.imwrite(str(inter_dir / f"{prefix}sam_mask.png"), sam_mask) cv2.imwrite(str(inter_dir / f"{prefix}blended_source.png"), src_bgr) cv2.imwrite(str(inter_dir / f"{prefix}bbox_mask.png"), bbox_mask) return insertanything_infer( source_image=cv2.cvtColor(src_bgr, cv2.COLOR_BGR2RGB), mask_image=bbox_mask, ref_image=cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), ref_mask=foreground_mask, sam_mask=sam_mask, seeds=seeds, strength=strength, split_ratio=split_ratio, save_path=out_dir, filename_suffix=filename_suffix, net=self._ia_net, return_image=True, ) if mode != "conservative": raise ValueError(f"Unsupported mode: {mode}") mask = make_rect_mask_from_bbox(h, w, list(bbox_xyxy)) if verbose and save_path is not None: inter_dir = Path(save_path) / "intermediates" prefix = f"{filename_suffix}_" if filename_suffix else "" cv2.imwrite(str(inter_dir / f"{prefix}bbox_mask.png"), mask) return insertanything_infer( source_image=cv2.cvtColor(background, cv2.COLOR_BGR2RGB), mask_image=mask, ref_image=cv2.cvtColor(foreground, cv2.COLOR_BGR2RGB), ref_mask=foreground_mask, seeds=seeds, strength=strength, split_ratio=split_ratio, save_path=out_dir, filename_suffix=filename_suffix, net=self._ia_net, return_image=True, )