From 20a8c17aa59b5f6b800fc277ad3ba74304f6c462 Mon Sep 17 00:00:00 2001 From: Sander Vandenhaute Date: Wed, 7 Aug 2024 14:44:59 -0400 Subject: [PATCH] began adding docstrings --- psiflow/data/utils.py | 327 ++++++++++++++++++++++++++++++++++++++++++ psiflow/geometry.py | 222 ++++++++++++++++++++++++++++ 2 files changed, 549 insertions(+) diff --git a/psiflow/data/utils.py b/psiflow/data/utils.py index 80ac9f7..9078de1 100644 --- a/psiflow/data/utils.py +++ b/psiflow/data/utils.py @@ -18,6 +18,21 @@ def _write_frames( extra_states: Union[Geometry, list[Geometry], None] = None, outputs: list = [], ) -> None: + """ + Write Geometry instances to a file. + + Args: + *states: Variable number of Geometry instances to write. + extra_states: Additional Geometry instance(s) to write. + outputs: List of Parsl futures. The first element should be a DataFuture + representing the output file path. + + Returns: + None + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ all_states = list(states) if extra_states is not None: if isinstance(extra_states, list): @@ -38,6 +53,24 @@ def _read_frames( inputs: list = [], outputs: list = [], ) -> Optional[list[Geometry]]: + """ + Read Geometry instances from a file. + + Args: + indices: Indices of frames to read. Can be None (read all), a slice, a list of integers, or a single integer. + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing the geometry data. + outputs: List of Parsl futures. If provided, the first element should be + a DataFuture representing the output file path where the selected + geometries will be written. + + Returns: + Optional[list[Geometry]]: List of read Geometry instances if no output + is specified, otherwise None. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ frame_index = 0 frame_regex = re.compile(r"^\d+$") length = _count_frames(inputs=inputs) @@ -95,6 +128,23 @@ def _extract_quantities( *extra_data: Geometry, inputs: list = [], ) -> tuple[np.ndarray, ...]: + """ + Extract specified quantities from Geometry instances. + + Args: + quantities: Tuple of quantity names to extract. + atom_indices: List of atom indices to consider. + elements: List of element symbols to consider. + *extra_data: Additional Geometry instances. + inputs: List of Parsl futures. If provided, the first element should be a DataFuture + representing the input file path containing geometry data. + + Returns: + tuple[np.ndarray, ...]: Tuple of arrays containing extracted quantities. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ if not len(extra_data): assert len(inputs) == 1 data = _read_frames(inputs=inputs) @@ -161,6 +211,24 @@ def _insert_quantities( inputs: list = [], outputs: list = [], ) -> None: + """ + Insert quantities into Geometry instances. + + Args: + quantities: Tuple of quantity names to insert. + arrays: List of arrays containing the quantities to insert. + data: List of Geometry instances to update. + inputs: List of Parsl futures. If provided, the first element should be a DataFuture + representing the input file path containing geometry data. + outputs: List of Parsl futures. If provided, the first element should be a DataFuture + representing the output file path where updated geometries will be written. + + Returns: + None + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ if data is None: assert len(inputs) == 1 data = _read_frames(inputs=inputs) @@ -214,6 +282,19 @@ def _insert_quantities( @typeguard.typechecked def _check_distances(state: Geometry, threshold: float) -> Geometry: + """ + Check if all interatomic distances in a Geometry are above a threshold. + + Args: + state: Geometry instance to check. + threshold: Minimum allowed interatomic distance. + + Returns: + Geometry: The input Geometry if all distances are above the threshold, otherwise NullState. + + Note: + This function is wrapped as a Parsl app and executed using the default_htex executor. + """ from ase.geometry.geometry import find_mic if state == NullState: @@ -244,6 +325,22 @@ def _assign_identifiers( inputs: list = [], outputs: list = [], ) -> int: + """ + Assign identifiers to Geometry instances in a file. + + Args: + identifier: Starting identifier value. + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + outputs: List of Parsl futures. The first element should be a DataFuture + representing the output file path where updated geometries will be written. + + Returns: + int: Next available identifier. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ data = _read_frames(slice(None), inputs=[inputs[0]]) states = [] if identifier is None: # do not assign but look for max @@ -274,6 +371,21 @@ def _join_frames( inputs: list = [], outputs: list = [], ): + """ + Join multiple frame files into a single file. + + Args: + inputs: List of Parsl futures. Each element should be a DataFuture + representing an input file path containing geometry data. + outputs: List of Parsl futures. The first element should be a DataFuture + representing the output file path where joined frames will be written. + + Returns: + None + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ assert len(outputs) == 1 with open(outputs[0], "wb") as destination: @@ -287,6 +399,19 @@ def _join_frames( @typeguard.typechecked def _count_frames(inputs: list = []) -> int: + """ + Count the number of frames in a file. + + Args: + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + + Returns: + int: Number of frames in the file. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ nframes = 0 frame_regex = re.compile(r"^\d+$") with open(inputs[0], "r") as f: @@ -303,6 +428,21 @@ def _count_frames(inputs: list = []) -> int: @typeguard.typechecked def _reset_frames(inputs: list = [], outputs: list = []) -> None: + """ + Reset all frames in a file. + + Args: + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + outputs: List of Parsl futures. The first element should be a DataFuture + representing the output file path where reset frames will be written. + + Returns: + None + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ data = _read_frames(inputs=[inputs[0]]) for geometry in data: geometry.reset() @@ -314,6 +454,21 @@ def _reset_frames(inputs: list = [], outputs: list = []) -> None: @typeguard.typechecked def _clean_frames(inputs: list = [], outputs: list = []) -> None: + """ + Clean all frames in a file. + + Args: + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + outputs: List of Parsl futures. The first element should be a DataFuture + representing the output file path where cleaned frames will be written. + + Returns: + None + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ data = _read_frames(inputs=[inputs[0]]) for geometry in data: geometry.clean() @@ -330,6 +485,23 @@ def _apply_offset( outputs: list = [], **atomic_energies: float, ) -> None: + """ + Apply an energy offset to all frames in a file. + + Args: + subtract: Whether to subtract or add the offset. + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + outputs: List of Parsl futures. The first element should be a DataFuture + representing the output file path where updated frames will be written. + **atomic_energies: Atomic energies for each element. + + Returns: + None + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ assert len(inputs) == 1 assert len(outputs) == 1 data = _read_frames(inputs=[inputs[0]]) @@ -361,6 +533,19 @@ def _apply_offset( @typeguard.typechecked def _get_elements(inputs: list = []) -> set[str]: + """ + Get the set of elements present in all frames of a file. + + Args: + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + + Returns: + set[str]: Set of element symbols. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ data = _read_frames(inputs=[inputs[0]]) return set([chemical_symbols[n] for g in data for n in g.per_atom.numbers]) @@ -370,6 +555,21 @@ def _get_elements(inputs: list = []) -> set[str]: @typeguard.typechecked def _align_axes(inputs: list = [], outputs: list = []) -> None: + """ + Align axes for all frames in a file. + + Args: + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + outputs: List of Parsl futures. The first element should be a DataFuture + representing the output file path where aligned frames will be written. + + Returns: + None + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ data = _read_frames(inputs=[inputs[0]]) for geometry in data: geometry.align_axes() @@ -381,6 +581,21 @@ def _align_axes(inputs: list = [], outputs: list = []) -> None: @typeguard.typechecked def _not_null(inputs: list = [], outputs: list = []) -> list[bool]: + """ + Check which frames in a file are not null states. + + Args: + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + outputs: List of Parsl futures. If provided, the first element should be a DataFuture + representing the output file path where non-null frames will be written. + + Returns: + list[bool]: List of boolean values indicating non-null states. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ frame_regex = re.compile(r"^\d+$") data = [] @@ -414,6 +629,22 @@ def _app_filter( inputs: list = [], outputs: list = [], ) -> None: + """ + Filter frames based on a specified quantity. + + Args: + quantity: The quantity to filter on. + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + outputs: List of Parsl futures. The first element should be a DataFuture + representing the output file path where filtered frames will be written. + + Returns: + None + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ data = _read_frames(inputs=[inputs[0]]) i = 0 while i < len(data): @@ -451,6 +682,21 @@ def _shuffle( inputs: list = [], outputs: list = [], ) -> None: + """ + Shuffle the order of frames in a file. + + Args: + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + outputs: List of Parsl futures. The first element should be a DataFuture + representing the output file path where shuffled frames will be written. + + Returns: + None + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ data = _read_frames(inputs=[inputs[0]]) indices = np.arange(len(data)) np.random.shuffle(indices) @@ -468,6 +714,20 @@ def _train_valid_indices( train_valid_split: float, shuffle: bool, ) -> tuple[list[int], list[int]]: + """ + Generate indices for train and validation splits. + + Args: + effective_nstates: Total number of states. + train_valid_split: Fraction of states to use for training. + shuffle: Whether to shuffle the indices. + + Returns: + tuple[list[int], list[int]]: Lists of indices for training and validation sets. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ ntrain = int(np.floor(effective_nstates * train_valid_split)) nvalid = effective_nstates - ntrain assert ntrain > 0 @@ -489,6 +749,17 @@ def get_train_valid_indices( train_valid_split: float, shuffle: bool, ) -> tuple[AppFuture, AppFuture]: + """ + Get futures for train and validation indices. + + Args: + effective_nstates: Future representing the total number of states. + train_valid_split: Fraction of states to use for training. + shuffle: Whether to shuffle the indices. + + Returns: + tuple[AppFuture, AppFuture]: Futures for training and validation indices. + """ future = train_valid_indices(effective_nstates, train_valid_split, shuffle) return unpack_i(future, 0), unpack_i(future, 1) @@ -500,6 +771,18 @@ def get_index_element_mask( elements: Optional[list[str]], natoms_padded: Optional[int] = None, ) -> np.ndarray: + """ + Generate a mask for atom indices and elements. + + Args: + numbers: Array of atomic numbers. + atom_indices: List of atom indices to include. + elements: List of element symbols to include. + natoms_padded: Total number of atoms including padding. + + Returns: + np.ndarray: Boolean mask array. + """ mask = np.array([True] * len(numbers)) if elements is not None: @@ -528,6 +811,20 @@ def _compute_rmse( array1: np.ndarray, reduce: bool = True, ) -> Union[float, np.ndarray]: + """ + Compute the Root Mean Square Error (RMSE) between two arrays. + + Args: + array0: First array. + array1: Second array. + reduce: Whether to reduce the result to a single value. + + Returns: + Union[float, np.ndarray]: RMSE value(s). + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ assert array0.shape == array1.shape assert np.all(np.isnan(array0) == np.isnan(array1)) @@ -561,6 +858,20 @@ def _compute_mae( array1, reduce: bool = True, ) -> Union[float, np.ndarray]: + """ + Compute the Mean Absolute Error (MAE) between two arrays. + + Args: + array0: First array. + array1: Second array. + reduce: Whether to reduce the result to a single value. + + Returns: + Union[float, np.ndarray]: MAE value(s). + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ assert array0.shape == array1.shape mask0 = np.logical_not(np.isnan(array0)) mask1 = np.logical_not(np.isnan(array1)) @@ -584,6 +895,22 @@ def _batch_frames( inputs: list = [], outputs: list = [], ) -> Optional[list[Geometry]]: + """ + Split frames into batches. + + Args: + batch_size: Number of frames per batch. + inputs: List of Parsl futures. The first element should be a DataFuture + representing the input file path containing geometry data. + outputs: List of Parsl futures. Each element should be a DataFuture + representing an output file path for each batch. + + Returns: + Optional[list[Geometry]]: List of Geometry instances if no outputs are specified. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ frame_regex = re.compile(r"^\d+$") data = [] diff --git a/psiflow/geometry.py b/psiflow/geometry.py index 143e070..bd9e271 100644 --- a/psiflow/geometry.py +++ b/psiflow/geometry.py @@ -37,6 +37,25 @@ @typeguard.typechecked class Geometry: + """ + Represents an atomic structure with associated properties. + + This class encapsulates the atomic structure, including atom positions, cell parameters, + and various physical properties such as energy and forces. + + Attributes: + per_atom (np.recarray): Record array containing per-atom properties. + cell (np.ndarray): 3x3 array representing the unit cell vectors. + order (dict): Dictionary to store custom ordering information. + energy (Optional[float]): Total energy of the system. + stress (Optional[np.ndarray]): Stress tensor of the system. + delta (Optional[float]): Delta value, if applicable. + phase (Optional[str]): Phase information, if applicable. + logprob (Optional[np.ndarray]): Log probability values, if applicable. + stdout (Optional[str]): Standard output information, if applicable. + identifier (Optional[int]): Unique identifier for the geometry. + """ + per_atom: np.recarray cell: np.ndarray order: dict @@ -61,6 +80,22 @@ def __init__( stdout: Optional[str] = None, identifier: Optional[int] = None, ): + """ + Initialize a Geometry instance, though the preferred way of instantiating + proceeds via the `from_data` or `from_atoms` class methods + + Args: + per_atom (np.recarray): Record array containing per-atom properties. + cell (np.ndarray): 3x3 array representing the unit cell vectors. + order (Optional[dict], optional): Custom ordering information. Defaults to None. + energy (Optional[float], optional): Total energy of the system. Defaults to None. + stress (Optional[np.ndarray], optional): Stress tensor of the system. Defaults to None. + delta (Optional[float], optional): Delta value. Defaults to None. + phase (Optional[str], optional): Phase information. Defaults to None. + logprob (Optional[np.ndarray], optional): Log probability values. Defaults to None. + stdout (Optional[str], optional): Standard output information. Defaults to None. + identifier (Optional[int], optional): Unique identifier for the geometry. Defaults to None. + """ self.per_atom = per_atom.astype(per_atom_dtype) # copies data self.cell = cell.astype(np.float32) assert self.cell.shape == (3, 3) @@ -76,6 +111,9 @@ def __init__( self.identifier = identifier def reset(self): + """ + Reset all computed properties of the geometry to their default values. + """ self.energy = None self.stress = None self.delta = None @@ -84,12 +122,24 @@ def reset(self): self.per_atom.forces[:] = np.nan def clean(self): + """ + Clean the geometry by resetting properties and removing additional information. + """ self.reset() self.order = {} self.stdout = None self.identifier = None def __eq__(self, other) -> bool: + """ + Check if two Geometry instances are equal. + + Args: + other: The other object to compare with. + + Returns: + bool: True if the geometries are equal, False otherwise. + """ if not isinstance(other, Geometry): return False # have to check separately for np.allclose due to different dtypes @@ -104,6 +154,9 @@ def __eq__(self, other) -> bool: return bool(equal) def align_axes(self): + """ + Align the axes of the unit cell to a canonical representation for periodic systems. + """ if self.periodic: # only do something if periodic: positions = self.per_atom.positions cell = self.cell @@ -111,9 +164,21 @@ def align_axes(self): reduce_box_vectors(cell) def __len__(self): + """ + Get the number of atoms in the geometry. + + Returns: + int: The number of atoms. + """ return len(self.per_atom) def to_string(self) -> str: + """ + Convert the Geometry instance to a string representation in extended XYZ format. + + Returns: + str: String representation of the geometry. + """ if self.periodic: comment = 'Lattice="' comment += " ".join([str(x) for x in np.reshape(self.cell.T, 9, order="F")]) @@ -162,15 +227,37 @@ def to_string(self) -> str: return "\n".join(lines) def save(self, path_xyz: Union[Path, str]): + """ + Save the Geometry instance to an XYZ file. + + Args: + path_xyz (Union[Path, str]): Path to save the XYZ file. + """ path_xyz = psiflow.resolve_and_check(path_xyz) with open(path_xyz, "w") as f: f.write(self.to_string()) def copy(self) -> Geometry: + """ + Create a deep copy of the Geometry instance. + + Returns: + Geometry: A new Geometry instance with the same data. + """ return Geometry.from_string(self.to_string()) @classmethod def from_string(cls, s: str, natoms: Optional[int] = None) -> Optional[Geometry]: + """ + Create a Geometry instance from a string representation in extended XYZ format. + + Args: + s (str): String representation of the geometry. + natoms (Optional[int], optional): Number of atoms (if known). Defaults to None. + + Returns: + Optional[Geometry]: A new Geometry instance, or None if the string is empty. + """ if len(s) == 0: return None if not natoms: # natoms in s @@ -216,6 +303,15 @@ def from_string(cls, s: str, natoms: Optional[int] = None) -> Optional[Geometry] @classmethod def load(cls, path_xyz: Union[Path, str]) -> Geometry: + """ + Load a Geometry instance from an XYZ file. + + Args: + path_xyz (Union[Path, str]): Path to the XYZ file. + + Returns: + Geometry: A new Geometry instance loaded from the file. + """ path_xyz = psiflow.resolve_and_check(Path(path_xyz)) assert path_xyz.exists() with open(path_xyz, "r") as f: @@ -224,10 +320,22 @@ def load(cls, path_xyz: Union[Path, str]) -> Geometry: @property def periodic(self): + """ + Check if the geometry is periodic. + + Returns: + bool: True if the geometry is periodic, False otherwise. + """ return np.any(self.cell) @property def per_atom_energy(self): + """ + Calculate the energy per atom. + + Returns: + Optional[float]: Energy per atom if total energy is available, None otherwise. + """ if self.energy is None: return None else: @@ -235,6 +343,12 @@ def per_atom_energy(self): @property def volume(self): + """ + Calculate the volume of the unit cell. + + Returns: + float: Volume of the unit cell for periodic systems, np.nan for non-periodic systems. + """ if not self.periodic: return np.nan else: @@ -247,6 +361,17 @@ def from_data( positions: np.ndarray, cell: Optional[np.ndarray], ) -> Geometry: + """ + Create a Geometry instance from atomic numbers, positions, and cell data. + + Args: + numbers (np.ndarray): Array of atomic numbers. + positions (np.ndarray): Array of atomic positions. + cell (Optional[np.ndarray]): Unit cell vectors (or None for non-periodic systems). + + Returns: + Geometry: A new Geometry instance. + """ per_atom = np.recarray(len(numbers), dtype=per_atom_dtype) per_atom.numbers[:] = numbers per_atom.positions[:] = positions @@ -259,6 +384,15 @@ def from_data( @classmethod def from_atoms(cls, atoms: Atoms) -> Geometry: + """ + Create a Geometry instance from an ASE Atoms object. + + Args: + atoms (Atoms): ASE Atoms object. + + Returns: + Geometry: A new Geometry instance. + """ per_atom = np.recarray(len(atoms), dtype=per_atom_dtype) per_atom.numbers[:] = atoms.numbers.astype(np.uint8) per_atom.positions[:] = atoms.get_positions() @@ -279,6 +413,12 @@ def from_atoms(cls, atoms: Atoms) -> Geometry: def new_nullstate(): + """ + Create a new null state Geometry. + + Returns: + Geometry: A Geometry instance representing a null state. + """ return Geometry.from_data(np.zeros(1), np.zeros((1, 3)), None) @@ -287,6 +427,15 @@ def new_nullstate(): def is_lower_triangular(cell: np.ndarray) -> bool: + """ + Check if a cell matrix is lower triangular. + + Args: + cell (np.ndarray): 3x3 cell matrix. + + Returns: + bool: True if the cell matrix is lower triangular, False otherwise. + """ return ( cell[0, 0] > 0 and cell[1, 1] > 0 # positive volumes @@ -298,6 +447,15 @@ def is_lower_triangular(cell: np.ndarray) -> bool: def is_reduced(cell: np.ndarray) -> bool: + """ + Check if a cell matrix is in reduced form. + + Args: + cell (np.ndarray): 3x3 cell matrix. + + Returns: + bool: True if the cell matrix is in reduced form, False otherwise. + """ return ( cell[0, 0] > abs(2 * cell[1, 0]) and cell[0, 0] > abs(2 * cell[2, 0]) # b mostly along y axis @@ -318,6 +476,10 @@ def transform_lower_triangular( keyword. The box vector lengths and angles remain exactly the same. + Args: + pos (np.ndarray): Array of atomic positions. + cell (np.ndarray): 3x3 cell matrix. + reorder (bool, optional): Whether to reorder lattice vectors. Defaults to False. """ if reorder: # reorder box vectors as k, l, m with |k| >= |l| >= |m| norms = np.linalg.norm(cell, axis=1) @@ -360,6 +522,15 @@ def reduce_box_vectors(cell: np.ndarray): @typeguard.typechecked def get_mass_matrix(geometry: Geometry) -> np.ndarray: + """ + Compute the mass matrix for a given geometry. + + Args: + geometry (Geometry): Input geometry. + + Returns: + np.ndarray: Mass matrix. + """ masses = np.repeat( np.array([atomic_masses[n] for n in geometry.per_atom.numbers]), 3, @@ -370,6 +541,16 @@ def get_mass_matrix(geometry: Geometry) -> np.ndarray: @typeguard.typechecked def mass_weight(hessian: np.ndarray, geometry: Geometry) -> np.ndarray: + """ + Apply mass-weighting to a Hessian matrix. + + Args: + hessian (np.ndarray): Input Hessian matrix. + geometry (Geometry): Geometry associated with the Hessian. + + Returns: + np.ndarray: Mass-weighted Hessian matrix. + """ assert hessian.shape[0] == hessian.shape[1] assert len(geometry) * 3 == hessian.shape[0] return hessian * get_mass_matrix(geometry) @@ -377,12 +558,32 @@ def mass_weight(hessian: np.ndarray, geometry: Geometry) -> np.ndarray: @typeguard.typechecked def mass_unweight(hessian: np.ndarray, geometry: Geometry) -> np.ndarray: + """ + Remove mass-weighting from a Hessian matrix. + + Args: + hessian (np.ndarray): Input mass-weighted Hessian matrix. + geometry (Geometry): Geometry associated with the Hessian. + + Returns: + np.ndarray: Unweighted Hessian matrix. + """ assert hessian.shape[0] == hessian.shape[1] assert len(geometry) * 3 == hessian.shape[0] return hessian / get_mass_matrix(geometry) def create_outputs(quantities: list[str], data: list[Geometry]) -> list[np.ndarray]: + """ + Create output arrays for specified quantities from a list of Geometry instances. + + Args: + quantities (list[str]): List of quantity names to extract. + data (list[Geometry]): List of Geometry instances. + + Returns: + list[np.ndarray]: List of arrays containing the requested quantities. + """ order_names = list(set([k for g in data for k in g.order])) assert all([q in QUANTITIES + order_names for q in quantities]) natoms = np.array([len(geometry) for geometry in data], dtype=int) @@ -433,6 +634,17 @@ def _assign_identifier( identifier: int, discard: bool = False, ) -> tuple[Geometry, int]: + """ + Assign an identifier to a Geometry instance. + + Args: + state (Geometry): Input Geometry instance. + identifier (int): Identifier to assign. + discard (bool, optional): Whether to discard the state. Defaults to False. + + Returns: + tuple[Geometry, int]: Updated Geometry and next available identifier. + """ if (state == NullState) or discard: return state, identifier else: @@ -449,6 +661,16 @@ def _check_equality( state0: Geometry, state1: Geometry, ) -> bool: + """ + Check if two Geometry instances are equal. + + Args: + state0 (Geometry): First Geometry instance. + state1 (Geometry): Second Geometry instance. + + Returns: + bool: True if the Geometry instances are equal, False otherwise. + """ return state0 == state1