DiTBaseModel

class DiTBaseModel(cfg_scale: float = 0.0, device: device | None = None)[source]

Bases: BaseModel[DDTensor]

Pre-trained 256x256 ImageNet transformer diffusion model.

Uses the facebookresearch/DiT-XL-2-256 model from the diffusers library.

property do_cfg: bool
forward(x: DDTensor, t: Tensor, **kwargs: Any) DDTensor[source]

Forward pass of the model, outputting \(\epsilon(x_t, t)\).

Parameters:
  • x (DDTensor, shape (n, 4, 64, 64)) – Input data.

  • t (torch.Tensor, shape (n,)) – Time steps, values in [0, 1].

  • **kwargs (dict) – Additional keyword arguments passed to the UNet model.

Returns:

output – Output of the model.

Return type:

DDTensor, shape (n, 4, 64, 64)

output_type: OutputType = 'epsilon'
postprocess(x: DDTensor) DDTensor[source]

Decode the images from the latent space.

Parameters:

x (DDTensor, shape (n, 4, 64, 64)) – Final sample in latent space.

Returns:

decoded – Decoded images in pixel space.

Return type:

DDTensor, shape (n, 3, 512, 512)

preprocess(x: DDTensor, **kwargs: Any) tuple[DDTensor, dict[str, Any]][source]

Encode the prompt (if provided instead of encoder_hidden_states).

Parameters:
  • x (DDTensor, shape (n, 4, 64, 64)) – Input data to preprocess.

  • **kwargs (dict) – Additional keyword arguments to preprocess.

Returns:

  • output (DataType) – Preprocessed data.

  • kwargs (dict) – Preprocessed keyword arguments.

sample_p0(n: int, **kwargs: Any) tuple[DDTensor, dict[str, Any]][source]

Sample n latent datapoints from the base distribution \(p_0\).

Parameters:
  • n (int) – Number of samples to draw.

  • kwargs (dict) – Additional keyword arguments.

Returns:

  • samples (DDTensor, shape (n, 4, 64, 64)) – Samples from the base distribution \(p_0\).

  • kwargs (dict) – Additional keyword arguments, a randomly selected class label is provided if “class_label” is not in the input.

Notes

The base distribution \(p_0\) is a standard Gaussian distribution.

property scheduler: DiffusionScheduler

Scheduler used for sampling.