Environment
- class Environment(base_model: BaseModel[D], reward: Reward[D], discretization_steps: int, reward_scale: float = 1.0)[source]
Bases:
ABC,Generic[D]Abstract base class for all environments.
- Parameters:
base_model (BaseModel[D]) – The base generative model used in the environment.
reward (Reward[D]) – The reward function used to compute the final reward.
discretization_steps (int) – The number of discretization steps to use when sampling trajectories.
reward_scale (float, default=1.0) – Scale of the reward (can be negative). This is used to control trade-off between high rewards and proximity to base model.
- batch_sample(n: int, batch_size: int, pbar: bool = False, **kwargs: Any) Sample[D][source]
Sample n trajectories from the environment in batches.
- Parameters:
n (int) – Number of trajectories to sample.
batch_size (int) – Batch size for sampling.
pbar (bool, default=False) – Whether to display progress bars or not.
**kwargs (dict) – Additional keyword arguments to pass to the base model at every timestep (e.g. text embedding or class label).
- property control_policy: Policy[D] | None
Current control policy u(x, t) of the environment.
- abstract drift(x: D, t: Tensor, **kwargs: Any) tuple[D, Tensor][source]
Compute the drift term of the environment’s dynamics.
- Parameters:
x (D) – The current state.
t (torch.Tensor, shape (n,)) – The current time step in [0, 1].
**kwargs (dict) – Keyword arguments to the model.
- Returns:
drift (D) – The drift term at state x and time t.
running_cost (torch.Tensor, shape (n,)) – Running cost \(L(x_t, t)\) of the policy for the given (state, timestep)-pair.
- property policy: Policy[D]
Current policy (replacement of base model) of the environment.
- abstract pred_final(x: D, t: Tensor, **kwargs: Any) D[source]
Compute the final state prediction from the current state.
- sample(n: int, pbar: bool = True, x0: D | None = None, **kwargs: Any) Sample[D][source]
Sample n trajectories from the environment.
- Parameters:
n (int) – Number of trajectories to sample.
pbar (bool, default: True) – Whether to display a progress bar.
x0 (D, optional) – Initial states to start the trajectories from. If None, samples from \(p_0\).
**kwargs (dict) – Additional keyword arguments to pass to the base model at every timestep (e.g. text embedding or class label).
- Returns:
A Sample object containing the sampled trajectories and associated data.
- Return type:
Sample[D]