FlowMolBaseModel

class FlowMolBaseModel(model_name: str, scheduler_params: tuple[float, float, float, float], device: device)[source]

Bases: BaseModel[DDGraph]

Pre-trained FlowMol on GEOM-Drugs and QM9.

forward(x: DDGraph, t: Tensor, **kwargs: Any) DDGraph[source]

Compute the endpoint vector field \(\hat{x_1}(x, t)\).

output_type: OutputType = 'endpoint'
postprocess(x: DDGraph) DDGraph[source]

Re-name features from x_t to x_1.

preprocess(x: DDGraph, **kwargs: Any) tuple[DDGraph, dict[str, Any]][source]

Preprocess data and keyword arguments for the base model.

Parameters:
  • x (D) – Input data to preprocess.

  • **kwargs (dict) – Additional keyword arguments to preprocess.

Returns:

  • output (D) – Preprocessed data.

  • kwargs (dict) – Preprocessed keyword arguments.

sample_p0(n: int, **kwargs: Any) tuple[DDGraph, dict[str, Any]][source]

Sample n datapoints from the base distribution \(p_0\).

Parameters:

n (int) – Number of samples to draw.

Returns:

samples – Samples from the base distribution \(p_0\).

Return type:

DDGraph

Notes

The base distribution \(p_0\) is a standard Gaussian distribution.

property scheduler: FlowMolScheduler

Scheduler used for sampling.

train_loss(x1: DDGraph, xt: DDGraph | None = None, t: Tensor | None = None, pred: DDGraph | None = None, **kwargs: Any) Tensor[source]

Compute loss for a single batch training step.

Parameters:
  • x1 (DDGraph) – Target data points.

  • xt (Optional[DDGraph], default=None) – Noisy data points at time t. If None, will be sampled.

  • t (Optional[torch.Tensor], shape (len(x1),), default=None) – Time steps. If None, will be sampled.

  • pred (Optional[DDGraph], default=None) – Model predictions. If None, will be computed by the model.

  • **kwargs (dict) – Keyword arguments

Returns:

loss – Computed loss for the training step.

Return type:

torch.Tensor, shape (len(x1),)