BaseModel

class BaseModel(device: device | None)[source]

Bases: ABC, Module, Generic[D]

Abstract base class for base models used in flow matching and diffusion.

abstract forward(x: D, t: Tensor, **kwargs: Any) D[source]

Forward pass of the base model.

Parameters:
  • x (D) – Input data.

  • t (torch.Tensor, shape (n,)) – Time steps, values in [0, 1].

Returns:

output – Output of the model.

Return type:

D

output_type: Literal['epsilon', 'endpoint', 'velocity', 'score']
postprocess(x: D) D[source]

Postprocess samples x_1 (e.g., decode with VAE).

Parameters:

x (D) – Input data to postprocess.

Returns:

output – Postprocessed output.

Return type:

D

preprocess(x: D, **kwargs: Any) tuple[D, dict[str, Any]][source]

Preprocess data and keyword arguments for the base model.

Parameters:
  • x (D) – Input data to preprocess.

  • **kwargs (dict) – Additional keyword arguments to preprocess.

Returns:

  • output (D) – Preprocessed data.

  • kwargs (dict) – Preprocessed keyword arguments.

abstract sample_p0(n: int, **kwargs: Any) tuple[D, dict[str, Any]][source]

Sample n data points from the base distribution p0.

Parameters:
  • n (int) – Number of samples to draw.

  • **kwargs (dict) – Additional keyword arguments.

Returns:

  • samples (D) – Samples from the base distribution p0.

  • kwargs (dict) – Additional keyword arguments.

abstract property scheduler: Scheduler[D]

Base model-dependent scheduler used for sampling.

train_loss(x1: D, xt: D | None = None, t: Tensor | None = None, pred: D | None = None, **kwargs: Any) Tensor[source]

Compute loss for a single batch training step.

Parameters:
  • x1 (D) – Target data points.

  • xt (Optional[D], default=None) – Noisy data points at time t. If None, will be sampled.

  • t (Optional[torch.Tensor], shape (len(x1),), default=None) – Time steps. If None, will be sampled.

  • pred (Optional[D], default=None) – Model predictions. If None, will be computed by the model.

  • **kwargs (dict) – Keyword arguments

Returns:

loss – Computed loss for the training step.

Return type:

torch.Tensor, shape (len(x1),)