From 63376f70fcb5409203286fa7d0abcec5de15dda9 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Wed, 10 Feb 2016 14:15:50 +0100 Subject: [PATCH] Moar quantifier fixes --- .../scala/leon/codegen/CodeGeneration.scala | 13 ++----------- .../scala/leon/codegen/CompilationUnit.scala | 17 +++++++++++++++++ .../leon/evaluators/RecursiveEvaluator.scala | 16 +++++++++------- src/main/scala/leon/purescala/ExprOps.scala | 6 +++--- src/main/scala/leon/purescala/Extractors.scala | 10 +++++----- src/main/scala/leon/purescala/TypeOps.scala | 6 ++++++ .../leon/solvers/smtlib/SMTLIBTarget.scala | 8 ++++++-- .../templates/QuantificationManager.scala | 3 +-- .../scala/leon/solvers/z3/FairZ3Solver.scala | 2 -- 9 files changed, 49 insertions(+), 32 deletions(-) diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index cc85fe1a5..9030ba304 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -8,6 +8,7 @@ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps._ import purescala.Constructors._ import purescala.Extractors._ import purescala.Quantification._ @@ -223,16 +224,6 @@ trait CodeGeneration { ch.freeze } - private def typeParameters(expr: Expr): Seq[TypeParameter] = { - var tparams: Set[TypeParameter] = Set.empty - def extractParameters(tpe: TypeTree): Unit = tpe match { - case tp: TypeParameter => tparams += tp - case NAryType(tps, _) => tps.foreach(extractParameters) - } - preTraversal(e => extractParameters(e.getType))(expr) - tparams.toSeq.sortBy(_.id.uniqueName) - } - private[codegen] val lambdaToClass = scala.collection.mutable.Map.empty[Lambda, String] private[codegen] val classToLambda = scala.collection.mutable.Map.empty[String, Lambda] @@ -241,7 +232,7 @@ trait CodeGeneration { val reverseSubst = structSubst.map(p => p._2 -> p._1) val nl = normalized.asInstanceOf[Lambda] - val tparams: Seq[TypeParameter] = typeParameters(nl) + val tparams: Seq[TypeParameter] = typeParamsOf(nl).toSeq.sortBy(_.id.uniqueName) val closedVars = purescala.ExprOps.variablesOf(nl).toSeq.sortBy(_.uniqueName) val closuresWithoutMonitor = closedVars.map(id => id -> typeToJVM(id.getType)) diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index ff3c4af10..707eae8fa 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -8,6 +8,7 @@ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps._ import purescala.Extractors._ import purescala.Constructors._ import utils.UniqueCounter @@ -248,6 +249,22 @@ class CompilationUnit(val ctx: LeonContext, } l + case l @ Lambda(args, body) => + val (afName, closures, tparams, consSig) = compileLambda(l) + val args = closures.map { case (id, _) => + if (id == monitorID) monitor + else if (id == tpsID) typeParamsOf(l).toSeq.sortBy(_.id.uniqueName).map(registerType).toArray + else throw CompilationException(s"Unexpected closure $id in Lambda compilation") + } + + val lc = loader.loadClass(afName) + val conss = lc.getConstructors.sortBy(_.getParameterTypes.length) + println(conss) + assert(conss.nonEmpty) + val lambdaConstructor = conss.last + println(args.toArray) + lambdaConstructor.newInstance(args.toArray : _*).asInstanceOf[AnyRef] + case f @ IsTyped(FiniteArray(elems, default, IntLiteral(length)), ArrayType(underlying)) => if (length < 0) { throw LeonFatalError( diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 4bbdecc1d..2ebb5ade5 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -501,13 +501,15 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int FiniteSet(els.map(e), base) case l @ Lambda(_, _) => - val (nl, structSubst) = normalizeStructure(matchToIfThenElse(l)) - val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap - val newLambda = replaceFromIDs(mapping, nl).asInstanceOf[Lambda] - if (!gctx.lambdas.isDefinedAt(newLambda)) { - gctx.lambdas += (newLambda -> nl.asInstanceOf[Lambda]) - } - newLambda + val mapping = variablesOf(l).map(id => id -> e(Variable(id))).toMap + val newLambda = replaceFromIDs(mapping, l).asInstanceOf[Lambda] + val (normalized, _) = normalizeStructure(matchToIfThenElse(newLambda)) + val nl = normalized.asInstanceOf[Lambda] + if (!gctx.lambdas.isDefinedAt(nl)) { + val (norm, _) = normalizeStructure(matchToIfThenElse(l)) + gctx.lambdas += (nl -> norm.asInstanceOf[Lambda]) + } + nl case FiniteLambda(mapping, dflt, tpe) => FiniteLambda(mapping.map(p => p._1.map(e) -> e(p._2)), e(dflt), tpe) diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index af5a04eb8..2672cfe7f 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -1994,8 +1994,8 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { Let(i, e, apply(b, args)) case LetTuple(is, es, b) => letTuple(is, es, apply(b, args)) - case l @ Lambda(params, body) => - l.withParamSubst(args, body) + //case l @ Lambda(params, body) => + // l.withParamSubst(args, body) case _ => Application(expr, args) } @@ -2017,7 +2017,7 @@ object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { case Application(caller, args) => val newArgs = args.map(rec(_, true)) val newCaller = rec(caller, false) - extract(application(newCaller, newArgs), build) + extract(Application(newCaller, newArgs), build) case FunctionInvocation(fd, args) => val newArgs = args.map(rec(_, true)) extract(FunctionInvocation(fd, newArgs), build) diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index 00541f445..49e6afd3a 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -145,7 +145,7 @@ object Extractors { Some(Seq(t1, t2), (es: Seq[Expr]) => SetUnion(es(0), es(1))) case SetDifference(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => SetDifference(es(0), es(1))) - case mg@MapApply(t1, t2) => + case mg @ MapApply(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => MapApply(es(0), es(1))) case MapUnion(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => MapUnion(es(0), es(1))) @@ -165,9 +165,9 @@ object Extractors { Some(Seq(const, body), (es: Seq[Expr]) => Assert(es(0), oerr, es(1))) /* Other operators */ - case fi@FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _))) - case mi@MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, as => MethodInvocation(as.head, cd, tfd, as.tail))) - case fa@Application(caller, args) => Some(caller +: args, as => application(as.head, as.tail)) + case fi @ FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _))) + case mi @ MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, as => MethodInvocation(as.head, cd, tfd, as.tail))) + case fa @ Application(caller, args) => Some(caller +: args, as => Application(as.head, as.tail)) case CaseClass(cd, args) => Some((args, CaseClass(cd, _))) case And(args) => Some((args, and)) case Or(args) => Some((args, or)) @@ -197,7 +197,7 @@ object Extractors { val l = as.length nonemptyArray(as.take(l - 2), Some((as(l - 2), as(l - 1)))) })) - case na@NonemptyArray(elems, None) => + case na @ NonemptyArray(elems, None) => val ArrayType(tpe) = na.getType val (indexes, elsOrdered) = elems.toSeq.unzip diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 3310a8a8b..fedac58ec 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -23,6 +23,12 @@ object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree })(t) } + def typeParamsOf(expr: Expr): Set[TypeParameter] = { + var tparams: Set[TypeParameter] = Set.empty + ExprOps.preTraversal(e => typeParamsOf(e.getType))(expr) + tparams + } + def canBeSubtypeOf( tpe: TypeTree, freeParams: Seq[TypeParameter], diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index d239de90d..4b8386949 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -20,7 +20,7 @@ import _root_.smtlib.printer.{ RecursivePrinter => SMTPrinter } import _root_.smtlib.parser.Commands.{ Constructor => SMTConstructor, FunDef => SMTFunDef, - Assert => _, + Assert => SMTAssert, _ } import _root_.smtlib.parser.Terms.{ @@ -533,7 +533,11 @@ trait SMTLIBTarget extends Interruptible { case gv @ GenericValue(tpe, n) => genericValues.cachedB(gv) { - declareVariable(FreshIdentifier("gv" + n, tpe)) + val v = declareVariable(FreshIdentifier("gv" + n, tpe)) + for ((ogv, ov) <- genericValues.aToB if ogv.getType == tpe) { + emit(SMTAssert(Core.Not(Core.Equals(v, ov)))) + } + v } /** diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index b60ac4eb9..84f8b707f 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -441,8 +441,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage case Left(quant) if quantified(quant) => subst += quant -> arg case Right(qam) => - val argVal = arg.encoded - eqConstraints += (qam.encoded -> argVal) + eqConstraints += (qam.encoded -> arg.encoded) } val substituter = encoder.substitute(subst.mapValues(_.encoded)) diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 1fb008834..ca1b929e1 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -58,14 +58,12 @@ class FairZ3Solver(val context: LeonContext, val program: Program) r } - /* override def solverCheckAssumptions[R](assumptions: Seq[Z3AST])(block: Option[Boolean] => R): R = { solver.push() // FIXME: remove when z3 bug is fixed val res = solver.checkAssumptions(assumptions : _*) solver.pop() // FIXME: remove when z3 bug is fixed block(res) } - */ def solverGetModel: ModelWrapper = new ModelWrapper { val model = solver.getModel -- GitLab