CIFARBaseModel
- class CIFARBaseModel(device: device | None)[source]
-
Pre-trained diffusion model on CIFAR-10 32x32.
Uses the google/ddpm-cifar10-32 model from the diffusers library.
Examples
`python device = torch.device("cpu") base_model = CIFARBaseModel().to(device) reward = CompressionReward() env = EpsilonEnvironment(base_model, reward, discretization_steps=100) policy = copy.deepcopy(base_model) env.policy = policy `- forward(x: DDTensor, t: Tensor, **kwargs: Any) DDTensor[source]
Forward pass of the model, outputting \(\epsilon(x_t, t)\).
- sample_p0(n: int, **kwargs: Any) tuple[DDTensor, dict[str, Any]][source]
Sample n datapoints from the base distribution \(p_0\).
- Parameters:
n (int) – Number of samples to draw.
- Returns:
samples – Samples from the base distribution \(p_0\).
- Return type:
DDTensor, shape (n, 3, 32, 32)
Notes
The base distribution \(p_0\) is a standard Gaussian distribution.
- property scheduler: DiffusionScheduler
Scheduler used for sampling.