diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index cb491b80..75bc51ab 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -111,14 +111,16 @@ def write_vcd_model(steps): print("%s Writing model to VCD file." % smt.timestamp()) vcd = mkvcd(open(vcdfile, "w")) + for netpath in sorted(smt.hiernets(topmod)): - width = len(smt.get_net_bin(topmod, netpath, "s0")) - vcd.add_net([topmod] + netpath, width) + vcd.add_net([topmod] + netpath, smt.net_width(topmod, netpath)) for i in range(steps): vcd.set_time(i) - for netpath in sorted(smt.hiernets(topmod)): - vcd.set_net([topmod] + netpath, smt.get_net_bin(topmod, netpath, "s%d" % i)) + path_list = sorted(smt.hiernets(topmod)) + value_list = smt.get_net_bin_list(topmod, path_list, "s%d" % i) + for path, value in zip(path_list, value_list): + vcd.set_net([topmod] + path, value) vcd.set_time(steps) diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index 53d2ec57..1b3944eb 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -157,7 +157,7 @@ class smtio: print("< %s" % line) if count_brackets == 0: break - if not self.p.poll(): + if self.p.poll(): print("SMT Solver terminated unexpectedly: %s" % "".join(stmt)) sys.exit(1) @@ -297,33 +297,51 @@ class smtio: self.write("(get-value (%s))" % (expr)) return self.parse(self.read())[0][1] - def get_net(self, mod_name, net_path, state_name): - def mkexpr(mod, base, path): - if len(path) == 1: - assert mod in self.modinfo - assert path[0] in self.modinfo[mod].wsize - return "(|%s_n %s| %s)" % (mod, path[0], base) + def get_list(self, expr_list): + self.write("(get-value (%s))" % " ".join(expr_list)) + return [n[1] for n in self.parse(self.read())] + def net_expr(self, mod, base, path): + if len(path) == 1: assert mod in self.modinfo - assert path[0] in self.modinfo[mod].cells + assert path[0] in self.modinfo[mod].wsize + return "(|%s_n %s| %s)" % (mod, path[0], base) - nextmod = self.modinfo[mod].cells[path[0]] - nextbase = "(|%s_h %s| %s)" % (mod, path[0], base) - return mkexpr(nextmod, nextbase, path[1:]) + assert mod in self.modinfo + assert path[0] in self.modinfo[mod].cells - return self.get(mkexpr(mod_name, state_name, net_path)) + nextmod = self.modinfo[mod].cells[path[0]] + nextbase = "(|%s_h %s| %s)" % (mod, path[0], base) + return self.net_expr(nextmod, nextbase, path[1:]) - def get_net_bool(self, mod_name, net_path, state_name): - v = self.get_net(mod_name, net_path, state_name) - assert v in ["true", "false"] - return 1 if v == "true" else 0 + def net_width(self, mod, net_path): + for i in range(len(net_path)-1): + assert mod in self.modinfo + assert net_path[i] in self.modinfo[mod].cells + mod = self.modinfo[mod].cells[net_path[i]] + + assert mod in self.modinfo + assert net_path[-1] in self.modinfo[mod].wsize + return self.modinfo[mod].wsize[net_path[-1]] + + def get_net(self, mod_name, net_path, state_name): + return self.get(self.net_expr(mod_name, state_name, net_path)) + + def get_net_list(self, mod_name, net_path_list, state_name): + return self.get_list([self.net_expr(mod_name, state_name, n) for n in net_path_list]) def get_net_hex(self, mod_name, net_path, state_name): return self.bv2hex(self.get_net(mod_name, net_path, state_name)) + def get_net_hex_list(self, mod_name, net_path_list, state_name): + return [self.bv2hex(v) for v in self.get_net_list(mod_name, net_path_list, state_name)] + def get_net_bin(self, mod_name, net_path, state_name): return self.bv2bin(self.get_net(mod_name, net_path, state_name)) + def get_net_bin_list(self, mod_name, net_path_list, state_name): + return [self.bv2bin(v) for v in self.get_net_list(mod_name, net_path_list, state_name)] + def wait(self): self.p.wait()