CIFARBaseModel

class CIFARBaseModel(device: device | None)[source]

Bases: BaseModel[DDTensor]

Pre-trained diffusion model on CIFAR-10 32x32.

Uses the google/ddpm-cifar10-32 model from the diffusers library.

Examples

`python device = torch.device("cpu") base_model = CIFARBaseModel().to(device) reward = CompressionReward() env = EpsilonEnvironment(base_model, reward, discretization_steps=100) policy = copy.deepcopy(base_model) env.policy = policy `

forward(x: DDTensor, t: Tensor, **kwargs: Any) DDTensor[source]

Forward pass of the model, outputting \(\epsilon(x_t, t)\).

Parameters:
  • x (DDTensor, shape (n, 3, 32, 32)) – Input data.

  • t (torch.Tensor, shape (n,)) – Time steps, values in [0, 1].

  • **kwargs (dict) – Additional keyword arguments passed to the UNet model.

Returns:

output – Output of the model.

Return type:

DDTensor, shape (n, 3, 32, 32)

output_type: OutputType = 'epsilon'
postprocess(x: DDTensor) DDTensor[source]

Convert to [0, 1].

Parameters:

x (DDTensor, shape (n, 3, 32, 32)) – Final sample in [-1, 1].

Returns:

decoded – Final sample in [0, 1].

Return type:

DDTensor, shape (n, 3, 32, 32)

sample_p0(n: int, **kwargs: Any) tuple[DDTensor, dict[str, Any]][source]

Sample n datapoints from the base distribution \(p_0\).

Parameters:

n (int) – Number of samples to draw.

Returns:

samples – Samples from the base distribution \(p_0\).

Return type:

DDTensor, shape (n, 3, 32, 32)

Notes

The base distribution \(p_0\) is a standard Gaussian distribution.

property scheduler: DiffusionScheduler

Scheduler used for sampling.