Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions interwhen/utils/zebralogic_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections import defaultdict
from typing import Dict, List, Tuple, Optional, Set
from importlib import resources
from z3 import Solver, And, Or, Not, Bool, PbEq
from z3 import Solver, And, Or, Not, Bool, PbEq, sat


# ============== ZebraLogicProblem ==============
Expand All @@ -32,6 +32,8 @@ def __init__(self, problem: dict):
self.houses = list(range(self.n_houses))
self.features = self._make_features(problem['features'])
self.clue_texts = problem['clues']
for clue_ir in problem['clue_irs']:
self.apply_ir(clue_ir)

def _make_features(self, features_domains: dict) -> dict:
"""Create Z3 boolean variables for each feature/value/house combination.
Expand Down Expand Up @@ -73,7 +75,6 @@ def compile_constraint(self, ir: dict):
same_house: Two entities at same position
pos_relation: Spatial relationship between entities
"""
from z3 import And, Or, Not

t = ir["type"]

Expand Down Expand Up @@ -106,11 +107,15 @@ def B(h): return self.features[f2][v2][h]

def k_to_left(k):
clauses = []

# A is somewhere to the left of B
if k == "?":
for h1 in self.houses:
for h2 in self.houses:
if h1 < h2:
clauses.append(And(A(h1), B(h2)))

# A at h, B at h + k
else:
k = int(k)
for h in range(0, self.n_houses - k):
Expand All @@ -119,11 +124,15 @@ def k_to_left(k):

def k_to_right(k):
clauses = []

# A is somewhere to the right of B
if k == "?":
for h1 in self.houses:
for h2 in self.houses:
if h1 > h2:
clauses.append(And(A(h1), B(h2)))

# A at h, B at h - k
else:
k = int(k)
for h in range(k, self.n_houses):
Expand Down Expand Up @@ -154,12 +163,15 @@ def apply_ir(self, ir):
compiled_disjuncts = []
for disjunct in ir:
assert isinstance(disjunct, list), "Each disjunct must be a list of conjuncts"
compiled_conjuncts = [self.compile_constraint(c) for c in disjunct]
compiled_conjuncts = []
for conjunct in disjunct:
assert isinstance(conjunct, dict), "Each conjunct must be a dict representing an IR"
compiled = self.compile_constraint(conjunct)
compiled_conjuncts.append(compiled)
compiled_disjuncts.append(And(*compiled_conjuncts))
self.solver.add(Or(*compiled_disjuncts))

@property
def is_satisfiable(self) -> bool:
"""Check if the current constraints are satisfiable."""
from z3 import sat
return self.solver.check() == sat