Source code for hugr.std.int
"""HUGR integer types and operations."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar
from typing_extensions import Self
from hugr import ext, tys, val
from hugr.ops import AsExtOp, DataflowOp, ExtOp, RegisteredOp
from hugr.std import _load_extension
if TYPE_CHECKING:
from hugr.ops import Command, ComWire
CONVERSIONS_EXTENSION = _load_extension("arithmetic.conversions")
INT_TYPES_EXTENSION = _load_extension("arithmetic.int.types")
_INT_PARAM = tys.BoundedNatParam(7)
INT_T_DEF = INT_TYPES_EXTENSION.types["int"]
[docs]
def int_t(width: int) -> tys.ExtType:
"""Create an integer type with a fixed log bit width.
Args:
width: The log bit width of the integer.
Returns:
The integer type.
Examples:
>>> int_t(5).type_def.name # 32 bit integer
'int'
"""
return INT_T_DEF.instantiate(
[tys.BoundedNatArg(n=width)],
)
def _int_tv(index: int) -> tys.ExtType:
return INT_T_DEF.instantiate(
[tys.VariableArg(idx=index, param=_INT_PARAM)],
)
#: HUGR 32-bit integer type.
INT_T = int_t(5)
[docs]
@dataclass
class IntVal(val.ExtensionValue):
"""Custom value for an integer."""
v: int
width: int = field(default=5)
[docs]
def to_value(self) -> val.Extension:
name = "ConstInt"
payload = {"log_width": self.width, "value": self.v}
return val.Extension(
name,
typ=int_t(self.width),
val=payload,
extensions=[INT_TYPES_EXTENSION.name],
)
def __str__(self) -> str:
return f"{self.v}"
INT_OPS_EXTENSION = _load_extension("arithmetic.int")
@dataclass(frozen=True)
class _DivModDef(RegisteredOp):
"""DivMod operation, has two inputs and two outputs."""
width: int = 5
const_op_def: ClassVar[ext.OpDef] = INT_OPS_EXTENSION.operations["idivmod_u"]
def type_args(self) -> list[tys.TypeArg]:
return [tys.BoundedNatArg(n=self.width)]
def cached_signature(self) -> tys.FunctionType | None:
row: list[tys.Type] = [int_t(self.width)] * 2
return tys.FunctionType.endo(row)
@classmethod
def from_ext(cls, custom: ExtOp) -> Self | None:
if custom.op_def() != cls.op_def():
return None
match custom.args:
case [tys.BoundedNatArg(n=a1)]:
return cls(width=a1)
case _:
msg = f"Invalid args: {custom.args}"
raise AsExtOp.InvalidExtOp(msg)
def __call__(self, a: ComWire, b: ComWire) -> Command:
return DataflowOp.__call__(self, a, b)
#: DivMod operation.
DivMod = _DivModDef()