FlowMolBaseModel
- class FlowMolBaseModel(model_name: str, scheduler_params: tuple[float, float, float, float], device: device)[source]
-
Pre-trained FlowMol on GEOM-Drugs and QM9.
- forward(x: DDGraph, t: Tensor, **kwargs: Any) DDGraph[source]
Compute the endpoint vector field \(\hat{x_1}(x, t)\).
- preprocess(x: DDGraph, **kwargs: Any) tuple[DDGraph, dict[str, Any]][source]
Preprocess data and keyword arguments for the base model.
- Parameters:
x (D) – Input data to preprocess.
**kwargs (dict) – Additional keyword arguments to preprocess.
- Returns:
output (D) – Preprocessed data.
kwargs (dict) – Preprocessed keyword arguments.
- sample_p0(n: int, **kwargs: Any) tuple[DDGraph, dict[str, Any]][source]
Sample n datapoints from the base distribution \(p_0\).
- Parameters:
n (int) – Number of samples to draw.
- Returns:
samples – Samples from the base distribution \(p_0\).
- Return type:
Notes
The base distribution \(p_0\) is a standard Gaussian distribution.
- train_loss(x1: DDGraph, xt: DDGraph | None = None, t: Tensor | None = None, pred: DDGraph | None = None, **kwargs: Any) Tensor[source]
Compute loss for a single batch training step.
- Parameters:
x1 (DDGraph) – Target data points.
xt (Optional[DDGraph], default=None) – Noisy data points at time t. If None, will be sampled.
t (Optional[torch.Tensor], shape (len(x1),), default=None) – Time steps. If None, will be sampled.
pred (Optional[DDGraph], default=None) – Model predictions. If None, will be computed by the model.
**kwargs (dict) – Keyword arguments
- Returns:
loss – Computed loss for the training step.
- Return type:
torch.Tensor, shape (len(x1),)