Scheduler

class Scheduler[source]

Bases: ABC, Generic[D]

Abstract base class for schedulers of flow matching models.

Generally \(\beta_t = 1-\alpha_t\), but this can be re-defined.

abstract alpha(x: D, t: Tensor) D[source]

\(\alpha_t\).

Can be overwritten if \(\alpha_t\) is data-dependent.

Parameters:
  • x (D) – Data tensor.

  • t (torch.Tensor, shape (n,)) – Time tensor with values in [0, 1].

Returns:

alpha – Values of \(\alpha_t\) at the given times.

Return type:

D, same data shape as x

abstract alpha_dot(x: D, t: Tensor) D[source]

\(\dot{\alpha}_t\).

Can be overwritten if \(\dot{\alpha}_t\) is data-dependent.

Parameters:
  • x (D) – Data tensor.

  • t (torch.Tensor, shape (n,)) – Time tensor with values in [0, 1].

Returns:

alpha_dot – Values of \(\dot{\alpha}_t\) at the given times.

Return type:

D, same data shape as x

beta(x: D, t: Tensor) D[source]

\(\beta_t\).

Parameters:
  • x (D) – Data tensor.

  • t (torch.Tensor, shape (n,)) – Time tensor with values in [0, 1].

Returns:

beta – Values of \(\beta_t\) at the given times.

Return type:

D, same data shape as x

beta_dot(x: D, t: Tensor) D[source]

\(\dot{\beta}_t\).

Can be overwritten if \(\dot{\beta}_t\) is data-dependent.

Parameters:
  • x (D) – Data tensor.

  • t (torch.Tensor, shape (n,)) – Time tensor with values in [0, 1].

Returns:

beta_dot – Values of \(\dot{\beta}_t\) at the given times.

Return type:

D, same data shape as x

eta(x: D, t: Tensor) D[source]

\(\eta_t\) as defined in [Adjoint Matching](https://arxiv.org/abs/2409.08861).

kappa(x: D, t: Tensor) D[source]

\(\kappa_t\) as defined in [Adjoint Matching](https://arxiv.org/abs/2409.08861).

model_input(t: Tensor) Tensor[source]

Input to the model at time t.

Defaults to t, but could be different if using a different time parameterization.

property noise_schedule: NoiseSchedule[D]

Get the current noise schedule.

sigma(x: D, t: Tensor) D[source]

\(\sigma(t)\) noise schedule.

Parameters:
  • x (D) – Data tensor.

  • t (torch.Tensor, shape (n,)) – Time tensor with values in [0, 1].

Returns:

sigma – Values of \(\sigma(t)\) at the given times.

Return type:

D, same data shape as x