Environment

class Environment(base_model: BaseModel[D], reward: Reward[D], discretization_steps: int, reward_scale: float = 1.0)[source]

Bases: ABC, Generic[D]

Abstract base class for all environments.

Parameters:
  • base_model (BaseModel[D]) – The base generative model used in the environment.

  • reward (Reward[D]) – The reward function used to compute the final reward.

  • discretization_steps (int) – The number of discretization steps to use when sampling trajectories.

  • reward_scale (float, default=1.0) – Scale of the reward (can be negative). This is used to control trade-off between high rewards and proximity to base model.

batch_sample(n: int, batch_size: int, pbar: bool = False, **kwargs: Any) Sample[D][source]

Sample n trajectories from the environment in batches.

Parameters:
  • n (int) – Number of trajectories to sample.

  • batch_size (int) – Batch size for sampling.

  • pbar (bool, default=False) – Whether to display progress bars or not.

  • **kwargs (dict) – Additional keyword arguments to pass to the base model at every timestep (e.g. text embedding or class label).

property control_policy: Policy[D] | None

Current control policy u(x, t) of the environment.

property device: device

Get the device of the base model.

diffusion(x: D, t: Tensor) D[source]

Compute the diffusion term of the environment’s dynamics.

Parameters:
  • x (D) – The current state.

  • t (torch.Tensor, shape (n,)) – The current time step in [0, 1].

Returns:

diffusion – The diffusion term at time t.

Return type:

D

abstract drift(x: D, t: Tensor, **kwargs: Any) tuple[D, Tensor][source]

Compute the drift term of the environment’s dynamics.

Parameters:
  • x (D) – The current state.

  • t (torch.Tensor, shape (n,)) – The current time step in [0, 1].

  • **kwargs (dict) – Keyword arguments to the model.

Returns:

  • drift (D) – The drift term at state x and time t.

  • running_cost (torch.Tensor, shape (n,)) – Running cost \(L(x_t, t)\) of the policy for the given (state, timestep)-pair.

property is_control_policy_set: bool

Whether a custom policy has been set.

property is_policy_set: bool

Whether a custom policy has been set.

property policy: Policy[D]

Current policy (replacement of base model) of the environment.

abstract pred_final(x: D, t: Tensor, **kwargs: Any) D[source]

Compute the final state prediction from the current state.

Parameters:
  • x (D) – The current state.

  • t (torch.Tensor, shape (n,)) – The current time step in [0, 1].

  • **kwargs (dict) – Keyword arguments to the model.

Returns:

final – The predicted final state from state x and time t.

Return type:

D

sample(n: int, pbar: bool = True, x0: D | None = None, **kwargs: Any) Sample[D][source]

Sample n trajectories from the environment.

Parameters:
  • n (int) – Number of trajectories to sample.

  • pbar (bool, default: True) – Whether to display a progress bar.

  • x0 (D, optional) – Initial states to start the trajectories from. If None, samples from \(p_0\).

  • **kwargs (dict) – Additional keyword arguments to pass to the base model at every timestep (e.g. text embedding or class label).

Returns:

A Sample object containing the sampled trajectories and associated data.

Return type:

Sample[D]

property scheduler: Scheduler[D]

Get the scheduler of the base model.