Skip to content

ImaginePrompt

Bases: BaseModel

Source code in imaginairy/schema.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
class ImaginePrompt(BaseModel, protected_namespaces=()):
    model_config = ConfigDict(extra="forbid", validate_assignment=True)

    prompt: List[WeightedPrompt] = Field(default=None, validate_default=True)  # type: ignore
    negative_prompt: List[WeightedPrompt] = Field(
        default_factory=list, validate_default=True
    )
    prompt_strength: float = Field(default=7.5, le=50, ge=-50, validate_default=True)
    init_image: LazyLoadingImage | None = Field(
        None, description="base64 encoded image", validate_default=True
    )
    init_image_strength: float | None = Field(
        ge=0, le=1, default=None, validate_default=True
    )
    image_prompt: List[LazyLoadingImage] | None = Field(None, validate_default=True)
    image_prompt_strength: float = Field(ge=0, le=1, default=0.0)
    control_inputs: List[ControlInput] = Field(
        default_factory=list, validate_default=True
    )
    mask_prompt: str | None = Field(
        default=None,
        description="text description of the things to be masked",
        validate_default=True,
    )
    mask_image: LazyLoadingImage | None = Field(default=None, validate_default=True)
    mask_mode: MaskMode = MaskMode.REPLACE
    mask_modify_original: bool = True
    outpaint: str | None = ""
    model_weights: config.ModelWeightsConfig = Field(  # type: ignore
        default=config.DEFAULT_MODEL_WEIGHTS, validate_default=True
    )
    solver_type: str = Field(default=config.DEFAULT_SOLVER, validate_default=True)
    seed: int | None = Field(default=None, validate_default=True)
    steps: int = Field(validate_default=True)
    size: tuple[int, int] = Field(validate_default=True)
    upscale: bool = False
    fix_faces: bool = False
    fix_faces_fidelity: float | None = Field(0.5, ge=0, le=1, validate_default=True)
    conditioning: str | None = None
    tile_mode: str = ""
    allow_compose_phase: bool = True
    is_intermediate: bool = False
    collect_progress_latents: bool = False
    caption_text: str = Field(
        "", description="text to be overlaid on the image", validate_default=True
    )
    composition_strength: float = Field(ge=0, le=1, validate_default=True)
    inpaint_method: InpaintMethod = "finetune"

    def __init__(
        self,
        prompt: PromptInput = "",
        *,
        negative_prompt: PromptInput = None,
        prompt_strength: float | None = 7.5,
        init_image: LazyLoadingImage | None = None,
        init_image_strength: float | None = None,
        image_prompt: LazyLoadingImage | List[LazyLoadingImage] | None = None,
        image_prompt_strength: float | None = 0.35,
        control_inputs: List[ControlInput] | None = None,
        mask_prompt: str | None = None,
        mask_image: LazyLoadingImage | None = None,
        mask_mode: MaskInput = MaskMode.REPLACE,
        mask_modify_original: bool = True,
        outpaint: str | None = "",
        model_weights: str | config.ModelWeightsConfig = config.DEFAULT_MODEL_WEIGHTS,
        solver_type: str = config.DEFAULT_SOLVER,
        seed: int | None = None,
        steps: int | None = None,
        size: int | str | tuple[int, int] | None = None,
        upscale: bool = False,
        fix_faces: bool = False,
        fix_faces_fidelity: float | None = 0.2,
        conditioning: str | None = None,
        tile_mode: str = "",
        allow_compose_phase: bool = True,
        is_intermediate: bool = False,
        collect_progress_latents: bool = False,
        caption_text: str = "",
        composition_strength: float | None = 0.5,
        inpaint_method: InpaintMethod = "finetune",
    ):
        if image_prompt and not isinstance(image_prompt, list):
            image_prompt = [image_prompt]

        if not image_prompt_strength:
            image_prompt_strength = 0.35

        super().__init__(
            prompt=prompt,
            negative_prompt=negative_prompt,
            prompt_strength=prompt_strength,
            init_image=init_image,
            init_image_strength=init_image_strength,
            image_prompt=image_prompt,
            image_prompt_strength=image_prompt_strength,
            control_inputs=control_inputs,
            mask_prompt=mask_prompt,
            mask_image=mask_image,
            mask_mode=mask_mode,
            mask_modify_original=mask_modify_original,
            outpaint=outpaint,
            model_weights=model_weights,
            solver_type=solver_type,
            seed=seed,
            steps=steps,
            size=size,
            upscale=upscale,
            fix_faces=fix_faces,
            fix_faces_fidelity=fix_faces_fidelity,
            conditioning=conditioning,
            tile_mode=tile_mode,
            allow_compose_phase=allow_compose_phase,
            is_intermediate=is_intermediate,
            collect_progress_latents=collect_progress_latents,
            caption_text=caption_text,
            composition_strength=composition_strength,
            inpaint_method=inpaint_method,
        )
        self._default_negative_prompt = None

    @field_validator("prompt", "negative_prompt", mode="before")
    def make_into_weighted_prompts(
        cls,
        value: PromptInput,
    ) -> list[WeightedPrompt]:
        match value:
            case None:
                return []

            case str():
                if value is not None:
                    return [WeightedPrompt(text=value)]
                else:
                    return []
            case WeightedPrompt():
                return [value]
            case list():
                if all(isinstance(item, str) for item in value):
                    return [WeightedPrompt(text=str(p)) for p in value]
                elif all(isinstance(item, WeightedPrompt) for item in value):
                    return cast(List[WeightedPrompt], value)
        raise ValueError("Invalid prompt input")

    @field_validator("prompt", "negative_prompt", mode="after")
    @classmethod
    def must_have_some_weight(cls, v):
        if v:
            total_weight = sum(p.weight for p in v)
            if total_weight == 0:
                raise ValueError("Total weight of prompts cannot be 0")
        return v

    @field_validator("prompt", "negative_prompt", mode="after")
    def sort_prompts(cls, v):
        if isinstance(v, list):
            v.sort(key=lambda p: p.weight, reverse=True)
        return v

    @property
    def default_negative_prompt(self):
        default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
        if self.model_weights:
            default_negative_prompt = self.model_weights.defaults.get(
                "negative_prompt", default_negative_prompt
            )
        return default_negative_prompt

    @model_validator(mode="after")
    def validate_negative_prompt(self):
        if self.negative_prompt == []:
            self.negative_prompt = [WeightedPrompt(text=self.default_negative_prompt)]

        return self

    @field_validator("prompt_strength", mode="before")
    def validate_prompt_strength(cls, v):
        return 7.5 if v is None else v

    @field_validator("tile_mode", mode="before")
    def validate_tile_mode(cls, v):
        valid_tile_modes = ("", "x", "y", "xy")
        if v is True:
            return "xy"

        if v is False or v is None:
            return ""

        if not isinstance(v, str):
            msg = f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
            raise ValueError(msg)  # noqa

        v = v.lower()
        if v not in valid_tile_modes:
            msg = f"Invalid tile_mode: '{v}'. Valid modes are: {valid_tile_modes}"
            raise ValueError(msg)
        return v

    @field_validator("outpaint", mode="after")
    def validate_outpaint(cls, v):
        from imaginairy.utils.outpaint import outpaint_arg_str_parse

        outpaint_arg_str_parse(v)
        return v

    @field_validator("conditioning", mode="after")
    def validate_conditioning(cls, v):
        from torch import Tensor

        if v is None:
            return v

        if not isinstance(v, Tensor):
            raise ValueError("conditioning must be a torch.Tensor")  # noqa
        return v

    @model_validator(mode="before")
    @classmethod
    def set_default_composition_strength(cls, data: Any) -> Any:
        if not isinstance(data, dict):
            return data
        comp_strength = data.get("composition_strength")
        default_comp_strength = 0.5
        if comp_strength is None:
            model_weights = data.get("model_weights")
            if isinstance(model_weights, config.ModelWeightsConfig):
                default_comp_strength = model_weights.defaults.get(
                    "composition_strength", default_comp_strength
                )
            data["composition_strength"] = default_comp_strength

        return data

    # @field_validator("init_image", "mask_image", mode="after")
    # def handle_images(cls, v):
    #     if isinstance(v, str):
    #         return LazyLoadingImage(filepath=v)
    #
    #     return v

    @model_validator(mode="after")
    def set_init_from_control_inputs(self):
        if self.init_image is None:
            for control_input in self.control_inputs:
                if control_input.image:
                    self.init_image = control_input.image
                    break

        return self

    @field_validator("control_inputs", mode="before")
    def validate_control_inputs(cls, v):
        if v is None:
            v = []
        return v

    @field_validator("control_inputs", mode="after")
    def set_image_from_init_image(cls, v, info: core_schema.FieldValidationInfo):
        v = v or []
        for control_input in v:
            if control_input.image is None and control_input.image_raw is None:
                control_input.image = info.data["init_image"]
        return v

    @field_validator("mask_image")
    def validate_mask_image(cls, v, info: core_schema.FieldValidationInfo):
        if v is not None and info.data.get("mask_prompt") is not None:
            msg = "You can only set one of `mask_image` and `mask_prompt`"
            raise ValueError(msg)
        return v

    @field_validator("mask_prompt", "mask_image", mode="before")
    def validate_mask_prompt(cls, v, info: core_schema.FieldValidationInfo):
        if info.data.get("init_image") is None and v:
            msg = "You must set `init_image` if you want to use a mask"
            raise ValueError(msg)
        return v

    @model_validator(mode="before")
    def resolve_model_weights(cls, data: Any):
        if not isinstance(data, dict):
            return data

        model_weights = data.get("model_weights")
        if model_weights is None:
            model_weights = config.DEFAULT_MODEL_WEIGHTS
        from imaginairy.utils.model_manager import resolve_model_weights_config

        should_use_inpainting = bool(
            data.get("mask_image") or data.get("mask_prompt") or data.get("outpaint")
        )
        should_use_inpainting_weights = (
            should_use_inpainting and data.get("inpaint_method") == "finetune"
        )
        model_weights_config = resolve_model_weights_config(
            model_weights=model_weights,
            default_model_architecture=None,
            for_inpainting=should_use_inpainting_weights,
        )
        data["model_weights"] = model_weights_config

        return data

    @field_validator("seed")
    def validate_seed(cls, v):
        return v

    @field_validator("fix_faces_fidelity", mode="before")
    def validate_fix_faces_fidelity(cls, v):
        if v is None:
            return 0.5

        return v

    @field_validator("solver_type", mode="after")
    def validate_solver_type(cls, v, info: core_schema.FieldValidationInfo):
        from imaginairy.samplers import SolverName

        if v is None:
            v = config.DEFAULT_SOLVER

        v = v.lower()

        if info.data.get("model") == "edit" and v in (
            SolverName.PLMS,
            SolverName.DDIM,
        ):
            msg = "PLMS and DDIM solvers are not supported for pix2pix edit model."
            raise ValueError(msg)
        return v

    @field_validator("steps", mode="before")
    def validate_steps(cls, v, info: core_schema.FieldValidationInfo):
        model_weights = info.data.get("model_weights")

        # Try to get steps from model weights defaults
        if (
            v is None
            and model_weights
            and isinstance(model_weights, config.ModelWeightsConfig)
        ):
            v = model_weights.defaults.get("steps")

        # If not found in model weights, try model architecture defaults
        if v is None and model_weights and model_weights.architecture:
            v = model_weights.architecture.defaults.get("steps")

        # If still not found, use solver-specific defaults
        if v is None:
            solver_type = info.data.get("solver_type", "ddim").lower()
            steps_lookup = {"ddim": 50, "dpmpp": 20}
            v = steps_lookup.get(
                solver_type, 50
            )  # Default to 50 if solver not recognized

        try:
            return int(v)
        except (OverflowError, TypeError) as e:
            raise ValueError("Steps must be an integer") from e

    @model_validator(mode="after")
    def validate_init_image_strength(self):
        if self.init_image_strength is None:
            if self.control_inputs:
                self.init_image_strength = 0.0
            elif self.outpaint or self.mask_image or self.mask_prompt:
                self.init_image_strength = 0.0
            else:
                self.init_image_strength = 0.2

        return self

    @field_validator("size", mode="before")
    def validate_image_size(cls, v, info: core_schema.FieldValidationInfo):
        from imaginairy.utils.model_manager import get_model_default_image_size
        from imaginairy.utils.named_resolutions import normalize_image_size

        if v is None:
            v = get_model_default_image_size(info.data["model_weights"].architecture)

        width, height = normalize_image_size(v)

        return width, height

    @field_validator("size", mode="after")
    def validate_image_size_after(cls, v, info: core_schema.FieldValidationInfo):
        width, height = v
        min_size = 8
        max_size = 100_000
        if not min_size <= width <= max_size:
            msg = f"Width must be between {min_size} and {max_size}. Got: {width}"
            raise ValueError(msg)

        if not min_size <= height <= max_size:
            msg = f"Height must be between {min_size} and {max_size}. Got: {height}"
            raise ValueError(msg)
        return v

    @field_validator("caption_text", mode="before")
    def validate_caption_text(cls, v):
        if v is None:
            v = ""

        return v

    @property
    def prompts(self):
        return self.prompt

    @property
    def prompt_text(self) -> str:
        if not self.prompt:
            return ""
        if len(self.prompt) == 1:
            return self.prompt[0].text
        return "|".join(str(p) for p in self.prompt)

    @property
    def negative_prompt_text(self) -> str:
        if not self.negative_prompt:
            return ""
        if len(self.negative_prompt) == 1:
            return self.negative_prompt[0].text
        return "|".join(str(p) for p in self.negative_prompt)

    @property
    def width(self) -> int:
        return self.size[0]

    @property
    def height(self) -> int:
        return self.size[1]

    @property
    def aspect_ratio(self) -> str:
        from imaginairy.utils.img_utils import aspect_ratio

        return aspect_ratio(width=self.width, height=self.height)

    @property
    def should_use_inpainting(self) -> bool:
        return bool(self.outpaint or self.mask_image or self.mask_prompt)

    @property
    def should_use_inpainting_weights(self) -> bool:
        return self.should_use_inpainting and self.inpaint_method == "finetune"

    @property
    def model_architecture(self) -> config.ModelArchitecture:
        return self.model_weights.architecture

    def prompt_description(self):
        if self.negative_prompt_text == self.default_negative_prompt:
            neg_prompt = "DEFAULT-NEGATIVE-PROMPT"
        else:
            neg_prompt = f'"{self.negative_prompt_text}"'

        from termcolor import colored

        prompt_text = colored(self.prompt_text, "green")

        return (
            f'"{prompt_text}"\n'
            "    "
            f"negative-prompt:{neg_prompt}\n"
            "    "
            f"size:{self.width}x{self.height}px-({self.aspect_ratio}) "
            f"seed:{self.seed} "
            f"prompt-strength:{self.prompt_strength} "
            f"steps:{self.steps} solver-type:{self.solver_type} "
            f"init-image-strength:{self.init_image_strength} "
            f"arch:{self.model_architecture.aliases[0]} "
            f"weights:{self.model_weights.aliases[0]}"
        )

    def logging_dict(self):
        """Return a dict of the object but with binary data replaced with reprs."""
        data = self.model_dump()
        data["init_image"] = repr(self.init_image)
        data["mask_image"] = repr(self.mask_image)
        data["image_prompt"] = repr(self.image_prompt)
        if self.control_inputs:
            data["control_inputs"] = [repr(ci) for ci in self.control_inputs]
        return data

    def full_copy(self, deep=True, update=None):
        new_prompt = self.model_copy(
            deep=deep,
            update=update,
        )
        # new_prompt = self.model_validate(new_prompt) doesn't work for some reason https://github.com/pydantic/pydantic/issues/7387
        new_prompt = new_prompt.model_validate(dict(new_prompt))
        return new_prompt

    def make_concrete_copy(self) -> Self:
        seed = self.seed if self.seed is not None else random.randint(1, 1_000_000_000)
        return self.full_copy(
            deep=False,
            update={
                "seed": seed,
            },
        )

logging_dict()

Return a dict of the object but with binary data replaced with reprs.

Source code in imaginairy/schema.py
def logging_dict(self):
    """Return a dict of the object but with binary data replaced with reprs."""
    data = self.model_dump()
    data["init_image"] = repr(self.init_image)
    data["mask_image"] = repr(self.mask_image)
    data["image_prompt"] = repr(self.image_prompt)
    if self.control_inputs:
        data["control_inputs"] = [repr(ci) for ci in self.control_inputs]
    return data