Coverage for /home/runner/work/tket/tket/pytket/pytket/wasm/wasm.py: 90%

136 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-02 12:44 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import base64 

16import hashlib 

17from functools import cached_property 

18from os.path import exists 

19 

20from qwasm import ( # type: ignore 

21 LANG_TYPE_EMPTY, 

22 LANG_TYPE_F32, 

23 LANG_TYPE_F64, 

24 LANG_TYPE_I32, 

25 LANG_TYPE_I64, 

26 SEC_EXPORT, 

27 SEC_FUNCTION, 

28 SEC_TYPE, 

29 decode_module, 

30) 

31from typing_extensions import deprecated 

32 

33 

34class WasmModuleHandler: 

35 """Construct and optionally check a wasm module for use in wasm Ops.""" 

36 

37 checked: bool 

38 _int_size: int 

39 _wasm_module: bytes 

40 _functions: dict[str, tuple[int, int]] 

41 _unsupported_functions: list[str] 

42 

43 type_lookup = { # noqa: RUF012 

44 LANG_TYPE_I32: "i32", 

45 LANG_TYPE_I64: "i64", 

46 LANG_TYPE_F32: "f32", 

47 LANG_TYPE_F64: "f64", 

48 LANG_TYPE_EMPTY: None, 

49 } 

50 

51 def __init__( 

52 self, wasm_module: bytes, check: bool = True, int_size: int = 32 

53 ) -> None: 

54 """ 

55 Construct a wasm module handler 

56 

57 :param wasm_module: A wasm module in binary format. 

58 :param check: If ``True`` checks file for compatibility with wasm 

59 standards. If ``False`` checks are skipped. 

60 :param int_size: length of the integer that is used in the wasm file 

61 """ 

62 self._int_size = int_size 

63 if int_size == 32: # noqa: PLR2004 

64 self._int_type = self.type_lookup[LANG_TYPE_I32] 

65 elif int_size == 64: # noqa: PLR2004 65 ↛ 68line 65 didn't jump to line 68 because the condition on line 65 was always true

66 self._int_type = self.type_lookup[LANG_TYPE_I64] 

67 else: 

68 raise ValueError( 

69 "given integer length not valid, only 32 and 64 are allowed" 

70 ) 

71 

72 # stores the names of the functions mapped 

73 # to the number of parameters and the number of return values 

74 self._functions = {} 

75 

76 # contains the list of functions that are not allowed 

77 # to use in pytket (because of types that are not integers 

78 # of the supplied int_size.) 

79 self._unsupported_functions = [] 

80 

81 self._wasm_module = wasm_module 

82 self.checked = False 

83 

84 if check: 

85 self.check() 

86 

87 def check(self) -> None: # noqa: PLR0912, PLR0915 

88 """Collect functions from the module that can be used with pytket. 

89 

90 Populates the internal list of supported and unsupported functions 

91 and marks the module as checked so that subsequent checking is not 

92 required. 

93 """ 

94 if self.checked: 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true

95 return 

96 

97 function_signatures: list = [] 

98 function_names: list = [] 

99 _func_lookup = {} 

100 mod_iter = iter(decode_module(self._wasm_module)) 

101 _, _ = next(mod_iter) 

102 

103 for _, cur_sec_data in mod_iter: 

104 # read in list of function signatures 

105 if cur_sec_data.id == SEC_TYPE: 

106 for idx, entry in enumerate(cur_sec_data.payload.entries): 

107 function_signatures.append({}) 

108 function_signatures[idx]["parameter_types"] = [ 

109 self.type_lookup[pt] for pt in entry.param_types 

110 ] 

111 if entry.return_count > 1: 

112 if ( 112 ↛ 116line 112 didn't jump to line 116 because the condition on line 112 was never true

113 isinstance(entry.return_type, list) 

114 and len(entry.return_type) == entry.return_count 

115 ): 

116 function_signatures[idx]["return_types"] = [ 

117 self.type_lookup[rt] for rt in entry.return_type 

118 ] 

119 elif isinstance(entry.return_type, int): 119 ↛ 124line 119 didn't jump to line 124 because the condition on line 119 was always true

120 function_signatures[idx]["return_types"] = [ 

121 self.type_lookup[entry.return_type] 

122 ] * entry.return_count 

123 else: 

124 raise ValueError( 

125 "Only parameter and return values of " 

126 f"i{self._int_size} types are" 

127 f" allowed, found type: {entry.return_type}" 

128 ) 

129 elif entry.return_count == 1: 

130 function_signatures[idx]["return_types"] = [ 

131 self.type_lookup[entry.return_type] 

132 ] 

133 else: 

134 function_signatures[idx]["return_types"] = [] 

135 

136 # read in list of function names 

137 elif cur_sec_data.id == SEC_EXPORT: 

138 f_idx = 0 

139 for _, entry in enumerate(cur_sec_data.payload.entries): 

140 if entry.kind == 0: 

141 f_name = entry.field_str.tobytes().decode() 

142 function_names.append(f_name) 

143 _func_lookup[f_name] = (f_idx, entry.index) 

144 f_idx += 1 

145 

146 # read in map of function signatures to function names 

147 elif cur_sec_data.id == SEC_FUNCTION: 

148 self._function_types = cur_sec_data.payload.types 

149 

150 for x in function_names: 

151 # check for only integer type in parameters and return values 

152 supported_function = True 

153 idx = _func_lookup[x][1] 

154 

155 if idx >= len(self._function_types): 

156 raise ValueError("invalid wasm file") 

157 

158 for t in function_signatures[self._function_types[idx]]["parameter_types"]: 

159 if t != self._int_type: 

160 supported_function = False 

161 for t in function_signatures[self._function_types[idx]]["return_types"]: 

162 if t != self._int_type: 

163 supported_function = False 

164 

165 if len(function_signatures[self._function_types[idx]]["return_types"]) > 1: 

166 supported_function = False 

167 

168 if supported_function: 

169 self._functions[x] = ( 

170 len( 

171 function_signatures[self._function_types[idx]][ 

172 "parameter_types" 

173 ] 

174 ), 

175 len(function_signatures[self._function_types[idx]]["return_types"]), 

176 ) 

177 

178 if not supported_function: 

179 self._unsupported_functions.append(x) 

180 

181 if "init" not in self._functions: 

182 raise ValueError("wasm file needs to contain a function named 'init'") 

183 

184 if self._functions["init"][0] != 0: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true

185 raise ValueError("init function should not have any parameter") 

186 

187 if self._functions["init"][1] != 0: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true

188 raise ValueError("init function should not have any results") 

189 

190 # Mark the module as checked, which indicates that function 

191 # signatures are available and that it does not need 

192 # to be checked again. 

193 self.checked = True 

194 

195 @property 

196 @deprecated("Use public property `checked` instead.") 

197 def _check_file(self) -> bool: 

198 return self.checked 

199 

200 def __str__(self) -> str: 

201 """str representation of the wasm module""" 

202 return self.uid 

203 

204 def __repr__(self) -> str: 

205 """str representation of the contents of the wasm file.""" 

206 if not self.checked: 

207 return f"Unchecked wasm module file with the uid {self.uid}" 

208 

209 result = f"Functions in wasm file with the uid {self.uid}:\n" 

210 for x in self.functions: 

211 result += f"function '{x}' with " 

212 result += f"{self.functions[x][0]} i{self._int_size} parameter(s)" 

213 result += f" and {self.functions[x][1]} i{self._int_size} return value(s)\n" 

214 

215 for x in self.unsupported_functions: 

216 result += ( 

217 f"unsupported function with invalid parameter or result type: '{x}' \n" 

218 ) 

219 

220 return result 

221 

222 def bytecode(self) -> bytes: 

223 """The wasm content as bytecode""" 

224 return self._wasm_module 

225 

226 @cached_property 

227 def bytecode_base64(self) -> bytes: 

228 """The wasm content as base64 encoded bytecode.""" 

229 return base64.b64encode(self._wasm_module) 

230 

231 @property 

232 @deprecated("Use public property `bytecode_base64` instead.") 

233 def _wasm_file_encoded(self) -> bytes: 

234 return self.bytecode_base64 

235 

236 @cached_property 

237 def uid(self) -> str: 

238 """A unique identifier for the module calculated from its' checksum.""" 

239 return hashlib.sha256(self.bytecode_base64).hexdigest() 

240 

241 @property 

242 @deprecated("Use public property `uid` instead.") 

243 def _wasmfileuid(self) -> str: 

244 return self.uid 

245 

246 def check_function( 

247 self, function_name: str, number_of_parameters: int, number_of_returns: int 

248 ) -> bool: 

249 """ 

250 Checks a given function name and signature if it is included and the 

251 module has previously been checked. 

252 

253 If the module has not been checked this function with will raise a 

254 ValueError. 

255 

256 :param function_name: name of the function that is checked 

257 :param number_of_parameters: number of integer parameters of the function 

258 :param number_of_returns: number of integer return values of the function 

259 :return: true if the signature and the name of the function is correct""" 

260 if not self.checked: 260 ↛ 261line 260 didn't jump to line 261 because the condition on line 260 was never true

261 raise ValueError( 

262 "Cannot retrieve functions from an unchecked wasm module." 

263 " Please call .check() first." 

264 ) 

265 

266 return ( 

267 (function_name in self._functions) 

268 and (self._functions[function_name][0] == number_of_parameters) 

269 and (self._functions[function_name][1] == number_of_returns) 

270 ) 

271 

272 @property 

273 def functions(self) -> dict[str, tuple[int, int]]: 

274 """Retrieve the names of functions with the number of input and out arguments. 

275 

276 If the module has not been checked this function with will raise a 

277 ValueError. 

278 """ 

279 if not self.checked: 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true

280 raise ValueError( 

281 "Cannot retrieve functions from an unchecked wasm module." 

282 " Please call .check() first." 

283 ) 

284 return self._functions 

285 

286 @property 

287 def unsupported_functions(self) -> list[str]: 

288 """Retrieve the names of unsupported functions as a list of strings. 

289 

290 If the module has not been checked this function with will raise a 

291 ValueError. 

292 """ 

293 if not self.checked: 293 ↛ 294line 293 didn't jump to line 294 because the condition on line 293 was never true

294 raise ValueError( 

295 "Cannot retrieve functions from an unchecked wasm module." 

296 " Please call .check() first." 

297 ) 

298 return self._unsupported_functions 

299 

300 

301class WasmFileHandler(WasmModuleHandler): 

302 """Construct and optionally check a wasm module from a file for use in wasm Ops.""" 

303 

304 def __init__(self, filepath: str, check_file: bool = True, int_size: int = 32): 

305 """ 

306 Construct a wasm file handler using a filepath to read a wasm module into 

307 memory. 

308 

309 :param filepath: Path to the wasm file 

310 :param check_file: If ``True`` checks file for compatibility with wasm 

311 standards. If ``False`` checks are skipped. 

312 :param int_size: length of the integer that is used in the wasm file 

313 """ 

314 if not exists(filepath): 

315 raise ValueError("wasm file not found at given path") 

316 

317 with open(filepath, "rb") as file: 

318 self._wasm_file: bytes = file.read() 

319 

320 super().__init__(self._wasm_file, check_file, int_size)