diff --git a/src/main/scala/leon/codegen/CodeGenPhase.scala b/src/main/scala/leon/codegen/CodeGenPhase.scala index 12f31cf2f86f72de964002c3e6c6d4bccccea9de..2d9abd0b2a403ffed57480b8cb6dbdb4eefcda4d 100644 --- a/src/main/scala/leon/codegen/CodeGenPhase.scala +++ b/src/main/scala/leon/codegen/CodeGenPhase.scala @@ -19,7 +19,7 @@ object CodeGenPhase extends LeonPhase[Program,CompilationResult] { def run(ctx : LeonContext)(p : Program) : CompilationResult = { try { val unit = new CompilationUnit(ctx, p); - unit.writeClassFiles() + unit.writeClassFiles("./") CompilationResult(successful = true) } catch { case NonFatal(e) => CompilationResult(successful = false) diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 44140ba7a6f2f5996e9e0a81b2ed2c43aa0e1f84..0540745f193243f55b0330749cf757f3cba94933 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -675,7 +675,10 @@ trait CodeGeneration { case rh: RepairHole => mkExpr(simplestValue(rh.getType), ch) // It is expected to be invalid, we want to repair it - case choose @ Choose(_, _) => + case Choose(_, _, Some(e)) => + mkExpr(e, ch) + + case choose @ Choose(_, _, None) => val prob = synthesis.Problem.fromChoose(choose) val id = runtime.ChooseEntryPoint.register(prob, this); @@ -702,10 +705,12 @@ trait CodeGeneration { ch << Ldc(id) ch << InvokeStatic(GenericValuesClass, "get", "(I)Ljava/lang/Object;") - case NoTree( tp@(Int32Type | BooleanType | UnitType | CharType)) => + case nt @ NoTree( tp@(Int32Type | BooleanType | UnitType | CharType)) => + println("COMPILING "+nt+" TO "+simplestValue(tp)) mkExpr(simplestValue(tp), ch) - case NoTree(_) => + case nt @ NoTree(_) => + println("COMPILING "+nt+" TO NULL") ch << ACONST_NULL case This(ct) => diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index a29175fe11618832b18127cf9c0d2e21f9e0c762..a1fc0c4fe6de2fedacdebe0383fd6985eef6c854 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -414,9 +414,9 @@ class CompilationUnit(val ctx: LeonContext, classes.values.foreach(loader.register _) } - def writeClassFiles() { + def writeClassFiles(prefix: String) { for ((d, cl) <- classes) { - cl.writeToFile(cl.className + ".class") + cl.writeToFile(prefix+cl.className + ".class") } } diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index 5d1756153f9f4d480e5472bc4ad4094772e1366b..851731bb926f603d187acb3a3154a54d273735d5 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -22,6 +22,7 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) exte var monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) val unit = new CompilationUnit(ctx, prog, params) + val isCompiled = prog.definedFunctions.toSet case class DefaultRecContext(mappings: Map[Identifier, Expr], needJVMRef: Boolean = false) extends RecContext { diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 68ce7c5c80f98b93db1281c488d701c1c808927c..747a5d52b6a62a8cfa002537b4aa2b71ea968b93 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -340,7 +340,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val rDefault = e(default) val rLength = e(length) val IntLiteral(iLength) = rLength - FiniteArray((1 to iLength).map(_ => rDefault).toSeq) + FiniteArray((1 to iLength).map(_ => rDefault).toSeq).setType(ArrayType(rDefault.getType)) case ArrayLength(a) => var ra = e(a) @@ -355,7 +355,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val IntLiteral(index) = ri val FiniteArray(exprs) = ra - FiniteArray(exprs.updated(index, rv)) + FiniteArray(exprs.updated(index, rv)).setType(ra.getType) case ArraySelect(a, i) => val IntLiteral(index) = e(i) @@ -367,7 +367,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } case FiniteArray(exprs) => - FiniteArray(exprs.map(ex => e(ex))) + FiniteArray(exprs.map(ex => e(ex))).setType(expr.getType) case f @ FiniteMap(ss) => FiniteMap(ss.map{ case (k, v) => (e(k), e(v)) }.distinct).setType(f.getType) case g @ MapGet(m,k) => (e(m), e(k)) match { @@ -402,7 +402,10 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case p : Passes => e(p.asConstraint) - case choose: Choose => + case choose @ Choose(_, _, Some(impl)) => + e(impl) + + case choose @ Choose(_, _, None) => import purescala.TreeOps.simplestValue implicit val debugSection = utils.DebugSectionSynthesis diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 19c2daf00af172190286c5f914c32d4009a5e7d7..f55c5d548ac53884951e735d0eaf037b328a04e5 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -24,7 +24,7 @@ import purescala.Extractors._ import purescala.Constructors._ import purescala.TreeOps._ import purescala.TypeTreeOps._ -import purescala.DefOps.{inPackage, inUnit} +import purescala.DefOps.packageOf import xlang.Trees.{Block => LeonBlock, _} import xlang.TreeOps._ @@ -292,7 +292,7 @@ trait CodeExtraction extends ASTExtractors { val _ = Program(FreshIdentifier("__program"), withoutImports map { _._1 }) // With the aid of formed units, we extract the correct imports - val objPerPack = objects2Objects map { _._2 } groupBy { inPackage(_)} + val objPerPack = objects2Objects map { _._2 } groupBy { packageOf(_)} withoutImports map { case (u, imps) => u.copy(imports = { // Extract imports from source diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index 73bc8c16c3fb30574f178c22cdf70b43323bde18..1b1dfbf06ae56e489691a088c79b3597c172ffe9 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -23,7 +23,7 @@ object Constructors { case Nil => body case x :: Nil => - if (value.getType == x.getType) { + if (value.getType == x.getType || !value.getType.isInstanceOf[TupleType]) { // This is for cases where we build it like: letTuple(List(x), tupleWrap(List(z))) Let(x, value, body) } else { diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 36448d7c7a2e663f9771e4eddae0132a2005bab7..c8ea48d591135d0b4f45bb12380857f125d522d5 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -2,58 +2,59 @@ package leon.purescala import Common._ import Definitions._ +import Trees._ object DefOps { - - def inPackage(df : Definition) : PackageRef = { + + def packageOf(df: Definition): PackageRef = { df match { case _ : Program => List() case u : UnitDef => u.pack - case _ => df.owner map inPackage getOrElse List() + case _ => df.owner map packageOf getOrElse List() } } - - def inUnit(df : Definition) : Option[UnitDef] = df match { + + def unitOf(df: Definition): Option[UnitDef] = df match { case p : Program => None case u : UnitDef => Some(u) - case other => other.owner flatMap inUnit + case other => other.owner flatMap unitOf } - def inModule(df : Definition) : Option[ModuleDef] = df match { + def moduleOf(df: Definition): Option[ModuleDef] = df match { case p : Program => None case m : ModuleDef => Some(m) - case other => other.owner flatMap inModule + case other => other.owner flatMap moduleOf } - - def inProgram(df : Definition) : Option[Program] = { + + def programOf(df: Definition): Option[Program] = { df match { case p : Program => Some(p) - case other => other.owner flatMap inProgram + case other => other.owner flatMap programOf } } - - def pathFromRoot (df : Definition): List[Definition] ={ + + def pathFromRoot(df: Definition): List[Definition] ={ def rec(df : Definition) : List[Definition] = df.owner match { case Some(owner) => df :: rec(owner) case None => List(df) } rec(df).reverse } - + def unitsInPackage(p: Program, pack : PackageRef) = p.units filter { _.pack == pack } - + def isImportedBy(df : Definition, i : Import) : Boolean = i.importedDefs contains df - + def isImportedBy(df : Definition, is: Seq[Import]) : Boolean = is exists {isImportedBy(df,_)} - + def leastCommonAncestor(df1 : Definition, df2 : Definition) : Definition = { (pathFromRoot(df1) zip pathFromRoot(df2)) .takeWhile{case (df1,df2) => df1 eq df2} .last._1 } - + /** Returns the set of definitions directly visible from the current definition * Definitions that are shadowed by others are not returned. @@ -62,13 +63,13 @@ object DefOps { var toRet = Map[String,Definition]() val asList = (pathFromRoot(df).reverse flatMap { _.subDefinitions }) ++ { - inProgram(df) match { + programOf(df) match { case None => List() - case Some(p) => unitsInPackage(p, inPackage(df)) flatMap { _.subDefinitions } + case Some(p) => unitsInPackage(p, packageOf(df)) flatMap { _.subDefinitions } } } ++ - inProgram(df).toList ++ - ( for ( u <- inUnit(df).toSeq; + programOf(df).toList ++ + ( for ( u <- unitOf(df).toSeq; imp <- u.imports; impDf <- imp.importedDefs ) yield impDf @@ -107,7 +108,7 @@ object DefOps { def packageAsVisibleFrom(df : Definition, p : PackageRef) = { val visiblePacks = - inPackage(df) +: (inUnit(df).toSeq.flatMap(_.imports) collect { case PackageImport(pack) => pack }) + packageOf(df) +: (unitOf(df).toSeq.flatMap(_.imports) collect { case PackageImport(pack) => pack }) val bestSuper = visiblePacks filter { pack => pack == p || isSuperPackageOf(pack,p)} match { case Nil => Nil case other => other maxBy { _.length } @@ -121,14 +122,14 @@ object DefOps { val ancestor = leastCommonAncestor(base, target) val pth = rootPath dropWhile { _.owner != Some(ancestor) } val pathFromAncestor = if (pth.isEmpty) List(target) else pth - val index = rootPath lastIndexWhere { isImportedBy(_,inUnit(base).toSeq.flatMap { _.imports }) } + val index = rootPath lastIndexWhere { isImportedBy(_, unitOf(base).toSeq.flatMap { _.imports }) } val pathFromImport = rootPath drop scala.math.max(index, 0) val finalPath = if (pathFromAncestor.length <= pathFromImport.length) pathFromAncestor else pathFromImport assert(!finalPath.isEmpty) // Package val pack = if (finalPath.head.isInstanceOf[UnitDef]) { - packageAsVisibleFrom(base, inPackage(target)) + packageAsVisibleFrom(base, packageOf(target)) } else Nil @@ -136,7 +137,7 @@ object DefOps { } def fullName(df: Definition, fromProgram: Option[Program] = None): String = - fromProgram orElse inProgram(df) match { + fromProgram orElse programOf(df) match { case None => df.id.name case Some(p) => val (pr, ds) = pathAsVisibleFrom(p, df) @@ -163,7 +164,7 @@ object DefOps { exploreStandalones : Boolean = true // Unset this if your path already includes standalone object names ) : Option[Definition] = { - require(inProgram(base).isDefined) + require(programOf(base).isDefined) val fullNameList = fullName.split("\\.").toList map scala.reflect.NameTransformer.encode require(!fullNameList.isEmpty) @@ -207,8 +208,8 @@ object DefOps { df <- descendDefs(startingPoint,path) ) yield df ) orElse { - val program = inProgram(base).get - val currentPack = inPackage(base) + val program = programOf(base).get + val currentPack = packageOf(base) val knownPacks = program.units map { _.pack } // The correct package has the maximum identifiers @@ -239,7 +240,7 @@ object DefOps { else { val point = program.modules find { mod => mod.id.toString == objectPart.head && - inPackage(mod) == packagePart + packageOf(mod) == packagePart } orElse { onCondition (exploreStandalones) { // Search in standalone objects @@ -247,7 +248,7 @@ object DefOps { case ModuleDef(_,subDefs,true) => subDefs }.flatten.find { df => df.id.toString == objectPart.head && - inPackage(df) == packagePart + packageOf(df) == packagePart } } } @@ -291,6 +292,100 @@ object DefOps { def postMapOnFunDef(repl : Expr => Option[Expr], applyRec : Boolean = false )(funDef : FunDef) : FunDef = { applyOnFunDef(postMap(repl, applyRec))(funDef) } - + + private def defaultFiMap(fi: FunctionInvocation, nfd: FunDef): Option[FunctionInvocation] = (fi, nfd) match { + case (FunctionInvocation(old, args), newfd) if old.fd != newfd => + Some(FunctionInvocation(newfd.typed(old.tps), args)) + case _ => + None + } + + def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], + fiMapF: (FunctionInvocation, FunDef) => Option[FunctionInvocation] = defaultFiMap) = { + + var fdMapCache = Map[FunDef, Option[FunDef]]() + def fdMap(fd: FunDef): FunDef = { + if (!(fdMapCache contains fd)) { + fdMapCache += fd -> fdMapF(fd) + } + + fdMapCache(fd).getOrElse(fd) + } + + def replaceCalls(e: Expr): Expr = { + preMap { + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + fiMapF(fi, fdMap(fd)).map(_.setPos(fi)) + case _ => + None + }(e) + } + + val newP = p.copy(units = for (u <- p.units) yield { + u.copy( + modules = for (m <- u.modules) yield { + m.copy(defs = for (df <- m.defs) yield { + df match { + case f : FunDef => + val newF = fdMap(f) + newF.fullBody = replaceCalls(newF.fullBody) + newF + case c : ClassDef => + // val oldMethods = c.methods + // c.clearMethods() + // for (m <- oldMethods) { + // c.registerMethod(functionToFunction.get(m).map{_.to}.getOrElse(m)) + // } + c + case d => + d + } + }) + }, + imports = u.imports map { + case SingleImport(fd : FunDef) => + SingleImport(fdMap(fd)) + case other => other + } + ) + }) + + (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) + } + + def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = { + var found = false + val res = p.copy(units = for (u <- p.units) yield { + u.copy( + modules = for (m <- u.modules) yield { + var newdefs = for (df <- m.defs) yield { + df match { + case `after` => + found = true + after +: fds.toSeq + case d => + Seq(d) + } + } + + m.copy(defs = newdefs.flatten) + } + ) + }) + if (!found) { + println("addFunDefs could not find anchor function!") + } + res + } + + def mapFunDefs(e: Expr, fdMap: PartialFunction[FunDef, FunDef]): Expr = { + preMap { + case FunctionInvocation(tfd, args) => + fdMap.lift.apply(tfd.fd).map { + nfd => FunctionInvocation(nfd.typed(tfd.tps), args) + } + case _ => None + }(e) + } } diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 2455eff7e94e8fc260aeb12c2ec1e96b776c0e12..c7a72fa21fb31d4efb51647806cc0329eb61b179 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -81,26 +81,6 @@ object Definitions { copy(units = units.map{_.duplicate}) } - // Use only to add functions around functions - def addDefinition(d: Definition, around: Definition): Program = { - val Some(m) = inModule(around) - val nm = m.copy(defs = d +: m.defs) - - m.owner match { - case Some(u: UnitDef) => - val nu = u.copy(modules = u.modules.filterNot(_ == m) :+ nm) - - u.owner match { - case Some(p: Program) => - p.copy(units = p.units.filterNot(_ == u) :+ nu) - case _ => - this - } - case _ => - this - } - } - lazy val library = Library(this) def writeScalaFile(filename: String) { @@ -137,7 +117,7 @@ object Definitions { case PackageImport(pack) => { import DefOps._ // Ignore standalone modules, assume there are extra imports for them - inProgram(this) map { unitsInPackage(_,pack) } getOrElse List() + programOf(this) map { unitsInPackage(_,pack) } getOrElse List() } case SingleImport(imported) => List(imported) case WildcardImport(imported) => imported.subDefinitions @@ -401,7 +381,7 @@ object Definitions { def subDefinitions = params ++ tparams ++ nestedFuns.toList def duplicate: FunDef = { - val fd = new FunDef(id, tparams, returnType, params, defType) + val fd = new FunDef(id.freshen, tparams, returnType, params, defType) fd.copyContentFrom(this) fd.copiedFrom(this) } diff --git a/src/main/scala/leon/purescala/FunctionMapping.scala b/src/main/scala/leon/purescala/FunctionMapping.scala index 17c2cdfbf9350622e7a3b6019f34b616819975b6..8efa25c866eb363adc8bfb830186d5eebc6147cf 100644 --- a/src/main/scala/leon/purescala/FunctionMapping.scala +++ b/src/main/scala/leon/purescala/FunctionMapping.scala @@ -49,6 +49,8 @@ abstract class FunctionMapping extends TransformationPhase { // c.registerMethod(functionToFunction.get(m).map{_.to}.getOrElse(m)) // } c + case d => + d } }) }, diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index cda5084b9ee65d025512148e75a297623334f244..fa16449fac0b5e37cc9b232990b0a4a321989aef 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -291,7 +291,14 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case GenericValue(tp, id) => p"$tp#$id" case Tuple(exprs) => p"($exprs)" case TupleSelect(t, i) => p"${t}._$i" - case Choose(vars, pred) => p"choose(($vars) => $pred)" + case NoTree(tpe) => p"???($tpe)" + case Choose(vars, pred, oimpl) => + oimpl match { + case Some(e) => + p"$e /* choose: $vars => $pred */" + case None => + p"choose(($vars) => $pred)" + } case e @ Error(tpe, err) => p"""error[$tpe]("$err")""" case CaseClassInstanceOf(cct, e) => if (cct.classDef.isCaseObject) { @@ -482,7 +489,7 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe import DefOps._ val newPack = ( for ( scope <- ctx.scope; - unit <- inUnit(scope); + unit <- unitOf(scope); currentPack = unit.pack ) yield { if (isSuperPackageOf(currentPack,pack)) diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index b48def12d12d13469e11c178f642f1c5062c2f10..d9fd8e2966224b60b847f896be84af3e861bac05 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -25,7 +25,7 @@ class ScalaPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) ex tree match { case Not(Equals(l, r)) => p"$l != $r" case Implies(l,r) => pp(or(not(l), r)) - case Choose(vars, pred) => p"choose((${typed(vars)}) => $pred)" + case Choose(vars, pred, None) => p"choose((${typed(vars)}) => $pred)" case s @ FiniteSet(rss) => { val rs = rss.toSeq s.getType match { diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index e493ecbb3571edf3e987d8b06ea75ebdc1195b9c..5cf5924cb25f72a5f66eaf4231949caaf52b0ef5 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -361,7 +361,7 @@ object TreeOps { case Variable(i) => subvs + i case LetDef(fd,_) => subvs -- fd.params.map(_.id) -- fd.postcondition.map(_._1) case Let(i,_,_) => subvs - i - case Choose(is,_) => subvs -- is + case Choose(is,_,_) => subvs -- is case MatchLike(_, cses, _) => subvs -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b)) case Passes(_, _ , cses) => subvs -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b)) case Lambda(args, body) => subvs -- args.map(_.id) @@ -1383,7 +1383,7 @@ object TreeOps { def containsChoose(e: Expr): Boolean = { preTraversal{ - case Choose(_, _) => return true + case Choose(_, _, None) => return true case _ => }(e) false @@ -1391,7 +1391,7 @@ object TreeOps { def isDeterministic(e: Expr): Boolean = { preTraversal{ - case Choose(_, _) => return false + case Choose(_, _, None) => return false case Hole(_, _) => return false case RepairHole(_, _) => return false case Gives(_,_) => return false @@ -1637,7 +1637,7 @@ object TreeOps { case (Variable(i1), Variable(i2)) => idHomo(i1, i2) - case (Choose(ids1, e1), Choose(ids2, e2)) => + case (Choose(ids1, e1, _), Choose(ids2, e2, _)) => isHomo(e1, e2)(map ++ (ids1 zip ids2)) case (Let(id1, v1, e1), Let(id2, v2, e2)) => @@ -2096,7 +2096,7 @@ object TreeOps { def isStringLiteral(e: Expr): Option[String] = e match { case CaseClass(cct, args) => - val p = inProgram(cct.classDef) + val p = programOf(cct.classDef) require(p.isDefined) val lib = p.get.library @@ -2126,7 +2126,7 @@ object TreeOps { def isListLiteral(e: Expr): Option[(TypeTree, List[Expr])] = e match { case CaseClass(cct, args) => - val p = inProgram(cct.classDef) + val p = programOf(cct.classDef) require(p.isDefined) val lib = p.get.library diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 70410503470c814559c266769489e21a06a7fd1f..171d8cc04280aacd79832671ba2caec875549072 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -45,13 +45,13 @@ object Trees { def getType = body.getType } - case class Choose(vars: List[Identifier], pred: Expr) extends Expr with UnaryExtractable { + case class Choose(vars: List[Identifier], pred: Expr, var impl: Option[Expr] = None) extends Expr with NAryExtractable { require(!vars.isEmpty) def getType = if (vars.size > 1) TupleType(vars.map(_.getType)) else vars.head.getType def extract = { - Some((pred, (e: Expr) => Choose(vars, e).setPos(this))) + Some((Seq(pred)++impl, (es: Seq[Expr]) => Choose(vars, es.head, es.tail.headOption).setPos(this))) } } diff --git a/src/main/scala/leon/purescala/TypeTreeOps.scala b/src/main/scala/leon/purescala/TypeTreeOps.scala index 306ba571fec959967293270f0d116a7a32842efb..0cd0f7ba2ac08211ed84820d5bff649cbbdb098a 100644 --- a/src/main/scala/leon/purescala/TypeTreeOps.scala +++ b/src/main/scala/leon/purescala/TypeTreeOps.scala @@ -216,9 +216,9 @@ object TypeTreeOps { val newIds = ids.map(id => freshId(id, tpeSub(id.getType))) LetTuple(newIds, srec(value), rec(idsMap ++ (ids zip newIds))(body)).copiedFrom(l) - case c @ Choose(xs, pred) => + case c @ Choose(xs, pred, oimpl) => val newXs = xs.map(id => freshId(id, tpeSub(id.getType))) - Choose(newXs, rec(idsMap ++ (xs zip newXs))(pred)).copiedFrom(c) + Choose(newXs, rec(idsMap ++ (xs zip newXs))(pred), oimpl.map(srec)).copiedFrom(c) case l @ Lambda(args, body) => val newArgs = args.map { arg => diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index e2f9bbe3cd154aab86b64223ce71177eb49a8ece..93d0c59362d44cc3a0c3e754ad7df377c5ec5c9b 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -43,9 +43,10 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout reporter.info(ASCIIHelpers.title("2. Locating/Focusing synthesis problem")) - val t2 = new Timer().start - val (synth, ci) = getSynthesizer(passingTests, failingTests) - val p = synth.problem + val t2 = new Timer().start + val synth = getSynthesizer(passingTests, failingTests) + val ci = synth.ci + val p = synth.problem var solutions = List[Solution]() @@ -110,7 +111,7 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout } } - def getSynthesizer(passingTests: List[Example], failingTests: List[Example]): (Synthesizer , ChooseInfo)= { + def getSynthesizer(passingTests: List[Example], failingTests: List[Example]): Synthesizer = { val body = fd.body.get; @@ -137,12 +138,12 @@ class Repairman(ctx: LeonContext, initProgram: Program, fd: FunDef, verifTimeout ); // extract chooses from fd - val Seq(ci) = ChooseInfo.extractFromFunction(ctx, program, fd, soptions) + val Seq(ci) = ChooseInfo.extractFromFunction(program, fd) val nci = ci.copy(pc = and(ci.pc, guide)) val p = nci.problem - (new Synthesizer(ctx, fd, program, p, soptions), nci) + new Synthesizer(ctx, program, nci, soptions) } private def focusRepair(program: Program, fd: FunDef, passingTests: List[Example], failingTests: List[Example]): (Expr, Expr) = { diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index 9387da05dc93a4e7a20c20785433b7c59607ff73..09ce38e76a3637588753334a4c05e25de2a4846c 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -262,7 +262,16 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T]) { } } - case c @ Choose(ids, cond) => + case h @ RepairHole(_, _) => + val hid = FreshIdentifier("hole", true).setType(h.getType) + exprVars += hid + Variable(hid) + + case c @ Choose(ids, cond, Some(impl)) => + rec(pathVar, impl) + + + case c @ Choose(ids, cond, None) => val cid = FreshIdentifier("choose", true).setType(c.getType) storeExpr(cid) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index ba1945ae43cae18e8f9f331c39bf95270679b017..ac0719cd63ff41723c7ee4d0edd583ead42d0e39 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -638,7 +638,7 @@ trait AbstractZ3Solver typeToSort(at) val meta = arrayMetaDecls(normalizeType(at)) - val ar = z3.mkConstArray(typeToSort(base), rec(default)) + val ar = z3.mkConstArray(typeToSort(Int32Type), rec(default)) val res = meta.cons(ar, rec(length)) res diff --git a/src/main/scala/leon/synthesis/ChooseInfo.scala b/src/main/scala/leon/synthesis/ChooseInfo.scala index 142498c5916d51a0546882eaadf5738710d31226..841247dd2b2bc15785f99e2c626708dc31e75f70 100644 --- a/src/main/scala/leon/synthesis/ChooseInfo.scala +++ b/src/main/scala/leon/synthesis/ChooseInfo.scala @@ -11,37 +11,33 @@ import purescala.TreeOps._ import purescala.DefOps._ import Witnesses._ -case class ChooseInfo(ctx: LeonContext, - prog: Program, - fd: FunDef, +case class ChooseInfo(fd: FunDef, pc: Expr, source: Expr, - ch: Choose, - settings: SynthesisSettings) { + ch: Choose) { - val problem = Problem.fromChoose(ch, pc) - val synthesizer = new Synthesizer(ctx, fd, prog, problem, settings) + val problem = Problem.fromChoose(ch, pc) } object ChooseInfo { - def extractFromProgram(ctx: LeonContext, prog: Program, options: SynthesisSettings): List[ChooseInfo] = { + def extractFromProgram(prog: Program): List[ChooseInfo] = { // Look for choose() val results = for (f <- prog.definedFunctions if f.body.isDefined; - ci <- extractFromFunction(ctx, prog, f, options)) yield { + ci <- extractFromFunction(prog, f)) yield { ci } results.sortBy(_.source.getPos) } - def extractFromFunction(ctx: LeonContext, prog: Program, fd: FunDef, options: SynthesisSettings): Seq[ChooseInfo] = { + def extractFromFunction(prog: Program, fd: FunDef): Seq[ChooseInfo] = { val actualBody = and(fd.precondition.getOrElse(BooleanLiteral(true)), fd.body.get) val term = Terminating(fd.typedWithDef, fd.params.map(_.id.toVariable)) for ((ch, path) <- new ChooseCollectorWithPaths().traverse(actualBody)) yield { - ChooseInfo(ctx, prog, fd, and(path, term), ch, ch, options) + ChooseInfo(fd, and(path, term), ch, ch) } } } diff --git a/src/main/scala/leon/synthesis/PartialSolution.scala b/src/main/scala/leon/synthesis/PartialSolution.scala index 42b2d7cfcffd9ad1b68fa13eb5939538851b6ff0..13c9fc99d080bc1662e64dd22b92f2dc8edc7629 100644 --- a/src/main/scala/leon/synthesis/PartialSolution.scala +++ b/src/main/scala/leon/synthesis/PartialSolution.scala @@ -54,10 +54,13 @@ class PartialSolution(g: Graph, includeUntrusted: Boolean = false) { case s @ Solution(pre, defs, term) => (e: Expr) => Solution(replaceFromIDs(Map(anchor -> e), pre), - defs.map(preMapOnFunDef({ + defs.map { d => + d.fullBody = preMap({ case Variable(`anchor`) => Some(e) case _ => None - })), + })(d.fullBody) + d + }, replaceFromIDs(Map(anchor -> e), term), s.isTrusted) } diff --git a/src/main/scala/leon/synthesis/SearchContext.scala b/src/main/scala/leon/synthesis/SearchContext.scala index 39f85eaeea538177087810b444cee1559e2d1f6b..853314f08e2b3fd226f1a8edd29af2b2a7bb5ca0 100644 --- a/src/main/scala/leon/synthesis/SearchContext.scala +++ b/src/main/scala/leon/synthesis/SearchContext.scala @@ -3,6 +3,7 @@ package leon package synthesis +import purescala.Trees.Choose import graph._ /** @@ -11,6 +12,7 @@ import graph._ */ case class SearchContext ( sctx: SynthesisContext, + ci: ChooseInfo, currentNode: Node, search: Search ) { @@ -18,7 +20,6 @@ case class SearchContext ( val reporter = sctx.reporter val program = sctx.program - def searchDepth = { def depthOf(n: Node): Int = n.parent match { case Some(n2) => 1+depthOf(n2) diff --git a/src/main/scala/leon/synthesis/SynthesisContext.scala b/src/main/scala/leon/synthesis/SynthesisContext.scala index b6e74e2f2e02bb6f7abb6291eb36895ee938095f..e59e1eaadacf9edd5f50c7a0f21f6c80796bfe62 100644 --- a/src/main/scala/leon/synthesis/SynthesisContext.scala +++ b/src/main/scala/leon/synthesis/SynthesisContext.scala @@ -60,7 +60,7 @@ object SynthesisContext { SynthesisContext( synth.context, synth.settings, - synth.functionContext, + synth.ci.fd, synth.program, synth.reporter) } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 17abfd860da1e0ab48ee1f9bac0afbe0b08b1092..a0aff1783163babb86c821e014af7141e0b7ef65 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -85,9 +85,6 @@ object SynthesisPhase extends LeonPhase[Program, Program] { case LeonFlagOption("derivtrees", v) => options = options.copy(generateDerivationTrees = v) - case LeonFlagOption("cegis:unintprobe", v) => - options = options.copy(cegisUseUninterpretedProbe = v) - case LeonFlagOption("cegis:unsatcores", v) => options = options.copy(cegisUseUnsatCores = v) @@ -132,16 +129,17 @@ object SynthesisPhase extends LeonPhase[Program, Program] { filterInclusive(options.filterFuns.map(fdMatcher), Some(excludeByDefault _)) compose ciTofd } - var chooses = ChooseInfo.extractFromProgram(ctx, p, options).filter(fdFilter) + var chooses = ChooseInfo.extractFromProgram(p).filter(fdFilter) var functions = Set[FunDef]() chooses.foreach { ci => - val (search, solutions) = ci.synthesizer.validate(ci.synthesizer.synthesize()) + val synthesizer = new Synthesizer(ctx, p, ci, options) + val (search, solutions) = synthesizer.validate(synthesizer.synthesize()) val fd = ci.fd - if (ci.synthesizer.settings.generateDerivationTrees) { + if (options.generateDerivationTrees) { val dot = new DotGenerator(search.g) dot.writeFile("derivation"+DotGenerator.nextId()+".dot") } diff --git a/src/main/scala/leon/synthesis/SynthesisSettings.scala b/src/main/scala/leon/synthesis/SynthesisSettings.scala index 90713c20ba54a15ebafc7c4ebf0bd25afe27e538..126f7994311615377d6ffe137e360a71f0595316 100644 --- a/src/main/scala/leon/synthesis/SynthesisSettings.scala +++ b/src/main/scala/leon/synthesis/SynthesisSettings.scala @@ -22,7 +22,6 @@ case class SynthesisSettings( functionsToIgnore: Set[FunDef] = Set(), // Cegis related options - cegisUseUninterpretedProbe: Boolean = false, cegisUseUnsatCores: Boolean = true, cegisUseOptTimeout: Boolean = true, cegisUseBssFiltering: Boolean = true, diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index cca8d8177903f6d9f01c9012aa50c1155b69e4b5..b9bae7dad53c9159453e157d1129298b2d717085 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -19,21 +19,22 @@ import java.io.File import synthesis.graph._ class Synthesizer(val context : LeonContext, - val functionContext: FunDef, val program: Program, - val problem: Problem, + val ci: ChooseInfo, val settings: SynthesisSettings) { + val problem = ci.problem + val reporter = context.reporter def getSearch(): Search = { if (settings.manualSearch) { - new ManualSearch(context, problem, settings.costModel) + new ManualSearch(context, ci, problem, settings.costModel) } else if (settings.searchWorkers > 1) { ??? //new ParallelSearch(this, problem, options.searchWorkers) } else { - new SimpleSearch(context, problem, settings.costModel, settings.searchBound) + new SimpleSearch(context, ci, problem, settings.costModel, settings.searchBound) } } @@ -115,7 +116,7 @@ class Synthesizer(val context : LeonContext, }.toMap } - val fd = new FunDef(FreshIdentifier(functionContext.id.name+"_final", true), Nil, ret, problem.as.map(id => ValDef(id, id.getType)), DefType.MethodDef) + val fd = new FunDef(FreshIdentifier(ci.fd.id.name+"_final", true), Nil, ret, problem.as.map(id => ValDef(id, id.getType)), DefType.MethodDef) fd.precondition = Some(and(problem.pc, sol.pre)) fd.postcondition = Some((res.id, replace(mapPost, problem.phi))) fd.body = Some(sol.term) diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala index b9bfa13cb324a092893a8a7d759651c3a786f5ce..4e851ef70160b59bcba45e121324dc82ea3ca68d 100644 --- a/src/main/scala/leon/synthesis/graph/Search.scala +++ b/src/main/scala/leon/synthesis/graph/Search.scala @@ -4,13 +4,14 @@ package graph import scala.annotation.tailrec +import purescala.Trees.Choose import leon.utils.StreamUtils.cartesianProduct import scala.collection.mutable.ArrayBuffer import leon.utils.Interruptible import java.util.concurrent.atomic.AtomicBoolean -abstract class Search(ctx: LeonContext, p: Problem, costModel: CostModel) extends Interruptible { +abstract class Search(ctx: LeonContext, ci: ChooseInfo, p: Problem, costModel: CostModel) extends Interruptible { val g = new Graph(costModel, p); def findNodeToExpandFrom(n: Node): Option[Node] @@ -22,11 +23,11 @@ abstract class Search(ctx: LeonContext, p: Problem, costModel: CostModel) extend n match { case an: AndNode => ctx.timers.synthesis.applications.get(an.ri.toString).timed { - an.expand(SearchContext(sctx, an, this)) + an.expand(SearchContext(sctx, ci, an, this)) } case on: OrNode => - on.expand(SearchContext(sctx, on, this)) + on.expand(SearchContext(sctx, ci, on, this)) } } } @@ -83,7 +84,7 @@ abstract class Search(ctx: LeonContext, p: Problem, costModel: CostModel) extend ctx.interruptManager.registerForInterrupts(this) } -class SimpleSearch(ctx: LeonContext, p: Problem, costModel: CostModel, bound: Option[Int]) extends Search(ctx, p, costModel) { +class SimpleSearch(ctx: LeonContext, ci: ChooseInfo, p: Problem, costModel: CostModel, bound: Option[Int]) extends Search(ctx, ci, p, costModel) { val expansionBuffer = ArrayBuffer[Node]() def findIn(n: Node) { @@ -123,7 +124,7 @@ class SimpleSearch(ctx: LeonContext, p: Problem, costModel: CostModel, bound: Op } } -class ManualSearch(ctx: LeonContext, problem: Problem, costModel: CostModel) extends Search(ctx, problem, costModel) { +class ManualSearch(ctx: LeonContext, ci: ChooseInfo, problem: Problem, costModel: CostModel) extends Search(ctx, ci, problem, costModel) { import ctx.reporter._ abstract class Command diff --git a/src/main/scala/leon/synthesis/rules/CegisLike.scala b/src/main/scala/leon/synthesis/rules/CegisLike.scala index 53d9a4bde5a184ef39e5d7286db39e2ae091aa71..be01424025cc154e461b44d03f0629c1b3621ab3 100644 --- a/src/main/scala/leon/synthesis/rules/CegisLike.scala +++ b/src/main/scala/leon/synthesis/rules/CegisLike.scala @@ -8,6 +8,7 @@ import leon.utils.StreamUtils import solvers._ import solvers.z3._ +import verification._ import purescala.Trees._ import purescala.Common._ import purescala.Definitions._ @@ -18,6 +19,7 @@ import purescala.TypeTreeOps._ import purescala.Extractors._ import purescala.Constructors._ import purescala.ScalaPrinter +import purescala.PrinterOptions import utils.Helpers._ import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} @@ -41,22 +43,25 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + def debugPrinter(t: Tree) = ScalaPrinter(t, PrinterOptions(printUniqueIds = true)) + + val exSolverTo = 2000L + val cexSolverTo = 2000L + val sctx = hctx.sctx + val ctx = sctx.context // CEGIS Flags to activate or deactivate features - val useUninterpretedProbe = sctx.settings.cegisUseUninterpretedProbe val useUnsatCores = sctx.settings.cegisUseUnsatCores val useOptTimeout = sctx.settings.cegisUseOptTimeout val useVanuatoo = sctx.settings.cegisUseVanuatoo val useCETests = sctx.settings.cegisUseCETests val useCEPruning = sctx.settings.cegisUseCEPruning - // Limits the number of programs CEGIS will specifically test for instead of reasoning symbolically - val testUpTo = 5 + // Limits the number of programs CEGIS will specifically validate individually + val validateUpTo = 5 val useBssFiltering = sctx.settings.cegisUseBssFiltering val filterThreshold = 1.0/2 - val evalParams = CodeGenParams(maxFunctionInvocations = 2000) - lazy val evaluator = new CodeGenEvaluator(sctx.context, sctx.program, evalParams) val interruptManager = sctx.context.interruptManager @@ -66,341 +71,526 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { return Nil } - class NonDeterministicProgram(val p: Problem, - val initGuard: Identifier) { - - val grammar = params.grammar - - // b -> (c, ex) means the clause b => c == ex - var mappings: Map[Identifier, (Identifier, Expr)] = Map() + class NonDeterministicProgram(val p: Problem) { - // b -> Set(c1, c2) means c1 and c2 are uninterpreted behind b, requires b to be closed - private var guardedTerms: Map[Identifier, Set[Identifier]] = Map(initGuard -> p.xs.toSet) - - private var labels: Map[Identifier, T] = Map() ++ p.xs.map(x => x -> params.rootLabel(x.getType)) - - def isBClosed(b: Identifier) = guardedTerms.contains(b) + private val grammar = params.grammar /** - * Stores which b's guard which c's, which then are represented by which - * b's: + * Different view of the tree of expressions: + * + * Case used to illustrate the different views, assuming encoding: * - * b -> Map(c1 -> Set(b2, b3), - * c2 -> Set(b4, b5)) + * b1 => c1 == F(c2, c3) + * b2 => c1 == G(c4, c5) + * b3 => c6 == H(c4, c5) * - * means b protects c1 (with sub alternatives b2/b3), and c2 (with sub b4/b5) + * c1 -> Seq( + * (b1, F(c2, c3), Set(c2, c3)) + * (b2, G(c4, c5), Set(c4, c5)) + * ) + * c6 -> Seq( + * (b3, H(c7, c8), Set(c7, c8)) + * ) */ - private var bTree = Map[Identifier, Map[Identifier, Set[Identifier]]]( initGuard -> p.xs.map(_ -> Set[Identifier]()).toMap) + private var cTree: Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]] = Map() /** - * Stores which c's are guarded by which b's - * - * c1 -> Set(b2, b3) + * Computes dependencies of c's * - * means c1 is protected by b2 and b3 + * c1 -> Set(c2, c3, c4, c5) */ - private var revBTree : Map[Identifier, Set[Identifier]] = Map() - + private var cDeps: Map[Identifier, Set[Identifier]] = Map() + /** - * Computes dependencies of c's + * Keeps track of blocked Bs and which C are affected, assuming cs are undefined: * - * Assuming encoding: - * b1 => c == F(c2, c3) - * b2 => c == F(c4, c5) + * b2 -> Set(c4) + * b3 -> Set(c4) + */ + private var closedBs: Map[Identifier, Set[Identifier]] = Map() + + /** + * Maps c identifiers to grammar labels * - * will be represented here as c -> Set(c2, c3, c4, c5) + * Labels allows us to use grammars that are not only type-based */ - private var cChildren: Map[Identifier, Set[Identifier]] = Map().withDefaultValue(Set()) + private var labels: Map[Identifier, T] = Map() ++ p.xs.map(x => x -> params.rootLabel(x.getType)) + + private var bs: Set[Identifier] = Set() + private var bsOrdered: Seq[Identifier] = Seq() + + /** + * Checks if 'b' is closed (meaning it depends on uninterpreted terms) + */ + def isBActive(b: Identifier) = !closedBs.contains(b) - // Returns all possible assignments to Bs in order to enumerate all possible programs + /** + * Returns all possible assignments to Bs in order to enumerate all possible programs + */ def allPrograms(): Traversable[Set[Identifier]] = { - def allChildPaths(b: Identifier): Stream[Set[Identifier]] = { - if (isBClosed(b)) { - Stream.empty - } else { - bTree.get(b) match { - case Some(cToBs) => - val streams = cToBs.values.toSeq.map { children => - children.toStream.flatMap(c => allChildPaths(c).map(l => l + b)) - } + import StreamUtils._ - StreamUtils.cartesianProduct(streams).map { ls => - ls.foldLeft(Set[Identifier]())(_ ++ _) - } - case None => - Stream.cons(Set(b), Stream.empty) + def allProgramsFor(cs: Set[Identifier]): Stream[Set[Identifier]] = { + val streams = for (c <- cs.toSeq) yield { + val subs = for ((b, _, subcs) <- cTree(c) if isBActive(b); + p <- allProgramsFor(subcs)) yield { + + p + b } + + subs.toStream + } + StreamUtils.cartesianProduct(streams).map { ls => + ls.foldLeft(Set[Identifier]())(_ ++ _) } } - allChildPaths(initGuard).toSet + allProgramsFor(p.xs.toSet) } - /* - * Compilation/Execution of programs - */ + private def debugCExpr(cTree: Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]], + markedBs: Set[Identifier] = Set()): Unit = { + println(" -- -- -- -- -- ") + for ((c, alts) <- cTree) { + println + println(f"$c%-4s :=") + for ((b, ex, cs) <- alts ) { + val active = if (isBActive(b)) " " else "тип" + val markS = if (markedBs(b)) Console.GREEN else "" + val markE = if (markedBs(b)) Console.RESET else "" + + println(f" $markS$active $b%-4s => $ex%-40s [$cs]$markE") + } + } + } + + private def computeCExpr(): Expr = { - type EvaluationResult = EvaluationResults.Result + val lets = (for ((c, alts) <- cTree) yield { + val activeAlts = alts.filter(a => isBActive(a._1)) - private var triedCompilation = false - private var progEvaluator: Option[(Seq[Expr], Seq[Expr]) => EvaluationResult] = None + val expr = activeAlts.foldLeft(simplestValue(c.getType): Expr){ + case (e, (b, ex, _)) => IfExpr(b.toVariable, ex, e) + } + + (c, expr) + }).toMap - def canTest(): Boolean = { - if (!triedCompilation) { - progEvaluator = compile() + // We order the lets base don dependencies + def defFor(c: Identifier): Expr = { + cDeps(c).filter(lets contains _).foldLeft(lets(c)) { + case (e, c) => Let(c, defFor(c), e) + } } - progEvaluator.isDefined - } + val resLets = p.xs.map(defFor(_)) + val res = tupleWrap(p.xs.map(defFor)) - private var bssOrdered: Seq[Identifier] = Seq() + val substMap : Map[Expr,Expr] = (bsOrdered.zipWithIndex.map { + case (b,i) => Variable(b) -> ArraySelect(bArrayId.toVariable, IntLiteral(i)) + }).toMap - def testForProgram(bss: Set[Identifier])(ins: Seq[Expr]): Boolean = { - if (canTest()) { - val bssValues : Seq[Expr] = bssOrdered.map(i => BooleanLiteral(bss(i))) + val simplerRes = simplifyLets(res) - val evalResult = progEvaluator.get.apply(bssValues, ins) + replace(substMap, simplerRes) + } - evalResult match { - case EvaluationResults.Successful(res) => - res == BooleanLiteral(true) - case EvaluationResults.RuntimeError(err) => - false - case EvaluationResults.EvaluatorError(err) => - sctx.reporter.error("Error testing CE: "+err) - false - } - } else { - true - } + /** + * Information about the final Program representing CEGIS solutions at + * the current unfolding level + */ + private val outerSolution = new PartialSolution(hctx.search.g).solutionAround(hctx.currentNode).getOrElse { + sctx.reporter.fatalError("Unable to create outer solution") } - def compile(): Option[(Seq[Expr], Seq[Expr]) => EvaluationResult] = { - var unreachableCs: Set[Identifier] = guardedTerms.flatMap(_._2).toSet + private val bArrayId = FreshIdentifier("bArray", true).setType(ArrayType(BooleanType)) - val cToExprs = mappings.groupBy(_._2._1).map { - case (c, maps) => - // We only keep cases within the current unfoldings closedBs - val cases = maps.flatMap{ case (b, (_, ex)) => if (isBClosed(b)) None else Some(b -> ex) } + private var cTreeFd = new FunDef(FreshIdentifier("cTree", true), + Seq(), + p.outType, + p.as.map(id => ValDef(id, id.getType)), + DefType.MethodDef + ) - // We compute the IF expression corresponding to each c - val ifExpr = if (cases.isEmpty) { - // This can happen with ADTs with only cases with arguments - Error(c.getType, "No valid clause available") - } else { - cases.tail.foldLeft(cases.head._2) { - case (elze, (b, thenn)) => IfExpr(Variable(b), thenn, elze) - } + private var phiFd = new FunDef(FreshIdentifier("phiFd", true), + Seq(), + BooleanType, + p.as.map(id => ValDef(id, id.getType)), + DefType.MethodDef + ) + + private var programCTree: Program = _ + + // Map functions from original program to cTree program + private var fdMapCTree: Map[FunDef, FunDef] = _ + + private var tester: (Seq[Expr], Set[Identifier]) => EvaluationResults.Result = _ + + private def initializeCTreeProgram(): Unit = { + + // CEGIS is solved by called cTree function (without bs yet) + val fullSol = outerSolution(FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable))) + + + val chFd = hctx.ci.fd + val prog0 = hctx.program + + val affected = prog0.callGraph.transitiveCallers(chFd).toSet ++ Set(chFd, cTreeFd, phiFd) ++ fullSol.defs + + //println("Affected:") + //for (fd <- affected) { + // println(" - "+debugPrinter(fd.id)) + //} + + + cTreeFd.body = None + phiFd.body = Some( + letTuple(p.xs, + FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), + p.phi) + ) + + val prog1 = addFunDefs(prog0, Seq(cTreeFd, phiFd) ++ fullSol.defs, chFd) + + val (prog2, fdMap2) = replaceFunDefs(prog1)({ + case fd if affected(fd) => + // Add the b array argument to all affected functions + val nfd = new FunDef(fd.id.freshen, + fd.tparams, + fd.returnType, + fd.params :+ ValDef(bArrayId, bArrayId.getType), + fd.defType) + nfd.copyContentFrom(fd) + nfd.copiedFrom(fd) + + if (fd == chFd) { + nfd.fullBody = replace(Map(hctx.ci.ch -> fullSol.guardedTerm), nfd.fullBody) } - c -> ifExpr - }.toMap + Some(nfd) + + case _ => + None + }, { + case (FunctionInvocation(old, args), newfd) if old.fd != newfd => + Some(FunctionInvocation(newfd.typed(old.tps), args :+ bArrayId.toVariable)) + case _ => + None + }) + //println("FunDef Map:") + //for ((f, t) <- fdMap2) { + // println("- "+debugPrinter(f.id)+" -> "+debugPrinter(t.id)) + //} + + //println("Program:") + //println(debugPrinter(prog2)) + + programCTree = prog2 + cTreeFd = fdMap2(cTreeFd) + phiFd = fdMap2(phiFd) + fdMapCTree = fdMap2 + } - // Map each x generated by the program to fresh xs passed as argument - val newXs = p.xs.map(x => x -> FreshIdentifier(x.name, true).setType(x.getType)) + private def setCExpr(cTree: Expr): Unit = { - val baseExpr = p.phi + cTreeFd.body = Some(preMap{ + case FunctionInvocation(TypedFunDef(fd, tps), args) if fdMapCTree contains fd => + Some(FunctionInvocation(fdMapCTree(fd).typed(tps), args :+ bArrayId.toVariable)) + case _ => + None + }(cTree)) - bssOrdered = bss.toSeq.sortBy(_.id) + val evalParams = CodeGenParams(maxFunctionInvocations = -1, doInstrument = false) - var res = baseExpr - def composeWith(c: Identifier) { - cToExprs.get(c) match { - case Some(value) => - val guards = (revBTree.getOrElse(c,Set()) - initGuard ).toSeq map { _.toVariable } - res = Let(c, if(guards.isEmpty) cToExprs(c) else IfExpr(orJoin(guards), cToExprs(c), NoTree(c.getType)), res) - case None => - res = Let(c, Error(c.getType, "No value available"), res) - } + val evaluator = new DualEvaluator(sctx.context, programCTree, evalParams) + + //println("-- "*30) + //println(debugPrinter(programCTree)) + //println(".. "*30) + + tester = + { (ins: Seq[Expr], bValues: Set[Identifier]) => + val bsValue = FiniteArray(bsOrdered.map(b => BooleanLiteral(bValues(b)))).setType(ArrayType(BooleanType)) + val args = ins :+ bsValue + + val fi = FunctionInvocation(phiFd.typed, args) - for (dep <- cChildren(c) if !unreachableCs(dep)) { - composeWith(dep) + evaluator.eval(fi, Map()) } + } - } - for (c <- p.xs) { - composeWith(c) + private def updateCTree() { + if (programCTree eq null) { + initializeCTreeProgram() } - val simplerRes = simplifyLets(res) + setCExpr(computeCExpr()) + } - def compileWithArray(): Option[(Seq[Expr], Seq[Expr]) => EvaluationResult] = { - val ba = FreshIdentifier("bssArray").setType(ArrayType(BooleanType)) - val bav = Variable(ba) - val substMap : Map[Expr,Expr] = (bssOrdered.zipWithIndex.map { - case (b,i) => Variable(b) -> ArraySelect(bav, IntLiteral(i)) - }).toMap - val forArray = replace(substMap, simplerRes) + def testForProgram(bValues: Set[Identifier])(ins: Seq[Expr]): Boolean = { + tester(ins, bValues) match { + case EvaluationResults.Successful(res) => + res == BooleanLiteral(true) - // We trust arrays to be fast... - val eval = evaluator.compile(forArray, ba +: p.as) + case EvaluationResults.RuntimeError(err) => + false - eval.map{e => { case (bss, ins) => - e(FiniteArray(bss).setType(ArrayType(BooleanType)) +: ins) - }} + case EvaluationResults.EvaluatorError(err) => + sctx.reporter.error("Error testing CE: "+err) + false } + } - def compileWithArgs(): Option[(Seq[Expr], Seq[Expr]) => EvaluationResult] = { - val eval = evaluator.compile(simplerRes, bssOrdered ++ p.as) - - eval.map{e => { case (bss, ins) => - e(bss ++ ins) - }} - } - triedCompilation = true - val localVariables = bss.size + cToExprs.size + p.as.size + def getExpr(bValues: Set[Identifier]): Expr = { + def getCValue(c: Identifier): Expr = { + cTree(c).find(i => bValues(i._1)).map { + case (b, ex, cs) => + val map = for (c <- cs) yield { + c -> getCValue(c) + } - if (localVariables < 128) { - compileWithArgs().orElse(compileWithArray()) - } else { - compileWithArray() + substAll(map.toMap, ex) + }.getOrElse { + simplestValue(c.getType) + } } + + tupleWrap(p.xs.map(c => getCValue(c))) } - def determinize(bss: Set[Identifier]): Expr = { - val cClauses = mappings.filterKeys(bss).map(_._2).toMap + def validatePrograms(bss: Set[Set[Identifier]]): Either[Stream[Solution], Seq[Seq[Expr]]] = { + try { + val cexs = for (bs <- bss.toSeq) yield { + val sol = getExpr(bs) - def getCValue(c: Identifier): Expr = { - val map = for (dep <- cChildren(c) if cClauses contains dep) yield { - dep -> getCValue(dep) + val fullSol = outerSolution(sol) + + val prog = addFunDefs(hctx.program, fullSol.defs, hctx.ci.fd) + + hctx.ci.ch.impl = Some(fullSol.guardedTerm) + + //println("Validating Solution "+sol) + //println(debugPrinter(prog)) + + val cnstr = and(p.pc, letTuple(p.xs, sol, Not(p.phi))) + //println("Solving for: "+cnstr) + + val solver = (new FairZ3Solver(ctx, prog) with TimeoutSolver).setTimeout(cexSolverTo) + try { + solver.assertCnstr(cnstr) + solver.check match { + case Some(true) => + excludeProgram(bs) + val model = solver.getModel + Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))) + + case Some(false) => + // UNSAT, valid program + return Left(Stream(Solution(BooleanLiteral(true), Set(), sol, true))) + + case None => + None + } + } finally { + solver.free + } } - substAll(map.toMap, cClauses(c)) + Right(cexs.flatten) + } finally { + hctx.ci.ch.impl = None } + } - tupleWrap(p.xs.map(c => getCValue(c))) + var excludedPrograms = ArrayBuffer[Set[Identifier]]() + // Explicitly remove program computed by bValues from the search space + def excludeProgram(bValues: Set[Identifier]): Unit = { + excludedPrograms += bValues } /** * Shrinks the non-deterministic program to the provided set of * alternatives only */ - def filterFor(remainingBss: Set[Identifier]): Seq[Expr] = { - val filteredBss = remainingBss + initGuard - - // The following code is black-magic, read with caution - mappings = mappings.filterKeys(filteredBss) - guardedTerms = Map() - bTree = bTree.filterKeys(filteredBss) - bTree = bTree.mapValues(cToBs => cToBs.mapValues(bs => bs & filteredBss)) - - val filteredCss = mappings.map(_._2._1).toSet - cChildren = cChildren.filterKeys(filteredCss) - cChildren = cChildren.mapValues(css => css & filteredCss) - for (c <- filteredCss) { - if (!(cChildren contains c)) { - cChildren += c -> Set() + def shrinkTo(remainingBs: Set[Identifier], finalUnfolding: Boolean): Unit = { + //println("Shrinking!") + + val initialBs = remainingBs ++ (if (finalUnfolding) Set() else closedBs.keySet) + + var cParent = Map[Identifier, Identifier](); + var cOfB = Map[Identifier, Identifier](); + var underBs = Map[Identifier, Set[Identifier]]() + + for ((cparent, alts) <- cTree; + (b, _, cs) <- alts) { + + cOfB += b -> cparent + + for (cchild <- cs) { + underBs += cchild -> (underBs.getOrElse(cchild, Set()) + b) + cParent += cchild -> cparent } } - // Finally, we reset the state of the evaluator - triedCompilation = false - progEvaluator = None + def bParents(b: Identifier): Set[Identifier] = { + val parentBs = underBs.getOrElse(cOfB(b), Set()) + Set(b) ++ parentBs.flatMap(bParents) + } + + // include parents + val keptBs = initialBs.flatMap(bParents) - var cGroups = Map[Identifier, (Set[Identifier], Set[Identifier])]() + //println("Initial Bs: "+initialBs) + //println("Keeping Bs: "+keptBs) - for ((parentGuard, cToBs) <- bTree; (c, bss) <- cToBs) { - val (ps, bs) = cGroups.getOrElse(c, (Set[Identifier](), Set[Identifier]())) + //debugCExpr(cTree, keptBs) - cGroups += c -> (ps + parentGuard, bs ++ bss) + var newCTree = Map[Identifier, Seq[(Identifier, Expr, Set[Identifier])]]() + + for ((c, alts) <- cTree) yield { + newCTree += c -> alts.filter(a => keptBs(a._1)) } - // We need to regenerate clauses for each b - val pathConstraints = for ((_, (parentGuards, bs)) <- cGroups) yield { - val bvs = bs.toList.map(Variable(_)) + def removeDeadAlts(c: Identifier, deadC: Identifier) { + if (newCTree contains c) { + val alts = newCTree(c) + val newAlts = alts.filterNot(a => a._3 contains deadC) + + if (newAlts.isEmpty) { + for (cp <- cParent.get(c)) { + removeDeadAlts(cp, c) + } + newCTree -= c + } else { + newCTree += c -> newAlts + } + } + } - // Represents the case where all parents guards are false, indicating - // that this C should not be considered at all - val failedPath = andJoin(parentGuards.toSeq.map(p => Not(p.toVariable))) + //println("BETWEEN") + //debugCExpr(newCTree, keptBs) - val distinct = bvs.combinations(2).collect { - case List(a, b) => - or(not(a), not(b)) + for ((c, alts) <- newCTree if alts.isEmpty) { + for (cp <- cParent.get(c)) { + removeDeadAlts(cp, c) } + newCTree -= c + } + + var newCDeps = Map[Identifier, Set[Identifier]]() + + for ((c, alts) <- cTree) yield { + newCDeps += c -> alts.map(_._3).toSet.flatten + } + + cTree = newCTree + cDeps = newCDeps + closedBs = closedBs.filterKeys(keptBs) - andJoin(Seq(orJoin(failedPath :: bvs), implies(failedPath, andJoin(bvs.map(Not(_))))) ++ distinct) + bs = cTree.map(_._2.map(_._1)).flatten.toSet + bsOrdered = bs.toSeq.sortBy(_.id) + + //debugCExpr(cTree) + updateCTree() + } + + class CGenerator { + private var buffers = Map[T, Stream[Identifier]]() + + private var slots = Map[T, Int]().withDefaultValue(0) + + private def streamOf(t: T): Stream[Identifier] = { + FreshIdentifier("c", true).setType(t.getType) #:: streamOf(t) } - // Generate all the b => c = ... - val impliess = mappings.map { case (bid, (recId, ex)) => - implies(Variable(bid), Equals(Variable(recId), ex)) + def reset(): Unit = { + slots = Map[T, Int]().withDefaultValue(0) } - (pathConstraints ++ impliess).toSeq + def getNext(t: T) = { + if (!(buffers contains t)) { + buffers += t -> streamOf(t) + } + + val n = slots(t) + slots += t -> (n+1) + + buffers(t)(n) + } } - def unfold(finalUnfolding: Boolean): (List[Expr], Set[Identifier]) = { - var newClauses = List[Expr]() - var newGuardedTerms = Map[Identifier, Set[Identifier]]() - var newMappings = Map[Identifier, (Identifier, Expr)]() + def unfold(finalUnfolding: Boolean): Boolean = { + var newBs = Set[Identifier]() + var unfoldedSomething = false; - var cGroups = Map[Identifier, Set[Identifier]]() + def freshB() = { + val id = FreshIdentifier("B", true).setType(BooleanType) + newBs += id + id + } - for ((parentGuard, recIds) <- guardedTerms; recId <- recIds) { - cGroups += recId -> (cGroups.getOrElse(recId, Set()) + parentGuard) + val unfoldBehind = if (cTree.isEmpty) { + p.xs + } else { + closedBs.flatMap(_._2).toSet } - for ((recId, parentGuards) <- cGroups) { + closedBs = Map[Identifier, Set[Identifier]]() + + for (c <- unfoldBehind) { + var alts = grammar.getProductions(labels(c)) - var alts = grammar.getProductions(labels(recId)) if (finalUnfolding) { alts = alts.filter(_.subTrees.isEmpty) } - val altsWithBranches = alts.map(alt => FreshIdentifier("B", true).setType(BooleanType) -> alt) + val cGen = new CGenerator() - val bvs = altsWithBranches.map(alt => Variable(alt._1)) + val cTreeInfos = if (alts.nonEmpty) { + for (gen <- alts) yield { + val b = freshB() - // Represents the case where all parents guards are false, indicating - // that this C should not be considered at all - val failedPath = andJoin(parentGuards.toSeq.map(p => not(p.toVariable))) + // Optimize labels + cGen.reset() - val distinct = bvs.combinations(2).collect { - case List(a, b) => - or(not(a), not(b)) - } - - val pre = andJoin(Seq(orJoin(failedPath +: bvs), implies(failedPath, andJoin(bvs.map(Not(_))))) ++ distinct) + val cToLabel = for (t <- gen.subTrees) yield { + cGen.getNext(t) -> t + } - var cBankCache = Map[T, Stream[Identifier]]() - def freshC(t: T): Stream[Identifier] = Stream.cons(FreshIdentifier("c", true).setType(t.getType), freshC(t)) - def getC(t: T, index: Int) = cBankCache.getOrElse(t, { - cBankCache += t -> freshC(t) - cBankCache(t) - })(index) - val cases = for((bid, gen) <- altsWithBranches.toList) yield { // b1 => E(gen1, gen2) [b1 -> {gen1, gen2}] - val newLabels = for ((t, i) <- gen.subTrees.zipWithIndex) yield { getC(t, i) -> t } - labels ++= newLabels + labels ++= cToLabel - val rec = newLabels.map(_._1) - val ex = gen.builder(rec.map(_.toVariable)) + val cs = cToLabel.map(_._1) + val ex = gen.builder(cs.map(_.toVariable)) - if (!rec.isEmpty) { - newGuardedTerms += bid -> rec.toSet - cChildren += recId -> (cChildren(recId) ++ rec) - } + if (!cs.isEmpty) { + closedBs += b -> cs.toSet + } - newMappings += bid -> (recId -> ex) + //println(" + "+b+" => "+c+" = "+ex) - implies(Variable(bid), Equals(Variable(recId), ex)) - } + unfoldedSomething = true - val newBIds = altsWithBranches.map(_._1).toSet + (b, ex, cs.toSet) + } + } else { + // Happens in final unfolding when no alts have ground terms + val b = freshB() + closedBs += b -> Set() - for (parentGuard <- parentGuards) { - bTree += parentGuard -> (bTree.getOrElse(parentGuard, Map()) + (recId -> newBIds)) + Seq((b, simplestValue(c.getType), Set[Identifier]())) } - newClauses = newClauses ::: pre :: cases + cTree += c -> cTreeInfos + cDeps += c -> cTreeInfos.map(_._3).toSet.flatten } sctx.reporter.ifDebug { printer => @@ -408,23 +598,98 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { grammar.printProductions(printer) } - //program = And(program :: newClauses) + bs = bs ++ newBs + bsOrdered = bs.toSeq.sortBy(_.id) + + updateCTree() - mappings = mappings ++ newMappings + unfoldedSomething + } - guardedTerms = newGuardedTerms + def solveForTentativeProgram(): Option[Option[Set[Identifier]]] = { + val solver = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(exSolverTo) + val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) - // Finally, we reset the state of the evaluator - triedCompilation = false - progEvaluator = None + val fixedBs = FiniteArray(bsOrdered.map(_.toVariable)).setType(ArrayType(BooleanType)) + val cnstrFixed = replaceFromIDs(Map(bArrayId -> fixedBs), cnstr) - revBTree ++= cGroups - - (newClauses, newGuardedTerms.keySet) + val toFind = and(p.pc, cnstrFixed) + //println(" --- Constraints ---") + //println(" - "+toFind) + solver.assertCnstr(toFind) + + // oneOfBs + for ((c, alts) <- cTree) { + val activeBs = alts.map(_._1).filter(isBActive) + + if (activeBs.nonEmpty) { + val oneOf = orJoin(activeBs.map(_.toVariable)); + //println(" - "+oneOf) + solver.assertCnstr(oneOf) + } + } + + for (ex <- excludedPrograms) { + val notThisProgram = Not(andJoin(ex.map(_.toVariable).toSeq)) + //println(" - "+notThisProgram) + solver.assertCnstr(notThisProgram) + } + + try { + solver.check match { + case Some(true) => + val model = solver.getModel + + val bModel = bs.filter(b => model.get(b).map(_ == BooleanLiteral(true)).getOrElse(false)) + + //println("Found potential expr: "+getExpr(bModel)+" under inputs: "+model) + Some(Some(bModel)) + + case Some(false) => + println("No Model!") + Some(None) + + case None => + println("Timeout!") + None + } + } finally { + solver.free + } } - def bss = mappings.keySet - def css : Set[Identifier] = mappings.values.map(_._1).toSet ++ guardedTerms.flatMap(_._2) + def solveForCounterExample(bs: Set[Identifier]): Option[Option[Seq[Expr]]] = { + val solver = (new FairZ3Solver(ctx, programCTree) with TimeoutSolver).setTimeout(cexSolverTo) + val cnstr = FunctionInvocation(phiFd.typed, phiFd.params.map(_.id.toVariable)) + + val fixedBs = FiniteArray(bsOrdered.map(b => BooleanLiteral(bs(b)))).setType(ArrayType(BooleanType)) + val cnstrFixed = replaceFromIDs(Map(bArrayId -> fixedBs), cnstr) + + solver.assertCnstr(p.pc) + solver.assertCnstr(Not(cnstrFixed)) + + try { + solver.check match { + case Some(true) => + val model = solver.getModel + val cex = p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) + + Some(Some(cex)) + + case Some(false) => + Some(None) + + case None => + None + } + } finally { + solver.free + } + } + + def free(): Unit = { + + } } List(new RuleInstantiation(this.name) { @@ -435,17 +700,12 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { var ass = p.as.toSet var xss = p.xs.toSet - val initGuard = FreshIdentifier("START", true).setType(BooleanType) - - val ndProgram = new NonDeterministicProgram(p, initGuard) + val ndProgram = new NonDeterministicProgram(p) var unfolding = 1 val maxUnfoldings = params.maxUnfoldings sctx.reporter.debug(s"maxUnfoldings=$maxUnfoldings") - val exSolverTo = 2000L - val cexSolverTo = 2000L - var baseExampleInputs: ArrayBuffer[Seq[Expr]] = new ArrayBuffer[Seq[Expr]]() // We populate the list of examples with a predefined one @@ -494,6 +754,8 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) { new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, pc, 20, 3000) } else { + val evalParams = CodeGenParams(maxFunctionInvocations = -1, doInstrument = false) + val evaluator = new CodeGenEvaluator(sctx.context, sctx.program, evalParams) new NaiveDataGen(sctx.context, sctx.program, evaluator).generateFor(p.as, pc, 20, 1000) } @@ -520,46 +782,19 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { baseExampleInputs.iterator ++ cachedInputIterator } - def checkForPrograms(programs: Set[Set[Identifier]]): RuleApplication = { - for (prog <- programs) { - val expr = ndProgram.determinize(prog) - val res = Equals(tupleWrap(p.xs.map(Variable(_))), expr) - - val solver3 = sctx.newSolver.setTimeout(cexSolverTo) - solver3.assertCnstr(and(pc, res, not(p.phi))) - - try { - solver3.check match { - case Some(false) => - return RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = true)) - case None => - return RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false)) - case Some(true) => - // invalid program, we skip - } - } finally { - solver3.free() - } - } - - RuleFailed() - } - // Keep track of collected cores to filter programs to test var collectedCores = Set[Set[Identifier]]() - val initExClause = and(pc, p.phi, Variable(initGuard)) - val initCExClause = and(pc, not(p.phi), Variable(initGuard)) + //val initExClause = and(pc, p.phi, Variable(initGuard)) + //val initCExClause = and(pc, not(p.phi), Variable(initGuard)) - // solver1 is used for the initial SAT queries - var solver1 = sctx.newSolver.setTimeout(exSolverTo) - solver1.assertCnstr(initExClause) + //// solver1 is used for the initial SAT queries + //var solver1 = sctx.newSolver.setTimeout(exSolverTo) + //solver1.assertCnstr(initExClause) - // solver2 is used for validating a candidate program, or finding new inputs - var solver2 = sctx.newSolver.setTimeout(cexSolverTo) - solver2.assertCnstr(initCExClause) - - var didFilterAlready = false + //// solver2 is used for validating a candidate program, or finding new inputs + //var solver2 = sctx.newSolver.setTimeout(cexSolverTo) + //solver2.assertCnstr(initCExClause) val tpe = tupleTypeWrap(p.xs.map(_.getType)) @@ -567,29 +802,11 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { do { var skipCESearch = false - var bssAssumptions = Set[Identifier]() - - if (!didFilterAlready) { - val (clauses, closedBs) = ndProgram.unfold(unfolding == maxUnfoldings) - - bssAssumptions = closedBs - - sctx.reporter.ifDebug { debug => - debug("UNFOLDING: ") - for (c <- clauses) { - debug(" - " + c.asString(sctx.context)) - } - debug("CLOSED Bs "+closedBs) - } - - val clause = andJoin(clauses) + // Unfold formula + val unfoldSuccess = ndProgram.unfold(unfolding == maxUnfoldings) - solver1.assertCnstr(clause) - solver2.assertCnstr(clause) - - if (clauses.isEmpty) { - unfolding = maxUnfoldings - } + if (!unfoldSuccess) { + unfolding = maxUnfoldings } // Compute all programs that have not been excluded yet @@ -599,15 +816,21 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { Set() } - val allPrograms = prunedPrograms.size - - sctx.reporter.debug("#Programs: "+prunedPrograms.size) + val nInitial = prunedPrograms.size + sctx.reporter.debug("#Programs: "+nInitial) + //sctx.reporter.ifDebug{ printer => + // for (p <- prunedPrograms.take(10)) { + // printer(" - "+ndProgram.getExpr(p)) + // } + // if(nPassing > 10) { + // printer(" - ...") + // } + //} var wrongPrograms = Set[Set[Identifier]](); // We further filter the set of working programs to remove those that fail on known examples - if (useCEPruning && hasInputExamples() && ndProgram.canTest()) { - + if (useCEPruning && hasInputExamples()) { for (bs <- prunedPrograms if !interruptManager.isInterrupted()) { var valid = true val examples = allInputExamples() @@ -615,7 +838,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val e = examples.next() if (!ndProgram.testForProgram(bs)(e)) { failedTestsStats(e) += 1 - //sctx.reporter.debug(" Program: "+ndProgram.determinize(bs)+" failed on "+e) + //sctx.reporter.debug(" Program: "+ndProgram.getExpr(bs)+" failed on "+e) wrongPrograms += bs prunedPrograms -= bs @@ -623,7 +846,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { } } - if (wrongPrograms.size % 1000 == 0) { + if (wrongPrograms.size+1 % 1000 == 0) { sctx.reporter.debug("..."+wrongPrograms.size) } } @@ -631,71 +854,54 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { val nPassing = prunedPrograms.size sctx.reporter.debug("#Programs passing tests: "+nPassing) + sctx.reporter.ifDebug{ printer => + for (p <- prunedPrograms.take(10)) { + printer(" - "+ndProgram.getExpr(p)) + } + if(nPassing > 10) { + printer(" - ...") + } + } if (nPassing == 0 || interruptManager.isInterrupted()) { + // No test passed, we can skip solver and unfold again, if possible skipCESearch = true; - } else if (nPassing <= testUpTo) { - // Immediate Test - checkForPrograms(prunedPrograms) match { - case rs: RuleClosed if rs.solutions.nonEmpty => - result = Some(rs) - case _ => - wrongPrograms.foreach { p => - solver1.assertCnstr(Not(andJoin(p.map(Variable(_)).toSeq))) - } + } else if (nPassing <= validateUpTo) { + // Small enough number of programs to try them individually + ndProgram.validatePrograms(prunedPrograms) match { + case Left(sols) if sols.nonEmpty => + result = Some(RuleClosed(sols)) + case Right(cexs) => + // All programs failed verification, we filter everything out and unfold + //ndProgram.shrinkTo(Set(), unfolding == maxUnfoldings) + skipCESearch = true; } - } else if (((nPassing < allPrograms*filterThreshold) || didFilterAlready) && useBssFiltering) { - // We filter the Bss so that the formula we give to z3 is much smalled + } else if (((nPassing < nInitial*filterThreshold)) && useBssFiltering) { + // We shrink the program to only use the bs mentionned val bssToKeep = prunedPrograms.foldLeft(Set[Identifier]())(_ ++ _) - - // Cannot unfold normally after having filtered, so we need to - // repeat the filtering procedure at next unfolding. - didFilterAlready = true - - // Freshening solvers - solver1.free() - solver1 = sctx.newSolver.setTimeout(exSolverTo) - solver1.assertCnstr(initExClause) - - solver2.free() - solver2 = sctx.newSolver.setTimeout(cexSolverTo) - solver2.assertCnstr(initCExClause) - - val clauses = ndProgram.filterFor(bssToKeep) - val clause = andJoin(clauses) - - solver1.assertCnstr(clause) - solver2.assertCnstr(clause) + ndProgram.shrinkTo(bssToKeep, unfolding == maxUnfoldings) } else { - wrongPrograms.foreach { p => - solver1.assertCnstr(not(andJoin(p.map(_.toVariable).toSeq))) + wrongPrograms.foreach { + ndProgram.excludeProgram(_) } } - val bss = ndProgram.bss - + // CEGIS Loop at a given unfolding level while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted()) { - solver1.checkAssumptions(bssAssumptions.map(id => Not(Variable(id)))) match { - case Some(true) => - val satModel = solver1.getModel - - val bssAssumptions: Set[Expr] = bss.map(b => satModel(b) match { - case BooleanLiteral(true) => Variable(b) - case BooleanLiteral(false) => Not(Variable(b)) - }) + ndProgram.solveForTentativeProgram() match { + case Some(Some(bs)) => + // Should we validate this program with Z3? - val validateWithZ3 = if (useCETests && hasInputExamples() && ndProgram.canTest()) { + val validateWithZ3 = if (useCETests && hasInputExamples()) { - val p = bssAssumptions.collect { case Variable(b) => b } - - if (allInputExamples().forall(ndProgram.testForProgram(p))) { + if (allInputExamples().forall(ndProgram.testForProgram(bs))) { // All valid inputs also work with this, we need to // make sure by validating this candidate with z3 true } else { // One valid input failed with this candidate, we can skip - solver1.assertCnstr(not(andJoin(p.map(_.toVariable).toSeq))) + ndProgram.excludeProgram(bs) false } } else { @@ -703,95 +909,42 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { true } - if (validateWithZ3) { - solver2.checkAssumptions(bssAssumptions) match { - case Some(true) => - val invalidModel = solver2.getModel - - val fixedAss = andJoin(ass.collect { - case a if invalidModel contains a => Equals(Variable(a), invalidModel(a)) - }.toSeq) - - val newCE = p.as.map(valuateWithModel(invalidModel)) - - baseExampleInputs += newCE + if (true || validateWithZ3) { + ndProgram.solveForCounterExample(bs) match { + case Some(Some(inputsCE)) => + // Found counter example! + baseExampleInputs += inputsCE // Retest whether the newly found C-E invalidates all programs - if (useCEPruning && ndProgram.canTest) { - if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(newCE))) { + if (useCEPruning) { + if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(inputsCE))) { skipCESearch = true } } - val unsatCore = if (useUnsatCores) { - solver1.push() - solver1.assertCnstr(fixedAss) - - val core = solver1.checkAssumptions(bssAssumptions) match { - case Some(false) => - // Core might be empty if unfolding level is - // insufficient, it becomes unsat no matter what - // the assumptions are. - solver1.getUnsatCore - - case Some(true) => - // Can't be! - bssAssumptions - - case None => - return RuleFailed() - } - - solver1.pop() - - collectedCores += core.collect{ case Variable(id) => id } - - core - } else { - bssAssumptions - } - - if (unsatCore.isEmpty) { - skipCESearch = true - } else { - solver1.assertCnstr(not(andJoin(unsatCore.toSeq))) - } - - case Some(false) => - - val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) - + case Some(None) => + // Found no counter example! Program is a valid solution + val expr = ndProgram.getExpr(bs) result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) - case _ => + case None => + // We are not sure if (useOptTimeout) { // Interpret timeout in CE search as "the candidate is valid" sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - val expr = ndProgram.determinize(satModel.filter(_._2 == BooleanLiteral(true)).keySet) + val expr = ndProgram.getExpr(bs) result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) } else { - return RuleFailed() + result = Some(RuleFailed()) } + } } - } - - - case Some(false) => - if (useUninterpretedProbe) { - solver1.check match { - case Some(false) => - // Unsat even without blockers (under which fcalls are then uninterpreted) - return RuleFailed() - - case _ => - } - } + case Some(None) => skipCESearch = true - case _ => - // Last chance, we test first few programs - result = Some(checkForPrograms(prunedPrograms.take(testUpTo))) + case None => + result = Some(RuleFailed()) } } @@ -806,8 +959,7 @@ abstract class CEGISLike[T <% Typed](name: String) extends Rule(name) { e.printStackTrace RuleFailed() } finally { - solver1.free() - solver2.free() + ndProgram.free() } } }) diff --git a/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala b/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala index 95dd54157bb8c0ec292c79185c2dcb6d112180f4..0034e0588f5344eda2a415f93b3ad80585a9a7de 100644 --- a/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala +++ b/src/main/scala/leon/synthesis/utils/SynthesisProblemExtractionPhase.scala @@ -10,29 +10,17 @@ import purescala.Definitions._ import solvers.z3._ import solvers.Solver -object SynthesisProblemExtractionPhase extends LeonPhase[Program, (Program, Map[FunDef, Seq[Problem]])] { +object SynthesisProblemExtractionPhase extends LeonPhase[Program, (Program, Map[FunDef, Seq[ChooseInfo]])] { val name = "Synthesis Problem Extraction" val description = "Synthesis Problem Extraction" - def run(ctx: LeonContext)(p: Program): (Program, Map[FunDef, Seq[Problem]]) = { - var results = Map[FunDef, Seq[Problem]]() - def noop(u:Expr, u2: Expr) = u - - - def actOnChoose(f: FunDef)(e: Expr) = e match { - case ch @ Choose(vars, pred) => - val problem = Problem.fromChoose(ch) - - results += f -> (results.getOrElse(f, Seq()) :+ problem) - case _ => - } - + def run(ctx: LeonContext)(p: Program): (Program, Map[FunDef, Seq[ChooseInfo]]) = { // Look for choose() - for (f <- p.definedFunctions.sortBy(_.id.toString) if f.body.isDefined) { - preTraversal(actOnChoose(f))(f.body.get) + val results = for (f <- p.definedFunctions.sortBy(_.id.toString) if f.body.isDefined) yield { + f -> ChooseInfo.extractFromFunction(p, f) } - (p, results) + (p, results.toMap) } } diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index 99506a14bc2e4d17af8fc4343d632c87911f47ce..6fd90151d8e46bd931373961d95de83196716283 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -14,6 +14,8 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { override val description = "Induction tactic for suitable functions" override val shortDescription = "induction" + val reporter = vctx.reporter + private def firstAbsClassDef(args: Seq[ValDef]): Option[(AbstractClassType, ValDef)] = { args.map(vd => (vd.getType, vd)).collect { case (act: AbstractClassType, vd) => (act, vd) diff --git a/src/main/scala/leon/verification/Tactic.scala b/src/main/scala/leon/verification/Tactic.scala index 26eef79c142095b1c1cba0132596de51cf1382f9..8d80f5ffbfa0315f5a7b14175d3d38cfb0fdc7cf 100644 --- a/src/main/scala/leon/verification/Tactic.scala +++ b/src/main/scala/leon/verification/Tactic.scala @@ -11,9 +11,6 @@ abstract class Tactic(vctx: VerificationContext) { val description : String val shortDescription : String - val program = vctx.program - val reporter = vctx.reporter - def generateVCs(fdUnsafe: FunDef): Seq[VerificationCondition] = { val fd = fdUnsafe.duplicate fd.fullBody = matchToIfThenElse(fd.fullBody) diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index 33056d81bd70247c1f1cc906b81f7e65eae11288..93ba69d154269af7b69e27355ed2286761a9eee9 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -250,7 +250,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val (bodyRes, bodyScope, bodyFun) = toFunction(b) (bodyRes, (b2: Expr) => LetDef(newFd, bodyScope(b2)).copiedFrom(expr), bodyFun) } - case c @ Choose(ids, b) => { + case c @ Choose(ids, b, _) => { //Recall that Choose cannot mutate variables from the scope (c, (b2: Expr) => b2, Map()) } diff --git a/src/test/scala/leon/test/synthesis/StablePrintingSuite.scala b/src/test/scala/leon/test/synthesis/StablePrintingSuite.scala index ca2821c687601fe947a1a365de3bfdf86f12735a..a933080a960196ed98c3e34a8362c1772eb86f44 100644 --- a/src/test/scala/leon/test/synthesis/StablePrintingSuite.scala +++ b/src/test/scala/leon/test/synthesis/StablePrintingSuite.scala @@ -34,7 +34,7 @@ class StablePrintingSuite extends LeonTestSuite { private def testIterativeSynthesis(cat: String, f: File, depth: Int) { - def getChooses(ctx: LeonContext, content: String): (Program, Seq[ChooseInfo]) = { + def getChooses(ctx: LeonContext, content: String): (Program, SynthesisSettings, Seq[ChooseInfo]) = { val opts = SynthesisSettings() val pipeline = leon.utils.TemporaryInputPhase andThen frontends.scalac.ExtractionPhase andThen @@ -43,7 +43,7 @@ class StablePrintingSuite extends LeonTestSuite { val program = pipeline.run(ctx)((content, Nil)) - (program, ChooseInfo.extractFromProgram(ctx, program, opts)) + (program, opts, ChooseInfo.extractFromProgram(program)) } case class Job(content: String, choosesToProcess: Set[Int], rules: List[String]) { @@ -69,7 +69,7 @@ class StablePrintingSuite extends LeonTestSuite { info(j.info("compilation")) - val (pgm, chooses) = try { + val (pgm, opts, chooses) = try { getChooses(ctx, j.content) } catch { case e: Throwable => @@ -84,9 +84,10 @@ class StablePrintingSuite extends LeonTestSuite { if (j.rules.size < depth) { for ((ci, i) <- chooses.zipWithIndex if j.choosesToProcess(i) || j.choosesToProcess.isEmpty) { - val sctx = SynthesisContext.fromSynthesizer(ci.synthesizer) - val search = ci.synthesizer.getSearch() - val hctx = SearchContext(sctx, search.g.root, search) + val synthesizer = new Synthesizer(ctx, pgm, ci, opts) + val sctx = SynthesisContext.fromSynthesizer(synthesizer) + val search = synthesizer.getSearch() + val hctx = SearchContext(sctx, ci, search.g.root, search) val problem = ci.problem info(j.info("synthesis "+problem)) val apps = sctx.rules flatMap { _.instantiateOn(hctx, problem)} diff --git a/src/test/scala/leon/test/synthesis/SynthesisRegressionSuite.scala b/src/test/scala/leon/test/synthesis/SynthesisRegressionSuite.scala index 1e6ed1bd758b2580b044edbde73dabf5c784fffd..9ad67c8a03240cded8e0843370dca3743c85e44b 100644 --- a/src/test/scala/leon/test/synthesis/SynthesisRegressionSuite.scala +++ b/src/test/scala/leon/test/synthesis/SynthesisRegressionSuite.scala @@ -26,22 +26,26 @@ class SynthesisRegressionSuite extends LeonTestSuite { private def testSynthesis(cat: String, f: File, bound: Int) { var chooses = List[ChooseInfo]() + var program: Program = null + var ctx: LeonContext = null + var opts: SynthesisSettings = null test(cat+": "+f.getName()+" Compilation") { - val ctx = createLeonContext("--synthesis") + ctx = createLeonContext("--synthesis") - val opts = SynthesisSettings(searchBound = Some(bound), allSeeing = true) + opts = SynthesisSettings(searchBound = Some(bound), allSeeing = true) val pipeline = leon.frontends.scalac.ExtractionPhase andThen leon.utils.PreprocessingPhase - val program = pipeline.run(ctx)(f.getAbsolutePath :: Nil) + program = pipeline.run(ctx)(f.getAbsolutePath :: Nil) - chooses = ChooseInfo.extractFromProgram(ctx, program, opts) + chooses = ChooseInfo.extractFromProgram(program) } for (ci <- chooses) { test(cat+": "+f.getName()+" - "+ci.fd.id.name) { - val (search, sols) = ci.synthesizer.synthesize() + val synthesizer = new Synthesizer(ctx, program, ci, opts) + val (search, sols) = synthesizer.synthesize() if (sols.isEmpty) { fail("Solution was not found. (Search bound: "+bound+")") } diff --git a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala index 73f01af8a05dbdf5c78efa700ac48909e3056ab6..5fe0a94ebe50136c32e1962e1696c3269fcfad8a 100644 --- a/src/test/scala/leon/test/synthesis/SynthesisSuite.scala +++ b/src/test/scala/leon/test/synthesis/SynthesisSuite.scala @@ -37,17 +37,18 @@ class SynthesisSuite extends LeonTestSuite { val (program, results) = pipeline.run(ctx)((content, Nil)) - for ((f, ps) <- results; p <- ps) { - info("%-20s".format(f.id.toString)) + for ((f,cis) <- results; ci <- cis) { + info("%-20s".format(ci.fd.id.toString)) val sctx = SynthesisContext(ctx, opts, - f, + ci.fd, program, ctx.reporter) - val search = new SimpleSearch(ctx, p, opts.costModel, None) - val hctx = SearchContext(sctx, search.g.root, search) + val p = ci.problem + val search = new SimpleSearch(ctx, ci, p, opts.costModel, None) + val hctx = SearchContext(sctx, ci, search.g.root, search) block(hctx, f, p) }