Coverage for /home/runner/work/tket/tket/pytket/pytket/wasm/wasm.py: 90%

136 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-09 15:08 +0000

1# Copyright Quantinuum 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import base64 

16import hashlib 

17from functools import cached_property 

18from os.path import exists 

19 

20from qwasm import ( # type: ignore 

21 LANG_TYPE_EMPTY, 

22 LANG_TYPE_F32, 

23 LANG_TYPE_F64, 

24 LANG_TYPE_I32, 

25 LANG_TYPE_I64, 

26 SEC_EXPORT, 

27 SEC_FUNCTION, 

28 SEC_TYPE, 

29 decode_module, 

30) 

31from typing_extensions import deprecated 

32 

33 

34class WasmModuleHandler: 

35 """Construct and optionally check a wasm module for use in wasm Ops.""" 

36 

37 checked: bool 

38 _int_size: int 

39 _wasm_module: bytes 

40 _functions: dict[str, tuple[int, int]] 

41 _unsupported_functions: list[str] 

42 

43 type_lookup = { # noqa: RUF012 

44 LANG_TYPE_I32: "i32", 

45 LANG_TYPE_I64: "i64", 

46 LANG_TYPE_F32: "f32", 

47 LANG_TYPE_F64: "f64", 

48 LANG_TYPE_EMPTY: None, 

49 } 

50 

51 def __init__( 

52 self, wasm_module: bytes, check: bool = True, int_size: int = 32 

53 ) -> None: 

54 """ 

55 Construct a wasm module handler 

56 

57 :param wasm_module: A wasm module in binary format. 

58 :type wasm_module: bytes 

59 :param check: If ``True`` checks file for compatibility with wasm 

60 standards. If ``False`` checks are skipped. 

61 :type check: bool 

62 :param int_size: length of the integer that is used in the wasm file 

63 :type int_size: int 

64 """ 

65 self._int_size = int_size 

66 if int_size == 32: # noqa: PLR2004 

67 self._int_type = self.type_lookup[LANG_TYPE_I32] 

68 elif int_size == 64: # noqa: PLR2004 68 ↛ 71line 68 didn't jump to line 71 because the condition on line 68 was always true

69 self._int_type = self.type_lookup[LANG_TYPE_I64] 

70 else: 

71 raise ValueError( 

72 "given integer length not valid, only 32 and 64 are allowed" 

73 ) 

74 

75 # stores the names of the functions mapped 

76 # to the number of parameters and the number of return values 

77 self._functions = {} 

78 

79 # contains the list of functions that are not allowed 

80 # to use in pytket (because of types that are not integers 

81 # of the supplied int_size.) 

82 self._unsupported_functions = [] 

83 

84 self._wasm_module = wasm_module 

85 self.checked = False 

86 

87 if check: 

88 self.check() 

89 

90 def check(self) -> None: # noqa: PLR0912, PLR0915 

91 """Collect functions from the module that can be used with pytket. 

92 

93 Populates the internal list of supported and unsupported functions 

94 and marks the module as checked so that subsequent checking is not 

95 required. 

96 """ 

97 if self.checked: 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true

98 return 

99 

100 function_signatures: list = [] 

101 function_names: list = [] 

102 _func_lookup = {} 

103 mod_iter = iter(decode_module(self._wasm_module)) 

104 _, _ = next(mod_iter) 

105 

106 for _, cur_sec_data in mod_iter: 

107 # read in list of function signatures 

108 if cur_sec_data.id == SEC_TYPE: 

109 for idx, entry in enumerate(cur_sec_data.payload.entries): 

110 function_signatures.append({}) 

111 function_signatures[idx]["parameter_types"] = [ 

112 self.type_lookup[pt] for pt in entry.param_types 

113 ] 

114 if entry.return_count > 1: 

115 if ( 115 ↛ 119line 115 didn't jump to line 119 because the condition on line 115 was never true

116 isinstance(entry.return_type, list) 

117 and len(entry.return_type) == entry.return_count 

118 ): 

119 function_signatures[idx]["return_types"] = [ 

120 self.type_lookup[rt] for rt in entry.return_type 

121 ] 

122 elif isinstance(entry.return_type, int): 122 ↛ 127line 122 didn't jump to line 127 because the condition on line 122 was always true

123 function_signatures[idx]["return_types"] = [ 

124 self.type_lookup[entry.return_type] 

125 ] * entry.return_count 

126 else: 

127 raise ValueError( 

128 "Only parameter and return values of " 

129 f"i{self._int_size} types are" 

130 f" allowed, found type: {entry.return_type}" 

131 ) 

132 elif entry.return_count == 1: 

133 function_signatures[idx]["return_types"] = [ 

134 self.type_lookup[entry.return_type] 

135 ] 

136 else: 

137 function_signatures[idx]["return_types"] = [] 

138 

139 # read in list of function names 

140 elif cur_sec_data.id == SEC_EXPORT: 

141 f_idx = 0 

142 for _, entry in enumerate(cur_sec_data.payload.entries): 

143 if entry.kind == 0: 

144 f_name = entry.field_str.tobytes().decode() 

145 function_names.append(f_name) 

146 _func_lookup[f_name] = (f_idx, entry.index) 

147 f_idx += 1 

148 

149 # read in map of function signatures to function names 

150 elif cur_sec_data.id == SEC_FUNCTION: 

151 self._function_types = cur_sec_data.payload.types 

152 

153 for x in function_names: 

154 # check for only integer type in parameters and return values 

155 supported_function = True 

156 idx = _func_lookup[x][1] 

157 

158 if idx >= len(self._function_types): 

159 raise ValueError("invalid wasm file") 

160 

161 for t in function_signatures[self._function_types[idx]]["parameter_types"]: 

162 if t != self._int_type: 

163 supported_function = False 

164 for t in function_signatures[self._function_types[idx]]["return_types"]: 

165 if t != self._int_type: 

166 supported_function = False 

167 

168 if len(function_signatures[self._function_types[idx]]["return_types"]) > 1: 

169 supported_function = False 

170 

171 if supported_function: 

172 self._functions[x] = ( 

173 len( 

174 function_signatures[self._function_types[idx]][ 

175 "parameter_types" 

176 ] 

177 ), 

178 len(function_signatures[self._function_types[idx]]["return_types"]), 

179 ) 

180 

181 if not supported_function: 

182 self._unsupported_functions.append(x) 

183 

184 if "init" not in self._functions: 

185 raise ValueError("wasm file needs to contain a function named 'init'") 

186 

187 if self._functions["init"][0] != 0: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true

188 raise ValueError("init function should not have any parameter") 

189 

190 if self._functions["init"][1] != 0: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true

191 raise ValueError("init function should not have any results") 

192 

193 # Mark the module as checked, which indicates that function 

194 # signatures are available and that it does not need 

195 # to be checked again. 

196 self.checked = True 

197 

198 @property 

199 @deprecated("Use public property `checked` instead.") 

200 def _check_file(self) -> bool: 

201 return self.checked 

202 

203 def __str__(self) -> str: 

204 """str representation of the wasm module""" 

205 return self.uid 

206 

207 def __repr__(self) -> str: 

208 """str representation of the contents of the wasm file.""" 

209 if not self.checked: 

210 return f"Unchecked wasm module file with the uid {self.uid}" 

211 

212 result = f"Functions in wasm file with the uid {self.uid}:\n" 

213 for x in self.functions: 

214 result += f"function '{x}' with " 

215 result += f"{self.functions[x][0]} i{self._int_size} parameter(s)" 

216 result += f" and {self.functions[x][1]} i{self._int_size} return value(s)\n" 

217 

218 for x in self.unsupported_functions: 

219 result += ( 

220 f"unsupported function with invalid parameter or result type: '{x}' \n" 

221 ) 

222 

223 return result 

224 

225 def bytecode(self) -> bytes: 

226 """The wasm content as bytecode""" 

227 return self._wasm_module 

228 

229 @cached_property 

230 def bytecode_base64(self) -> bytes: 

231 """The wasm content as base64 encoded bytecode.""" 

232 return base64.b64encode(self._wasm_module) 

233 

234 @property 

235 @deprecated("Use public property `bytecode_base64` instead.") 

236 def _wasm_file_encoded(self) -> bytes: 

237 return self.bytecode_base64 

238 

239 @cached_property 

240 def uid(self) -> str: 

241 """A unique identifier for the module calculated from its' checksum.""" 

242 return hashlib.sha256(self.bytecode_base64).hexdigest() 

243 

244 @property 

245 @deprecated("Use public property `uid` instead.") 

246 def _wasmfileuid(self) -> str: 

247 return self.uid 

248 

249 def check_function( 

250 self, function_name: str, number_of_parameters: int, number_of_returns: int 

251 ) -> bool: 

252 """ 

253 Checks a given function name and signature if it is included and the 

254 module has previously been checked. 

255 

256 If the module has not been checked this function with will raise a 

257 ValueError. 

258 

259 :param function_name: name of the function that is checked 

260 :type function_name: str 

261 :param number_of_parameters: number of integer parameters of the function 

262 :type number_of_parameters: int 

263 :param number_of_returns: number of integer return values of the function 

264 :type number_of_returns: int 

265 :return: true if the signature and the name of the function is correct""" 

266 if not self.checked: 266 ↛ 267line 266 didn't jump to line 267 because the condition on line 266 was never true

267 raise ValueError( 

268 "Cannot retrieve functions from an unchecked wasm module." 

269 " Please call .check() first." 

270 ) 

271 

272 return ( 

273 (function_name in self._functions) 

274 and (self._functions[function_name][0] == number_of_parameters) 

275 and (self._functions[function_name][1] == number_of_returns) 

276 ) 

277 

278 @property 

279 def functions(self) -> dict[str, tuple[int, int]]: 

280 """Retrieve the names of functions with the number of input and out arguments. 

281 

282 If the module has not been checked this function with will raise a 

283 ValueError. 

284 """ 

285 if not self.checked: 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true

286 raise ValueError( 

287 "Cannot retrieve functions from an unchecked wasm module." 

288 " Please call .check() first." 

289 ) 

290 return self._functions 

291 

292 @property 

293 def unsupported_functions(self) -> list[str]: 

294 """Retrieve the names of unsupported functions as a list of strings. 

295 

296 If the module has not been checked this function with will raise a 

297 ValueError. 

298 """ 

299 if not self.checked: 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true

300 raise ValueError( 

301 "Cannot retrieve functions from an unchecked wasm module." 

302 " Please call .check() first." 

303 ) 

304 return self._unsupported_functions 

305 

306 

307class WasmFileHandler(WasmModuleHandler): 

308 """Construct and optionally check a wasm module from a file for use in wasm Ops.""" 

309 

310 def __init__(self, filepath: str, check_file: bool = True, int_size: int = 32): 

311 """ 

312 Construct a wasm file handler using a filepath to read a wasm module into 

313 memory. 

314 

315 :param filepath: Path to the wasm file 

316 :type filepath: str 

317 :param check_file: If ``True`` checks file for compatibility with wasm 

318 standards. If ``False`` checks are skipped. 

319 :type check_file: bool 

320 :param int_size: length of the integer that is used in the wasm file 

321 :type int_size: int 

322 """ 

323 if not exists(filepath): 

324 raise ValueError("wasm file not found at given path") 

325 

326 with open(filepath, "rb") as file: 

327 self._wasm_file: bytes = file.read() 

328 

329 super().__init__(self._wasm_file, check_file, int_size)