from typing import Union
from btor2nodenameprinter import Btor2NodeNamePrinter

Node = tuple[int, int]
"Identifies an internal node in the circuit: (copy_id, local_id), or (0, 0) for ObEq."
Literal = Union[tuple[bool, Node], tuple[bool, Node, Node]]
"Represents a literal: (pos_or_neg, node) or (eq_or_neq, node_low, node_high)"
Clause = tuple[Literal]
"Represents a clause, essentially a list of sorted literals"

# `copy_id` is 1, 2, 3, ... for circuit A, and -1, -2, -3, ... for circuit B. 
# `copy_id` is 0 for the ObEq signal.
OB_EQ_NODE = (0, 0)

def make_node_a(copy_index: int, local_id: int) -> Node:
    "Returns the node identifier for circuit A."
    return (1 + copy_index, local_id)

def make_node_b(copy_index: int, local_id: int) -> Node:
    "Returns the node identifier for circuit B."
    return (-1 - copy_index, local_id)

def make_p_lit(node: Node, pos: bool) -> Literal:
    "Returns a predicate literal of a node."
    return (pos, node)

def make_e_lit(node1: Node, node2: Node, eq: bool) -> Literal:
    "Returns an equality literal of two nodes."
    if node1 < node2:  # sort the nodes
        return (eq, node1, node2)
    return (eq, node2, node1)

def make_clause(literals: list[Literal]) -> Clause:
    "Returns a clause from a list of literals."
    return tuple(sorted(literals))


def neg_of_lit(lit: Literal) -> Literal:
    "Returns the negation of a literal."
    return (not lit[0], *lit[1:])

def neg_of_clause(clause: Clause) -> list[Literal]:
    "Returns the negation of a clause."
    return [neg_of_lit(lit) for lit in clause]

def neg_of_cube(cube: list[Literal]) -> Clause:
    "Returns the negation of a cube."
    return make_clause(neg_of_lit(lit) for lit in cube)


def is_strict_subclause(clause_1: Clause, clause_2: Clause) -> bool:
    "Returns whether `clause_1` is a strict subclause of `clause_2`."
    if len(clause_1) >= len(clause_2):
        return False
    i = 0
    for lit in clause_1:
        while i < len(clause_2) and clause_2[i] != lit:
            i += 1
        if i == len(clause_2):
            return False
        i += 1
    return True


## experimental optimization, seems useless though??
def expand_clause(clause: Clause) -> Clause:
    "Expands a clause by adding all possible (in)/-equality literals (is it better for SAT solvers?)."
    literals = []
    representative_to_group: dict[Node, list[Node]] = {}  # "representative" is the "smallest" node in a group
    neq_pairs: list[tuple[Node, Node]] = []
    for lit in clause:
        if len(lit) == 2:  # keep it:
            literals.append(lit)
        else:
            eq, node_low, node_high = lit  # in a clause, eq/neq is flipped
            if eq:  # `==` in clause; corresponding to `!=` in cube
                if node_low not in representative_to_group:
                    representative_to_group[node_low] = [node_low]
                if node_high not in representative_to_group:
                    representative_to_group[node_high] = [node_high]
                neq_pairs.append((node_low, node_high))
            else:  # `!=` in clause; corresponding to `==` in cube
                if node_low not in representative_to_group:
                    representative_to_group[node_low] = [node_low]
                representative_to_group[node_low].append(node_high)
    for group in representative_to_group.values():
        for i in range(len(group)):
            for j in range(i + 1, len(group)):
                literals.append(make_e_lit(group[i], group[j], False))
    for rep_1, rep_2 in neq_pairs:
        for node_1 in representative_to_group[rep_1]:
            for node_2 in representative_to_group[rep_2]:
                literals.append(make_e_lit(node_1, node_2, True))
    return make_clause(literals)


LIT_OB_EQ = (True, OB_EQ_NODE)
"The positive ObEq literal."
LIT_NEG_OB_EQ = (False, OB_EQ_NODE)
"The negated ObEq literal."


class Printer:
    """
    For printing nodes, literals and clauses (and more...)
    """
    printer_a: Btor2NodeNamePrinter
    "The printer for circuit A."
    printer_b: Btor2NodeNamePrinter
    "The printer for circuit B."

    def __init__(self, btor2_lines_a: list[str], btor2_lines_b: list[str]):
        self.printer_a = Btor2NodeNamePrinter(btor2_lines_a)
        self.printer_b = Btor2NodeNamePrinter(btor2_lines_b)

    def node(self, node: Node) -> str:
        "Returns the string representation of a node."
        copy_id, local_id = node
        if copy_id > 0:
            return self.printer_a.get_str(local_id, f"@A{copy_id - 1}")
        elif copy_id < 0:
            return self.printer_b.get_str(local_id, f"@B{-copy_id - 1}")
        return "ObEq"

    def literal(self, literal: Literal) -> str:
        "Returns the string representation of a literal."
        if len(literal) == 2:
            pos, node = literal
            return f"{'' if pos else '~'}{self.node(node)}"
        eq, node1, node2 = literal
        return f"{self.node(node1)} {'!=' if not eq else '=='} {self.node(node2)}"

    def clause(self, clause: Clause) -> str:
        "Returns the string representation of a clause."
        return " \/ ".join(self.literal(lit) for lit in clause)

    def invariant(self, clauses: list[Clause]) -> list[str]:
        "Returns the string representation of an invariant (negated clauses)."
        lines = []
        wedge = " /\ "
        for clause in clauses:
            if any(is_strict_subclause(other_clause, clause) for other_clause in clauses):
                continue  # skip subclauses
            cube = neg_of_clause(clause)
            lines.append(f"\/ ({wedge.join(self.literal(lit) for lit in cube)})")
        return lines
