Source code for diffusiongym.molecules.types

"""Types for molecular graphs in diffusiongym."""

from typing import Optional, Sequence, Union

import dgl
import torch
from typing_extensions import Self

from diffusiongym.types import BinaryOp, DDMixin, UnaryOp


def construct_ue_mask(g: dgl.DGLGraph) -> torch.Tensor:
    """Construct a mask indicating upper edges in the graph."""
    ul_pattern = torch.tensor([1, 0], device=g.device).repeat(g.batch_size)
    n_edges_pattern = (g.batch_num_edges() / 2).int().repeat_interleave(2)
    return ul_pattern.repeat_interleave(n_edges_pattern).bool()


def construct_n_idx(g: dgl.DGLGraph) -> torch.Tensor:
    """Construct a tensor which maps each node to its graph index in the batch."""
    return torch.repeat_interleave(
        torch.arange(g.batch_size, device=g.device),
        g.batch_num_nodes(),
    )


def construct_e_idx(g: dgl.DGLGraph) -> torch.Tensor:
    """Construct a tensor which maps each edge to its graph index in the batch."""
    return torch.repeat_interleave(
        torch.arange(g.batch_size, device=g.device),
        g.batch_num_edges(),
    )


[docs] class DDGraph(DDMixin): """A wrapper around DGLGraph that supports required factory methods. Parameters ---------- graph : dgl.DGLGraph The graph to wrap. ue_mask : Optional[torch.Tensor], optional Mask indicating upper edges in the graph, by default None n_idx : Optional[torch.Tensor], optional Tensor mapping each node to its graph index in the batch, by default None e_idx : Optional[torch.Tensor], optional Tensor mapping each edge to its graph index in the batch, by default None """ def __init__( self, graph: dgl.DGLGraph, ue_mask: Optional[torch.Tensor] = None, n_idx: Optional[torch.Tensor] = None, e_idx: Optional[torch.Tensor] = None, ): if ue_mask is None: ue_mask = construct_ue_mask(graph) if n_idx is None: n_idx = construct_n_idx(graph) if e_idx is None: e_idx = construct_e_idx(graph) self.graph = graph self.ue_mask = ue_mask self.n_idx = n_idx self.e_idx = e_idx def __repr__(self) -> str: return f"{self.__class__.__name__}(num_nodes={self.graph.num_nodes()}, num_edges={self.graph.num_edges()}, batch_size={len(self)})" @property def device(self) -> torch.device: return self.graph.device
[docs] def to(self, device: torch.device | str) -> Self: return self.__class__( self.graph.to(device), self.ue_mask.to(device), self.n_idx.to(device), self.e_idx.to(device), )
def __len__(self) -> int: return int(self.graph.batch_size) def __getitem__(self, idx: Union[int, slice]) -> Self: if isinstance(idx, int): n = len(self) # Faster to use slice_batch if we only want one item if idx < 0: idx += n if idx < 0 or idx >= n: raise IndexError(f"Index {idx} out of range for batch size {n}") return self.__class__(dgl.slice_batch(self.graph, idx)) if isinstance(idx, slice): graphs = dgl.unbatch(self.graph) selected_graphs = graphs[idx] if not selected_graphs: raise ValueError("The slice resulted in an empty graph sequence.") return self.__class__(dgl.batch(selected_graphs)) raise TypeError(f"Invalid index type: {type(idx)}")
[docs] @classmethod def collate(cls, items: Sequence[Self]) -> Self: if not items: raise ValueError("Cannot collate an empty sequence") return cls(dgl.batch([item.graph for item in items]))
def _get_empty_graph(self) -> dgl.DGLGraph: """Get an empty graph with the same structure as Self.""" # Clone the graph structure empty_graph = dgl.graph(self.graph.edges(), num_nodes=self.graph.num_nodes()) # Preserve batch information if self.graph.batch_size > 1: empty_graph.set_batch_num_nodes(self.graph.batch_num_nodes()) empty_graph.set_batch_num_edges(self.graph.batch_num_edges()) return empty_graph
[docs] def apply(self, op: UnaryOp) -> Self: res = self._get_empty_graph() for key, val in self.graph.ndata.items(): if isinstance(val, torch.Tensor): res.ndata[key] = op(val) for key, val in self.graph.edata.items(): if isinstance(val, torch.Tensor): res.edata[key] = op(val) return self.__class__(res, self.ue_mask, self.n_idx, self.e_idx)
[docs] def combine(self, other: Union[Self, float, torch.Tensor], op: BinaryOp) -> Self: res = self._get_empty_graph() if isinstance(other, DDGraph): for key, val in self.graph.ndata.items(): if key in other.graph.ndata: res.ndata[key] = op(val, other.graph.ndata[key]) # type: ignore else: res.ndata[key] = val for key, val in self.graph.edata.items(): if key in other.graph.edata: res.edata[key] = op(val, other.graph.edata[key]) # type: ignore else: res.edata[key] = val else: for key, val in self.graph.ndata.items(): res.ndata[key] = op(val, other) # type: ignore for key, val in self.graph.edata.items(): res.edata[key] = op(val, other) # type: ignore return self.__class__(res, self.ue_mask, self.n_idx, self.e_idx)
[docs] def aggregate(self, reduction: str = "mean") -> torch.Tensor: batch_size = len(self) summed = torch.zeros(batch_size, device=self.graph.device) # Initialize counts if we need to calculate the mean later counts = None if reduction == "mean": counts = torch.zeros(batch_size, device=self.graph.device) for _, val in self.graph.ndata.items(): if isinstance(val, torch.Tensor): aggregated = torch.zeros(batch_size, *val.shape[1:], device=val.device, dtype=val.dtype) aggregated.index_add_(0, self.n_idx, val) summed += aggregated.sum(dim=-1) # Track number of elements added if counts is not None: num_elements = val[0].numel() item_counts = torch.zeros(batch_size, device=val.device) ones = torch.ones(val.size(0), device=val.device) item_counts.index_add_(0, self.n_idx, ones) counts += item_counts * num_elements for _, val in self.graph.edata.items(): if isinstance(val, torch.Tensor): aggregated = torch.zeros(batch_size, *val.shape[1:], device=val.device, dtype=val.dtype) aggregated.index_add_(0, self.e_idx, val) summed += aggregated.sum(dim=-1) # Track number of elements added if counts is not None: num_elements = val[0].numel() item_counts = torch.zeros(batch_size, device=val.device) ones = torch.ones(val.size(0), device=val.device) item_counts.index_add_(0, self.e_idx, ones) counts += item_counts * num_elements if counts is not None: # Avoid division by zero for empty graphs return summed / counts.clamp(min=1) return summed
[docs] def randn_like(self) -> Self: out = super().randn_like() # Remove COM init_coms = dgl.readout_nodes(out.graph, feat="x_t", op="mean") out.graph.ndata["x_t"] = out.graph.ndata["x_t"] - init_coms[out.n_idx] # Also make sure that both sides of edges are equivalent out.graph.edata["e_t"][~out.ue_mask] = out.graph.edata["e_t"][out.ue_mask] return out