Source code for lambeq.training.tn_path_optimizer

# Copyright 2021-2025 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Opt-Einsum contractor for reusing previously computed paths.
"""

import functools
from pathlib import Path
import pickle
from typing import (
    Collection, Iterable, Sequence, TypeVar
)

import opt_einsum as oe     # type: ignore[import-untyped]
from tensornetwork import (
    AbstractNode, contract_between, contract_parallel,
    Edge, get_all_edges, get_subgraph_dangling
)
from tensornetwork.contractors.opt_einsum_paths import utils

from lambeq.training.checkpoint import Checkpoint


T = TypeVar('T')
ContractionPath = Collection[tuple[int, ...]]
ContractionKey = tuple[
    tuple[tuple[str, ...], ...],
    tuple[str, ...],
    tuple[tuple[str, int], ...]
]


def _dedup(seq: Iterable[T]) -> list[T]:
    """
    Remove duplicates from list while maintaining order of the items.
    Note: this method was taken from
    https://stackoverflow.com/questions/480214/how-do-i-remove-duplicates-
    from-a-list-while-preserving-order
    """
    seen: set[T] = set()
    # Alias add() so python doesn't have to check it each time we call it
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]


[docs] class TnPathOptimizer(oe.paths.PathOptimizer): """Opt-einsum custom optimizer.""" _optimizer: functools.partial[ContractionPath] memory_limit: int | None
[docs] def __init__( self, algorithm: str = 'auto', memory_limit: int | None = None, **kwargs ): """ Parameters ---------- algorithm: :py:class:`str`, default: 'auto' Algorithm type to use when the path is not already cached. recommended options are: - ``auto`` (default): Fast; decent enough paths. - ``auto-hq``: Slow, but finds very high quality paths. Not recommended. - ``random-greedy``: Highly configurable, usually faster than auto-hq and finds better paths than auto. memory_limit: :py:class:`int` Limit the memory usage of the intermediate tensors. This is not recommended, as it will generally make the path finding much slower. If the size is a concern, use the random-greedy algorithm with ``minimize='size'`` instead. kwargs: Extra keyword arguments to pass to the fallback algorithm initializer. These will depend on the chosen fallback algorithm. For random-greedy, the following kwargs are available: - ``max_repeats``: int = 32 - ``max_time``: float - ``minimize``: 'size' | 'flops' = 'flops' - ``parallel``: bool | int = False - whether to run trials in parallel. If a number is specified, use that many processes at once, otherwise use all available CPU cores. """ self.memory_limit = memory_limit self._optimizer = functools.partial( oe.paths.get_path_fn(algorithm), **kwargs )
[docs] def store_to_checkpoint(self, checkpoint: Checkpoint) -> Checkpoint: return checkpoint
[docs] def restore_from_checkpoint(self, checkpoint: Checkpoint): """Load any information saved to a checkpoint. We do not expect to recover init kwargs from the checkpoint - the TnPathOptimizer should be initialised separately.""" pass
[docs] def __call__( self, inputs: list[set[Edge]], output: set[Edge], size_dict: dict[Edge, int], memory_limit: int | None = None, edge_list: list[Edge] | None = None ): return self._optimizer( inputs, output, size_dict, memory_limit=memory_limit or self.memory_limit )
[docs] class CachedTnPathOptimizer(TnPathOptimizer): """ Opt-einsum custom optimizer. Cache computed paths for quick path lookup. """ cached_paths: dict[ContractionKey, ContractionPath] filepath: Path | None save_checkpoints: bool checkpoint_cached_paths_key = 'tn_path_optimizer_cached_paths'
[docs] def __init__( self, algorithm: str = 'auto-hq', save_file: Path | None = None, save_checkpoints: bool = False, **kwargs ): """ Parameters ---------- algorithm: :py:class:`str`, default: 'auto-hq' Fallback algorithm type to use when the path is not already cached. Recommended options are: - ``auto``: Fast, but paths are not that efficient. Use if you do not expect paths to be reused. - ``auto-hq`` (default): Slow, but finds very high quality paths. - ``random-greedy``: Highly configurable, usually faster than auto-hq and finds better paths than auto. Preferred option for larger networks where auto-hq is too slow. save_file: :py:class:`Path` or :py:class:`None` (Optional) filepath to save the cached paths to. File contents is updated with each new path. save_checkpoints: :py:class:`bool`, default: False Whether to include the cached paths in the checkpoints. kwargs: Extra keyword arguments to pass to the fallback algorithm initializer. These will depend on the chosen fallback algorithm. All fallback algorithm above accept a max_memory kwarg: - ``memory_limit``: (optional) int However, it is preferable to use random-greedy with ``minimize='size'`` for better performance. For random-greedy, the following kwargs are available: - ``max_repeats``: int = 32 - ``max_time``: float - ``minimize``: 'size' | 'flops' = 'flops' - ``parallel``: bool | int = False - whether to run trials in parallel. If a number is specified, use that many processes at once, otherwise use all available CPU cores. """ super().__init__(algorithm, **kwargs) self.cached_paths = {} self.filepath = save_file self.save_checkpoints = save_checkpoints if self.filepath is not None: try: with open(self.filepath, 'rb') as f: self.cached_paths = pickle.load(f) except FileNotFoundError: # Make sure the path exists for future use. self.filepath.mkdir(parents=True, exist_ok=True) except EOFError: # No previous paths present; continue. pass
[docs] def __call__( self, inputs: list[set[Edge]], output: set[Edge], size_dict: dict[Edge, int], memory_limit: int | None = None, edge_list: list[Edge] | None = None ) -> ContractionPath: if edge_list is None: raise ValueError( 'Edge list must be supplied in order to cache paths.' ) # Tensornetwork gives us edges with wacky names and arbitrary # order. Make them a single symbol instead, starting from 'a' # for the first input item. edge_map = { edge: oe.get_symbol(i) for i, edge in enumerate(edge_list) } # get a hashable key for this contraction sizes = tuple( (edge_map[edge], size_dict[edge]) for edge in edge_list ) input_tuples = tuple( tuple(sorted(edge_map[e] for e in s)) for s in inputs ) output_tuple = tuple(sorted(edge_map[e] for e in output)) key = (input_tuples, output_tuple, sizes) if key not in self.cached_paths: # Couldn't find a saved path. # Use fallback and save it for next time. path = super().__call__( inputs, output, size_dict, memory_limit, edge_list ) self.add_paths({key: path}) return self.cached_paths[key]
[docs] def add_paths(self, new_paths: dict[ContractionKey, ContractionPath]): """Append new paths to the current cache.""" self.cached_paths = { **self.cached_paths, **new_paths } if self.filepath is not None: with open(self.filepath, 'wb') as f: pickle.dump(self.cached_paths, f)
[docs] def store_to_checkpoint(self, checkpoint: Checkpoint) -> Checkpoint: checkpoint = super().store_to_checkpoint(checkpoint) if self.save_checkpoints: checkpoint.add_many({ self.checkpoint_cached_paths_key: self.cached_paths, }) return checkpoint
[docs] def restore_from_checkpoint(self, checkpoint: Checkpoint): super().restore_from_checkpoint(checkpoint) if self.checkpoint_cached_paths_key in checkpoint: self.cached_paths = checkpoint[self.checkpoint_cached_paths_key]
def ordered_nodes_contractor( nodes: list[AbstractNode], algorithm: TnPathOptimizer, output_edge_order: Sequence[Edge] | None = None, ignore_edge_order: bool = False ) -> AbstractNode: """Copy tensornetwork's base contractor but preserve the node order for caching purposes. Parameters ---------- nodes: :py:class:`list[AbstractNode]` A collection of connected nodes. algorithm: :py:class:`oe.paths.PathOptimizer | SavedTnOptimizer` `opt_einsum` contraction method to use. output_edge_order: :py:class:`Sequence[Edge] | None` An optional list of edges. Edges of the final node in `nodes_set` are reordered into `output_edge_order`; if final node has more than one edge, `output_edge_order` must be provided. ignore_edge_order: :py:class:`bool` An option to ignore the output edge order. Returns ------- :py:class:`AbstractNode` Final node after full contraction. """ nodes_set = set(nodes) edges = get_all_edges(nodes_set) # output edge order has to be determined before any contraction # (edges are refreshed after contractions) if not ignore_edge_order: if output_edge_order is None: output_edge_order = list(get_subgraph_dangling(nodes)) if len(output_edge_order) > 1: raise ValueError( 'The final node after contraction has more than ' 'one remaining edge. In this case `output_edge_order` ' 'has to be provided.' ) if set(output_edge_order) != get_subgraph_dangling(nodes): raise ValueError( 'output edges are not equal to the remaining ' 'non-contracted edges of the final node.' ) for edge in edges: if not edge.is_disabled: # if it is disabled we already contracted it if edge.is_trace(): nodes_set.remove(edge.node1) nodes_set.add(contract_parallel(edge)) if len(nodes_set) == 1: # There's nothing to contract. if ignore_edge_order: return list(nodes_set)[0] return list(nodes_set)[0].reorder_edges(output_edge_order) # Then apply `opt_einsum`'s algorithm path, nodes = _get_path_with_ordered_nodes( _dedup([n for n in nodes if n in nodes_set]), algorithm ) for a, b in path: new_node = contract_between( nodes[a], nodes[b], allow_outer_product=True ) nodes.append(new_node) nodes = utils.multi_remove(nodes, [a, b]) # if the final node has more than one edge, # output_edge_order has to be specified final_node = nodes[0] # nodes were connected, we checked this if not ignore_edge_order: final_node.reorder_edges(output_edge_order) return final_node def _get_path_with_ordered_nodes( nodes: list[AbstractNode], algorithm: TnPathOptimizer ) -> tuple[Collection[tuple[int, ...]], list[AbstractNode]]: """Calculates the contraction paths using `opt_einsum` methods. A copy of the tensornetwork implementation, that uses a consistent node ordering. Parameters ---------- nodes: An iterable of nodes. algorithm: `opt_einsum` method to use for calculating the contraction path. Returns ------- The optimal contraction path as returned by `opt_einsum`. """ input_sets = [set(node.edges) for node in nodes] output_set = get_subgraph_dangling(nodes) size_dict = {edge: edge.dimension for edge in get_all_edges(nodes)} # Fix an edge order, so we can find this same contraction again later edge_list = _dedup( e for n in nodes for e in n.edges ) return algorithm( input_sets, output_set, size_dict, edge_list=edge_list ), nodes