VelocityEnvironment

class VelocityEnvironment(base_model: BaseModel[D], reward: Reward[D], discretization_steps: int, reward_scale: float = 1.0)[source]

Bases: Environment[D]

Environment with tensor samples and base model predict velocity \(v(x, t)\).

Parameters:
  • base_model (BaseModel) – The base generative model used in the environment.

  • reward (Reward) – The reward function used to compute the final reward.

  • discretization_steps (int) – The number of discretization steps to use when sampling trajectories.

  • reward_scale (float, default=1.0) – Scale of the reward (can be negative). This is used to control trade-off between high rewards and proximity to base model.

drift(x: D, t: Tensor, **kwargs: Any) tuple[D, Tensor][source]

Compute the drift term of the environment’s dynamics.

Parameters:
  • x (D) – The current state.

  • t (torch.Tensor, shape (n,)) – The current time step in [0, 1].

  • **kwargs (dict) – Additional keyword arguments to pass to the base model (e.g. text embedding or class label).

Returns:

  • drift (D) – The drift term at state x and time t.

  • running_cost (torch.Tensor, shape (n,)) – Running cost \(L(x_t, t)\) of the policy for the given (state, timestep)-pair.

pred_final(x: D, t: Tensor, **kwargs: Any) D[source]

Compute the final state prediction from the current state.

Parameters:
  • x (D) – The current state.

  • t (torch.Tensor, shape (n,)) – The current time step in [0, 1].

  • **kwargs (dict) – Keyword arguments to the model.

Returns:

final – The predicted final state from state x and time t.

Return type:

D