Skip to content

imagine()

Generates images based on the provided prompts using the ImaginAIry API.

Parameters:

Name Type Description Default
prompts list[ImaginePrompt] | str | ImaginePrompt

A prompt or list of prompts for image generation. Can be a string, a single ImaginePrompt instance, or a list of ImaginePrompt instances.

required
precision str

The precision mode for image generation, defaults to 'autocast'.

'autocast'
debug_img_callback Callable

Callback function for debugging images, defaults to None.

None
progress_img_callback Callable

Callback function called at intervals with progress images, defaults to None.

None
progress_img_interval_steps int

Number of steps between each progress image callback, defaults to 3.

3
progress_img_interval_min_s float

Minimum seconds between each progress image callback, defaults to 0.1.

0.1
half_mode

If set, determines whether to use half precision mode for image generation, defaults to None.

None
add_caption bool

Flag to add captions to the generated images, defaults to False.

False
unsafe_retry_count int

Number of retries for generating an image if it is deemed unsafe, defaults to 1.

1

Yields:

Type Description

The generated image(s) based on the provided prompts.

Source code in imaginairy/api/generate.py
def imagine(
    prompts: "list[ImaginePrompt] | str | ImaginePrompt",
    precision: str = "autocast",
    debug_img_callback: Callable | None = None,
    progress_img_callback: Callable | None = None,
    progress_img_interval_steps: int = 3,
    progress_img_interval_min_s=0.1,
    half_mode=None,
    add_caption: bool = False,
    unsafe_retry_count: int = 1,
):
    """
    Generates images based on the provided prompts using the ImaginAIry API.

    Args:
        prompts (list[ImaginePrompt] | str | ImaginePrompt): A prompt or list of prompts for image generation.
            Can be a string, a single ImaginePrompt instance, or a list of ImaginePrompt instances.
        precision (str, optional): The precision mode for image generation, defaults to 'autocast'.
        debug_img_callback (Callable, optional): Callback function for debugging images, defaults to None.
        progress_img_callback (Callable, optional): Callback function called at intervals with progress images, defaults to None.
        progress_img_interval_steps (int, optional): Number of steps between each progress image callback, defaults to 3.
        progress_img_interval_min_s (float, optional): Minimum seconds between each progress image callback, defaults to 0.1.
        half_mode: If set, determines whether to use half precision mode for image generation, defaults to None.
        add_caption (bool, optional): Flag to add captions to the generated images, defaults to False.
        unsafe_retry_count (int, optional): Number of retries for generating an image if it is deemed unsafe, defaults to 1.

    Yields:
        The generated image(s) based on the provided prompts.
    """
    import torch.nn

    from imaginairy.api.generate_refiners import generate_single_image
    from imaginairy.schema import ImaginePrompt
    from imaginairy.utils import (
        check_torch_version,
        fix_torch_group_norm,
        fix_torch_nn_layer_norm,
        get_device,
    )

    check_torch_version()

    prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
    prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts

    try:
        num_prompts = str(len(prompts))
    except TypeError:
        num_prompts = "?"

    if get_device() == "cpu":
        logger.warning("Running in CPU mode. It's gonna be slooooooow.")
        from imaginairy.utils.torch_installer import torch_version_check

        torch_version_check()

    if half_mode is None:
        half_mode = "cuda" in get_device() or get_device() == "mps"

    with torch.no_grad(), fix_torch_nn_layer_norm(), fix_torch_group_norm():
        for i, prompt in enumerate(prompts):
            concrete_prompt = prompt.make_concrete_copy()
            prog_text = f"{i + 1}/{num_prompts}"
            logger.info(f"🖼  {prog_text} {concrete_prompt.prompt_description()}")
            for attempt in range(unsafe_retry_count + 1):
                if attempt > 0 and isinstance(concrete_prompt.seed, int):
                    concrete_prompt.seed += 100_000_000 + attempt
                result = generate_single_image(
                    concrete_prompt,
                    debug_img_callback=debug_img_callback,
                    progress_img_callback=progress_img_callback,
                    progress_img_interval_steps=progress_img_interval_steps,
                    progress_img_interval_min_s=progress_img_interval_min_s,
                    add_caption=add_caption,
                    dtype=torch.float16 if half_mode else torch.float32,
                    output_perf=True,
                )
                if not result.safety_score.is_filtered:
                    break
                if attempt < unsafe_retry_count:
                    logger.info("    Image was unsafe, retrying with new seed...")

            yield result