diff --git a/src/main/scala/leon/LeonOption.scala b/src/main/scala/leon/LeonOption.scala index a0a9c9d92ee78397cce95dd0a52042fa8bc63330..f079399f9995e409b6b72881ef8f81d92c273556 100644 --- a/src/main/scala/leon/LeonOption.scala +++ b/src/main/scala/leon/LeonOption.scala @@ -25,8 +25,10 @@ abstract class LeonOptionDef[+A] { try { parser(s) } catch { case _ : IllegalArgumentException => - reporter.error(s"Invalid option usage: $usageDesc") - Main.displayHelp(reporter, error = true) + reporter.fatalError( + s"Invalid option usage: --$name\n" + + "Try 'leon --help' for more information." + ) } } diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index f238d4a854bd881c74a3205ef5b727a220e34ae1..db49d0425e39ac888badb90c97aa7c8432d0ee96 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -116,8 +116,10 @@ object Main { } // Find respective LeonOptionDef, or report an unknown option val df = allOptions.find(_. name == name).getOrElse{ - initReporter.error(s"Unknown option: $name") - displayHelp(initReporter, error = true) + initReporter.fatalError( + s"Unknown option: $name\n" + + "Try 'leon --help' for more information." + ) } df.parse(value)(initReporter) } diff --git a/src/main/scala/leon/SharedOptions.scala b/src/main/scala/leon/SharedOptions.scala index 839dda206b4a702ad5011049a38c76d7dcd21478..b68e64c83a6fd42e445dd4b5fa3b3bd42a935de4 100644 --- a/src/main/scala/leon/SharedOptions.scala +++ b/src/main/scala/leon/SharedOptions.scala @@ -5,12 +5,11 @@ package leon import leon.utils.{DebugSections, DebugSection} import OptionParsers._ -/* - * This object contains options that are shared among different modules of Leon. - * - * Options that determine the pipeline of Leon are not stored here, - * but in MainComponent in Main.scala. - */ +/** This object contains options that are shared among different modules of Leon. + * + * Options that determine the pipeline of Leon are not stored here, + * but in [[Main.MainComponent]] instead. + */ object SharedOptions extends LeonComponent { val name = "sharedOptions" @@ -45,7 +44,7 @@ object SharedOptions extends LeonComponent { val name = "debug" val description = { val sects = DebugSections.all.toSeq.map(_.name).sorted - val (first, second) = sects.splitAt(sects.length/2) + val (first, second) = sects.splitAt(sects.length/2 + 1) "Enable detailed messages per component.\nAvailable:\n" + " " + first.mkString(", ") + ",\n" + " " + second.mkString(", ") @@ -61,8 +60,6 @@ object SharedOptions extends LeonComponent { Set(rs) case None => throw new IllegalArgumentException - //initReporter.error("Section "+s+" not found, available: "+DebugSections.all.map(_.name).mkString(", ")) - //Set() } } } diff --git a/src/main/scala/leon/datagen/GrammarDataGen.scala b/src/main/scala/leon/datagen/GrammarDataGen.scala index cd86c707ddb893918e512f6d7101cc4cc92b6405..04541e78a6e63d1fa8670f4d4aeb12dd9c4417a2 100644 --- a/src/main/scala/leon/datagen/GrammarDataGen.scala +++ b/src/main/scala/leon/datagen/GrammarDataGen.scala @@ -4,14 +4,17 @@ package leon package datagen import purescala.Expressions._ -import purescala.Types.TypeTree +import purescala.Types._ import purescala.Common._ import purescala.Constructors._ import purescala.Extractors._ +import purescala.ExprOps._ import evaluators._ import bonsai.enumerators._ import grammars._ +import utils.UniqueCounter +import utils.SeqUtils.cartesianProduct /** Utility functions to generate values of a given type. * In fact, it could be used to generate *terms* of a given type, @@ -19,9 +22,40 @@ import grammars._ class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar[TypeTree] = ValueGrammar) extends DataGenerator { implicit val ctx = evaluator.context + // Assume e contains generic values with index 0. + // Return a series of expressions with all normalized combinations of generic values. + private def expandGenerics(e: Expr): Seq[Expr] = { + val c = new UniqueCounter[TypeParameter] + val withUniqueCounters: Expr = postMap { + case GenericValue(t, _) => + Some(GenericValue(t, c.next(t))) + case _ => None + }(e) + + val indices = c.current + + val (tps, substInt) = (for { + tp <- indices.keySet.toSeq + } yield tp -> (for { + from <- 0 to indices(tp) + to <- 0 to from + } yield (from, to))).unzip + + val combos = cartesianProduct(substInt) + + val substitutions = combos map { subst => + tps.zip(subst).map { case (tp, (from, to)) => + (GenericValue(tp, from): Expr) -> (GenericValue(tp, to): Expr) + }.toMap + } + + substitutions map (replace(_, withUniqueCounters)) + + } + def generate(tpe: TypeTree): Iterator[Expr] = { - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](grammar.getProductions) - enum.iterator(tpe) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree, Expr]](grammar.getProductions) + enum.iterator(tpe).flatMap(expandGenerics) } def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]] = { @@ -51,4 +85,8 @@ class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar[TypeTree] } } + def generateMapping(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int) = { + generateFor(ins, satisfying, maxValid, maxEnumerated) map (ins zip _) + } + } diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 62e69fa62131980cb63a2888d9a2aeea724424d2..193361f0e85a43eef47e1f2376cfcef1e8fffc56 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -8,7 +8,7 @@ import purescala.Constructors._ import purescala.ExprOps._ import purescala.Expressions.Pattern import purescala.Extractors._ -import purescala.TypeOps._ +import purescala.TypeOps.isSubtypeOf import purescala.Types._ import purescala.Common._ import purescala.Expressions._ diff --git a/src/main/scala/leon/evaluators/StreamEvaluator.scala b/src/main/scala/leon/evaluators/StreamEvaluator.scala index 9cc6dd132036ffdde4a5ef70e2be87e6276f9f3c..385157e4d64c835841279f0970705f6a5aa10b1c 100644 --- a/src/main/scala/leon/evaluators/StreamEvaluator.scala +++ b/src/main/scala/leon/evaluators/StreamEvaluator.scala @@ -6,7 +6,7 @@ package evaluators import purescala.Constructors._ import purescala.ExprOps._ import purescala.Extractors._ -import purescala.TypeOps._ +import purescala.TypeOps.{leastUpperBound, isSubtypeOf} import purescala.Types._ import purescala.Common.Identifier import purescala.Definitions.{TypedFunDef, Program} diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 3e93664008654edef6fa057c56451db0a108bc6b..e3fb828617ceca48f6b5287cafe3494d9692b570 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -22,7 +22,7 @@ import Common._ import Extractors._ import Constructors._ import ExprOps._ -import TypeOps._ +import TypeOps.{leastUpperBound, typesCompatible, typeParamsOf, canBeSubtypeOf} import xlang.Expressions.{Block => LeonBlock, _} import xlang.ExprOps._ @@ -137,10 +137,6 @@ trait CodeExtraction extends ASTExtractors { private var currentFunDef: FunDef = null - //This is a bit misleading, if an expr is not mapped then it has no owner, if it is mapped to None it means - //that it can have any owner - private var owners: Map[Identifier, Option[FunDef]] = Map() - // This one never fails, on error, it returns Untyped def leonType(tpt: Type)(implicit dctx: DefContext, pos: Position): LeonType = { try { @@ -639,7 +635,6 @@ trait CodeExtraction extends ASTExtractors { val ptpe = leonType(sym.tpe)(nctx, sym.pos) val tpe = if (sym.isByNameParam) FunctionType(Seq(), ptpe) else ptpe val newID = FreshIdentifier(sym.name.toString, tpe).setPos(sym.pos) - owners += (newID -> None) val vd = LeonValDef(newID).setPos(sym.pos) if (sym.isByNameParam) { @@ -798,21 +793,7 @@ trait CodeExtraction extends ASTExtractors { }} else body0 val finalBody = try { - flattenBlocks(extractTreeOrNoTree(body)(fctx)) match { - case e if e.getType.isInstanceOf[ArrayType] => - getOwner(e) match { - case Some(Some(fd)) if fd == funDef => - e - - case None => - e - - case _ => - outOfSubsetError(body, "Function cannot return an array that is not locally defined") - } - case e => - e - } + flattenBlocks(extractTreeOrNoTree(body)(fctx)) } catch { case e: ImpureCodeEncounteredException => e.emit() @@ -1090,15 +1071,6 @@ trait CodeExtraction extends ASTExtractors { val newID = FreshIdentifier(vs.name.toString, binderTpe) val valTree = extractTree(bdy) - if(valTree.getType.isInstanceOf[ArrayType]) { - getOwner(valTree) match { - case None => - owners += (newID -> Some(currentFunDef)) - case _ => - outOfSubsetError(tr, "Cannot alias array") - } - } - val restTree = rest match { case Some(rst) => val nctx = dctx.withNewVar(vs -> (() => Variable(newID))) @@ -1138,7 +1110,7 @@ trait CodeExtraction extends ASTExtractors { case _ => (Nil, restTree) } - LetDef(funDefWithBody +: other_fds, block) + letDef(funDefWithBody +: other_fds, block) // FIXME case ExDefaultValueFunction @@ -1151,15 +1123,6 @@ trait CodeExtraction extends ASTExtractors { val newID = FreshIdentifier(vs.name.toString, binderTpe) val valTree = extractTree(bdy) - if(valTree.getType.isInstanceOf[ArrayType]) { - getOwner(valTree) match { - case None => - owners += (newID -> Some(currentFunDef)) - case Some(_) => - outOfSubsetError(tr, "Cannot alias array") - } - } - val restTree = rest match { case Some(rst) => { val nv = vs -> (() => Variable(newID)) @@ -1178,9 +1141,6 @@ trait CodeExtraction extends ASTExtractors { case Some(fun) => val Variable(id) = fun() val rhsTree = extractTree(rhs) - if(rhsTree.getType.isInstanceOf[ArrayType] && getOwner(rhsTree).isDefined) { - outOfSubsetError(tr, "Cannot alias array") - } Assignment(id, rhsTree) case None => @@ -1223,18 +1183,6 @@ trait CodeExtraction extends ASTExtractors { outOfSubsetError(tr, "Array update only works on variables") } - getOwner(lhsRec) match { - // case Some(Some(fd)) if fd != currentFunDef => - // outOfSubsetError(tr, "cannot update an array that is not defined locally") - - // case Some(None) => - // outOfSubsetError(tr, "cannot update an array that is not defined locally") - - case Some(_) => - - case None => sys.error("This array: " + lhsRec + " should have had an owner") - } - val indexRec = extractTree(index) val newValueRec = extractTree(newValue) ArrayUpdate(lhsRec, indexRec, newValueRec) @@ -1309,7 +1257,6 @@ trait CodeExtraction extends ASTExtractors { val aTpe = extractType(tpt) val oTpe = oracleType(ops.pos, aTpe) val newID = FreshIdentifier(sym.name.toString, oTpe) - owners += (newID -> None) newID } @@ -1331,7 +1278,6 @@ trait CodeExtraction extends ASTExtractors { val vds = args map { vd => val aTpe = extractType(vd.tpt) val newID = FreshIdentifier(vd.symbol.name.toString, aTpe) - owners += (newID -> None) LeonValDef(newID) } @@ -1347,7 +1293,6 @@ trait CodeExtraction extends ASTExtractors { val vds = args map { case (tpt, sym) => val aTpe = extractType(tpt) val newID = FreshIdentifier(sym.name.toString, aTpe) - owners += (newID -> None) LeonValDef(newID) } @@ -1908,34 +1853,6 @@ trait CodeExtraction extends ASTExtractors { } } - private def getReturnedExpr(expr: LeonExpr): Seq[LeonExpr] = expr match { - case Let(_, _, rest) => getReturnedExpr(rest) - case LetVar(_, _, rest) => getReturnedExpr(rest) - case LeonBlock(_, rest) => getReturnedExpr(rest) - case IfExpr(_, thenn, elze) => getReturnedExpr(thenn) ++ getReturnedExpr(elze) - case MatchExpr(_, cses) => cses.flatMap{ cse => getReturnedExpr(cse.rhs) } - case _ => Seq(expr) - } - - def getOwner(exprs: Seq[LeonExpr]): Option[Option[FunDef]] = { - val exprOwners: Seq[Option[Option[FunDef]]] = exprs.map { - case Variable(id) => - owners.get(id) - case _ => - None - } - - if(exprOwners.contains(None)) - None - else if(exprOwners.contains(Some(None))) - Some(None) - else if(exprOwners.exists(o1 => exprOwners.exists(o2 => o1 != o2))) - Some(None) - else - exprOwners.head - } - - def getOwner(expr: LeonExpr): Option[Option[FunDef]] = getOwner(getReturnedExpr(expr)) } def containsLetDef(expr: LeonExpr): Boolean = { diff --git a/src/main/scala/leon/grammars/BaseGrammar.scala b/src/main/scala/leon/grammars/BaseGrammar.scala index f11f937498051eb47c2c522a0faa2a1499545175..6e0a2ee5e6842255aac5755c9ee27005e5360eb5 100644 --- a/src/main/scala/leon/grammars/BaseGrammar.scala +++ b/src/main/scala/leon/grammars/BaseGrammar.scala @@ -7,56 +7,65 @@ import purescala.Types._ import purescala.Expressions._ import purescala.Constructors._ +/** The basic grammar for Leon expressions. + * Generates the most obvious expressions for a given type, + * without regard of context (variables in scope, current function etc.) + * Also does some trivial simplifications. + */ case object BaseGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => List( - terminal(BooleanLiteral(true)), - terminal(BooleanLiteral(false)), - nonTerminal(List(BooleanType), { case Seq(a) => not(a) }), - nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }), - nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), + terminal(BooleanLiteral(false), Tags.BooleanC), + terminal(BooleanLiteral(true), Tags.BooleanC), + nonTerminal(List(BooleanType), { case Seq(a) => not(a) }, Tags.Not), + nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }, Tags.And), + nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }, Tags.Or ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessEquals(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessEquals(a, b) }) ) case Int32Type => List( - terminal(IntLiteral(0)), - terminal(IntLiteral(1)), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => plus(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => minus(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => times(a, b) }) + terminal(IntLiteral(0), Tags.Zero), + terminal(IntLiteral(1), Tags.One ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => plus(a, b) }, Tags.Plus ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => minus(a, b) }, Tags.Minus), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => times(a, b) }, Tags.Times) ) case IntegerType => List( - terminal(InfiniteIntegerLiteral(0)), - terminal(InfiniteIntegerLiteral(1)), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => plus(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => minus(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => times(a, b) }) + terminal(InfiniteIntegerLiteral(0), Tags.Zero), + terminal(InfiniteIntegerLiteral(1), Tags.One ), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => plus(a, b) }, Tags.Plus ), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => minus(a, b) }, Tags.Minus), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => times(a, b) }, Tags.Times), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => Modulo(a, b) }, Tags.Mod), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => Division(a, b) }, Tags.Div) ) case TupleType(stps) => List( - nonTerminal(stps, { sub => Tuple(sub) }) + nonTerminal(stps, Tuple, Tags.Constructor(isTerminal = false)) ) case cct: CaseClassType => List( - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct) ) ) case act: AbstractClassType => act.knownCCDescendants.map { cct => - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct) ) } case st @ SetType(base) => List( - nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }), + terminal(FiniteSet(Set(), base), Tags.Constant), + nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)), nonTerminal(List(st, st), { case Seq(a, b) => SetUnion(a, b) }), nonTerminal(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }), nonTerminal(List(st, st), { case Seq(a, b) => SetDifference(a, b) }) @@ -64,7 +73,7 @@ case object BaseGrammar extends ExpressionGrammar[TypeTree] { case UnitType => List( - terminal(UnitLiteral()) + terminal(UnitLiteral(), Tags.Constant) ) case _ => diff --git a/src/main/scala/leon/grammars/Constants.scala b/src/main/scala/leon/grammars/Constants.scala new file mode 100644 index 0000000000000000000000000000000000000000..81c55346052668e0d82b05ee240867eb1e5c468c --- /dev/null +++ b/src/main/scala/leon/grammars/Constants.scala @@ -0,0 +1,33 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import purescala.Expressions._ +import purescala.Types.TypeTree +import purescala.ExprOps.collect +import purescala.Extractors.IsTyped + +/** Generates constants found in an [[leon.purescala.Expressions.Expr]]. + * Some constants that are generated by other grammars (like 0, 1) will be excluded + */ +case class Constants(e: Expr) extends ExpressionGrammar[TypeTree] { + + private val excluded: Set[Expr] = Set( + InfiniteIntegerLiteral(1), + InfiniteIntegerLiteral(0), + IntLiteral(1), + IntLiteral(0), + BooleanLiteral(true), + BooleanLiteral(false) + ) + + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { + val literals = collect[Expr]{ + case IsTyped(l:Literal[_], `t`) => Set(l) + case _ => Set() + }(e) + + (literals -- excluded map (terminal(_, Tags.Constant))).toSeq + } +} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/DepthBoundedGrammar.scala b/src/main/scala/leon/grammars/DepthBoundedGrammar.scala deleted file mode 100644 index fc999be644bf2c4a7a20a73403cf7b1001bb9b68..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/DepthBoundedGrammar.scala +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -case class DepthBoundedGrammar[T](g: ExpressionGrammar[NonTerminal[T]], bound: Int) extends ExpressionGrammar[NonTerminal[T]] { - def computeProductions(l: NonTerminal[T])(implicit ctx: LeonContext): Seq[Gen] = g.computeProductions(l).flatMap { - case gen => - if (l.depth == Some(bound) && gen.subTrees.nonEmpty) { - Nil - } else if (l.depth.exists(_ > bound)) { - Nil - } else { - List ( - nonTerminal(gen.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), gen.builder) - ) - } - } -} diff --git a/src/main/scala/leon/grammars/Empty.scala b/src/main/scala/leon/grammars/Empty.scala index 70ebddc98f21fc872aef8635fe36de7e9ba9bbce..737f9cdf389454f403a6581e13eec7fafa383f34 100644 --- a/src/main/scala/leon/grammars/Empty.scala +++ b/src/main/scala/leon/grammars/Empty.scala @@ -5,6 +5,7 @@ package grammars import purescala.Types.Typed +/** The empty expression grammar */ case class Empty[T <: Typed]() extends ExpressionGrammar[T] { - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = Nil + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = Nil } diff --git a/src/main/scala/leon/grammars/EqualityGrammar.scala b/src/main/scala/leon/grammars/EqualityGrammar.scala index e9463a771204d877a4d748c373b6d198e2c2591b..a2f9c41360ada03334ace63eca3ca46f9f6d5ff7 100644 --- a/src/main/scala/leon/grammars/EqualityGrammar.scala +++ b/src/main/scala/leon/grammars/EqualityGrammar.scala @@ -6,13 +6,15 @@ package grammars import purescala.Types._ import purescala.Constructors._ -import bonsai._ - +/** A grammar of equalities + * + * @param types The set of types for which equalities will be generated + */ case class EqualityGrammar(types: Set[TypeTree]) extends ExpressionGrammar[TypeTree] { - override def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => types.toList map { tp => - nonTerminal(List(tp, tp), { case Seq(a, b) => equality(a, b) }) + nonTerminal(List(tp, tp), { case Seq(a, b) => equality(a, b) }, Tags.Equals) } case _ => Nil diff --git a/src/main/scala/leon/grammars/ExpressionGrammar.scala b/src/main/scala/leon/grammars/ExpressionGrammar.scala index ac394ab840bddf0d498080a04e447ce66de07caa..3179312b7f65444eb3e8c39357fd449e13339c8f 100644 --- a/src/main/scala/leon/grammars/ExpressionGrammar.scala +++ b/src/main/scala/leon/grammars/ExpressionGrammar.scala @@ -6,23 +6,37 @@ package grammars import purescala.Expressions._ import purescala.Types._ import purescala.Common._ +import transformers.Union +import utils.Timer import scala.collection.mutable.{HashMap => MutableMap} +/** Represents a context-free grammar of expressions + * + * @tparam T The type of nonterminal symbols for this grammar + */ abstract class ExpressionGrammar[T <: Typed] { - type Gen = Generator[T, Expr] - private[this] val cache = new MutableMap[T, Seq[Gen]]() + type Prod = ProductionRule[T, Expr] - def terminal(builder: => Expr) = { - Generator[T, Expr](Nil, { (subs: Seq[Expr]) => builder }) + private[this] val cache = new MutableMap[T, Seq[Prod]]() + + /** Generates a [[ProductionRule]] without nonterminal symbols */ + def terminal(builder: => Expr, tag: Tags.Tag = Tags.Top, cost: Int = 1) = { + ProductionRule[T, Expr](Nil, { (subs: Seq[Expr]) => builder }, tag, cost) } - def nonTerminal(subs: Seq[T], builder: (Seq[Expr] => Expr)): Generator[T, Expr] = { - Generator[T, Expr](subs, builder) + /** Generates a [[ProductionRule]] with nonterminal symbols */ + def nonTerminal(subs: Seq[T], builder: (Seq[Expr] => Expr), tag: Tags.Tag = Tags.Top, cost: Int = 1): ProductionRule[T, Expr] = { + ProductionRule[T, Expr](subs, builder, tag, cost) } - def getProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = { + /** The list of production rules for this grammar for a given nonterminal. + * This is the cached version of [[getProductions]] which clients should use. + * + * @param t The nonterminal for which production rules will be generated + */ + def getProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = { cache.getOrElse(t, { val res = computeProductions(t) cache += t -> res @@ -30,9 +44,13 @@ abstract class ExpressionGrammar[T <: Typed] { }) } - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] + /** The list of production rules for this grammar for a given nonterminal + * + * @param t The nonterminal for which production rules will be generated + */ + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] - def filter(f: Gen => Boolean) = { + def filter(f: Prod => Boolean) = { new ExpressionGrammar[T] { def computeProductions(t: T)(implicit ctx: LeonContext) = ExpressionGrammar.this.computeProductions(t).filter(f) } @@ -44,14 +62,19 @@ abstract class ExpressionGrammar[T <: Typed] { final def printProductions(printer: String => Unit)(implicit ctx: LeonContext) { - for ((t, gs) <- cache; g <- gs) { - val subs = g.subTrees.map { t => - FreshIdentifier(Console.BOLD+t.asString+Console.RESET, t.getType).toVariable - } + for ((t, gs) <- cache) { + val lhs = f"${Console.BOLD}${t.asString}%50s${Console.RESET} ::=" + if (gs.isEmpty) { + printer(s"$lhs ε") + } else for (g <- gs) { + val subs = g.subTrees.map { t => + FreshIdentifier(Console.BOLD + t.asString + Console.RESET, t.getType).toVariable + } - val gen = g.builder(subs).asString + val gen = g.builder(subs).asString - printer(f"${Console.BOLD}${t.asString}%30s${Console.RESET} ::= $gen") + printer(s"$lhs $gen") + } } } } diff --git a/src/main/scala/leon/grammars/FunctionCalls.scala b/src/main/scala/leon/grammars/FunctionCalls.scala index 14f92393934c18804bdb130e9c1617b915a347bd..1233fb1931a83b5ca674019be0c85144339dd19f 100644 --- a/src/main/scala/leon/grammars/FunctionCalls.scala +++ b/src/main/scala/leon/grammars/FunctionCalls.scala @@ -10,8 +10,14 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Expressions._ +/** Generates non-recursive function calls + * + * @param currentFunction The currend function for which no calls will be generated + * @param types The candidate real type parameters for [[currentFunction]] + * @param exclude An additional set of functions for which no calls will be generated + */ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[TypeTree], exclude: Set[FunDef]) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { def getCandidates(fd: FunDef): Seq[TypedFunDef] = { // Prevents recursive calls @@ -73,7 +79,7 @@ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[Type val funcs = visibleFunDefsFromMain(prog).toSeq.sortBy(_.id).flatMap(getCandidates).filterNot(filter) funcs.map{ tfd => - nonTerminal(tfd.params.map(_.getType), { sub => FunctionInvocation(tfd, sub) }) + nonTerminal(tfd.params.map(_.getType), FunctionInvocation(tfd, _), Tags.tagOf(tfd.fd, isSafe = false)) } } } diff --git a/src/main/scala/leon/grammars/Generator.scala b/src/main/scala/leon/grammars/Generator.scala deleted file mode 100644 index 18d132e2c25ea222324dc05809220f12d0fb7100..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/Generator.scala +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -import bonsai.{Generator => Gen} - -object GrammarTag extends Enumeration { - val Top = Value -} -import GrammarTag._ - -class Generator[T, R](subTrees: Seq[T], builder: Seq[R] => R, tag: Value) extends Gen[T,R](subTrees, builder) -object Generator { - def apply[T, R](subTrees: Seq[T], builder: Seq[R] => R, tag: Value = Top) = new Generator(subTrees, builder, tag) -} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/Grammars.scala b/src/main/scala/leon/grammars/Grammars.scala index 23b1dd5a14cfeb82dd4555832e777597615b337e..06aba3d5f5343cc7e2807854f0b4665bfa1a602c 100644 --- a/src/main/scala/leon/grammars/Grammars.scala +++ b/src/main/scala/leon/grammars/Grammars.scala @@ -7,6 +7,7 @@ import purescala.Expressions._ import purescala.Definitions._ import purescala.Types._ import purescala.TypeOps._ +import transformers.OneOf import synthesis.{SynthesisContext, Problem} @@ -16,6 +17,7 @@ object Grammars { BaseGrammar || EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ inputs.map { _.getType }) || OneOf(inputs) || + Constants(currentFunction.fullBody) || FunctionCalls(prog, currentFunction, inputs.map(_.getType), exclude) || SafeRecursiveCalls(prog, ws, pc) } @@ -28,3 +30,4 @@ object Grammars { g.filter(g => g.subTrees.forall(t => typeDepth(t.getType) <= b)) } } + diff --git a/src/main/scala/leon/grammars/NonTerminal.scala b/src/main/scala/leon/grammars/NonTerminal.scala index 7492ffac5c17df326084f846857c6ac3bebe1775..600189ffa06378841f6bf3285f7f1bd7bb6116f5 100644 --- a/src/main/scala/leon/grammars/NonTerminal.scala +++ b/src/main/scala/leon/grammars/NonTerminal.scala @@ -5,7 +5,14 @@ package grammars import purescala.Types._ -case class NonTerminal[T](t: TypeTree, l: T, depth: Option[Int] = None) extends Typed { +/** A basic non-terminal symbol of a grammar. + * + * @param t The type of which expressions will be generated + * @param l A label that characterizes this [[NonTerminal]] + * @param depth The optional depth within the syntax tree where this [[NonTerminal]] is. + * @tparam L The type of label for this NonTerminal. + */ +case class NonTerminal[L](t: TypeTree, l: L, depth: Option[Int] = None) extends Typed { val getType = t override def asString(implicit ctx: LeonContext) = t.asString+"#"+l+depth.map(d => "@"+d).getOrElse("") diff --git a/src/main/scala/leon/grammars/ProductionRule.scala b/src/main/scala/leon/grammars/ProductionRule.scala new file mode 100644 index 0000000000000000000000000000000000000000..fc493a7d9d17557a26a8e54ff4615d39ba922190 --- /dev/null +++ b/src/main/scala/leon/grammars/ProductionRule.scala @@ -0,0 +1,18 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import bonsai.Generator + +/** Represents a production rule of a non-terminal symbol of an [[ExpressionGrammar]]. + * + * @param subTrees The nonterminals that are used in the right-hand side of this [[ProductionRule]] + * (and will generate deeper syntax trees). + * @param builder A function that builds the syntax tree that this [[ProductionRule]] represents from nested trees. + * @param tag Gives information about the nature of this production rule. + * @tparam T The type of nonterminal symbols of the grammar + * @tparam R The type of syntax trees of the grammar + */ +case class ProductionRule[T, R](override val subTrees: Seq[T], override val builder: Seq[R] => R, tag: Tags.Tag, cost: Int = 1) + extends Generator[T,R](subTrees, builder) diff --git a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala index 1bbcb0523158ac95713f5a0d4a16f0f35e14edf4..f3234176a8c17378a7a5f027f38cbd42069ae7d6 100644 --- a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala +++ b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala @@ -9,15 +9,25 @@ import purescala.ExprOps._ import purescala.Expressions._ import synthesis.utils.Helpers._ +/** Generates recursive calls that will not trivially result in non-termination. + * + * @param ws An expression that contains the known set [[synthesis.Witnesses.Terminating]] expressions + * @param pc The path condition for the generated [[Expr]] by this grammar + */ case class SafeRecursiveCalls(prog: Program, ws: Expr, pc: Expr) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { val calls = terminatingCalls(prog, t, ws, pc) calls.map { - case (e, free) => + case (fi, free) => val freeSeq = free.toSeq - nonTerminal(freeSeq.map(_.getType), { sub => replaceFromIDs(freeSeq.zip(sub).toMap, e) }) + nonTerminal( + freeSeq.map(_.getType), + { sub => replaceFromIDs(freeSeq.zip(sub).toMap, fi) }, + Tags.tagOf(fi.tfd.fd, isSafe = true), + 2 + ) } } } diff --git a/src/main/scala/leon/grammars/SimilarTo.scala b/src/main/scala/leon/grammars/SimilarTo.scala index 77e912792965d860fc934eb016370c8f2b57fd8f..3a7708e9a77960ffbfde98d478d2ca7c73c713d0 100644 --- a/src/main/scala/leon/grammars/SimilarTo.scala +++ b/src/main/scala/leon/grammars/SimilarTo.scala @@ -3,21 +3,24 @@ package leon package grammars +import transformers._ import purescala.Types._ import purescala.TypeOps._ import purescala.Extractors._ import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.DefOps._ import purescala.Expressions._ import synthesis._ +/** A grammar that generates expressions by inserting small variations in [[e]] + * @param e The [[Expr]] to which small variations will be inserted + * @param terminals A set of [[Expr]]s that may be inserted into [[e]] as small variations + */ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisContext, p: Problem) extends ExpressionGrammar[NonTerminal[String]] { val excludeFCalls = sctx.settings.functionsToIgnore - val normalGrammar = DepthBoundedGrammar(EmbeddedGrammar( + val normalGrammar: ExpressionGrammar[NonTerminal[String]] = DepthBoundedGrammar(EmbeddedGrammar( BaseGrammar || EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ terminals.map { _.getType }) || OneOf(terminals.toSeq :+ e) || @@ -37,9 +40,9 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } } - private[this] var similarCache: Option[Map[L, Seq[Gen]]] = None + private[this] var similarCache: Option[Map[L, Seq[Prod]]] = None - def computeProductions(t: L)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: L)(implicit ctx: LeonContext): Seq[Prod] = { t match { case NonTerminal(_, "B", _) => normalGrammar.computeProductions(t) case _ => @@ -54,7 +57,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } } - def computeSimilar(e : Expr)(implicit ctx: LeonContext): Seq[(L, Gen)] = { + def computeSimilar(e : Expr)(implicit ctx: LeonContext): Seq[(L, Prod)] = { def getLabel(t: TypeTree) = { val tpe = bestRealType(t) @@ -67,9 +70,9 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte case _ => false } - def rec(e: Expr, gl: L): Seq[(L, Gen)] = { + def rec(e: Expr, gl: L): Seq[(L, Prod)] = { - def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Gen)] = { + def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Prod)] = { val subGls = subs.map { s => getLabel(s.getType) } // All the subproductions for sub gl @@ -81,8 +84,8 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } val swaps = if (subs.size > 1 && !isCommutative(e)) { - (for (i <- 0 until subs.size; - j <- i+1 until subs.size) yield { + (for (i <- subs.indices; + j <- i+1 until subs.size) yield { if (subs(i).getType == subs(j).getType) { val swapSubs = subs.updated(i, subs(j)).updated(j, subs(i)) @@ -98,18 +101,18 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte allSubs ++ injectG ++ swaps } - def cegis(gl: L): Seq[(L, Gen)] = { + def cegis(gl: L): Seq[(L, Prod)] = { normalGrammar.getProductions(gl).map(gl -> _) } - def int32Variations(gl: L, e : Expr): Seq[(L, Gen)] = { + def int32Variations(gl: L, e : Expr): Seq[(L, Prod)] = { Seq( gl -> terminal(BVMinus(e, IntLiteral(1))), gl -> terminal(BVPlus (e, IntLiteral(1))) ) } - def intVariations(gl: L, e : Expr): Seq[(L, Gen)] = { + def intVariations(gl: L, e : Expr): Seq[(L, Prod)] = { Seq( gl -> terminal(Minus(e, InfiniteIntegerLiteral(1))), gl -> terminal(Plus (e, InfiniteIntegerLiteral(1))) @@ -118,7 +121,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte // Find neighbor case classes that are compatible with the arguments: // Turns And(e1, e2) into Or(e1, e2)... - def ccVariations(gl: L, cc: CaseClass): Seq[(L, Gen)] = { + def ccVariations(gl: L, cc: CaseClass): Seq[(L, Prod)] = { val CaseClass(cct, args) = cc val neighbors = cct.root.knownCCDescendants diff Seq(cct) @@ -129,7 +132,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } val funFilter = (fd: FunDef) => fd.isSynthetic || (excludeFCalls contains fd) - val subs: Seq[(L, Gen)] = e match { + val subs: Seq[(L, Prod)] = e match { case _: Terminal | _: Let | _: LetDef | _: MatchExpr => gens(e, gl, Nil, { _ => e }) ++ cegis(gl) diff --git a/src/main/scala/leon/grammars/SizeBoundedGrammar.scala b/src/main/scala/leon/grammars/SizeBoundedGrammar.scala deleted file mode 100644 index 1b25e30f61aa74598feb255366fe10a153bc9e30..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/SizeBoundedGrammar.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -import purescala.Types._ -import leon.utils.SeqUtils.sumTo - -case class SizedLabel[T <: Typed](underlying: T, size: Int) extends Typed { - val getType = underlying.getType - - override def asString(implicit ctx: LeonContext) = underlying.asString+"|"+size+"|" -} - -case class SizeBoundedGrammar[T <: Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[SizedLabel[T]] { - def computeProductions(sl: SizedLabel[T])(implicit ctx: LeonContext): Seq[Gen] = { - if (sl.size <= 0) { - Nil - } else if (sl.size == 1) { - g.getProductions(sl.underlying).filter(_.subTrees.isEmpty).map { gen => - terminal(gen.builder(Seq())) - } - } else { - g.getProductions(sl.underlying).filter(_.subTrees.nonEmpty).flatMap { gen => - val sizes = sumTo(sl.size-1, gen.subTrees.size) - - for (ss <- sizes) yield { - val subSizedLabels = (gen.subTrees zip ss) map (s => SizedLabel(s._1, s._2)) - - nonTerminal(subSizedLabels, gen.builder) - } - } - } - } -} diff --git a/src/main/scala/leon/grammars/Tags.scala b/src/main/scala/leon/grammars/Tags.scala new file mode 100644 index 0000000000000000000000000000000000000000..4a6b6fca491b8db6f74622edd9298ec5cd6053b0 --- /dev/null +++ b/src/main/scala/leon/grammars/Tags.scala @@ -0,0 +1,65 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import purescala.Types.CaseClassType +import purescala.Definitions.FunDef + +object Tags { + /** A class for tags that tag a [[ProductionRule]] with the kind of expression in generates. */ + abstract class Tag + case object Top extends Tag // Tag for the top-level of the grammar (default) + case object Zero extends Tag // Tag for 0 + case object One extends Tag // Tag for 1 + case object BooleanC extends Tag // Tag for boolean constants + case object Constant extends Tag // Tag for other constants + case object And extends Tag // Tags for boolean operations + case object Or extends Tag + case object Not extends Tag + case object Plus extends Tag // Tags for arithmetic operations + case object Minus extends Tag + case object Times extends Tag + case object Mod extends Tag + case object Div extends Tag + case object Variable extends Tag // Tag for variables + case object Equals extends Tag // Tag for equality + /** Constructors like Tuple, CaseClass... + * + * @param isTerminal If true, this constructor represents a terminal symbol + * (in practice, case class with 0 fields) + */ + case class Constructor(isTerminal: Boolean) extends Tag + /** Tag for function calls + * + * @param isMethod Whether the function called is a method + * @param isSafe Whether this constructor represents a safe function call. + * We need this because this call implicitly contains a variable, + * so we want to allow constants in all arguments. + */ + case class FunCall(isMethod: Boolean, isSafe: Boolean) extends Tag + + /** The set of tags that represent constants */ + val isConst: Set[Tag] = Set(Zero, One, Constant, BooleanC, Constructor(true)) + + /** The set of tags that represent commutative operations */ + val isCommut: Set[Tag] = Set(Plus, Times, Equals) + + /** The set of tags which have trivial results for equal arguments */ + val symmetricTrivial = Set(Minus, And, Or, Equals, Div, Mod) + + /** Tags which allow constants in all their operands + * + * In reality, the current version never allows that: it is only allowed in safe function calls + * which by construction contain a hidden reference to a variable. + * TODO: Experiment with different conditions, e.g. are constants allowed in + * top-level/ general function calls/ constructors/...? + */ + def allConstArgsAllowed(t: Tag) = t match { + case FunCall(_, true) => true + case _ => false + } + + def tagOf(cct: CaseClassType) = Constructor(cct.fields.isEmpty) + def tagOf(fd: FunDef, isSafe: Boolean) = FunCall(fd.methodOwner.isDefined, isSafe) +} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/ValueGrammar.scala b/src/main/scala/leon/grammars/ValueGrammar.scala index 98850c8df4adcf3e776970c176cf37c251823917..d3c42201728f4b03d9518b3db503cff9189dcc8b 100644 --- a/src/main/scala/leon/grammars/ValueGrammar.scala +++ b/src/main/scala/leon/grammars/ValueGrammar.scala @@ -6,62 +6,64 @@ package grammars import purescala.Types._ import purescala.Expressions._ +/** A grammar of values (ground terms) */ case object ValueGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => List( - terminal(BooleanLiteral(true)), - terminal(BooleanLiteral(false)) + terminal(BooleanLiteral(true), Tags.One), + terminal(BooleanLiteral(false), Tags.Zero) ) case Int32Type => List( - terminal(IntLiteral(0)), - terminal(IntLiteral(1)), - terminal(IntLiteral(5)) + terminal(IntLiteral(0), Tags.Zero), + terminal(IntLiteral(1), Tags.One), + terminal(IntLiteral(5), Tags.Constant) ) case IntegerType => List( - terminal(InfiniteIntegerLiteral(0)), - terminal(InfiniteIntegerLiteral(1)), - terminal(InfiniteIntegerLiteral(5)) + terminal(InfiniteIntegerLiteral(0), Tags.Zero), + terminal(InfiniteIntegerLiteral(1), Tags.One), + terminal(InfiniteIntegerLiteral(5), Tags.Constant) ) case StringType => List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("foo")), - terminal(StringLiteral("bar")) + terminal(StringLiteral(""), Tags.Constant), + terminal(StringLiteral("a"), Tags.Constant), + terminal(StringLiteral("foo"), Tags.Constant), + terminal(StringLiteral("bar"), Tags.Constant) ) case tp: TypeParameter => - for (ind <- (1 to 3).toList) yield { - terminal(GenericValue(tp, ind)) - } + List( + terminal(GenericValue(tp, 0)) + ) case TupleType(stps) => List( - nonTerminal(stps, { sub => Tuple(sub) }) + nonTerminal(stps, Tuple, Tags.Constructor(stps.isEmpty)) ) case cct: CaseClassType => List( - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)}) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct)) ) case act: AbstractClassType => act.knownCCDescendants.map { cct => - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)}) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct)) } case st @ SetType(base) => List( - nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }), - nonTerminal(List(base, base), { case elems => FiniteSet(elems.toSet, base) }) + terminal(FiniteSet(Set(), base), Tags.Constant), + nonTerminal(List(base), { elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)), + nonTerminal(List(base, base), { elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)) ) case UnitType => List( - terminal(UnitLiteral()) + terminal(UnitLiteral(), Tags.Constant) ) case _ => diff --git a/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala b/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..02e045497a28db205d4a33a300ac0b742510920a --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala @@ -0,0 +1,21 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +/** Limits a grammar to a specific expression depth */ +case class DepthBoundedGrammar[L](g: ExpressionGrammar[NonTerminal[L]], bound: Int) extends ExpressionGrammar[NonTerminal[L]] { + def computeProductions(l: NonTerminal[L])(implicit ctx: LeonContext): Seq[Prod] = g.computeProductions(l).flatMap { + case gen => + if (l.depth == Some(bound) && gen.isNonTerminal) { + Nil + } else if (l.depth.exists(_ > bound)) { + Nil + } else { + List ( + nonTerminal(gen.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), gen.builder) + ) + } + } +} diff --git a/src/main/scala/leon/grammars/EmbeddedGrammar.scala b/src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala similarity index 74% rename from src/main/scala/leon/grammars/EmbeddedGrammar.scala rename to src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala index 8dcbc6ec10f9aa42895e5f876cdd4d72479de229..d989a8804b32f62697b7f31e498e61393a12c35b 100644 --- a/src/main/scala/leon/grammars/EmbeddedGrammar.scala +++ b/src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala @@ -2,10 +2,9 @@ package leon package grammars +package transformers -import purescala.Types._ -import purescala.Expressions._ -import purescala.Constructors._ +import leon.purescala.Types.Typed /** * Embed a grammar Li->Expr within a grammar Lo->Expr @@ -13,9 +12,9 @@ import purescala.Constructors._ * We rely on a bijection between Li and Lo labels */ case class EmbeddedGrammar[Ti <: Typed, To <: Typed](innerGrammar: ExpressionGrammar[Ti], iToo: Ti => To, oToi: To => Ti) extends ExpressionGrammar[To] { - def computeProductions(t: To)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: To)(implicit ctx: LeonContext): Seq[Prod] = { innerGrammar.computeProductions(oToi(t)).map { innerGen => - nonTerminal(innerGen.subTrees.map(iToo), innerGen.builder) + nonTerminal(innerGen.subTrees.map(iToo), innerGen.builder, innerGen.tag) } } } diff --git a/src/main/scala/leon/grammars/OneOf.scala b/src/main/scala/leon/grammars/transformers/OneOf.scala similarity index 56% rename from src/main/scala/leon/grammars/OneOf.scala rename to src/main/scala/leon/grammars/transformers/OneOf.scala index 0e10c096151c1fdf83d3c7e7f10c4a4a6518215b..5c57c6a1a48179e2d813398aa651022df6cae35a 100644 --- a/src/main/scala/leon/grammars/OneOf.scala +++ b/src/main/scala/leon/grammars/transformers/OneOf.scala @@ -2,14 +2,15 @@ package leon package grammars +package transformers -import purescala.Types._ -import purescala.Expressions._ -import purescala.TypeOps._ -import purescala.Constructors._ +import purescala.Expressions.Expr +import purescala.Types.TypeTree +import purescala.TypeOps.isSubtypeOf +/** Generates one production rule for each expression in a sequence that has compatible type */ case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { inputs.collect { case i if isSubtypeOf(i.getType, t) => terminal(i) diff --git a/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala b/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..1b605359fdf18d02f08d105a9cccc58757b99262 --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala @@ -0,0 +1,59 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +import purescala.Types.Typed +import utils.SeqUtils._ + +/** Adds information about size to a nonterminal symbol */ +case class SizedNonTerm[T <: Typed](underlying: T, size: Int) extends Typed { + val getType = underlying.getType + + override def asString(implicit ctx: LeonContext) = underlying.asString+"|"+size+"|" +} + +/** Limits a grammar by producing expressions of size bounded by the [[SizedNonTerm.size]] of a given [[SizedNonTerm]]. + * + * In case of commutative operations, the grammar will produce trees skewed to the right + * (i.e. the right subtree will always be larger). Notice we do not lose generality in case of + * commutative operations. + */ +case class SizeBoundedGrammar[T <: Typed](g: ExpressionGrammar[T], optimizeCommut: Boolean) extends ExpressionGrammar[SizedNonTerm[T]] { + def computeProductions(sl: SizedNonTerm[T])(implicit ctx: LeonContext): Seq[Prod] = { + if (sl.size <= 0) { + Nil + } else if (sl.size == 1) { + g.getProductions(sl.underlying).filter(_.isTerminal).map { gen => + terminal(gen.builder(Seq()), gen.tag) + } + } else { + g.getProductions(sl.underlying).filter(_.isNonTerminal).flatMap { gen => + + // Ad-hoc equality that does not take into account position etc.of TaggedNonTerminal's + // TODO: Ugly and hacky + def characteristic(t: T): Typed = t match { + case TaggedNonTerm(underlying, _, _, _) => + underlying + case other => + other + } + + // Optimization: When we have a commutative operation and all the labels are the same, + // we can skew the expression to always be right-heavy + val sizes = if(optimizeCommut && Tags.isCommut(gen.tag) && gen.subTrees.map(characteristic).toSet.size == 1) { + sumToOrdered(sl.size-gen.cost, gen.arity) + } else { + sumTo(sl.size-gen.cost, gen.arity) + } + + for (ss <- sizes) yield { + val subSizedLabels = (gen.subTrees zip ss) map (s => SizedNonTerm(s._1, s._2)) + + nonTerminal(subSizedLabels, gen.builder, gen.tag) + } + } + } + } +} diff --git a/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala b/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..43ce13e850ed1b52460ef1a74d7b039adacbd519 --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala @@ -0,0 +1,111 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +import leon.purescala.Types.Typed +import Tags._ + +/** Adds to a nonterminal information about about the tag of its parent's [[leon.grammars.ProductionRule.tag]] + * and additional information. + * + * @param underlying The underlying nonterminal + * @param tag The tag of the parent of this nonterminal + * @param pos The index of this nonterminal in its father's production rule + * @param isConst Whether this nonterminal is obliged to generate/not generate constants. + * + */ +case class TaggedNonTerm[T <: Typed](underlying: T, tag: Tag, pos: Int, isConst: Option[Boolean]) extends Typed { + val getType = underlying.getType + + private val cString = isConst match { + case Some(true) => "↓" + case Some(false) => "↑" + case None => "○" + } + + /** [[isConst]] is printed as follows: ↓ for constants only, ↑ for nonconstants only, + * ○ for anything allowed. + */ + override def asString(implicit ctx: LeonContext): String = s"$underlying%$tag@$pos$cString" +} + +/** Constraints a grammar to reduce redundancy by utilizing information provided by the [[TaggedNonTerm]]. + * + * 1) In case of associative operations, right associativity is enforced. + * 2) Does not generate + * - neutral and absorbing elements (incl. boolean equality) + * - nested negations + * 3) Excludes method calls on nullary case objects, e.g. Nil().size + * 4) Enforces that no constant trees are generated (and recursively for each subtree) + * + * @param g The underlying untagged grammar + */ +case class TaggedGrammar[T <: Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[TaggedNonTerm[T]] { + + private def exclude(tag: Tag, pos: Int): Set[Tag] = (tag, pos) match { + case (Top, _) => Set() + case (And, 0) => Set(And, BooleanC) + case (And, 1) => Set(BooleanC) + case (Or, 0) => Set(Or, BooleanC) + case (Or, 1) => Set(BooleanC) + case (Plus, 0) => Set(Plus, Zero, One) + case (Plus, 1) => Set(Zero) + case (Minus, 1) => Set(Zero) + case (Not, _) => Set(Not, BooleanC) + case (Times, 0) => Set(Times, Zero, One) + case (Times, 1) => Set(Zero, One) + case (Equals,_) => Set(Not, BooleanC) + case (Div | Mod, 0 | 1) => Set(Zero, One) + case (FunCall(true, _), 0) => Set(Constructor(true)) // Don't allow Nil().size etc. + case _ => Set() + } + + def computeProductions(t: TaggedNonTerm[T])(implicit ctx: LeonContext): Seq[Prod] = { + + // Point (4) for this level + val constFilter: g.Prod => Boolean = t.isConst match { + case Some(b) => + innerGen => isConst(innerGen.tag) == b + case None => + _ => true + } + + g.computeProductions(t.underlying) + // Include only constants iff constants are forced, only non-constants iff they are forced + .filter(constFilter) + // Points (1), (2). (3) + .filterNot { innerGen => exclude(t.tag, t.pos)(innerGen.tag) } + .flatMap { innerGen => + + def nt(isConst: Int => Option[Boolean]) = nonTerminal( + innerGen.subTrees.zipWithIndex.map { + case (t, pos) => TaggedNonTerm(t, innerGen.tag, pos, isConst(pos)) + }, + innerGen.builder, + innerGen.tag + ) + + def powerSet[A](t: Set[A]): Set[Set[A]] = { + @scala.annotation.tailrec + def pwr(t: Set[A], ps: Set[Set[A]]): Set[Set[A]] = + if (t.isEmpty) ps + else pwr(t.tail, ps ++ (ps map (_ + t.head))) + + pwr(t, Set(Set.empty[A])) + } + + // Allow constants everywhere if this is allowed, otherwise demand at least 1 variable. + // Aka. tag subTrees correctly so point (4) is enforced in the lower level + // (also, make sure we treat terminals correctly). + if (innerGen.isTerminal || allConstArgsAllowed(innerGen.tag)) { + Seq(nt(_ => None)) + } else { + val indices = innerGen.subTrees.indices.toSet + (powerSet(indices) - indices) map (indices => nt(x => Some(indices(x)))) + } + } + } + +} diff --git a/src/main/scala/leon/grammars/Or.scala b/src/main/scala/leon/grammars/transformers/Union.scala similarity index 73% rename from src/main/scala/leon/grammars/Or.scala rename to src/main/scala/leon/grammars/transformers/Union.scala index e691a245984eaeb11277b9278505b49cf623fed3..471625ac3c22c22456f49f366ed26e5195b5f4ab 100644 --- a/src/main/scala/leon/grammars/Or.scala +++ b/src/main/scala/leon/grammars/transformers/Union.scala @@ -2,8 +2,9 @@ package leon package grammars +package transformers -import purescala.Types._ +import purescala.Types.Typed case class Union[T <: Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGrammar[T] { val subGrammars: Seq[ExpressionGrammar[T]] = gs.flatMap { @@ -11,6 +12,6 @@ case class Union[T <: Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGr case g => Seq(g) } - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = subGrammars.flatMap(_.getProductions(t)) } diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 119543797434a9bb90eee8bb2b507e12de3acb3e..4efc97986400f312afeb19674615e9822c56c8e3 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -275,6 +275,15 @@ object DefOps { None } + /** Clones the given program by replacing some functions by other functions. + * + * @param p The original program + * @param fdMapF Given f, returns Some(g) if f should be replaced by g, and None if f should be kept. + * May be called once each time a function appears (definition and invocation), + * so make sure to output the same if the argument is the same. + * @param fiMapF Given a previous function invocation and its new function definition, returns the expression to use. + * By default it is the function invocation using the new function definition. + * @return the new program with a map from the old functions to the new functions */ def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) : (Program, Map[FunDef, FunDef])= { @@ -297,7 +306,6 @@ object DefOps { df match { case f : FunDef => val newF = fdMap(f) - newF.fullBody = replaceFunCalls(newF.fullBody, fdMap, fiMapF) newF case d => d @@ -307,7 +315,11 @@ object DefOps { } ) }) - + for(fd <- newP.definedFunctions) { + if(ExprOps.exists{ case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache.getOrElse(fd, None) != None case _ => false }(fd.fullBody)) { + fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF) + } + } (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) } @@ -320,32 +332,38 @@ object DefOps { }(e) } - def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = { + def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = { var found = false val res = p.copy(units = for (u <- p.units) yield { u.copy( - defs = u.defs.map { + defs = u.defs.flatMap { case m: ModuleDef => val newdefs = for (df <- m.defs) yield { df match { case `after` => found = true - after +: fds.toSeq - case d => - Seq(d) + after +: cds.toSeq + case d => Seq(d) } } - m.copy(defs = newdefs.flatten) - case d => d + Seq(m.copy(defs = newdefs.flatten)) + case `after` => + found = true + after +: cds.toSeq + case d => Seq(d) } ) }) if (!found) { - println("addFunDefs could not find anchor function!") + println("addDefs could not find anchor definition!") } res } + + def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = addDefs(p, fds, after) + + def addClassDefs(p: Program, fds: Traversable[ClassDef], after: ClassDef): Program = addDefs(p, fds, after) // @Note: This function does not filter functions in classdefs def filterFunDefs(p: Program, fdF: FunDef => Boolean): Program = { diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 5ce8a11ec204c4fde241797dd2706d1a5be20dc0..19d20604ddc9ff9c3187c2fe24ee50f46e658850 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -19,285 +19,19 @@ import solvers._ * * The generic operations lets you apply operations on a whole tree * expression. You can look at: - * - [[ExprOps.fold foldRight]] - * - [[ExprOps.preTraversal preTraversal]] - * - [[ExprOps.postTraversal postTraversal]] - * - [[ExprOps.preMap preMap]] - * - [[ExprOps.postMap postMap]] - * - [[ExprOps.genericTransform genericTransform]] + * - [[SubTreeOps.fold foldRight]] + * - [[SubTreeOps.preTraversal preTraversal]] + * - [[SubTreeOps.postTraversal postTraversal]] + * - [[SubTreeOps.preMap preMap]] + * - [[SubTreeOps.postMap postMap]] + * - [[SubTreeOps.genericTransform genericTransform]] * * These operations usually take a higher order function that gets applied to the * expression tree in some strategy. They provide an expressive way to build complex * operations on Leon expressions. * */ -object ExprOps { - - /* ======== - * Core API - * ======== - * - * All these functions should be stable, tested, and used everywhere. Modify - * with care. - */ - - - /** Does a right tree fold - * - * A right tree fold applies the input function to the subnodes first (from left - * to right), and combine the results along with the current node value. - * - * @param f a function that takes the current node and the seq - * of results form the subtrees. - * @param e The Expr on which to apply the fold. - * @return The expression after applying `f` on all subtrees. - * @note the computation is lazy, hence you should not rely on side-effects of `f` - */ - def fold[T](f: (Expr, Seq[T]) => T)(e: Expr): T = { - val rec = fold(f) _ - val Operator(es, _) = e - - //Usages of views makes the computation lazy. (which is useful for - //contains-like operations) - f(e, es.view.map(rec)) - } - - /** Pre-traversal of the tree. - * - * Invokes the input function on every node '''before''' visiting - * children. Traverse children from left to right subtrees. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) - * }}} - * will yield, in order: - * {{{ - * f(Add(a, Minus(b, c))); f(a); f(Minus(b, c)); f(b); f(c) - * }}} - * - * @param f a function to apply on each node of the expression - * @param e the expression to traverse - */ - def preTraversal(f: Expr => Unit)(e: Expr): Unit = { - val rec = preTraversal(f) _ - val Operator(es, _) = e - f(e) - es.foreach(rec) - } - - /** Post-traversal of the tree. - * - * Invokes the input function on every node '''after''' visiting - * children. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) - * }}} - * will yield, in order: - * {{{ - * f(a), f(b), f(c), f(Minus(b, c)), f(Add(a, Minus(b, c))) - * }}} - * - * @param f a function to apply on each node of the expression - * @param e the expression to traverse - */ - def postTraversal(f: Expr => Unit)(e: Expr): Unit = { - val rec = postTraversal(f) _ - val Operator(es, _) = e - es.foreach(rec) - f(e) - } - - /** Pre-transformation of the tree. - * - * Takes a partial function of replacements and substitute - * '''before''' recursing down the trees. - * - * Supports two modes : - * - * - If applyRec is false (default), will only substitute once on each level. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, d) // And not Add(a, f) because it only substitute once for each level. - * }}} - * - * - If applyRec is true, it will substitute multiple times on each level: - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, f) - * }}} - * - * @note The mode with applyRec true can diverge if f is not well formed - */ - def preMap(f: Expr => Option[Expr], applyRec : Boolean = false)(e: Expr): Expr = { - val rec = preMap(f, applyRec) _ - - val newV = if (applyRec) { - // Apply f as long as it returns Some() - fixpoint { e : Expr => f(e) getOrElse e } (e) - } else { - f(e) getOrElse e - } - - val Operator(es, builder) = newV - val newEs = es.map(rec) - - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(newV) - } else { - newV - } - } - - /** Post-transformation of the tree. - * - * Takes a partial function of replacements. - * Substitutes '''after''' recursing down the trees. - * - * Supports two modes : - * - * - If applyRec is false (default), will only substitute once on each level. - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e - * }}} - * will yield: - * {{{ - * Add(a, Minus(e, c)) - * }}} - * - * - If applyRec is true, it will substitute multiple times on each level: - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, f) - * }}} - * - * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent) - */ - def postMap(f: Expr => Option[Expr], applyRec : Boolean = false)(e: Expr): Expr = { - val rec = postMap(f, applyRec) _ - - val Operator(es, builder) = e - val newEs = es.map(rec) - val newV = { - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(e) - } else { - e - } - } - - if (applyRec) { - // Apply f as long as it returns Some() - fixpoint { e : Expr => f(e) getOrElse e } (newV) - } else { - f(newV) getOrElse newV - } - - } - - - /** Applies functions and combines results in a generic way - * - * Start with an initial value, and apply functions to nodes before - * and after the recursion in the children. Combine the results of - * all children and apply a final function on the resulting node. - * - * @param pre a function applied on a node before doing a recursion in the children - * @param post a function applied to the node built from the recursive application to - all children - * @param combiner a function to combine the resulting values from all children with - the current node - * @param init the initial value - * @param expr the expression on which to apply the transform - * - * @see [[simpleTransform]] - * @see [[simplePreTransform]] - * @see [[simplePostTransform]] - */ - def genericTransform[C](pre: (Expr, C) => (Expr, C), - post: (Expr, C) => (Expr, C), - combiner: (Expr, Seq[C]) => C)(init: C)(expr: Expr) = { - - def rec(eIn: Expr, cIn: C): (Expr, C) = { - - val (expr, ctx) = pre(eIn, cIn) - val Operator(es, builder) = expr - val (newExpr, newC) = { - val (nes, cs) = es.map{ rec(_, ctx)}.unzip - val newE = builder(nes).copiedFrom(expr) - - (newE, combiner(newE, cs)) - } - - post(newExpr, newC) - } - - rec(expr, init) - } - - /* - * ============= - * Auxiliary API - * ============= - * - * Convenient methods using the Core API. - */ - - /** Checks if the predicate holds in some sub-expression */ - def exists(matcher: Expr => Boolean)(e: Expr): Boolean = { - fold[Boolean]({ (e, subs) => matcher(e) || subs.contains(true) } )(e) - } - - /** Collects a set of objects from all sub-expressions */ - def collect[T](matcher: Expr => Set[T])(e: Expr): Set[T] = { - fold[Set[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) - } - - def collectPreorder[T](matcher: Expr => Seq[T])(e: Expr): Seq[T] = { - fold[Seq[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) - } - - /** Returns a set of all sub-expressions matching the predicate */ - def filter(matcher: Expr => Boolean)(e: Expr): Set[Expr] = { - collect[Expr] { e => Set(e) filter matcher }(e) - } - - /** Counts how many times the predicate holds in sub-expressions */ - def count(matcher: Expr => Int)(e: Expr): Int = { - fold[Int]({ (e, subs) => matcher(e) + subs.sum } )(e) - } - - /** Replaces bottom-up sub-expressions by looking up for them in a map */ - def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = { - postMap(substs.lift)(expr) - } - - /** Replaces bottom-up sub-expressions by looking up for them in the provided order */ - def replaceSeq(substs: Seq[(Expr, Expr)], expr: Expr): Expr = { - var res = expr - for (s <- substs) { - res = replace(Map(s), res) - } - res - } - +object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { /** Replaces bottom-up sub-identifiers by looking up for them in a map */ def replaceFromIDs(substs: Map[Identifier, Expr], expr: Expr) : Expr = { postMap({ @@ -332,7 +66,7 @@ object ExprOps { Lambda(args, rec(binders ++ args.map(_.id), bd)) case Forall(args, bd) => Forall(args, rec(binders ++ args.map(_.id), bd)) - case Operator(subs, builder) => + case Deconstructor(subs, builder) => builder(subs map (rec(binders, _))) }).copiedFrom(e) @@ -341,7 +75,7 @@ object ExprOps { /** Returns the set of free variables in an expression */ def variablesOf(expr: Expr): Set[Identifier] = { - import leon.xlang.Expressions.LetVar + import leon.xlang.Expressions._ fold[Set[Identifier]] { case (e, subs) => val subvs = subs.flatten.toSet @@ -375,6 +109,13 @@ object ExprOps { case _ => Set() }(expr) } + + def nestedFunDefsOf(expr: Expr): Set[FunDef] = { + collect[FunDef] { + case LetDef(fds, _) => fds.toSet + case _ => Set() + }(expr) + } /** Returns functions in directly nested LetDefs */ def directlyNestedFunDefs(e: Expr): Set[FunDef] = { @@ -442,7 +183,7 @@ object ExprOps { case l @ Let(i,e,b) => val newID = FreshIdentifier(i.name, i.getType, alwaysShowUniqueID = true).copiedFrom(i) - Some(Let(newID, e, replace(Map(Variable(i) -> Variable(newID)), b))) + Some(Let(newID, e, replaceFromIDs(Map(i -> Variable(newID)), b))) case _ => None }(expr) @@ -597,7 +338,7 @@ object ExprOps { def simplerLet(t: Expr) : Option[Expr] = t match { case letExpr @ Let(i, t: Terminal, b) if isDeterministic(b) => - Some(replace(Map(Variable(i) -> t), b)) + Some(replaceFromIDs(Map(i -> t), b)) case letExpr @ Let(i,e,b) if isDeterministic(b) => { val occurrences = count { @@ -608,7 +349,7 @@ object ExprOps { if(occurrences == 0) { Some(b) } else if(occurrences == 1) { - Some(replace(Map(Variable(i) -> e), b)) + Some(replaceFromIDs(Map(i -> e), b)) } else { None } @@ -619,7 +360,7 @@ object ExprOps { val (remIds, remExprs) = (ids zip exprs).filter { case (id, value: Terminal) => - newBody = replace(Map(Variable(id) -> value), newBody) + newBody = replaceFromIDs(Map(id -> value), newBody) //we replace, so we drop old false case (id, value) => @@ -695,7 +436,7 @@ object ExprOps { case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)) case m @ MatchExpr(scrut, cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).setPos(m) case p @ Passes(in, out, cses) => Passes(rec(in, s), rec(out,s), cses.map(inCase(_, s))).setPos(p) - case n @ Operator(args, recons) => { + case n @ Deconstructor(args, recons) => { var change = false val rargs = args.map(a => { val ra = rec(a, s) @@ -1204,7 +945,7 @@ object ExprOps { def transform(expr: Expr): Option[Expr] = expr match { case IfExpr(c, t, e) => None - case nop@Operator(ts, op) => { + case nop@Deconstructor(ts, op) => { val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } if(iteIndex == -1) None else { val (beforeIte, startIte) = ts.splitAt(iteIndex) @@ -1355,7 +1096,7 @@ object ExprOps { formulaSize(rhs) + og.map(formulaSize).getOrElse(0) + patternSize(p) }.sum - case Operator(es, _) => + case Deconstructor(es, _) => es.map(formulaSize).sum+1 } @@ -1851,7 +1592,7 @@ object ExprOps { case (v1, v2) if isValue(v1) && isValue(v2) => v1 == v2 - case Same(Operator(es1, _), Operator(es2, _)) => + case Same(Deconstructor(es1, _), Deconstructor(es2, _)) => (es1.size == es2.size) && (es1 zip es2).forall{ case (e1, e2) => isHomo(e1, e2) } @@ -2190,7 +1931,7 @@ object ExprOps { f(e, initParent) - val Operator(es, _) = e + val Deconstructor(es, _) = e es foreach rec } @@ -2283,7 +2024,7 @@ object ExprOps { case l @ Lambda(args, body) => val newBody = rec(body, true) extract(Lambda(args, newBody), build) - case Operator(es, recons) => recons(es.map(rec(_, build))) + case Deconstructor(es, recons) => recons(es.map(rec(_, build))) } rec(lift(expr), true) @@ -2311,7 +2052,7 @@ object ExprOps { fds ++= nfds - Some(LetDef(nfds.map(_._2), b)) + Some(letDef(nfds.map(_._2), b)) case FunctionInvocation(tfd, args) => if (fds contains tfd.fd) { diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index bb70a6676923db89648df770318f8c04129bad3e..88edef68586474c24e7a13b70a22e9beed7441b8 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -76,10 +76,6 @@ object Expressions { val getType = tpe } - case class Old(id: Identifier) extends Expr with Terminal { - val getType = id.getType - } - /** Precondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *require* * * @param pred The precondition formula inside ``require(...)`` @@ -165,7 +161,7 @@ object Expressions { * @param body The body of the expression after the function */ case class LetDef(fds: Seq[FunDef], body: Expr) extends Expr { - assert(fds.nonEmpty) + require(fds.nonEmpty) val getType = body.getType } @@ -364,6 +360,22 @@ object Expressions { someValue.id ) } + + object PatternExtractor extends SubTreeOps.Extractor[Pattern] { + def unapply(e: Pattern): Option[(Seq[Pattern], (Seq[Pattern]) => Pattern)] = e match { + case (_: InstanceOfPattern) | (_: WildcardPattern) | (_: LiteralPattern[_]) => + Some(Seq(), es => e) + case CaseClassPattern(binder, ct, subpatterns) => + Some(subpatterns, es => CaseClassPattern(binder, ct, es)) + case TuplePattern(binder, subpatterns) => + Some(subpatterns, es => TuplePattern(binder, es)) + case UnapplyPattern(binder, unapplyFun, subpatterns) => + Some(subpatterns, es => UnapplyPattern(binder, unapplyFun, es)) + case _ => None + } + } + + object PatternOps extends { val Deconstructor = PatternExtractor } with SubTreeOps[Pattern] /** Symbolic I/O examples as a match/case. * $encodingof `out == (in match { cases; case _ => out })` @@ -578,7 +590,10 @@ object Expressions { /** $encodingof `lhs.subString(start, end)` for strings */ case class SubString(expr: Expr, start: Expr, end: Expr) extends Expr { val getType = { - if (expr.getType == StringType && (start == IntegerType || start == Int32Type) && (end == IntegerType || end == Int32Type)) StringType + val ext = expr.getType + val st = start.getType + val et = end.getType + if (ext == StringType && (st == IntegerType || st == Int32Type) && (et == IntegerType || et == Int32Type)) StringType else Untyped } } @@ -770,7 +785,7 @@ object Expressions { * * [[exprs]] should always contain at least 2 elements. * If you are not sure about this requirement, you should use - * [[purescala.Constructors#tupleWrap purescala's constructor tupleWrap]] + * [[leon.purescala.Constructors#tupleWrap purescala's constructor tupleWrap]] * * @param exprs The expressions in the tuple */ @@ -783,7 +798,7 @@ object Expressions { * * Index is 1-based, first element of tuple is 1. * If you are not sure that [[tuple]] is indeed of a TupleType, - * you should use [[purescala.Constructors$.tupleSelect(t:leon\.purescala\.Expressions\.Expr,index:Int,isTuple:Boolean):leon\.purescala\.Expressions\.Expr* purescala's constructor tupleSelect]] + * you should use [[leon.purescala.Constructors.tupleSelect(t:leon\.purescala\.Expressions\.Expr,index:Int,isTuple:Boolean):leon\.purescala\.Expressions\.Expr* purescala's constructor tupleSelect]] */ case class TupleSelect(tuple: Expr, index: Int) extends Expr { require(index >= 1) diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index e2581dd8cdb33e3d025e8206d72dd32fc4ca59f7..cfa4780efccad338b5df44400b53da7264a2c2d7 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -7,12 +7,11 @@ import Expressions._ import Common._ import Types._ import Constructors._ -import ExprOps._ -import Definitions.Program +import Definitions.{Program, AbstractClassDef, CaseClassDef} object Extractors { - object Operator { + object Operator extends SubTreeOps.Extractor[Expr] { def unapply(expr: Expr): Option[(Seq[Expr], (Seq[Expr]) => Expr)] = expr match { /* Unary operators */ case Not(t) => @@ -250,6 +249,8 @@ object Extractors { None } } + + // Extractors for types are available at Types.NAryType trait Extractable { def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] @@ -367,7 +368,7 @@ object Extractors { def unapply(me : MatchExpr) : Option[(Pattern, Expr, Expr)] = { Option(me) collect { - case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, variablesOf(scrut)) => + case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, ExprOps.variablesOf(scrut)) => ( pattern, scrut, body ) } } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 039bef507339294874ad59e7e38dfe335b8f6a2f..44d454a9c0a4f1127ed0533f872c66f1ec6bc067 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -81,15 +81,12 @@ class PrettyPrinter(opts: PrinterOptions, } p"$name" - case Old(id) => - p"old($id)" - case Variable(id) => p"$id" case Let(b,d,e) => - p"""|val $b = $d - |$e""" + p"""|val $b = $d + |$e""" case LetDef(a::q,body) => p"""|$a diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 67cc994649e6b454ee8389fbfdd7674da4aebb6a..11f0c187e144873c386a01702326056e636e1225 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -9,14 +9,12 @@ import Common._ import Expressions._ import Types._ import Definitions._ -import org.apache.commons.lang3.StringEscapeUtils -/** This pretty-printer only print valid scala syntax */ +/** This pretty-printer only prints valid scala syntax */ class ScalaPrinter(opts: PrinterOptions, opgm: Option[Program], sb: StringBuffer = new StringBuffer) extends PrettyPrinter(opts, opgm, sb) { - private val dbquote = "\"" override def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = { tree match { diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index f3e7c15886fbce8b4e7c2854f0f83fc2a74e26c7..e06055dc4d9dfce8b24d2d8ddb698ebbbc781079 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -8,6 +8,7 @@ import Common._ import Definitions._ import Expressions._ import Extractors._ +import Constructors.letDef class ScopeSimplifier extends Transformer { case class Scope(inScope: Set[Identifier] = Set(), oldToNew: Map[Identifier, Identifier] = Map(), funDefs: Map[FunDef, FunDef] = Map()) { @@ -65,7 +66,7 @@ class ScopeSimplifier extends Transformer { for((newFd, localScopeToRegister, fd) <- fds_mapping) { newFd.fullBody = rec(fd.fullBody, newScope.register(localScopeToRegister)) } - LetDef(fds_mapping.map(_._1), rec(body, newScope)) + letDef(fds_mapping.map(_._1), rec(body, newScope)) case MatchExpr(scrut, cases) => val rs = rec(scrut, scope) diff --git a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala index 32474bcb918d6cd6a2fe2d0ec4cba53fcc5c63a3..0257a0b7a7c5592215206eefbb6928c46cc9995f 100644 --- a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala +++ b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala @@ -98,7 +98,6 @@ class SelfPrettyPrinter { this.excluded = excluded val s = prettyPrintersForType(v.getType) // TODO: Included the variable excluded if necessary. if(s.isEmpty) { - println("Could not find pretty printer for type " + v.getType) orElse } else { val l: Lambda = s.head diff --git a/src/main/scala/leon/purescala/SubTreeOps.scala b/src/main/scala/leon/purescala/SubTreeOps.scala new file mode 100644 index 0000000000000000000000000000000000000000..140bd5edc2ff5f316a7afb0d90442df12b78ced8 --- /dev/null +++ b/src/main/scala/leon/purescala/SubTreeOps.scala @@ -0,0 +1,327 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Expressions.Expr +import Types.TypeTree +import Common._ +import utils._ + +object SubTreeOps { + trait Extractor[SubTree <: Tree] { + def unapply(e: SubTree): Option[(Seq[SubTree], (Seq[SubTree]) => SubTree)] + } +} +trait SubTreeOps[SubTree <: Tree] { + val Deconstructor: SubTreeOps.Extractor[SubTree] + + /* ======== + * Core API + * ======== + * + * All these functions should be stable, tested, and used everywhere. Modify + * with care. + */ + + /** Does a right tree fold + * + * A right tree fold applies the input function to the subnodes first (from left + * to right), and combine the results along with the current node value. + * + * @param f a function that takes the current node and the seq + * of results form the subtrees. + * @param e The value on which to apply the fold. + * @return The expression after applying `f` on all subtrees. + * @note the computation is lazy, hence you should not rely on side-effects of `f` + */ + def fold[T](f: (SubTree, Seq[T]) => T)(e: SubTree): T = { + val rec = fold(f) _ + val Deconstructor(es, _) = e + + //Usages of views makes the computation lazy. (which is useful for + //contains-like operations) + f(e, es.view.map(rec)) + } + + + /** Pre-traversal of the tree. + * + * Invokes the input function on every node '''before''' visiting + * children. Traverse children from left to right subtrees. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) + * }}} + * will yield, in order: + * {{{ + * f(Add(a, Minus(b, c))); f(a); f(Minus(b, c)); f(b); f(c) + * }}} + * + * @param f a function to apply on each node of the expression + * @param e the expression to traverse + */ + def preTraversal(f: SubTree => Unit)(e: SubTree): Unit = { + val rec = preTraversal(f) _ + val Deconstructor(es, _) = e + f(e) + es.foreach(rec) + } + + /** Post-traversal of the tree. + * + * Invokes the input function on every node '''after''' visiting + * children. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) + * }}} + * will yield, in order: + * {{{ + * f(a), f(b), f(c), f(Minus(b, c)), f(Add(a, Minus(b, c))) + * }}} + * + * @param f a function to apply on each node of the expression + * @param e the expression to traverse + */ + def postTraversal(f: SubTree => Unit)(e: SubTree): Unit = { + val rec = postTraversal(f) _ + val Deconstructor(es, _) = e + es.foreach(rec) + f(e) + } + + /** Pre-transformation of the tree. + * + * Takes a partial function of replacements and substitute + * '''before''' recursing down the trees. + * + * Supports two modes : + * + * - If applyRec is false (default), will only substitute once on each level. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, d) // And not Add(a, f) because it only substitute once for each level. + * }}} + * + * - If applyRec is true, it will substitute multiple times on each level: + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, f) + * }}} + * + * @note The mode with applyRec true can diverge if f is not well formed + */ + def preMap(f: SubTree => Option[SubTree], applyRec : Boolean = false)(e: SubTree): SubTree = { + val rec = preMap(f, applyRec) _ + + val newV = if (applyRec) { + // Apply f as long as it returns Some() + fixpoint { e : SubTree => f(e) getOrElse e } (e) + } else { + f(e) getOrElse e + } + + val Deconstructor(es, builder) = newV + val newEs = es.map(rec) + + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(newV) + } else { + newV + } + } + + + /** Post-transformation of the tree. + * + * Takes a partial function of replacements. + * Substitutes '''after''' recursing down the trees. + * + * Supports two modes : + * + * - If applyRec is false (default), will only substitute once on each level. + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e + * }}} + * will yield: + * {{{ + * Add(a, Minus(e, c)) + * }}} + * + * - If applyRec is true, it will substitute multiple times on each level: + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, f) + * }}} + * + * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent) + */ + def postMap(f: SubTree => Option[SubTree], applyRec : Boolean = false)(e: SubTree): SubTree = { + val rec = postMap(f, applyRec) _ + + val Deconstructor(es, builder) = e + val newEs = es.map(rec) + val newV = { + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + } + + if (applyRec) { + // Apply f as long as it returns Some() + fixpoint { e : SubTree => f(e) getOrElse e } (newV) + } else { + f(newV) getOrElse newV + } + + } + + + /** Applies functions and combines results in a generic way + * + * Start with an initial value, and apply functions to nodes before + * and after the recursion in the children. Combine the results of + * all children and apply a final function on the resulting node. + * + * @param pre a function applied on a node before doing a recursion in the children + * @param post a function applied to the node built from the recursive application to + all children + * @param combiner a function to combine the resulting values from all children with + the current node + * @param init the initial value + * @param expr the expression on which to apply the transform + * + * @see [[simpleTransform]] + * @see [[simplePreTransform]] + * @see [[simplePostTransform]] + */ + def genericTransform[C](pre: (SubTree, C) => (SubTree, C), + post: (SubTree, C) => (SubTree, C), + combiner: (SubTree, Seq[C]) => C)(init: C)(expr: SubTree) = { + + def rec(eIn: SubTree, cIn: C): (SubTree, C) = { + + val (expr, ctx) = pre(eIn, cIn) + val Deconstructor(es, builder) = expr + val (newExpr, newC) = { + val (nes, cs) = es.map{ rec(_, ctx)}.unzip + val newE = builder(nes).copiedFrom(expr) + + (newE, combiner(newE, cs)) + } + + post(newExpr, newC) + } + + rec(expr, init) + } + + /** Pre-transformation of the tree, with a context value from "top-down". + * + * Takes a partial function of replacements. + * Substitutes '''before''' recursing down the trees. The function returns + * an option of the new value, as well as the new context to be used for + * the recursion in its children. The context is "lost" when going back up, + * so changes made by one node will not be see by its siblings. + */ + def preMapWithContext[C](f: (SubTree, C) => (Option[SubTree], C), applyRec: Boolean = false) + (e: SubTree, c: C): SubTree = { + + def rec(expr: SubTree, context: C): SubTree = { + + val (newV, newCtx) = { + if(applyRec) { + var ctx = context + val finalV = fixpoint{ e: SubTree => { + val res = f(e, ctx) + ctx = res._2 + res._1.getOrElse(e) + }} (expr) + (finalV, ctx) + } else { + val res = f(expr, context) + (res._1.getOrElse(expr), res._2) + } + } + + val Deconstructor(es, builder) = newV + val newEs = es.map(e => rec(e, newCtx)) + + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(newV) + } else { + newV + } + + } + + rec(e, c) + } + + /* + * ============= + * Auxiliary API + * ============= + * + * Convenient methods using the Core API. + */ + + /** Checks if the predicate holds in some sub-expression */ + def exists(matcher: SubTree => Boolean)(e: SubTree): Boolean = { + fold[Boolean]({ (e, subs) => matcher(e) || subs.contains(true) } )(e) + } + + /** Collects a set of objects from all sub-expressions */ + def collect[T](matcher: SubTree => Set[T])(e: SubTree): Set[T] = { + fold[Set[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) + } + + def collectPreorder[T](matcher: SubTree => Seq[T])(e: SubTree): Seq[T] = { + fold[Seq[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) + } + + /** Returns a set of all sub-expressions matching the predicate */ + def filter(matcher: SubTree => Boolean)(e: SubTree): Set[SubTree] = { + collect[SubTree] { e => Set(e) filter matcher }(e) + } + + /** Counts how many times the predicate holds in sub-expressions */ + def count(matcher: SubTree => Int)(e: SubTree): Int = { + fold[Int]({ (e, subs) => matcher(e) + subs.sum } )(e) + } + + /** Replaces bottom-up sub-expressions by looking up for them in a map */ + def replace(substs: Map[SubTree,SubTree], expr: SubTree) : SubTree = { + postMap(substs.lift)(expr) + } + + /** Replaces bottom-up sub-expressions by looking up for them in the provided order */ + def replaceSeq(substs: Seq[(SubTree, SubTree)], expr: SubTree): SubTree = { + var res = expr + for (s <- substs) { + res = replace(Map(s), res) + } + res + } + +} \ No newline at end of file diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index db655365c24a7304831e818849238bba0a849de9..14cf3b8250682a71285fffb3fd34e2e25608cdb1 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -11,16 +11,16 @@ import Extractors._ import Constructors._ import ExprOps.preMap -object TypeOps { +object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree] { def typeDepth(t: TypeTree): Int = t match { - case NAryType(tps, builder) => 1+ (0 +: (tps map typeDepth)).max + case NAryType(tps, builder) => 1 + (0 +: (tps map typeDepth)).max } - def typeParamsOf(t: TypeTree): Set[TypeParameter] = t match { - case tp: TypeParameter => Set(tp) - case _ => - val NAryType(subs, _) = t - subs.flatMap(typeParamsOf).toSet + def typeParamsOf(t: TypeTree): Set[TypeParameter] = { + collect[TypeParameter]({ + case tp: TypeParameter => Set(tp) + case _ => Set.empty + })(t) } def canBeSubtypeOf( @@ -313,7 +313,7 @@ object TypeOps { val returnType = tpeSub(fd.returnType) val params = fd.params map (vd => vd.copy(id = freshId(vd.id, tpeSub(vd.getType)))) val newFd = fd.duplicate(id, tparams, params, returnType) - val subCalls = preMap { + val subCalls = ExprOps.preMap { case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd => Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi)) case _ => @@ -335,7 +335,7 @@ object TypeOps { } val newBd = srec(subCalls(bd)).copiedFrom(bd) - LetDef(newFds, newBd).copiedFrom(l) + letDef(newFds, newBd).copiedFrom(l) case l @ Lambda(args, body) => val newArgs = args.map { arg => diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index 3a0a85bb24045df18ab65b0afa488a34c8921315..9ec0e4b41f33aa0895ff160a5379d16a7cb88d87 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -133,7 +133,7 @@ object Types { case class AbstractClassType(classDef: AbstractClassDef, tps: Seq[TypeTree]) extends ClassType case class CaseClassType(classDef: CaseClassDef, tps: Seq[TypeTree]) extends ClassType - object NAryType { + object NAryType extends SubTreeOps.Extractor[TypeTree] { def unapply(t: TypeTree): Option[(Seq[TypeTree], Seq[TypeTree] => TypeTree)] = t match { case CaseClassType(ccd, ts) => Some((ts, ts => CaseClassType(ccd, ts))) case AbstractClassType(acd, ts) => Some((ts, ts => AbstractClassType(acd, ts))) @@ -142,6 +142,7 @@ object Types { case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1)))) case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head))) + /* n-ary operators */ case t => Some(Nil, _ => t) } } diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 37f187679897bbec908ffeaf18f5385198f915c9..9dcd782a3323138e5bfe4c68822fba614e2065e7 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -3,6 +3,7 @@ package leon package repair +import leon.datagen.GrammarDataGen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ @@ -25,7 +26,6 @@ import synthesis.Witnesses._ import synthesis.graph.{dotGenIds, DotGenerator} import rules._ -import grammars._ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeoutMs: Option[Long], repairTimeoutMs: Option[Long]) { implicit val ctx = ctx0 @@ -155,7 +155,7 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou }(DebugSectionReport) if (synth.settings.generateDerivationTrees) { - val dot = new DotGenerator(search.g) + val dot = new DotGenerator(search) dot.writeFile("derivation"+ dotGenIds.nextGlobal + ".dot") } @@ -236,29 +236,10 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou def discoverTests(): ExamplesBank = { - import bonsai.enumerators._ - val maxEnumerated = 1000 val maxValid = 400 val evaluator = new CodeGenEvaluator(ctx, program, CodeGenParams.default) - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](ValueGrammar.getProductions) - - val inputs = enum.iterator(tupleTypeWrap(fd.params map { _.getType})).map(unwrapTuple(_, fd.params.size)) - - val filtering: Seq[Expr] => Boolean = fd.precondition match { - case None => - _ => true - case Some(pre) => - val argIds = fd.paramIds - evaluator.compile(pre, argIds) match { - case Some(evalFun) => - val sat = EvaluationResults.Successful(BooleanLiteral(true)); - { (es: Seq[Expr]) => evalFun(new solvers.Model((argIds zip es).toMap)) == sat } - case None => - { _ => false } - } - } val inputsToExample: Seq[Expr] => Example = { ins => evaluator.eval(functionInvocation(fd, ins)) match { @@ -269,10 +250,10 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou } } - val generatedTests = inputs - .take(maxEnumerated) - .filter(filtering) - .take(maxValid) + val dataGen = new GrammarDataGen(evaluator) + + val generatedTests = dataGen + .generateFor(fd.paramIds, fd.precOrTrue, maxValid, maxEnumerated) .map(inputsToExample) .toList diff --git a/src/main/scala/leon/solvers/Model.scala b/src/main/scala/leon/solvers/Model.scala index 07bdee913f21605fbc41f660af608c492e5ee1b5..060cb7fc6fcf83df0785002a1dbfc35f40918113 100644 --- a/src/main/scala/leon/solvers/Model.scala +++ b/src/main/scala/leon/solvers/Model.scala @@ -68,6 +68,7 @@ class Model(protected val mapping: Map[Identifier, Expr]) def isDefinedAt(id: Identifier): Boolean = mapping.isDefinedAt(id) def get(id: Identifier): Option[Expr] = mapping.get(id) def getOrElse[E >: Expr](id: Identifier, e: E): E = get(id).getOrElse(e) + def ids = mapping.keys def apply(id: Identifier): Expr = get(id).getOrElse { throw new IllegalArgumentException } } diff --git a/src/main/scala/leon/solvers/QuantificationSolver.scala b/src/main/scala/leon/solvers/QuantificationSolver.scala index fa11ab6613bd65b196cce87ee062c3c56f0b95f9..4f56903c578da6c2d1b71f19268e8885cb93e99e 100644 --- a/src/main/scala/leon/solvers/QuantificationSolver.scala +++ b/src/main/scala/leon/solvers/QuantificationSolver.scala @@ -25,7 +25,7 @@ class HenkinModelBuilder(domains: HenkinDomains) override def result = new HenkinModel(mapBuilder.result, domains) } -trait QuantificationSolver { +trait QuantificationSolver extends Solver { val program: Program def getModel: HenkinModel diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 67d28f877019a3e4741df6d268c68fc2d86d17ee..42d7ead0235a13ad3465dcd1fe64fd80db33ed25 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -79,10 +79,12 @@ object SolverFactory { def getFromName(ctx: LeonContext, program: Program)(name: String): SolverFactory[TimeoutSolver] = name match { case "fairz3" => - SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver) + // Previously: new FairZ3Solver(ctx, program) with TimeoutSolver + SolverFactory(() => new Z3StringFairZ3Solver(ctx, program) with TimeoutSolver) case "unrollz3" => - SolverFactory(() => new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) + // Previously: new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver + SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) case "enum" => SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver) @@ -91,10 +93,12 @@ object SolverFactory { SolverFactory(() => new GroundSolver(ctx, program) with TimeoutSolver) case "smt-z3" => - SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) + // Previously: new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver + SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) case "smt-z3-q" => - SolverFactory(() => new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) + // Previously: new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver + SolverFactory(() => new Z3StringSMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) case "smt-cvc4" => SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBCVC4Solver(ctx, program)) with TimeoutSolver) diff --git a/src/main/scala/leon/solvers/SolverUnsupportedError.scala b/src/main/scala/leon/solvers/SolverUnsupportedError.scala index 5d519160d7aed9fce7a42584c8d53806e53e265a..2efc8ea39b0da8494b2cd1309b3dcf9c2ca9cec3 100644 --- a/src/main/scala/leon/solvers/SolverUnsupportedError.scala +++ b/src/main/scala/leon/solvers/SolverUnsupportedError.scala @@ -7,7 +7,7 @@ import purescala.Common.Tree object SolverUnsupportedError { def msg(t: Tree, s: Solver, reason: Option[String]) = { - s" is unsupported by solver ${s.name}" + reason.map(":\n " + _ ).getOrElse("") + s"(of ${t.getClass}) is unsupported by solver ${s.name}" + reason.map(":\n " + _ ).getOrElse("") } } diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..602873aba9df1859650ee94486c59d9487ec7fbd --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala @@ -0,0 +1,239 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Quantification._ +import purescala.Constructors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import purescala.DefOps +import purescala.TypeOps +import purescala.Extractors._ +import utils._ +import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre, optNoChecks, optUnfoldFactor} +import templates._ +import evaluators._ +import Template._ +import leon.solvers.z3.Z3StringConversion +import leon.utils.Bijection +import leon.solvers.z3.StringEcoSystem + +object Z3StringCapableSolver { + def convert(p: Program, force: Boolean = false): (Program, Option[Z3StringConversion]) = { + val converter = new Z3StringConversion(p) + import converter.Forward._ + var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() + var hasStrings = false + val program_with_strings = converter.getProgram + val (new_program, fdMap) = DefOps.replaceFunDefs(program_with_strings)((fd: FunDef) => { + globalFdMap.get(fd).map(_._2).orElse( + if( fd.body.map(exists(e => TypeOps.exists{ _== StringType }(e.getType))).getOrElse(false) || + fd.paramIds.exists(id => TypeOps.exists(_ == StringType)(id.getType))) { + val idMap = fd.params.map(vd => vd.id -> convertId(vd.id)).toMap + val newFdId = convertId(fd.id) + val newFd = fd.duplicate(newFdId, + fd.tparams, + fd.params.map(vd => ValDef(idMap(vd.id))), + convertType(fd.returnType)) + globalFdMap += fd -> ((idMap, newFd)) + hasStrings = hasStrings || (program_with_strings.library.escape.get != fd) + Some(newFd) + } else None + ) + }) + if(!hasStrings && !force) { + (p, None) + } else { + converter.globalFdMap ++= globalFdMap.view.map(kv => (kv._1, kv._2._2)) + for((fd, (idMap, newFd)) <- globalFdMap) { + implicit val idVarMap = idMap.mapValues(id => Variable(id)) + newFd.fullBody = convertExpr(newFd.fullBody) + } + (new_program, Some(converter)) + } + } +} +trait ForcedProgramConversion { self: Z3StringCapableSolver[_] => + override def convertProgram(p: Program): (Program, Option[Z3StringConversion]) = { + Z3StringCapableSolver.convert(p, true) + } +} + +abstract class Z3StringCapableSolver[+TUnderlying <: Solver](val context: LeonContext, val program: Program, + val underlyingConstructor: (Program, Option[Z3StringConversion]) => TUnderlying) +extends Solver { + def convertProgram(p: Program): (Program, Option[Z3StringConversion]) = Z3StringCapableSolver.convert(p) + protected val (new_program, someConverter) = convertProgram(program) + + val underlying = underlyingConstructor(new_program, someConverter) + + def getModel: leon.solvers.Model = { + val model = underlying.getModel + someConverter match { + case None => model + case Some(converter) => + println("Conversion") + val ids = model.ids.toSeq + val exprs = ids.map(model.apply) + import converter.Backward._ + val original_ids = ids.map(convertId) + val original_exprs = exprs.map{ case e => convertExpr(e)(Map()) } + new Model(original_ids.zip(original_exprs).toMap) + } + } + + // Members declared in leon.utils.Interruptible + def interrupt(): Unit = underlying.interrupt() + def recoverInterrupt(): Unit = underlying.recoverInterrupt() + + // Members declared in leon.solvers.Solver + def assertCnstr(expression: Expr): Unit = { + someConverter.map{converter => + import converter.Forward._ + val newExpression = convertExpr(expression)(Map()) + underlying.assertCnstr(newExpression) + }.getOrElse(underlying.assertCnstr(expression)) + } + def getUnsatCore: Set[Expr] = { + someConverter.map{converter => + import converter.Backward._ + underlying.getUnsatCore map (e => convertExpr(e)(Map())) + }.getOrElse(underlying.getUnsatCore) + } + def check: Option[Boolean] = underlying.check + def free(): Unit = underlying.free() + def pop(): Unit = underlying.pop() + def push(): Unit = underlying.push() + def reset(): Unit = underlying.reset() + def name: String = underlying.name +} + +import z3._ + +trait Z3StringAbstractZ3Solver[TUnderlying <: Solver] extends AbstractZ3Solver { self: Z3StringCapableSolver[TUnderlying] => +} + +trait Z3StringNaiveAssumptionSolver[TUnderlying <: Solver] extends NaiveAssumptionSolver { self: Z3StringCapableSolver[TUnderlying] => +} + +trait Z3StringEvaluatingSolver[TUnderlying <: EvaluatingSolver] extends EvaluatingSolver{ self: Z3StringCapableSolver[TUnderlying] => + // Members declared in leon.solvers.EvaluatingSolver + val useCodeGen: Boolean = underlying.useCodeGen +} + +trait Z3StringQuantificationSolver[TUnderlying <: QuantificationSolver] extends QuantificationSolver { self: Z3StringCapableSolver[TUnderlying] => + // Members declared in leon.solvers.QuantificationSolver + override def getModel: leon.solvers.HenkinModel = { + val model = underlying.getModel + someConverter map { converter => + val ids = model.ids.toSeq + val exprs = ids.map(model.apply) + import converter.Backward._ + val original_ids = ids.map(convertId) + val original_exprs = exprs.map{ case e => convertExpr(e)(Map()) } + + val new_domain = new HenkinDomains( + model.doms.lambdas.map(kv => + (convertExpr(kv._1)(Map()).asInstanceOf[Lambda], + kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap, + model.doms.tpes.map(kv => + (convertType(kv._1), + kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap + ) + + new HenkinModel(original_ids.zip(original_exprs).toMap, new_domain) + } getOrElse model + } +} + +trait EvaluatorCheckConverter extends DeterministicEvaluator { + def converter: Z3StringConversion + abstract override def check(expression: Expr, model: solvers.Model) : CheckResult = { + val c = converter + import c.Backward._ // Because the evaluator is going to be called by the underlying solver, but it will use the original program + super.check(convertExpr(expression)(Map()), convertModel(model)) + } +} + +class ConvertibleCodeGenEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) + extends CodeGenEvaluator(context, originalProgram) with EvaluatorCheckConverter { + override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { + import converter._ + super.compile(Backward.convertExpr(expression)(Map()), args.map(Backward.convertId)) + .map(evaluator => (m: Model) => Forward.convertResult(evaluator(Backward.convertModel(m))) + ) + } +} + +class ConvertibleDefaultEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) extends DefaultEvaluator(context, originalProgram) with EvaluatorCheckConverter { + override def eval(ex: Expr, model: Model): EvaluationResults.Result[Expr] = { + import converter._ + Forward.convertResult(super.eval(Backward.convertExpr(ex)(Map()), Backward.convertModel(model))) + } +} + + +class FairZ3SolverWithBackwardEvaluator(context: LeonContext, program: Program, + originalProgram: Program, someConverter: Option[Z3StringConversion]) extends FairZ3Solver(context, program) { + override lazy val evaluator: DeterministicEvaluator = { // We evaluate expressions using the original evaluator + someConverter match { + case Some(converter) => + if (useCodeGen) { + new ConvertibleCodeGenEvaluator(context, originalProgram, converter) + } else { + new ConvertibleDefaultEvaluator(context, originalProgram, converter) + } + case None => + if (useCodeGen) { + new CodeGenEvaluator(context, program) + } else { + new DefaultEvaluator(context, program) + } + } + } +} + + +class Z3StringFairZ3Solver(context: LeonContext, program: Program) + extends Z3StringCapableSolver(context, program, + (prgm: Program, someConverter: Option[Z3StringConversion]) => + new FairZ3SolverWithBackwardEvaluator(context, prgm, program, someConverter)) + with Z3StringEvaluatingSolver[FairZ3Solver] + with Z3StringQuantificationSolver[FairZ3Solver] { + // Members declared in leon.solvers.z3.AbstractZ3Solver + protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + someConverter match { + case None => underlying.checkAssumptions(assumptions) + case Some(converter) => + underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + } + } +} + +class Z3StringUnrollingSolver(context: LeonContext, program: Program, underlyingSolverConstructor: Program => Solver) + extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => + new UnrollingSolver(context, program, underlyingSolverConstructor(program))) + with Z3StringNaiveAssumptionSolver[UnrollingSolver] + with Z3StringEvaluatingSolver[UnrollingSolver] + with Z3StringQuantificationSolver[UnrollingSolver] { + override def getUnsatCore = super[Z3StringNaiveAssumptionSolver].getUnsatCore +} + +class Z3StringSMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) + extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => + new smtlib.SMTLIBZ3QuantifiedSolver(context, program)) { + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + someConverter match { + case None => underlying.checkAssumptions(assumptions) + case Some(converter) => + underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + } + } +} + diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 47017bcf1471770f1b1e9fb574a81bed34f4c515..d05b45c35e2e3847be4e1a6650c779de91da91d9 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -104,6 +104,7 @@ trait SMTLIBTarget extends Interruptible { interpreter.eval(cmd) match { case err @ ErrorResponse(msg) if !hasError && !interrupted && !rawOut => reporter.warning(s"Unexpected error from $targetName solver: $msg") + //println(Thread.currentThread().getStackTrace.map(_.toString).take(10).mkString("\n")) // Store that there was an error. Now all following check() // invocations will return None addError() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index 1731b94ae5f4f91b87248deddc50db5339915552..3d4a06a838a5057a693d85754bb5113e1ce7d0ae 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -8,15 +8,15 @@ import purescala.Common._ import purescala.Expressions._ import purescala.Constructors._ import purescala.Types._ -import purescala.Definitions._ + import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} import _root_.smtlib.interpreters.Z3Interpreter import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} import _root_.smtlib.theories.ArraysEx -import leon.solvers.z3.Z3StringConversion -trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { +trait SMTLIBZ3Target extends SMTLIBTarget { + def targetName = "z3" def interpreterOps(ctx: LeonContext) = { @@ -40,11 +40,11 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { override protected def declareSort(t: TypeTree): Sort = { val tpe = normalizeType(t) sorts.cachedB(tpe) { - convertType(tpe) match { + tpe match { case SetType(base) => super.declareSort(BooleanType) declareSetSort(base) - case t => + case _ => super.declareSort(t) } } @@ -69,13 +69,9 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { Sort(SMTIdentifier(setSort.get), Seq(declareSort(of))) } - override protected def fromSMT(t: Term, expected_otpe: Option[TypeTree] = None) + override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { - val otpe = expected_otpe match { - case Some(StringType) => Some(listchar) - case _ => expected_otpe - } - val res = (t, otpe) match { + (t, otpe) match { case (SimpleSymbol(s), Some(tp: TypeParameter)) => val n = s.name.split("!").toList.last GenericValue(tp, n.toInt) @@ -100,16 +96,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { case _ => super.fromSMT(t, otpe) } - expected_otpe match { - case Some(StringType) => - StringLiteral(convertToString(res)(program)) - case _ => res - } - } - - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = toSMT(e) - def targetApplication(tfd: TypedFunDef, args: Seq[Term])(implicit bindings: Map[Identifier, Term]): Term = { - FunctionApplication(declareFunction(tfd), args) } override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { @@ -146,7 +132,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { case SetIntersection(l, r) => ArrayMap(SSymbol("and"), toSMT(l), toSMT(r)) - case StringConverted(result) => result case _ => super.toSMT(e) } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index aa0e5dad6506be945919cb22361cce649f40c077..59be52775406a78c22a55773dfd161db003b21bb 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -9,7 +9,7 @@ import purescala.Expressions._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -import purescala.TypeOps._ +import purescala.TypeOps.bestRealType import purescala.Definitions._ import purescala.Constructors._ import purescala.Quantification._ @@ -207,7 +207,11 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], // id => expr && ... && expr var guardedExprs = Map[Identifier, Seq[Expr]]() def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { - assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean") + assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean." + ( + purescala.ExprOps.fold[String]{ (e, se) => + s"$e is of type ${e.getType}" + se.map(child => "\n " + "\n".r.replaceAllIn(child, "\n ")).mkString + }(expr) + )) val prev = guardedExprs.getOrElse(guardVar, Nil) diff --git a/src/main/scala/leon/solvers/templates/TemplateManager.scala b/src/main/scala/leon/solvers/templates/TemplateManager.scala index 2b75f08f0480cf272515bb8d8393e01e29d4dbf1..cdfe0c9ed6462b322b760e635122f5c3b11d2923 100644 --- a/src/main/scala/leon/solvers/templates/TemplateManager.scala +++ b/src/main/scala/leon/solvers/templates/TemplateManager.scala @@ -11,7 +11,7 @@ import purescala.Quantification._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -import purescala.TypeOps._ +import purescala.TypeOps.bestRealType import utils._ diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index ac1e8855a4a53ed5ef2f66c34d4b015d9a26ba3f..7c3486ff533a630bbd1bce6bde31d86969163e7c 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -264,322 +264,311 @@ trait AbstractZ3Solver extends Solver { case other => throw SolverUnsupportedError(other, this) } - - protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier, Z3AST] = Map.empty): Z3AST = { - implicit var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { + var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { initialMap } else { // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of // variable has to remain in a map etc. variables.aToB.collect{ case (Variable(id), p2) => id -> p2 } } - new Z3StringConversion[Z3AST] { - def getProgram = AbstractZ3Solver.this.program - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Z3AST]): Z3AST = { - rec(e) - } - def targetApplication(tfd: TypedFunDef, args: Seq[Z3AST])(implicit bindings: Map[Identifier, Z3AST]): Z3AST = { - z3.mkApp(functionDefToDecl(tfd), args: _*) + + def rec(ex: Expr): Z3AST = ex match { + + // TODO: Leave that as a specialization? + case LetTuple(ids, e, b) => { + z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => + val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) + entry } - def rec(ex: Expr): Z3AST = ex match { - - // TODO: Leave that as a specialization? - case LetTuple(ids, e, b) => { - z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => - val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) - entry - } - val rb = rec(b) - z3Vars = z3Vars -- ids - rb - } - - case p @ Passes(_, _, _) => - rec(p.asConstraint) - - case me @ MatchExpr(s, cs) => - rec(matchToIfThenElse(me)) - - case Let(i, e, b) => { - val re = rec(e) - z3Vars = z3Vars + (i -> re) - val rb = rec(b) - z3Vars = z3Vars - i - rb - } - - case Waypoint(_, e, _) => rec(e) - case a @ Assert(cond, err, body) => - rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) - - case e @ Error(tpe, _) => { - val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) - // Might introduce dupplicates (e), but no worries here - variables += (e -> newAST) - newAST - } - case v @ Variable(id) => z3Vars.get(id) match { - case Some(ast) => + val rb = rec(b) + z3Vars = z3Vars -- ids + rb + } + + case p @ Passes(_, _, _) => + rec(p.asConstraint) + + case me @ MatchExpr(s, cs) => + rec(matchToIfThenElse(me)) + + case Let(i, e, b) => { + val re = rec(e) + z3Vars = z3Vars + (i -> re) + val rb = rec(b) + z3Vars = z3Vars - i + rb + } + + case Waypoint(_, e, _) => rec(e) + case a @ Assert(cond, err, body) => + rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) + + case e @ Error(tpe, _) => { + val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) + // Might introduce dupplicates (e), but no worries here + variables += (e -> newAST) + newAST + } + case v @ Variable(id) => z3Vars.get(id) match { + case Some(ast) => + ast + case None => { + variables.getB(v) match { + case Some(ast) => ast - case None => { - variables.getB(v) match { - case Some(ast) => - ast - - case None => - val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) - z3Vars = z3Vars + (id -> newAST) - variables += (v -> newAST) - newAST - } - } - } - - case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) - case And(exs) => z3.mkAnd(exs.map(rec): _*) - case Or(exs) => z3.mkOr(exs.map(rec): _*) - case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) - case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) - case Not(e) => z3.mkNot(rec(e)) - case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) - case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) - case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) - case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) - case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() - case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) - case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) - case Minus(l, r) => z3.mkSub(rec(l), rec(r)) - case Times(l, r) => z3.mkMul(rec(l), rec(r)) - case Division(l, r) => { - val rl = rec(l) - val rr = rec(r) - z3.mkITE( - z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), - z3.mkDiv(rl, rr), - z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) - ) - } - case Remainder(l, r) => { - val q = rec(Division(l, r)) - z3.mkSub(rec(l), z3.mkMul(rec(r), q)) - } - case Modulo(l, r) => { - z3.mkMod(rec(l), rec(r)) - } - case UMinus(e) => z3.mkUnaryMinus(rec(e)) - - case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) - case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) - case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) - case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) - case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) - - case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) - case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) - case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) - case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) - case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) - case BVUMinus(e) => z3.mkBVNeg(rec(e)) - case BVNot(e) => z3.mkBVNot(rec(e)) - case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) - case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) - case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) - case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) - case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) - case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) - case LessThan(l, r) => l.getType match { - case IntegerType => z3.mkLT(rec(l), rec(r)) - case RealType => z3.mkLT(rec(l), rec(r)) - case Int32Type => z3.mkBVSlt(rec(l), rec(r)) - case CharType => z3.mkBVSlt(rec(l), rec(r)) - } - case LessEquals(l, r) => l.getType match { - case IntegerType => z3.mkLE(rec(l), rec(r)) - case RealType => z3.mkLE(rec(l), rec(r)) - case Int32Type => z3.mkBVSle(rec(l), rec(r)) - case CharType => z3.mkBVSle(rec(l), rec(r)) - //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") - } - case GreaterThan(l, r) => l.getType match { - case IntegerType => z3.mkGT(rec(l), rec(r)) - case RealType => z3.mkGT(rec(l), rec(r)) - case Int32Type => z3.mkBVSgt(rec(l), rec(r)) - case CharType => z3.mkBVSgt(rec(l), rec(r)) - } - case GreaterEquals(l, r) => l.getType match { - case IntegerType => z3.mkGE(rec(l), rec(r)) - case RealType => z3.mkGE(rec(l), rec(r)) - case Int32Type => z3.mkBVSge(rec(l), rec(r)) - case CharType => z3.mkBVSge(rec(l), rec(r)) - } - - case StringConverted(result) => - result - - case u : UnitLiteral => - val tpe = normalizeType(u.getType) - typeToSort(tpe) - val constructor = constructors.toB(tpe) - constructor() - - case t @ Tuple(es) => - val tpe = normalizeType(t.getType) - typeToSort(tpe) - val constructor = constructors.toB(tpe) - constructor(es.map(rec): _*) - - case ts @ TupleSelect(t, i) => - val tpe = normalizeType(t.getType) - typeToSort(tpe) - val selector = selectors.toB((tpe, i-1)) - selector(rec(t)) - - case c @ CaseClass(ct, args) => - typeToSort(ct) // Making sure the sort is defined - val constructor = constructors.toB(ct) - constructor(args.map(rec): _*) - - case c @ CaseClassSelector(cct, cc, sel) => - typeToSort(cct) // Making sure the sort is defined - val selector = selectors.toB(cct, c.selectorIndex) - selector(rec(cc)) - - case AsInstanceOf(expr, ct) => - rec(expr) - - case IsInstanceOf(e, act: AbstractClassType) => - act.knownCCDescendants match { - case Seq(cct) => - rec(IsInstanceOf(e, cct)) - case more => - val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) - rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) - } - - case IsInstanceOf(e, cct: CaseClassType) => - typeToSort(cct) // Making sure the sort is defined - val tester = testers.toB(cct) - tester(rec(e)) - - case al @ ArraySelect(a, i) => - val tpe = normalizeType(a.getType) - - val sa = rec(a) - val content = selectors.toB((tpe, 1))(sa) - - z3.mkSelect(content, rec(i)) - - case al @ ArrayUpdated(a, i, e) => - val tpe = normalizeType(a.getType) - - val sa = rec(a) - val ssize = selectors.toB((tpe, 0))(sa) - val scontent = selectors.toB((tpe, 1))(sa) - - val newcontent = z3.mkStore(scontent, rec(i), rec(e)) - - val constructor = constructors.toB(tpe) - - constructor(ssize, newcontent) - - case al @ ArrayLength(a) => - val tpe = normalizeType(a.getType) - val sa = rec(a) - selectors.toB((tpe, 0))(sa) - - case arr @ FiniteArray(elems, oDefault, length) => - val at @ ArrayType(base) = normalizeType(arr.getType) - typeToSort(at) - - val default = oDefault.getOrElse(simplestValue(base)) - - val ar = rec(RawArrayValue(Int32Type, elems.map { - case (i, e) => IntLiteral(i) -> e - }, default)) - - constructors.toB(at)(rec(length), ar) - - case f @ FunctionInvocation(tfd, args) => - z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) - - case fa @ Application(caller, args) => - val ft @ FunctionType(froms, to) = normalizeType(caller.getType) - val funDecl = lambdas.cachedB(ft) { - val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) - val returnSort = typeToSort(to) - - val name = FreshIdentifier("dynLambda").uniqueName - z3.mkFreshFuncDecl(name, sortSeq, returnSort) - } - z3.mkApp(funDecl, (caller +: args).map(rec): _*) - - case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) - case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) - case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) - case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) - case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) - case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) - - case RawArrayValue(keyTpe, elems, default) => - val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) - - elems.foldLeft(ar) { - case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) - } - - /** - * ===== Map operations ===== - */ - case m @ FiniteMap(elems, from, to) => - val MapType(_, t) = normalizeType(m.getType) - - rec(RawArrayValue(from, elems.map{ - case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) - }.toMap, CaseClass(library.noneType(t), Seq()))) - - case MapApply(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) - typeToSort(mt) - - val el = z3.mkSelect(rec(m), rec(k)) - - // Really ?!? We don't check that it is actually != None? - selectors.toB(library.someType(t), 0)(el) - - case MapIsDefinedAt(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) - typeToSort(mt) - - val el = z3.mkSelect(rec(m), rec(k)) - - testers.toB(library.someType(t))(el) - - case MapUnion(m1, FiniteMap(elems, _, _)) => - val mt @ MapType(_, t) = normalizeType(m1.getType) - typeToSort(mt) - - elems.foldLeft(rec(m1)) { case (m, (k,v)) => - z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) - } - - - case gv @ GenericValue(tp, id) => - z3.mkApp(genericValueToDecl(gv)) - - case other => - unsupported(other) + + case None => + val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) + z3Vars = z3Vars + (id -> newAST) + variables += (v -> newAST) + newAST } - }.rec(expr) + } + } + + case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) + case And(exs) => z3.mkAnd(exs.map(rec): _*) + case Or(exs) => z3.mkOr(exs.map(rec): _*) + case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) + case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) + case Not(e) => z3.mkNot(rec(e)) + case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) + case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) + case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) + case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) + case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() + case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) + case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) + case Minus(l, r) => z3.mkSub(rec(l), rec(r)) + case Times(l, r) => z3.mkMul(rec(l), rec(r)) + case Division(l, r) => { + val rl = rec(l) + val rr = rec(r) + z3.mkITE( + z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), + z3.mkDiv(rl, rr), + z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) + ) + } + case Remainder(l, r) => { + val q = rec(Division(l, r)) + z3.mkSub(rec(l), z3.mkMul(rec(r), q)) + } + case Modulo(l, r) => { + z3.mkMod(rec(l), rec(r)) + } + case UMinus(e) => z3.mkUnaryMinus(rec(e)) + + case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) + case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) + case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) + case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) + case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) + + case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) + case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) + case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) + case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) + case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) + case BVUMinus(e) => z3.mkBVNeg(rec(e)) + case BVNot(e) => z3.mkBVNot(rec(e)) + case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) + case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) + case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) + case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) + case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) + case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) + case LessThan(l, r) => l.getType match { + case IntegerType => z3.mkLT(rec(l), rec(r)) + case RealType => z3.mkLT(rec(l), rec(r)) + case Int32Type => z3.mkBVSlt(rec(l), rec(r)) + case CharType => z3.mkBVSlt(rec(l), rec(r)) + } + case LessEquals(l, r) => l.getType match { + case IntegerType => z3.mkLE(rec(l), rec(r)) + case RealType => z3.mkLE(rec(l), rec(r)) + case Int32Type => z3.mkBVSle(rec(l), rec(r)) + case CharType => z3.mkBVSle(rec(l), rec(r)) + //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") + } + case GreaterThan(l, r) => l.getType match { + case IntegerType => z3.mkGT(rec(l), rec(r)) + case RealType => z3.mkGT(rec(l), rec(r)) + case Int32Type => z3.mkBVSgt(rec(l), rec(r)) + case CharType => z3.mkBVSgt(rec(l), rec(r)) + } + case GreaterEquals(l, r) => l.getType match { + case IntegerType => z3.mkGE(rec(l), rec(r)) + case RealType => z3.mkGE(rec(l), rec(r)) + case Int32Type => z3.mkBVSge(rec(l), rec(r)) + case CharType => z3.mkBVSge(rec(l), rec(r)) + } + + case u : UnitLiteral => + val tpe = normalizeType(u.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor() + + case t @ Tuple(es) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor(es.map(rec): _*) + + case ts @ TupleSelect(t, i) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val selector = selectors.toB((tpe, i-1)) + selector(rec(t)) + + case c @ CaseClass(ct, args) => + typeToSort(ct) // Making sure the sort is defined + val constructor = constructors.toB(ct) + constructor(args.map(rec): _*) + + case c @ CaseClassSelector(cct, cc, sel) => + typeToSort(cct) // Making sure the sort is defined + val selector = selectors.toB(cct, c.selectorIndex) + selector(rec(cc)) + + case AsInstanceOf(expr, ct) => + rec(expr) + + case IsInstanceOf(e, act: AbstractClassType) => + act.knownCCDescendants match { + case Seq(cct) => + rec(IsInstanceOf(e, cct)) + case more => + val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) + rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) + } + + case IsInstanceOf(e, cct: CaseClassType) => + typeToSort(cct) // Making sure the sort is defined + val tester = testers.toB(cct) + tester(rec(e)) + + case al @ ArraySelect(a, i) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val content = selectors.toB((tpe, 1))(sa) + + z3.mkSelect(content, rec(i)) + + case al @ ArrayUpdated(a, i, e) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val ssize = selectors.toB((tpe, 0))(sa) + val scontent = selectors.toB((tpe, 1))(sa) + + val newcontent = z3.mkStore(scontent, rec(i), rec(e)) + + val constructor = constructors.toB(tpe) + + constructor(ssize, newcontent) + + case al @ ArrayLength(a) => + val tpe = normalizeType(a.getType) + val sa = rec(a) + selectors.toB((tpe, 0))(sa) + + case arr @ FiniteArray(elems, oDefault, length) => + val at @ ArrayType(base) = normalizeType(arr.getType) + typeToSort(at) + + val default = oDefault.getOrElse(simplestValue(base)) + + val ar = rec(RawArrayValue(Int32Type, elems.map { + case (i, e) => IntLiteral(i) -> e + }, default)) + + constructors.toB(at)(rec(length), ar) + + case f @ FunctionInvocation(tfd, args) => + z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) + + case fa @ Application(caller, args) => + val ft @ FunctionType(froms, to) = normalizeType(caller.getType) + val funDecl = lambdas.cachedB(ft) { + val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) + val returnSort = typeToSort(to) + + val name = FreshIdentifier("dynLambda").uniqueName + z3.mkFreshFuncDecl(name, sortSeq, returnSort) + } + z3.mkApp(funDecl, (caller +: args).map(rec): _*) + + case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) + case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) + case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) + case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) + case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) + case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) + + case RawArrayValue(keyTpe, elems, default) => + val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) + + elems.foldLeft(ar) { + case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) + } + + /** + * ===== Map operations ===== + */ + case m @ FiniteMap(elems, from, to) => + val MapType(_, t) = normalizeType(m.getType) + + rec(RawArrayValue(from, elems.map{ + case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) + }, CaseClass(library.noneType(t), Seq()))) + + case MapApply(m, k) => + val mt @ MapType(_, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + // Really ?!? We don't check that it is actually != None? + selectors.toB(library.someType(t), 0)(el) + + case MapIsDefinedAt(m, k) => + val mt @ MapType(_, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + testers.toB(library.someType(t))(el) + + case MapUnion(m1, FiniteMap(elems, _, _)) => + val mt @ MapType(_, t) = normalizeType(m1.getType) + typeToSort(mt) + + elems.foldLeft(rec(m1)) { case (m, (k,v)) => + z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) + } + + + case gv @ GenericValue(tp, id) => + z3.mkApp(genericValueToDecl(gv)) + + case other => + unsupported(other) + } + + rec(expr) } protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree): Expr = { - def rec(t: Z3AST, expected_tpe: TypeTree): Expr = { + + def rec(t: Z3AST, tpe: TypeTree): Expr = { val kind = z3.getASTKind(t) - val tpe = Z3StringTypeConversion.convert(expected_tpe)(program) - val res = kind match { + kind match { case Z3NumeralIntAST(Some(v)) => val leading = t.toString.substring(0, 2 min t.toString.length) if(leading == "#x") { @@ -769,11 +758,6 @@ trait AbstractZ3Solver extends Solver { } case _ => unsound(t, "unexpected AST") } - expected_tpe match { - case StringType => - StringLiteral(Z3StringTypeConversion.convertToString(res)(program)) - case _ => res - } } rec(tree, normalizeType(tpe)) @@ -790,8 +774,7 @@ trait AbstractZ3Solver extends Solver { } def idToFreshZ3Id(id: Identifier): Z3AST = { - val correctType = Z3StringTypeConversion.convert(id.getType)(program) - z3.mkFreshConst(id.uniqueName, typeToSort(correctType)) + z3.mkFreshConst(id.uniqueName, typeToSort(id.getType)) } def reset() = { diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala index 3daf1ad4964ad73e8c4d9701ae4e65d0f4170897..21df018db70935ad63b6c22e7d6bc77894005b23 100644 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala @@ -7,51 +7,132 @@ import purescala.Expressions._ import purescala.Constructors._ import purescala.Types._ import purescala.Definitions._ -import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} -import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} -import _root_.smtlib.interpreters.Z3Interpreter -import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} -import _root_.smtlib.theories.ArraysEx import leon.utils.Bijection +import leon.purescala.DefOps +import leon.purescala.TypeOps +import leon.purescala.Extractors.Operator +import leon.evaluators.EvaluationResults -object Z3StringTypeConversion { - def convert(t: TypeTree)(implicit p: Program) = new Z3StringTypeConversion { def getProgram = p }.convertType(t) - def convertToString(e: Expr)(implicit p: Program) = new Z3StringTypeConversion{ def getProgram = p }.convertToString(e) -} - -trait Z3StringTypeConversion { - val stringBijection = new Bijection[String, Expr]() +object StringEcoSystem { + private def withIdentifier[T](name: String, tpe: TypeTree = Untyped)(f: Identifier => T): T = { + val id = FreshIdentifier(name, tpe) + f(id) + } + private def withIdentifiers[T](name: String, tpe: TypeTree, name2: String, tpe2: TypeTree = Untyped)(f: (Identifier, Identifier) => T): T = { + withIdentifier(name, tpe)(id => withIdentifier(name2, tpe2)(id2 => f(id, id2))) + } - lazy val conschar = program.lookupCaseClass("leon.collection.Cons") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find Cons in Z3 solver") + val StringList = AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) + val StringListTyped = StringList.typed + val StringCons = withIdentifiers("head", CharType, "tail", StringListTyped){ (head, tail) => + val d = CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) + d.setFields(Seq(ValDef(head), ValDef(tail))) + d } - lazy val nilchar = program.lookupCaseClass("leon.collection.Nil") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find Nil in Z3 solver") + StringList.registerChild(StringCons) + val StringConsTyped = StringCons.typed + val StringNil = CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) + val StringNilTyped = StringNil.typed + StringList.registerChild(StringNil) + + val StringSize = withIdentifiers("l", StringListTyped, "StringSize"){ (lengthArg, id) => + val fd = new FunDef(id, Seq(), Seq(ValDef(lengthArg)), IntegerType) + fd.body = Some(withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(lengthArg), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, InfiniteIntegerLiteral(BigInt(0))), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + Plus(InfiniteIntegerLiteral(BigInt(1)), FunctionInvocation(fd.typed, Seq(Variable(t))))) + )) + }) + fd } - lazy val listchar = program.lookupAbstractClass("leon.collection.List") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find List in Z3 solver") + val StringListConcat = withIdentifiers("x", StringListTyped, "y", StringListTyped) { (x, y) => + val fd = new FunDef(FreshIdentifier("StringListConcat"), Seq(), Seq(ValDef(x), ValDef(y)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(x), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, Variable(y)), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + CaseClass(StringConsTyped, Seq(Variable(h), FunctionInvocation(fd.typed, Seq(Variable(t), Variable(y))))) + ))) + } + ) + fd } - def lookupFunDef(s: String): FunDef = program.lookupFunDef(s) match { - case Some(fd) => fd - case _ => throw new Exception("Could not find function "+s+" in program") + + val StringTake = withIdentifiers("tt", StringListTyped, "it", StringListTyped) { (tt, it) => + val fd = new FunDef(FreshIdentifier("StringTake"), Seq(), Seq(ValDef(tt), ValDef(it)), StringListTyped) + fd.body = Some{ + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(tt), Variable(it))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringNilTyped, Seq()), + CaseClass(StringConsTyped, Seq(Variable(h), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))))) + )))) + } + } + } + fd + } + + val StringDrop = withIdentifiers("td", StringListTyped, "id", IntegerType) { (td, id) => + val fd = new FunDef(FreshIdentifier("StringDrop"), Seq(), Seq(ValDef(td), ValDef(id)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(td), Variable(id))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringConsTyped, Seq(Variable(h), Variable(t))), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))) + )))) + }} + ) + fd } - lazy val list_size = lookupFunDef("leon.collection.List.size").typed(Seq(CharType)) - lazy val list_++ = lookupFunDef("leon.collection.List.++").typed(Seq(CharType)) - lazy val list_take = lookupFunDef("leon.collection.List.take").typed(Seq(CharType)) - lazy val list_drop = lookupFunDef("leon.collection.List.drop").typed(Seq(CharType)) - lazy val list_slice = lookupFunDef("leon.collection.List.slice").typed(Seq(CharType)) - private lazy val program = getProgram + val StringSlice = withIdentifier("s", StringListTyped) { s => withIdentifiers("from", IntegerType, "to", IntegerType) { (from, to) => + val fd = new FunDef(FreshIdentifier("StringSlice"), Seq(), Seq(ValDef(s), ValDef(from), ValDef(to)), StringListTyped) + fd.body = Some( + FunctionInvocation(StringTake.typed, + Seq(FunctionInvocation(StringDrop.typed, Seq(Variable(s), Variable(from))), + Minus(Variable(to), Variable(from))))) + fd + } } - def getProgram: Program + val classDefs = Seq(StringList, StringCons, StringNil) + val funDefs = Seq(StringSize, StringListConcat, StringTake, StringDrop, StringSlice) +} + +class Z3StringConversion(val p: Program) extends Z3StringConverters { + val stringBijection = new Bijection[String, Expr]() + + import StringEcoSystem._ + + lazy val listchar = StringList.typed + lazy val conschar = StringCons.typed + lazy val nilchar = StringNil.typed + + lazy val list_size = StringSize.typed + lazy val list_++ = StringListConcat.typed + lazy val list_take = StringTake.typed + lazy val list_drop = StringDrop.typed + lazy val list_slice = StringSlice.typed - def convertType(t: TypeTree): TypeTree = t match { - case StringType => listchar - case _ => t + def getProgram = program_with_string_methods + + lazy val program_with_string_methods = { + val p2 = DefOps.addClassDefs(p, StringEcoSystem.classDefs, p.library.Nil.get) + DefOps.addFunDefs(p2, StringEcoSystem.funDefs, p2.library.escape.get) } + def convertToString(e: Expr)(implicit p: Program): String = stringBijection.cachedA(e) { e match { @@ -59,7 +140,7 @@ trait Z3StringTypeConversion { case CaseClass(_, Seq()) => "" } } - def convertFromString(v: String) = + def convertFromString(v: String): Expr = stringBijection.cachedB(v) { v.toList.foldRight(CaseClass(nilchar, Seq())){ case (char, l) => CaseClass(conschar, Seq(CharLiteral(char), l)) @@ -67,28 +148,226 @@ trait Z3StringTypeConversion { } } -trait Z3StringConversion[TargetType] extends Z3StringTypeConversion { - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, TargetType]): TargetType - def targetApplication(fd: TypedFunDef, args: Seq[TargetType])(implicit bindings: Map[Identifier, TargetType]): TargetType +trait Z3StringConverters { self: Z3StringConversion => + import StringEcoSystem._ + val mappedVariables = new Bijection[Identifier, Identifier]() + + val globalFdMap = new Bijection[FunDef, FunDef]() - object StringConverted { - def unapply(e: Expr)(implicit bindings: Map[Identifier, TargetType]): Option[TargetType] = e match { + trait BidirectionalConverters { + def convertFunDef(fd: FunDef): FunDef + def hasIdConversion(id: Identifier): Boolean + def convertId(id: Identifier): Identifier + def isTypeToConvert(tpe: TypeTree): Boolean + def convertType(tpe: TypeTree): TypeTree + def convertPattern(pattern: Pattern): Pattern + def convertExpr(expr: Expr)(implicit bindings: Map[Identifier, Expr]): Expr + + object PatternConverted { + def unapply(e: Pattern): Option[Pattern] = Some(e match { + case InstanceOfPattern(binder, ct) => + InstanceOfPattern(binder.map(convertId), convertType(ct).asInstanceOf[ClassType]) + case WildcardPattern(binder) => + WildcardPattern(binder.map(convertId)) + case CaseClassPattern(binder, ct, subpatterns) => + CaseClassPattern(binder.map(convertId), convertType(ct).asInstanceOf[CaseClassType], subpatterns map convertPattern) + case TuplePattern(binder, subpatterns) => + TuplePattern(binder.map(convertId), subpatterns map convertPattern) + case UnapplyPattern(binder, TypedFunDef(fd, tpes), subpatterns) => + UnapplyPattern(binder.map(convertId), TypedFunDef(convertFunDef(fd), tpes map convertType), subpatterns map convertPattern) + case PatternExtractor(es, builder) => + builder(es map convertPattern) + }) + } + + object ExprConverted { + def unapply(e: Expr)(implicit bindings: Map[Identifier, Expr]): Option[Expr] = Some(e match { + case Variable(id) if bindings contains id => bindings(id).copiedFrom(e) + case Variable(id) if hasIdConversion(id) => Variable(convertId(id)).copiedFrom(e) + case Variable(id) => e + case pl@PartialLambda(mappings, default, tpe) => + PartialLambda( + mappings.map(kv => (kv._1.map(argtpe => convertExpr(argtpe)), + convertExpr(kv._2))), + default.map(d => convertExpr(d)), convertType(tpe).asInstanceOf[FunctionType]) + case Lambda(args, body) => + println("Converting Lambda :" + e) + val new_bindings = scala.collection.mutable.ListBuffer[(Identifier, Identifier)]() + val new_args = for(arg <- args) yield { + val in = arg.getType + val new_id = convertId(arg.id) + if(new_id ne arg.id) { + new_bindings += (arg.id -> new_id) + ValDef(new_id) + } else arg + } + val res = Lambda(new_args, convertExpr(body)(bindings ++ new_bindings.map(t => (t._1, Variable(t._2))))).copiedFrom(e) + res + case Let(a, expr, body) if isTypeToConvert(a.getType) => + val new_a = convertId(a) + val new_bindings = bindings + (a -> Variable(new_a)) + val expr2 = convertExpr(expr)(new_bindings) + val body2 = convertExpr(body)(new_bindings) + Let(new_a, expr2, body2).copiedFrom(e) + case CaseClass(CaseClassType(ccd, tpes), args) => + CaseClass(CaseClassType(ccd, tpes map convertType), args map convertExpr).copiedFrom(e) + case CaseClassSelector(CaseClassType(ccd, tpes), caseClass, selector) => + CaseClassSelector(CaseClassType(ccd, tpes map convertType), convertExpr(caseClass), selector).copiedFrom(e) + case MethodInvocation(rec: Expr, cd: ClassDef, TypedFunDef(fd, tpes), args: Seq[Expr]) => + MethodInvocation(convertExpr(rec), cd, TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) + case FunctionInvocation(TypedFunDef(fd, tpes), args) => + FunctionInvocation(TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) + case This(ct: ClassType) => + This(convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case IsInstanceOf(expr, ct) => + IsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case AsInstanceOf(expr, ct) => + AsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case Tuple(args) => + Tuple(for(arg <- args) yield convertExpr(arg)).copiedFrom(e) + case MatchExpr(scrutinee, cases) => + MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { + MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) + }) + case Operator(es, builder) => + val rec = convertExpr _ + val newEs = es.map(rec) + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + case e => e + }) + } + + def convertModel(model: Model): Model = { + new Model(model.ids.map{i => + val id = convertId(i) + id -> convertExpr(model(i))(Map()) + }.toMap) + } + + def convertResult(result: EvaluationResults.Result[Expr]) = { + result match { + case EvaluationResults.Successful(e) => EvaluationResults.Successful(convertExpr(e)(Map())) + case result => result + } + } + } + + object Forward extends BidirectionalConverters { + /* The conversion between functions should already have taken place */ + def convertFunDef(fd: FunDef): FunDef = { + globalFdMap.getBorElse(fd, fd) + } + def hasIdConversion(id: Identifier): Boolean = { + mappedVariables.containsA(id) + } + def convertId(id: Identifier): Identifier = { + mappedVariables.getB(id) match { + case Some(idB) => idB + case None => + if(isTypeToConvert(id.getType)) { + val new_id = FreshIdentifier(id.name, convertType(id.getType)) + mappedVariables += (id -> new_id) + new_id + } else id + } + } + def isTypeToConvert(tpe: TypeTree): Boolean = + TypeOps.exists(StringType == _)(tpe) + def convertType(tpe: TypeTree): TypeTree = + TypeOps.preMap{ case StringType => Some(StringList.typed) case e => None}(tpe) + def convertPattern(e: Pattern): Pattern = e match { + case LiteralPattern(binder, StringLiteral(s)) => + s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { + case (elem, pattern) => + CaseClassPattern(None, StringConsTyped, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) + } + case PatternConverted(e) => e + } + + /** Method which can use recursively StringConverted in its body in unapply positions */ + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { + case Variable(id) if isTypeToConvert(id.getType) => Variable(convertId(id)).copiedFrom(e) case StringLiteral(v) => // No string support for z3 at this moment. val stringEncoding = convertFromString(v) - Some(convertToTarget(stringEncoding)) + convertExpr(stringEncoding).copiedFrom(e) case StringLength(a) => - Some(targetApplication(list_size, Seq(convertToTarget(a)))) + FunctionInvocation(list_size, Seq(convertExpr(a))).copiedFrom(e) case StringConcat(a, b) => - Some(targetApplication(list_++, Seq(convertToTarget(a), convertToTarget(b)))) + FunctionInvocation(list_++, Seq(convertExpr(a), convertExpr(b))).copiedFrom(e) case SubString(a, start, Plus(start2, length)) if start == start2 => - Some(targetApplication(list_take, - Seq(targetApplication(list_drop, Seq(convertToTarget(a), convertToTarget(start))), convertToTarget(length)))) + FunctionInvocation(list_take, + Seq(FunctionInvocation(list_drop, Seq(convertExpr(a), convertExpr(start))), convertExpr(length))).copiedFrom(e) case SubString(a, start, end) => - Some(targetApplication(list_slice, Seq(convertToTarget(a), convertToTarget(start), convertToTarget(end)))) - case _ => None + FunctionInvocation(list_slice, Seq(convertExpr(a), convertExpr(start), convertExpr(end))).copiedFrom(e) + case MatchExpr(scrutinee, cases) => + MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { + MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) + }) + case ExprConverted(e) => e + } + } + + object Backward extends BidirectionalConverters { + def convertFunDef(fd: FunDef): FunDef = { + globalFdMap.getAorElse(fd, fd) + } + def hasIdConversion(id: Identifier): Boolean = { + mappedVariables.containsB(id) + } + def convertId(id: Identifier): Identifier = { + mappedVariables.getA(id) match { + case Some(idA) => idA + case None => + if(isTypeToConvert(id.getType)) { + val old_type = convertType(id.getType) + val old_id = FreshIdentifier(id.name, old_type) + mappedVariables += (old_id -> id) + old_id + } else id + } + } + def convertIdToMapping(id: Identifier): (Identifier, Variable) = { + id -> Variable(convertId(id)) } + def isTypeToConvert(tpe: TypeTree): Boolean = + TypeOps.exists(t => TypeOps.isSubtypeOf(t, StringListTyped))(tpe) + def convertType(tpe: TypeTree): TypeTree = { + TypeOps.preMap{ + case StringList | StringCons | StringNil => Some(StringType) + case e => None}(tpe) + } + def convertPattern(e: Pattern): Pattern = e match { + case CaseClassPattern(b, StringNilTyped, Seq()) => + LiteralPattern(b.map(convertId), StringLiteral("")) + case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), subpattern)) => + convertPattern(subpattern) match { + case LiteralPattern(_, StringLiteral(s)) + => LiteralPattern(b.map(convertId), StringLiteral(elem + s)) + case e => LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)) + } + case PatternConverted(e) => e + } - def apply(t: TypeTree): TypeTree = convertType(t) + + + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = + e match { + case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> + StringLiteral(convertToString(cc)(self.p)) + case FunctionInvocation(StringSize, Seq(a)) => + StringLength(convertExpr(a)).copiedFrom(e) + case FunctionInvocation(StringListConcat, Seq(a, b)) => + StringConcat(convertExpr(a), convertExpr(b)).copiedFrom(e) + case FunctionInvocation(StringTake, + Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => + val rstart = convertExpr(start) + SubString(convertExpr(a), rstart, plus(rstart, convertExpr(length))).copiedFrom(e) + case ExprConverted(e) => e + } } } \ No newline at end of file diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index 42d6c41ab82fd80e28776b5ce0deeb082212612c..78a483446624a28482772efe5752a4f6f44e1995 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -6,13 +6,10 @@ package synthesis import purescala.Expressions._ import purescala.Definitions._ import purescala.ExprOps._ -import purescala.Types.TypeTree import purescala.Common._ import purescala.Constructors._ -import purescala.Extractors._ import evaluators._ import grammars._ -import bonsai.enumerators._ import codegen._ import datagen._ import solvers._ @@ -123,9 +120,9 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { val datagen = new GrammarDataGen(evaluator, ValueGrammar) val solverDataGen = new SolverDataGen(ctx, program, (ctx, pgm) => SolverFactory(() => new FairZ3Solver(ctx, pgm))) - val generatedExamples = datagen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample(_)) + val generatedExamples = datagen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample) - val solverExamples = solverDataGen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample(_)) + val solverExamples = solverDataGen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample) ExamplesBank(generatedExamples.toSeq ++ solverExamples.toList, Nil) } @@ -196,6 +193,9 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { case (a, b, c) => None }) getOrElse { + + // If the input contains free variables, it does not provide concrete examples. + // We will instantiate them according to a simple grammar to get them. if(this.keepAbstractExamples) { cs.optGuard match { case Some(BooleanLiteral(false)) => @@ -206,34 +206,16 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { Seq((Require(pred, pattExpr), cs.rhs)) } } else { - // If the input contains free variables, it does not provide concrete examples. - // We will instantiate them according to a simple grammar to get them. - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](ValueGrammar.getProductions) - val values = enum.iterator(tupleTypeWrap(freeVars.map { _.getType })) - val instantiations = values.map { - v => freeVars.zip(unwrapTuple(v, freeVars.size)).toMap - } - - def filterGuard(e: Expr, mapping: Map[Identifier, Expr]): Boolean = cs.optGuard match { - case Some(guard) => - // in -> e should be enough. We shouldn't find any subexpressions of in. - evaluator.eval(replace(Map(in -> e), guard), mapping) match { - case EvaluationResults.Successful(BooleanLiteral(true)) => true - case _ => false - } - - case None => - true + val dataGen = new GrammarDataGen(evaluator) + + val theGuard = replace(Map(in -> pattExpr), cs.optGuard.getOrElse(BooleanLiteral(true))) + + dataGen.generateFor(freeVars, theGuard, examplesPerCase, 1000).toSeq map { vals => + val inst = freeVars.zip(vals).toMap + val inR = replaceFromIDs(inst, pattExpr) + val outR = replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs)) + (inR, outR) } - - if(cs.optGuard == Some(BooleanLiteral(false))) { - Nil - } else (for { - inst <- instantiations.toSeq - inR = replaceFromIDs(inst, pattExpr) - outR = replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs)) - if filterGuard(inR, inst) - } yield (inR, outR)).take(examplesPerCase) } } } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index cd27e272d53e9669f4c2d1f2c8e07356819b4827..3a86ca64a79238aec4e59b197e001fd4c24660b4 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -35,8 +35,10 @@ abstract class PreprocessingRule(name: String) extends Rule(name) { /** Contains the list of all available rules for synthesis */ object Rules { + + def all: List[Rule] = all(false) /** Returns the list of all available rules for synthesis */ - def all = List[Rule]( + def all(naiveGrammar: Boolean): List[Rule] = List[Rule]( StringRender, Unification.DecompTrivialClash, Unification.OccursCheck, // probably useless @@ -54,8 +56,8 @@ object Rules { OptimisticGround, EqualitySplit, InequalitySplit, - CEGIS, - TEGIS, + if(naiveGrammar) NaiveCEGIS else CEGIS, + //TEGIS, //BottomUpTEGIS, rules.Assert, DetupleOutput, diff --git a/src/main/scala/leon/synthesis/SourceInfo.scala b/src/main/scala/leon/synthesis/SourceInfo.scala index 4bb10d38c9ffc7a7667d165b84e4f65c1edc9e0c..8ab07929d78479656f18ce1fd652cfa7ef870e17 100644 --- a/src/main/scala/leon/synthesis/SourceInfo.scala +++ b/src/main/scala/leon/synthesis/SourceInfo.scala @@ -45,6 +45,10 @@ object SourceInfo { ci } + if (results.isEmpty) { + ctx.reporter.warning("No 'choose' found. Maybe the functions you chose do not exist?") + } + results.sortBy(_.source.getPos) } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index b9ba6df1f688e01edf631e995ba2f8623bcfc5fe..ac4d30614d8269ce78a92a95e232855eb40d9fbd 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -3,13 +3,11 @@ package leon package synthesis -import purescala.ExprOps._ - +import purescala.ExprOps.replace import purescala.ScalaPrinter -import leon.utils._ import purescala.Definitions.{Program, FunDef} -import leon.utils.ASCIIHelpers +import leon.utils._ import graph._ object SynthesisPhase extends TransformationPhase { @@ -21,11 +19,13 @@ object SynthesisPhase extends TransformationPhase { val optDerivTrees = LeonFlagOptionDef( "derivtrees", "Generate derivation trees", false) // CEGIS options - val optCEGISOptTimeout = LeonFlagOptionDef( "cegis:opttimeout", "Consider a time-out of CE-search as untrusted solution", true) - val optCEGISVanuatoo = LeonFlagOptionDef( "cegis:vanuatoo", "Generate inputs using new korat-style generator", false) + val optCEGISOptTimeout = LeonFlagOptionDef("cegis:opttimeout", "Consider a time-out of CE-search as untrusted solution", true ) + val optCEGISVanuatoo = LeonFlagOptionDef("cegis:vanuatoo", "Generate inputs using new korat-style generator", false) + val optCEGISNaiveGrammar = LeonFlagOptionDef("cegis:naive", "Use the old naive grammar for CEGIS", false) + val optCEGISMaxSize = LeonLongOptionDef("cegis:maxsize", "Maximum size of expressions synthesized by CEGIS", 5L, "N") override val definedOptions : Set[LeonOptionDef[Any]] = - Set(optManual, optCostModel, optDerivTrees, optCEGISOptTimeout, optCEGISVanuatoo) + Set(optManual, optCostModel, optDerivTrees, optCEGISOptTimeout, optCEGISVanuatoo, optCEGISNaiveGrammar, optCEGISMaxSize) def processOptions(ctx: LeonContext): SynthesisSettings = { val ms = ctx.findOption(optManual) @@ -53,11 +53,13 @@ object SynthesisPhase extends TransformationPhase { timeoutMs = timeout map { _ * 1000 }, generateDerivationTrees = ctx.findOptionOrDefault(optDerivTrees), costModel = costModel, - rules = Rules.all ++ (if(ms.isDefined) Seq(rules.AsChoose, rules.SygusCVC4) else Seq()), + rules = Rules.all(ctx.findOptionOrDefault(optCEGISNaiveGrammar)) ++ + (if(ms.isDefined) Seq(rules.AsChoose, rules.SygusCVC4) else Seq()), manualSearch = ms, functions = ctx.findOption(SharedOptions.optFunctions) map { _.toSet }, - cegisUseOptTimeout = ctx.findOption(optCEGISOptTimeout), - cegisUseVanuatoo = ctx.findOption(optCEGISVanuatoo) + cegisUseOptTimeout = ctx.findOptionOrDefault(optCEGISOptTimeout), + cegisUseVanuatoo = ctx.findOptionOrDefault(optCEGISVanuatoo), + cegisMaxSize = ctx.findOptionOrDefault(optCEGISMaxSize).toInt ) } @@ -80,7 +82,7 @@ object SynthesisPhase extends TransformationPhase { try { if (options.generateDerivationTrees) { - val dot = new DotGenerator(search.g) + val dot = new DotGenerator(search) dot.writeFile("derivation"+dotGenIds.nextGlobal+".dot") } diff --git a/src/main/scala/leon/synthesis/SynthesisSettings.scala b/src/main/scala/leon/synthesis/SynthesisSettings.scala index 5202818e18765ebf4086ef41d1685967a14940d0..61dc24ece71081c0f02f5bdcb38d9d9eeb0fee14 100644 --- a/src/main/scala/leon/synthesis/SynthesisSettings.scala +++ b/src/main/scala/leon/synthesis/SynthesisSettings.scala @@ -16,7 +16,8 @@ case class SynthesisSettings( functionsToIgnore: Set[FunDef] = Set(), // Cegis related options - cegisUseOptTimeout: Option[Boolean] = None, - cegisUseVanuatoo: Option[Boolean] = None + cegisUseOptTimeout: Boolean = true, + cegisUseVanuatoo : Boolean = false, + cegisMaxSize: Int = 5 ) diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index bafed6ec2bab51539bfc0547563bbad2aeea873e..efd1ad13e0538f855487e94b7b1a35d7d893627f 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -70,21 +70,19 @@ class Synthesizer(val context : LeonContext, // Print out report for synthesis, if necessary reporter.ifDebug { printer => - import java.io.FileWriter import java.text.SimpleDateFormat import java.util.Date val categoryName = ci.fd.getPos.file.toString.split("/").dropRight(1).lastOption.getOrElse("?") val benchName = categoryName+"."+ci.fd.id.name - var time = lastTime/1000.0; + val time = lastTime/1000.0 val defs = visibleDefsFrom(ci.fd)(program).collect { case cd: ClassDef => 1 + cd.fields.size case fd: FunDef => 1 + fd.params.size + formulaSize(fd.fullBody) } - val psize = defs.sum; - + val psize = defs.sum val (size, calls, proof) = result.headOption match { case Some((sol, trusted)) => diff --git a/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala b/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala index ed9f44768752c716b2acf0f6db9863cde1353068..9303b4d1ef9813a38109b35c17adadf546297252 100644 --- a/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala +++ b/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala @@ -3,16 +3,13 @@ package leon package synthesis package disambiguation -import leon.LeonContext -import leon.purescala.Expressions._ import purescala.Types.FunctionType import purescala.Common.FreshIdentifier import purescala.Constructors.{ and, tupleWrap } import purescala.Definitions.{ FunDef, Program, ValDef } import purescala.ExprOps -import purescala.Expressions.{ BooleanLiteral, Equals, Expr, Lambda, MatchCase, Passes, Variable, WildcardPattern } import purescala.Extractors.TopLevelAnds -import leon.purescala.Expressions._ +import purescala.Expressions._ /** * @author Mikael diff --git a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala index bb3dc45ffd8d3a0fa4a3dd87fc436a7b3dfde113..81f98f86432dc54ffb446f473fee3a1afcf358ca 100644 --- a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala +++ b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala @@ -1,21 +1,18 @@ package leon package synthesis.disambiguation +import datagen.GrammarDataGen import synthesis.Solution import evaluators.DefaultEvaluator import purescala.Expressions._ import purescala.ExprOps -import purescala.Constructors._ -import purescala.Extractors._ import purescala.Types.{StringType, TypeTree} import purescala.Common.Identifier import purescala.Definitions.Program import purescala.DefOps -import grammars.ValueGrammar -import bonsai.enumerators.MemoizedEnumerator +import grammars._ import solvers.ModelBuilder import scala.collection.mutable.ListBuffer -import grammars._ import evaluators.AbstractEvaluator import scala.annotation.tailrec @@ -71,15 +68,15 @@ object QuestionBuilder { /** Specific enumeration of strings, which can be used with the QuestionBuilder#setValueEnumerator method */ object SpecialStringValueGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { - case StringType => - List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("\"'\n\t")), - terminal(StringLiteral("Lara 2007")) - ) - case _ => ValueGrammar.computeProductions(t) + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { + case StringType => + List( + terminal(StringLiteral("")), + terminal(StringLiteral("a")), + terminal(StringLiteral("\"'\n\t")), + terminal(StringLiteral("Lara 2007")) + ) + case _ => ValueGrammar.computeProductions(t) } } } @@ -94,11 +91,9 @@ object QuestionBuilder { * * @tparam T A subtype of Expr that will be the type used in the Question[T] results. * @param input The identifier of the unique function's input. Must be typed or the type should be defined by setArgumentType - * @param ruleApplication The set of solutions for the body of f * @param filter A function filtering which outputs should be considered for comparison. - * It takes as input the sequence of outputs already considered for comparison, and the new output. - * It should return Some(result) if the result can be shown, and None else. - * @return An ordered + * It takes as input the sequence of outputs already considered for comparison, and the new output. + * It should return Some(result) if the result can be shown, and None else. * */ class QuestionBuilder[T <: Expr]( @@ -178,25 +173,25 @@ class QuestionBuilder[T <: Expr]( /** Returns a list of input/output questions to ask to the user. */ def result(): List[Question[T]] = { if(solutions.isEmpty) return Nil - - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree,Expr]](value_enumerator.getProductions) - val values = enum.iterator(tupleTypeWrap(_argTypes)) - val instantiations = values.map(makeGenericValuesUnique _).map { - v => input.zip(unwrapTuple(v, input.size)) - } - - val enumerated_inputs = instantiations.take(expressionsToTake).toList - + + val datagen = new GrammarDataGen(new DefaultEvaluator(c, p), value_enumerator) + val enumerated_inputs = datagen.generateMapping(input, BooleanLiteral(true), expressionsToTake, expressionsToTake) + .map(inputs => + inputs.map(id_expr => + (id_expr._1, makeGenericValuesUnique(id_expr._2)))).toList + val solution = solutions.head val alternatives = solutions.drop(1).take(solutionsToTake).toList val questions = ListBuffer[Question[T]]() - for{possible_input <- enumerated_inputs - current_output_nonfiltered <- run(solution, possible_input) - current_output <- filter(Seq(), current_output_nonfiltered)} { + for { + possibleInput <- enumerated_inputs + currentOutputNonFiltered <- run(solution, possibleInput) + currentOutput <- filter(Seq(), currentOutputNonFiltered) + } { - val alternative_outputs = ((ListBuffer[T](current_output) /: alternatives) { (prev, alternative) => - run(alternative, possible_input) match { - case Some(alternative_output) if alternative_output != current_output => + val alternative_outputs = (ListBuffer[T](currentOutput) /: alternatives) { (prev, alternative) => + run(alternative, possibleInput) match { + case Some(alternative_output) if alternative_output != currentOutput => filter(prev, alternative_output) match { case Some(alternative_output_filtered) => prev += alternative_output_filtered @@ -204,11 +199,11 @@ class QuestionBuilder[T <: Expr]( } case _ => prev } - }).drop(1).toList.distinct - if(alternative_outputs.nonEmpty || keepEmptyAlternativeQuestions(current_output)) { - questions += Question(possible_input.map(_._2), current_output, alternative_outputs.sortWith((e,f) => _alternativeSortMethod.compare(e, f) <= 0)) + }.drop(1).toList.distinct + if(alternative_outputs.nonEmpty || keepEmptyAlternativeQuestions(currentOutput)) { + questions += Question(possibleInput.map(_._2), currentOutput, alternative_outputs.sortWith((e,f) => _alternativeSortMethod.compare(e, f) <= 0)) } } questions.toList.sortBy(_questionSorMethod(_)) } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/synthesis/graph/DotGenerator.scala b/src/main/scala/leon/synthesis/graph/DotGenerator.scala index 7da38716116f51d89e751a8aa12d709be776e17c..78ef7b371487a6711d3508b9712f7806e9c551e0 100644 --- a/src/main/scala/leon/synthesis/graph/DotGenerator.scala +++ b/src/main/scala/leon/synthesis/graph/DotGenerator.scala @@ -6,7 +6,11 @@ import leon.utils.UniqueCounter import java.io.{File, FileWriter, BufferedWriter} -class DotGenerator(g: Graph) { +class DotGenerator(search: Search) { + + implicit val ctx = search.ctx + + val g = search.g private val idCounter = new UniqueCounter[Unit] idCounter.nextGlobal // Start with 1 @@ -80,12 +84,14 @@ class DotGenerator(g: Graph) { } def nodeDesc(n: Node): String = n match { - case an: AndNode => an.ri.toString - case on: OrNode => on.p.toString + case an: AndNode => an.ri.asString + case on: OrNode => on.p.asString } def drawNode(res: StringBuffer, name: String, n: Node) { + val index = n.parent.map(_.descendants.indexOf(n) + " ").getOrElse("") + def escapeHTML(str: String) = str.replaceAll("&", "&").replaceAll("<", "<").replaceAll(">", ">") val color = if (n.isSolved) { @@ -109,10 +115,10 @@ class DotGenerator(g: Graph) { res append "<TR><TD BORDER=\"0\">"+escapeHTML(n.cost.asString)+"</TD></TR>" } - res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(nodeDesc(n)))+"</TD></TR>" + res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(index + nodeDesc(n)))+"</TD></TR>" if (n.isSolved) { - res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(limit(n.generateSolutions().head.toString))+"</TD></TR>" + res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(limit(n.generateSolutions().head.asString))+"</TD></TR>" } res append "</TABLE>>, shape = \"none\" ];\n" @@ -126,4 +132,4 @@ class DotGenerator(g: Graph) { } } -object dotGenIds extends UniqueCounter[Unit] \ No newline at end of file +object dotGenIds extends UniqueCounter[Unit] diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala index 98554a5ae492972e0b7b3915979d9af829d81555..c630e315d9777110b5dcde7adc42cf6172161af3 100644 --- a/src/main/scala/leon/synthesis/graph/Search.scala +++ b/src/main/scala/leon/synthesis/graph/Search.scala @@ -10,7 +10,7 @@ import scala.collection.mutable.ArrayBuffer import leon.utils.Interruptible import java.util.concurrent.atomic.AtomicBoolean -abstract class Search(ctx: LeonContext, ci: SourceInfo, p: Problem, costModel: CostModel) extends Interruptible { +abstract class Search(val ctx: LeonContext, ci: SourceInfo, p: Problem, costModel: CostModel) extends Interruptible { val g = new Graph(costModel, p) def findNodeToExpandFrom(n: Node): Option[Node] diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index df2c44193412a55af004dfa7695901044a4b5b53..d3dc6347280a45642935e6ea3c314246a3cb6958 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -65,7 +65,7 @@ case object ADTSplit extends Rule("ADT Split.") { case Some((id, act, cases)) => val oas = p.as.filter(_ != id) - val subInfo = for(ccd <- cases) yield { + val subInfo0 = for(ccd <- cases) yield { val cct = CaseClassType(ccd, act.tps) val args = cct.fields.map { vd => FreshIdentifier(vd.id.name, vd.getType, true) }.toList @@ -89,6 +89,10 @@ case object ADTSplit extends Rule("ADT Split.") { (cct, subProblem, subPattern) } + val subInfo = subInfo0.sortBy{ case (cct, _, _) => + cct.fieldsTypes.count { t => t == act } + } + val onSuccess: List[Solution] => Option[Solution] = { case sols => diff --git a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala index 2f3869af16b71f9635e36d27774f55a7cee7140c..4c12f58224427c1d74654638e28965d746f93d54 100644 --- a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala +++ b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala @@ -14,7 +14,6 @@ import codegen.CodeGenParams import grammars._ import bonsai.enumerators._ -import bonsai.{Generator => Gen} case object BottomUpTEGIS extends BottomUpTEGISLike[TypeTree]("BU TEGIS") { def getGrammar(sctx: SynthesisContext, p: Problem) = { @@ -51,13 +50,13 @@ abstract class BottomUpTEGISLike[T <: Typed](name: String) extends Rule(name) { val nTests = tests.size - var compiled = Map[Generator[T, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() + var compiled = Map[ProductionRule[T, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() /** * Compile Generators to functions from Expr to Expr. The compiled * generators will be passed to the enumerator */ - def compile(gen: Generator[T, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { + def compile(gen: ProductionRule[T, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { compiled.getOrElse(gen, { val executor = if (gen.subTrees.isEmpty) { @@ -108,7 +107,7 @@ abstract class BottomUpTEGISLike[T <: Typed](name: String) extends Rule(name) { val targetType = tupleTypeWrap(p.xs.map(_.getType)) val wrappedTests = tests.map { case (is, os) => (is, tupleWrap(os))} - val enum = new BottomUpEnumerator[T, Expr, Expr, Generator[T, Expr]]( + val enum = new BottomUpEnumerator[T, Expr, Expr, ProductionRule[T, Expr]]( grammar.getProductions, wrappedTests, { (vecs, gen) => diff --git a/src/main/scala/leon/synthesis/rules/CEGIS.scala b/src/main/scala/leon/synthesis/rules/CEGIS.scala index 1fcf01d52088ea9d4d25d184a673ef8335a8d260..b0de64ed05458d22cc113170dc850e2c1e2f6a3b 100644 --- a/src/main/scala/leon/synthesis/rules/CEGIS.scala +++ b/src/main/scala/leon/synthesis/rules/CEGIS.scala @@ -4,16 +4,31 @@ package leon package synthesis package rules -import purescala.Types._ - import grammars._ -import utils._ +import grammars.transformers._ +import purescala.Types.TypeTree -case object CEGIS extends CEGISLike[TypeTree]("CEGIS") { +/** Basic implementation of CEGIS that uses a naive grammar */ +case object NaiveCEGIS extends CEGISLike[TypeTree]("Naive CEGIS") { def getParams(sctx: SynthesisContext, p: Problem) = { CegisParams( grammar = Grammars.typeDepthBound(Grammars.default(sctx, p), 2), // This limits type depth - rootLabel = {(tpe: TypeTree) => tpe } + rootLabel = {(tpe: TypeTree) => tpe }, + optimizations = false + ) + } +} + +/** More advanced implementation of CEGIS that uses a less permissive grammar + * and some optimizations + */ +case object CEGIS extends CEGISLike[TaggedNonTerm[TypeTree]]("CEGIS") { + def getParams(sctx: SynthesisContext, p: Problem) = { + val base = NaiveCEGIS.getParams(sctx,p).grammar + CegisParams( + grammar = TaggedGrammar(base), + rootLabel = TaggedNonTerm(_, Tags.Top, 0, None), + optimizations = true ) } } diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index d577f7f9fe1f260f4af9ca4d3cb20ca868e37fcc..291e485d70b80b580095a484e02448166de9e18c 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -4,10 +4,6 @@ package leon package synthesis package rules -import leon.utils.SeqUtils -import solvers._ -import grammars._ - import purescala.Expressions._ import purescala.Common._ import purescala.Definitions._ @@ -16,44 +12,59 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Constructors._ -import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} +import solvers._ +import grammars._ +import grammars.transformers._ +import leon.utils._ import evaluators._ import datagen._ import codegen.CodeGenParams +import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} + abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case class CegisParams( grammar: ExpressionGrammar[T], rootLabel: TypeTree => T, - maxUnfoldings: Int = 5 + optimizations: Boolean, + maxSize: Option[Int] = None ) def getParams(sctx: SynthesisContext, p: Problem): CegisParams def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + val exSolverTo = 2000L val cexSolverTo = 2000L - // Track non-deterministic programs up to 10'000 programs, or give up + // Track non-deterministic programs up to 100'000 programs, or give up val nProgramsLimit = 100000 val sctx = hctx.sctx val ctx = sctx.context + val timers = ctx.timers.synthesis.cegis + // CEGIS Flags to activate or deactivate features - val useOptTimeout = sctx.settings.cegisUseOptTimeout.getOrElse(true) - val useVanuatoo = sctx.settings.cegisUseVanuatoo.getOrElse(false) + val useOptTimeout = sctx.settings.cegisUseOptTimeout + val useVanuatoo = sctx.settings.cegisUseVanuatoo // Limits the number of programs CEGIS will specifically validate individually val validateUpTo = 3 + val passingRatio = 10 val interruptManager = sctx.context.interruptManager val params = getParams(sctx, p) - if (params.maxUnfoldings == 0) { + // If this CEGISLike forces a maxSize, take it, otherwise find it in the settings + val maxSize = params.maxSize.getOrElse(sctx.settings.cegisMaxSize) + + ctx.reporter.debug(s"This is $name. Settings: optimizations = ${params.optimizations}, maxSize = $maxSize, vanuatoo=$useVanuatoo, optTimeout=$useOptTimeout") + + if (maxSize == 0) { return Nil } @@ -61,13 +72,13 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { private var termSize = 0 - val grammar = SizeBoundedGrammar(params.grammar) + val grammar = SizeBoundedGrammar(params.grammar, params.optimizations) - def rootLabel = SizedLabel(params.rootLabel(tupleTypeWrap(p.xs.map(_.getType))), termSize) + def rootLabel = SizedNonTerm(params.rootLabel(tupleTypeWrap(p.xs.map(_.getType))), termSize) - var nAltsCache = Map[SizedLabel[T], Int]() + var nAltsCache = Map[SizedNonTerm[T], Int]() - def countAlternatives(l: SizedLabel[T]): Int = { + def countAlternatives(l: SizedNonTerm[T]): Int = { if (!(nAltsCache contains l)) { val count = grammar.getProductions(l).map { gen => gen.subTrees.map(countAlternatives).product @@ -91,18 +102,18 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { * b3 => c6 == H(c4, c5) * * c1 -> Seq( - * (b1, F(c2, c3), Set(c2, c3)) - * (b2, G(c4, c5), Set(c4, c5)) + * (b1, F(_, _), Seq(c2, c3)) + * (b2, G(_, _), Seq(c4, c5)) * ) * c6 -> Seq( - * (b3, H(c7, c8), Set(c7, c8)) + * (b3, H(_, _), Seq(c7, c8)) * ) */ private var cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]] = Map() // C identifiers corresponding to p.xs - private var rootC: Identifier = _ + private var rootC: Identifier = _ private var bs: Set[Identifier] = Set() @@ -110,19 +121,19 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { class CGenerator { - private var buffers = Map[SizedLabel[T], Stream[Identifier]]() + private var buffers = Map[SizedNonTerm[T], Stream[Identifier]]() - private var slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + private var slots = Map[SizedNonTerm[T], Int]().withDefaultValue(0) - private def streamOf(t: SizedLabel[T]): Stream[Identifier] = Stream.continually( + private def streamOf(t: SizedNonTerm[T]): Stream[Identifier] = Stream.continually( FreshIdentifier(t.asString, t.getType, true) ) def rewind(): Unit = { - slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + slots = Map[SizedNonTerm[T], Int]().withDefaultValue(0) } - def getNext(t: SizedLabel[T]) = { + def getNext(t: SizedNonTerm[T]) = { if (!(buffers contains t)) { buffers += t -> streamOf(t) } @@ -140,13 +151,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { def updateCTree(): Unit = { + ctx.timers.synthesis.cegis.updateCTree.start() def freshB() = { val id = FreshIdentifier("B", BooleanType, true) bs += id id } - def defineCTreeFor(l: SizedLabel[T], c: Identifier): Unit = { + def defineCTreeFor(l: SizedNonTerm[T], c: Identifier): Unit = { if (!(cTree contains c)) { val cGen = new CGenerator() @@ -182,11 +194,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { sctx.reporter.ifDebug { printer => printer("Grammar so far:") grammar.printProductions(printer) + printer("") } bsOrdered = bs.toSeq.sorted + excludedPrograms = ArrayBuffer() setCExpr(computeCExpr()) + ctx.timers.synthesis.cegis.updateCTree.stop() } /** @@ -233,9 +248,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { cache(c) } - SeqUtils.cartesianProduct(seqs).map { ls => - ls.foldLeft(Set[Identifier]())(_ ++ _) - } + SeqUtils.cartesianProduct(seqs).map(_.flatten.toSet) } allProgramsFor(Seq(rootC)) @@ -287,7 +300,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case (e, alt) => IfExpr(alt._1.toVariable, exprOf(alt), e) } } else { - Error(c.getType, "Impossibru") + Error(c.getType, s"Empty production rule: $c") } cToFd(c).fullBody = body @@ -325,11 +338,10 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { solFd.fullBody = Ensuring( FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), - Lambda(p.xs.map(ValDef(_)), p.phi) + Lambda(p.xs.map(ValDef), p.phi) ) - - phiFd.body = Some( + phiFd.body = Some( letTuple(p.xs, FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), p.phi) @@ -373,46 +385,56 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { private val innerPhi = outerExprToInnerExpr(p.phi) private var programCTree: Program = _ - private var tester: (Example, Set[Identifier]) => EvaluationResults.Result[Expr] = _ + + private var evaluator: DefaultEvaluator = _ private def setCExpr(cTreeInfo: (Expr, Seq[FunDef])): Unit = { val (cTree, newFds) = cTreeInfo cTreeFd.body = Some(cTree) programCTree = addFunDefs(innerProgram, newFds, cTreeFd) + evaluator = new DefaultEvaluator(sctx.context, programCTree) //println("-- "*30) //println(programCTree.asString) //println(".. "*30) + } - //val evaluator = new DualEvaluator(sctx.context, programCTree, CodeGenParams.default) - val evaluator = new DefaultEvaluator(sctx.context, programCTree) - - tester = - { (ex: Example, bValues: Set[Identifier]) => - // TODO: Test output value as well - val envMap = bs.map(b => b -> BooleanLiteral(bValues(b))).toMap - - ex match { - case InExample(ins) => - val fi = FunctionInvocation(phiFd.typed, ins) - evaluator.eval(fi, envMap) + def testForProgram(bValues: Set[Identifier])(ex: Example): Boolean = { - case InOutExample(ins, outs) => - val fi = FunctionInvocation(cTreeFd.typed, ins) - val eq = equality(fi, tupleWrap(outs)) - evaluator.eval(eq, envMap) - } - } - } + val origImpl = cTreeFd.fullBody + val outerSol = getExpr(bValues) + val innerSol = outerExprToInnerExpr(outerSol) + val cnstr = letTuple(p.xs, innerSol, innerPhi) + cTreeFd.fullBody = innerSol + + timers.testForProgram.start() + val res = ex match { + case InExample(ins) => + evaluator.eval(cnstr, p.as.zip(ins).toMap) + + case InOutExample(ins, outs) => + val eq = equality(innerSol, tupleWrap(outs)) + evaluator.eval(eq, p.as.zip(ins).toMap) + } + timers.testForProgram.stop() + cTreeFd.fullBody = origImpl - def testForProgram(bValues: Set[Identifier])(ex: Example): Boolean = { - tester(ex, bValues) match { + res match { case EvaluationResults.Successful(res) => res == BooleanLiteral(true) case EvaluationResults.RuntimeError(err) => + /*if (err.contains("Empty production rule")) { + println(programCTree.asString) + println(bValues) + println(ex) + println(this.getExpr(bValues)) + (new Throwable).printStackTrace() + println(err) + println() + }*/ sctx.reporter.debug("RE testing CE: "+err) false @@ -420,18 +442,18 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { sctx.reporter.debug("Error testing CE: "+err) false } - } - + } // Returns the outer expression corresponding to a B-valuation def getExpr(bValues: Set[Identifier]): Expr = { + def getCValue(c: Identifier): Expr = { cTree(c).find(i => bValues(i._1)).map { case (b, builder, cs) => builder(cs.map(getCValue)) }.getOrElse { - simplestValue(c.getType) + Error(c.getType, "Impossible assignment of bs") } } @@ -445,60 +467,70 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { def validatePrograms(bss: Set[Set[Identifier]]): Either[Stream[Solution], Seq[Seq[Expr]]] = { val origImpl = cTreeFd.fullBody - val cexs = for (bs <- bss.toSeq) yield { + var cexs = Seq[Seq[Expr]]() + + for (bs <- bss.toSeq) { val outerSol = getExpr(bs) val innerSol = outerExprToInnerExpr(outerSol) - + //println(s"Testing $outerSol") cTreeFd.fullBody = innerSol val cnstr = and(innerPc, letTuple(p.xs, innerSol, Not(innerPhi))) - //println("Solving for: "+cnstr.asString) + val eval = new DefaultEvaluator(ctx, innerProgram) - val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) - val solver = solverf.getNewSolver() - try { - solver.assertCnstr(cnstr) - solver.check match { - case Some(true) => - excludeProgram(bs, true) - val model = solver.getModel - //println("Found counter example: ") - //for ((s, v) <- model) { - // println(" "+s.asString+" -> "+v.asString) - //} - - //val evaluator = new DefaultEvaluator(ctx, prog) - //println(evaluator.eval(cnstr, model)) - - Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))) - - case Some(false) => - // UNSAT, valid program - return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) + if (cexs exists (cex => eval.eval(cnstr, p.as.zip(cex).toMap).result == Some(BooleanLiteral(true)))) { + //println(s"Program $outerSol fails!") + excludeProgram(bs, true) + cTreeFd.fullBody = origImpl + } else { + //println("Solving for: "+cnstr.asString) + + val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) + val solver = solverf.getNewSolver() + try { + solver.assertCnstr(cnstr) + solver.check match { + case Some(true) => + excludeProgram(bs, true) + val model = solver.getModel + //println("Found counter example: ") + //for ((s, v) <- model) { + // println(" "+s.asString+" -> "+v.asString) + //} + + //val evaluator = new DefaultEvaluator(ctx, prog) + //println(evaluator.eval(cnstr, model)) + //println(s"Program $outerSol fails with cex ${p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))}") + cexs +:= p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) + + case Some(false) => + // UNSAT, valid program + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) - case None => - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - // Optimistic valid solution - return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) - } else { - None - } + case None => + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") + // Optimistic valid solution + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) + } + } + } finally { + solverf.reclaim(solver) + solverf.shutdown() + cTreeFd.fullBody = origImpl } - } finally { - solverf.reclaim(solver) - solverf.shutdown() - cTreeFd.fullBody = origImpl } } - Right(cexs.flatten) + Right(cexs) } var excludedPrograms = ArrayBuffer[Set[Identifier]]() + def allProgramsClosed = allProgramsCount() <= excludedPrograms.size + // Explicitly remove program computed by bValues from the search space // // If the bValues comes from models, we make sure the bValues we exclude @@ -542,9 +574,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { //println(" --- Constraints ---") //println(" - "+toFind.asString) try { - //TODO: WHAT THE F IS THIS? - //val bsOrNotBs = andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable))) - //solver.assertCnstr(bsOrNotBs) solver.assertCnstr(toFind) for ((c, alts) <- cTree) { @@ -660,9 +689,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ndProgram.init() var unfolding = 1 - val maxUnfoldings = params.maxUnfoldings - - sctx.reporter.debug(s"maxUnfoldings=$maxUnfoldings") var baseExampleInputs: ArrayBuffer[Example] = new ArrayBuffer[Example]() @@ -670,7 +696,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ndProgram.grammar.printProductions(printer) } - // We populate the list of examples with a predefined one + // We populate the list of examples with a defined one sctx.reporter.debug("Acquiring initial list of examples") baseExampleInputs ++= p.eb.examples @@ -708,7 +734,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } - /** * We generate tests for discarding potential programs */ @@ -738,8 +763,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { try { do { - var skipCESearch = false - // Unfold formula ndProgram.unfold() @@ -748,6 +771,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { val nInitial = prunedPrograms.size sctx.reporter.debug("#Programs: "+nInitial) + //sctx.reporter.ifDebug{ printer => // val limit = 100 @@ -764,34 +788,33 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { // We further filter the set of working programs to remove those that fail on known examples if (hasInputExamples) { + timers.filter.start() for (bs <- prunedPrograms if !interruptManager.isInterrupted) { - var valid = true val examples = allInputExamples() - while(valid && examples.hasNext) { - val e = examples.next() - if (!ndProgram.testForProgram(bs)(e)) { - failedTestsStats(e) += 1 - sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}") - wrongPrograms += bs - prunedPrograms -= bs - - valid = false - } + examples.find(e => !ndProgram.testForProgram(bs)(e)).foreach { e => + failedTestsStats(e) += 1 + sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}") + wrongPrograms += bs + prunedPrograms -= bs } if (wrongPrograms.size+1 % 1000 == 0) { sctx.reporter.debug("..."+wrongPrograms.size) } } + timers.filter.stop() } val nPassing = prunedPrograms.size - sctx.reporter.debug("#Programs passing tests: "+nPassing) + val nTotal = ndProgram.allProgramsCount() + //println(s"Iotal: $nTotal, passing: $nPassing") + + sctx.reporter.debug(s"#Programs passing tests: $nPassing out of $nTotal") sctx.reporter.ifDebug{ printer => - for (p <- prunedPrograms.take(10)) { + for (p <- prunedPrograms.take(100)) { printer(" - "+ndProgram.getExpr(p).asString) } - if(nPassing > 10) { + if(nPassing > 100) { printer(" - ...") } } @@ -805,94 +828,86 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } + // We can skip CE search if - we have excluded all programs or - we do so with validatePrograms + var skipCESearch = nPassing == 0 || interruptManager.isInterrupted || { + // If the number of pruned programs is very small, or by far smaller than the number of total programs, + // we hypothesize it will be easier to just validate them individually. + // Otherwise, we validate a small number of programs just in case we are lucky FIXME is this last clause useful? + val (programsToValidate, otherPrograms) = if (nTotal / nPassing > passingRatio || nPassing < 10) { + (prunedPrograms, Nil) + } else { + prunedPrograms.splitAt(validateUpTo) + } - if (nPassing == 0 || interruptManager.isInterrupted) { - // No test passed, we can skip solver and unfold again, if possible - skipCESearch = true - } else { - var doFilter = true - - if (validateUpTo > 0) { - // Validate the first N programs individualy - ndProgram.validatePrograms(prunedPrograms.take(validateUpTo)) match { - case Left(sols) if sols.nonEmpty => - doFilter = false - result = Some(RuleClosed(sols)) - case Right(cexs) => - baseExampleInputs ++= cexs.map(InExample) - - if (nPassing <= validateUpTo) { - // All programs failed verification, we filter everything out and unfold - doFilter = false - skipCESearch = true + ndProgram.validatePrograms(programsToValidate) match { + case Left(sols) if sols.nonEmpty => + // Found solution! Exit CEGIS + result = Some(RuleClosed(sols)) + true + case Right(cexs) => + // Found some counterexamples + val newCexs = cexs.map(InExample) + baseExampleInputs ++= newCexs + // Retest whether the newly found C-E invalidates some programs + for (p <- otherPrograms if !interruptManager.isInterrupted) { + // Exclude any programs that fail at least one new cex + newCexs.find { cex => !ndProgram.testForProgram(p)(cex) }.foreach { cex => + failedTestsStats(cex) += 1 + ndProgram.excludeProgram(p, true) } - } + } + // If we excluded all programs, we can skip CE search + programsToValidate.size >= nPassing } + } - if (doFilter) { - sctx.reporter.debug("Excluding "+wrongPrograms.size+" programs") - wrongPrograms.foreach { - ndProgram.excludeProgram(_, true) - } + if (!skipCESearch) { + sctx.reporter.debug("Excluding "+wrongPrograms.size+" programs") + wrongPrograms.foreach { + ndProgram.excludeProgram(_, true) } } // CEGIS Loop at a given unfolding level - while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted) { + while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted && !ndProgram.allProgramsClosed) { + timers.loop.start() ndProgram.solveForTentativeProgram() match { case Some(Some(bs)) => - // Should we validate this program with Z3? - - val validateWithZ3 = if (hasInputExamples) { - - if (allInputExamples().forall(ndProgram.testForProgram(bs))) { - // All valid inputs also work with this, we need to - // make sure by validating this candidate with z3 - true - } else { - println("testing failed ?!") - // One valid input failed with this candidate, we can skip + // No inputs to test or all valid inputs also work with this. + // We need to make sure by validating this candidate with z3 + sctx.reporter.debug("Found tentative model, need to validate!") + ndProgram.solveForCounterExample(bs) match { + case Some(Some(inputsCE)) => + sctx.reporter.debug("Found counter-example:" + inputsCE) + val ce = InExample(inputsCE) + // Found counter example! Exclude this program + baseExampleInputs += ce ndProgram.excludeProgram(bs, false) - false - } - } else { - // No inputs or capability to test, we need to ask Z3 - true - } - sctx.reporter.debug("Found tentative model (Validate="+validateWithZ3+")!") - - if (validateWithZ3) { - ndProgram.solveForCounterExample(bs) match { - case Some(Some(inputsCE)) => - sctx.reporter.debug("Found counter-example:"+inputsCE) - val ce = InExample(inputsCE) - // Found counter example! - baseExampleInputs += ce - - // Retest whether the newly found C-E invalidates all programs - if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(ce))) { - skipCESearch = true - } else { - ndProgram.excludeProgram(bs, false) - } - - case Some(None) => - // Found no counter example! Program is a valid solution + + // Retest whether the newly found C-E invalidates some programs + prunedPrograms.foreach { p => + if (!ndProgram.testForProgram(p)(ce)) ndProgram.excludeProgram(p, true) + } + + case Some(None) => + // Found no counter example! Program is a valid solution + val expr = ndProgram.getExpr(bs) + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) + + case None => + // We are not sure + sctx.reporter.debug("Unknown") + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") val expr = ndProgram.getExpr(bs) - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) - - case None => - // We are not sure - sctx.reporter.debug("Unknown") - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - val expr = ndProgram.getExpr(bs) - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) - } else { - result = Some(RuleFailed()) - } - } + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) + } else { + // Ok, we failed to validate, exclude this program + ndProgram.excludeProgram(bs, false) + // TODO: Make CEGIS fail early when it fails on 1 program? + // result = Some(RuleFailed()) + } } case Some(None) => @@ -901,11 +916,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case None => result = Some(RuleFailed()) } + + timers.loop.stop() } unfolding += 1 - } while(unfolding <= maxUnfoldings && result.isEmpty && !interruptManager.isInterrupted) + } while(unfolding <= maxSize && result.isEmpty && !interruptManager.isInterrupted) + if (interruptManager.isInterrupted) interruptManager.recoverInterrupt() result.getOrElse(RuleFailed()) } catch { diff --git a/src/main/scala/leon/synthesis/rules/CEGLESS.scala b/src/main/scala/leon/synthesis/rules/CEGLESS.scala index c12edac075bc8525d395d5f792ef4579c0d109f1..36cc7f9e65dae8af9d8c17d4db936dc4400c0ece 100644 --- a/src/main/scala/leon/synthesis/rules/CEGLESS.scala +++ b/src/main/scala/leon/synthesis/rules/CEGLESS.scala @@ -4,10 +4,10 @@ package leon package synthesis package rules +import leon.grammars.transformers.Union import purescala.ExprOps._ import purescala.Types._ import purescala.Extractors._ -import utils._ import grammars._ import Witnesses._ @@ -24,7 +24,7 @@ case object CEGLESS extends CEGISLike[NonTerminal[String]]("CEGLESS") { val inputs = p.as.map(_.toVariable) sctx.reporter.ifDebug { printer => - printer("Guides available:") + printer("Guides available:") for (g <- guides) { printer(" - "+g.asString(ctx)) } @@ -35,7 +35,8 @@ case object CEGLESS extends CEGISLike[NonTerminal[String]]("CEGLESS") { CegisParams( grammar = guidedGrammar, rootLabel = { (tpe: TypeTree) => NonTerminal(tpe, "G0") }, - maxUnfoldings = (0 +: guides.map(depth(_) + 1)).max + optimizations = false, + maxSize = Some((0 +: guides.map(depth(_) + 1)).max) ) } } diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala index 2ae2b1d5d0292a6ed725055e61b1b4af4100a63c..d3b4c823dd7110763316d121407bcf94820c5826 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala @@ -83,7 +83,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { } } - var eb = p.qeb.mapIns { info => + val eb = p.qeb.mapIns { info => List(info.flatMap { case (id, v) => ebMapInfo.get(id) match { case Some(m) => @@ -103,7 +103,8 @@ case object DetupleInput extends NormalizingRule("Detuple In") { case CaseClass(ct, args) => val (cts, es) = args.zip(ct.fields).map { case (CaseClassSelector(ct, e, id), field) if field.id == id => (ct, e) - case _ => return e + case _ => + return e }.unzip if (cts.distinct.size == 1 && es.distinct.size == 1) { @@ -126,7 +127,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { val sub = Problem(newAs, subWs, subPc, subProblem, p.xs, eb) - val s = {substAll(reverseMap, _:Expr)} andThen { simplePostTransform(recompose) } + val s = (substAll(reverseMap, _:Expr)) andThen simplePostTransform(recompose) Some(decomp(List(sub), forwardMap(s), s"Detuple ${reverseMap.keySet.mkString(", ")}")) } else { diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index 05a49a9064142bf537c3789ff28df948511f8dea..62878dd31761050b839f794beb066bceca051a56 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -7,35 +7,32 @@ package rules import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import bonsai.enumerators.MemoizedEnumerator -import leon.evaluators.DefaultEvaluator -import leon.evaluators.AbstractEvaluator -import leon.synthesis.programsets.DirectProgramSet -import leon.synthesis.programsets.JoinProgramSet -import leon.purescala.Common.FreshIdentifier -import leon.purescala.Common.Identifier -import leon.purescala.DefOps -import leon.purescala.Definitions.FunDef -import leon.purescala.Definitions.FunDef -import leon.purescala.Definitions.ValDef -import leon.purescala.ExprOps -import leon.solvers.Model -import leon.solvers.ModelBuilder -import leon.solvers.string.StringSolver -import leon.utils.DebugSectionSynthesis +import evaluators.DefaultEvaluator +import evaluators.AbstractEvaluator +import purescala.Definitions.{FunDef, ValDef, Program, TypedFunDef, CaseClassDef, AbstractClassDef} +import purescala.Common._ +import purescala.Types._ import purescala.Constructors._ -import purescala.Definitions._ -import purescala.ExprOps._ import purescala.Expressions._ import purescala.Extractors._ import purescala.TypeOps -import purescala.Types._ -import leon.purescala.SelfPrettyPrinter +import purescala.DefOps +import purescala.ExprOps +import purescala.SelfPrettyPrinter +import solvers.Model +import solvers.ModelBuilder +import solvers.string.StringSolver +import synthesis.programsets.DirectProgramSet +import synthesis.programsets.JoinProgramSet +import leon.utils.DebugSectionSynthesis + + /** A template generator for a given type tree. * Extend this class using a concrete type tree, * Then use the apply method to get a hole which can be a placeholder for holes in the template. - * Each call to the ``.instantiate` method of the subsequent Template will provide different instances at each position of the hole. + * Each call to the `.instantiate` method of the subsequent Template will provide different instances at each position of the hole. */ abstract class TypedTemplateGenerator(t: TypeTree) { import StringRender.WithIds diff --git a/src/main/scala/leon/synthesis/rules/TEGIS.scala b/src/main/scala/leon/synthesis/rules/TEGIS.scala index d7ec34617ee7dc50745c3b6839511e2c00a6037e..3d496d0597e1947af0eb83504be5af449d7854f1 100644 --- a/src/main/scala/leon/synthesis/rules/TEGIS.scala +++ b/src/main/scala/leon/synthesis/rules/TEGIS.scala @@ -6,7 +6,6 @@ package rules import purescala.Types._ import grammars._ -import utils._ case object TEGIS extends TEGISLike[TypeTree]("TEGIS") { def getParams(sctx: SynthesisContext, p: Problem) = { diff --git a/src/main/scala/leon/synthesis/rules/TEGISLike.scala b/src/main/scala/leon/synthesis/rules/TEGISLike.scala index 91084ae4f6d69d055c36f0ce2c75bc4b41bfa763..93e97de6f1ad97c40def5b77c0d79fbb60282633 100644 --- a/src/main/scala/leon/synthesis/rules/TEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/TEGISLike.scala @@ -12,6 +12,7 @@ import datagen._ import evaluators._ import codegen.CodeGenParams import grammars._ +import leon.utils.GrowableIterable import scala.collection.mutable.{HashMap => MutableMap} @@ -40,7 +41,7 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val nTests = if (p.pc == BooleanLiteral(true)) 50 else 20 - val useVanuatoo = sctx.settings.cegisUseVanuatoo.getOrElse(false) + val useVanuatoo = sctx.settings.cegisUseVanuatoo val inputGenerator: Iterator[Seq[Expr]] = if (useVanuatoo) { new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, nTests, 3000) @@ -53,8 +54,6 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val failedTestsStats = new MutableMap[Seq[Expr], Int]().withDefaultValue(0) - def hasInputExamples = gi.nonEmpty - var n = 1 def allInputExamples() = { if (n == 10 || n == 50 || n % 500 == 0) { @@ -64,14 +63,12 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { gi.iterator } - var tests = p.eb.valids.map(_.ins).distinct - if (gi.nonEmpty) { - val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000) - val evaluator = new DualEvaluator(sctx.context, sctx.program, evalParams) + val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000) + val evaluator = new DualEvaluator(sctx.context, sctx.program, evalParams) - val enum = new MemoizedEnumerator[T, Expr, Generator[T, Expr]](grammar.getProductions) + val enum = new MemoizedEnumerator[T, Expr, ProductionRule[T, Expr]](grammar.getProductions) val targetType = tupleTypeWrap(p.xs.map(_.getType)) @@ -80,7 +77,6 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val allExprs = enum.iterator(params.rootLabel(targetType)) var candidate: Option[Expr] = None - var n = 1 def findNext(): Option[Expr] = { candidate = None @@ -111,14 +107,9 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { candidate } - def toStream: Stream[Solution] = { - findNext() match { - case Some(e) => - Stream.cons(Solution(BooleanLiteral(true), Set(), e, isTrusted = false), toStream) - case None => - Stream.empty - } - } + val toStream = Stream.continually(findNext()).takeWhile(_.nonEmpty).map( e => + Solution(BooleanLiteral(true), Set(), e.get, isTrusted = false) + ) RuleClosed(toStream) } else { diff --git a/src/main/scala/leon/synthesis/utils/Helpers.scala b/src/main/scala/leon/synthesis/utils/Helpers.scala index acd285a4570f93ee9dd85ba3dd29a7e4b120c25a..4bfedc4acbe59440ac7f3382c8187ae201775f02 100644 --- a/src/main/scala/leon/synthesis/utils/Helpers.scala +++ b/src/main/scala/leon/synthesis/utils/Helpers.scala @@ -34,7 +34,18 @@ object Helpers { } } - def terminatingCalls(prog: Program, tpe: TypeTree, ws: Expr, pc: Expr): List[(Expr, Set[Identifier])] = { + /** Given an initial set of function calls provided by a list of [[Terminating]], + * returns function calls that will hopefully be safe to call recursively from within this initial function calls. + * + * For each returned call, one argument is substituted by a "smaller" one, while the rest are left as holes. + * + * @param prog The current program + * @param tpe The expected type for the returned function calls + * @param ws Helper predicates that contain [[Terminating]]s with the initial calls + * @param pc The path condition + * @return A list of pairs of (safe function call, holes), where holes stand for the rest of the arguments of the function. + */ + def terminatingCalls(prog: Program, tpe: TypeTree, ws: Expr, pc: Expr): List[(FunctionInvocation, Set[Identifier])] = { val TopLevelAnds(wss) = ws val TopLevelAnds(clauses) = pc diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala index 57a62b665c797b10fab2d099fabd3a722f6e7d27..3680930639a2cfba46490d4a21bab7772d7fd0c8 100644 --- a/src/main/scala/leon/utils/Bijection.scala +++ b/src/main/scala/leon/utils/Bijection.scala @@ -11,8 +11,13 @@ class Bijection[A, B] { b2a += b -> a } - def +=(t: (A,B)): Unit = { - this += (t._1, t._2) + def +=(t: (A,B)): this.type = { + +=(t._1, t._2) + this + } + + def ++=(t: Iterable[(A,B)]) = { + (this /: t){ case (b, elem) => b += elem } } def clear(): Unit = { @@ -22,6 +27,9 @@ class Bijection[A, B] { def getA(b: B) = b2a.get(b) def getB(a: A) = a2b.get(a) + + def getAorElse(b: B, orElse: =>A) = b2a.getOrElse(b, orElse) + def getBorElse(a: A, orElse: =>B) = a2b.getOrElse(a, orElse) def toA(b: B) = getA(b).get def toB(a: A) = getB(a).get diff --git a/src/main/scala/leon/utils/GrowableIterable.scala b/src/main/scala/leon/utils/GrowableIterable.scala index d05a9f06576a9e3748ba0ba5fdd33656cc9ac457..0b32fe6261b3bd41cf6bb8ad11fcc6161b47d44b 100644 --- a/src/main/scala/leon/utils/GrowableIterable.scala +++ b/src/main/scala/leon/utils/GrowableIterable.scala @@ -1,4 +1,4 @@ -package leon +package leon.utils import scala.collection.mutable.ArrayBuffer diff --git a/src/main/scala/leon/utils/InliningPhase.scala b/src/main/scala/leon/utils/InliningPhase.scala index 8053a8dc1e9956c9ef264c247d41f8d96ffbb934..17fdf48bc88d88cca9c49e63672af6f3c76afa48 100644 --- a/src/main/scala/leon/utils/InliningPhase.scala +++ b/src/main/scala/leon/utils/InliningPhase.scala @@ -5,7 +5,7 @@ package leon.utils import leon._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.TypeOps._ +import purescala.TypeOps.instantiateType import purescala.ExprOps._ import purescala.DefOps._ import purescala.Constructors.caseClassSelector diff --git a/src/main/scala/leon/utils/SeqUtils.scala b/src/main/scala/leon/utils/SeqUtils.scala index 002f2ebedc8a6dfb265fbf101c2185b3bfa17ce1..f2290a68d11bc668c348af3954af16b95f0f7d88 100644 --- a/src/main/scala/leon/utils/SeqUtils.scala +++ b/src/main/scala/leon/utils/SeqUtils.scala @@ -34,7 +34,10 @@ object SeqUtils { } def sumTo(sum: Int, arity: Int): Seq[Seq[Int]] = { - if (arity == 1) { + require(arity >= 1) + if (sum < arity) { + Nil + } else if (arity == 1) { Seq(Seq(sum)) } else { (1 until sum).flatMap{ n => @@ -42,6 +45,20 @@ object SeqUtils { } } } + + def sumToOrdered(sum: Int, arity: Int): Seq[Seq[Int]] = { + def rec(sum: Int, arity: Int): Seq[Seq[Int]] = { + require(arity > 0) + if (sum < 0) Nil + else if (arity == 1) Seq(Seq(sum)) + else for { + n <- 0 to sum / arity + rest <- rec(sum - arity * n, arity - 1) + } yield n +: rest.map(n + _) + } + + rec(sum, arity) filterNot (_.head == 0) + } } class CartesianView[+A](views: Seq[SeqView[A, Seq[A]]]) extends SeqView[Seq[A], Seq[Seq[A]]] { diff --git a/src/main/scala/leon/utils/UniqueCounter.scala b/src/main/scala/leon/utils/UniqueCounter.scala index 06a6c0bb4b1badd63df38c3285c5fd8514d249fb..7c7862747271a67d899b9a590bc2d9c5fbb7de40 100644 --- a/src/main/scala/leon/utils/UniqueCounter.scala +++ b/src/main/scala/leon/utils/UniqueCounter.scala @@ -17,4 +17,5 @@ class UniqueCounter[K] { globalId } + def current = nameIds } diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index f4f603393728dd5a7b748c990486533b1cd18db6..45fa8bea46c71643c68a5a69f5f18e9318c4c449 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -125,7 +125,7 @@ object UnitElimination extends TransformationPhase { } } - LetDef(newFds, rest) + letDef(newFds, rest) } case ite@IfExpr(cond, tExpr, eExpr) => diff --git a/src/main/scala/leon/verification/InjectAsserts.scala b/src/main/scala/leon/verification/InjectAsserts.scala index 4e126827bd6cf352692c43e8433857b8894615d4..1bd9a695788877bd2a0034ec151294daddd5ab59 100644 --- a/src/main/scala/leon/verification/InjectAsserts.scala +++ b/src/main/scala/leon/verification/InjectAsserts.scala @@ -8,7 +8,6 @@ import Expressions._ import ExprOps._ import Definitions._ import Constructors._ -import xlang.Expressions._ object InjectAsserts extends SimpleLeonPhase[Program, Program] { diff --git a/src/main/scala/leon/xlang/AntiAliasingPhase.scala b/src/main/scala/leon/xlang/AntiAliasingPhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..7eb391088ea64c276aa0bbf7836e4a9511aa978c --- /dev/null +++ b/src/main/scala/leon/xlang/AntiAliasingPhase.scala @@ -0,0 +1,383 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ +package leon.xlang + +import leon.TransformationPhase +import leon.LeonContext +import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.Expressions._ +import leon.purescala.ExprOps._ +import leon.purescala.DefOps._ +import leon.purescala.Types._ +import leon.purescala.Constructors._ +import leon.purescala.Extractors._ +import leon.xlang.Expressions._ + +object AntiAliasingPhase extends TransformationPhase { + + val name = "Anti-Aliasing" + val description = "Make aliasing explicit" + + override def apply(ctx: LeonContext, pgm: Program): Program = { + val fds = allFunDefs(pgm) + fds.foreach(fd => checkAliasing(fd)(ctx)) + + var updatedFunctions: Map[FunDef, FunDef] = Map() + + val effects = effectsAnalysis(pgm) + + //for each fun def, all the vars the the body captures. Only + //mutable types. + val varsInScope: Map[FunDef, Set[Identifier]] = (for { + fd <- fds + } yield { + val allFreeVars = fd.body.map(bd => variablesOf(bd)).getOrElse(Set()) + val freeVars = allFreeVars -- fd.params.map(_.id) + val mutableFreeVars = freeVars.filter(id => id.getType.isInstanceOf[ArrayType]) + (fd, mutableFreeVars) + }).toMap + + /* + * The first pass will introduce all new function definitions, + * so that in the next pass we can update function invocations + */ + for { + fd <- fds + } { + updatedFunctions += (fd -> updateFunDef(fd, effects)(ctx)) + } + + for { + fd <- fds + } { + updateBody(fd, effects, updatedFunctions, varsInScope)(ctx) + } + + val res = replaceFunDefs(pgm)(fd => updatedFunctions.get(fd), (fi, fd) => None) + //println(res._1) + res._1 + } + + /* + * Create a new FunDef for a given FunDef in the program. + * Adapt the signature to express its effects. In case the + * function has no effect, this will still introduce a fresh + * FunDef as the body might have to be updated anyway. + */ + private def updateFunDef(fd: FunDef, effects: Effects)(ctx: LeonContext): FunDef = { + + val ownEffects = effects(fd) + val aliasedParams: Seq[Identifier] = fd.params.zipWithIndex.flatMap{ + case (vd, i) if ownEffects.contains(i) => Some(vd) + case _ => None + }.map(_.id) + + fd.body.foreach(body => getReturnedExpr(body).foreach{ + case v@Variable(id) if aliasedParams.contains(id) => + ctx.reporter.fatalError(v.getPos, "Cannot return a shared reference to a mutable object") + case _ => () + }) + //val allBodies: Set[Expr] = + // fd.body.toSet.flatMap((bd: Expr) => nestedFunDefsOf(bd).flatMap(_.body)) ++ fd.body + //allBodies.foreach(body => getReturnedExpr(body).foreach{ + // case v@Variable(id) if aliasedParams.contains(id) => + // ctx.reporter.fatalError(v.getPos, "Cannot return a shared reference to a mutable object: "k+ v) + // case _ => () + //}) + + val newReturnType: TypeTree = if(aliasedParams.isEmpty) fd.returnType else { + tupleTypeWrap(fd.returnType +: aliasedParams.map(_.getType)) + } + val newFunDef = new FunDef(fd.id.freshen, fd.tparams, fd.params, newReturnType) + newFunDef.addFlags(fd.flags) + newFunDef.setPos(fd) + newFunDef + } + + private def updateBody(fd: FunDef, effects: Effects, updatedFunDefs: Map[FunDef, FunDef], varsInScope: Map[FunDef, Set[Identifier]]) + (ctx: LeonContext): Unit = { + + val ownEffects = effects(fd) + val aliasedParams: Seq[Identifier] = fd.params.zipWithIndex.flatMap{ + case (vd, i) if ownEffects.contains(i) => Some(vd) + case _ => None + }.map(_.id) + + val newFunDef = updatedFunDefs(fd) + + if(aliasedParams.isEmpty) { + val newBody = fd.body.map(body => { + makeSideEffectsExplicit(body, Seq(), effects, updatedFunDefs, varsInScope)(ctx) + }) + newFunDef.body = newBody + newFunDef.precondition = fd.precondition + newFunDef.postcondition = fd.postcondition + } else { + val freshLocalVars: Seq[Identifier] = aliasedParams.map(v => v.freshen) + val rewritingMap: Map[Identifier, Identifier] = aliasedParams.zip(freshLocalVars).toMap + + val newBody = fd.body.map(body => { + + val freshBody = replaceFromIDs(rewritingMap.map(p => (p._1, p._2.toVariable)), body) + val explicitBody = makeSideEffectsExplicit(freshBody, freshLocalVars, effects, updatedFunDefs, varsInScope)(ctx) + + //WARNING: only works if side effects in Tuples are extracted from left to right, + // in the ImperativeTransformation phase. + val finalBody: Expr = Tuple(explicitBody +: freshLocalVars.map(_.toVariable)) + + freshLocalVars.zip(aliasedParams).foldLeft(finalBody)((bd, vp) => { + LetVar(vp._1, Variable(vp._2), bd) + }) + + }) + + val newPostcondition = fd.postcondition.map(post => { + val Lambda(Seq(res), postBody) = post + val newRes = ValDef(FreshIdentifier(res.id.name, newFunDef.returnType)) + val newBody = + replace( + aliasedParams.zipWithIndex.map{ case (id, i) => + (id.toVariable, TupleSelect(newRes.toVariable, i+2)): (Expr, Expr)}.toMap ++ + aliasedParams.map(id => + (Old(id), id.toVariable): (Expr, Expr)).toMap + + (res.toVariable -> TupleSelect(newRes.toVariable, 1)), + postBody) + Lambda(Seq(newRes), newBody).setPos(post) + }) + + newFunDef.body = newBody + newFunDef.precondition = fd.precondition + newFunDef.postcondition = newPostcondition + } + } + + //We turn all local val of mutable objects into vars and explicit side effects + //using assignments. We also make sure that no aliasing is being done. + private def makeSideEffectsExplicit + (body: Expr, aliasedParams: Seq[Identifier], effects: Effects, updatedFunDefs: Map[FunDef, FunDef], varsInScope: Map[FunDef, Set[Identifier]]) + (ctx: LeonContext): Expr = { + preMapWithContext[Set[Identifier]]((expr, bindings) => expr match { + + case up@ArrayUpdate(a, i, v) => { + val ra@Variable(id) = a + if(bindings.contains(id)) + (Some(Assignment(id, ArrayUpdated(ra, i, v).setPos(up)).setPos(up)), bindings) + else + (None, bindings) + } + + case l@Let(id, IsTyped(v, ArrayType(_)), b) => { + val varDecl = LetVar(id, v, b).setPos(l) + (Some(varDecl), bindings + id) + } + + case l@LetVar(id, IsTyped(v, ArrayType(_)), b) => { + (None, bindings + id) + } + + //we need to replace local fundef by the new updated fun defs. + case l@LetDef(fds, body) => { + //this might be traversed several time in case of doubly nested fundef, + //so we need to ignore the second times by checking if updatedFunDefs + //contains a mapping or not + val nfds = fds.map(fd => updatedFunDefs.get(fd).getOrElse(fd)) + (Some(LetDef(nfds, body)), bindings) + } + + case fi@FunctionInvocation(fd, args) => { + + val vis: Set[Identifier] = varsInScope.get(fd.fd).getOrElse(Set()) + args.find({ + case Variable(id) => vis.contains(id) + case _ => false + }).foreach(aliasedArg => + ctx.reporter.fatalError(aliasedArg.getPos, "Illegal passing of aliased parameter: " + aliasedArg)) + + updatedFunDefs.get(fd.fd) match { + case None => (None, bindings) + case Some(nfd) => { + val nfi = FunctionInvocation(nfd.typed(fd.tps), args).setPos(fi) + val fiEffects = effects.getOrElse(fd.fd, Set()) + if(fiEffects.nonEmpty) { + val modifiedArgs: Seq[Variable] = + args.zipWithIndex.filter{ case (arg, i) => fiEffects.contains(i) } + .map(_._1.asInstanceOf[Variable]) + + val duplicatedParams = modifiedArgs.diff(modifiedArgs.distinct).distinct + if(duplicatedParams.nonEmpty) + ctx.reporter.fatalError(fi.getPos, "Illegal passing of aliased parameter: " + duplicatedParams.head) + + val freshRes = FreshIdentifier("res", nfd.returnType) + + val extractResults = Block( + modifiedArgs.zipWithIndex.map(p => Assignment(p._1.id, TupleSelect(freshRes.toVariable, p._2 + 2))), + TupleSelect(freshRes.toVariable, 1)) + + + val newExpr = Let(freshRes, nfi, extractResults) + (Some(newExpr), bindings) + } else { + (Some(nfi), bindings) + } + } + } + } + + case _ => (None, bindings) + + })(body, aliasedParams.toSet) + } + + //TODO: in the future, any object with vars could be aliased and so + // we will need a general property + + private type Effects = Map[FunDef, Set[Int]] + + /* + * compute effects for each function in the program, including any nested + * functions (LetDef). + */ + private def effectsAnalysis(pgm: Program): Effects = { + + //currently computed effects (incremental) + var effects: Map[FunDef, Set[Int]] = Map() + //missing dependencies for a function for its effects to be complete + var missingEffects: Map[FunDef, Set[FunctionInvocation]] = Map() + + def effectsFullyComputed(fd: FunDef): Boolean = !missingEffects.isDefinedAt(fd) + + for { + fd <- allFunDefs(pgm) + } { + fd.body match { + case None => + effects += (fd -> Set()) + case Some(body) => { + val mutableParams = fd.params.filter(vd => vd.getType match { + case ArrayType(_) => true + case _ => false + }) + val mutatedParams = mutableParams.filter(vd => exists { + case ArrayUpdate(Variable(a), _, _) => a == vd.id + case _ => false + }(body)) + val mutatedParamsIndices = fd.params.zipWithIndex.flatMap{ + case (vd, i) if mutatedParams.contains(vd) => Some(i) + case _ => None + }.toSet + effects = effects + (fd -> mutatedParamsIndices) + + val missingCalls: Set[FunctionInvocation] = functionCallsOf(body).filterNot(fi => fi.tfd.fd == fd) + if(missingCalls.nonEmpty) + missingEffects += (fd -> missingCalls) + } + } + } + + def rec(): Unit = { + val previousMissingEffects = missingEffects + + for{ (fd, calls) <- missingEffects } { + var newMissingCalls: Set[FunctionInvocation] = calls + for(fi <- calls) { + val mutatedArgs = invocEffects(fi) + val mutatedFunParams: Set[Int] = fd.params.zipWithIndex.flatMap{ + case (vd, i) if mutatedArgs.contains(vd.id) => Some(i) + case _ => None + }.toSet + effects += (fd -> (effects(fd) ++ mutatedFunParams)) + + if(effectsFullyComputed(fi.tfd.fd)) { + newMissingCalls -= fi + } + } + if(newMissingCalls.isEmpty) + missingEffects = missingEffects - fd + else + missingEffects += (fd -> newMissingCalls) + } + + if(missingEffects != previousMissingEffects) { + rec() + } + } + + def invocEffects(fi: FunctionInvocation): Set[Identifier] = { + //TODO: the require should be fine once we consider nested functions as well + //require(effects.isDefinedAt(fi.tfd.fd) + val mutatedParams: Set[Int] = effects.get(fi.tfd.fd).getOrElse(Set()) + fi.args.zipWithIndex.flatMap{ + case (Variable(id), i) if mutatedParams.contains(i) => Some(id) + case _ => None + }.toSet + } + + rec() + effects + } + + + def checkAliasing(fd: FunDef)(ctx: LeonContext): Unit = { + def checkReturnValue(body: Expr, bindings: Set[Identifier]): Unit = { + getReturnedExpr(body).foreach{ + case IsTyped(v@Variable(id), ArrayType(_)) if bindings.contains(id) => + ctx.reporter.fatalError(v.getPos, "Cannot return a shared reference to a mutable object: " + v) + case _ => () + } + } + + fd.body.foreach(bd => { + val params = fd.params.map(_.id).toSet + checkReturnValue(bd, params) + preMapWithContext[Set[Identifier]]((expr, bindings) => expr match { + case l@Let(id, IsTyped(v, ArrayType(_)), b) => { + v match { + case FiniteArray(_, _, _) => () + case FunctionInvocation(_, _) => () + case ArrayUpdated(_, _, _) => () + case _ => ctx.reporter.fatalError(l.getPos, "Cannot alias array: " + l) + } + (None, bindings + id) + } + case l@LetVar(id, IsTyped(v, ArrayType(_)), b) => { + v match { + case FiniteArray(_, _, _) => () + case FunctionInvocation(_, _) => () + case ArrayUpdated(_, _, _) => () + case _ => ctx.reporter.fatalError(l.getPos, "Cannot alias array: " + l) + } + (None, bindings + id) + } + case l@LetDef(fds, body) => { + fds.foreach(fd => fd.body.foreach(bd => checkReturnValue(bd, bindings))) + (None, bindings) + } + + case _ => (None, bindings) + })(bd, params) + }) + } + + /* + * A bit hacky, but not sure of the best way to do something like that + * currently. + */ + private def getReturnedExpr(expr: Expr): Seq[Expr] = expr match { + case Let(_, _, rest) => getReturnedExpr(rest) + case LetVar(_, _, rest) => getReturnedExpr(rest) + case Block(_, rest) => getReturnedExpr(rest) + case IfExpr(_, thenn, elze) => getReturnedExpr(thenn) ++ getReturnedExpr(elze) + case MatchExpr(_, cses) => cses.flatMap{ cse => getReturnedExpr(cse.rhs) } + case e => Seq(expr) + } + + + /* + * returns all fun def in the program, including local definitions inside + * other functions (LetDef). + */ + private def allFunDefs(pgm: Program): Seq[FunDef] = + pgm.definedFunctions.flatMap(fd => + fd.body.toSet.flatMap((bd: Expr) => + nestedFunDefsOf(bd)) + fd) +} diff --git a/src/main/scala/leon/xlang/Expressions.scala b/src/main/scala/leon/xlang/Expressions.scala index d627e0d284f4933cd6ed7ecefbe7dd13f4e658f8..98214ee640bd95227c0b759113ab74d2c9555d94 100644 --- a/src/main/scala/leon/xlang/Expressions.scala +++ b/src/main/scala/leon/xlang/Expressions.scala @@ -15,6 +15,14 @@ object Expressions { trait XLangExpr extends Expr + case class Old(id: Identifier) extends XLangExpr with Terminal with PrettyPrintable { + val getType = id.getType + + def printWith(implicit pctx: PrinterContext): Unit = { + p"old($id)" + } + } + case class Block(exprs: Seq[Expr], last: Expr) extends XLangExpr with Extractable with PrettyPrintable { def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = { Some((exprs :+ last, exprs => Block(exprs.init, exprs.last))) diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index 45bb36770cddca417ba51582cd7824fc09152199..6b7f7cc6ee3c00827289313ed60e111b6ec3a640 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -9,7 +9,7 @@ import leon.purescala.Expressions._ import leon.purescala.Extractors._ import leon.purescala.Constructors._ import leon.purescala.ExprOps._ -import leon.purescala.TypeOps._ +import leon.purescala.TypeOps.leastUpperBound import leon.purescala.Types._ import leon.xlang.Expressions._ @@ -67,7 +67,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] { val (tRes, tScope, tFun) = toFunction(tExpr) val (eRes, eScope, eFun) = toFunction(eExpr) - val iteRType = leastUpperBound(tRes.getType, eRes.getType).get + val iteRType = leastUpperBound(tRes.getType, eRes.getType).getOrElse(Untyped) val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varsInScope).toSeq val resId = FreshIdentifier("res", iteRType) @@ -218,7 +218,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] { case LetDef(fds, b) => if(fds.size > 1) { - //TODO: no support for true mutually recursion + //TODO: no support for true mutual recursion toFunction(LetDef(Seq(fds.head), LetDef(fds.tail, b))) } else { diff --git a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala index 3a7f8be381cfa5fdb87dc6870cf35141f8d8cf33..59dd3217714f964be5f9ff6a1428617c6e54e17a 100644 --- a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala +++ b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala @@ -12,7 +12,8 @@ object XLangDesugaringPhase extends LeonPhase[Program, Program] { override def run(ctx: LeonContext, pgm: Program): (LeonContext, Program) = { val phases = - ArrayTransformation andThen + //ArrayTransformation andThen + AntiAliasingPhase andThen EpsilonElimination andThen ImperativeCodeElimination diff --git a/src/test/resources/regression/frontends/error/xlang/Array2.scala b/src/test/resources/regression/frontends/error/xlang/Array2.scala deleted file mode 100644 index b1b370395d7e0b648e0b88875b3678eaf4668eb5..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array2.scala +++ /dev/null @@ -1,11 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -object Array2 { - - def foo(): Int = { - val a = Array.fill(5)(5) - val b = a - b(3) - } - -} diff --git a/src/test/resources/regression/frontends/error/xlang/Array3.scala b/src/test/resources/regression/frontends/error/xlang/Array3.scala deleted file mode 100644 index 14a8512015102bd235a9efe41782ba7d5a46fd44..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array3.scala +++ /dev/null @@ -1,14 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -object Array3 { - - def foo(): Int = { - val a = Array.fill(5)(5) - if(a.length > 2) - a(1) = 2 - else - 0 - 0 - } - -} diff --git a/src/test/resources/regression/frontends/error/xlang/Array4.scala b/src/test/resources/regression/frontends/error/xlang/Array4.scala deleted file mode 100644 index e41535d6d267986ba7764a6f457970c6ab33b733..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array4.scala +++ /dev/null @@ -1,10 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -object Array4 { - - def foo(a: Array[Int]): Int = { - val b = a - b(3) - } - -} diff --git a/src/test/resources/regression/frontends/error/xlang/Array5.scala b/src/test/resources/regression/frontends/error/xlang/Array5.scala deleted file mode 100644 index 8b7254e9482ddc7df1196c03b1a191c54e86ea0f..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array5.scala +++ /dev/null @@ -1,12 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -object Array5 { - - def foo(a: Array[Int]): Int = { - a(2) = 4 - a(2) - } - -} - -// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/frontends/error/xlang/Array6.scala b/src/test/resources/regression/frontends/error/xlang/Array6.scala deleted file mode 100644 index c4d0c09541d3c7fe4ea4be1527cd704e96d54bb1..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array6.scala +++ /dev/null @@ -1,12 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - - -object Array6 { - - def foo(): Int = { - val a = Array.fill(5)(5) - var b = a - b(0) - } - -} diff --git a/src/test/resources/regression/frontends/error/xlang/Array7.scala b/src/test/resources/regression/frontends/error/xlang/Array7.scala deleted file mode 100644 index ab6f4c20da5a84adfd8509a60174d78c0c423654..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array7.scala +++ /dev/null @@ -1,11 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -object Array7 { - - def foo(): Int = { - val a = Array.fill(5)(5) - var b = a - b(0) - } - -} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation1.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation1.scala new file mode 100644 index 0000000000000000000000000000000000000000..f0e622c66c31e4fe01ff8be5c1b08e17d4d330eb --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation1.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object ArrayParamMutation1 { + + def update(a: Array[BigInt]): Unit = { + require(a.length > 0) + a(0) = 10 + } + + def f(): BigInt = { + val a = Array.fill(10)(BigInt(0)) + update(a) + a(0) + } ensuring(res => res == 10) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation2.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation2.scala new file mode 100644 index 0000000000000000000000000000000000000000..801b35e0cc545f160fc8061e34fd0ee06b7c3f73 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation2.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object ArrayParamMutation2 { + + def rec(a: Array[BigInt]): BigInt = { + require(a.length > 1 && a(0) >= 0) + if(a(0) == 0) + a(1) + else { + a(0) = a(0) - 1 + a(1) = a(1) + a(0) + rec(a) + } + } ensuring(res => a(0) == 0) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation3.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation3.scala new file mode 100644 index 0000000000000000000000000000000000000000..f575167444f839c0ee900a35de5e4e822624dc21 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation3.scala @@ -0,0 +1,23 @@ +import leon.lang._ + +object ArrayParamMutation3 { + + def odd(a: Array[BigInt]): Boolean = { + require(a.length > 0 && a(0) >= 0) + if(a(0) == 0) false + else { + a(0) = a(0) - 1 + even(a) + } + } ensuring(res => a(0) == 0) + + def even(a: Array[BigInt]): Boolean = { + require(a.length > 0 && a(0) >= 0) + if(a(0) == 0) true + else { + a(0) = a(0) - 1 + odd(a) + } + } ensuring(res => a(0) == 0) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation4.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation4.scala new file mode 100644 index 0000000000000000000000000000000000000000..31af4cd5885ea66a2d11e9387ba7e306423ec4d7 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation4.scala @@ -0,0 +1,23 @@ +import leon.lang._ + +object ArrayParamMutation4 { + + def multipleArgs(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0) + if(a1(0) == 10) + a2(0) = 13 + else + a2(0) = a1(0) + 1 + } + + def transitiveEffects(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0) + multipleArgs(a1, a2) + } ensuring(_ => a2(0) >= a1(0)) + + def transitiveReverseEffects(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0) + multipleArgs(a2, a1) + } ensuring(_ => a1(0) >= a2(0)) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation5.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation5.scala new file mode 100644 index 0000000000000000000000000000000000000000..249a79d1f3b7d8df8c941ab3121c4eafed149e03 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation5.scala @@ -0,0 +1,21 @@ + +import leon.lang._ + +object ArrayParamMutation5 { + + def mutuallyRec1(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a1(0) > 0 && a2.length > 0) + if(a1(0) == 10) { + () + } else { + mutuallyRec2(a1, a2) + } + } ensuring(res => a1(0) == 10) + + def mutuallyRec2(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0 && a1(0) > 0) + a1(0) = 10 + mutuallyRec1(a1, a2) + } + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation6.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation6.scala new file mode 100644 index 0000000000000000000000000000000000000000..29ded427fa6546a103d8da6f98cefc1415f389a6 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation6.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object ArrayParamMutation6 { + + def multipleEffects(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0) + a1(0) = 11 + a2(0) = 12 + } ensuring(_ => a1(0) != a2(0)) + + def f(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0) + multipleEffects(a1, a2) + } ensuring(_ => a1(0) == 11 && a2(0) == 12) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation7.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation7.scala new file mode 100644 index 0000000000000000000000000000000000000000..53d67729fd57723d1693c564e99cd3d66ee095ef --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation7.scala @@ -0,0 +1,29 @@ +import leon.lang._ + +object ArrayParamMutation7 { + + def f(i: BigInt)(implicit world: Array[BigInt]): BigInt = { + require(world.length == 3) + + world(1) += 1 //global counter of f + + val res = i*i + world(0) = res + res + } + + def mainProgram(): Unit = { + + implicit val world: Array[BigInt] = Array(0,0,0) + + f(1) + assert(world(0) == 1) + f(2) + assert(world(0) == 4) + f(4) + assert(world(0) == 16) + + assert(world(1) == 3) + } + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation8.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation8.scala new file mode 100644 index 0000000000000000000000000000000000000000..68aa737eb42e6e51073dac27b799e06bde928400 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation8.scala @@ -0,0 +1,25 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +import leon.lang._ + +object ArrayParamMutation8 { + + def odd(a: Array[BigInt]): Boolean = { + require(a.length > 0 && a(0) >= 0) + if(a(0) == 0) false + else { + a(0) = a(0) - 1 + even(a) + } + } ensuring(res => if(old(a)(0) % 2 == 1) res else !res) + + def even(a: Array[BigInt]): Boolean = { + require(a.length > 0 && a(0) >= 0) + if(a(0) == 0) true + else { + a(0) = a(0) - 1 + odd(a) + } + } ensuring(res => if(old(a)(0) % 2 == 0) res else !res) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation9.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation9.scala new file mode 100644 index 0000000000000000000000000000000000000000..f5046b6cf3b40382ccc5d989d81d73bc577da9f7 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation9.scala @@ -0,0 +1,22 @@ +import leon.lang._ + +object ArrayParamMutation9 { + def abs(a: Array[Int]) { + require(a.length > 0) + var i = 0; + (while (i < a.length) { + a(i) = if (a(i) < 0) -a(i) else a(i) // <-- this makes Leon crash + i = i + 1 + }) invariant(i >= 0) + } + + + def main = { + val a = Array(0, -1, 2, -3) + + abs(a) + + a(0) + a(1) - 1 + a(2) - 2 + a(3) - 3 // == 0 + } + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation1.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation1.scala new file mode 100644 index 0000000000000000000000000000000000000000..a7250a7bcfd572c49584110d213f9e9991a10c9f --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation1.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object NestedFunParamsMutation1 { + + def f(): Int = { + def g(a: Array[Int]): Unit = { + require(a.length > 0) + a(0) = 10 + } + + val a = Array(1,2,3,4) + g(a) + a(0) + } ensuring(_ == 10) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation2.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation2.scala new file mode 100644 index 0000000000000000000000000000000000000000..799a87c6e9e70bf6ef89bfc1fb7a6e116adb7feb --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation2.scala @@ -0,0 +1,21 @@ +import leon.lang._ + +object NestedFunParamsMutation2 { + + def f(): Int = { + def g(a: Array[Int]): Unit = { + require(a.length > 0) + a(0) = 10 + } + + def h(a: Array[Int]): Unit = { + require(a.length > 0) + g(a) + } + + val a = Array(1,2,3,4) + h(a) + a(0) + } ensuring(_ == 10) + +} diff --git a/src/test/resources/regression/frontends/error/xlang/Array1.scala b/src/test/resources/regression/xlang/error/Array1.scala similarity index 100% rename from src/test/resources/regression/frontends/error/xlang/Array1.scala rename to src/test/resources/regression/xlang/error/Array1.scala diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing1.scala b/src/test/resources/regression/xlang/error/ArrayAliasing1.scala new file mode 100644 index 0000000000000000000000000000000000000000..30b1652dac16f1f8a9c7b36d83d9a0a52811e3c6 --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing1.scala @@ -0,0 +1,13 @@ +import leon.lang._ + +object ArrayAliasing1 { + + def f1(): BigInt = { + val a = Array.fill(10)(BigInt(0)) + val b = a + b(0) = 10 + a(0) + } ensuring(_ == 10) + +} + diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing10.scala b/src/test/resources/regression/xlang/error/ArrayAliasing10.scala new file mode 100644 index 0000000000000000000000000000000000000000..05737b03d9816be72cf52f4913f57a482abf76dd --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing10.scala @@ -0,0 +1,19 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +object ArrayAliasing10 { + + def foo(): Int = { + val a = Array.fill(5)(0) + + def rec(): Array[Int] = { + + def nestedRec(): Array[Int] = { + a + } + nestedRec() + } + val b = rec() + b(0) + } + +} diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing2.scala b/src/test/resources/regression/xlang/error/ArrayAliasing2.scala new file mode 100644 index 0000000000000000000000000000000000000000..4e906865a8848aaa00150c67de16f6b32136c64a --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing2.scala @@ -0,0 +1,11 @@ +import leon.lang._ + +object ArrayAliasing2 { + + def f1(a: Array[BigInt]): BigInt = { + val b = a + b(0) = 10 + a(0) + } ensuring(_ == 10) + +} diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing3.scala b/src/test/resources/regression/xlang/error/ArrayAliasing3.scala new file mode 100644 index 0000000000000000000000000000000000000000..0398fc37b9dc2028e1535878ec377bef1620dd88 --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing3.scala @@ -0,0 +1,11 @@ +import leon.lang._ + +object ArrayAliasing3 { + + def f1(a: Array[BigInt], b: Boolean): BigInt = { + val c = if(b) a else Array[BigInt](1,2,3,4,5) + c(0) = 10 + a(0) + } ensuring(_ == 10) + +} diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing4.scala b/src/test/resources/regression/xlang/error/ArrayAliasing4.scala new file mode 100644 index 0000000000000000000000000000000000000000..2632782c39e853744744b66309ef10342bee386b --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing4.scala @@ -0,0 +1,11 @@ +import leon.lang._ + +object ArrayAliasing4 { + + def f1(a: Array[BigInt]): Array[BigInt] = { + require(a.length > 0) + a(0) = 10 + a + } ensuring(res => res(0) == 10) + +} diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing5.scala b/src/test/resources/regression/xlang/error/ArrayAliasing5.scala new file mode 100644 index 0000000000000000000000000000000000000000..b9363d1ab5a627df29e6a0f0018c73850dcbb529 --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing5.scala @@ -0,0 +1,18 @@ +import leon.lang._ + +object ArrayAliasing5 { + + + def f1(a: Array[BigInt], b: Array[BigInt]): Unit = { + require(a.length > 0 && b.length > 0) + a(0) = 10 + b(0) = 20 + } ensuring(_ => a(0) == 10 && b(0) == 20) + + + def callWithAliases(): Unit = { + val a = Array[BigInt](0,0,0,0) + f1(a, a) + } + +} diff --git a/src/test/resources/regression/frontends/error/xlang/Array8.scala b/src/test/resources/regression/xlang/error/ArrayAliasing6.scala similarity index 80% rename from src/test/resources/regression/frontends/error/xlang/Array8.scala rename to src/test/resources/regression/xlang/error/ArrayAliasing6.scala index bbe5bd5fd92b0f4f9662379693d06924bdaf5461..963a134bf71da7a625252411854d73272f56d574 100644 --- a/src/test/resources/regression/frontends/error/xlang/Array8.scala +++ b/src/test/resources/regression/xlang/error/ArrayAliasing6.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -object Array8 { +object ArrayAliasing6 { def foo(a: Array[Int]): Array[Int] = { a diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing7.scala b/src/test/resources/regression/xlang/error/ArrayAliasing7.scala new file mode 100644 index 0000000000000000000000000000000000000000..21bc94502327b334f2e4e5887d4a7286731c78b7 --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing7.scala @@ -0,0 +1,10 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +object ArrayAliasing7 { + + def foo(a: Array[Int]): Array[Int] = { + val b = a + b + } + +} diff --git a/src/test/resources/regression/frontends/error/xlang/Array9.scala b/src/test/resources/regression/xlang/error/ArrayAliasing8.scala similarity index 86% rename from src/test/resources/regression/frontends/error/xlang/Array9.scala rename to src/test/resources/regression/xlang/error/ArrayAliasing8.scala index fbc7dd7376e0966df5ed6eb93bafa7427aeab9e8..e7c27cc9cebf657033da4c07936ce16a510075a0 100644 --- a/src/test/resources/regression/frontends/error/xlang/Array9.scala +++ b/src/test/resources/regression/xlang/error/ArrayAliasing8.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -object Array9 { +object ArrayAliasing8 { def foo(a: Array[Int]): Int = { def rec(): Array[Int] = { diff --git a/src/test/resources/regression/frontends/error/xlang/Array10.scala b/src/test/resources/regression/xlang/error/ArrayAliasing9.scala similarity index 87% rename from src/test/resources/regression/frontends/error/xlang/Array10.scala rename to src/test/resources/regression/xlang/error/ArrayAliasing9.scala index 563cdacdf7e0e66ff56eec571fae4d3e3bbe10be..c84d29c3fbb4866100173b7ac0e7b4d0da9a1e57 100644 --- a/src/test/resources/regression/frontends/error/xlang/Array10.scala +++ b/src/test/resources/regression/xlang/error/ArrayAliasing9.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -object Array10 { +object ArrayAliasing9 { def foo(): Int = { val a = Array.fill(5)(0) diff --git a/src/test/resources/regression/xlang/error/NestedFunctionAliasing1.scala b/src/test/resources/regression/xlang/error/NestedFunctionAliasing1.scala new file mode 100644 index 0000000000000000000000000000000000000000..12feace5413c23e688ac194192482120793b4e24 --- /dev/null +++ b/src/test/resources/regression/xlang/error/NestedFunctionAliasing1.scala @@ -0,0 +1,17 @@ +import leon.lang._ + +object NestedFunctinAliasing1 { + + def f(): Int = { + val a = Array(1,2,3,4) + + def g(b: Array[Int]): Unit = { + require(b.length > 0 && a.length > 0) + b(0) = 10 + a(0) = 17 + } ensuring(_ => b(0) == 10) + + g(a) + a(0) + } ensuring(_ == 10) +} diff --git a/src/test/resources/regression/xlang/error/NestedFunctionAliasing2.scala b/src/test/resources/regression/xlang/error/NestedFunctionAliasing2.scala new file mode 100644 index 0000000000000000000000000000000000000000..81a9b82b39fb47af105ef2e3122fe86a8b10dbb6 --- /dev/null +++ b/src/test/resources/regression/xlang/error/NestedFunctionAliasing2.scala @@ -0,0 +1,17 @@ +import leon.lang._ + +object NestedFunctinAliasing1 { + + def f(a: Array(1,2,3,4)): Int = { + + def g(b: Array[Int]): Unit = { + require(b.length > 0 && a.length > 0) + b(0) = 10 + a(0) = 17 + } ensuring(_ => b(0) == 10) + + g(a) + a(0) + } ensuring(_ == 10) + +} diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala index d568e471f08eb4cf1a558675419daabd9e9c940a..c5571fa4f9b2def4f33a586247aa7eb213483a10 100644 --- a/src/test/scala/leon/integration/solvers/SolversSuite.scala +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -22,13 +22,13 @@ class SolversSuite extends LeonTestSuiteWithProgram { val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { (if (SolverFactory.hasNativeZ3) Seq( - ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) + ("fairz3", (ctx: LeonContext, pgm: Program) => new Z3StringFairZ3Solver(ctx, pgm) with ForcedProgramConversion ) ) else Nil) ++ (if (SolverFactory.hasZ3) Seq( - ("smt-z3", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) + ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBZ3Solver(ctx, pgm)) with ForcedProgramConversion ) ) else Nil) ++ (if (SolverFactory.hasCVC4) Seq( - ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBCVC4Solver(ctx, pgm)) with ForcedProgramConversion ) ) else Nil) } @@ -78,7 +78,7 @@ class SolversSuite extends LeonTestSuiteWithProgram { } } case _ => - fail(s"Solver $solver - Constraint "+cnstr.asString+" is unsat!?") + fail(s"Solver $solver - Constraint "+cnstr.asString+" is unsat!? Solver was "+solver.getClass) } } finally { solver.free() diff --git a/src/test/scala/leon/regression/frontends/FrontEndsSuite.scala b/src/test/scala/leon/regression/frontends/FrontEndsSuite.scala index b9bfccd36bc2abcc3479c9d3a703df0cdfe16ab7..a0e52aa0b083c04a5a3e1141d2f3287a1995611e 100644 --- a/src/test/scala/leon/regression/frontends/FrontEndsSuite.scala +++ b/src/test/scala/leon/regression/frontends/FrontEndsSuite.scala @@ -36,7 +36,6 @@ class FrontEndsSuite extends LeonRegressionSuite { } val pipeNormal = xlang.NoXLangFeaturesChecking andThen NoopPhase() // redundant NoopPhase to trigger throwing error between phases - val pipeX = NoopPhase[Program]() val baseDir = "regression/frontends/" forEachFileIn(baseDir+"passing/") { f => @@ -45,8 +44,5 @@ class FrontEndsSuite extends LeonRegressionSuite { forEachFileIn(baseDir+"error/simple/") { f => testFrontend(f, pipeNormal, true) } - forEachFileIn(baseDir+"error/xlang/") { f => - testFrontend(f, pipeX, true) - } } diff --git a/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala b/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala index c70df950e0768ca71dd7bba013ac04cd7404edab..ca2b4a3c98107c73c9244486200dd0a44a348cb2 100644 --- a/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala +++ b/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala @@ -251,6 +251,7 @@ object SortedList { case "insertSorted" => Decomp("Assert isSorted(in1)", List( Decomp("ADT Split on 'in1'", List( + Close("CEGIS"), Decomp("Ineq. Split on 'head*' and 'v*'", List( Close("CEGIS"), Decomp("Equivalent Inputs *", List( @@ -259,8 +260,7 @@ object SortedList { )) )), Close("CEGIS") - )), - Close("CEGIS") + )) )) )) } diff --git a/src/test/scala/leon/regression/verification/VerificationSuite.scala b/src/test/scala/leon/regression/verification/VerificationSuite.scala index f2ae97880694baac498b94570f81beb1c21c422f..446d06675cdb9b3eb1c7095137ab95e8a928c399 100644 --- a/src/test/scala/leon/regression/verification/VerificationSuite.scala +++ b/src/test/scala/leon/regression/verification/VerificationSuite.scala @@ -41,7 +41,7 @@ trait VerificationSuite extends LeonRegressionSuite { VerificationPhase andThen (if (desugarXLang) FixReportLabels else NoopPhase[VerificationReport]) - val ctx = createLeonContext(files:_*) + val ctx = createLeonContext(files:_*).copy(reporter = new TestErrorReporter) try { val (_, ast) = extraction.run(ctx, files) diff --git a/src/test/scala/leon/regression/xlang/XLangDesugaringSuite.scala b/src/test/scala/leon/regression/xlang/XLangDesugaringSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..41260ec2df6f9d9f32827c39e7f700da15ccca9e --- /dev/null +++ b/src/test/scala/leon/regression/xlang/XLangDesugaringSuite.scala @@ -0,0 +1,46 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.regression.xlang + +import leon._ +import leon.test._ +import purescala.Definitions.Program +import java.io.File + +class XLangDesugaringSuite extends LeonRegressionSuite { + // Hard-code output directory, for Eclipse purposes + + val pipeline = frontends.scalac.ExtractionPhase andThen new utils.PreprocessingPhase(true) + + def testFrontend(f: File, forError: Boolean) = { + test ("Testing " + f.getName) { + val ctx = createLeonContext() + if (forError) { + intercept[LeonFatalError]{ + pipeline.run(ctx, List(f.getAbsolutePath)) + } + } else { + pipeline.run(ctx, List(f.getAbsolutePath)) + } + } + + } + + private def forEachFileIn(path : String)(block : File => Unit) { + val fs = filesInResourceDir(path, _.endsWith(".scala")) + + for(f <- fs) { + block(f) + } + } + + val baseDir = "regression/xlang/" + + forEachFileIn(baseDir+"passing/") { f => + testFrontend(f, false) + } + forEachFileIn(baseDir+"error/") { f => + testFrontend(f, true) + } + +} diff --git a/src/test/scala/leon/test/TestSilentReporter.scala b/src/test/scala/leon/test/TestSilentReporter.scala index 2cf9ea4f7f6c78d4f07002a01f434a150e0d9034..2a8761584222f02c1e4bf85b6fd031c603771079 100644 --- a/src/test/scala/leon/test/TestSilentReporter.scala +++ b/src/test/scala/leon/test/TestSilentReporter.scala @@ -13,3 +13,10 @@ class TestSilentReporter extends DefaultReporter(Set()) { case _ => } } + +class TestErrorReporter extends DefaultReporter(Set()) { + override def emit(msg: Message): Unit = msg match { + case Message(this.ERROR | this.FATAL, _, _) => super.emit(msg) + case _ => + } +} diff --git a/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala b/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala index b448953acf1856cf1df12289880000226f78f8bf..4f74b00c5f93a3125caece106809dfd4182fa8a8 100644 --- a/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala +++ b/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala @@ -6,7 +6,7 @@ import leon.test._ import leon.purescala.Common._ import leon.purescala.Expressions._ import leon.purescala.Types._ -import leon.purescala.TypeOps._ +import leon.purescala.TypeOps.isSubtypeOf import leon.purescala.Definitions._ import leon.purescala.ExprOps._ @@ -279,4 +279,44 @@ class ExprOpsSuite extends LeonTestSuite with helpers.WithLikelyEq with helpers. } } + + test("preMapWithContext") { ctx => + val expr = Plus(bi(1), Minus(bi(2), bi(3))) + def op(e : Expr, set: Set[Int]): (Option[Expr], Set[Int]) = e match { + case Minus(InfiniteIntegerLiteral(two), e2) if two == BigInt(2) => (Some(bi(2)), set) + case InfiniteIntegerLiteral(one) if one == BigInt(1) => (Some(bi(2)), set) + case InfiniteIntegerLiteral(two) if two == BigInt(2) => (Some(bi(42)), set) + case _ => (None, set) + } + + assert(preMapWithContext(op, false)(expr, Set()) === Plus(bi(2), bi(2))) + assert(preMapWithContext(op, true)(expr, Set()) === Plus(bi(42), bi(42))) + + val expr2 = Let(x.id, bi(1), Let(y.id, bi(2), Plus(x, y))) + def op2(e: Expr, bindings: Map[Identifier, BigInt]): (Option[Expr], Map[Identifier, BigInt]) = e match { + case Let(id, InfiniteIntegerLiteral(v), body) => (None, bindings + (id -> v)) + case Variable(id) => (bindings.get(id).map(v => InfiniteIntegerLiteral(v)), bindings) + case _ => (None, bindings) + } + + assert(preMapWithContext(op2, false)(expr2, Map()) === Let(x.id, bi(1), Let(y.id, bi(2), Plus(bi(1), bi(2))))) + + def op3(e: Expr, bindings: Map[Identifier, BigInt]): (Option[Expr], Map[Identifier, BigInt]) = e match { + case Let(id, InfiniteIntegerLiteral(v), body) => (Some(body), bindings + (id -> v)) + case Variable(id) => (bindings.get(id).map(v => InfiniteIntegerLiteral(v)), bindings) + case _ => (None, bindings) + } + assert(preMapWithContext(op3, true)(expr2, Map()) === Plus(bi(1), bi(2))) + + + val expr4 = Plus(Let(y.id, bi(2), y), + Let(y.id, bi(4), y)) + def op4(e: Expr, bindings: Map[Identifier, BigInt]): (Option[Expr], Map[Identifier, BigInt]) = e match { + case Let(id, InfiniteIntegerLiteral(v), body) => (Some(body), if(bindings.contains(id)) bindings else (bindings + (id -> v))) + case Variable(id) => (bindings.get(id).map(v => InfiniteIntegerLiteral(v)), bindings) + case _ => (None, bindings) + } + assert(preMapWithContext(op4, true)(expr4, Map()) === Plus(bi(2), bi(4))) + } + } diff --git a/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala b/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala index fe01946d158153d2dd9ae2a3be2234ee4cd18aa9..0f30a5ba1a95d39e78a1594f39804c8161e919a6 100644 --- a/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala +++ b/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala @@ -72,17 +72,12 @@ object BatchedQueue { def enqueue(v: T): Queue[T] = { require(invariant) - f match { - case Cons(h, t) => - Queue(f, Cons(v, r)) - case Nil() => - Queue(Cons(v, f), Nil()) - } - + ???[Queue[T]] } ensuring { (res: Queue[T]) => - res.invariant && - res.toList.last == v && - res.content == this.content ++ Set(v) + res.invariant && + res.toList.last == v && + res.size == size + 1 && + res.content == this.content ++ Set(v) } } } diff --git a/testcases/synthesis/etienne-thesis/run.sh b/testcases/synthesis/etienne-thesis/run.sh index ee64d86702076bf5ff909c3437f321498a2afe68..924b99cc57386f1dba92bfb97017b41a801cd8ea 100755 --- a/testcases/synthesis/etienne-thesis/run.sh +++ b/testcases/synthesis/etienne-thesis/run.sh @@ -1,7 +1,7 @@ #!/bin/bash function run { - cmd="./leon --debug=report --timeout=30 --synthesis $1" + cmd="./leon --debug=report --timeout=30 --synthesis --cegis:maxsize=5 $1" echo "Running " $cmd echo "------------------------------------------------------------------------------------------------------------------" $cmd; @@ -35,9 +35,9 @@ run testcases/synthesis/etienne-thesis/UnaryNumerals/Distinct.scala run testcases/synthesis/etienne-thesis/UnaryNumerals/Mult.scala # BatchedQueue -#run testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala +run testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala run testcases/synthesis/etienne-thesis/BatchedQueue/Dequeue.scala # AddressBook -#run testcases/synthesis/etienne-thesis/AddressBook/Make.scala +run testcases/synthesis/etienne-thesis/AddressBook/Make.scala run testcases/synthesis/etienne-thesis/AddressBook/Merge.scala diff --git a/testcases/verification/strings/invalid/CompatibleListChar.scala b/testcases/verification/strings/invalid/CompatibleListChar.scala new file mode 100644 index 0000000000000000000000000000000000000000..86eec34cddcee8055c34d5ffc791b7bbf7a397e7 --- /dev/null +++ b/testcases/verification/strings/invalid/CompatibleListChar.scala @@ -0,0 +1,29 @@ +import leon.lang._ +import leon.annotation._ +import leon.collection._ +import leon.collection.ListOps._ +import leon.lang.synthesis._ + +object CompatibleListChar { + def rec[T](l : List[T], f : T => String): String = l match { + case Cons(head, tail) => f(head) + rec(tail, f) + case Nil() => "" + } + def customToString[T](l : List[T], p: List[Char], d: String, fd: String => String, fp: List[Char] => String, pf: String => List[Char], f : T => String): String = rec(l, f) ensuring { + (res : String) => (p == Nil[Char]() || d == "" || fd(d) == "" || fp(p) == "" || pf(d) == Nil[Char]()) && ((l, res) passes { + case Cons(a, Nil()) => f(a) + }) + } + def customPatternMatching(s: String): Boolean = { + s match { + case "" => true + case b => List(b) match { + case Cons("", Nil()) => true + case Cons(s, Nil()) => false // StrOps.length(s) < BigInt(2) // || (s == "\u0000") //+ "a" + case Cons(_, Cons(_, Nil())) => true + case _ => false + } + case _ => false + } + } holds +} \ No newline at end of file diff --git a/testcases/verification/xlang/AbsFun.scala b/testcases/verification/xlang/AbsFun.scala index fe37632df68bc4ae1c164bb034902e45b18365c4..a6ff9679e27fc32ff4dd8b62a1cf2170386083d8 100644 --- a/testcases/verification/xlang/AbsFun.scala +++ b/testcases/verification/xlang/AbsFun.scala @@ -35,11 +35,7 @@ object AbsFun { isPositive(t, k)) if(k < tab.length) { - val nt = if(tab(k) < 0) { - t.updated(k, -tab(k)) - } else { - t.updated(k, tab(k)) - } + val nt = t.updated(k, if(tab(k) < 0) -tab(k) else tab(k)) while0(nt, k+1, tab) } else { (t, k) @@ -54,11 +50,7 @@ object AbsFun { def property(t: Array[Int], k: Int): Boolean = { require(isPositive(t, k) && t.length >= 0 && k >= 0) if(k < t.length) { - val nt = if(t(k) < 0) { - t.updated(k, -t(k)) - } else { - t.updated(k, t(k)) - } + val nt = t.updated(k, if(t(k) < 0) -t(k) else t(k)) isPositive(nt, k+1) } else true } holds