Source code for hugr.std.int

"""HUGR integer types and operations."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar

from typing_extensions import Self

import hugr.model as model
from hugr import ext, tys, val
from hugr.ops import AsExtOp, DataflowOp, ExtOp, RegisteredOp
from hugr.std import _load_extension

if TYPE_CHECKING:
    from hugr.ops import Command, ComWire

CONVERSIONS_EXTENSION = _load_extension("arithmetic.conversions")

INT_TYPES_EXTENSION = _load_extension("arithmetic.int.types")
_INT_PARAM = tys.BoundedNatParam(7)

INT_T_DEF = INT_TYPES_EXTENSION.types["int"]


[docs] def int_t(width: int) -> tys.ExtType: """Create an integer type with a fixed log bit width. Args: width: The log bit width of the integer. Returns: The integer type. Examples: >>> int_t(5).type_def.name # 32 bit integer 'int' """ return INT_T_DEF.instantiate( [tys.BoundedNatArg(n=width)], )
def _int_tv(index: int) -> tys.ExtType: return INT_T_DEF.instantiate( [tys.VariableArg(idx=index, param=_INT_PARAM)], ) #: HUGR 32-bit integer type. INT_T = int_t(5) def _to_unsigned(val: int, bits: int) -> int: """Convert a signed integer to its unsigned representation in twos-complement form. Positive integers are unchanged, while negative integers are converted by adding 2^bits to the value. Raises ValueError if the value is out of range for the given bit width (valid range is [-2^(bits-1), 2^(bits-1)-1]). """ half_max = 1 << (bits - 1) min_val = -half_max max_val = half_max - 1 if val < min_val or val > max_val: msg = f"Value {val} out of range for {bits}-bit signed integer." raise ValueError(msg) # if val < 0: return (1 << bits) + val return val
[docs] @dataclass class IntVal(val.ExtensionValue): """Custom value for a signed integer.""" v: int width: int = field(default=5)
[docs] def to_value(self) -> val.Extension: name = "ConstInt" unsigned = _to_unsigned(self.v, 1 << self.width) payload = {"log_width": self.width, "value": unsigned} return val.Extension( name, typ=int_t(self.width), val=payload, )
def __str__(self) -> str: return f"{self.v}" def to_model(self) -> model.Term: unsigned = _to_unsigned(self.v, 1 << self.width) return model.Apply( "arithmetic.int.const", [model.Literal(self.width), model.Literal(unsigned)] )
INT_OPS_EXTENSION = _load_extension("arithmetic.int") @dataclass(frozen=True) class _DivModDef(RegisteredOp): """DivMod operation, has two inputs and two outputs.""" width: int = 5 const_op_def: ClassVar[ext.OpDef] = INT_OPS_EXTENSION.operations["idivmod_u"] def type_args(self) -> list[tys.TypeArg]: return [tys.BoundedNatArg(n=self.width)] def cached_signature(self) -> tys.FunctionType | None: row: list[tys.Type] = [int_t(self.width)] * 2 return tys.FunctionType.endo(row) @classmethod def from_ext(cls, custom: ExtOp) -> Self | None: if custom.op_def() != cls.op_def(): return None match custom.args: case [tys.BoundedNatArg(n=a1)]: return cls(width=a1) case _: msg = f"Invalid args: {custom.args}" raise AsExtOp.InvalidExtOp(msg) def __call__(self, a: ComWire, b: ComWire) -> Command: return DataflowOp.__call__(self, a, b) #: DivMod operation. DivMod = _DivModDef()