class ImaginePrompt(BaseModel, protected_namespaces=()):
model_config = ConfigDict(extra="forbid", validate_assignment=True)
prompt: List[WeightedPrompt] = Field(default=None, validate_default=True) # type: ignore
negative_prompt: List[WeightedPrompt] = Field(
default_factory=list, validate_default=True
)
prompt_strength: float = Field(default=7.5, le=50, ge=-50, validate_default=True)
init_image: LazyLoadingImage | None = Field(
None, description="base64 encoded image", validate_default=True
)
init_image_strength: float | None = Field(
ge=0, le=1, default=None, validate_default=True
)
image_prompt: List[LazyLoadingImage] | None = Field(None, validate_default=True)
image_prompt_strength: float = Field(ge=0, le=1, default=0.0)
control_inputs: List[ControlInput] = Field(
default_factory=list, validate_default=True
)
mask_prompt: str | None = Field(
default=None,
description="text description of the things to be masked",
validate_default=True,
)
mask_image: LazyLoadingImage | None = Field(default=None, validate_default=True)
mask_mode: MaskMode = MaskMode.REPLACE
mask_modify_original: bool = True
outpaint: str | None = ""
model_weights: config.ModelWeightsConfig = Field( # type: ignore
default=config.DEFAULT_MODEL_WEIGHTS, validate_default=True
)
solver_type: str = Field(default=config.DEFAULT_SOLVER, validate_default=True)
seed: int | None = Field(default=None, validate_default=True)
steps: int = Field(validate_default=True)
size: tuple[int, int] = Field(validate_default=True)
upscale: bool = False
fix_faces: bool = False
fix_faces_fidelity: float | None = Field(0.5, ge=0, le=1, validate_default=True)
conditioning: str | None = None
tile_mode: str = ""
allow_compose_phase: bool = True
is_intermediate: bool = False
collect_progress_latents: bool = False
caption_text: str = Field(
"", description="text to be overlaid on the image", validate_default=True
)
composition_strength: float = Field(ge=0, le=1, validate_default=True)
inpaint_method: InpaintMethod = "finetune"
def __init__(
self,
prompt: PromptInput = "",
*,
negative_prompt: PromptInput = None,
prompt_strength: float | None = 7.5,
init_image: LazyLoadingImage | None = None,
init_image_strength: float | None = None,
image_prompt: LazyLoadingImage | List[LazyLoadingImage] | None = None,
image_prompt_strength: float | None = 0.35,
control_inputs: List[ControlInput] | None = None,
mask_prompt: str | None = None,
mask_image: LazyLoadingImage | None = None,
mask_mode: MaskInput = MaskMode.REPLACE,
mask_modify_original: bool = True,
outpaint: str | None = "",
model_weights: str | config.ModelWeightsConfig = config.DEFAULT_MODEL_WEIGHTS,
solver_type: str = config.DEFAULT_SOLVER,
seed: int | None = None,
steps: int | None = None,
size: int | str | tuple[int, int] | None = None,
upscale: bool = False,
fix_faces: bool = False,
fix_faces_fidelity: float | None = 0.2,
conditioning: str | None = None,
tile_mode: str = "",
allow_compose_phase: bool = True,
is_intermediate: bool = False,
collect_progress_latents: bool = False,
caption_text: str = "",
composition_strength: float | None = 0.5,
inpaint_method: InpaintMethod = "finetune",
):
if image_prompt and not isinstance(image_prompt, list):
image_prompt = [image_prompt]
if not image_prompt_strength:
image_prompt_strength = 0.35
super().__init__(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_strength=prompt_strength,
init_image=init_image,
init_image_strength=init_image_strength,
image_prompt=image_prompt,
image_prompt_strength=image_prompt_strength,
control_inputs=control_inputs,
mask_prompt=mask_prompt,
mask_image=mask_image,
mask_mode=mask_mode,
mask_modify_original=mask_modify_original,
outpaint=outpaint,
model_weights=model_weights,
solver_type=solver_type,
seed=seed,
steps=steps,
size=size,
upscale=upscale,
fix_faces=fix_faces,
fix_faces_fidelity=fix_faces_fidelity,
conditioning=conditioning,
tile_mode=tile_mode,
allow_compose_phase=allow_compose_phase,
is_intermediate=is_intermediate,
collect_progress_latents=collect_progress_latents,
caption_text=caption_text,
composition_strength=composition_strength,
inpaint_method=inpaint_method,
)
self._default_negative_prompt = None
@field_validator("prompt", "negative_prompt", mode="before")
def make_into_weighted_prompts(
cls,
value: PromptInput,
) -> list[WeightedPrompt]:
match value:
case None:
return []
case str():
if value is not None:
return [WeightedPrompt(text=value)]
else:
return []
case WeightedPrompt():
return [value]
case list():
if all(isinstance(item, str) for item in value):
return [WeightedPrompt(text=str(p)) for p in value]
elif all(isinstance(item, WeightedPrompt) for item in value):
return cast(List[WeightedPrompt], value)
raise ValueError("Invalid prompt input")
@field_validator("prompt", "negative_prompt", mode="after")
@classmethod
def must_have_some_weight(cls, v):
if v:
total_weight = sum(p.weight for p in v)
if total_weight == 0:
raise ValueError("Total weight of prompts cannot be 0")
return v
@field_validator("prompt", "negative_prompt", mode="after")
def sort_prompts(cls, v):
if isinstance(v, list):
v.sort(key=lambda p: p.weight, reverse=True)
return v
@property
def default_negative_prompt(self):
default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
if self.model_weights:
default_negative_prompt = self.model_weights.defaults.get(
"negative_prompt", default_negative_prompt
)
return default_negative_prompt
@model_validator(mode="after")
def validate_negative_prompt(self):
if self.negative_prompt == []:
self.negative_prompt = [WeightedPrompt(text=self.default_negative_prompt)]
return self
@field_validator("prompt_strength", mode="before")
def validate_prompt_strength(cls, v):
return 7.5 if v is None else v
@field_validator("tile_mode", mode="before")
def validate_tile_mode(cls, v):
valid_tile_modes = ("", "x", "y", "xy")
if v is True:
return "xy"
if v is False or v is None:
return ""
if not isinstance(v, str):
msg = f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
raise ValueError(msg) # noqa
v = v.lower()
if v not in valid_tile_modes:
msg = f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
raise ValueError(msg)
return v
@field_validator("outpaint", mode="after")
def validate_outpaint(cls, v):
from imaginairy.utils.outpaint import outpaint_arg_str_parse
outpaint_arg_str_parse(v)
return v
@field_validator("conditioning", mode="after")
def validate_conditioning(cls, v):
from torch import Tensor
if v is None:
return v
if not isinstance(v, Tensor):
raise ValueError("conditioning must be a torch.Tensor") # noqa
return v
@model_validator(mode="before")
@classmethod
def set_default_composition_strength(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
comp_strength = data.get("composition_strength")
default_comp_strength = 0.5
if comp_strength is None:
model_weights = data.get("model_weights")
if isinstance(model_weights, config.ModelWeightsConfig):
default_comp_strength = model_weights.defaults.get(
"composition_strength", default_comp_strength
)
data["composition_strength"] = default_comp_strength
return data
# @field_validator("init_image", "mask_image", mode="after")
# def handle_images(cls, v):
# if isinstance(v, str):
# return LazyLoadingImage(filepath=v)
#
# return v
@model_validator(mode="after")
def set_init_from_control_inputs(self):
if self.init_image is None:
for control_input in self.control_inputs:
if control_input.image:
self.init_image = control_input.image
break
return self
@field_validator("control_inputs", mode="before")
def validate_control_inputs(cls, v):
if v is None:
v = []
return v
@field_validator("control_inputs", mode="after")
def set_image_from_init_image(cls, v, info: core_schema.FieldValidationInfo):
v = v or []
for control_input in v:
if control_input.image is None and control_input.image_raw is None:
control_input.image = info.data["init_image"]
return v
@field_validator("mask_image")
def validate_mask_image(cls, v, info: core_schema.FieldValidationInfo):
if v is not None and info.data.get("mask_prompt") is not None:
msg = "You can only set one of `mask_image` and `mask_prompt`"
raise ValueError(msg)
return v
@field_validator("mask_prompt", "mask_image", mode="before")
def validate_mask_prompt(cls, v, info: core_schema.FieldValidationInfo):
if info.data.get("init_image") is None and v:
msg = "You must set `init_image` if you want to use a mask"
raise ValueError(msg)
return v
@model_validator(mode="before")
def resolve_model_weights(cls, data: Any):
if not isinstance(data, dict):
return data
model_weights = data.get("model_weights")
if model_weights is None:
model_weights = config.DEFAULT_MODEL_WEIGHTS
from imaginairy.utils.model_manager import resolve_model_weights_config
should_use_inpainting = bool(
data.get("mask_image") or data.get("mask_prompt") or data.get("outpaint")
)
should_use_inpainting_weights = (
should_use_inpainting and data.get("inpaint_method") == "finetune"
)
model_weights_config = resolve_model_weights_config(
model_weights=model_weights,
default_model_architecture=None,
for_inpainting=should_use_inpainting_weights,
)
data["model_weights"] = model_weights_config
return data
@field_validator("seed")
def validate_seed(cls, v):
return v
@field_validator("fix_faces_fidelity", mode="before")
def validate_fix_faces_fidelity(cls, v):
if v is None:
return 0.5
return v
@field_validator("solver_type", mode="after")
def validate_solver_type(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.samplers import SolverName
if v is None:
v = config.DEFAULT_SOLVER
v = v.lower()
if info.data.get("model") == "edit" and v in (
SolverName.PLMS,
SolverName.DDIM,
):
msg = "PLMS and DDIM solvers are not supported for pix2pix edit model."
raise ValueError(msg)
return v
@field_validator("steps", mode="before")
def validate_steps(cls, v, info: core_schema.FieldValidationInfo):
model_weights = info.data.get("model_weights")
# Try to get steps from model weights defaults
if (
v is None
and model_weights
and isinstance(model_weights, config.ModelWeightsConfig)
):
v = model_weights.defaults.get("steps")
# If not found in model weights, try model architecture defaults
if v is None and model_weights and model_weights.architecture:
v = model_weights.architecture.defaults.get("steps")
# If still not found, use solver-specific defaults
if v is None:
solver_type = info.data.get("solver_type", "ddim").lower()
steps_lookup = {"ddim": 50, "dpmpp": 20}
v = steps_lookup.get(
solver_type, 50
) # Default to 50 if solver not recognized
try:
return int(v)
except (OverflowError, TypeError) as e:
raise ValueError("Steps must be an integer") from e
@model_validator(mode="after")
def validate_init_image_strength(self):
if self.init_image_strength is None:
if self.control_inputs:
self.init_image_strength = 0.0
elif self.outpaint or self.mask_image or self.mask_prompt:
self.init_image_strength = 0.0
else:
self.init_image_strength = 0.2
return self
@field_validator("size", mode="before")
def validate_image_size(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.utils.model_manager import get_model_default_image_size
from imaginairy.utils.named_resolutions import normalize_image_size
if v is None:
v = get_model_default_image_size(info.data["model_weights"].architecture)
width, height = normalize_image_size(v)
return width, height
@field_validator("size", mode="after")
def validate_image_size_after(cls, v, info: core_schema.FieldValidationInfo):
width, height = v
min_size = 8
max_size = 100_000
if not min_size <= width <= max_size:
msg = f"Width must be between {min_size} and {max_size}. Got: {width}"
raise ValueError(msg)
if not min_size <= height <= max_size:
msg = f"Height must be between {min_size} and {max_size}. Got: {height}"
raise ValueError(msg)
return v
@field_validator("caption_text", mode="before")
def validate_caption_text(cls, v):
if v is None:
v = ""
return v
@property
def prompts(self):
return self.prompt
@property
def prompt_text(self) -> str:
if not self.prompt:
return ""
if len(self.prompt) == 1:
return self.prompt[0].text
return "|".join(str(p) for p in self.prompt)
@property
def negative_prompt_text(self) -> str:
if not self.negative_prompt:
return ""
if len(self.negative_prompt) == 1:
return self.negative_prompt[0].text
return "|".join(str(p) for p in self.negative_prompt)
@property
def width(self) -> int:
return self.size[0]
@property
def height(self) -> int:
return self.size[1]
@property
def aspect_ratio(self) -> str:
from imaginairy.utils.img_utils import aspect_ratio
return aspect_ratio(width=self.width, height=self.height)
@property
def should_use_inpainting(self) -> bool:
return bool(self.outpaint or self.mask_image or self.mask_prompt)
@property
def should_use_inpainting_weights(self) -> bool:
return self.should_use_inpainting and self.inpaint_method == "finetune"
@property
def model_architecture(self) -> config.ModelArchitecture:
return self.model_weights.architecture
def prompt_description(self):
if self.negative_prompt_text == self.default_negative_prompt:
neg_prompt = "DEFAULT-NEGATIVE-PROMPT"
else:
neg_prompt = f'"{self.negative_prompt_text}"'
from termcolor import colored
prompt_text = colored(self.prompt_text, "green")
return (
f'"{prompt_text}"\n'
" "
f"negative-prompt:{neg_prompt}\n"
" "
f"size:{self.width}x{self.height}px-({self.aspect_ratio}) "
f"seed:{self.seed} "
f"prompt-strength:{self.prompt_strength} "
f"steps:{self.steps} solver-type:{self.solver_type} "
f"init-image-strength:{self.init_image_strength} "
f"arch:{self.model_architecture.aliases[0]} "
f"weights:{self.model_weights.aliases[0]}"
)
def logging_dict(self):
"""Return a dict of the object but with binary data replaced with reprs."""
data = self.model_dump()
data["init_image"] = repr(self.init_image)
data["mask_image"] = repr(self.mask_image)
data["image_prompt"] = repr(self.image_prompt)
if self.control_inputs:
data["control_inputs"] = [repr(ci) for ci in self.control_inputs]
return data
def full_copy(self, deep=True, update=None):
new_prompt = self.model_copy(
deep=deep,
update=update,
)
# new_prompt = self.model_validate(new_prompt) doesn't work for some reason https://github.com/pydantic/pydantic/issues/7387
new_prompt = new_prompt.model_validate(dict(new_prompt))
return new_prompt
def make_concrete_copy(self) -> Self:
seed = self.seed if self.seed is not None else random.randint(1, 1_000_000_000)
return self.full_copy(
deep=False,
update={
"seed": seed,
},
)