From 96622f3e3c0679ed6308948b07c7d08d1f84dc93 Mon Sep 17 00:00:00 2001
From: Philippe Suter <philippe.suter@gmail.com>
Date: Wed, 23 Jun 2010 15:32:14 +0000
Subject: [PATCH] support for functions with impure bodies

---
 src/funcheck/CodeExtraction.scala | 10 ++++-
 src/funcheck/Extractors.scala     |  7 +++-
 src/funcheck/FunCheckPlugin.scala |  5 ++-
 src/multisets/Main.scala          |  1 -
 src/purescala/Analysis.scala      | 63 ++++++++++++++++++-------------
 src/purescala/Definitions.scala   |  4 +-
 src/purescala/PrettyPrinter.scala |  5 ++-
 src/purescala/Reporter.scala      | 43 ++++++++++++---------
 testcases/IntOperations.scala     | 10 +++++
 9 files changed, 96 insertions(+), 52 deletions(-)

diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala
index 06e25d9d9..6f4a2b8c4 100644
--- a/src/funcheck/CodeExtraction.scala
+++ b/src/funcheck/CodeExtraction.scala
@@ -228,7 +228,13 @@ trait CodeExtraction extends Extractors {
         case _ => ;
       }
       
-      funDef.body = s2ps(realBody)
+      val bodyAttempt = try {
+        Some(scala2PureScala(unit, pluginInstance.silentlyTolerateNonPureBodies)(realBody))
+      } catch {
+        case e: ImpureCodeEncounteredException => None
+      }
+
+      funDef.body = bodyAttempt
       funDef.precondition = reqCont
       funDef.postcondition = ensCont
       funDef
@@ -417,7 +423,7 @@ trait CodeExtraction extends Extractors {
       case _ => {
         if(!silent) {
           println(tr)
-          unit.error(tr.pos, "Could not extract as PureScala.")
+          reporter.info(tr.pos, "Could not extract as PureScala.", true)
         }
         throw ImpureCodeEncounteredException(tree)
       }
diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala
index 5f09c45ff..6236796ef 100644
--- a/src/funcheck/Extractors.scala
+++ b/src/funcheck/Extractors.scala
@@ -44,8 +44,11 @@ trait Extractors {
       /** Extracts the 'require' contract from an expression (only if it's the
        * first call in the block). */
       def unapply(tree: Block): Option[(Tree,Tree)] = tree match {
-        case Block(Apply(ScalaPredef("require"), contractBody :: Nil) :: Nil, body) =>
-          Some((body,contractBody))
+        case Block(Apply(ScalaPredef("require"), contractBody :: Nil) :: rest, body) =>
+          if(rest.isEmpty)
+            Some((body,contractBody))
+          else
+            Some((Block(rest,body),contractBody))
         case _ => None
       }
     }
diff --git a/src/funcheck/FunCheckPlugin.scala b/src/funcheck/FunCheckPlugin.scala
index be396d0a1..5a83136f0 100644
--- a/src/funcheck/FunCheckPlugin.scala
+++ b/src/funcheck/FunCheckPlugin.scala
@@ -13,13 +13,15 @@ class FunCheckPlugin(val global: Global) extends Plugin {
 
   var stopAfterAnalysis: Boolean = true
   var stopAfterExtraction: Boolean = false
+  var silentlyTolerateNonPureBodies: Boolean = false
 
   /** The help message displaying the options for that plugin. */
   override val optionsHelp: Option[String] = Some(
     "  -P:funcheck:uniqid             When pretty-printing funcheck trees, show identifiers IDs" + "\n" +
     "  -P:funcheck:with-code          Allows the compiler to keep running after the static analysis" + "\n" +
     "  -P:funcheck:parse              Checks only whether the program is valid PureScala" + "\n" +
-    "  -P:funcheck:extensions=ex1:... Specifies a list of qualified class names of extensions to be loaded"
+    "  -P:funcheck:extensions=ex1:... Specifies a list of qualified class names of extensions to be loaded" + "\n" +
+    "  -P:funcheck:tolerant           Silently extracts non-pure function bodies as ''unknown''"
   )
 
   /** Processes the command-line options. */
@@ -29,6 +31,7 @@ class FunCheckPlugin(val global: Global) extends Plugin {
         case "with-code" =>                      stopAfterAnalysis = false
         case "uniqid"    =>                      purescala.Settings.showIDs = true
         case "parse"     =>                      stopAfterExtraction = true
+        case "tolerant"  =>                      silentlyTolerateNonPureBodies = true
         case s if s.startsWith("extensions=") => purescala.Settings.extensionNames = s.substring("extensions=".length, s.length)
         case _ => error("Invalid option: " + option)
       }
diff --git a/src/multisets/Main.scala b/src/multisets/Main.scala
index a47962ea3..254b33c46 100644
--- a/src/multisets/Main.scala
+++ b/src/multisets/Main.scala
@@ -6,7 +6,6 @@ import purescala.Trees._
 
 class Main(reporter: Reporter) extends Solver(reporter) {
   val description = "Multiset Solver"
-  println("called!")
 
   def solve(expr: Expr) : Option[Boolean] = {
     reporter.info("Don't know how to solve anything.")
diff --git a/src/purescala/Analysis.scala b/src/purescala/Analysis.scala
index 47cac1a9c..a3dbf7aaf 100644
--- a/src/purescala/Analysis.scala
+++ b/src/purescala/Analysis.scala
@@ -8,10 +8,12 @@ import z3.scala._
 import Extensions._
 
 class Analysis(val program: Program) {
-  val extensions: Seq[Extension] = loadAll(Settings.reporter)
+  val reporter = Settings.reporter
+  val extensions: Seq[Extension] = loadAll(reporter)
+
 
   if(!extensions.isEmpty) {
-    Settings.reporter.info("The following extensions are loaded:\n" + extensions.toList.map(_.description).mkString("  ", ", ", ""))
+    reporter.info("The following extensions are loaded:\n" + extensions.toList.map(_.description).mkString("  ", "\n  ", ""))
   }
 
     // Analysis should check that:
@@ -29,37 +31,46 @@ class Analysis(val program: Program) {
 
         program.mainObject.defs.filter(_.isInstanceOf[FunDef]).foreach(df => {
             val funDef = df.asInstanceOf[FunDef]
+
+            if(funDef.body.isDefined) {
             val vc = postconditionVC(funDef)
-            if(vc != BooleanLiteral(true)) {
-                println("Verification condition (post) for " + funDef.id + ":")
-                println(vc)
-                val (z3f,stupidMap) = toZ3Formula(z3, vc)
-                z3.assertCnstr(z3.mkNot(z3f))
-                //z3.print
-                z3.checkAndGetModel() match {
-                    case (Some(true),m) => {
-                        println("There's a bug! Here's a model for a counter-example:")
-                        m.print
-                    }
-                    case (Some(false),_) => println("Contract satisfied!")
-                    case (None,_) => println("Z3 couldn't run properly or does not know the answer :(")
-                }
+              if(vc != BooleanLiteral(true)) {
+                  reporter.info("Verification condition (post) for " + funDef.id + ":")
+                  reporter.info(vc)
+                  val (z3f,stupidMap) = toZ3Formula(z3, vc)
+                  z3.assertCnstr(z3.mkNot(z3f))
+                  //z3.print
+                  z3.checkAndGetModel() match {
+                      case (Some(true),m) => {
+                          reporter.error("There's a bug! Here's a model for a counter-example:")
+                          m.print
+                      }
+                      case (Some(false),_) => reporter.info("Contract satisfied!")
+                      case (None,_) => reporter.error("Z3 couldn't run properly or does not know the answer :(")
+                  }
+              }
+            } else {
+              if(funDef.postcondition.isDefined) {
+                reporter.warning(funDef, "Could not verify postcondition: function implementation is unknown.")
+              }
             }
         }) 
     }
 
     def postconditionVC(functionDefinition: FunDef) : Expr = {
-        val prec = functionDefinition.precondition
-        val post = functionDefinition.postcondition
+      assert(functionDefinition.body.isDefined)
+      val prec = functionDefinition.precondition
+      val post = functionDefinition.postcondition
+      val body = functionDefinition.body.get
 
-        if(post.isEmpty) {
-            BooleanLiteral(true)
-        } else {
-            if(prec.isEmpty)
-                replaceInExpr(Map(ResultVariable() -> functionDefinition.body), post.get)
-            else
-                Implies(prec.get, replaceInExpr(Map(ResultVariable() -> functionDefinition.body), post.get))
-        }
+      if(post.isEmpty) {
+        BooleanLiteral(true)
+      } else {
+        if(prec.isEmpty)
+          replaceInExpr(Map(ResultVariable() -> body), post.get)
+        else
+          Implies(prec.get, replaceInExpr(Map(ResultVariable() -> body), post.get))
+      }
     }
 
     def flatten(expr: Expr) : (Expr,List[(Variable,Expr)]) = {
diff --git a/src/purescala/Definitions.scala b/src/purescala/Definitions.scala
index 837478e25..42e2807c5 100644
--- a/src/purescala/Definitions.scala
+++ b/src/purescala/Definitions.scala
@@ -74,7 +74,7 @@ object Definitions {
 
   /** Functions (= 'methods' of objects) */
   object FunDef {
-    def unapply(fd: FunDef): Option[(Identifier,TypeTree,VarDecls,Expr,Option[Expr],Option[Expr])] = {
+    def unapply(fd: FunDef): Option[(Identifier,TypeTree,VarDecls,Option[Expr],Option[Expr],Option[Expr])] = {
       if(fd != null) {
         Some((fd.id, fd.returnType, fd.args, fd.body, fd.precondition, fd.postcondition))
       } else {
@@ -83,7 +83,7 @@ object Definitions {
     }
   }
   class FunDef(val id: Identifier, val returnType: TypeTree, val args: VarDecls) extends Definition {
-    var body: Expr = _
+    var body: Option[Expr] = None
     var precondition: Option[Expr] = None
     var postcondition: Option[Expr] = None
   }
diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala
index 479fdf184..c7529cfdf 100644
--- a/src/purescala/PrettyPrinter.scala
+++ b/src/purescala/PrettyPrinter.scala
@@ -289,7 +289,10 @@ object PrettyPrinter {
         nsb.append(") : ")
         nsb = pp(rt, nsb, lvl)
         nsb.append(" = ")
-        pp(body, nsb, lvl)
+        if(body.isDefined)
+          pp(body.get, nsb, lvl)
+        else
+          nsb.append("[unknown function implementation]")
       }
 
       case _ => sb.append("Defn?")
diff --git a/src/purescala/Reporter.scala b/src/purescala/Reporter.scala
index 194409ff4..1836431fa 100644
--- a/src/purescala/Reporter.scala
+++ b/src/purescala/Reporter.scala
@@ -21,26 +21,35 @@ abstract class Reporter {
 }
 
 object DefaultReporter extends Reporter {
-  private val errorPfx = "Error: "
-  private val warningPfx = "Warning: "
-  private val infoPfx = "Info: "
-  private val fatalPfx = "Fatal error: "
+  private val errorPfx   = "[ Error ] "
+  private val warningPfx = "[Warning] "
+  private val infoPfx    = "[ Info  ] "
+  private val fatalPfx   = "[ Fatal ] "
 
   def output(msg: String) : Unit = {
     Console.err.println(msg)
-    Console.err.println("")
   }
 
-  def error(msg: Any) = output(errorPfx + msg.toString)
-  def warning(msg: Any) = output(warningPfx + msg.toString)
-  def info(msg: Any) = output(infoPfx + msg.toString)
-  def fatalError(msg: Any) = { output(fatalPfx + msg.toString); exit(0) }
-  def error(definition: Definition, msg: Any) = output(errorPfx + "\n" + PrettyPrinter(definition) + msg.toString)
-  def warning(definition: Definition, msg: Any) = output(warningPfx + "\n" + PrettyPrinter(definition) + msg.toString)
-  def info(definition: Definition, msg: Any) = output(infoPfx + "\n" + PrettyPrinter(definition) + msg.toString)
-  def fatalError(definition: Definition, msg: Any) = { output(fatalPfx + "\n" + PrettyPrinter(definition) + msg.toString); exit(0) }
-  def error(expr: Expr, msg: Any) = output(errorPfx + "\n" + PrettyPrinter(expr) + msg.toString) 
-  def warning(expr: Expr, msg: Any) = output(warningPfx + "\n" + PrettyPrinter(expr) + msg.toString) 
-  def info(expr: Expr, msg: Any) = output(infoPfx + "\n" + PrettyPrinter(expr) + msg.toString) 
-  def fatalError(expr: Expr, msg: Any) = { output(fatalPfx + "\n" + PrettyPrinter(expr) + msg.toString); exit(0) }
+  private def reline(pfx: String, msg: String) : String = {
+    val color = if(pfx == errorPfx || pfx == warningPfx || pfx == fatalPfx) {
+      Console.RED
+    } else {
+      Console.BLUE
+    }
+    "[" + color + pfx.substring(1, pfx.length-2) + Console.RESET + "] " +
+    msg.trim.replaceAll("\n", "\n" + pfx)
+  }
+
+  def error(msg: Any) = output(reline(errorPfx, msg.toString))
+  def warning(msg: Any) = output(reline(warningPfx, msg.toString))
+  def info(msg: Any) = output(reline(infoPfx, msg.toString))
+  def fatalError(msg: Any) = { output(reline(fatalPfx, msg.toString)); exit(0) }
+  def error(definition: Definition, msg: Any) = output(reline(errorPfx, PrettyPrinter(definition) + "\n" + msg.toString))
+  def warning(definition: Definition, msg: Any) = output(reline(warningPfx, PrettyPrinter(definition) + "\n" + msg.toString))
+  def info(definition: Definition, msg: Any) = output(reline(infoPfx, PrettyPrinter(definition) + "\n" + msg.toString))
+  def fatalError(definition: Definition, msg: Any) = { output(reline(fatalPfx, PrettyPrinter(definition) + "\n" + msg.toString)); exit(0) }
+  def error(expr: Expr, msg: Any) = output(reline(errorPfx, PrettyPrinter(expr) + "\n" + msg.toString)) 
+  def warning(expr: Expr, msg: Any) = output(reline(warningPfx, PrettyPrinter(expr) + "\n" + msg.toString))
+  def info(expr: Expr, msg: Any) = output(reline(infoPfx, PrettyPrinter(expr) + "\n" + msg.toString))
+  def fatalError(expr: Expr, msg: Any) = { output(reline(fatalPfx, PrettyPrinter(expr) + "\n" + msg.toString)); exit(0) }
 }
diff --git a/testcases/IntOperations.scala b/testcases/IntOperations.scala
index 0c79ac4cc..11903892f 100644
--- a/testcases/IntOperations.scala
+++ b/testcases/IntOperations.scala
@@ -3,4 +3,14 @@ object IntOperations {
         require(b >= 0)
         a + b
     } ensuring(_ >= a)
+
+    def factorial(v: Int) : Int = ({
+      require(v >= 0)
+      var c = 2
+      var t = 1
+      while(c <= v) {
+        t = t * c
+      }
+      t
+    }) ensuring(_ >= v)
 }
-- 
GitLab