diff --git a/Game.scala b/Game.scala new file mode 100644 index 0000000000000000000000000000000000000000..391594f760b122aeb5d3de71f287a3eec4b13ffe --- /dev/null +++ b/Game.scala @@ -0,0 +1,233 @@ +import leon.lang._ +import leon.annotation._ +import leon.lang.synthesis._ +import leon.collection._ + +object Test { + case class Pos(x: Int, y: Int) { + def up = Pos(x, y-1) + def down = Pos(x, y+1) + def left = Pos(x-1, y) + def right = Pos(x+1, y) + + def isValid(s: State) = { + x >= 0 && y >= 0 && + x < s.map.size.x && y < s.map.size.y && + !(s.map.walls contains this) + } + + def distance(o: Pos) = { + (if (o.x < x) (x-o.x) else (o.x-x)) + + (if (o.y < y) (y-o.y) else (o.y-y)) + } + } + + case class Map(walls: Set[Pos], size: Pos) + + abstract class Action; + case object MoveUp extends Action + case object MoveDown extends Action + case object MoveLeft extends Action + case object MoveRight extends Action + case object Quit extends Action + + case class State(pos: Pos, + monster: Pos, + stop: Boolean, + map: Map) { + + def isValid = { + pos.isValid(this) && monster.isValid(this) + } + + } + + def step(s: State)(implicit o: Oracle[Action]): State = { + require(s.isValid) + + val u = display(s) + + stepMonster(stepPlayer(s)(o.left)) + } + + def stepMonster(s: State) = { + require(s.isValid) + if (s.pos == s.monster) { + State(s.pos, s.monster, true, s.map) + } else { + val mp = choose { + (res: Pos) => + res.distance(s.monster) <= 1 && + res.distance(s.pos) <= s.monster.distance(s.pos) && + res.isValid(s) + } + State(s.pos, mp, mp != s.pos, s.map) + } + } ensuring { _.isValid } + + def stepPlayer(s: State)(implicit o: Oracle[Action]) = { + val action: Action = ??? + + val ns = action match { + case Quit => + State(s.pos, s.monster, true, s.map) + case _ if s.stop => + s + case MoveUp if s.pos.y > 0 => + State(s.pos.up, s.monster, s.stop, s.map) + case MoveDown => + State(s.pos.down, s.monster, s.stop, s.map) + case MoveLeft if s.pos.x > 0 => + State(s.pos.left, s.monster, s.stop, s.map) + case MoveRight => + State(s.pos.right, s.monster, s.stop, s.map) + case _ => + s + } + + if (ns.isValid) ns else s + } + + def steps(s: State, b: Int)(implicit o: Oracle[Action]): State = { + if (b == 0 || s.stop) { + s + } else { + steps(step(s)(o), b-1)(o.right) + } + } + + def play(s: State)(implicit o: Oracle[Action]): State = { + steps(s, -1) + } + + @extern + def display(s: State): Int = { + print('â•”') + for (x <- 0 until s.map.size.x) { + print('â•') + } + println('â•—') + for (y <- 0 until s.map.size.y) { + print('â•‘') + for (x <- 0 until s.map.size.x) { + val c = Pos(x,y) + if (s.map.walls contains c) { + print('X') + } else if (s.pos == c) { + print('o') + } else if (s.monster == c) { + print('m') + } else { + print(" ") + } + } + println('â•‘') + } + print('â•š') + for (x <- 0 until s.map.size.x) { + print('â•') + } + println('â•') + + 42 + } + + @extern + def main(args: Array[String]) { + + abstract class OracleSource[T] extends Oracle[T] { + def branch: OracleSource[T] + def value: T + + lazy val v: T = value + lazy val l: OracleSource[T] = branch + lazy val r: OracleSource[T] = branch + + override def head = v + override def left = l + override def right = r + } + + class Keyboard extends OracleSource[Action] { + def branch = new Keyboard + def value = { + var askAgain = false + var action: Action = Quit + do { + if (askAgain) println("?") + askAgain = false + print("> ") + readLine().trim match { + case "up" => + action = MoveUp + case "down" => + action = MoveDown + case "left" => + action = MoveLeft + case "right" => + action = MoveRight + case "quit" => + action = Quit + case _ => + askAgain = true + } + } while(askAgain) + + action + } + } + + class Random extends OracleSource[Action] { + def value = { + readLine() + scala.util.Random.nextInt(4) match { + case 0 => + MoveUp + case 1 => + MoveDown + case 2 => + MoveLeft + case 3 => + MoveRight + case _ => + MoveUp + } + } + + def branch = new Random + } + + val map = Map(Set(Pos(2,2), Pos(2,3), Pos(4,4), Pos(5,5)), Pos(10,10)) + val init = State(Pos(0,0), Pos(4,5), false, map) + + play(init)(new Random) + } + + def test1() = { + withOracle{ o: Oracle[Action] => + { + val map = Map(Set(Pos(2,2), Pos(2,3), Pos(4,4), Pos(5,5)), Pos(10,10)) + val init = State(Pos(0,0), Pos(4,4), false, map) + + steps(init, 5)(o) + } ensuring { + _.pos == Pos(0,3) + } + } + } + + def validStep(s: State) = { + require(s.map.size.x > 3 && s.map.size.y > 3 && s.pos != s.monster && s.isValid && !s.stop) + + val ns = withOracle { o: Oracle[Action] => + { + stepPlayer(s)(o) + } ensuring { + res => res.isValid && !res.stop + } + } + stepMonster(ns) + } ensuring { + res => res.isValid && !res.stop + } +} diff --git a/library/annotation/package.scala b/library/annotation/package.scala index 1aa564cbd6b0d5aa13d3ba20e5444d4013480aad..ca65c78bd6efd112cc73bbe134b582a530b33077 100644 --- a/library/annotation/package.scala +++ b/library/annotation/package.scala @@ -17,7 +17,7 @@ package object annotation { @ignore class main extends StaticAnnotation @ignore - class proxy extends StaticAnnotation + class extern extends StaticAnnotation @ignore class ignore extends StaticAnnotation diff --git a/library/lang/synthesis/package.scala b/library/lang/synthesis/package.scala index abf314dc838a064a7699cbd9ae85358f30094764..540054899b37dc540f0c359f51fe6e32dc553c82 100644 --- a/library/lang/synthesis/package.scala +++ b/library/lang/synthesis/package.scala @@ -6,25 +6,28 @@ import leon.annotation._ package object synthesis { @ignore - private def noChoose = throw new RuntimeException("Implementation not supported") + private def noImpl = throw new RuntimeException("Implementation not supported") @ignore - def choose[A](predicate: A => Boolean): A = noChoose + def choose[A](predicate: A => Boolean): A = noImpl @ignore - def choose[A, B](predicate: (A, B) => Boolean): (A, B) = noChoose + def choose[A, B](predicate: (A, B) => Boolean): (A, B) = noImpl @ignore - def choose[A, B, C](predicate: (A, B, C) => Boolean): (A, B, C) = noChoose + def choose[A, B, C](predicate: (A, B, C) => Boolean): (A, B, C) = noImpl @ignore - def choose[A, B, C, D](predicate: (A, B, C, D) => Boolean): (A, B, C, D) = noChoose + def choose[A, B, C, D](predicate: (A, B, C, D) => Boolean): (A, B, C, D) = noImpl @ignore - def choose[A, B, C, D, E](predicate: (A, B, C, D, E) => Boolean): (A, B, C, D, E) = noChoose + def choose[A, B, C, D, E](predicate: (A, B, C, D, E) => Boolean): (A, B, C, D, E) = noImpl @library def ???[T](implicit o: Oracle[T]): T = o.head @library - def ?[T](e1: T)(implicit o: Oracle[Boolean], o2: Oracle[T]): T = if(???[Boolean]) e1 else ???[T] + def ?[T](e1: T)(implicit o1: Oracle[Boolean], o2: Oracle[T]): T = if(???[Boolean](o1)) e1 else ???[T](o2) @ignore - def ?[T](e1: T, es: T*)(implicit o: Oracle[Boolean]): T = noChoose + def ?[T](e1: T, es: T*)(implicit o: Oracle[Boolean]): T = noImpl + + @ignore + def withOracle[A, R](body: Oracle[A] => R): R = noImpl } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index f3abb7513d9800035d95b13a7c231b42672ecfea..bbade12a57bc0748faa385f592f511e50ee9bdda 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -421,9 +421,6 @@ trait CodeGeneration { ch << InvokeSpecial(ErrorClass, constructorName, "(Ljava/lang/String;)V") ch << ATHROW - case hole @ Hole(oracle) => - mkExpr(OracleTraverser(oracle, hole.getType, program).value, ch) - case choose @ Choose(_, _) => val prob = synthesis.Problem.fromChoose(choose) diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index c018e48fa47232d62a655f93123d2952148b7a24..5c988ebd73615d970d59de8478e2dbb6828a073a 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -289,11 +289,21 @@ trait ASTExtractors { } object ExChooseExpression { - def unapply(tree: Apply) : Option[(List[(Tree, Symbol)], Tree, Tree, Tree)] = tree match { + def unapply(tree: Apply) : Option[(List[(Tree, Symbol)], Tree)] = tree match { case a @ Apply( TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "choose"), types), Function(vds, predicateBody) :: Nil) => - Some(((types zip vds.map(_.symbol)).toList, a, predicateBody, s)) + Some(((types zip vds.map(_.symbol)).toList, predicateBody)) + case _ => None + } + } + + object ExWithOracleExpression { + def unapply(tree: Apply) : Option[(List[(Tree, Symbol)], Tree)] = tree match { + case a @ Apply( + TypeApply(s @ ExSymbol("leon", "lang", "synthesis", "withOracle"), types), + Function(vds, body) :: Nil) => + Some(((types zip vds.map(_.symbol)).toList, body)) case _ => None } } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 3474f6d7c4734690360003f6175566c8ee4b8abe..e77b25d5058980483c802112104163e2a8d59c78 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -35,7 +35,9 @@ trait CodeExtraction extends ASTExtractors { val reporter = self.ctx.reporter def annotationsOf(s: Symbol): Set[String] = { - (for(a <- s.annotations ++ s.owner.annotations) yield { + val actualSymbol = s.accessedOrSelf + + (for(a <- actualSymbol.annotations ++ actualSymbol.owner.annotations) yield { val name = a.atp.safeToString.replaceAll("\\.package\\.", ".") if (name startsWith "leon.annotation.") { Some(name.split("\\.", 3)(2)) @@ -136,14 +138,14 @@ trait CodeExtraction extends ASTExtractors { tparams: Map[Symbol, TypeParameter] = Map(), vars: Map[Symbol, () => LeonExpr] = Map(), mutableVars: Map[Symbol, () => LeonExpr] = Map(), - isProxy: Boolean = false + isExtern: Boolean = false ) { def union(that: DefContext) = { copy(this.tparams ++ that.tparams, this.vars ++ that.vars, this.mutableVars ++ that.mutableVars, - this.isProxy || that.isProxy) + this.isExtern || that.isExtern) } def isVariable(s: Symbol) = (vars contains s) || (mutableVars contains s) @@ -165,8 +167,8 @@ trait CodeExtraction extends ASTExtractors { (annotationsOf(s) contains "ignore") || (s.fullName.toString.endsWith(".main")) } - def isProxy(s: Symbol) = { - annotationsOf(s) contains "proxy" + def isExtern(s: Symbol) = { + annotationsOf(s) contains "extern" } def extractModules: List[LeonModuleDef] = { @@ -239,6 +241,17 @@ trait CodeExtraction extends ASTExtractors { private var classesToClasses = Map[Symbol, LeonClassDef]() + def oracleType(pos: Position, tpe: LeonType) = { + classesToClasses.find { + case (sym, cl) => (sym.fullName.toString == "leon.lang.synthesis.Oracle") + } match { + case Some((_, cd)) => + classDefToClassType(cd, List(tpe)) + case None => + outOfSubsetError(pos, "Could not find class Oracle") + } + } + def libraryMethod(classname: String, methodName: String): Option[(LeonClassDef, FunDef)] = { classesToClasses.values.find(_.id.name == classname).flatMap { cl => cl.methods.find(_.id.name == methodName).map { fd => (cl, fd) } @@ -289,7 +302,7 @@ trait CodeExtraction extends ASTExtractors { if (sym.name.toString == "synthesis") { new Exception().printStackTrace() } - outOfSubsetError(pos, "Class "+sym.name+" not defined?") + outOfSubsetError(pos, "Class "+sym.fullName+" not defined?") } } } @@ -407,7 +420,7 @@ trait CodeExtraction extends ASTExtractors { // Type params of the function itself val tparams = extractTypeParams(sym.typeParams.map(_.tpe)) - val nctx = dctx.copy(tparams = dctx.tparams ++ tparams.toMap, isProxy = isProxy(sym)) + val nctx = dctx.copy(tparams = dctx.tparams ++ tparams.toMap, isExtern = isExtern(sym)) val newParams = sym.info.paramss.flatten.map{ sym => val ptpe = toPureScalaType(sym.tpe)(nctx, sym.pos) @@ -499,7 +512,7 @@ trait CodeExtraction extends ASTExtractors { val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap - extractFunBody(fd, params, body)(DefContext(tparamsMap, isProxy = isProxy(sym))) + extractFunBody(fd, params, body)(DefContext(tparamsMap, isExtern = isExtern(sym))) case _ => } @@ -578,10 +591,10 @@ trait CodeExtraction extends ASTExtractors { } } catch { case e: ImpureCodeEncounteredException => - if (!dctx.isProxy) { + if (!dctx.isExtern) { e.emit() if (ctx.settings.strictCompilation) { - reporter.error(funDef.getPos, "Function "+funDef.id.name+" could not be extracted. (Forgot @proxy ?)") + reporter.error(funDef.getPos, "Function "+funDef.id.name+" could not be extracted. (Forgot @extern ?)") } else { reporter.warning(funDef.getPos, "Function "+funDef.id.name+" is not fully unavailable to Leon.") } @@ -708,7 +721,7 @@ trait CodeExtraction extends ASTExtractors { val b = try { extractTree(body) } catch { - case (e: ImpureCodeEncounteredException) if dctx.isProxy => + case (e: ImpureCodeEncounteredException) if dctx.isExtern => NoTree(toPureScalaType(current.tpe)(dctx, current.pos)) } @@ -721,7 +734,7 @@ trait CodeExtraction extends ASTExtractors { val b = try { extractTree(body) } catch { - case (e: ImpureCodeEncounteredException) if dctx.isProxy => + case (e: ImpureCodeEncounteredException) if dctx.isExtern => NoTree(toPureScalaType(current.tpe)(dctx, current.pos)) } @@ -741,7 +754,7 @@ trait CodeExtraction extends ASTExtractors { val b = try { rest.map(extractTree).getOrElse(UnitLiteral()) } catch { - case (e: ImpureCodeEncounteredException) if dctx.isProxy => + case (e: ImpureCodeEncounteredException) if dctx.isExtern => NoTree(toPureScalaType(current.tpe)(dctx, current.pos)) } @@ -812,7 +825,7 @@ trait CodeExtraction extends ASTExtractors { fd.addAnnotation(annotationsOf(d.symbol).toSeq : _*) - val newDctx = dctx.copy(tparams = dctx.tparams ++ tparamsMap, isProxy = isProxy(sym)) + val newDctx = dctx.copy(tparams = dctx.tparams ++ tparamsMap, isExtern = isExtern(sym)) val oldCurrentFunDef = currentFunDef @@ -949,36 +962,58 @@ trait CodeExtraction extends ASTExtractors { } case hole @ ExHoleExpression(tpt, exprs, os) => - val leonTpe = extractType(tpt) val leonExprs = exprs.map(extractTree) val leonOracles = os.map(extractTree) + def rightOf(o: LeonExpr): LeonExpr = { + val Some((cl, fd)) = libraryMethod("Oracle", "right") + MethodInvocation(o, cl, fd.typed(Nil), Nil) + } + + def valueOf(o: LeonExpr): LeonExpr = { + val Some((cl, fd)) = libraryMethod("Oracle", "head") + MethodInvocation(o, cl, fd.typed(Nil), Nil) + } + + leonExprs match { case Nil => - Hole(leonOracles(0)).setType(leonTpe) + valueOf(leonOracles(0)) case List(e) => - IfExpr(Hole(leonOracles(0)).setType(BooleanType), e, Hole(leonOracles(1)).setType(leonTpe)) + IfExpr(valueOf(leonOracles(0)), e, valueOf(leonOracles(1))) case exs => val l = exs.last var o = leonOracles(0) - def rightOf(o: LeonExpr): LeonExpr = { - val Some((cl, fd)) = libraryMethod("Oracle", "right") - MethodInvocation(o, cl, fd.typed(Nil), Nil) - } - exs.init.foldRight(l)({ (e: LeonExpr, r: LeonExpr) => - val res = IfExpr(Hole(o).setType(BooleanType), e, r) + val res = IfExpr(valueOf(leonOracles(0)), e, r) o = rightOf(o) res }) } - case chs @ ExChooseExpression(args, tpt, body, select) => - val cTpe = extractType(tpt) + case ops @ ExWithOracleExpression(oracles, body) => + val newOracles = oracles map { case (tpt, sym) => + val aTpe = extractType(tpt) + val oTpe = oracleType(ops.pos, aTpe) + val newID = FreshIdentifier(sym.name.toString).setType(oTpe) + owners += (newID -> None) + newID + } + + val newVars = (oracles zip newOracles).map { + case ((_, sym), id) => + sym -> (() => Variable(id)) + } + + val cBody = extractTree(body)(dctx.withNewVars(newVars)) + + WithOracle(newOracles, cBody) + + case chs @ ExChooseExpression(args, body) => val vars = args map { case (tpt, sym) => val aTpe = extractType(tpt) val newID = FreshIdentifier(sym.name.toString).setType(aTpe) @@ -1362,7 +1397,7 @@ trait CodeExtraction extends ASTExtractors { if (seenClasses contains sym) { classDefToClassType(getClassDef(sym, NoPosition), tps) } else { - if (dctx.isProxy) { + if (dctx.isExtern) { unknownsToTP.getOrElse(sym, { val tp = TypeParameter(FreshIdentifier(sym.name.toString, true)) unknownsToTP += sym -> tp diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 1ae66eac7c2f4402b05d615593027cc5508f3d71..7873fbaa138fb09dafc4d4bd3ba27b1f056896a4 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -169,6 +169,11 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe | (${typed(id)}) => $post |}""" + case c @ WithOracle(vars, pred) => + p"""|withOracle { (${typed(vars)}) => + | $pred + |}""" + case CaseClass(cct, args) => if (cct.classDef.isCaseObject) { p"$cct" @@ -190,7 +195,6 @@ class PrettyPrinter(opts: PrinterOptions, val sb: StringBuffer = new StringBuffe case GenericValue(tp, id) => p"$tp#$id" case Tuple(exprs) => p"($exprs)" case TupleSelect(t, i) => p"${t}._$i" - case h @ Hole(o) => p"???[${h.getType}]($o)" case Choose(vars, pred) => p"choose(($vars) => $pred)" case e @ Error(err) => p"""error[${e.getType}]("$err")""" case CaseClassInstanceOf(cct, e) => diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 87090a80da5b9712adb197317599fc83c4952bb0..4807224e5197e6fbfe0feff272c668cff3542c05 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -16,7 +16,7 @@ class ScalaPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) ex override def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = { tree match { case Not(Equals(l, r)) => p"$l != $r" - case Iff(l,r) => pp(Equals(l, r)) + case Iff(l,r) => p"$l == $r" case Implies(l,r) => pp(Or(Not(l), r)) case Choose(vars, pred) => p"choose((${typed(vars)}) => $pred)" case s @ FiniteSet(rs) => { diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 9bd1b04ed022474a56a88dc4e5f9a8d616eedbf5..6e0a85e0f41643e0040c1af2d04b1f3aea30a67c 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -415,7 +415,7 @@ object TreeOps { val eval = new DefaultEvaluator(ctx, program) def isGround(e: Expr): Boolean = { - variablesOf(e).isEmpty && !usesHoles(e) && !containsChoose(e) + variablesOf(e).isEmpty && !containsChoose(e) } def rec(e: Expr): Option[Expr] = e match { @@ -1341,38 +1341,6 @@ object TreeOps { false } - def containsHoles(e: Expr): Boolean = { - preTraversal{ - case Hole(_) => return true - case _ => - }(e) - false - } - - /** - * Returns true if the expression directly or indirectly relies on a Hole - */ - def usesHoles(e: Expr): Boolean = { - var cache = Map[FunDef, Boolean]() - - def callsHolesExpr(e: Expr): Boolean = { - containsHoles(e) || functionCallsOf(e).exists(fi => callsHoles(fi.tfd.fd)) - } - - def callsHoles(fd: FunDef): Boolean = cache.get(fd) match { - case Some(r) => r - case None => - cache += fd -> false - - val res = fd.body.map(callsHolesExpr _).getOrElse(false) - - cache += fd -> res - res - } - - callsHolesExpr(e) - } - /** * Returns the value for an identifier given a model. */ diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 2a56870b90183c225f39a3a1cfa111f9e207da9a..fd0b1f4d0cf8ccec5dc2cace3c9ffa6ba66a3b84 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -53,10 +53,14 @@ object Trees { } } - // A hole is like a all-seeing choose - case class Hole(oracle: Expr) extends Expr with UnaryExtractable { + // Provide an oracle (synthesizable, all-seeing choose) + case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with FixedType with UnaryExtractable { + assert(!oracles.isEmpty) + + val fixedType = body.getType + def extract = { - Some((oracle, (o: Expr) => Hole(o).copiedFrom(this))) + Some((body, (e: Expr) => WithOracle(oracles, e).setPos(this))) } } diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 6ad14ff61b72883f270c1a1cac909d83209ab418..06a94d7ddf23ac487ea74c758960d4e71c7097bc 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -243,14 +243,21 @@ trait AbstractZ3Solver case class UntranslatableTypeException(msg: String) extends Exception(msg) - def rootType(ct: ClassType): ClassType = ct.parent match { - case Some(p) => rootType(p) - case None => ct + def rootType(ct: TypeTree): TypeTree = ct match { + case ct: ClassType => + ct.parent match { + case Some(p) => rootType(p) + case None => ct + } + case t => t } def declareADTSort(ct: ClassType): Z3Sort = { import Z3Context.{ADTSortReference, RecursiveType, RegularSort} + //println("///"*40) + //println("Declaring for: "+ct) + def getHierarchy(ct: ClassType): (ClassType, Seq[CaseClassType]) = ct match { case act: AbstractClassType => (act, act.knownCCDescendents) @@ -263,57 +270,63 @@ trait AbstractZ3Solver } } - var newHierarchiesMap = Map[ClassType, Seq[CaseClassType]]() + def resolveTypes(ct: ClassType) = { + var newHierarchiesMap = Map[ClassType, Seq[CaseClassType]]() - def findDependencies(ct: ClassType): Unit = { - val (root, sub) = getHierarchy(ct) + def findDependencies(ct: ClassType): Unit = { + val (root, sub) = getHierarchy(ct) - if (!(newHierarchiesMap contains root) && !(sorts containsLeon root)) { - newHierarchiesMap += root -> sub + if (!(newHierarchiesMap contains root) && !(sorts containsLeon root)) { + newHierarchiesMap += root -> sub - // look for dependencies - for (ct <- root +: sub; f <- ct.fields) f.tpe match { - case fct: ClassType => - findDependencies(fct) - case _ => + // look for dependencies + for (ct <- root +: sub; f <- ct.fields) f.tpe match { + case fct: ClassType => + findDependencies(fct) + case _ => + } } } - } - // Populates the dependencies of the ADT to define. - findDependencies(ct) + // Populates the dependencies of the ADT to define. + findDependencies(ct) + + //println("Dependencies: ") + //for ((r, sub) <- newHierarchiesMap) { + // println(s" - $r >: $sub") + //} - val newHierarchies = newHierarchiesMap.toSeq + val newHierarchies = newHierarchiesMap.toSeq - val indexMap: Map[ClassType, Int] = Map()++newHierarchies.map(_._1).zipWithIndex + val indexMap: Map[ClassType, Int] = Map()++newHierarchies.map(_._1).zipWithIndex - def typeToSortRef(tt: TypeTree): ADTSortReference = tt match { - case ct: ClassType if sorts containsLeon rootType(ct) => - RegularSort(sorts.toZ3(rootType(ct))) + def typeToSortRef(tt: TypeTree): ADTSortReference = rootType(tt) match { + case ct: ClassType if sorts containsLeon ct => + RegularSort(sorts.toZ3(ct)) - case act : AbstractClassType => - // It has to be here - RecursiveType(indexMap(act)) + case act: ClassType => + // It has to be here + RecursiveType(indexMap(act)) - case cct: CaseClassType => cct.parent match { - case Some(p) => - typeToSortRef(p) - case None => - RecursiveType(indexMap(cct)) + case _=> + RegularSort(typeToSort(tt)) } - case _=> - RegularSort(typeToSort(tt)) + // Define stuff + val defs = for ((root, childrenList) <- newHierarchies) yield { + ( + root.id.uniqueName, + childrenList.map(ccd => ccd.id.uniqueName), + childrenList.map(ccd => ccd.fields.map(f => (f.id.uniqueName, typeToSortRef(f.tpe)))) + ) + } + (defs, newHierarchies) } - // Define stuff - val defs = for ((root, childrenList) <- newHierarchies) yield { - ( - root.id.uniqueName, - childrenList.map(ccd => ccd.id.uniqueName), - childrenList.map(ccd => ccd.fields.map(f => (f.id.uniqueName, typeToSortRef(f.tpe)))) - ) - } + // @EK: the first step is needed to introduce ADT sorts referenced inside Sets of this CT + // When defining Map(s: Set[Pos], p: Pos), it will need Pos, but Pos will be defined through Set[Pos] in the first pass + resolveTypes(ct) + val (defs, newHierarchies) = resolveTypes(ct) //for ((n, sub, cstrs) <- defs) { // println(n+":") @@ -343,6 +356,8 @@ trait AbstractZ3Solver } } + //println("\\\\\\"*40) + sorts.toZ3(ct) } @@ -403,7 +418,6 @@ trait AbstractZ3Solver case tt @ SetType(base) => sorts.toZ3OrCompute(tt) { val newSetSort = z3.mkSetSort(typeToSort(base)) - val card = z3.mkFreshFuncDecl("card", Seq(newSetSort), typeToSort(Int32Type)) setCardDecls += tt -> card @@ -566,7 +580,7 @@ trait AbstractZ3Solver z3.mkApp(functionDefToDecl(tfd), args.map(rec(_)): _*) case SetEquals(s1, s2) => z3.mkEq(rec(s1), rec(s2)) - case ElementOfSet(e, s) => z3.mkSetSubset(z3.mkSetAdd(z3.mkEmptySet(typeToSort(e.getType)), rec(e)), rec(s)) + case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) @@ -650,9 +664,6 @@ trait AbstractZ3Solver case gv @ GenericValue(tp, id) => z3.mkApp(genericValueToDecl(gv)) - case h @ Hole(o) => - rec(OracleTraverser(o, h.getType, program).value) - case _ => { reporter.warning(ex.getPos, "Can't handle this in translation to Z3: " + ex) throw new CantTranslateException diff --git a/src/main/scala/leon/synthesis/ChooseInfo.scala b/src/main/scala/leon/synthesis/ChooseInfo.scala index 85f51ee663fb940d036594eb80f040c483b721fc..a9e49f10d136be5d3bc1369ff783e08f2d1d2ec1 100644 --- a/src/main/scala/leon/synthesis/ChooseInfo.scala +++ b/src/main/scala/leon/synthesis/ChooseInfo.scala @@ -35,26 +35,6 @@ object ChooseInfo { } } - - if (options.allSeeing) { - // Functions that call holes are also considered for (all-seeing) synthesis - - val holesFd = prog.definedFunctions.filter(fd => fd.hasBody && containsHoles(fd.body.get)).toSet - - val callers = prog.callGraph.transitiveCallers(holesFd) ++ holesFd - - for (f <- callers if f.hasPostcondition && f.hasBody) { - val path = f.precondition.getOrElse(BooleanLiteral(true)) - - val x = FreshIdentifier("x", true).setType(f.returnType) - val (pid, pex) = f.postcondition.get - - val ch = Choose(List(x), And(Equals(x.toVariable, f.body.get), replaceFromIDs(Map(pid -> x.toVariable), pex))) - - results = ChooseInfo(ctx, prog, f, path, f.body.get, ch, options) :: results - } - } - results.sortBy(_.fd.id.toString) } } diff --git a/src/main/scala/leon/synthesis/ConvertWithOracles.scala b/src/main/scala/leon/synthesis/ConvertWithOracles.scala new file mode 100644 index 0000000000000000000000000000000000000000..d7c1b6853d7c554b656983c1aabb2f58a08f098b --- /dev/null +++ b/src/main/scala/leon/synthesis/ConvertWithOracles.scala @@ -0,0 +1,91 @@ +/* Copyright 2009-2014 EPFL, Lausanne */ + +package leon +package synthesis + +import purescala.Common._ +import purescala.Trees._ +import purescala.TreeOps._ +import purescala.Definitions._ + +object ConvertWithOracle extends LeonPhase[Program, Program] { + val name = "Convert WithOracle to Choose" + val description = "Convert WithOracle found in bodies to equivalent Choose" + + /** + * This phase converts a body with "withOracle{ .. }" into a choose construct: + * + * def foo(a: T) = { + * require(..a..) + * withOracle { o => + * expr(a,o) ensuring { x => post(x) } + * } + * } + * + * gets converted into: + * + * def foo(a: T) { + * require(..a..) + * val o = choose { (o) => { + * val res = expr(a, o) + * pred(res) + * } + * expr(a,o) + * } ensuring { res => + * pred(res) + * } + * + */ + def run(ctx: LeonContext)(pgm: Program): Program = { + + pgm.definedFunctions.foreach(fd => { + if (fd.hasBody) { + val body = preMap { + case wo @ WithOracle(os, b) => + withoutSpec(b) match { + case Some(body) => + val chooseOs = os.map(_.freshen) + + val pred = postconditionOf(b) match { + case Some((id, post)) => + replaceFromIDs((os zip chooseOs.map(_.toVariable)).toMap, Let(id, body, post)) + case None => + BooleanLiteral(true) + } + + if (chooseOs.size > 1) { + Some(LetTuple(os, Choose(chooseOs, pred), b)) + } else { + Some(Let(os.head, Choose(chooseOs, pred), b)) + } + case None => + None + } + case _ => None + }(fd.body.get) + + fd.body = Some(body) + } + + // Ensure that holes are not found in pre and/or post conditions + fd.precondition.foreach { + preTraversal{ + case _: WithOracle => + ctx.reporter.error("WithOracle expressions are not supported in preconditions. (function "+fd.id.asString(ctx)+")") + case _ => + } + } + + fd.postcondition.foreach { case (id, post) => + preTraversal{ + case _: WithOracle => + ctx.reporter.error("WithOracle expressions are not supported in postconditions. (function "+fd.id.asString(ctx)+")") + case _ => + }(post) + } + + }) + + pgm + } +} diff --git a/src/main/scala/leon/synthesis/CostModel.scala b/src/main/scala/leon/synthesis/CostModel.scala index 85e5285eedd9e14f5d28042949df8db62fb40b8d..a133d7dee7fb7d0714dee560abbc26f68fef7f6a 100644 --- a/src/main/scala/leon/synthesis/CostModel.scala +++ b/src/main/scala/leon/synthesis/CostModel.scala @@ -98,11 +98,7 @@ case object WeightedBranchesCostModel extends CostModel("WeightedBranches") { def problemCost(p: Problem): Cost = new Cost { val value = { - if (usesHoles(p.phi)) { - p.xs.size + 50 - } else { - p.xs.size - } + p.xs.size } } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 85c8b60245e0f36626f52031b8498e13c3ad5e6f..5f75db65091be80f8e0c5f3f3301c655fb4a8f93 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -31,8 +31,8 @@ object Rules { ADTSplit, InlineHoles, IntegerEquation, - IntegerInequalities, - AngelicHoles + IntegerInequalities + //AngelicHoles // @EK: Disabled now as it is explicit with withOracle { .. } ) def getInstantiations(sctx: SynthesisContext, problem: Problem) = { diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 00a00766103071a2ffe041b4e4532a2761ba8f06..1de672ad9355eef54aacef3d6ffb4c330f735429 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -36,10 +36,15 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr) { // // Indices are 0-indexed def project(indices: Seq[Int]): Solution = { - val t = FreshIdentifier("t", true).setType(term.getType) - val newTerm = Let(t, term, Tuple(indices.map(i => TupleSelect(t.toVariable, i+1)))) - - Solution(pre, defs, newTerm) + term.getType match { + case TupleType(ts) => + val t = FreshIdentifier("t", true).setType(term.getType) + val newTerm = Let(t, term, Tuple(indices.map(i => TupleSelect(t.toVariable, i+1)))) + + Solution(pre, defs, newTerm) + case _ => + this + } } diff --git a/src/main/scala/leon/synthesis/SynthesisOptions.scala b/src/main/scala/leon/synthesis/SynthesisOptions.scala index 706fc69fff03d9cd600ad88c9739029346f70eb5..a92550ac9cb0401d7c4e6dcf15bd33babb87364c 100644 --- a/src/main/scala/leon/synthesis/SynthesisOptions.scala +++ b/src/main/scala/leon/synthesis/SynthesisOptions.scala @@ -26,5 +26,8 @@ case class SynthesisOptions( cegisUseCETests: Boolean = true, cegisUseCEPruning: Boolean = true, cegisUseBPaths: Boolean = true, - cegisUseVanuatoo: Boolean = false + cegisUseVanuatoo: Boolean = false, + + // Oracles and holes + distreteHoles: Boolean = false ) diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index e68dbf292e97d34d3f971753d8554ec894e236ca..ad404e220b2575d64fe44f274eb6f4091413c3d5 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -32,7 +32,8 @@ object SynthesisPhase extends LeonPhase[Program, Program] { LeonFlagOptionDef( "cegis:bssfilter", "--cegis:bssfilter", "Filter non-det programs when tests pruning works well", true), LeonFlagOptionDef( "cegis:unsatcores", "--cegis:unsatcores", "Use UNSAT-cores in pruning", true), LeonFlagOptionDef( "cegis:opttimeout", "--cegis:opttimeout", "Consider a time-out of CE-search as untrusted solution", true), - LeonFlagOptionDef( "cegis:vanuatoo", "--cegis:vanuatoo", "Generate inputs using new korat-style generator", false) + LeonFlagOptionDef( "cegis:vanuatoo", "--cegis:vanuatoo", "Generate inputs using new korat-style generator", false), + LeonFlagOptionDef( "holes:discrete", "--holes:discrete", "Oracles get split", false) ) def processOptions(ctx: LeonContext): SynthesisOptions = { @@ -110,6 +111,9 @@ object SynthesisPhase extends LeonPhase[Program, Program] { case LeonFlagOption("cegis:vanuatoo", v) => options = options.copy(cegisUseVanuatoo = v) + case LeonFlagOption("holes:discrete", v) => + options = options.copy(distreteHoles = v) + case _ => } diff --git a/src/main/scala/leon/synthesis/rules/AngelicHoles.scala b/src/main/scala/leon/synthesis/rules/AngelicHoles.scala deleted file mode 100644 index 07c5f687fb9a16a9c825995d07c84cc7a93aa7c8..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/synthesis/rules/AngelicHoles.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2009-2014 EPFL, Lausanne */ - -package leon -package synthesis -package rules - -import purescala.Common._ -import purescala.Definitions._ -import purescala.Trees._ -import purescala.TreeOps._ -import purescala.TypeTrees._ -import purescala.Extractors._ - -// Synthesizing a function with Hole is actually synthesizing an Oracle, so Oracle becomes output: -// [[ a,o < Phi(a,o,x) > x ]] ---> [[ a < Phi(a,o,x) > x, o ]] -case object AngelicHoles extends NormalizingRule("Angelic Holes") { - def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { - val oracleClass = sctx.program.definedClasses.find(_.id.name == "Oracle").getOrElse { - sctx.reporter.fatalError("Can't find Oracle class") - } - - def isOracle(i: Identifier) = { - i.getType match { - case AbstractClassType(acd, _) if acd == oracleClass => true - case _ => false - } - } - - if (usesHoles(p.phi)) { - val (oracles, as) = p.as.partition(isOracle) - - if (oracles.nonEmpty) { - val sub = p.copy(as = as, xs = p.xs ++ oracles) - List(RuleInstantiation.immediateDecomp(p, this, List(sub), { - case List(s) => - // We ignore the last output params that are oracles - Some(s.project(0 until p.xs.size)) - - case _ => - None - }, "Hole Semantics")) - } else { - Nil - } - } else { - Nil - } - } -} - diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 83c5a3bf4d223386a5dbe73f18838e02b56e8cda..99cdcda04fc1299328af82d712d618a20874b0ef 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -112,7 +112,7 @@ case object CEGIS extends Rule("CEGIS") { val isNotSynthesizable = fd.body match { case Some(b) => - !containsChoose(b) && !usesHoles(b) + !containsChoose(b) case None => false diff --git a/src/main/scala/leon/synthesis/rules/InlineHoles.scala b/src/main/scala/leon/synthesis/rules/InlineHoles.scala index 97961429fefb1bdfdb7cf84f25da51a18b209578..8d7071648b83de79dc14e74d2e01edbe76ab50d9 100644 --- a/src/main/scala/leon/synthesis/rules/InlineHoles.scala +++ b/src/main/scala/leon/synthesis/rules/InlineHoles.scala @@ -11,6 +11,7 @@ import leon.utils._ import solvers._ import purescala.Common._ +import purescala.Definitions._ import purescala.Trees._ import purescala.TreeOps._ import purescala.TypeTrees._ @@ -20,6 +21,46 @@ case object InlineHoles extends Rule("Inline-Holes") { override val priority = RulePriorityHoles def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + // When true: withOracle gets converted into a big choose() on result. + val discreteHoles = sctx.options.distreteHoles + + if (!discreteHoles) { + return Nil; + } + + val Some(oracleHead) = sctx.program.definedFunctions.find(_.id.name == "Oracle.head") + + def containsHoles(e: Expr): Boolean = { + preTraversal{ + case FunctionInvocation(TypedFunDef(`oracleHead`, _), _) => return true + case _ => + }(e) + false + } + + /** + * Returns true if the expression directly or indirectly relies on a Hole + */ + def usesHoles(e: Expr): Boolean = { + var cache = Map[FunDef, Boolean]() + + def callsHolesExpr(e: Expr): Boolean = { + containsHoles(e) || functionCallsOf(e).exists(fi => callsHoles(fi.tfd.fd)) + } + + def callsHoles(fd: FunDef): Boolean = cache.get(fd) match { + case Some(r) => r + case None => + cache += fd -> false + + val res = fd.body.map(callsHolesExpr _).getOrElse(false) + + cache += fd -> res + res + } + + callsHolesExpr(e) + } @tailrec def inlineUntilHoles(e: Expr): Expr = { @@ -41,16 +82,15 @@ case object InlineHoles extends Rule("Inline-Holes") { } def inlineHoles(phi: Expr): (List[Identifier], Expr) = { + var newXs = List[Identifier]() val res = preMap { - case h @ Hole(o) => - val tpe = h.getType + case h @ FunctionInvocation(TypedFunDef(`oracleHead`, Seq(tpe)), Seq(o)) => val x = FreshIdentifier("h", true).setType(tpe) newXs ::= x Some(x.toVariable) - case _ => None }(phi) @@ -87,7 +127,6 @@ case object InlineHoles extends Rule("Inline-Holes") { None } - // 2) a version with holes reachable to continue applying itself val newPhi = inlineUntilHoles(p.phi) val (newXs, newPhiInlined) = inlineHoles(newPhi) diff --git a/src/main/scala/leon/synthesis/rules/OnePoint.scala b/src/main/scala/leon/synthesis/rules/OnePoint.scala index 2091e67dc3a98a214620e264a8cd390751e9bfd8..7deff8fe62a2df6ed1d0d91935251edb0795097d 100644 --- a/src/main/scala/leon/synthesis/rules/OnePoint.scala +++ b/src/main/scala/leon/synthesis/rules/OnePoint.scala @@ -14,7 +14,7 @@ case object OnePoint extends NormalizingRule("One-point") { val TopLevelAnds(exprs) = p.phi def validOnePoint(x: Identifier, e: Expr) = { - !(variablesOf(e) contains x) && !usesHoles(e) + !(variablesOf(e) contains x) } val candidates = exprs.collect { diff --git a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala index 17080def2769fffa3f165e64e4c89fbbd1f46267..1e5a3a3318e3ac3f9b3d050bbd81fa81395514e4 100644 --- a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala +++ b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala @@ -17,7 +17,14 @@ case object UnconstrainedOutput extends NormalizingRule("Unconstr.Output") { val onSuccess: List[Solution] => Option[Solution] = { case List(s) => - Some(Solution(s.pre, s.defs, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(id.getType) else Variable(id)))))) + val term = if (sub.xs.size > 1) { + LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(id.getType) else Variable(id)))) + } else if (sub.xs.size == 1) { + Let(sub.xs.head, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(id.getType) else Variable(id)))) + } else { + Tuple(p.xs.map(id => simplestValue(id.getType))) + } + Some(Solution(s.pre, s.defs, term)) case _ => None } diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala index a7e0b46a0832042f4acc989e1dc762f5eae0ff99..d5719090ec749321821eef874e9b612aef525880 100644 --- a/src/main/scala/leon/utils/PreprocessingPhase.scala +++ b/src/main/scala/leon/utils/PreprocessingPhase.scala @@ -6,6 +6,7 @@ package utils import purescala.Definitions.Program import purescala.{MethodLifting, CompleteAbstractDefinitions} +import synthesis.{ConvertWithOracle} object PreprocessingPhase extends TransformationPhase { @@ -17,6 +18,7 @@ object PreprocessingPhase extends TransformationPhase { val phases = MethodLifting andThen TypingPhase andThen + ConvertWithOracle andThen CompleteAbstractDefinitions andThen InjectAsserts diff --git a/src/main/scala/leon/utils/TypingPhase.scala b/src/main/scala/leon/utils/TypingPhase.scala index a3e1f3e61961fc5472ea66e18799f15caaa457d5..9e366a5eac457327cbc431677abca16cdf2d20c9 100644 --- a/src/main/scala/leon/utils/TypingPhase.scala +++ b/src/main/scala/leon/utils/TypingPhase.scala @@ -33,7 +33,7 @@ object TypingPhase extends LeonPhase[Program, Program] { // Part (1) fd.precondition = { val argTypesPreconditions = fd.params.flatMap(arg => arg.tpe match { - case cct : CaseClassType => Seq(CaseClassInstanceOf(cct, arg.id.toVariable)) + case cct : CaseClassType if cct.parent.isDefined => Seq(CaseClassInstanceOf(cct, arg.id.toVariable)) case _ => Seq() }) argTypesPreconditions match { @@ -46,7 +46,7 @@ object TypingPhase extends LeonPhase[Program, Program] { } fd.postcondition = fd.returnType match { - case cct : CaseClassType => { + case cct : CaseClassType if cct.parent.isDefined => { fd.postcondition match { case Some((id, p)) =>