From cef64cec6ad21c214f1354c4a61f8cdb8438bd76 Mon Sep 17 00:00:00 2001
From: "Emmanouil (Manos) Koukoutos" <emmanouil.koukoutos@epfl.ch>
Date: Thu, 19 Feb 2015 13:30:08 +0100
Subject: [PATCH] leon.collection set operations

---
 library/collection/package.scala              | 24 +++++++++
 .../scala/leon/codegen/CodeGeneration.scala   | 53 +++++++++++++++++++
 .../scala/leon/datagen/VanuatooDataGen.scala  |  9 ++++
 .../leon/evaluators/RecursiveEvaluator.scala  | 10 ++++
 .../frontends/scalac/CodeExtraction.scala     |  1 -
 src/main/scala/leon/utils/Library.scala       |  2 +
 .../leon/test/codegen/CodeGenTests.scala      | 13 +++++
 7 files changed, 111 insertions(+), 1 deletion(-)
 create mode 100644 library/collection/package.scala

diff --git a/library/collection/package.scala b/library/collection/package.scala
new file mode 100644
index 000000000..77d953a10
--- /dev/null
+++ b/library/collection/package.scala
@@ -0,0 +1,24 @@
+/* Copyright 2009-2015 EPFL, Lausanne */
+
+package leon
+
+import leon.annotation._
+import leon.collection.List
+import leon.lang.synthesis.choose
+
+package object collection {
+
+  @library
+  def setToList[A](set: Set[A]): List[A] = choose { 
+    (x: List[A]) => x.content == set
+  }
+
+  @library
+  def setForall[A](set: Set[A], p: A => Boolean): Boolean = 
+    setToList(set).forall(p)
+
+  @library
+  def setExists[A](set: Set[A], p: A => Boolean): Boolean =
+    setToList(set).exists(p)
+
+}
diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala
index bfea7dabe..fd704fd28 100644
--- a/src/main/scala/leon/codegen/CodeGeneration.scala
+++ b/src/main/scala/leon/codegen/CodeGeneration.scala
@@ -396,6 +396,59 @@ trait CodeGeneration {
           case _ =>
         }
         
+      case FunctionInvocation(TypedFunDef(fd, Seq(tp)), Seq(set)) if fd == program.library.setToList.get =>
+
+        val ObjectClass = "java/lang/Object"
+        val IteratorClass = "java/util/Iterator"
+        val nil = CaseClass(CaseClassType(program.library.Nil.get, Seq(tp)), Seq())
+        val cons = program.library.Cons.get
+        val (consName, ccApplySig) = leonClassToJVMInfo(cons).getOrElse {
+          throw CompilationException("Unknown class : " + cons)
+        }
+        
+        mkExpr(nil, ch)
+        mkExpr(set, ch)
+        //if (params.requireMonitor) {
+        //  ch << ALoad(locals.monitorIndex)
+        //}
+
+        // No dynamic dispatching/overriding in Leon, 
+        // so no need to take care of own vs. "super" methods
+        ch << InvokeVirtual(SetClass, "getElements", s"()L$IteratorClass;")
+        
+        val loop = ch.getFreshLabel("loop")
+        val out = ch.getFreshLabel("out")
+        ch << Label(loop)
+        // list, it
+        ch << DUP
+        // list, it, it
+        ch << InvokeInterface(IteratorClass, "hasNext", "()Z")
+        // list, it, hasNext
+        ch << IfEq(out)
+        // list, it
+        ch << DUP2
+        // list, it, list, it
+        ch << InvokeInterface(IteratorClass, "next", s"()L$ObjectClass;") << SWAP
+        // list, it, elem, list
+        ch << New(consName) << DUP << DUP2_X2
+        // list, it, cons, cons, elem, list, cons, cons
+        ch << POP << POP
+        // list, it, cons, cons, elem, list
+        
+        if (params.requireMonitor) {
+          ch << ALoad(locals.monitorIndex) << DUP_X2 << POP
+        }
+        ch << InvokeSpecial(consName, constructorName, ccApplySig)
+        // list, it, newList
+        ch << DUP_X2 << POP << SWAP << POP
+        // newList, it
+        ch << Goto(loop)
+        
+        ch << Label(out)
+        // list, it
+        ch << POP
+        // list
+      
       // Static lazy fields/ functions
       case FunctionInvocation(tfd, as) =>
         val (cn, mn, ms) = leonFunDefToJVMInfo(tfd.fd).getOrElse {
diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala
index e40af01e9..2ffc66cc0 100644
--- a/src/main/scala/leon/datagen/VanuatooDataGen.scala
+++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala
@@ -66,6 +66,15 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
         cs
       })
 
+    case st @ SetType(sub) =>
+      constructors.getOrElse(st, {
+        val cs = for (size <- List(0, 1, 2, 5)) yield {
+          Constructor[Expr, TypeTree]((1 to size).map(i => sub).toList, st, s => FiniteSet(s.toSet).setType(st), st.toString+"@"+size)
+        }
+        constructors += st -> cs
+        cs
+      })
+    
     case tt @ TupleType(parts) =>
       constructors.getOrElse(tt, {
         val cs = List(Constructor[Expr, TypeTree](parts, tt, s => Tuple(s), tt.toString))
diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
index c90f56768..7db4c90ea 100644
--- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
+++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala
@@ -130,6 +130,16 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
         case _ => throw EvalError(typeErrorMsg(first, BooleanType))
       }
 
+    case FunctionInvocation(TypedFunDef(fd, Seq(tp)), Seq(set)) if fd == program.library.setToList.get =>
+      val els = e(set) match {
+        case FiniteSet(els) => els
+        case _ => throw EvalError(typeErrorMsg(set, SetType(tp)))
+      }
+      val cons = program.library.Cons.get
+      val nil = CaseClass(CaseClassType(program.library.Nil.get, Seq(tp)), Seq())
+      def mkCons(h: Expr, t: Expr) = CaseClass(CaseClassType(cons, Seq(tp)), Seq(h,t))
+      els.foldRight(nil)(mkCons)
+      
     case FunctionInvocation(tfd, args) =>
       if (gctx.stepsLeft < 0) {
         throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")")
diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
index 13aefc5b7..67126c0d5 100644
--- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
+++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala
@@ -1607,7 +1607,6 @@ trait CodeExtraction extends ASTExtractors {
             case (IsTyped(a1, SetType(b1)), "contains", List(a2)) =>
               ElementOfSet(a2, a1)
 
-
             // Multiset methods
             case (IsTyped(a1, MultisetType(b1)), "++", List(IsTyped(a2, MultisetType(b2))))  if b1 == b2 =>
               MultisetUnion(a1, a2)
diff --git a/src/main/scala/leon/utils/Library.scala b/src/main/scala/leon/utils/Library.scala
index 7cc8915d7..88ed6e5d8 100644
--- a/src/main/scala/leon/utils/Library.scala
+++ b/src/main/scala/leon/utils/Library.scala
@@ -12,6 +12,8 @@ case class Library(pgm: Program) {
   lazy val Nil  = lookup("leon.collection.Nil") collect { case ccd : CaseClassDef => ccd }
 
   lazy val String = lookup("leon.lang.string.String") collect { case ccd : CaseClassDef => ccd }
+
+  lazy val setToList = lookup("leon.collection.setToList") collect { case fd : FunDef => fd }
   
   def lookup(name: String): Option[Definition] = {
     searchByFullName(name, pgm)
diff --git a/src/test/scala/leon/test/codegen/CodeGenTests.scala b/src/test/scala/leon/test/codegen/CodeGenTests.scala
index 5640cf1bb..fe876c9e5 100644
--- a/src/test/scala/leon/test/codegen/CodeGenTests.scala
+++ b/src/test/scala/leon/test/codegen/CodeGenTests.scala
@@ -420,7 +420,20 @@ class CodeGenTests extends test.LeonTestSuite {
         def test =  sum(l)
       }""",
       IntLiteral(1 + 2 + 3)
+    ),
+    
+    TestCase("SetToList", """
+      import leon.collection._
+      object SetToList {
+        def test = {
+          val s = Set(1, 2, 3, 4, 5)
+          val s2 = setToList(s).content
+          s == s2
+        }
+      }""",
+      BooleanLiteral(true)
     )
+    
   )
   
   
-- 
GitLab