From 2fc9b822a27080f297542608b01d8e64ec2c9da2 Mon Sep 17 00:00:00 2001
From: Katja Goltsova <14182676+cache-nez@users.noreply.github.com>
Date: Fri, 28 Oct 2022 17:45:58 +0200
Subject: [PATCH] Support equivalent names of predicates (#88)

---
 .../src/main/scala/lisa/utils/Parser.scala    | 61 +++++++++++++++----
 .../test/scala/lisa/utils/ParserTest.scala    |  5 +-
 .../test/scala/lisa/utils/PrinterTest.scala   |  6 ++
 3 files changed, 57 insertions(+), 15 deletions(-)

diff --git a/lisa-utils/src/main/scala/lisa/utils/Parser.scala b/lisa-utils/src/main/scala/lisa/utils/Parser.scala
index c0524e54..a795dafa 100644
--- a/lisa-utils/src/main/scala/lisa/utils/Parser.scala
+++ b/lisa-utils/src/main/scala/lisa/utils/Parser.scala
@@ -10,6 +10,30 @@ import scallion.util.Unfolds.unfoldRight
 import silex.*
 
 object Parser {
+  enum Notation {
+    case Prefix, Infix
+  }
+
+  abstract class Label(val id: String, val notation: Notation)
+
+  case class InfixLabel(override val id: String) extends Label(id, Notation.Infix)
+
+  case class PrefixLabel(override val id: String) extends Label(id, Notation.Prefix)
+
+  case class CanonicalLabel(print: Label, internal: Label)
+
+  def equivalentLabelsToMap(labels: List[String], print: Label, internal: Label): Map[String, CanonicalLabel] =
+    labels.map(_ -> CanonicalLabel(print, internal)).toMap
+
+  private val mapping: Map[String, CanonicalLabel] =
+    equivalentLabelsToMap("elem" :: "in" :: "∊" :: Nil, InfixLabel("∊"), PrefixLabel("elem")) ++
+      equivalentLabelsToMap("subset_of" :: "subset" :: "⊆" :: Nil, InfixLabel("⊆"), PrefixLabel("subset_of")) ++
+      equivalentLabelsToMap("sim" :: "same_cardinality" :: "≈" :: Nil, InfixLabel("≈"), PrefixLabel("same_cardinality")) ++
+      equivalentLabelsToMap("=" :: Nil, InfixLabel("="), InfixLabel("="))
+
+  def getPrintName(id: String): Option[Label] = mapping.get(id).map(_.print)
+
+  def getInternalName(id: String): String = mapping.get(id).map(_.internal.id).getOrElse(id)
 
   class ParserException(msg: String) extends Exception(msg)
   object UnreachableException extends ParserException("Internal error: expected unreachable")
@@ -282,11 +306,13 @@ object Parser {
     val INFIX_ARITY = 2
     val infixPredicateLabel: Syntax[ConstantPredicateLabel] = accept(InfixPredicateKind)(
       {
-        case InfixPredicateToken(id) => ConstantPredicateLabel(id, INFIX_ARITY)
+        case InfixPredicateToken(id) => ConstantPredicateLabel(getInternalName(id), INFIX_ARITY)
         case _ => throw UnreachableException
       },
       {
-        case ConstantPredicateLabel(id, INFIX_ARITY) if isInfixPredicate(id) => Seq(InfixPredicateToken(id))
+        case ConstantPredicateLabel(id, INFIX_ARITY) if getPrintName(id).map(_.notation == Notation.Infix).getOrElse(isInfixPredicate(id)) =>
+          val printName = getPrintName(id).map(_.id).getOrElse(id)
+          Seq(InfixPredicateToken(printName))
         case _ => throw UnreachableException
       }
     )
@@ -374,27 +400,36 @@ object Parser {
         // predicate application
         case ConstantToken(id) ~ maybeArgs ~ None =>
           val args = maybeArgs.getOrElse(Seq())
-          PredicateFormula(ConstantPredicateLabel(id, args.size), args)
-        case SchematicToken(id) ~ Some(args) ~ None => PredicateFormula(SchematicPredicateLabel(id, args.size), args)
+          PredicateFormula(ConstantPredicateLabel(getInternalName(id), args.size), args)
+        case SchematicToken(id) ~ Some(args) ~ None => PredicateFormula(SchematicPredicateLabel(getInternalName(id), args.size), args)
         case SchematicToken(id) ~ None ~ None =>
-          PredicateFormula(VariableFormulaLabel(id), Seq())
+          PredicateFormula(VariableFormulaLabel(getInternalName(id)), Seq())
 
-        // equality of two function applications
+        // infix relation of two function applications
         case fun1 ~ args1 ~ Some(pred ~ term2) =>
           PredicateFormula(pred, Seq(createTerm(fun1, args1.getOrElse(Seq())), term2))
 
         case _ => throw UnreachableException
       },
       {
-        case PredicateFormula(label @ ConstantPredicateLabel(id, INFIX_ARITY), Seq(first, second)) if isInfixPredicate(id) => Seq(invertTerm(first) ~ Some(label ~ second))
+        case PredicateFormula(label @ ConstantPredicateLabel(id, INFIX_ARITY), Seq(first, second)) if getPrintName(id).map(_.notation == Notation.Infix).getOrElse(isInfixPredicate(id)) =>
+          Seq(invertTerm(first) ~ Some(label ~ second))
         case PredicateFormula(label, args) =>
-          val prefixApp = label match {
-            case VariableFormulaLabel(id) => SchematicToken(id) ~ None
-            case SchematicPredicateLabel(id, _) => SchematicToken(id) ~ Some(args)
-            case ConstantPredicateLabel(id, 0) => ConstantToken(id) ~ None
-            case ConstantPredicateLabel(id, _) => ConstantToken(id) ~ Some(args)
+          val (canonicalId, isInfix) = getPrintName(label.id).map(l => (l.id, l.notation == Notation.Infix)).getOrElse(label.id, false)
+          if (isInfix && args.size == INFIX_ARITY) {
+            args match {
+              case Seq(first, second) => Seq(invertTerm(first) ~ Some(ConstantPredicateLabel(canonicalId, INFIX_ARITY) ~ second))
+              case _ => throw UnreachableException
+            }
+          } else {
+            val prefixApp = label match {
+              case VariableFormulaLabel(id) => SchematicToken(canonicalId) ~ None
+              case SchematicPredicateLabel(id, _) => SchematicToken(canonicalId) ~ Some(args)
+              case ConstantPredicateLabel(id, 0) => ConstantToken(canonicalId) ~ None
+              case ConstantPredicateLabel(id, _) => ConstantToken(canonicalId) ~ Some(args)
+            }
+            Seq(prefixApp ~ None)
           }
-          Seq(prefixApp ~ None)
       }
     )
 
diff --git a/lisa-utils/src/test/scala/lisa/utils/ParserTest.scala b/lisa-utils/src/test/scala/lisa/utils/ParserTest.scala
index a8e45608..36d5520b 100644
--- a/lisa-utils/src/test/scala/lisa/utils/ParserTest.scala
+++ b/lisa-utils/src/test/scala/lisa/utils/ParserTest.scala
@@ -220,12 +220,13 @@ class ParserTest extends AnyFunSuite with TestUtils {
     )
   }
 
-  test("infix predicates") {
-    val in = ConstantPredicateLabel("∊", 2)
+  test("equivalent names") {
+    val in = ConstantPredicateLabel("elem", 2)
     assert(Parser.parseFormula("x∊y") == PredicateFormula(in, Seq(cx, cy)))
     assert(Parser.parseFormula("x ∊ y") == PredicateFormula(in, Seq(cx, cy)))
     assert(Parser.parseFormula("'x ∊ 'y") == PredicateFormula(in, Seq(x, y)))
     assert(Parser.parseFormula("('x ∊ 'y) /\\ a") == ConnectorFormula(And, Seq(PredicateFormula(in, Seq(x, y)), a)))
     assert(Parser.parseFormula("a \\/ ('x ∊ 'y)") == ConnectorFormula(Or, Seq(a, PredicateFormula(in, Seq(x, y)))))
+
   }
 }
diff --git a/lisa-utils/src/test/scala/lisa/utils/PrinterTest.scala b/lisa-utils/src/test/scala/lisa/utils/PrinterTest.scala
index a43739e9..32b1d2af 100644
--- a/lisa-utils/src/test/scala/lisa/utils/PrinterTest.scala
+++ b/lisa-utils/src/test/scala/lisa/utils/PrinterTest.scala
@@ -260,9 +260,15 @@ class PrinterTest extends AnyFunSuite with TestUtils {
 
   test("infix predicates") {
     val in = ConstantPredicateLabel("∊", 2)
+    val prefixIn = ConstantPredicateLabel("elem", 2)
     assert(Parser.printFormula(PredicateFormula(in, Seq(cx, cy))) == "x ∊ y")
     assert(Parser.printFormula(PredicateFormula(in, Seq(x, y))) == "'x ∊ 'y")
     assert(Parser.printFormula(ConnectorFormula(And, Seq(PredicateFormula(in, Seq(x, y)), a))) == "'x ∊ 'y ∧ a")
     assert(Parser.printFormula(ConnectorFormula(Or, Seq(a, PredicateFormula(in, Seq(x, y))))) == "a ∨ 'x ∊ 'y")
+
+    assert(Parser.printFormula(PredicateFormula(prefixIn, Seq(cx, cy))) == "x ∊ y")
+    assert(Parser.printFormula(PredicateFormula(prefixIn, Seq(x, y))) == "'x ∊ 'y")
+    assert(Parser.printFormula(ConnectorFormula(And, Seq(PredicateFormula(prefixIn, Seq(x, y)), a))) == "'x ∊ 'y ∧ a")
+    assert(Parser.printFormula(ConnectorFormula(Or, Seq(a, PredicateFormula(prefixIn, Seq(x, y))))) == "a ∨ 'x ∊ 'y")
   }
 }
-- 
GitLab