# Copyright 2019-2023 Quantinuum
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy, deepcopy
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, List, OrderedDict, Tuple, Union, cast
import networkx as nx # type: ignore
from .graphviz import _taskgraph_to_graphviz
from .mittask import (
IOTask,
MitTask,
Wire,
)
if TYPE_CHECKING:
import graphviz as gv # type: ignore
[docs]
class TaskGraph:
"""
The TaskGraph class stores a networkx graph where vertices
are pure functions or tasks, and edges hold data.
In the TaskGraph class these tasks and edges have no
type restrictions, though for the run method to be succesful, the
types of ports edges are attached to must match.
:param _label: Name for identification of TaskGraph object.
"""
[docs]
def __init__(
self,
_label: str = "TaskGraph",
) -> None:
# set member variables
self._label = _label
self.G = None
self.characterisation: dict = {}
# default constructor runs all circuits through passed Backend
self._task_graph = nx.MultiDiGraph()
self._i, self._o = IOTask.Input, IOTask.Output
self._task_graph.add_edge(self._i, self._o, key=(0, 0), data=None)
# if requested, all data is held in cache and can be accessed after running
self._cache: OrderedDict[str, Tuple[MitTask, List[Wire]]] = OrderedDict()
[docs]
def from_TaskGraph(self, task_graph: "TaskGraph"):
"""
Returns a new TaskGraph object from another TaskGraph object.
:param task_graph: TaskGraph object to copy tasks from.
:return: Copied TaskGraph
"""
self._task_graph = deepcopy(task_graph._task_graph)
self._label = task_graph._label
self.characterisation = task_graph.characterisation
return self
@property
def tasks(self) -> List[MitTask]:
"""
Returns a list of all tasks with both input and output ports
in the TaskGraph.
"""
return list(self._task_graph)[2:]
[docs]
def __call__(self, input_wires: List[Wire]) -> Tuple[List[Wire]]:
return self.run(input_wires)
@property
def label(self) -> str:
return self._label
def get_characterisation(self) -> dict:
return self.characterisation
def update_characterisation(self, characterisation: dict):
self.characterisation.update(characterisation)
def set_characterisation(self, characterisation: dict):
self.characterisation = characterisation
@property
def n_in_wires(self) -> int:
"""
The number of in wires to a TaskGraph object is defined as the number
of out edges from the Input Vertex, as when called, a TaskGraph object
calls the run method which stores input arguments as data on Input vertex
output edges.
"""
return len(self._task_graph.out_edges(self._i))
@property
def n_out_wires(self) -> int:
"""
The number of out wires to a TaskGraph object is defined as the number
of in edges to the Input Vertex, as when called, a TaskGraph object
calls the run method which after running all tasks, returns
the data on input edges to the Output Vertex as a tuple.
"""
return len(self._task_graph.in_edges(self._o))
[docs]
def check_prepend_wires(self, task: Union[MitTask, "TaskGraph"]) -> bool:
"""
Confirms that the number of out wires of the proposed task to prepend to the
internal task_graph attribute matches the number of in wires to the graph.
:param task: Wrapped pure function to prepend to graph
:return: True if prepend permitted
"""
return task.n_out_wires == self.n_in_wires
[docs]
def check_append_wires(self, task: Union[MitTask, "TaskGraph"]) -> bool:
"""
Confirms that the number of in wires of the proposed task to append to the
internal task_graph attribute matches the number of out wires to the graph.
:param task: Wrapped pure function to append to graph
:return: True if append permitted
"""
return task.n_in_wires == self.n_out_wires
[docs]
def __str__(self):
return f"<TaskGraph::{self._label}>"
[docs]
def __repr__(self):
return str(self)
[docs]
def add_n_wires(self, num_wires: int):
"""
Adds num_wires number of edges between the input vertex
and output vertex, with no type restrictions.
:param num_wires: Number of edges to add between input and output vertices.
"""
for _ in range(num_wires):
in_port = len(self._task_graph.out_edges(self._i, data=True))
out_port = len(self._task_graph.in_edges(self._o, data=True))
self._task_graph.add_edge(
self._i, self._o, key=(in_port, out_port), data=None
)
[docs]
def add_wire(self):
"""
Adds a single edge between the input vertex and output vertex.
"""
self.add_n_wires(1)
# Add news task to start of TaskGraph
[docs]
def prepend(self, task: Union[MitTask, "TaskGraph"]):
"""
Inserts new task to the start of TaskGraph._task_graph.
All out edges from the Input vertex are wired as out edges from the task in the same port ordering (types must match).
New edges also added from the Input vertex to the task (any type permitted), ports ordered in arguments order.
:param task: New task to be prepended.
"""
assert self.check_prepend_wires(task)
# It's possible a single generated MitTask object could be used in different TaskGraph objects
# via prepend which may lead to a task address expecting input wires from different graphs
# use of copy here prevents this and task graph generation is not the bottleneck in running mitigation
# schemes so fine
task_copy = copy(task)
for i, edge in enumerate(list(self._task_graph.out_edges(self._i, keys=True))):
self._task_graph.add_edge(
task_copy, edge[1], key=(i, edge[2][1]), data=None
)
self._task_graph.remove_edge(edge[0], edge[1])
for port in range(task_copy.n_in_wires):
self._task_graph.add_edge(self._i, task_copy, key=(port, port), data=None)
[docs]
def append(self, task: Union[MitTask, "TaskGraph"]):
"""
Inserts new task to end of TaskGraph._task_graph.
All in edges to Output vertex are wired as in edges to task in same port ordering (types must match).
New edges added from task to Output vertex (any type permitted), ports ordered in arguments order.
:param task: New task to be appended.
"""
assert self.check_append_wires(task)
# It's possible a single generated MitTask object could be used in different TaskGraph objects
# via append which may lead to a task address expecting input wires from different graphs
# use of copy here prevents this and task graph generation is not the bottleneck in running mitigation
# schemes so fine
task_copy = copy(task)
for edge in list(self._task_graph.in_edges(self._o, keys=True)):
self._task_graph.add_edge(edge[0], task_copy, key=edge[2], data=None)
self._task_graph.remove_edge(edge[0], edge[1])
for port in range(task_copy.n_out_wires):
self._task_graph.add_edge(task_copy, self._o, key=(port, port), data=None)
[docs]
def decompose_TaskGraph_nodes(self):
"""
For each node in self._task_graph, if node is a TaskGraph object, substitutes that node
with the _task_graph structure held inside the node.
"""
check_for_decompose = True
while check_for_decompose:
# get all nodes and iterate through them
check_for_decompose = False
node_list = list(nx.topological_sort(self._task_graph))
for task in node_list:
# => TaskGraph object with _task_graph attribute for decomposition
if hasattr(task, "_task_graph"):
# relabel task names for ease of viewing wit visualisation methods
for sub_task in list(task._task_graph.nodes):
# in practice only IOTask
if hasattr(sub_task, "_label"):
sub_task._label = task._label + sub_task._label
task_in_edges = list(self._task_graph.in_edges(task, keys=True))
task_out_edges = list(self._task_graph.out_edges(task, keys=True))
task_input_out_edges = list(
task._task_graph.out_edges(task._i, keys=True)
)
task_output_in_edges = list(
task._task_graph.in_edges(task._o, keys=True)
)
if (
len(
set(task_input_out_edges).intersection(
set(task_output_in_edges)
)
)
> 0
):
raise ValueError(
"Decomposition of TaskGraph node {}, not permitted: TaskGraph to be decomposed has edge between input and output vertices.".format(
task
)
)
# These two cases imply faulty TaskGraph construction
# Note that this check is only made as this is necessary constraint for decomposing TaskGraph nodes
# Faulty construction should be caught at construction of TaskGraph object, including types
if len(task_in_edges) != len(task_input_out_edges):
raise TypeError(
"Decomposition of TaskGraph node {} not permitted: node "
"expects {} input wires but receives {}.".format(
task, len(task_input_out_edges), len(task_in_edges)
)
)
if len(task_out_edges) != len(task_output_in_edges):
raise TypeError(
"Decomposition of TaskGraph node {} not permitted: task_graph "
"expects {} output wires but node returns {}.".format(
task, len(task_output_in_edges), len(task_out_edges)
)
)
# remove all input and output edges from task._task_graph
# remove all input and output edges from self._task_graph task
# replace in_edges to task in self._task_graph with in_edges to first tasks in task
for outside_edge, inside_edge in zip(
task_in_edges, task_input_out_edges
):
task._task_graph.remove_edge(inside_edge[0], inside_edge[1])
self._task_graph.remove_edge(outside_edge[0], outside_edge[1])
self._task_graph.add_edge(
outside_edge[0],
inside_edge[1],
key=(outside_edge[2][0], inside_edge[2][1]),
data=None,
)
for outside_edge, inside_edge in zip(
task_out_edges, task_output_in_edges
):
task._task_graph.remove_edge(inside_edge[0], inside_edge[1])
self._task_graph.remove_edge(outside_edge[0], outside_edge[1])
self._task_graph.add_edge(
inside_edge[0],
outside_edge[1],
key=(inside_edge[2][0], outside_edge[2][1]),
data=None,
)
# add all remaining edges, filling the rest of the subsituted graph
self._task_graph.add_edges_from(task._task_graph.edges)
self._task_graph.remove_node(task)
check_for_decompose = True
break
[docs]
def parallel(self, task: Union[MitTask, "TaskGraph"]):
"""
Adds new MitTask/TaskGraph to TaskGraph object in parallel. All task in edges wired as out edges from Input vertex. All task out_Edges wired as in edges to Output Vertex.
:param task: New task to be added in parallel.
"""
task = copy(task)
base_n_input_outs = len(self._task_graph.out_edges(self._i))
for port in range(task.n_in_wires):
self._task_graph.add_edge(
self._i,
task,
key=(base_n_input_outs + port, port),
data=None,
)
base_n_output_ins = len(self._task_graph.in_edges(self._o))
for port in range(task.n_out_wires):
self._task_graph.add_edge(
task,
self._o,
key=(port, base_n_output_ins + port),
data=None,
)
[docs]
def run(
self, input_wires: List[Wire], cache: bool = False, characterisation: dict = {}
) -> Tuple[List[Wire]]:
"""
Each task in TaskGraph is a pure function that produces output data
from input data to some specification. Data is stored on edges of the
internal _task_graph object.
The run method first assigns each wire in the list to an output edge
of the input vertex. If more wires are passed than there are edges, wire
information is not assigned to some edge. If less wires are passed then there are edges,
some edges will have no data and later computation will likely fail.
Nodes (holding either MitTask or TaskGraph callable objects) on the graph undergo a topological sort to order them, and are then
executed sequentially.
To be executed, data from in edges to a task are passed arguments to the tasks
_method, and data returned from method are assigned to out edges from a task.
This process is repeated until all tasks are run, at which point all in edges
to the output vertex wil have data on, with each data set returned in a tuple.
:param input_wires: Each Wire holds information assigned as data to an output edge
from the input vertex of the _task_graph.
:param cache: If True each Tasks output data is stored in an OrderedDict with the
Task._label attribute as its key.
:return: Data from input edges to output vertex, assigned as wires.
"""
for edge, wire in zip(
self._task_graph.out_edges(self._i, data=True), input_wires
):
edge[2]["data"] = wire
# topological_sort fixes any dependency issues so can iterate and assume
# input wires all realised before a task is reached
node_list = list(nx.topological_sort(self._task_graph))
self.characterisation.update(characterisation)
# clear cache of held data if required
# also check that all mittask label are unique else dict will fail
if cache:
unique_labels = set()
for task in node_list:
if task not in (self._i, self._o):
unique_labels.add(task._label)
if len(unique_labels) != len(self._task_graph) - 2:
raise ValueError(
"Cache can't store all information as not all MitTask labels are unique."
)
else:
self._cache.clear()
for task in node_list:
# nothing to process
if task in (self._i, self._o):
continue
task.characterisation.update(self.characterisation)
# get all input data and store on inputs for task
in_edges = self._task_graph.in_edges(task, data=True, keys=True)
inputs = [None] * len(in_edges)
for _, _, ports, i_data in in_edges:
assert i_data["data"] is not None
inputs[ports[1]] = i_data["data"]
# run held task
outputs = task(inputs)
self.characterisation.update(task.characterisation)
if cache:
self._cache[task._label] = (task, outputs)
# assign outputs ot out_edges of task
out_edges = self._task_graph.out_edges(task, data=True, keys=True)
assert len(out_edges) == len(outputs)
for _, _, ports, o_data in out_edges:
o_data["data"] = outputs[ports[0]]
output_wire = [
edge[2]["data"]
for edge in list(self._task_graph.in_edges(self._o, data=True))
]
return cast(Tuple[List[Wire]], tuple(output_wire))
[docs]
def get_cache(self) -> OrderedDict[str, Tuple[MitTask, List[Wire]]]:
"""
:returns: Dictionary holding all output data from all MitTask.
This is only full after run is called with the cache argument set
to True. Keys are stored in graph topological order.
"""
return self._cache
[docs]
def get_task_graph(self) -> "gv.Digraph":
"""
Return a visual representation of the DAG as a graphviz object.
:returns: Representation of the DAG
"""
return _taskgraph_to_graphviz(self._task_graph, None, self._label)
[docs]
def view_task_graph(self) -> None:
"""
View the DAG.
"""
G = self.get_task_graph()
file = NamedTemporaryFile()
G.view(file.name, quiet=True)