import pyboolector as bt
from btor2circuit import Btor2Circuit
from typing import Union

## Terminologies (avoid misnomers in the code!):
# wire :=   [a nominal] identifier of a wire in a circuit
# index :=  [the unique identifier] of a wire in a BTOR2 file (a.k.a. BTOR2 index, BTOR2 ID) 
# term :=   [a nominal expression] built out of wires in the same circuit
# atom :=   [a nominal] of an atomic bit (some expression out of wires) in the abstract space
# lit :=    [a nominal] of a positive/negative atom
# cube :=   [a nominal] of a conjunction of literals
# clause := [a nominal] of a cube's negation
# node :=   [the real object] that actually exists in the SMT solver

class Wire:
    """
    Identifies a wire in a circuit.
    """
    is_a: bool
    "Is it a wire in circuit A (`True`) or circuit B (`False`)?"
    index: int
    "The BTOR2 index of the wire."
    prime: int
    "The number of primes in the wire. 0 for current state, 1 for next state, 2 for next-next state, etc."
    def __init__(self, is_a: bool, index: int, prime: int):
        self.is_a = is_a
        self.index = index
        self.prime = prime
    def __eq__(self, other):
        if not isinstance(other, Wire):
            return False
        return all((self.is_a == other.is_a, self.index == other.index, self.prime == other.prime))
    def __hash__(self):
        return hash((Wire, self.is_a, self.index, self.prime))
    def __lt__(self, other):
        assert isinstance(other, Wire), f"Cannot compare Wire with {type(other)}"
        return (self.is_a, self.index, self.prime) < (other.is_a, other.index, other.prime)
    def __repr__(self) -> str:
        prime_symbol = "'"
        return f"{'A' if self.is_a else 'B'}_{self.index}{prime_symbol * self.prime}"

def WireA(index: int, prime: int) -> Wire:
    "Creates a wire in circuit A."
    return Wire(True, index, prime)
def WireB(index: int, prime: int) -> Wire:
    "Creates a wire in circuit B."
    return Wire(False, index, prime)


class EqWirePair:
    """
    Establishing that two wires (left & right) are equal.
    """
    left: Wire
    right: Wire
    def __init__(self, left: Wire, right: Wire):
        if left < right:  # make sure left is always smaller than right
            right, left = left, right
        self.left = left
        self.right = right
    def _is_relational(self) -> bool:
        "An EqWirePair is relational, iff the left and right wires are from different circuits."
        return self.left.is_a != self.right.is_a
    def __eq__(self, other):
        if not isinstance(other, EqWirePair):
            return False
        return all((self.left == other.left, self.right == other.right))
    def __hash__(self):
        return hash((EqWirePair, self.left, self.right))
    def __lt__(self, other):
        assert isinstance(other, EqWirePair), f"Cannot compare EqWirePair with {type(other)}"
        return (self.left, self.right) < (other.left, other.right)
    def __repr__(self) -> str:
        return f"{self.left} == {self.right}"


# In theory, we could have more classes of Atoms. For now we only consider these two below:
Atom = Union[Wire, EqWirePair]
"Identifies a boolean atom, which is just a term with width 1 in the circuit."

def is_relational(atom: Atom) -> bool:
    "An atom is relational, iff it is of an `a.x == b.y` form."
    return isinstance(atom, EqWirePair) and atom._is_relational()

def prime_of_atom(atom: Atom) -> int:
    "Returns the number of primes in the atom."
    if isinstance(atom, Wire):
        return atom.prime
    elif isinstance(atom, EqWirePair):
        return max(atom.left.prime, atom.right.prime)
    else:
        raise ValueError(f"Unknown atom: {atom}")


Literal = tuple[bool, Atom]
"Represents a literal: a positive/negative atom."

Cube = list[Literal]
"Represents a cube: a conjunction of literals."

Clause = list[Literal]
"Represents a clause: a disjunction of literals."


class ProductMachine:
    """
    The product machine of two circuits, each with unrolled copies
    (can be lazily built and added to the solver).

    NOTE: Constructor has side effect: grows the product machine to the given heights (now 2, 2)
    """
    solver: bt.Boolector
    "The underlying SAT solver."
    suffix_a: str
    "The suffix for the first circuit, e.g. `@A`."
    suffix_b: str
    "The suffix for the second circuit, e.g. `@B`."
    btor2_lines_a: list[str]
    "The lines of the first circuit's BTOR2 file."
    btor2_lines_b: list[str]
    "The lines of the second circuit's BTOR2 file."
    circuit_a: list[Btor2Circuit]
    "The unrolling copies of the first circuit."
    circuit_b: list[Btor2Circuit]
    "The unrolling copies of the second circuit."
    
    def __init__(self, solver: bt.Boolector, suffix_a: str, suffix_b: str, btor2_lines_a: list[str], btor2_lines_b: list[str]):
        self.solver = solver
        self.suffix_a = suffix_a
        self.suffix_b = suffix_b
        self.btor2_lines_a = btor2_lines_a
        self.btor2_lines_b = btor2_lines_b
        self.circuit_a = []
        self.circuit_b = []
        self.grow(2, 2)  # initially, involved nodes are states + P only, 2 is enough.

    def I(self) -> list[Cube]:
        "Returns the (default) initial relation of the product machine. Assumed M-RDNF."
        ## below is just a default implementation, can change if needed (but must be an M-RDNF).
        literals = []
        a, b = self.circuit_a[0], self.circuit_b[0]
        for name in a.state_names():
            if name.startswith("_"):  # local state
                if name in a.name_to_init_state_id:
                    wire_init = WireA(a.name_to_init_state_id[name], 0)
                    wire_curr = WireA(a.name_to_state_id[name], 0)
                    literals.append((True, EqWirePair(wire_init, wire_curr)))
            else:  # interface state
                wire_a = WireA(a.name_to_state_id[name], 0)
                wire_b = WireB(b.name_to_state_id[name], 0)
                literals.append((True, EqWirePair(wire_a, wire_b)))
        for name in b.state_names():
            if name.startswith("_"):  # local state
                if name in b.name_to_init_state_id:
                    wire_init = WireB(b.name_to_init_state_id[name], 0)
                    wire_curr = WireB(b.name_to_state_id[name], 0)
                    literals.append((True, EqWirePair(wire_init, wire_curr)))
            # skip the interface states, they are already handled.
        return [literals]
    
    def P(self) -> list[Cube]:
        "Returns the (default) safety property (ObEq) of the product machine. Assumed M-RDNF."
        ## below is just a default implementation, can change if needed (but must be an M-RDNF).
        a, b = self.circuit_a[0], self.circuit_b[0]
        literals_nob = []  # when not observable
        literals_nob.append((False, WireA(a.valid_signal_id, 0)))
        literals_nob.append((False, WireB(b.valid_signal_id, 0)))
        literals_ob = []  # when observable
        literals_ob.append((True, WireA(a.valid_signal_id, 0)))
        literals_ob.append((True, WireB(b.valid_signal_id, 0)))
        for name in a.output_names():
            wire_a = WireA(a.name_to_output_id[name], 0)
            wire_b = WireB(b.name_to_output_id[name], 0)
            literals_ob.append((True, EqWirePair(wire_a, wire_b)))
        return [literals_nob, literals_ob]

    def _join_states(self, c_curr: Btor2Circuit, c_next: Btor2Circuit):
        "(private method) Joins the next states of `c_curr` with the current states of `c_next`."
        for name in c_curr.state_names():
            self.solver.Assert(c_curr.next_state_by_name(name) == c_next.curr_state_by_name(name))
    
    def grow(self, height_a: int, height_b: int):
        "Grows the product machine to the given heights. (Idempotent: does nothing if grown enough)"
        while len(self.circuit_a) < height_a:
            index = len(self.circuit_a)
            suffix = f"{self.suffix_a}{index}"
            self.circuit_a.append(Btor2Circuit(self.solver, suffix, self.btor2_lines_a))
            if index > 0:
                self._join_states(self.circuit_a[index - 1], self.circuit_a[index])
        while len(self.circuit_b) < height_b:
            index = len(self.circuit_b)
            suffix = f"{self.suffix_b}{index}"
            self.circuit_b.append(Btor2Circuit(self.solver, suffix, self.btor2_lines_b))
            if index > 0:
                self._join_states(self.circuit_b[index - 1], self.circuit_b[index])

    def all_state_wires(self) -> list[Wire]:
        "Returns all state variables (as wires) in the product machine."
        wires = []
        a, b = self.circuit_a[0], self.circuit_b[0]
        for name in a.state_names():
            wires.append(WireA(a.name_to_state_id[name], 0))
        for name in b.state_names():
            wires.append(WireB(b.name_to_state_id[name], 0))
        return wires

    def wire_to_node(self, wire: Wire, ex_prime_a: int, ex_prime_b: int) -> bt.BoolectorNode:
        "Converts a wire to a Boolector node, considering extra prime."
        if wire.is_a:
            circuit_copy = self.circuit_a[wire.prime + ex_prime_a]
        else:
            circuit_copy = self.circuit_b[wire.prime + ex_prime_b]
        return circuit_copy.id_to_node[wire.index]

    def atom_to_node(self, atom: Atom, ex_prime_a: int, ex_prime_b: int) -> bt.BoolectorNode:
        "Converts an atom to a Boolector node, considering extra primes."
        if isinstance(atom, Wire):
            return self.wire_to_node(atom, ex_prime_a, ex_prime_b)
        elif isinstance(atom, EqWirePair):
            left = self.wire_to_node(atom.left, ex_prime_a, ex_prime_b)
            right = self.wire_to_node(atom.right, ex_prime_a, ex_prime_b)
            return self.solver.Eq(left, right)
        else:
            raise ValueError(f"Unknown atom: {atom}")

    def lit_to_node(self, lit: Literal, ex_prime_a: int, ex_prime_b: int) -> bt.BoolectorNode:
        "Converts a literal to a Boolector node, considering extra primes."
        pos, atom = lit
        node = self.atom_to_node(atom, ex_prime_a, ex_prime_b)
        return node if pos else self.solver.Not(node)
    
    def cube_to_node(self, cube: Cube, ex_prime_a: int, ex_prime_b: int) -> bt.BoolectorNode:
        "Converts a cube to a Boolector node, considering extra primes."
        node = self.solver.Const(True)
        for lit in cube:
            node = self.solver.And(node, self.lit_to_node(lit, ex_prime_a, ex_prime_b))
        return node
    
    def clause_to_node(self, clause: Clause, ex_prime_a: int, ex_prime_b: int) -> bt.BoolectorNode:
        "Converts a clause to a Boolector node, considering extra primes."
        node = self.solver.Const(False)
        for lit in clause:
            node = self.solver.Or(node, self.lit_to_node(lit, ex_prime_a, ex_prime_b))
        return node

    def dnf_to_node(self, dnf: list[Cube], ex_prime_a: int, ex_prime_b: int) -> bt.BoolectorNode:
        "Converts a DNF to a Boolector node, considering extra primes."
        node = self.solver.Const(False)
        for cube in dnf:
            node = self.solver.Or(node, self.cube_to_node(cube, ex_prime_a, ex_prime_b))
        return node

    def get_all_wires(self, primes_a: list[int], primes_b: list[int]) -> list[Wire]:
        "Returns all wires in the product machine."
        wires = []
        for prime in primes_a:
            wires.extend([WireA(index, prime) for index in self.circuit_a[0].id_to_node])
        for prime in primes_b:
            wires.extend([WireB(index, prime) for index in self.circuit_b[0].id_to_node])
        return wires