# Copyright 2021-2024 Cambridge Quantum Computing Ltd.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Numerical Backend=================Module unifying the use of numerical backends for lambeq. This module isused to provide a common interface to different numerical backends,such as NumPy, JAX, PyTorch, and TensorFlow."""from__future__importannotationsfromcontextlibimportcontextmanagerfromtypesimportModuleTypefromtypingimportCallable,Generator
[docs]classBackend:""" A matrix backend. Parameters: module : The main module of the backend. array : The array class of the backend. """
[docs]@contextmanagerdefbackend(name:str|None=None,_stack=['numpy'],# noqa: B006_cache=dict())->Generator[Backend,None,None]:# noqa: B006""" Context manager for matrix backend. Parameters: name : The name of the backend, default is ``"numpy"``. """name=nameor_stack[-1]_stack.append(name)try:ifnamenotin_cache:_cache[name]=BACKENDS[name]()yield_cache[name]finally:_stack.pop()
[docs]defset_backend(name:str)->None:""" Override the default backend. Parameters: name : The name of the backend. """backend.__wrapped__.__defaults__[1][-1]=name# type: ignore[attr-defined] # noqa: E501
[docs]defget_backend()->Backend:""" Get the current backend. Example ------- >>> set_backend('jax') >>> assert isinstance(get_backend(), JAX) >>> set_backend('numpy') >>> assert isinstance(get_backend(), NumPy) """withbackend()asresult:returnresult