From 1900e8dd48bd53adadfbbe822c1dc5ef5a20bd23 Mon Sep 17 00:00:00 2001
From: Etienne Kneuss <colder@php.net>
Date: Thu, 4 Apr 2013 15:32:02 +0200
Subject: [PATCH] Implement untrusted solutions, that are then validated

---
 .../scala/leon/synthesis/ParallelSearch.scala |   6 +-
 src/main/scala/leon/synthesis/Rules.scala     |   6 +-
 .../scala/leon/synthesis/SimpleSearch.scala   | 107 +++++++++++++++++-
 .../scala/leon/synthesis/Synthesizer.scala    |  60 +++++++++-
 .../leon/synthesis/search/AndOrGraph.scala    |   4 +
 .../search/AndOrGraphParallelSearch.scala     |   4 +-
 .../search/AndOrGraphPartialSolution.scala    |   6 +-
 .../synthesis/search/AndOrGraphSearch.scala   |  10 +-
 8 files changed, 179 insertions(+), 24 deletions(-)

diff --git a/src/main/scala/leon/synthesis/ParallelSearch.scala b/src/main/scala/leon/synthesis/ParallelSearch.scala
index ca4346ade..dbdbe0517 100644
--- a/src/main/scala/leon/synthesis/ParallelSearch.scala
+++ b/src/main/scala/leon/synthesis/ParallelSearch.scala
@@ -49,13 +49,13 @@ class ParallelSearch(synth: Synthesizer,
     val prefix = "[%-20s] ".format(Option(t.rule).getOrElse("?"))
 
     t.app.apply(sctx) match {
-      case RuleSuccess(sol) =>
+      case RuleSuccess(sol, isTrusted) =>
         synth.synchronized {
           info(prefix+"Got: "+t.problem)
-          info(prefix+"Solved with: "+sol)
+          info(prefix+"Solved"+(if(isTrusted) "" else " (untrusted)")+" with: "+sol)
         }
 
-        ExpandSuccess(sol)
+        ExpandSuccess(sol, isTrusted)
       case RuleDecomposed(sub) =>
         synth.synchronized {
           info(prefix+"Got: "+t.problem)
diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala
index 2f51522c7..2357241c0 100644
--- a/src/main/scala/leon/synthesis/Rules.scala
+++ b/src/main/scala/leon/synthesis/Rules.scala
@@ -57,9 +57,9 @@ abstract class RuleInstantiation(val problem: Problem, val rule: Rule, val onSuc
 }
 
 sealed abstract class RuleApplicationResult
-case class RuleSuccess(solution: Solution)    extends RuleApplicationResult
-case class RuleDecomposed(sub: List[Problem]) extends RuleApplicationResult
-case object RuleApplicationImpossible         extends RuleApplicationResult
+case class RuleSuccess(solution: Solution, isTrusted: Boolean = true)  extends RuleApplicationResult
+case class RuleDecomposed(sub: List[Problem])                          extends RuleApplicationResult
+case object RuleApplicationImpossible                                  extends RuleApplicationResult
 
 object RuleInstantiation {
   def immediateDecomp(problem: Problem, rule: Rule, sub: List[Problem], onSuccess: List[Solution] => Option[Solution], description: String) = {
diff --git a/src/main/scala/leon/synthesis/SimpleSearch.scala b/src/main/scala/leon/synthesis/SimpleSearch.scala
index 06fa7da53..d042fc2e1 100644
--- a/src/main/scala/leon/synthesis/SimpleSearch.scala
+++ b/src/main/scala/leon/synthesis/SimpleSearch.scala
@@ -1,6 +1,8 @@
 package leon
 package synthesis
 
+import purescala.Definitions.FunDef
+
 import synthesis.search._
 
 case class TaskRunRule(app: RuleInstantiation) extends AOAndTask[Solution] {
@@ -48,10 +50,10 @@ class SimpleSearch(synth: Synthesizer,
 
     info(prefix+"Got: "+t.problem)
     t.app.apply(sctx) match {
-      case RuleSuccess(sol) =>
-        info(prefix+"Solved with: "+sol)
+      case RuleSuccess(sol, isTrusted) =>
+        info(prefix+"Solved"+(if(isTrusted) "" else " (untrusted)")+" with: "+sol)
 
-        ExpandSuccess(sol)
+        ExpandSuccess(sol, isTrusted)
       case RuleDecomposed(sub) =>
         info(prefix+"Decomposed into:")
         for(p <- sub) {
@@ -87,6 +89,101 @@ class SimpleSearch(synth: Synthesizer,
     }
   }
 
+  case class SubProgram(p: Problem, fd: FunDef, acc: Set[FunDef])
+
+  def programAt(n: g.Tree): SubProgram = {
+    import purescala.TypeTrees._
+    import purescala.Common._
+    import purescala.TreeOps.replace
+    import purescala.Trees._
+    import purescala.Definitions._
+
+    def programFrom(from: g.AndNode, sp: SubProgram): SubProgram = {
+      if (from.parent.parent eq null) {
+        sp
+      } else {
+        val at = from.parent.parent
+        val res = bestProgramForAnd(at, Map(from.parent -> sp))
+        programFrom(at, res)
+      }
+    }
+
+    def bestProgramForOr(on: g.OrTree): SubProgram = {
+      val problem: Problem = on.task.p
+
+      val fd = problemToFunDef(problem)
+
+      SubProgram(problem, fd, Set(fd))
+    }
+
+    def fundefToSol(p: Problem, fd: FunDef): Solution = {
+      Solution(BooleanLiteral(true), Set(), FunctionInvocation(fd, p.as.map(Variable(_))))
+    }
+
+    def solToSubProgram(p: Problem, s: Solution): SubProgram = {
+      val fd = problemToFunDef(p)
+      fd.precondition = Some(s.pre)
+      fd.body         = Some(s.term)
+
+      SubProgram(p, fd, Set(fd))
+    }
+
+    def bestProgramForAnd(an: g.AndNode, subPrograms: Map[g.OrTree, SubProgram]): SubProgram = {
+      val subSubPrograms = an.subTasks.map(an.subProblems).map( ot =>
+        subPrograms.getOrElse(ot, bestProgramForOr(ot))
+      )
+
+      val allFd = subSubPrograms.flatMap(_.acc)
+      val subSolutions = subSubPrograms.map(ssp => fundefToSol(ssp.p, ssp.fd))
+
+      val sp = solToSubProgram(an.task.problem, an.task.composeSolution(subSolutions).get)
+
+      sp.copy(acc = sp.acc ++ allFd)
+    }
+
+    def problemToFunDef(p: Problem): FunDef = {
+
+      val ret = if (p.xs.size == 1) {
+        p.xs.head.getType
+      } else {
+        TupleType(p.xs.map(_.getType))
+      }
+
+      val freshAs = p.as.map(_.freshen)
+
+      val map = (p.as.map(Variable(_): Expr) zip freshAs.map(Variable(_): Expr)).toMap
+
+      val res = ResultVariable().setType(ret)
+
+      val mapPost: Map[Expr, Expr] = if (p.xs.size > 1) {
+        p.xs.zipWithIndex.map{ case (id, i)  =>
+          Variable(id) -> TupleSelect(res, i+1)
+        }.toMap
+      } else {
+        Map(Variable(p.xs.head) -> ResultVariable().setType(ret))
+      }
+
+      val fd = new FunDef(FreshIdentifier("chimp", true), ret, freshAs.map(id => VarDecl(id, id.getType)))
+      fd.precondition = Some(replace(map, p.pc))
+      fd.postcondition = Some(replace(map++mapPost, p.phi))
+
+      fd
+    }
+
+    n match {
+      case an: g.AndNode =>
+        programFrom(an, bestProgramForAnd(an, Map.empty))
+
+      case on: g.OrNode =>
+        if (on.parent ne null) {
+          programAt(on.parent)
+        } else {
+          bestProgramForOr(on)
+        }
+    }
+  }
+
+
   var shouldStop = false
 
   def searchStep() {
@@ -105,7 +202,7 @@ class SimpleSearch(synth: Synthesizer,
     }
   }
 
-  def search(): Option[Solution] = {
+  def search(): Option[(Solution, Boolean)] = {
     sctx.solver.init()
 
     shouldStop = false
@@ -113,7 +210,7 @@ class SimpleSearch(synth: Synthesizer,
     while (!g.tree.isSolved && !shouldStop) {
       searchStep()
     }
-    g.tree.solution
+    g.tree.solution.map(s => (s, g.tree.isTrustworthy))
   }
 
   override def stop() {
diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala
index ac2acc6f5..36cffecf0 100644
--- a/src/main/scala/leon/synthesis/Synthesizer.scala
+++ b/src/main/scala/leon/synthesis/Synthesizer.scala
@@ -4,9 +4,10 @@ package synthesis
 import purescala.Common._
 import purescala.Definitions.{Program, FunDef}
 import purescala.TreeOps._
-import purescala.Trees.{Expr, Not}
+import purescala.Trees._
 import purescala.ScalaPrinter
 import solvers.z3._
+import solvers.TimeoutSolver
 import sun.misc.{Signal, SignalHandler}
 
 import solvers.Solver
@@ -24,7 +25,8 @@ class Synthesizer(val context : LeonContext,
                   val problem: Problem,
                   val options: SynthesisOptions) {
 
-  val silentContext = context.copy(reporter = new SilentReporter)
+  val silentReporter = new SilentReporter
+  val silentContext = context.copy(reporter = silentReporter)
 
   val rules: Seq[Rule] = options.rules
 
@@ -75,10 +77,60 @@ class Synthesizer(val context : LeonContext,
     }
 
     res match {
-      case Some(solution) =>
+      case Some((solution, true)) =>
+        val ssol = solution.toSimplifiedExpr(context, program)
         (solution, true)
+      case Some((sol, false)) =>
+        val ssol = sol.toSimplifiedExpr(context, program)
+        reporter.info("Solution requires validation")
+
+        val (npr, fds) = solutionToProgram(sol)
+
+        val tsolver = new TimeoutSolver(new FairZ3Solver(silentContext), 10000L)
+        tsolver.setProgram(npr)
+
+        import verification.AnalysisPhase._
+        val vcs = generateVerificationConditions(reporter, npr, fds.map(_.id.name))
+        val vcreport = checkVerificationConditions(silentReporter, Seq(tsolver), vcs)
+
+        if (vcreport.totalValid == vcreport.totalConditions) {
+          (sol, true)
+        } else {
+          reporter.warning("Solution was invalid:")
+          reporter.warning(fds.map(ScalaPrinter(_)).mkString("\n\n"))
+          reporter.warning(vcreport.summaryString)
+          (new AndOrGraphPartialSolution(search.g, (task: TaskRunRule) => Solution.choose(task.problem), false).getSolution, false)
+        }
       case None =>
-        (new AndOrGraphPartialSolution(search.g, (task: TaskRunRule) => Solution.choose(task.problem)).getSolution, false)
+        (new AndOrGraphPartialSolution(search.g, (task: TaskRunRule) => Solution.choose(task.problem), true).getSolution, false)
     }
   }
+
+  // Returns the new program and the new functions generated for this
+  def solutionToProgram(sol: Solution): (Program, Set[FunDef]) = {
+    import purescala.TypeTrees.TupleType
+    import purescala.Definitions.VarDecl
+
+    val mainObject = program.mainObject
+
+    // Create new fundef for the body
+    val ret = TupleType(problem.xs.map(_.getType))
+    val res = ResultVariable().setType(ret)
+
+    val mapPost: Map[Expr, Expr] =
+      problem.xs.zipWithIndex.map{ case (id, i)  =>
+        Variable(id) -> TupleSelect(res, i+1)
+      }.toMap
+
+    val fd = new FunDef(FreshIdentifier("finalTerm", true), ret, problem.as.map(id => VarDecl(id, id.getType)))
+    fd.precondition  = Some(And(problem.pc, sol.pre))
+    fd.postcondition = Some(replace(mapPost, problem.phi))
+    fd.body          = Some(sol.term)
+
+    val newDefs = sol.defs + fd
+
+    val npr = program.copy(mainObject = mainObject.copy(defs = mainObject.defs ++ newDefs))
+
+    (npr, newDefs)
+  }
 }
diff --git a/src/main/scala/leon/synthesis/search/AndOrGraph.scala b/src/main/scala/leon/synthesis/search/AndOrGraph.scala
index a289e40ca..50a7ab859 100644
--- a/src/main/scala/leon/synthesis/search/AndOrGraph.scala
+++ b/src/main/scala/leon/synthesis/search/AndOrGraph.scala
@@ -24,6 +24,7 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos
 
     def minCost: Cost
 
+    var isTrustworthy: Boolean = true
     var solution: Option[S] = None
     var isUnsolvable: Boolean = false
 
@@ -86,6 +87,7 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos
       if (subSolutions.size == subProblems.size) {
         task.composeSolution(subTasks.map(subSolutions)) match {
           case Some(sol) =>
+            isTrustworthy = subProblems.forall(_._2.isTrustworthy)
             solution = Some(sol)
             updateMin()
 
@@ -174,12 +176,14 @@ class AndOrGraph[AT <: AOAndTask[S], OT <: AOOrTask[S], S](val root: OT, val cos
     def notifySolution(sub: AndTree, sol: S) {
       solution match {
         case Some(preSol) if (costModel.solutionCost(preSol) < costModel.solutionCost(sol)) =>
+          isTrustworthy  = sub.isTrustworthy
           solution       = Some(sol)
           minAlternative = sub
 
           notifyParent(solution.get)
 
         case None =>
+          isTrustworthy  = sub.isTrustworthy
           solution       = Some(sol)
           minAlternative = sub
 
diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala
index 88f378a8e..4ff625966 100644
--- a/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala
+++ b/src/main/scala/leon/synthesis/search/AndOrGraphParallelSearch.scala
@@ -18,7 +18,7 @@ abstract class AndOrGraphParallelSearch[WC,
 
   var system: ActorSystem = _
 
-  def search(): Option[S] = {
+  def search(): Option[(S, Boolean)] = {
     system = ActorSystem("ParallelSearch")
 
     val master = system.actorOf(Props(new Master), name = "Master")
@@ -38,7 +38,7 @@ abstract class AndOrGraphParallelSearch[WC,
       system = null
     }
 
-    g.tree.solution
+    g.tree.solution.map(s => (s, g.tree.isTrustworthy))
   }
 
   override def stop() {
diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala b/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala
index d4d38dd77..3f9fbdb68 100644
--- a/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala
+++ b/src/main/scala/leon/synthesis/search/AndOrGraphPartialSolution.scala
@@ -2,7 +2,7 @@ package leon.synthesis.search
 
 class AndOrGraphPartialSolution[AT <: AOAndTask[S],
                                 OT <: AOOrTask[S],
-                                S](val g: AndOrGraph[AT, OT, S], missing: AT => S) {
+                                S](val g: AndOrGraph[AT, OT, S], missing: AT => S, includeUntrusted: Boolean) {
 
 
   def getSolution: S = {
@@ -10,7 +10,7 @@ class AndOrGraphPartialSolution[AT <: AOAndTask[S],
   }
 
   def solveAnd(t: g.AndTree): S = {
-    if (t.isSolved) {
+    if (t.isSolved && (includeUntrusted || t.isTrustworthy)) {
       t.solution.get
     } else {
       t match {
@@ -23,7 +23,7 @@ class AndOrGraphPartialSolution[AT <: AOAndTask[S],
   }
 
   def solveOr(t: g.OrTree): S = {
-    if (t.isSolved) {
+    if (t.isSolved && (includeUntrusted || t.isTrustworthy)) {
       t.solution.get
     } else {
       t match {
diff --git a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala
index d4f5e9416..00593edef 100644
--- a/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala
+++ b/src/main/scala/leon/synthesis/search/AndOrGraphSearch.scala
@@ -57,20 +57,21 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S],
 
   abstract class ExpandResult[T <: AOTask[S]]
   case class Expanded[T <: AOTask[S]](sub: List[T]) extends ExpandResult[T]
-  case class ExpandSuccess[T <: AOTask[S]](sol: S) extends ExpandResult[T]
+  case class ExpandSuccess[T <: AOTask[S]](sol: S, isTrustworthy: Boolean) extends ExpandResult[T]
   case class ExpandFailure[T <: AOTask[S]]() extends ExpandResult[T]
 
   def stop() {
 
   }
 
-  def search(): Option[S]
+  def search(): Option[(S, Boolean)]
 
   def onExpansion(al: g.AndLeaf, res: ExpandResult[OT]) {
     res match {
       case Expanded(ls) =>
         al.expandWith(ls)
-      case r @ ExpandSuccess(sol) =>
+      case r @ ExpandSuccess(sol, isTrustworthy) =>
+        al.isTrustworthy = isTrustworthy
         al.solution = Some(sol)
         al.parent.notifySolution(al, sol)
       case _ =>
@@ -89,7 +90,8 @@ abstract class AndOrGraphSearch[AT <: AOAndTask[S],
     res match {
       case Expanded(ls) =>
         ol.expandWith(ls)
-      case r @ ExpandSuccess(sol) =>
+      case r @ ExpandSuccess(sol, isTrustworthy) =>
+        ol.isTrustworthy = isTrustworthy
         ol.solution = Some(sol)
         ol.parent.notifySolution(ol, sol)
       case _ =>
-- 
GitLab