StableDiffusionBaseModel
- class StableDiffusionBaseModel(model_name: str, cfg_scale: float = 0.0, min_cfg_scale: float | None = None, max_cfg_scale: float | None = None, prompts: list[str] | None = None, device: device | None = None)[source]
-
Stable Diffusion base model.
- Parameters:
model_name (str) – Name of the Stable Diffusion diffusers model to use, e.g., “sd-legacy/stable-diffusion-v1-5”.
cfg_scale (float) – Classifier-free guidance scale to use during sampling. If 0.0, no CFG is used.
min_cfg_scale (float, optional) – If provided, the minimum CFG scale to use during sampling. If provided, this is used together with max_cfg_scale to sample a random CFG scale for each sample.
max_cfg_scale (float, optional) – If provided, the maximum CFG scale to use during sampling. If provided, this is used together with min_cfg_scale to sample a random CFG scale for each sample.
prompts (list[str], optional) – List of prompts to use for sampling if no prompt is provided through the input. If None, a default set of prompts is used.
device (torch.device, optional) – Device to load the model on. If None, uses the default.
- forward(x: DDTensor, t: Tensor, **kwargs: Any) DDTensor[source]
Forward pass of the model, outputting \(\epsilon(x_t, t)\).
- preprocess(x: DDTensor, **kwargs: Any) tuple[DDTensor, dict[str, Any]][source]
Encode the prompt (if provided instead of encoder_hidden_states).
- Parameters:
x (DDTensor, shape (n, 4, 64, 64)) – Input data to preprocess.
**kwargs (dict) – Additional keyword arguments to preprocess.
- Returns:
output (DataType) – Preprocessed data.
kwargs (dict) – Preprocessed keyword arguments.
- sample_p0(n: int, **kwargs: Any) tuple[DDTensor, dict[str, Any]][source]
Sample n latent datapoints from the base distribution \(p_0\).
- Parameters:
n (int) – Number of samples to draw.
kwargs (dict) – Additional keyword arguments.
- Returns:
samples (DDTensor, shape (n, 4, 64, 64)) – Samples from the base distribution \(p_0\).
kwargs (dict) – Additional keyword arguments, a randomly selected prompt if not provided through the input.
Notes
The base distribution \(p_0\) is a standard Gaussian distribution.
- property scheduler: DiffusionScheduler
Base model-dependent scheduler used for sampling.