diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index fb64b3311dfdd2af5b06df85ab2cb26088d3233e..9720666d145de9c9d1bf7af5ff2a91548509b8e2 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -509,9 +509,8 @@ trait ASTExtractors { true case _ => false } - } - + object ExDefaultValueFunction{ /** Matches a function that defines the default value of a parameter */ def unapply(dd: DefDef): Option[(Symbol, Seq[Symbol], Seq[ValDef], Type, String, Int, Tree)] = { @@ -520,11 +519,11 @@ trait ASTExtractors { case DefDef(_, name, tparams, vparamss, tpt, rhs) if( vparamss.size <= 1 && name != nme.CONSTRUCTOR && sym.isSynthetic ) => - + // Split the name into pieces, to find owner of the parameter + param.index // Form has to be <owner name>$default$<param index> val symPieces = sym.name.toString.reverse.split("\\$", 3).reverseMap(_.reverse) - + try { if (symPieces(1) != "default" || symPieces(0) == "copy") throw new IllegalArgumentException("") val ownerString = symPieces(0) diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 43d977fb3a273014e8d7e2874e05996268c0b465..388639adb7ad497b8af4a14f28055a2674124489 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -522,7 +522,7 @@ trait CodeExtraction extends ASTExtractors { val fields = args.map { case (fsym, t) => val tpe = leonType(t.tpt.tpe)(defCtx, fsym.pos) val id = cachedWithOverrides(fsym, Some(ccd), tpe) - LeonValDef(id.setPos(t.pos), Some(tpe)).setPos(t.pos) + LeonValDef(id.setPos(t.pos)).setPos(t.pos) } //println(s"Fields of $sym") ccd.setFields(fields) @@ -629,9 +629,10 @@ trait CodeExtraction extends ASTExtractors { val newParams = sym.info.paramss.flatten.map{ sym => val ptpe = leonType(sym.tpe)(nctx, sym.pos) - val newID = FreshIdentifier(sym.name.toString, ptpe).setPos(sym.pos) + val tpe = if (sym.isByNameParam) FunctionType(Seq(), ptpe) else ptpe + val newID = FreshIdentifier(sym.name.toString, tpe).setPos(sym.pos) owners += (newID -> None) - LeonValDef(newID).setPos(sym.pos) + LeonValDef(newID, sym.isByNameParam).setPos(sym.pos) } val tparamsDef = tparams.map(t => TypeParameterDef(t._2)) @@ -768,8 +769,9 @@ trait CodeExtraction extends ASTExtractors { vd.defaultValue = paramsToDefaultValues.get(s.symbol) } - val newVars = for ((s, vd) <- params zip funDef.params) yield { - s.symbol -> (() => Variable(vd.id)) + val newVars = for ((s, vd) <- params zip funDef.params) yield s.symbol -> { + if (s.symbol.isByNameParam) () => Application(Variable(vd.id), Seq()) + else () => Variable(vd.id) } val fctx = dctx.withNewVars(newVars).copy(isExtern = funDef.annotations("extern")) @@ -1530,23 +1532,24 @@ trait CodeExtraction extends ASTExtractors { val fd = getFunDef(sym, c.pos) val newTps = tps.map(t => extractType(t)) + val argsByName = (fd.params zip args).map(p => if (p._1.isLazy) Lambda(Seq(), p._2) else p._2) - FunctionInvocation(fd.typed(newTps), args) + FunctionInvocation(fd.typed(newTps), argsByName) case (IsTyped(rec, ct: ClassType), _, args) if isMethod(sym) => val fd = getFunDef(sym, c.pos) val cd = methodToClass(fd) val newTps = tps.map(t => extractType(t)) + val argsByName = (fd.params zip args).map(p => if (p._1.isLazy) Lambda(Seq(), p._2) else p._2) - MethodInvocation(rec, cd, fd.typed(newTps), args) + MethodInvocation(rec, cd, fd.typed(newTps), argsByName) case (IsTyped(rec, ft: FunctionType), _, args) => application(rec, args) - case (IsTyped(rec, cct: CaseClassType), name, Nil) if cct.fields.exists(_.id.name == name) => - - val fieldID = cct.fields.find(_.id.name == name).get.id + case (IsTyped(rec, cct: CaseClassType), name, Nil) if cct.classDef.fields.exists(_.id.name == name) => + val fieldID = cct.classDef.fields.find(_.id.name == name).get.id caseClassSelector(cct, rec, fieldID) diff --git a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala index a53456e73f72d323aca3597623ab9998c106bf94..3fe13e7c93fefdfcc4321159b0fa7b354788a801 100644 --- a/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala +++ b/src/main/scala/leon/invariant/engine/UnfoldingTemplateSolver.scala @@ -220,7 +220,7 @@ class UnfoldingTemplateSolver(ctx: InferenceContext, program: Program, rootFd: F val resvar = FreshIdentifier("res", fd.returnType, true) // FIXME: Is this correct (ResultVariable(fd.returnType) -> resvar.toVariable)) val ninv = replace(Map(ResultVariable(fd.returnType) -> resvar.toVariable), inv) - Some(Lambda(Seq(ValDef(resvar, Some(fd.returnType))), ninv)) + Some(Lambda(Seq(ValDef(resvar)), ninv)) } } else if (fd.postcondition.isDefined) { val Lambda(resultBinder, _) = fd.postcondition.get diff --git a/src/main/scala/leon/invariant/util/TreeUtil.scala b/src/main/scala/leon/invariant/util/TreeUtil.scala index 4494787081ace61dfe1e1c086a702da6b2be15c4..4109e59877c192570e95bb83b819ebd0a59edef7 100644 --- a/src/main/scala/leon/invariant/util/TreeUtil.scala +++ b/src/main/scala/leon/invariant/util/TreeUtil.scala @@ -43,8 +43,8 @@ object ProgramUtil { def createTemplateFun(plainTemp: Expr): FunctionInvocation = { val tmpl = Lambda(getTemplateIds(plainTemp).toSeq.map(id => ValDef(id)), plainTemp) - val tmplFd = new FunDef(FreshIdentifier("tmpl", FunctionType(Seq(tmpl.getType), BooleanType), false), Seq(), Seq(ValDef(FreshIdentifier("arg", tmpl.getType), - Some(tmpl.getType))), BooleanType) + val tmplFd = new FunDef(FreshIdentifier("tmpl", FunctionType(Seq(tmpl.getType), BooleanType), false), Seq(), + Seq(ValDef(FreshIdentifier("arg", tmpl.getType))), BooleanType) tmplFd.body = Some(BooleanLiteral(true)) FunctionInvocation(TypedFunDef(tmplFd, Seq()), Seq(tmpl)) } diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 103ed1bee94c2203d288848184d62fd6b2d2ce09..57a7d445105e41defb7adcd47da758ea1d976eaf 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -41,27 +41,21 @@ object Definitions { } } - /** A ValDef represents a parameter of a [[purescala.Definitions.FunDef function]] or - * a [[purescala.Definitions.CaseClassDef case class]]. - * - * The optional [[tpe]], if present, overrides the type of the underlying Identifier [[id]]. - * This is useful to instantiate argument types of polymorphic classes. To be consistent, - * never use the type of [[id]] directly; use [[ValDef#getType]] instead. - */ - case class ValDef(id: Identifier, tpe: Option[TypeTree] = None) extends Definition with Typed { + /** + * A ValDef declares a new identifier to be of a certain type. + * The optional tpe, if present, overrides the type of the underlying Identifier id + * This is useful to instantiate argument types of polymorphic functions + */ + case class ValDef(val id: Identifier, val isLazy: Boolean = false) extends Definition with Typed { self: Serializable => - val getType = tpe getOrElse id.getType + val getType = id.getType var defaultValue : Option[FunDef] = None def subDefinitions = Seq() - /** Transform this [[ValDef]] into a [[Expressions.Variable Variable]] - * - * Warning: the variable will not have the same type as this ValDef, but currently - * the Identifier type is enough for all uses in Leon. - */ + /** Transform this [[ValDef]] into a [[Expressions.Variable Variable]] */ def toVariable : Variable = Variable(id) } @@ -509,11 +503,10 @@ object Definitions { if (typesMap.isEmpty) { (fd.params, Map()) } else { - val newParams = fd.params.map { - case vd @ ValDef(id, _) => - val newTpe = translated(vd.getType) - val newId = FreshIdentifier(id.name, newTpe, true).copiedFrom(id) - ValDef(newId).setPos(vd) + val newParams = fd.params.map { vd => + val newTpe = translated(vd.getType) + val newId = FreshIdentifier(vd.id.name, newTpe, true).copiedFrom(vd.id) + vd.copy(id = newId).setPos(vd) } val paramsMap: Map[Identifier, Identifier] = (fd.params zip newParams).map { case (vd1, vd2) => vd1.id -> vd2.id }.toMap diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 84817e872598f75d8b2902aea003896d2eb0cd90..565f1d3f6f2d44d8e3eb07f436431eddb43f746e 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -540,8 +540,8 @@ object ExprOps { } val normalized = postMap { - case Lambda(args, body) => Some(Lambda(args.map(vd => ValDef(subst(vd.id), vd.tpe)), body)) - case Forall(args, body) => Some(Forall(args.map(vd => ValDef(subst(vd.id), vd.tpe)), body)) + case Lambda(args, body) => Some(Lambda(args.map(vd => vd.copy(id = subst(vd.id))), body)) + case Forall(args, body) => Some(Forall(args.map(vd => vd.copy(id = subst(vd.id))), body)) case Let(i, e, b) => Some(Let(subst(i), e, b)) case MatchExpr(scrut, cses) => Some(MatchExpr(scrut, cses.map { cse => cse.copy(pattern = replacePatternBinders(cse.pattern, subst)) @@ -792,7 +792,7 @@ object ExprOps { }) case CaseClassPattern(_, cct, subps) => - val subExprs = (subps zip cct.fields) map { + val subExprs = (subps zip cct.classDef.fields) map { case (p, f) => p.binder.map(_.toVariable).getOrElse(caseClassSelector(cct, in, f.id)) } @@ -869,8 +869,8 @@ object ExprOps { } case CaseClassPattern(ob, cct, subps) => - assert(cct.fields.size == subps.size) - val pairs = cct.fields.map(_.id).toList zip subps.toList + assert(cct.classDef.fields.size == subps.size) + val pairs = cct.classDef.fields.map(_.id).toList zip subps.toList val subTests = pairs.map(p => rec(caseClassSelector(cct, in, p._1), p._2)) val together = and(bind(ob, in) +: subTests :_*) and(IsInstanceOf(in, cct), together) @@ -905,7 +905,7 @@ object ExprOps { pattern match { case CaseClassPattern(b, cct, subps) => assert(cct.fields.size == subps.size) - val pairs = cct.fields.map(_.id).toList zip subps.toList + val pairs = cct.classDef.fields.map(_.id).toList zip subps.toList val subMaps = pairs.map(p => mapForPattern(caseClassSelector(cct, asInstOf(in, cct), p._1), p._2)) val together = subMaps.flatten.toMap bindIn(b, Some(cct)) ++ together diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 6a8c8f9c9415a6ebc3b9c46de5d5e326a50858da..8cfae30d2ebfa070ceb5f50fb58e1eab1c70ccd6 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -112,7 +112,7 @@ object MethodLifting extends TransformationPhase { val fdParams = fd.params map { vd => val newId = FreshIdentifier(vd.id.name, tSubst(vd.id.getType)) - ValDef(newId).setPos(vd.getPos) + vd.copy(id = newId).setPos(vd.getPos) } val paramsMap = fd.params.zip(fdParams).map{ case (from, to) => from.id -> to.id }.toMap val eSubst: Expr => Expr = instantiateType(_, tMap, paramsMap) @@ -140,7 +140,7 @@ object MethodLifting extends TransformationPhase { val retType = instantiateType(fd.returnType, tparamsMap) val fdParams = fd.params map { vd => val newId = FreshIdentifier(vd.id.name, instantiateType(vd.id.getType, tparamsMap)) - ValDef(newId).setPos(vd.getPos) + vd.copy(id = newId).setPos(vd.getPos) } val receiver = FreshIdentifier("thiss", recType).setPos(cd.id) diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 755c540d1f3b23c505488d49ecd945fa324d0080..9a017e7df4f974475986e4939c20bfef0116ae60 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -334,8 +334,8 @@ class PrettyPrinter(opts: PrinterOptions, case Not(expr) => p"\u00AC$expr" - case vd@ValDef(id, _) => - p"$id : ${vd.getType}" + case vd @ ValDef(id, lzy) => + p"$id :${if (lzy) "=> " else ""} ${vd.getType}" vd.defaultValue.foreach { fd => p" = ${fd.body.get}" } case This(_) => p"this" diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 51bed3eaf0f5fc2586f25fc7ff38cd6d8d3857b9..3644191c3261e238e1caf4484604f90dcf920978 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -185,14 +185,6 @@ object TypeOps { freshId(id, typeParamSubst(tps map { case (tpd, tp) => tpd.tp -> tp })(id.getType)) } - def instantiateType(vd: ValDef, tps: Map[TypeParameterDef, TypeTree]): ValDef = { - val ValDef(id, forcedType) = vd - ValDef( - freshId(id, instantiateType(id.getType, tps)), - forcedType map ((tp: TypeTree) => instantiateType(tp, tps)) - ) - } - def instantiateType(tpe: TypeTree, tps: Map[TypeParameterDef, TypeTree]): TypeTree = { if (tps.isEmpty) { tpe @@ -313,7 +305,7 @@ object TypeOps { TypeParameterDef(tpeSub(p.tp).asInstanceOf[TypeParameter]) } val returnType = tpeSub(fd.returnType) - val params = fd.params map (instantiateType(_, tps)) + val params = fd.params map (vd => vd.copy(id = freshId(vd.id, tpeSub(vd.getType)))) val newFd = fd.duplicate(id, tparams, params, returnType) val subCalls = preMap { @@ -332,7 +324,7 @@ object TypeOps { case l @ Lambda(args, body) => val newArgs = args.map { arg => val tpe = tpeSub(arg.getType) - ValDef(freshId(arg.id, tpe)) + arg.copy(id = freshId(arg.id, tpe)) } val mapping = args.map(_.id) zip newArgs.map(_.id) Lambda(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(l) @@ -340,7 +332,7 @@ object TypeOps { case f @ Forall(args, body) => val newArgs = args.map { arg => val tpe = tpeSub(arg.getType) - ValDef(freshId(arg.id, tpe)) + arg.copy(id = freshId(arg.id, tpe)) } val mapping = args.map(_.id) zip newArgs.map(_.id) Forall(newArgs, rec(idsMap ++ mapping)(body)).copiedFrom(f) diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index d39fda3338fbeab4dfa9c453a58afa6298bb6ccd..626cee7d0cac6c4692cfc11ce7ad9d2f2c954e9e 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -102,8 +102,14 @@ object Types { if (tmap.isEmpty) { classDef.fields } else { - // This is the only case where ValDef overrides the type of its Identifier - classDef.fields.map(vd => ValDef(vd.id, Some(instantiateType(vd.getType, tmap)))) + // !! WARNING !! + // vd.id changes but this should not be an issue as selector uses + // classDef.params ids which do not change! + classDef.fields.map { vd => + val newTpe = instantiateType(vd.getType, tmap) + val newId = FreshIdentifier(vd.id.name, newTpe).copiedFrom(vd.id) + vd.copy(id = newId).setPos(vd) + } } } diff --git a/src/main/scala/leon/solvers/isabelle/AdaptationPhase.scala b/src/main/scala/leon/solvers/isabelle/AdaptationPhase.scala index 911af8d84b6b6ccc94dd6ed781a77ca99e8a7771..2d3218c7b09edfa465618117a66b7f428f755b1c 100644 --- a/src/main/scala/leon/solvers/isabelle/AdaptationPhase.scala +++ b/src/main/scala/leon/solvers/isabelle/AdaptationPhase.scala @@ -26,7 +26,7 @@ object AdaptationPhase extends TransformationPhase { CaseClassType(dummy, List(tp)) def mkDummyParameter(tp: TypeParameter) = - ValDef(FreshIdentifier("dummy", mkDummyTyp(tp)), Some(mkDummyTyp(tp))) + ValDef(FreshIdentifier("dummy", mkDummyTyp(tp))) def mkDummyArgument(tree: TypeTree) = CaseClass(CaseClassType(dummy, List(tree)), Nil) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala index dc267825c8356451944b5fb6b7f9de5d55504232..e2eda5e1cc84da7e72fcea0f4b1ce24ff6d91fc9 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Solver.scala @@ -19,8 +19,10 @@ class SMTLIBCVC4Solver(context: LeonContext, program: Program) extends SMTLIBSol Seq( "-q", "--produce-models", - "--no-incremental", - "--tear-down-incremental", + "--incremental", +// "--no-incremental", +// "--tear-down-incremental", +// "--dt-rewrite-error-sel", // Removing since it causes CVC4 to segfault on some inputs "--rewrite-divk", "--print-success", "--lang", "smt" diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index f13ed21a2929a0b845e9c0138092bb36b5fc6428..72ecfb9d34039ad03675d2d5523c34462f9de20e 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -9,7 +9,7 @@ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Definitions._ -import _root_.smtlib.parser.Commands.{Assert => SMTAssert, _} +import _root_.smtlib.parser.Commands.{Assert => SMTAssert, FunDef => SMTFunDef, _} import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.CommandsResponses.{Error => ErrorResponse, _} @@ -62,27 +62,37 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) } protected def getModel(filter: Identifier => Boolean): Model = { - val syms = variables.aSet.filter(filter).toList.map(variables.aToB) + val syms = variables.aSet.filter(filter).map(variables.aToB) if (syms.isEmpty) { Model.empty } else { try { - val cmd: Command = GetValue( - syms.head, - syms.tail.map(s => QualifiedIdentifier(SMTIdentifier(s))) - ) + val cmd = GetModel() emit(cmd) match { - case GetValueResponseSuccess(valuationPairs) => + case GetModelResponseSuccess(smodel) => + var modelFunDefs = Map[SSymbol, DefineFun]() - new Model(valuationPairs.collect { - case (SimpleSymbol(sym), value) if variables.containsB(sym) => - val id = variables.toA(sym) + // first-pass to gather functions + for (me <- smodel) me match { + case me @ DefineFun(SMTFunDef(a, args, _, _)) if args.nonEmpty => + modelFunDefs += a -> me + case _ => + } + + var model = Map[Identifier, Expr]() + + for (me <- smodel) me match { + case DefineFun(SMTFunDef(s, args, kind, e)) if syms(s) => + val id = variables.toA(s) + model += id -> fromSMT(e, id.getType)(Map(), modelFunDefs) + case _ => + } + + new Model(model) - (id, fromSMT(value, id.getType)(Map(), Map())) - }.toMap) case _ => - Model.empty //FIXME improve this + Model.empty // FIXME improve this } } catch { case e : SMTLIBUnsupportedError => @@ -100,6 +110,7 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) variables.push() genericValues.push() sorts.push() + lambdas.push() functions.push() errors.push() @@ -113,6 +124,7 @@ abstract class SMTLIBSolver(val context: LeonContext, val program: Program) variables.pop() genericValues.pop() sorts.pop() + lambdas.pop() functions.pop() errors.pop() diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 07c2eaa567dd41aca7e8f239910b9856ae3ad268..ac5e0fd4bce6a8eb70faf590c25dec9dc630331c 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -11,6 +11,7 @@ import purescala.Expressions._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps._ import purescala.Constructors._ import purescala.Definitions._ @@ -18,7 +19,7 @@ import _root_.smtlib.common._ import _root_.smtlib.printer.{ RecursivePrinter => SMTPrinter } import _root_.smtlib.parser.Commands.{ Constructor => SMTConstructor, - FunDef => _, + FunDef => SMTFunDef, Assert => _, _ } @@ -147,14 +148,15 @@ trait SMTLIBTarget extends Interruptible { protected def freshSym(name: String): SSymbol = id2sym(FreshIdentifier(name)) /* Metadata for CC, and variables */ - protected val constructors = new IncrementalBijection[TypeTree, SSymbol]() - protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() - protected val testers = new IncrementalBijection[TypeTree, SSymbol]() - protected val variables = new IncrementalBijection[Identifier, SSymbol]() + protected val constructors = new IncrementalBijection[TypeTree, SSymbol]() + protected val selectors = new IncrementalBijection[(TypeTree, Int), SSymbol]() + protected val testers = new IncrementalBijection[TypeTree, SSymbol]() + protected val variables = new IncrementalBijection[Identifier, SSymbol]() protected val genericValues = new IncrementalBijection[GenericValue, SSymbol]() - protected val sorts = new IncrementalBijection[TypeTree, Sort]() - protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() - protected val errors = new IncrementalBijection[Unit, Boolean]() + protected val sorts = new IncrementalBijection[TypeTree, Sort]() + protected val functions = new IncrementalBijection[TypedFunDef, SSymbol]() + protected val lambdas = new IncrementalBijection[FunctionType, SSymbol]() + protected val errors = new IncrementalBijection[Unit, Boolean]() protected def hasError = errors.getB(()) contains true protected def addError() = errors += () -> true @@ -247,7 +249,7 @@ trait SMTLIBTarget extends Interruptible { declareSort(RawArrayType(from, library.optionType(to))) case FunctionType(from, to) => - Sort(SMTIdentifier(SSymbol("Array")), Seq(declareSort(tupleTypeWrap(from)), declareSort(to))) + Ints.IntSort() case tp: TypeParameter => declareUninterpretedSort(tp) @@ -330,6 +332,20 @@ trait SMTLIBTarget extends Interruptible { } } + protected def declareLambda(tpe: FunctionType): SSymbol = { + val realTpe = bestRealType(tpe).asInstanceOf[FunctionType] + lambdas.cachedB(realTpe) { + val id = FreshIdentifier("dynLambda") + val s = id2sym(id) + emit(DeclareFun( + s, + (realTpe +: realTpe.from).map(declareSort), + declareSort(realTpe.to) + )) + s + } + } + /* Translate a Leon Expr to an SMTLIB term */ def sortToSMT(s: Sort): SExpr = { @@ -523,7 +539,10 @@ trait SMTLIBTarget extends Interruptible { * ===== Everything else ===== */ case ap @ Application(caller, args) => - ArraysEx.Select(toSMT(caller), toSMT(tupleWrap(args))) + FunctionApplication( + declareLambda(caller.getType.asInstanceOf[FunctionType]), + (caller +: args).map(toSMT) + ) case Not(u) => Core.Not(toSMT(u)) case UMinus(u) => Ints.Neg(toSMT(u)) @@ -645,6 +664,67 @@ trait SMTLIBTarget extends Interruptible { } } + case (SNumeral(n), Some(ft @ FunctionType(from, to))) => + val dynLambda = lambdas.toB(ft) + letDefs.get(dynLambda) match { + case Some(DefineFun(SMTFunDef(a, SortedVar(dispatcher, dkind) +: args, rkind, body))) => + + object EQ { + def unapply(t: Term): Option[(Term, Term)] = t match { + case Core.Equals(e1, e2) => Some((e1, e2)) + case FunctionApplication(f, Seq(e1, e2)) if f.toString == "=" => Some((e1, e2)) + case _ => None + } + } + + object Num { + def unapply(t: Term): Option[BigInt] = t match { + case SNumeral(n) => Some(n) + case FunctionApplication(f, Seq(SNumeral(n))) if f.toString == "-" => Some(-n) + case _ => None + } + } + + val d = symbolToQualifiedId(dispatcher) + def dispatch(t: Term): Term = t match { + case Core.ITE(EQ(di, Num(ni)), thenn, elze) if di == d => + if (ni == n) thenn else dispatch(elze) + case Core.ITE(Core.And(EQ(di, Num(ni)), _), thenn, elze) if di == d => + if (ni == n) thenn else dispatch(elze) + case _ => t + } + + def extract(t: Term): Expr = { + def recCond(term: Term, index: Int): Seq[Expr] = term match { + case Core.And(e1, e2) => + val e1s = recCond(e1, index) + e1s ++ recCond(e2, index + e1s.size) + case EQ(e1, e2) => + recCond(e2, index) + case _ => Seq(fromSMT(term, from(index))) + } + + def recCases(term: Term, matchers: Seq[Expr]): Seq[(Seq[Expr], Expr)] = term match { + case Core.ITE(cond, thenn, elze) => + val cs = recCond(cond, matchers.size) + recCases(thenn, matchers ++ cs) ++ recCases(elze, matchers) + case _ => Seq(matchers -> fromSMT(term, to)) + } + + val cases = recCases(t, Seq.empty) + val (default, rest) = cases.partition(_._1.isEmpty) + + assert(default.size == 1 && rest.forall(_._1.size == from.size)) + PartialLambda(rest, Some(default.head._2), ft) + } + + val lambdaTerm = dispatch(body) + val lambda = extract(lambdaTerm) + lambda + + case None => unsupported(InfiniteIntegerLiteral(n), "Unknown function ref") + } + case (SNumeral(n), Some(RealType)) => FractionalLiteral(n, 1) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index f9fe677fd88f1eb1ae1d1f71e50a4647a6781743..ca2c0e714827b0213667ee479e5338c1292abb6c 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -91,10 +91,11 @@ trait AbstractZ3Solver extends Solver { protected val adtManager = new ADTManager(context) // Bijections between Leon Types/Functions/Ids to Z3 Sorts/Decls/ASTs - protected val functions = new IncrementalBijection[TypedFunDef, Z3FuncDecl]() - protected val generics = new IncrementalBijection[GenericValue, Z3FuncDecl]() - protected val sorts = new IncrementalBijection[TypeTree, Z3Sort]() - protected val variables = new IncrementalBijection[Expr, Z3AST]() + protected val functions = new IncrementalBijection[TypedFunDef, Z3FuncDecl]() + protected val generics = new IncrementalBijection[GenericValue, Z3FuncDecl]() + protected val lambdas = new IncrementalBijection[FunctionType, Z3FuncDecl]() + protected val sorts = new IncrementalBijection[TypeTree, Z3Sort]() + protected val variables = new IncrementalBijection[Expr, Z3AST]() protected val constructors = new IncrementalBijection[TypeTree, Z3FuncDecl]() protected val selectors = new IncrementalBijection[(TypeTree, Int), Z3FuncDecl]() @@ -108,6 +109,7 @@ trait AbstractZ3Solver extends Solver { z3 = new Z3Context(z3cfg) functions.clear() + lambdas.clear() generics.clear() sorts.clear() variables.clear() @@ -190,7 +192,6 @@ trait AbstractZ3Solver extends Solver { } } } - } // Prepares some of the Z3 sorts, but *not* the tuple sorts; these are created on-demand. @@ -230,7 +231,6 @@ trait AbstractZ3Solver extends Solver { declareStructuralSort(tpe) } - case tt @ SetType(base) => sorts.cachedB(tt) { z3.mkSetSort(typeToSort(base)) @@ -257,10 +257,8 @@ trait AbstractZ3Solver extends Solver { case ft @ FunctionType(from, to) => sorts.cachedB(ft) { - val fromSort = typeToSort(tupleTypeWrap(from)) - val toSort = typeToSort(to) - - z3.mkArraySort(fromSort, toSort) + val symbol = z3.mkFreshStringSymbol("fun") + z3.mkUninterpretedSort(symbol) } case other => @@ -496,7 +494,15 @@ trait AbstractZ3Solver extends Solver { z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) case fa @ Application(caller, args) => - z3.mkSelect(rec(caller), rec(tupleWrap(args))) + val ft @ FunctionType(froms, to) = bestRealType(caller.getType) + val funDecl = lambdas.cachedB(ft) { + val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) + val returnSort = typeToSort(to) + + val name = FreshIdentifier("dynLambda").uniqueName + z3.mkFreshFuncDecl(name, sortSeq, returnSort) + } + z3.mkApp(funDecl, (caller +: args).map(rec): _*) case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) @@ -563,8 +569,25 @@ trait AbstractZ3Solver extends Solver { def rec(t: Z3AST, tpe: TypeTree): Expr = { val kind = z3.getASTKind(t) - kind match { - case Z3NumeralIntAST(Some(v)) => { + (kind, tpe) match { + case (Z3NumeralIntAST(Some(v)), ft @ FunctionType(fts, tt)) => lambdas.getB(ft) match { + case None => throw new IllegalArgumentException + case Some(decl) => model.getModelFuncInterpretations.find(_._1 == decl) match { + case None => throw new IllegalArgumentException + case Some((_, mapping, elseValue)) => + val lambdaID = InfiniteIntegerLiteral(v) + val leonElseValue = rec(elseValue, tt) + PartialLambda(mapping.flatMap { case (z3Args, z3Result) => + z3.getASTKind(z3Args.head) match { + case Z3NumeralIntAST(Some(v)) if InfiniteIntegerLiteral(v) == lambdaID => + List((z3Args.tail zip fts).map(p => rec(p._1, p._2)) -> rec(z3Result, tt)) + case _ => Nil + } + }, Some(leonElseValue), ft) + } + } + + case (Z3NumeralIntAST(Some(v)), _) => val leading = t.toString.substring(0, 2 min t.toString.length) if(leading == "#x") { _root_.smtlib.common.Hexadecimal.fromString(t.toString.substring(2)) match { @@ -580,8 +603,8 @@ trait AbstractZ3Solver extends Solver { } else { InfiniteIntegerLiteral(v) } - } - case Z3NumeralIntAST(None) => { + + case (Z3NumeralIntAST(None), _) => _root_.smtlib.common.Hexadecimal.fromString(t.toString.substring(2)) match { case Some(hexa) => tpe match { @@ -591,9 +614,10 @@ trait AbstractZ3Solver extends Solver { } case None => unsound(t, "could not translate Z3NumeralIntAST numeral") } - } - case Z3NumeralRealAST(n: BigInt, d: BigInt) => FractionalLiteral(n, d) - case Z3AppAST(decl, args) => + + case (Z3NumeralRealAST(n: BigInt, d: BigInt), _) => FractionalLiteral(n, d) + + case (Z3AppAST(decl, args), _) => val argsSize = args.size if(argsSize == 0 && (variables containsB t)) { variables.toA(t) diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index c129243d15b661d9e42d578447fdea6cb1aea3e0..6a96848c5fb283ec1a6ae22817afff95f9300122 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -136,23 +136,21 @@ class FairZ3Solver(val context: LeonContext, val program: Program) private val freeVars = new IncrementalSet[Identifier]() private val constraints = new IncrementalSeq[Expr]() - val unrollingBank = new UnrollingBank(context, templateGenerator) + private val incrementals: List[IncrementalState] = List( + errors, freeVars, constraints, functions, generics, lambdas, sorts, variables, + constructors, selectors, testers, unrollingBank + ) + def push() { - errors.push() solver.push() - unrollingBank.push() - freeVars.push() - constraints.push() + incrementals.foreach(_.push()) } def pop() { - errors.pop() solver.pop(1) - unrollingBank.pop() - freeVars.pop() - constraints.pop() + incrementals.foreach(_.pop()) } override def check: Option[Boolean] = { diff --git a/src/main/scala/leon/transformations/InstrumentationUtil.scala b/src/main/scala/leon/transformations/InstrumentationUtil.scala index 2c461e466298614ebae30df8a33ff901fea72fa5..1c3b744d8f667b3a1bb7b653ff299868fe98b38a 100644 --- a/src/main/scala/leon/transformations/InstrumentationUtil.scala +++ b/src/main/scala/leon/transformations/InstrumentationUtil.scala @@ -63,7 +63,7 @@ object InstUtil { val vary = yid.toVariable val args = Seq(xid, yid) val maxType = FunctionType(Seq(IntegerType, IntegerType), IntegerType) - val mfd = new FunDef(FreshIdentifier("max", maxType, false), Seq(), args.map((arg) => ValDef(arg, Some(arg.getType))), IntegerType) + val mfd = new FunDef(FreshIdentifier("max", maxType, false), Seq(), args.map(arg => ValDef(arg)), IntegerType) val cond = GreaterEquals(varx, vary) mfd.body = Some(IfExpr(cond, varx, vary)) diff --git a/src/main/scala/leon/transformations/IntToRealProgram.scala b/src/main/scala/leon/transformations/IntToRealProgram.scala index 21229282b398fcc10cd9548ddaead2f25d8087a3..5ab16991398e5dcb4352e870795fff51045fb9c9 100644 --- a/src/main/scala/leon/transformations/IntToRealProgram.scala +++ b/src/main/scala/leon/transformations/IntToRealProgram.scala @@ -72,8 +72,7 @@ abstract class ProgramTypeTransformer { } def mapDecl(decl: ValDef): ValDef = { - val newtpe = mapType(decl.getType) - new ValDef(mapId(decl.id), Some(newtpe)) + decl.copy(id = mapId(decl.id)) } def mapType(tpe: TypeTree): TypeTree = { @@ -141,9 +140,9 @@ abstract class ProgramTypeTransformer { // FIXME //add a new postcondition newfd.fullBody = if (fd.postcondition.isDefined && newfd.body.isDefined) { - val Lambda(Seq(ValDef(resid, _)), pexpr) = fd.postcondition.get + val Lambda(Seq(ValDef(resid, lzy)), pexpr) = fd.postcondition.get val tempRes = mapId(resid).toVariable - Ensuring(newfd.body.get, Lambda(Seq(ValDef(tempRes.id, Some(tempRes.getType))), transformExpr(pexpr))) + Ensuring(newfd.body.get, Lambda(Seq(ValDef(tempRes.id, lzy)), transformExpr(pexpr))) // Some(mapId(resid), transformExpr(pexpr)) } else NoTree(fd.returnType) @@ -233,4 +232,4 @@ class RealToIntProgram extends ProgramTypeTransformer { } def mappedFun(fd: FunDef): FunDef = newFundefs(fd) -} \ No newline at end of file +} diff --git a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala index d17968dc05b789b936554b9c10dbde22ac360dfb..a696dae0580baf44c6ef955c72ccdcce9791bdab 100644 --- a/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala +++ b/src/main/scala/leon/transformations/NonlinearityEliminationPhase.scala @@ -26,7 +26,7 @@ object MultFuncs { val vary = yid.toVariable val args = Seq(xid, yid) val funcType = FunctionType(Seq(domain, domain), domain) - val mfd = new FunDef(FreshIdentifier("pmult", funcType, false), Seq(), args.map((arg) => ValDef(arg, Some(arg.getType))), domain) + val mfd = new FunDef(FreshIdentifier("pmult", funcType, false), Seq(), args.map(arg => ValDef(arg)), domain) val tmfd = TypedFunDef(mfd, Seq()) //define a body (a) using mult(x,y) = if(x == 0 || y ==0) 0 else mult(x-1,y) + y @@ -47,7 +47,7 @@ object MultFuncs { val post1 = Implies(guard, defn2) // mfd.postcondition = Some((resvar.id, And(Seq(post0, post1)))) - mfd.fullBody = Ensuring(mfd.body.get, Lambda(Seq(ValDef(resvar.id, Some(resvar.getType))), And(Seq(post0, post1)))) + mfd.fullBody = Ensuring(mfd.body.get, Lambda(Seq(ValDef(resvar.id)), And(Seq(post0, post1)))) //set function properties (for now, only monotonicity) mfd.addFlags(Set(Annotation("theoryop", Seq()), Annotation("monotonic", Seq()))) //"distributive" ? mfd @@ -59,7 +59,7 @@ object MultFuncs { val yid = FreshIdentifier("y", domain) val args = Seq(xid, yid) val funcType = FunctionType(Seq(domain, domain), domain) - val fd = new FunDef(FreshIdentifier("mult", funcType, false), Seq(), args.map((arg) => ValDef(arg, Some(arg.getType))), domain) + val fd = new FunDef(FreshIdentifier("mult", funcType, false), Seq(), args.map(arg => ValDef(arg)), domain) val tpivMultFun = TypedFunDef(pivMultFun, Seq()) //the body is defined as mult(x,y) = val px = if(x < 0) -x else x; diff --git a/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala b/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala index 87c5795838b460314d19805cbb8ca6caae4ee233..c4830cf40ca429358b0d2e420bbbe7f7d1f5c276 100644 --- a/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala +++ b/src/main/scala/leon/transformations/SerialInstrumentationPhase.scala @@ -111,7 +111,7 @@ class SerialInstrumenter(ctx: LeonContext, program: Program) { def mapPost(pred: Expr, from: FunDef, to: FunDef) = { pred match { - case Lambda(Seq(ValDef(fromRes, _)), postCond) if (instFuncs.contains(from)) => + case Lambda(Seq(ValDef(fromRes, lzy)), postCond) if (instFuncs.contains(from)) => val toResId = FreshIdentifier(fromRes.name, to.returnType, true) val newpost = postMap((e: Expr) => e match { case Variable(`fromRes`) => @@ -124,7 +124,7 @@ class SerialInstrumenter(ctx: LeonContext, program: Program) { case _ => None })(postCond) - Lambda(Seq(ValDef(toResId)), mapExpr(newpost)) + Lambda(Seq(ValDef(toResId, lzy)), mapExpr(newpost)) case _ => mapExpr(pred) } @@ -489,4 +489,4 @@ abstract class Instrumenter(program: Program, si: SerialInstrumenter) { */ def instrumentMatchCase(me: MatchExpr, mc: MatchCase, caseExprCost: Expr, scrutineeCost: Expr): Expr -} \ No newline at end of file +} diff --git a/src/test/resources/regression/verification/purescala/invalid/CallByName1.scala b/src/test/resources/regression/verification/purescala/invalid/CallByName1.scala new file mode 100644 index 0000000000000000000000000000000000000000..c96ab1617e254c19d8b08a8ecbd818d9cdc305e4 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/CallByName1.scala @@ -0,0 +1,17 @@ +import leon.lang._ + +object CallByName1 { + def byName1(i: Int, a: => Int): Int = { + if (i > 0) a + 1 + else 0 + } + + def byName2(i: Int, a: => Int): Int = { + if (i > 0) byName1(i - 1, a) + 2 + else 0 + } + + def test(): Boolean = { + byName1(1, byName2(3, 0)) == 0 && byName1(1, byName2(3, 0)) == 1 + }.holds +} diff --git a/src/test/resources/regression/verification/purescala/valid/CallByName1.scala b/src/test/resources/regression/verification/purescala/valid/CallByName1.scala new file mode 100644 index 0000000000000000000000000000000000000000..912acdc481e8191bd072bd7eaeb13684cd93c384 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/CallByName1.scala @@ -0,0 +1,9 @@ +import leon.lang._ + +object CallByName1 { + def add(a: => Int, b: => Int): Int = a + b + + def test(): Int = { + add(1,2) + } ensuring (_ == 3) +}