# Natural Language Toolkit: Discourse Representation Theory (DRT)
#
# Author: Dan Garrette <dhgarrette@gmail.com>
#
# Copyright (C) 2001-2023 NLTK Project
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT

import operator
from functools import reduce
from itertools import chain

from nltk.sem.logic import (
    APP,
    AbstractVariableExpression,
    AllExpression,
    AndExpression,
    ApplicationExpression,
    BinaryExpression,
    BooleanExpression,
    ConstantExpression,
    EqualityExpression,
    EventVariableExpression,
    ExistsExpression,
    Expression,
    FunctionVariableExpression,
    ImpExpression,
    IndividualVariableExpression,
    LambdaExpression,
    LogicParser,
    NegatedExpression,
    OrExpression,
    Tokens,
    Variable,
    is_eventvar,
    is_funcvar,
    is_indvar,
    unique_variable,
)

# Import Tkinter-based modules if they are available
try:
    from tkinter import Canvas, Tk
    from tkinter.font import Font

    from nltk.util import in_idle

except ImportError:
    # No need to print a warning here, nltk.draw has already printed one.
    pass


class DrtTokens(Tokens):
    DRS = "DRS"
    DRS_CONC = "+"
    PRONOUN = "PRO"
    OPEN_BRACKET = "["
    CLOSE_BRACKET = "]"
    COLON = ":"

    PUNCT = [DRS_CONC, OPEN_BRACKET, CLOSE_BRACKET, COLON]

    SYMBOLS = Tokens.SYMBOLS + PUNCT

    TOKENS = Tokens.TOKENS + [DRS] + PUNCT


class DrtParser(LogicParser):
    """A lambda calculus expression parser."""

    def __init__(self):
        LogicParser.__init__(self)

        self.operator_precedence = dict(
            [(x, 1) for x in DrtTokens.LAMBDA_LIST]
            + [(x, 2) for x in DrtTokens.NOT_LIST]
            + [(APP, 3)]
            + [(x, 4) for x in DrtTokens.EQ_LIST + Tokens.NEQ_LIST]
            + [(DrtTokens.COLON, 5)]
            + [(DrtTokens.DRS_CONC, 6)]
            + [(x, 7) for x in DrtTokens.OR_LIST]
            + [(x, 8) for x in DrtTokens.IMP_LIST]
            + [(None, 9)]
        )

    def get_all_symbols(self):
        """This method exists to be overridden"""
        return DrtTokens.SYMBOLS

    def isvariable(self, tok):
        return tok not in DrtTokens.TOKENS

    def handle(self, tok, context):
        """This method is intended to be overridden for logics that
        use different operators or expressions"""
        if tok in DrtTokens.NOT_LIST:
            return self.handle_negation(tok, context)

        elif tok in DrtTokens.LAMBDA_LIST:
            return self.handle_lambda(tok, context)

        elif tok == DrtTokens.OPEN:
            if self.inRange(0) and self.token(0) == DrtTokens.OPEN_BRACKET:
                return self.handle_DRS(tok, context)
            else:
                return self.handle_open(tok, context)

        elif tok.upper() == DrtTokens.DRS:
            self.assertNextToken(DrtTokens.OPEN)
            return self.handle_DRS(tok, context)

        elif self.isvariable(tok):
            if self.inRange(0) and self.token(0) == DrtTokens.COLON:
                return self.handle_prop(tok, context)
            else:
                return self.handle_variable(tok, context)

    def make_NegatedExpression(self, expression):
        return DrtNegatedExpression(expression)

    def handle_DRS(self, tok, context):
        # a DRS
        refs = self.handle_refs()
        if (
            self.inRange(0) and self.token(0) == DrtTokens.COMMA
        ):  # if there is a comma (it's optional)
            self.token()  # swallow the comma
        conds = self.handle_conds(context)
        self.assertNextToken(DrtTokens.CLOSE)
        return DRS(refs, conds, None)

    def handle_refs(self):
        self.assertNextToken(DrtTokens.OPEN_BRACKET)
        refs = []
        while self.inRange(0) and self.token(0) != DrtTokens.CLOSE_BRACKET:
            # Support expressions like: DRS([x y],C) == DRS([x,y],C)
            if refs and self.token(0) == DrtTokens.COMMA:
                self.token()  # swallow the comma
            refs.append(self.get_next_token_variable("quantified"))
        self.assertNextToken(DrtTokens.CLOSE_BRACKET)
        return refs

    def handle_conds(self, context):
        self.assertNextToken(DrtTokens.OPEN_BRACKET)
        conds = []
        while self.inRange(0) and self.token(0) != DrtTokens.CLOSE_BRACKET:
            # Support expressions like: DRS([x y],C) == DRS([x, y],C)
            if conds and self.token(0) == DrtTokens.COMMA:
                self.token()  # swallow the comma
            conds.append(self.process_next_expression(context))
        self.assertNextToken(DrtTokens.CLOSE_BRACKET)
        return conds

    def handle_prop(self, tok, context):
        variable = self.make_VariableExpression(tok)
        self.assertNextToken(":")
        drs = self.process_next_expression(DrtTokens.COLON)
        return DrtProposition(variable, drs)

    def make_EqualityExpression(self, first, second):
        """This method serves as a hook for other logic parsers that
        have different equality expression classes"""
        return DrtEqualityExpression(first, second)

    def get_BooleanExpression_factory(self, tok):
        """This method serves as a hook for other logic parsers that
        have different boolean operators"""
        if tok == DrtTokens.DRS_CONC:
            return lambda first, second: DrtConcatenation(first, second, None)
        elif tok in DrtTokens.OR_LIST:
            return DrtOrExpression
        elif tok in DrtTokens.IMP_LIST:

            def make_imp_expression(first, second):
                if isinstance(first, DRS):
                    return DRS(first.refs, first.conds, second)
                if isinstance(first, DrtConcatenation):
                    return DrtConcatenation(first.first, first.second, second)
                raise Exception("Antecedent of implication must be a DRS")

            return make_imp_expression
        else:
            return None

    def make_BooleanExpression(self, factory, first, second):
        return factory(first, second)

    def make_ApplicationExpression(self, function, argument):
        return DrtApplicationExpression(function, argument)

    def make_VariableExpression(self, name):
        return DrtVariableExpression(Variable(name))

    def make_LambdaExpression(self, variables, term):
        return DrtLambdaExpression(variables, term)


class DrtExpression:
    """
    This is the base abstract DRT Expression from which every DRT
    Expression extends.
    """

    _drt_parser = DrtParser()

    @classmethod
    def fromstring(cls, s):
        return cls._drt_parser.parse(s)

    def applyto(self, other):
        return DrtApplicationExpression(self, other)

    def __neg__(self):
        return DrtNegatedExpression(self)

    def __and__(self, other):
        return NotImplemented

    def __or__(self, other):
        assert isinstance(other, DrtExpression)
        return DrtOrExpression(self, other)

    def __gt__(self, other):
        assert isinstance(other, DrtExpression)
        if isinstance(self, DRS):
            return DRS(self.refs, self.conds, other)
        if isinstance(self, DrtConcatenation):
            return DrtConcatenation(self.first, self.second, other)
        raise Exception("Antecedent of implication must be a DRS")

    def equiv(self, other, prover=None):
        """
        Check for logical equivalence.
        Pass the expression (self <-> other) to the theorem prover.
        If the prover says it is valid, then the self and other are equal.

        :param other: an ``DrtExpression`` to check equality against
        :param prover: a ``nltk.inference.api.Prover``
        """
        assert isinstance(other, DrtExpression)

        f1 = self.simplify().fol()
        f2 = other.simplify().fol()
        return f1.equiv(f2, prover)

    @property
    def type(self):
        raise AttributeError(
            "'%s' object has no attribute 'type'" % self.__class__.__name__
        )

    def typecheck(self, signature=None):
        raise NotImplementedError()

    def __add__(self, other):
        return DrtConcatenation(self, other, None)

    def get_refs(self, recursive=False):
        """
        Return the set of discourse referents in this DRS.
        :param recursive: bool Also find discourse referents in subterms?
        :return: list of ``Variable`` objects
        """
        raise NotImplementedError()

    def is_pronoun_function(self):
        """Is self of the form "PRO(x)"?"""
        return (
            isinstance(self, DrtApplicationExpression)
            and isinstance(self.function, DrtAbstractVariableExpression)
            and self.function.variable.name == DrtTokens.PRONOUN
            and isinstance(self.argument, DrtIndividualVariableExpression)
        )

    def make_EqualityExpression(self, first, second):
        return DrtEqualityExpression(first, second)

    def make_VariableExpression(self, variable):
        return DrtVariableExpression(variable)

    def resolve_anaphora(self):
        return resolve_anaphora(self)

    def eliminate_equality(self):
        return self.visit_structured(lambda e: e.eliminate_equality(), self.__class__)

    def pretty_format(self):
        """
        Draw the DRS
        :return: the pretty print string
        """
        return "\n".join(self._pretty())

    def pretty_print(self):
        print(self.pretty_format())

    def draw(self):
        DrsDrawer(self).draw()


class DRS(DrtExpression, Expression):
    """A Discourse Representation Structure."""

    def __init__(self, refs, conds, consequent=None):
        """
        :param refs: list of ``DrtIndividualVariableExpression`` for the
            discourse referents
        :param conds: list of ``Expression`` for the conditions
        """
        self.refs = refs
        self.conds = conds
        self.consequent = consequent

    def replace(self, variable, expression, replace_bound=False, alpha_convert=True):
        """Replace all instances of variable v with expression E in self,
        where v is free in self."""
        if variable in self.refs:
            # if a bound variable is the thing being replaced
            if not replace_bound:
                return self
            else:
                i = self.refs.index(variable)
                if self.consequent:
                    consequent = self.consequent.replace(
                        variable, expression, True, alpha_convert
                    )
                else:
                    consequent = None
                return DRS(
                    self.refs[:i] + [expression.variable] + self.refs[i + 1 :],
                    [
                        cond.replace(variable, expression, True, alpha_convert)
                        for cond in self.conds
                    ],
                    consequent,
                )
        else:
            if alpha_convert:
                # any bound variable that appears in the expression must
                # be alpha converted to avoid a conflict
                for ref in set(self.refs) & expression.free():
                    newvar = unique_variable(ref)
                    newvarex = DrtVariableExpression(newvar)
                    i = self.refs.index(ref)
                    if self.consequent:
                        consequent = self.consequent.replace(
                            ref, newvarex, True, alpha_convert
                        )
                    else:
                        consequent = None
                    self = DRS(
                        self.refs[:i] + [newvar] + self.refs[i + 1 :],
                        [
                            cond.replace(ref, newvarex, True, alpha_convert)
                            for cond in self.conds
                        ],
                        consequent,
                    )

            # replace in the conditions
            if self.consequent:
                consequent = self.consequent.replace(
                    variable, expression, replace_bound, alpha_convert
                )
            else:
                consequent = None
            return DRS(
                self.refs,
                [
                    cond.replace(variable, expression, replace_bound, alpha_convert)
                    for cond in self.conds
                ],
                consequent,
            )

    def free(self):
        """:see: Expression.free()"""
        conds_free = reduce(operator.or_, [c.free() for c in self.conds], set())
        if self.consequent:
            conds_free.update(self.consequent.free())
        return conds_free - set(self.refs)

    def get_refs(self, recursive=False):
        """:see: AbstractExpression.get_refs()"""
        if recursive:
            conds_refs = self.refs + list(
                chain.from_iterable(c.get_refs(True) for c in self.conds)
            )
            if self.consequent:
                conds_refs.extend(self.consequent.get_refs(True))
            return conds_refs
        else:
            return self.refs

    def visit(self, function, combinator):
        """:see: Expression.visit()"""
        parts = list(map(function, self.conds))
        if self.consequent:
            parts.append(function(self.consequent))
        return combinator(parts)

    def visit_structured(self, function, combinator):
        """:see: Expression.visit_structured()"""
        consequent = function(self.consequent) if self.consequent else None
        return combinator(self.refs, list(map(function, self.conds)), consequent)

    def eliminate_equality(self):
        drs = self
        i = 0
        while i < len(drs.conds):
            cond = drs.conds[i]
            if (
                isinstance(cond, EqualityExpression)
                and isinstance(cond.first, AbstractVariableExpression)
                and isinstance(cond.second, AbstractVariableExpression)
            ):
                drs = DRS(
                    list(set(drs.refs) - {cond.second.variable}),
                    drs.conds[:i] + drs.conds[i + 1 :],
                    drs.consequent,
                )
                if cond.second.variable != cond.first.variable:
                    drs = drs.replace(cond.second.variable, cond.first, False, False)
                    i = 0
                i -= 1
            i += 1

        conds = []
        for cond in drs.conds:
            new_cond = cond.eliminate_equality()
            new_cond_simp = new_cond.simplify()
            if (
                not isinstance(new_cond_simp, DRS)
                or new_cond_simp.refs
                or new_cond_simp.conds
                or new_cond_simp.consequent
            ):
                conds.append(new_cond)

        consequent = drs.consequent.eliminate_equality() if drs.consequent else None
        return DRS(drs.refs, conds, consequent)

    def fol(self):
        if self.consequent:
            accum = None
            if self.conds:
                accum = reduce(AndExpression, [c.fol() for c in self.conds])

            if accum:
                accum = ImpExpression(accum, self.consequent.fol())
            else:
                accum = self.consequent.fol()

            for ref in self.refs[::-1]:
                accum = AllExpression(ref, accum)

            return accum

        else:
            if not self.conds:
                raise Exception("Cannot convert DRS with no conditions to FOL.")
            accum = reduce(AndExpression, [c.fol() for c in self.conds])
            for ref in map(Variable, self._order_ref_strings(self.refs)[::-1]):
                accum = ExistsExpression(ref, accum)
            return accum

    def _pretty(self):
        refs_line = " ".join(self._order_ref_strings(self.refs))

        cond_lines = [
            cond
            for cond_line in [
                filter(lambda s: s.strip(), cond._pretty()) for cond in self.conds
            ]
            for cond in cond_line
        ]
        length = max([len(refs_line)] + list(map(len, cond_lines)))
        drs = (
            [
                " _" + "_" * length + "_ ",
                "| " + refs_line.ljust(length) + " |",
                "|-" + "-" * length + "-|",
            ]
            + ["| " + line.ljust(length) + " |" for line in cond_lines]
            + ["|_" + "_" * length + "_|"]
        )
        if self.consequent:
            return DrtBinaryExpression._assemble_pretty(
                drs, DrtTokens.IMP, self.consequent._pretty()
            )
        return drs

    def _order_ref_strings(self, refs):
        strings = ["%s" % ref for ref in refs]
        ind_vars = []
        func_vars = []
        event_vars = []
        other_vars = []
        for s in strings:
            if is_indvar(s):
                ind_vars.append(s)
            elif is_funcvar(s):
                func_vars.append(s)
            elif is_eventvar(s):
                event_vars.append(s)
            else:
                other_vars.append(s)
        return (
            sorted(other_vars)
            + sorted(event_vars, key=lambda v: int([v[2:], -1][len(v[2:]) == 0]))
            + sorted(func_vars, key=lambda v: (v[0], int([v[1:], -1][len(v[1:]) == 0])))
            + sorted(ind_vars, key=lambda v: (v[0], int([v[1:], -1][len(v[1:]) == 0])))
        )

    def __eq__(self, other):
        r"""Defines equality modulo alphabetic variance.
        If we are comparing \x.M  and \y.N, then check equality of M and N[x/y]."""
        if isinstance(other, DRS):
            if len(self.refs) == len(other.refs):
                converted_other = other
                for (r1, r2) in zip(self.refs, converted_other.refs):
                    varex = self.make_VariableExpression(r1)
                    converted_other = converted_other.replace(r2, varex, True)
                if self.consequent == converted_other.consequent and len(
                    self.conds
                ) == len(converted_other.conds):
                    for c1, c2 in zip(self.conds, converted_other.conds):
                        if not (c1 == c2):
                            return False
                    return True
        return False

    def __ne__(self, other):
        return not self == other

    __hash__ = Expression.__hash__

    def __str__(self):
        drs = "([{}],[{}])".format(
            ",".join(self._order_ref_strings(self.refs)),
            ", ".join("%s" % cond for cond in self.conds),
        )  # map(str, self.conds)))
        if self.consequent:
            return (
                DrtTokens.OPEN
                + drs
                + " "
                + DrtTokens.IMP
                + " "
                + "%s" % self.consequent
                + DrtTokens.CLOSE
            )
        return drs


def DrtVariableExpression(variable):
    """
    This is a factory method that instantiates and returns a subtype of
    ``DrtAbstractVariableExpression`` appropriate for the given variable.
    """
    if is_indvar(variable.name):
        return DrtIndividualVariableExpression(variable)
    elif is_funcvar(variable.name):
        return DrtFunctionVariableExpression(variable)
    elif is_eventvar(variable.name):
        return DrtEventVariableExpression(variable)
    else:
        return DrtConstantExpression(variable)


class DrtAbstractVariableExpression(DrtExpression, AbstractVariableExpression):
    def fol(self):
        return self

    def get_refs(self, recursive=False):
        """:see: AbstractExpression.get_refs()"""
        return []

    def _pretty(self):
        s = "%s" % self
        blank = " " * len(s)
        return [blank, blank, s, blank]

    def eliminate_equality(self):
        return self


class DrtIndividualVariableExpression(
    DrtAbstractVariableExpression, IndividualVariableExpression
):
    pass


class DrtFunctionVariableExpression(
    DrtAbstractVariableExpression, FunctionVariableExpression
):
    pass


class DrtEventVariableExpression(
    DrtIndividualVariableExpression, EventVariableExpression
):
    pass


class DrtConstantExpression(DrtAbstractVariableExpression, ConstantExpression):
    pass


class DrtProposition(DrtExpression, Expression):
    def __init__(self, variable, drs):
        self.variable = variable
        self.drs = drs

    def replace(self, variable, expression, replace_bound=False, alpha_convert=True):
        if self.variable == variable:
            assert isinstance(
                expression, DrtAbstractVariableExpression
            ), "Can only replace a proposition label with a variable"
            return DrtProposition(
                expression.variable,
                self.drs.replace(variable, expression, replace_bound, alpha_convert),
            )
        else:
            return DrtProposition(
                self.variable,
                self.drs.replace(variable, expression, replace_bound, alpha_convert),
            )

    def eliminate_equality(self):
        return DrtProposition(self.variable, self.drs.eliminate_equality())

    def get_refs(self, recursive=False):
        return self.drs.get_refs(True) if recursive else []

    def __eq__(self, other):
        return (
            self.__class__ == other.__class__
            and self.variable == other.variable
            and self.drs == other.drs
        )

    def __ne__(self, other):
        return not self == other

    __hash__ = Expression.__hash__

    def fol(self):
        return self.drs.fol()

    def _pretty(self):
        drs_s = self.drs._pretty()
        blank = " " * len("%s" % self.variable)
        return (
            [blank + " " + line for line in drs_s[:1]]
            + ["%s" % self.variable + ":" + line for line in drs_s[1:2]]
            + [blank + " " + line for line in drs_s[2:]]
        )

    def visit(self, function, combinator):
        """:see: Expression.visit()"""
        return combinator([function(self.drs)])

    def visit_structured(self, function, combinator):
        """:see: Expression.visit_structured()"""
        return combinator(self.variable, function(self.drs))

    def __str__(self):
        return f"prop({self.variable}, {self.drs})"


class DrtNegatedExpression(DrtExpression, NegatedExpression):
    def fol(self):
        return NegatedExpression(self.term.fol())

    def get_refs(self, recursive=False):
        """:see: AbstractExpression.get_refs()"""
        return self.term.get_refs(recursive)

    def _pretty(self):
        term_lines = self.term._pretty()
        return (
            ["    " + line for line in term_lines[:2]]
            + ["__  " + line for line in term_lines[2:3]]
            + ["  | " + line for line in term_lines[3:4]]
            + ["    " + line for line in term_lines[4:]]
        )


class DrtLambdaExpression(DrtExpression, LambdaExpression):
    def alpha_convert(self, newvar):
        """Rename all occurrences of the variable introduced by this variable
        binder in the expression to ``newvar``.
        :param newvar: ``Variable``, for the new variable
        """
        return self.__class__(
            newvar,
            self.term.replace(self.variable, DrtVariableExpression(newvar), True),
        )

    def fol(self):
        return LambdaExpression(self.variable, self.term.fol())

    def _pretty(self):
        variables = [self.variable]
        term = self.term
        while term.__class__ == self.__class__:
            variables.append(term.variable)
            term = term.term
        var_string = " ".join("%s" % v for v in variables) + DrtTokens.DOT
        term_lines = term._pretty()
        blank = " " * len(var_string)
        return (
            ["    " + blank + line for line in term_lines[:1]]
            + [r" \  " + blank + line for line in term_lines[1:2]]
            + [r" /\ " + var_string + line for line in term_lines[2:3]]
            + ["    " + blank + line for line in term_lines[3:]]
        )

    def get_refs(self, recursive=False):
        """:see: AbstractExpression.get_refs()"""
        return (
            [self.variable] + self.term.get_refs(True) if recursive else [self.variable]
        )


class DrtBinaryExpression(DrtExpression, BinaryExpression):
    def get_refs(self, recursive=False):
        """:see: AbstractExpression.get_refs()"""
        return (
            self.first.get_refs(True) + self.second.get_refs(True) if recursive else []
        )

    def _pretty(self):
        return DrtBinaryExpression._assemble_pretty(
            self._pretty_subex(self.first),
            self.getOp(),
            self._pretty_subex(self.second),
        )

    @staticmethod
    def _assemble_pretty(first_lines, op, second_lines):
        max_lines = max(len(first_lines), len(second_lines))
        first_lines = _pad_vertically(first_lines, max_lines)
        second_lines = _pad_vertically(second_lines, max_lines)
        blank = " " * len(op)
        first_second_lines = list(zip(first_lines, second_lines))
        return (
            [
                " " + first_line + " " + blank + " " + second_line + " "
                for first_line, second_line in first_second_lines[:2]
            ]
            + [
                "(" + first_line + " " + op + " " + second_line + ")"
                for first_line, second_line in first_second_lines[2:3]
            ]
            + [
                " " + first_line + " " + blank + " " + second_line + " "
                for first_line, second_line in first_second_lines[3:]
            ]
        )

    def _pretty_subex(self, subex):
        return subex._pretty()


class DrtBooleanExpression(DrtBinaryExpression, BooleanExpression):
    pass


class DrtOrExpression(DrtBooleanExpression, OrExpression):
    def fol(self):
        return OrExpression(self.first.fol(), self.second.fol())

    def _pretty_subex(self, subex):
        if isinstance(subex, DrtOrExpression):
            return [line[1:-1] for line in subex._pretty()]
        return DrtBooleanExpression._pretty_subex(self, subex)


class DrtEqualityExpression(DrtBinaryExpression, EqualityExpression):
    def fol(self):
        return EqualityExpression(self.first.fol(), self.second.fol())


class DrtConcatenation(DrtBooleanExpression):
    """DRS of the form '(DRS + DRS)'"""

    def __init__(self, first, second, consequent=None):
        DrtBooleanExpression.__init__(self, first, second)
        self.consequent = consequent

    def replace(self, variable, expression, replace_bound=False, alpha_convert=True):
        """Replace all instances of variable v with expression E in self,
        where v is free in self."""
        first = self.first
        second = self.second
        consequent = self.consequent

        # If variable is bound
        if variable in self.get_refs():
            if replace_bound:
                first = first.replace(
                    variable, expression, replace_bound, alpha_convert
                )
                second = second.replace(
                    variable, expression, replace_bound, alpha_convert
                )
                if consequent:
                    consequent = consequent.replace(
                        variable, expression, replace_bound, alpha_convert
                    )
        else:
            if alpha_convert:
                # alpha convert every ref that is free in 'expression'
                for ref in set(self.get_refs(True)) & expression.free():
                    v = DrtVariableExpression(unique_variable(ref))
                    first = first.replace(ref, v, True, alpha_convert)
                    second = second.replace(ref, v, True, alpha_convert)
                    if consequent:
                        consequent = consequent.replace(ref, v, True, alpha_convert)

            first = first.replace(variable, expression, replace_bound, alpha_convert)
            second = second.replace(variable, expression, replace_bound, alpha_convert)
            if consequent:
                consequent = consequent.replace(
                    variable, expression, replace_bound, alpha_convert
                )

        return self.__class__(first, second, consequent)

    def eliminate_equality(self):
        # TODO: at some point.  for now, simplify.
        drs = self.simplify()
        assert not isinstance(drs, DrtConcatenation)
        return drs.eliminate_equality()

    def simplify(self):
        first = self.first.simplify()
        second = self.second.simplify()
        consequent = self.consequent.simplify() if self.consequent else None

        if isinstance(first, DRS) and isinstance(second, DRS):
            # For any ref that is in both 'first' and 'second'
            for ref in set(first.get_refs(True)) & set(second.get_refs(True)):
                # alpha convert the ref in 'second' to prevent collision
                newvar = DrtVariableExpression(unique_variable(ref))
                second = second.replace(ref, newvar, True)

            return DRS(first.refs + second.refs, first.conds + second.conds, consequent)
        else:
            return self.__class__(first, second, consequent)

    def get_refs(self, recursive=False):
        """:see: AbstractExpression.get_refs()"""
        refs = self.first.get_refs(recursive) + self.second.get_refs(recursive)
        if self.consequent and recursive:
            refs.extend(self.consequent.get_refs(True))
        return refs

    def getOp(self):
        return DrtTokens.DRS_CONC

    def __eq__(self, other):
        r"""Defines equality modulo alphabetic variance.
        If we are comparing \x.M  and \y.N, then check equality of M and N[x/y]."""
        if isinstance(other, DrtConcatenation):
            self_refs = self.get_refs()
            other_refs = other.get_refs()
            if len(self_refs) == len(other_refs):
                converted_other = other
                for (r1, r2) in zip(self_refs, other_refs):
                    varex = self.make_VariableExpression(r1)
                    converted_other = converted_other.replace(r2, varex, True)
                return (
                    self.first == converted_other.first
                    and self.second == converted_other.second
                    and self.consequent == converted_other.consequent
                )
        return False

    def __ne__(self, other):
        return not self == other

    __hash__ = DrtBooleanExpression.__hash__

    def fol(self):
        e = AndExpression(self.first.fol(), self.second.fol())
        if self.consequent:
            e = ImpExpression(e, self.consequent.fol())
        return e

    def _pretty(self):
        drs = DrtBinaryExpression._assemble_pretty(
            self._pretty_subex(self.first),
            self.getOp(),
            self._pretty_subex(self.second),
        )
        if self.consequent:
            drs = DrtBinaryExpression._assemble_pretty(
                drs, DrtTokens.IMP, self.consequent._pretty()
            )
        return drs

    def _pretty_subex(self, subex):
        if isinstance(subex, DrtConcatenation):
            return [line[1:-1] for line in subex._pretty()]
        return DrtBooleanExpression._pretty_subex(self, subex)

    def visit(self, function, combinator):
        """:see: Expression.visit()"""
        if self.consequent:
            return combinator(
                [function(self.first), function(self.second), function(self.consequent)]
            )
        else:
            return combinator([function(self.first), function(self.second)])

    def __str__(self):
        first = self._str_subex(self.first)
        second = self._str_subex(self.second)
        drs = Tokens.OPEN + first + " " + self.getOp() + " " + second + Tokens.CLOSE
        if self.consequent:
            return (
                DrtTokens.OPEN
                + drs
                + " "
                + DrtTokens.IMP
                + " "
                + "%s" % self.consequent
                + DrtTokens.CLOSE
            )
        return drs

    def _str_subex(self, subex):
        s = "%s" % subex
        if isinstance(subex, DrtConcatenation) and subex.consequent is None:
            return s[1:-1]
        return s


class DrtApplicationExpression(DrtExpression, ApplicationExpression):
    def fol(self):
        return ApplicationExpression(self.function.fol(), self.argument.fol())

    def get_refs(self, recursive=False):
        """:see: AbstractExpression.get_refs()"""
        return (
            self.function.get_refs(True) + self.argument.get_refs(True)
            if recursive
            else []
        )

    def _pretty(self):
        function, args = self.uncurry()
        function_lines = function._pretty()
        args_lines = [arg._pretty() for arg in args]
        max_lines = max(map(len, [function_lines] + args_lines))
        function_lines = _pad_vertically(function_lines, max_lines)
        args_lines = [_pad_vertically(arg_lines, max_lines) for arg_lines in args_lines]
        func_args_lines = list(zip(function_lines, list(zip(*args_lines))))
        return (
            [
                func_line + " " + " ".join(args_line) + " "
                for func_line, args_line in func_args_lines[:2]
            ]
            + [
                func_line + "(" + ",".join(args_line) + ")"
                for func_line, args_line in func_args_lines[2:3]
            ]
            + [
                func_line + " " + " ".join(args_line) + " "
                for func_line, args_line in func_args_lines[3:]
            ]
        )


def _pad_vertically(lines, max_lines):
    pad_line = [" " * len(lines[0])]
    return lines + pad_line * (max_lines - len(lines))


class PossibleAntecedents(list, DrtExpression, Expression):
    def free(self):
        """Set of free variables."""
        return set(self)

    def replace(self, variable, expression, replace_bound=False, alpha_convert=True):
        """Replace all instances of variable v with expression E in self,
        where v is free in self."""
        result = PossibleAntecedents()
        for item in self:
            if item == variable:
                self.append(expression)
            else:
                self.append(item)
        return result

    def _pretty(self):
        s = "%s" % self
        blank = " " * len(s)
        return [blank, blank, s]

    def __str__(self):
        return "[" + ",".join("%s" % it for it in self) + "]"


class AnaphoraResolutionException(Exception):
    pass


def resolve_anaphora(expression, trail=[]):
    if isinstance(expression, ApplicationExpression):
        if expression.is_pronoun_function():
            possible_antecedents = PossibleAntecedents()
            for ancestor in trail:
                for ref in ancestor.get_refs():
                    refex = expression.make_VariableExpression(ref)

                    # ==========================================================
                    # Don't allow resolution to itself or other types
                    # ==========================================================
                    if refex.__class__ == expression.argument.__class__ and not (
                        refex == expression.argument
                    ):
                        possible_antecedents.append(refex)

            if len(possible_antecedents) == 1:
                resolution = possible_antecedents[0]
            else:
                resolution = possible_antecedents
            return expression.make_EqualityExpression(expression.argument, resolution)
        else:
            r_function = resolve_anaphora(expression.function, trail + [expression])
            r_argument = resolve_anaphora(expression.argument, trail + [expression])
            return expression.__class__(r_function, r_argument)

    elif isinstance(expression, DRS):
        r_conds = []
        for cond in expression.conds:
            r_cond = resolve_anaphora(cond, trail + [expression])

            # if the condition is of the form '(x = [])' then raise exception
            if isinstance(r_cond, EqualityExpression):
                if isinstance(r_cond.first, PossibleAntecedents):
                    # Reverse the order so that the variable is on the left
                    temp = r_cond.first
                    r_cond.first = r_cond.second
                    r_cond.second = temp
                if isinstance(r_cond.second, PossibleAntecedents):
                    if not r_cond.second:
                        raise AnaphoraResolutionException(
                            "Variable '%s' does not "
                            "resolve to anything." % r_cond.first
                        )

            r_conds.append(r_cond)
        if expression.consequent:
            consequent = resolve_anaphora(expression.consequent, trail + [expression])
        else:
            consequent = None
        return expression.__class__(expression.refs, r_conds, consequent)

    elif isinstance(expression, AbstractVariableExpression):
        return expression

    elif isinstance(expression, NegatedExpression):
        return expression.__class__(
            resolve_anaphora(expression.term, trail + [expression])
        )

    elif isinstance(expression, DrtConcatenation):
        if expression.consequent:
            consequent = resolve_anaphora(expression.consequent, trail + [expression])
        else:
            consequent = None
        return expression.__class__(
            resolve_anaphora(expression.first, trail + [expression]),
            resolve_anaphora(expression.second, trail + [expression]),
            consequent,
        )

    elif isinstance(expression, BinaryExpression):
        return expression.__class__(
            resolve_anaphora(expression.first, trail + [expression]),
            resolve_anaphora(expression.second, trail + [expression]),
        )

    elif isinstance(expression, LambdaExpression):
        return expression.__class__(
            expression.variable, resolve_anaphora(expression.term, trail + [expression])
        )


class DrsDrawer:
    BUFFER = 3  # Space between elements
    TOPSPACE = 10  # Space above whole DRS
    OUTERSPACE = 6  # Space to the left, right, and bottom of the while DRS

    def __init__(self, drs, size_canvas=True, canvas=None):
        """
        :param drs: ``DrtExpression``, The DRS to be drawn
        :param size_canvas: bool, True if the canvas size should be the exact size of the DRS
        :param canvas: ``Canvas`` The canvas on which to draw the DRS.  If none is given, create a new canvas.
        """
        master = None
        if not canvas:
            master = Tk()
            master.title("DRT")

            font = Font(family="helvetica", size=12)

            if size_canvas:
                canvas = Canvas(master, width=0, height=0)
                canvas.font = font
                self.canvas = canvas
                (right, bottom) = self._visit(drs, self.OUTERSPACE, self.TOPSPACE)

                width = max(right + self.OUTERSPACE, 100)
                height = bottom + self.OUTERSPACE
                canvas = Canvas(master, width=width, height=height)  # , bg='white')
            else:
                canvas = Canvas(master, width=300, height=300)

            canvas.pack()
            canvas.font = font

        self.canvas = canvas
        self.drs = drs
        self.master = master

    def _get_text_height(self):
        """Get the height of a line of text"""
        return self.canvas.font.metrics("linespace")

    def draw(self, x=OUTERSPACE, y=TOPSPACE):
        """Draw the DRS"""
        self._handle(self.drs, self._draw_command, x, y)

        if self.master and not in_idle():
            self.master.mainloop()
        else:
            return self._visit(self.drs, x, y)

    def _visit(self, expression, x, y):
        """
        Return the bottom-rightmost point without actually drawing the item

        :param expression: the item to visit
        :param x: the top of the current drawing area
        :param y: the left side of the current drawing area
        :return: the bottom-rightmost point
        """
        return self._handle(expression, self._visit_command, x, y)

    def _draw_command(self, item, x, y):
        """
        Draw the given item at the given location

        :param item: the item to draw
        :param x: the top of the current drawing area
        :param y: the left side of the current drawing area
        :return: the bottom-rightmost point
        """
        if isinstance(item, str):
            self.canvas.create_text(x, y, anchor="nw", font=self.canvas.font, text=item)
        elif isinstance(item, tuple):
            # item is the lower-right of a box
            (right, bottom) = item
            self.canvas.create_rectangle(x, y, right, bottom)
            horiz_line_y = (
                y + self._get_text_height() + (self.BUFFER * 2)
            )  # the line separating refs from conds
            self.canvas.create_line(x, horiz_line_y, right, horiz_line_y)

        return self._visit_command(item, x, y)

    def _visit_command(self, item, x, y):
        """
        Return the bottom-rightmost point without actually drawing the item

        :param item: the item to visit
        :param x: the top of the current drawing area
        :param y: the left side of the current drawing area
        :return: the bottom-rightmost point
        """
        if isinstance(item, str):
            return (x + self.canvas.font.measure(item), y + self._get_text_height())
        elif isinstance(item, tuple):
            return item

    def _handle(self, expression, command, x=0, y=0):
        """
        :param expression: the expression to handle
        :param command: the function to apply, either _draw_command or _visit_command
        :param x: the top of the current drawing area
        :param y: the left side of the current drawing area
        :return: the bottom-rightmost point
        """
        if command == self._visit_command:
            # if we don't need to draw the item, then we can use the cached values
            try:
                # attempt to retrieve cached values
                right = expression._drawing_width + x
                bottom = expression._drawing_height + y
                return (right, bottom)
            except AttributeError:
                # the values have not been cached yet, so compute them
                pass

        if isinstance(expression, DrtAbstractVariableExpression):
            factory = self._handle_VariableExpression
        elif isinstance(expression, DRS):
            factory = self._handle_DRS
        elif isinstance(expression, DrtNegatedExpression):
            factory = self._handle_NegatedExpression
        elif isinstance(expression, DrtLambdaExpression):
            factory = self._handle_LambdaExpression
        elif isinstance(expression, BinaryExpression):
            factory = self._handle_BinaryExpression
        elif isinstance(expression, DrtApplicationExpression):
            factory = self._handle_ApplicationExpression
        elif isinstance(expression, PossibleAntecedents):
            factory = self._handle_VariableExpression
        elif isinstance(expression, DrtProposition):
            factory = self._handle_DrtProposition
        else:
            raise Exception(expression.__class__.__name__)

        (right, bottom) = factory(expression, command, x, y)

        # cache the values
        expression._drawing_width = right - x
        expression._drawing_height = bottom - y

        return (right, bottom)

    def _handle_VariableExpression(self, expression, command, x, y):
        return command("%s" % expression, x, y)

    def _handle_NegatedExpression(self, expression, command, x, y):
        # Find the width of the negation symbol
        right = self._visit_command(DrtTokens.NOT, x, y)[0]

        # Handle term
        (right, bottom) = self._handle(expression.term, command, right, y)

        # Handle variables now that we know the y-coordinate
        command(
            DrtTokens.NOT,
            x,
            self._get_centered_top(y, bottom - y, self._get_text_height()),
        )

        return (right, bottom)

    def _handle_DRS(self, expression, command, x, y):
        left = x + self.BUFFER  # indent the left side
        bottom = y + self.BUFFER  # indent the top

        # Handle Discourse Referents
        if expression.refs:
            refs = " ".join("%s" % r for r in expression.refs)
        else:
            refs = "     "
        (max_right, bottom) = command(refs, left, bottom)
        bottom += self.BUFFER * 2

        # Handle Conditions
        if expression.conds:
            for cond in expression.conds:
                (right, bottom) = self._handle(cond, command, left, bottom)
                max_right = max(max_right, right)
                bottom += self.BUFFER
        else:
            bottom += self._get_text_height() + self.BUFFER

        # Handle Box
        max_right += self.BUFFER
        return command((max_right, bottom), x, y)

    def _handle_ApplicationExpression(self, expression, command, x, y):
        function, args = expression.uncurry()
        if not isinstance(function, DrtAbstractVariableExpression):
            # It's not a predicate expression ("P(x,y)"), so leave arguments curried
            function = expression.function
            args = [expression.argument]

        # Get the max bottom of any element on the line
        function_bottom = self._visit(function, x, y)[1]
        max_bottom = max(
            [function_bottom] + [self._visit(arg, x, y)[1] for arg in args]
        )

        line_height = max_bottom - y

        # Handle 'function'
        function_drawing_top = self._get_centered_top(
            y, line_height, function._drawing_height
        )
        right = self._handle(function, command, x, function_drawing_top)[0]

        # Handle open paren
        centred_string_top = self._get_centered_top(
            y, line_height, self._get_text_height()
        )
        right = command(DrtTokens.OPEN, right, centred_string_top)[0]

        # Handle each arg
        for (i, arg) in enumerate(args):
            arg_drawing_top = self._get_centered_top(
                y, line_height, arg._drawing_height
            )
            right = self._handle(arg, command, right, arg_drawing_top)[0]

            if i + 1 < len(args):
                # since it's not the last arg, add a comma
                right = command(DrtTokens.COMMA + " ", right, centred_string_top)[0]

        # Handle close paren
        right = command(DrtTokens.CLOSE, right, centred_string_top)[0]

        return (right, max_bottom)

    def _handle_LambdaExpression(self, expression, command, x, y):
        # Find the width of the lambda symbol and abstracted variables
        variables = DrtTokens.LAMBDA + "%s" % expression.variable + DrtTokens.DOT
        right = self._visit_command(variables, x, y)[0]

        # Handle term
        (right, bottom) = self._handle(expression.term, command, right, y)

        # Handle variables now that we know the y-coordinate
        command(
            variables, x, self._get_centered_top(y, bottom - y, self._get_text_height())
        )

        return (right, bottom)

    def _handle_BinaryExpression(self, expression, command, x, y):
        # Get the full height of the line, based on the operands
        first_height = self._visit(expression.first, 0, 0)[1]
        second_height = self._visit(expression.second, 0, 0)[1]
        line_height = max(first_height, second_height)

        # Handle open paren
        centred_string_top = self._get_centered_top(
            y, line_height, self._get_text_height()
        )
        right = command(DrtTokens.OPEN, x, centred_string_top)[0]

        # Handle the first operand
        first_height = expression.first._drawing_height
        (right, first_bottom) = self._handle(
            expression.first,
            command,
            right,
            self._get_centered_top(y, line_height, first_height),
        )

        # Handle the operator
        right = command(" %s " % expression.getOp(), right, centred_string_top)[0]

        # Handle the second operand
        second_height = expression.second._drawing_height
        (right, second_bottom) = self._handle(
            expression.second,
            command,
            right,
            self._get_centered_top(y, line_height, second_height),
        )

        # Handle close paren
        right = command(DrtTokens.CLOSE, right, centred_string_top)[0]

        return (right, max(first_bottom, second_bottom))

    def _handle_DrtProposition(self, expression, command, x, y):
        # Find the width of the negation symbol
        right = command(expression.variable, x, y)[0]

        # Handle term
        (right, bottom) = self._handle(expression.term, command, right, y)

        return (right, bottom)

    def _get_centered_top(self, top, full_height, item_height):
        """Get the y-coordinate of the point that a figure should start at if
        its height is 'item_height' and it needs to be centered in an area that
        starts at 'top' and is 'full_height' tall."""
        return top + (full_height - item_height) / 2


def demo():
    print("=" * 20 + "TEST PARSE" + "=" * 20)
    dexpr = DrtExpression.fromstring
    print(dexpr(r"([x,y],[sees(x,y)])"))
    print(dexpr(r"([x],[man(x), walks(x)])"))
    print(dexpr(r"\x.\y.([],[sees(x,y)])"))
    print(dexpr(r"\x.([],[walks(x)])(john)"))
    print(dexpr(r"(([x],[walks(x)]) + ([y],[runs(y)]))"))
    print(dexpr(r"(([],[walks(x)]) -> ([],[runs(x)]))"))
    print(dexpr(r"([x],[PRO(x), sees(John,x)])"))
    print(dexpr(r"([x],[man(x), -([],[walks(x)])])"))
    print(dexpr(r"([],[(([x],[man(x)]) -> ([],[walks(x)]))])"))

    print("=" * 20 + "Test fol()" + "=" * 20)
    print(dexpr(r"([x,y],[sees(x,y)])").fol())

    print("=" * 20 + "Test alpha conversion and lambda expression equality" + "=" * 20)
    e1 = dexpr(r"\x.([],[P(x)])")
    print(e1)
    e2 = e1.alpha_convert(Variable("z"))
    print(e2)
    print(e1 == e2)

    print("=" * 20 + "Test resolve_anaphora()" + "=" * 20)
    print(resolve_anaphora(dexpr(r"([x,y,z],[dog(x), cat(y), walks(z), PRO(z)])")))
    print(
        resolve_anaphora(dexpr(r"([],[(([x],[dog(x)]) -> ([y],[walks(y), PRO(y)]))])"))
    )
    print(resolve_anaphora(dexpr(r"(([x,y],[]) + ([],[PRO(x)]))")))

    print("=" * 20 + "Test pretty_print()" + "=" * 20)
    dexpr(r"([],[])").pretty_print()
    dexpr(
        r"([],[([x],[big(x), dog(x)]) -> ([],[bark(x)]) -([x],[walk(x)])])"
    ).pretty_print()
    dexpr(r"([x,y],[x=y]) + ([z],[dog(z), walk(z)])").pretty_print()
    dexpr(r"([],[([x],[]) | ([y],[]) | ([z],[dog(z), walk(z)])])").pretty_print()
    dexpr(r"\P.\Q.(([x],[]) + P(x) + Q(x))(\x.([],[dog(x)]))").pretty_print()


def test_draw():
    try:
        from tkinter import Tk
    except ImportError as e:
        raise ValueError("tkinter is required, but it's not available.")

    expressions = [
        r"x",
        r"([],[])",
        r"([x],[])",
        r"([x],[man(x)])",
        r"([x,y],[sees(x,y)])",
        r"([x],[man(x), walks(x)])",
        r"\x.([],[man(x), walks(x)])",
        r"\x y.([],[sees(x,y)])",
        r"([],[(([],[walks(x)]) + ([],[runs(x)]))])",
        r"([x],[man(x), -([],[walks(x)])])",
        r"([],[(([x],[man(x)]) -> ([],[walks(x)]))])",
    ]

    for e in expressions:
        d = DrtExpression.fromstring(e)
        d.draw()


if __name__ == "__main__":
    demo()
