Coverage for /home/runner/work/tket/tket/pytket/pytket/wasm/wasm.py: 90%
136 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-14 11:30 +0000
1# Copyright Quantinuum
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import base64
16import hashlib
17from functools import cached_property
18from os.path import exists
20from qwasm import ( # type: ignore
21 LANG_TYPE_EMPTY,
22 LANG_TYPE_F32,
23 LANG_TYPE_F64,
24 LANG_TYPE_I32,
25 LANG_TYPE_I64,
26 SEC_EXPORT,
27 SEC_FUNCTION,
28 SEC_TYPE,
29 decode_module,
30)
31from typing_extensions import deprecated
34class WasmModuleHandler:
35 """Construct and optionally check a wasm module for use in wasm Ops."""
37 checked: bool
38 _int_size: int
39 _wasm_module: bytes
40 _functions: dict[str, tuple[int, int]]
41 _unsupported_functions: list[str]
43 type_lookup = {
44 LANG_TYPE_I32: "i32",
45 LANG_TYPE_I64: "i64",
46 LANG_TYPE_F32: "f32",
47 LANG_TYPE_F64: "f64",
48 LANG_TYPE_EMPTY: None,
49 }
51 def __init__(
52 self, wasm_module: bytes, check: bool = True, int_size: int = 32
53 ) -> None:
54 """
55 Construct a wasm module handler
57 :param wasm_module: A wasm module in binary format.
58 :type wasm_module: bytes
59 :param check: If ``True`` checks file for compatibility with wasm
60 standards. If ``False`` checks are skipped.
61 :type check: bool
62 :param int_size: length of the integer that is used in the wasm file
63 :type int_size: int
64 """
65 self._int_size = int_size
66 if int_size == 32:
67 self._int_type = self.type_lookup[LANG_TYPE_I32]
68 elif int_size == 64: 68 ↛ 71line 68 didn't jump to line 71 because the condition on line 68 was always true
69 self._int_type = self.type_lookup[LANG_TYPE_I64]
70 else:
71 raise ValueError(
72 "given integer length not valid, only 32 and 64 are allowed"
73 )
75 # stores the names of the functions mapped
76 # to the number of parameters and the number of return values
77 self._functions = dict()
79 # contains the list of functions that are not allowed
80 # to use in pytket (because of types that are not integers
81 # of the supplied int_size.)
82 self._unsupported_functions = []
84 self._wasm_module = wasm_module
85 self.checked = False
87 if check:
88 self.check()
90 def check(self) -> None:
91 """Collect functions from the module that can be used with pytket.
93 Populates the internal list of supported and unsupported functions
94 and marks the module as checked so that subsequent checking is not
95 required.
96 """
97 if self.checked: 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true
98 return
100 function_signatures: list = []
101 function_names: list = []
102 _func_lookup = {}
103 mod_iter = iter(decode_module(self._wasm_module))
104 _, _ = next(mod_iter)
106 for _, cur_sec_data in mod_iter:
107 # read in list of function signatures
108 if cur_sec_data.id == SEC_TYPE:
109 for idx, entry in enumerate(cur_sec_data.payload.entries):
110 function_signatures.append({})
111 function_signatures[idx]["parameter_types"] = [
112 self.type_lookup[pt] for pt in entry.param_types
113 ]
114 if entry.return_count > 1:
115 if ( 115 ↛ 119line 115 didn't jump to line 119 because the condition on line 115 was never true
116 isinstance(entry.return_type, list)
117 and len(entry.return_type) == entry.return_count
118 ):
119 function_signatures[idx]["return_types"] = [
120 self.type_lookup[rt] for rt in entry.return_type
121 ]
122 elif isinstance(entry.return_type, int): 122 ↛ 127line 122 didn't jump to line 127 because the condition on line 122 was always true
123 function_signatures[idx]["return_types"] = [
124 self.type_lookup[entry.return_type]
125 ] * entry.return_count
126 else:
127 raise ValueError(
128 "Only parameter and return values of "
129 f"i{self._int_size} types are"
130 f" allowed, found type: {entry.return_type}"
131 )
132 elif entry.return_count == 1:
133 function_signatures[idx]["return_types"] = [
134 self.type_lookup[entry.return_type]
135 ]
136 else:
137 function_signatures[idx]["return_types"] = []
139 # read in list of function names
140 elif cur_sec_data.id == SEC_EXPORT:
141 f_idx = 0
142 for _, entry in enumerate(cur_sec_data.payload.entries):
143 if entry.kind == 0:
144 f_name = entry.field_str.tobytes().decode()
145 function_names.append(f_name)
146 _func_lookup[f_name] = (f_idx, entry.index)
147 f_idx += 1
149 # read in map of function signatures to function names
150 elif cur_sec_data.id == SEC_FUNCTION:
151 self._function_types = cur_sec_data.payload.types
153 for x in function_names:
154 # check for only integer type in parameters and return values
155 supported_function = True
156 idx = _func_lookup[x][1]
158 if idx >= len(self._function_types):
159 raise ValueError("invalid wasm file")
161 for t in function_signatures[self._function_types[idx]]["parameter_types"]:
162 if t != self._int_type:
163 supported_function = False
164 for t in function_signatures[self._function_types[idx]]["return_types"]:
165 if t != self._int_type:
166 supported_function = False
168 if len(function_signatures[self._function_types[idx]]["return_types"]) > 1:
169 supported_function = False
171 if supported_function:
172 self._functions[x] = (
173 len(
174 function_signatures[self._function_types[idx]][
175 "parameter_types"
176 ]
177 ),
178 len(function_signatures[self._function_types[idx]]["return_types"]),
179 )
181 if not supported_function:
182 self._unsupported_functions.append(x)
184 if "init" not in self._functions:
185 raise ValueError("wasm file needs to contain a function named 'init'")
187 if self._functions["init"][0] != 0: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true
188 raise ValueError("init function should not have any parameter")
190 if self._functions["init"][1] != 0: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true
191 raise ValueError("init function should not have any results")
193 # Mark the module as checked, which indicates that function
194 # signatures are available and that it does not need
195 # to be checked again.
196 self.checked = True
198 @property
199 @deprecated("Use public property `checked` instead.")
200 def _check_file(self) -> bool:
201 return self.checked
203 def __str__(self) -> str:
204 """str representation of the wasm module"""
205 return self.uid
207 def __repr__(self) -> str:
208 """str representation of the contents of the wasm file."""
209 if not self.checked:
210 return f"Unchecked wasm module file with the uid {self.uid}"
212 result = f"Functions in wasm file with the uid {self.uid}:\n"
213 for x in self.functions:
214 result += f"function '{x}' with "
215 result += f"{self.functions[x][0]} i{self._int_size} parameter(s)"
216 result += f" and {self.functions[x][1]} i{self._int_size} return value(s)\n"
218 for x in self.unsupported_functions:
219 result += (
220 f"unsupported function with invalid "
221 f"parameter or result type: '{x}' \n"
222 )
224 return result
226 def bytecode(self) -> bytes:
227 """The wasm content as bytecode"""
228 return self._wasm_module
230 @cached_property
231 def bytecode_base64(self) -> bytes:
232 """The wasm content as base64 encoded bytecode."""
233 return base64.b64encode(self._wasm_module)
235 @property
236 @deprecated("Use public property `bytecode_base64` instead.")
237 def _wasm_file_encoded(self) -> bytes:
238 return self.bytecode_base64
240 @cached_property
241 def uid(self) -> str:
242 """A unique identifier for the module calculated from its' checksum."""
243 return hashlib.sha256(self.bytecode_base64).hexdigest()
245 @property
246 @deprecated("Use public property `uid` instead.")
247 def _wasmfileuid(self) -> str:
248 return self.uid
250 def check_function(
251 self, function_name: str, number_of_parameters: int, number_of_returns: int
252 ) -> bool:
253 """
254 Checks a given function name and signature if it is included and the
255 module has previously been checked.
257 If the module has not been checked this function with will raise a
258 ValueError.
260 :param function_name: name of the function that is checked
261 :type function_name: str
262 :param number_of_parameters: number of integer parameters of the function
263 :type number_of_parameters: int
264 :param number_of_returns: number of integer return values of the function
265 :type number_of_returns: int
266 :return: true if the signature and the name of the function is correct"""
267 if not self.checked: 267 ↛ 268line 267 didn't jump to line 268 because the condition on line 267 was never true
268 raise ValueError(
269 "Cannot retrieve functions from an unchecked wasm module."
270 " Please call .check() first."
271 )
273 return (
274 (function_name in self._functions)
275 and (self._functions[function_name][0] == number_of_parameters)
276 and (self._functions[function_name][1] == number_of_returns)
277 )
279 @property
280 def functions(self) -> dict[str, tuple[int, int]]:
281 """Retrieve the names of functions with the number of input and out arguments.
283 If the module has not been checked this function with will raise a
284 ValueError.
285 """
286 if not self.checked: 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true
287 raise ValueError(
288 "Cannot retrieve functions from an unchecked wasm module."
289 " Please call .check() first."
290 )
291 return self._functions
293 @property
294 def unsupported_functions(self) -> list[str]:
295 """Retrieve the names of unsupported functions as a list of strings.
297 If the module has not been checked this function with will raise a
298 ValueError.
299 """
300 if not self.checked: 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true
301 raise ValueError(
302 "Cannot retrieve functions from an unchecked wasm module."
303 " Please call .check() first."
304 )
305 return self._unsupported_functions
308class WasmFileHandler(WasmModuleHandler):
309 """Construct and optionally check a wasm module from a file for use in wasm Ops."""
311 def __init__(self, filepath: str, check_file: bool = True, int_size: int = 32):
312 """
313 Construct a wasm file handler using a filepath to read a wasm module into
314 memory.
316 :param filepath: Path to the wasm file
317 :type filepath: str
318 :param check_file: If ``True`` checks file for compatibility with wasm
319 standards. If ``False`` checks are skipped.
320 :type check_file: bool
321 :param int_size: length of the integer that is used in the wasm file
322 :type int_size: int
323 """
324 if not exists(filepath):
325 raise ValueError("wasm file not found at given path")
327 with open(filepath, "rb") as file:
328 self._wasm_file: bytes = file.read()
330 super().__init__(self._wasm_file, check_file, int_size)