From 827d8f01f8975810027b8f90fd92d53641b6e186 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Tue, 9 Feb 2016 11:34:05 +0100 Subject: [PATCH] Added generics to codegen evaluator --- .../scala/leon/codegen/runtime/Monitor.scala | 18 ++++++++++++------ .../scala/leon/evaluators/DualEvaluator.scala | 10 +++++----- .../leon/evaluators/RecursiveEvaluator.scala | 2 ++ src/main/scala/leon/purescala/TypeOps.scala | 2 +- .../solvers/combinators/UnrollingSolver.scala | 16 +++++++++++----- .../templates/QuantificationManager.scala | 12 ++++++++---- 6 files changed, 39 insertions(+), 21 deletions(-) diff --git a/src/main/scala/leon/codegen/runtime/Monitor.scala b/src/main/scala/leon/codegen/runtime/Monitor.scala index ac88c35be..64d80eaa8 100644 --- a/src/main/scala/leon/codegen/runtime/Monitor.scala +++ b/src/main/scala/leon/codegen/runtime/Monitor.scala @@ -115,15 +115,21 @@ class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Id val newTypes = tps.toSeq.map(unit.runtimeIdToTypeMap(_)) val tpMap = (tparams.map(TypeParameterDef(_)) zip newTypes).toMap - val vars = (variablesOf(p.pc) ++ variablesOf(p.phi)).toSeq.sortBy(_.uniqueName) - val newVars = vars.map(id => FreshIdentifier(id.name, instantiateType(id.getType, tpMap), true)) + val newXs = p.xs.map { id => + val newTpe = instantiateType(id.getType, tpMap) + if (id.getType == newTpe) id else FreshIdentifier(id.name, newTpe, true) + } + + val newAs = p.as.map { id => + val newTpe = instantiateType(id.getType, tpMap) + if (id.getType == newTpe) id else FreshIdentifier(id.name, newTpe, true) + } - val args = p.as.map(id => FreshIdentifier(id.name, instantiateType(id.getType, tpMap), true)) - val inputsMap = (args zip inputs).map { + val inputsMap = (newAs zip inputs).map { case (id, v) => Equals(Variable(id), unit.jvmToValue(v, id.getType)) } - val expr = instantiateType(and(p.pc, p.phi), tpMap, (vars zip newVars).toMap) + val expr = instantiateType(and(p.pc, p.phi), tpMap, (p.as zip newAs).toMap ++ (p.xs zip newXs)) solver.assertCnstr(andJoin(expr +: inputsMap)) try { @@ -133,7 +139,7 @@ class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Id val valModel = valuateWithModel(model) _ - val res = p.xs.map(valModel) + val res = newXs.map(valModel) val leonRes = tupleWrap(res) val total = System.currentTimeMillis-tStart diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index e387e4f0c..6fc5b856d 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -13,14 +13,12 @@ import codegen.runtime.{StdMonitor, Monitor} class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) extends RecursiveEvaluator(ctx, prog, params.maxFunctionInvocations) - with HasDefaultGlobalContext -{ + with HasDefaultGlobalContext { type RC = DualRecContext def initRC(mappings: Map[Identifier, Expr]): RC = DualRecContext(mappings) implicit val debugSection = utils.DebugSectionEvaluation - val unit = new CompilationUnit(ctx, prog, params) var monitor: Monitor = new StdMonitor(unit, params.maxFunctionInvocations, Map()) @@ -39,7 +37,10 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) val (className, methodName, _) = unit.leonFunDefToJVMInfo(tfd.fd).get - val allArgs = if (params.requireMonitor) monitor +: args else args + val allArgs = + (if (params.requireMonitor) Seq(monitor) else Seq()) ++ + (if (tfd.fd.tparams.nonEmpty) Seq(tfd.tps.map(unit.registerType(_)).toArray) else Seq()) ++ + args ctx.reporter.debug(s"Calling $className.$methodName(${args.mkString(",")})") @@ -126,7 +127,6 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) } } - override def eval(ex: Expr, model: solvers.Model) = { monitor = unit.getMonitor(model, params.maxFunctionInvocations) super.eval(ex, model) diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 22aa09a4d..e7ff08ba6 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -567,6 +567,8 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int ctx.reporter.debug("Verification took "+total+"ms") ctx.reporter.debug("Finished forall evaluation with: "+res) + println(fargs.map(_.id),replaceFromIDs(mapping, body)) + println(res) frlCache += (f, context) -> res res case _ => diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 6e694843c..3310a8a8b 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -219,7 +219,7 @@ object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree id } } - + def instantiateType(id: Identifier, tps: Map[TypeParameterDef, TypeTree]): Identifier = { freshId(id, typeParamSubst(tps map { case (tpd, tp) => tpd.tp -> tp })(id.getType)) } diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 10447122b..6d7af0ba4 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -263,6 +263,7 @@ trait AbstractUnrollingSolver[T] private def getTotalModel: Model = { val wrapped = solverGetModel + println(wrapped) val typeInsts = templateGenerator.manager.typeInstantiations val partialInsts = templateGenerator.manager.partialInstantiations @@ -310,12 +311,17 @@ trait AbstractUnrollingSolver[T] if (mapping.isDefinedAt(conds)) mapping else mapping + (conds -> result) } - val rest :+ ((_, dflt)) = filteredConds.toSeq.sortBy(_._1.size) - val body = rest.foldLeft(dflt) { case (elze, (conds, res)) => - if (conds.isEmpty) elze else IfExpr(andJoin(conds), res, elze) - } + if (filteredConds.isEmpty) { + // TODO: warning?? + value + } else { + val rest :+ ((_, dflt)) = filteredConds.toSeq.sortBy(_._1.size) + val body = rest.foldLeft(dflt) { case (elze, (conds, res)) => + if (conds.isEmpty) elze else IfExpr(andJoin(conds), res, elze) + } - Lambda(params.map(ValDef(_)), body) + Lambda(params.map(ValDef(_)), body) + } case _ => value }) diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index ea5369e8c..4765060fd 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -422,14 +422,19 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage private def extractSubst(mapping: Set[(Set[T], Matcher[T], Matcher[T])]): (Set[T], Map[T,Arg[T]], Boolean) = { var constraints: Set[T] = Set.empty var eqConstraints: Set[(T, T)] = Set.empty - var matcherEqs: List[(T, T)] = Nil var subst: Map[T, Arg[T]] = Map.empty + var matcherEqs: Set[(T, T)] = Set.empty + def strictnessCnstr(qarg: Arg[T], arg: Arg[T]): Unit = (qarg, arg) match { + case (Right(qam), Right(am)) => (qam.args zip am.args).foreach(p => strictnessCnstr(p._1, p._2)) + case _ => matcherEqs += qarg.encoded -> arg.encoded + } + for { (bs, qm @ Matcher(qcaller, _, qargs, _), m @ Matcher(caller, _, args, _)) <- mapping - _ = constraints ++= bs - _ = matcherEqs :+= qm.encoded -> m.encoded + _ = constraints ++= bs + encoder.mkEquals(qcaller, caller) (qarg, arg) <- (qargs zip args) + _ = strictnessCnstr(qarg, arg) } qarg match { case Left(quant) if subst.isDefinedAt(quant) => eqConstraints += (quant -> arg.encoded) @@ -438,7 +443,6 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage case Right(qam) => val argVal = arg.encoded eqConstraints += (qam.encoded -> argVal) - matcherEqs :+= qam.encoded -> argVal } val substituter = encoder.substitute(subst.mapValues(_.encoded)) -- GitLab