Source code for tierkreis.pyruntime.python_runtime

"""Implementation of simple python-only runtime."""

import asyncio
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Tuple, cast

import networkx as nx
import requests

from tierkreis.client.runtime_client import RuntimeClient
from tierkreis.core import Labels
from tierkreis.core.function import FunctionName
from tierkreis.core.protos.tierkreis.v1alpha1.graph import Output, OutputStream
from tierkreis.core.signature import Signature
from tierkreis.core.tierkreis_graph import (
    BoxNode,
    ConstNode,
    FunctionNode,
    GraphValue,
    IncomingWireType,
    InputNode,
    MatchNode,
    OutputNode,
    TagNode,
    TierkreisEdge,
    TierkreisGraph,
)
from tierkreis.core.type_errors import TierkreisTypeErrors
from tierkreis.core.type_inference import _TYPE_CHECK, infer_graph_types
from tierkreis.core.utils import map_vals
from tierkreis.core.values import StructValue, TierkreisValue, VariantValue, VecValue
from tierkreis.pyruntime import python_builtin

if TYPE_CHECKING:
    from tierkreis.core.tierkreis_graph import _EdgeData
    from tierkreis.worker.namespace import Namespace


class _ValueNotFound(Exception):
    def __init__(self, edge: TierkreisEdge) -> None:
        self.edge = edge
        super().__init__(f"Value not found on edge {edge.source} -> {edge.target}")


[docs] class OutputNotFound(_ValueNotFound): """Node output expected but not found."""
[docs] class InputNotFound(_ValueNotFound): """Node input expected but not found."""
[docs] class FunctionNotFound(Exception): """Function expected but not found.""" def __init__(self, fname: FunctionName) -> None: self.function = fname super().__init__(f"Function {fname} not found in namespace.")
[docs] class PyRuntime(RuntimeClient): """A simplified python-only Tierkreis runtime. Can be used with builtin operations and python only namespaces that are locally available.""" def __init__(self, roots: Iterable["Namespace"], num_workers: int = 1): """Initialise with locally available namespaces, and the number of workers (asyncio tasks) to use in execution.""" self.root = deepcopy(python_builtin.namespace) for root in roots: self.root.merge_namespace(root) self.num_workers = num_workers self._callback: Optional[Callable[[TierkreisEdge, TierkreisValue], None]] = None self.set_callback(None)
[docs] def set_callback( self, callback: Optional[Callable[[TierkreisEdge, TierkreisValue], None]] ): """Set a callback function that takes a TierkreisEdge and TierkreisValue, which will be called every time a edge receives an output. Can be used to inspect intermediate values.""" self._callback = callback
[docs] def callback( self, edge: TierkreisEdge, val: TierkreisValue, ): """If a callback function is set, call it with an edge and the value on the edge.""" if self._callback: self._callback(edge, val)
[docs] async def run_graph( self, run_g: TierkreisGraph, /, **py_inputs: Any, ) -> dict[str, TierkreisValue]: """Run a tierkreis graph using the python runtime, and provided inputs. Returns the outputs of the graph. """ total_nodes = run_g._graph.number_of_nodes() runtime_state: dict["_EdgeData", TierkreisValue] = {} async def run_node(node: int) -> dict[str, TierkreisValue]: tk_node = run_g[node] if isinstance(tk_node, OutputNode): return {} if isinstance(tk_node, InputNode): return map_vals(py_inputs, TierkreisValue.from_python) if isinstance(tk_node, ConstNode): return {Labels.VALUE: tk_node.value} in_edges = list(run_g.in_edges(node)) while not all(e.to_edge_handle() in runtime_state for e in in_edges): # wait for inputs to become available # only useful if there are other workers that can do things # while this one waits assert self.num_workers > 1 await asyncio.sleep(0) try: in_values = ( (e, runtime_state.pop(e.to_edge_handle())) for e in in_edges ) except KeyError as key_e: raise InputNotFound(run_g._to_tkedge(key_e.args[0])) from key_e inps = {e.target.port: val for e, val in in_values} if isinstance(tk_node, FunctionNode): fname = tk_node.function_name if fname.namespaces == [] and fname.name == "eval": return await self._run_eval(inps) elif fname.namespaces == [] and fname.name == "loop": return await self._run_loop(inps) elif fname.namespaces == [] and fname.name == "map": return await self._run_map(inps) else: function = self.root.get_function(fname) if function is None: raise FunctionNotFound(fname) # For now the PyRuntime does not provide a stack trace return (await function.run(self, dict(), StructValue(inps))).values elif isinstance(tk_node, BoxNode): return await self.run_graph( tk_node.graph, **inps, ) elif isinstance(tk_node, MatchNode): return self._run_match(inps) elif isinstance(tk_node, TagNode): return { Labels.VALUE: VariantValue(tk_node.tag_name, inps[Labels.VALUE]) } else: raise RuntimeError("Unknown node type.") async def worker(queue: asyncio.Queue[int]): # each worker gets the next node in the queue while True: node = await queue.get() # If the node is not yet runnable, # wait/block until it is, do not try to run any other node outs = await run_node(node) # assign outputs to edges for out_edge in run_g.out_edges(node): try: val = outs.pop(out_edge.source.port) except KeyError as key_e: raise OutputNotFound(out_edge) from key_e tkval = TierkreisValue.from_python(val) self.callback(out_edge, tkval) runtime_state[out_edge.to_edge_handle()] = tkval # signal this node is now done queue.task_done() que: asyncio.Queue[int] = asyncio.Queue(total_nodes) for node in nx.topological_sort(run_g._graph): # add all node names to the queue in topsort order # if there are fewer workers than nodes, and the queue is populated # in a non-topsort order, some worker may just wait forever for it's # node's inputs to become available. que.put_nowait(node) workers = [asyncio.create_task(worker(que)) for _ in range(self.num_workers)] queue_complete = asyncio.create_task(que.join()) # wait for either all nodes to complete, or for a worker to return await asyncio.wait( [queue_complete, *workers], return_when=asyncio.FIRST_COMPLETED ) if not queue_complete.done(): # If the queue hasn't completed, it means one of the workers has # raised - find it and propagate the exception. # even if the rest of the graph has not completed for t in workers: if t.done(): t.result() # this will raise for task in workers: task.cancel() # Wait until all worker tasks are cancelled. await asyncio.gather(*workers, return_exceptions=True) return { e.target.port: runtime_state.pop(e.to_edge_handle()) for e in run_g.in_edges(run_g.output_node_idx) }
async def _run_eval( self, ins: dict[str, TierkreisValue] ) -> dict[str, TierkreisValue]: thunk = cast(GraphValue, ins.pop(Labels.THUNK)).value return await self.run_graph(thunk, **ins) async def _run_loop( self, ins: dict[str, TierkreisValue] ) -> dict[str, TierkreisValue]: body = cast(GraphValue, ins.pop("body")).value while True: outs = await self.run_graph( body, **ins, ) out = cast( VariantValue, outs[Labels.VALUE], ) nxt = {"value": out.value} if out.tag == Labels.BREAK: return nxt else: ins = nxt async def _run_map( self, ins: dict[str, TierkreisValue] ) -> dict[str, TierkreisValue]: body = cast(GraphValue, ins.pop("thunk")).value inputs = cast(VecValue, ins.pop("value")).values async def task(x): return (await self.run_graph(body, value=x))[Labels.VALUE] tasks = [asyncio.create_task(task(x)) for x in inputs] out = await asyncio.gather(*tasks) ret = {"value": cast(TierkreisValue, VecValue(out))} return ret def _run_match(self, ins: dict[str, TierkreisValue]) -> dict[str, TierkreisValue]: variant = cast(VariantValue, ins[Labels.VARIANT_VALUE]) thunk = cast(GraphValue, ins[variant.tag]).value newg = TierkreisGraph() boxinps: dict[str, IncomingWireType] = { inp: newg.input[inp] for inp in thunk.inputs() } boxinps[Labels.VALUE] = newg.add_const(variant.value) box = newg.add_box(thunk, **boxinps) newg.set_outputs(**{out: box[out] for out in thunk.outputs()}) return {Labels.THUNK: GraphValue(newg)}
[docs] async def get_signature(self) -> Signature: return self.root.extract_signature(True)
[docs] async def type_check_graph(self, graph: TierkreisGraph) -> TierkreisGraph: return infer_graph_types(graph, await self.get_signature())
async def type_check_graph_with_inputs( self, tg, inputs: StructValue ) -> Tuple[TierkreisGraph, StructValue]: return infer_graph_types(tg, await self.get_signature(), inputs) @property def can_type_check(self) -> bool: return _TYPE_CHECK
[docs] class VizRuntime(PyRuntime): """Child class of ``PyRuntime`` that can interact with a tierkreis-viz instance for live graph visualization.""" def __init__(self, url: str, roots: Iterable["Namespace"], num_workers: int = 1): """`url` is the address of the running tierkreis-viz instance. See `PyRuntime` for remaining parameters """ self.url = url self.outputs = OutputStream() super().__init__(roots, num_workers) def _post(self, endpoint: str, data): proto_dat = data.to_proto() if hasattr(data, "to_proto") else data requests.post( self.url + endpoint, data=bytes(proto_dat), headers={"content-type": "application/protobuf"}, )
[docs] async def type_check_graph(self, graph: TierkreisGraph) -> TierkreisGraph: """See ``PyRuntime.type_check_graph``. Additionally updates visualized graph with type annotations.""" try: typedg = await super().type_check_graph(graph) except TierkreisTypeErrors as e: self._post("/api/typeErrors", e) raise e self.viz_graph(typedg) return typedg
[docs] def viz_graph(self, tg: TierkreisGraph): """Send graph to be visualized.""" self._post("/api/graph", tg) self._post("/api/streamList", OutputStream()) self._post("/api/typeErrors", TierkreisTypeErrors([]))
[docs] async def run_viz_graph( self, run_g: TierkreisGraph, /, **py_inputs: Any, ) -> dict[str, TierkreisValue]: """See ``PyRuntime.run_graph``. Additionally updates the visualization with the outputs of each node when they are available.""" self.outputs = OutputStream() return await self.run_graph(run_g, **py_inputs)
[docs] def callback( self, edge: TierkreisEdge, val: TierkreisValue, ): super().callback(edge, val) self.outputs.stream.append(Output(edge=edge.to_proto(), value=val.to_proto())) self._post("/api/streamList", self.outputs)