BaseModel
- class BaseModel(device: device | None)[source]
Bases:
ABC,Module,Generic[D]Abstract base class for base models used in flow matching and diffusion.
- preprocess(x: D, **kwargs: Any) tuple[D, dict[str, Any]][source]
Preprocess data and keyword arguments for the base model.
- Parameters:
x (D) – Input data to preprocess.
**kwargs (dict) – Additional keyword arguments to preprocess.
- Returns:
output (D) – Preprocessed data.
kwargs (dict) – Preprocessed keyword arguments.
- abstract sample_p0(n: int, **kwargs: Any) tuple[D, dict[str, Any]][source]
Sample n data points from the base distribution p0.
- Parameters:
n (int) – Number of samples to draw.
**kwargs (dict) – Additional keyword arguments.
- Returns:
samples (D) – Samples from the base distribution p0.
kwargs (dict) – Additional keyword arguments.
- train_loss(x1: D, xt: D | None = None, t: Tensor | None = None, pred: D | None = None, **kwargs: Any) Tensor[source]
Compute loss for a single batch training step.
- Parameters:
x1 (D) – Target data points.
xt (Optional[D], default=None) – Noisy data points at time t. If None, will be sampled.
t (Optional[torch.Tensor], shape (len(x1),), default=None) – Time steps. If None, will be sampled.
pred (Optional[D], default=None) – Model predictions. If None, will be computed by the model.
**kwargs (dict) – Keyword arguments
- Returns:
loss – Computed loss for the training step.
- Return type:
torch.Tensor, shape (len(x1),)