Skip to content
Snippets Groups Projects
Commit c727548f authored by Manos Koukoutos's avatar Manos Koukoutos
Browse files

Merge ND-evaluator into master

parents cd7a4652 9a418d03
No related branches found
No related tags found
No related merge requests found
Showing
with 923 additions and 296 deletions
/* Copyright 2009-2015 EPFL, Lausanne */
package leon.codegen.runtime;
import java.util.HashMap;
public abstract class Forall {
private static final HashMap<Tuple, Boolean> cache = new HashMap<Tuple, Boolean>();
protected final LeonCodeGenRuntimeHenkinMonitor monitor;
protected final Tuple closures;
protected final boolean check;
public Forall(LeonCodeGenRuntimeMonitor monitor, Tuple closures) throws LeonCodeGenEvaluationException {
if (!(monitor instanceof LeonCodeGenRuntimeHenkinMonitor))
throw new LeonCodeGenEvaluationException("Can't evaluate foralls without domain");
this.monitor = (LeonCodeGenRuntimeHenkinMonitor) monitor;
this.closures = closures;
this.check = (boolean) closures.get(closures.getArity() - 1);
}
public boolean check() {
Tuple key = new Tuple(new Object[] { getClass(), monitor, closures }); // check is in the closures
if (cache.containsKey(key)) {
return cache.get(key);
} else {
boolean res = checkForall();
cache.put(key, res);
return res;
}
}
public abstract boolean checkForall();
}
...@@ -4,4 +4,6 @@ package leon.codegen.runtime; ...@@ -4,4 +4,6 @@ package leon.codegen.runtime;
public abstract class Lambda { public abstract class Lambda {
public abstract Object apply(Object[] args) throws LeonCodeGenRuntimeException; public abstract Object apply(Object[] args) throws LeonCodeGenRuntimeException;
public abstract void checkForall(boolean[] quantified);
public abstract void checkAxiom();
} }
/* Copyright 2009-2015 EPFL, Lausanne */
package leon.codegen.runtime;
/** Such exceptions are thrown when the evaluator encounters a forall
* expression whose shape is not supported in Leon. */
public class LeonCodeGenQuantificationException extends Exception {
private static final long serialVersionUID = -1824885321497473916L;
public LeonCodeGenQuantificationException(String msg) {
super(msg);
}
}
...@@ -7,26 +7,42 @@ import java.util.LinkedList; ...@@ -7,26 +7,42 @@ import java.util.LinkedList;
import java.util.HashMap; import java.util.HashMap;
public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor {
private final HashMap<Integer, List<Tuple>> domains = new HashMap<Integer, List<Tuple>>(); private final HashMap<Integer, List<Tuple>> tpes = new HashMap<Integer, List<Tuple>>();
private final HashMap<Class<?>, List<Tuple>> lambdas = new HashMap<Class<?>, List<Tuple>>();
public final boolean checkForalls;
public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations, boolean checkForalls) {
super(maxInvocations); super(maxInvocations);
this.checkForalls = checkForalls;
}
public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) {
this(maxInvocations, false);
} }
public void add(int type, Tuple input) { public void add(int type, Tuple input) {
if (!domains.containsKey(type)) domains.put(type, new LinkedList<Tuple>()); if (!tpes.containsKey(type)) tpes.put(type, new LinkedList<Tuple>());
domains.get(type).add(input); tpes.get(type).add(input);
}
public void add(Class<?> clazz, Tuple input) {
if (!lambdas.containsKey(clazz)) lambdas.put(clazz, new LinkedList<Tuple>());
lambdas.get(clazz).add(input);
} }
public List<Tuple> domain(Object obj, int type) { public List<Tuple> domain(Object obj, int type) {
List<Tuple> domain = new LinkedList<Tuple>(); List<Tuple> domain = new LinkedList<Tuple>();
if (obj instanceof PartialLambda) { if (obj instanceof PartialLambda) {
for (Tuple key : ((PartialLambda) obj).mapping.keySet()) { PartialLambda l = (PartialLambda) obj;
for (Tuple key : l.mapping.keySet()) {
domain.add(key); domain.add(key);
} }
} else if (obj instanceof Lambda) {
List<Tuple> lambdaDomain = lambdas.get(obj.getClass());
if (lambdaDomain != null) domain.addAll(lambdaDomain);
} }
List<Tuple> tpeDomain = domains.get(type); List<Tuple> tpeDomain = tpes.get(type);
if (tpeDomain != null) domain.addAll(tpeDomain); if (tpeDomain != null) domain.addAll(tpeDomain);
return domain; return domain;
......
...@@ -6,9 +6,15 @@ import java.util.HashMap; ...@@ -6,9 +6,15 @@ import java.util.HashMap;
public final class PartialLambda extends Lambda { public final class PartialLambda extends Lambda {
final HashMap<Tuple, Object> mapping = new HashMap<Tuple, Object>(); final HashMap<Tuple, Object> mapping = new HashMap<Tuple, Object>();
private final Object dflt;
public PartialLambda() { public PartialLambda() {
this(null);
}
public PartialLambda(Object dflt) {
super(); super();
this.dflt = dflt;
} }
public void add(Tuple key, Object value) { public void add(Tuple key, Object value) {
...@@ -20,15 +26,18 @@ public final class PartialLambda extends Lambda { ...@@ -20,15 +26,18 @@ public final class PartialLambda extends Lambda {
Tuple tuple = new Tuple(args); Tuple tuple = new Tuple(args);
if (mapping.containsKey(tuple)) { if (mapping.containsKey(tuple)) {
return mapping.get(tuple); return mapping.get(tuple);
} else if (dflt != null) {
return dflt;
} else { } else {
throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments"); throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments " + tuple);
} }
} }
@Override @Override
public boolean equals(Object that) { public boolean equals(Object that) {
if (that != null && (that instanceof PartialLambda)) { if (that != null && (that instanceof PartialLambda)) {
return mapping.equals(((PartialLambda) that).mapping); PartialLambda l = (PartialLambda) that;
return ((dflt != null && dflt.equals(l.dflt)) || (dflt == null && l.dflt == null)) && mapping.equals(l.mapping);
} else { } else {
return false; return false;
} }
...@@ -36,6 +45,12 @@ public final class PartialLambda extends Lambda { ...@@ -36,6 +45,12 @@ public final class PartialLambda extends Lambda {
@Override @Override
public int hashCode() { public int hashCode() {
return 63 + 11 * mapping.hashCode(); return 63 + 11 * mapping.hashCode() + (dflt == null ? 0 : dflt.hashCode());
} }
@Override
public void checkForall(boolean[] quantified) {}
@Override
public void checkAxiom() {}
} }
...@@ -54,4 +54,20 @@ public final class Tuple { ...@@ -54,4 +54,20 @@ public final class Tuple {
_hash = h; _hash = h;
return h; return h;
} }
@Override
public String toString() {
String str = "(";
boolean first = true;
for (Object obj : elements) {
if (first) {
first = false;
} else {
str += ", ";
}
str += obj == null ? "null" : obj.toString();
}
str += ")";
return str;
}
} }
This diff is collapsed.
...@@ -127,13 +127,24 @@ class CompilationUnit(val ctx: LeonContext, ...@@ -127,13 +127,24 @@ class CompilationUnit(val ctx: LeonContext,
conss.last conss.last
} }
def modelToJVM(model: solvers.Model, maxInvocations: Int): LeonCodeGenRuntimeMonitor = model match { def modelToJVM(model: solvers.Model, maxInvocations: Int, check: Boolean): LeonCodeGenRuntimeMonitor = model match {
case hModel: solvers.HenkinModel => case hModel: solvers.HenkinModel =>
val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations) val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations, check)
for ((tpe, domain) <- hModel.domains; args <- domain) { for ((lambda, domain) <- hModel.doms.lambdas) {
val (afName, _, _) = compileLambda(lambda)
val lc = loader.loadClass(afName)
for (args <- domain) {
// note here that it doesn't matter that `lhm` doesn't yet have its domains
// filled since all values in `args` should be grounded
val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple]
lhm.add(lc, inputJvm)
}
}
for ((tpe, domain) <- hModel.doms.tpes; args <- domain) {
val tpeId = typeId(tpe) val tpeId = typeId(tpe)
// note here that it doesn't matter that `lhm` doesn't yet have its domains // same remark as above about valueToJVM(_)(lhm)
// filled since all values in `args` should be grounded
val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple]
lhm.add(tpeId, inputJvm) lhm.add(tpeId, inputJvm)
} }
...@@ -201,8 +212,13 @@ class CompilationUnit(val ctx: LeonContext, ...@@ -201,8 +212,13 @@ class CompilationUnit(val ctx: LeonContext,
} }
m m
case f @ PartialLambda(mapping, _) => case f @ PartialLambda(mapping, dflt, _) =>
val l = new leon.codegen.runtime.PartialLambda() val l = if (dflt.isDefined) {
new leon.codegen.runtime.PartialLambda(dflt.get)
} else {
new leon.codegen.runtime.PartialLambda()
}
for ((ks,v) <- mapping) { for ((ks,v) <- mapping) {
// Force tuple even with 1/0 elems. // Force tuple even with 1/0 elems.
val kJvm = tupleConstructor.newInstance(ks.map(valueToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple] val kJvm = tupleConstructor.newInstance(ks.map(valueToJVM).toArray).asInstanceOf[leon.codegen.runtime.Tuple]
...@@ -530,3 +546,5 @@ class CompilationUnit(val ctx: LeonContext, ...@@ -530,3 +546,5 @@ class CompilationUnit(val ctx: LeonContext,
private [codegen] object exprCounter extends UniqueCounter[Unit] private [codegen] object exprCounter extends UniqueCounter[Unit]
private [codegen] object lambdaCounter extends UniqueCounter[Unit] private [codegen] object lambdaCounter extends UniqueCounter[Unit]
private [codegen] object forallCounter extends UniqueCounter[Unit]
...@@ -8,7 +8,7 @@ import purescala.Expressions._ ...@@ -8,7 +8,7 @@ import purescala.Expressions._
import cafebabe._ import cafebabe._
import runtime.{LeonCodeGenRuntimeMonitor => LM, LeonCodeGenRuntimeHenkinMonitor => LHM} import runtime.{LeonCodeGenRuntimeMonitor => LM}
import java.lang.reflect.InvocationTargetException import java.lang.reflect.InvocationTargetException
...@@ -51,9 +51,9 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, ...@@ -51,9 +51,9 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr,
} }
} }
def eval(model: solvers.Model) : Expr = { def eval(model: solvers.Model, check: Boolean = false) : Expr = {
try { try {
val monitor = unit.modelToJVM(model, params.maxFunctionInvocations) val monitor = unit.modelToJVM(model, params.maxFunctionInvocations, check)
evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor)
} catch { } catch {
case ite : InvocationTargetException => throw ite.getCause case ite : InvocationTargetException => throw ite.getCause
......
...@@ -33,12 +33,24 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { ...@@ -33,12 +33,24 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
b -> Constructor[Expr, TypeTree](List(), BooleanType, s => BooleanLiteral(b), ""+b) b -> Constructor[Expr, TypeTree](List(), BooleanType, s => BooleanLiteral(b), ""+b)
}).toMap }).toMap
val chars = (for (c <- Set('a', 'b', 'c', 'd')) yield {
c -> Constructor[Expr, TypeTree](List(), CharType, s => CharLiteral(c), ""+c)
}).toMap
val rationals = (for (n <- Set(0, 1, 2, 3); d <- Set(1,2,3,4)) yield {
(n, d) -> Constructor[Expr, TypeTree](List(), RealType, s => FractionalLiteral(n, d), "" + n + "/" + d)
}).toMap
def intConstructor(i: Int) = ints(i) def intConstructor(i: Int) = ints(i)
def bigIntConstructor(i: Int) = bigInts(i) def bigIntConstructor(i: Int) = bigInts(i)
def boolConstructor(b: Boolean) = booleans(b) def boolConstructor(b: Boolean) = booleans(b)
def charConstructor(c: Char) = chars(c)
def rationalConstructor(n: Int, d: Int) = rationals(n -> d)
def cPattern(c: Constructor[Expr, TypeTree], args: Seq[VPattern[Expr, TypeTree]]) = { def cPattern(c: Constructor[Expr, TypeTree], args: Seq[VPattern[Expr, TypeTree]]) = {
ConstructorPattern[Expr, TypeTree](c, args) ConstructorPattern[Expr, TypeTree](c, args)
} }
...@@ -50,7 +62,6 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { ...@@ -50,7 +62,6 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
getConstructors(t).head.copy(retType = act) getConstructors(t).head.copy(retType = act)
} }
private def getConstructors(t: TypeTree): List[Constructor[Expr, TypeTree]] = t match { private def getConstructors(t: TypeTree): List[Constructor[Expr, TypeTree]] = t match {
case UnitType => case UnitType =>
constructors.getOrElse(t, { constructors.getOrElse(t, {
...@@ -97,8 +108,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { ...@@ -97,8 +108,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
case mt @ MapType(from, to) => case mt @ MapType(from, to) =>
constructors.getOrElse(mt, { constructors.getOrElse(mt, {
val cs = for (size <- List(0, 1, 2, 5)) yield { val cs = for (size <- List(0, 1, 2, 5)) yield {
val subs = (1 to size).flatMap(i => List(from, to)).toList val subs = (1 to size).flatMap(i => List(from, to)).toList
Constructor[Expr, TypeTree](subs, mt, s => FiniteMap(s.grouped(2).map(t => (t(0), t(1))).toMap, from, to), mt.asString(ctx)+"@"+size) Constructor[Expr, TypeTree](subs, mt, s => FiniteMap(s.grouped(2).map(t => (t(0), t(1))).toMap, from, to), mt.asString(ctx)+"@"+size)
} }
constructors += mt -> cs constructors += mt -> cs
...@@ -110,13 +120,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { ...@@ -110,13 +120,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
val cs = for (size <- List(1, 2, 3, 5)) yield { val cs = for (size <- List(1, 2, 3, 5)) yield {
val subs = (1 to size).flatMap(_ => from :+ to).toList val subs = (1 to size).flatMap(_ => from :+ to).toList
Constructor[Expr, TypeTree](subs, ft, { s => Constructor[Expr, TypeTree](subs, ft, { s =>
val args = from.map(tpe => FreshIdentifier("x", tpe, true))
val argsTuple = tupleWrap(args.map(_.toVariable))
val grouped = s.grouped(from.size + 1).toSeq val grouped = s.grouped(from.size + 1).toSeq
val body = grouped.init.foldRight(grouped.last.last) { case (t, elze) => val mapping = grouped.init.map { case args :+ res => (args -> res) }
IfExpr(Equals(argsTuple, tupleWrap(t.init)), t.last, elze) PartialLambda(mapping, Some(grouped.last.last), ft)
}
Lambda(args.map(id => ValDef(id)), body)
}, ft.asString(ctx) + "@" + size) }, ft.asString(ctx) + "@" + size)
} }
constructors += ft -> cs constructors += ft -> cs
...@@ -166,6 +172,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { ...@@ -166,6 +172,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
case (b: java.lang.Boolean, BooleanType) => case (b: java.lang.Boolean, BooleanType) =>
(cPattern(boolConstructor(b), List()), true) (cPattern(boolConstructor(b), List()), true)
case (c: java.lang.Character, CharType) =>
(cPattern(charConstructor(c), List()), true)
case (cc: codegen.runtime.CaseClass, ct: ClassType) => case (cc: codegen.runtime.CaseClass, ct: ClassType) =>
val r = cc.__getRead() val r = cc.__getRead()
...@@ -193,7 +202,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { ...@@ -193,7 +202,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
(ConstructorPattern(c, elems.map(_._1)), elems.forall(_._2)) (ConstructorPattern(c, elems.map(_._1)), elems.forall(_._2))
case _ => case _ =>
ctx.reporter.error("Could not retreive type for :"+cc.getClass.getName) ctx.reporter.error("Could not retrieve type for :"+cc.getClass.getName)
(AnyPattern[Expr, TypeTree](), false) (AnyPattern[Expr, TypeTree](), false)
} }
...@@ -217,6 +226,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { ...@@ -217,6 +226,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
case (gv: GenericValue, t: TypeParameter) => case (gv: GenericValue, t: TypeParameter) =>
(cPattern(getConstructors(t)(gv.id-1), List()), true) (cPattern(getConstructors(t)(gv.id-1), List()), true)
case (v, t) => case (v, t) =>
ctx.reporter.debug("Unsupported value, can't paternify : "+v+" ("+v.getClass+") : "+t) ctx.reporter.debug("Unsupported value, can't paternify : "+v+" ("+v.getClass+") : "+t)
(AnyPattern[Expr, TypeTree](), false) (AnyPattern[Expr, TypeTree](), false)
...@@ -287,8 +297,8 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { ...@@ -287,8 +297,8 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator {
None None
}) })
val stubValues = ints.values ++ bigInts.values ++ booleans.values ++ chars.values ++ rationals.values
val gen = new StubGenerator[Expr, TypeTree]((ints.values ++ bigInts.values ++ booleans.values).toSeq, val gen = new StubGenerator[Expr, TypeTree](stubValues.toSeq,
Some(getConstructors _), Some(getConstructors _),
treatEmptyStubsAsChildless = true) treatEmptyStubsAsChildless = true)
......
...@@ -22,6 +22,9 @@ class AngelicEvaluator(underlying: NDEvaluator) ...@@ -22,6 +22,9 @@ class AngelicEvaluator(underlying: NDEvaluator)
case other@(RuntimeError(_) | EvaluatorError(_)) => case other@(RuntimeError(_) | EvaluatorError(_)) =>
other.asInstanceOf[Result[Nothing]] other.asInstanceOf[Result[Nothing]]
} }
/** Checks that `model |= expr` and that quantifications are all valid */
def check(expr: Expr, model: Model): CheckResult = underlying.check(expr, model)
} }
class DemonicEvaluator(underlying: NDEvaluator) class DemonicEvaluator(underlying: NDEvaluator)
...@@ -39,4 +42,7 @@ class DemonicEvaluator(underlying: NDEvaluator) ...@@ -39,4 +42,7 @@ class DemonicEvaluator(underlying: NDEvaluator)
case other@(RuntimeError(_) | EvaluatorError(_)) => case other@(RuntimeError(_) | EvaluatorError(_)) =>
other.asInstanceOf[Result[Nothing]] other.asInstanceOf[Result[Nothing]]
} }
/** Checks that `model |= expr` and that quantifications are all valid */
def check(expr: Expr, model: Model): CheckResult = underlying.check(expr, model)
} }
\ No newline at end of file
...@@ -8,9 +8,15 @@ import purescala.Definitions._ ...@@ -8,9 +8,15 @@ import purescala.Definitions._
import purescala.Expressions._ import purescala.Expressions._
import codegen.CompilationUnit import codegen.CompilationUnit
import codegen.CompiledExpression
import codegen.CodeGenParams import codegen.CodeGenParams
import leon.codegen.runtime.LeonCodeGenRuntimeException
import leon.codegen.runtime.LeonCodeGenEvaluationException
import leon.codegen.runtime.LeonCodeGenQuantificationException
class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) with DeterministicEvaluator { class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) with DeterministicEvaluator {
val name = "codegen-eval" val name = "codegen-eval"
val description = "Evaluator for PureScala expressions based on compilation to JVM" val description = "Evaluator for PureScala expressions based on compilation to JVM"
...@@ -19,9 +25,55 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva ...@@ -19,9 +25,55 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva
this(ctx, new CompilationUnit(ctx, prog, params)) this(ctx, new CompilationUnit(ctx, prog, params))
} }
private def compileExpr(expression: Expr, args: Seq[Identifier]): Option[CompiledExpression] = {
ctx.timers.evaluators.codegen.compilation.start()
try {
Some(unit.compileExpression(expression, args)(ctx))
} catch {
case t: Throwable =>
ctx.reporter.warning(expression.getPos, "Error while compiling expression: "+t.getMessage)
None
} finally {
ctx.timers.evaluators.codegen.compilation.stop()
}
}
def check(expression: Expr, model: solvers.Model) : CheckResult = {
compileExpr(expression, model.toSeq.map(_._1)).map { ce =>
ctx.timers.evaluators.codegen.runtime.start()
try {
val res = ce.eval(model, check = true)
if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess
else EvaluationResults.CheckValidityFailure
} catch {
case e : ArithmeticException =>
EvaluationResults.CheckRuntimeFailure(e.getMessage)
case e : ArrayIndexOutOfBoundsException =>
EvaluationResults.CheckRuntimeFailure(e.getMessage)
case e : LeonCodeGenRuntimeException =>
EvaluationResults.CheckRuntimeFailure(e.getMessage)
case e : LeonCodeGenEvaluationException =>
EvaluationResults.CheckRuntimeFailure(e.getMessage)
case e : java.lang.ExceptionInInitializerError =>
EvaluationResults.CheckRuntimeFailure(e.getException.getMessage)
case so : java.lang.StackOverflowError =>
EvaluationResults.CheckRuntimeFailure("Stack overflow")
case e : LeonCodeGenQuantificationException =>
EvaluationResults.CheckQuantificationFailure(e.getMessage)
} finally {
ctx.timers.evaluators.codegen.runtime.stop()
}
}.getOrElse(EvaluationResults.CheckRuntimeFailure("Couldn't compile expression."))
}
def eval(expression: Expr, model: solvers.Model) : EvaluationResult = { def eval(expression: Expr, model: solvers.Model) : EvaluationResult = {
val toPairs = model.toSeq compile(expression, model.toSeq.map(_._1)).map { e =>
compile(expression, toPairs.map(_._1)).map { e =>
ctx.timers.evaluators.codegen.runtime.start() ctx.timers.evaluators.codegen.runtime.start()
val res = e(model) val res = e(model)
ctx.timers.evaluators.codegen.runtime.stop() ctx.timers.evaluators.codegen.runtime.stop()
...@@ -30,45 +82,30 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva ...@@ -30,45 +82,30 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva
} }
override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = {
import leon.codegen.runtime.LeonCodeGenRuntimeException compileExpr(expression, args).map(ce => (model: solvers.Model) => {
import leon.codegen.runtime.LeonCodeGenEvaluationException if (args.exists(arg => !model.isDefinedAt(arg))) {
EvaluationResults.EvaluatorError("Model undefined for free arguments")
ctx.timers.evaluators.codegen.compilation.start() } else try {
try { EvaluationResults.Successful(ce.eval(model))
val ce = unit.compileExpression(expression, args)(ctx) } catch {
case e : ArithmeticException =>
Some((model: solvers.Model) => { EvaluationResults.RuntimeError(e.getMessage)
if (args.exists(arg => !model.isDefinedAt(arg))) {
EvaluationResults.EvaluatorError("Model undefined for free arguments")
} else try {
EvaluationResults.Successful(ce.eval(model))
} catch {
case e : ArithmeticException =>
EvaluationResults.RuntimeError(e.getMessage)
case e : ArrayIndexOutOfBoundsException => case e : ArrayIndexOutOfBoundsException =>
EvaluationResults.RuntimeError(e.getMessage) EvaluationResults.RuntimeError(e.getMessage)
case e : LeonCodeGenRuntimeException => case e : LeonCodeGenRuntimeException =>
EvaluationResults.RuntimeError(e.getMessage) EvaluationResults.RuntimeError(e.getMessage)
case e : LeonCodeGenEvaluationException => case e : LeonCodeGenEvaluationException =>
EvaluationResults.EvaluatorError(e.getMessage) EvaluationResults.EvaluatorError(e.getMessage)
case e : java.lang.ExceptionInInitializerError => case e : java.lang.ExceptionInInitializerError =>
EvaluationResults.RuntimeError(e.getException.getMessage) EvaluationResults.RuntimeError(e.getException.getMessage)
case so : java.lang.StackOverflowError => case so : java.lang.StackOverflowError =>
EvaluationResults.RuntimeError("Stack overflow") EvaluationResults.RuntimeError("Stack overflow")
}
} })
})
} catch {
case t: Throwable =>
ctx.reporter.warning(expression.getPos, "Error while compiling expression: "+t.getMessage)
None
} finally {
ctx.timers.evaluators.codegen.compilation.stop()
}
} }
} }
...@@ -3,13 +3,11 @@ ...@@ -3,13 +3,11 @@
package leon package leon
package evaluators package evaluators
import leon.purescala.Extractors.{IsTyped, TopLevelAnds}
import purescala.Common._ import purescala.Common._
import purescala.Definitions._ import purescala.Definitions._
import purescala.Expressions._ import purescala.Expressions._
import purescala.Types._ import purescala.Types._
import purescala.Constructors._
import purescala.ExprOps._
import purescala.Quantification._
import solvers.{HenkinModel, Model} import solvers.{HenkinModel, Model}
abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps: Int) extends Evaluator(ctx, prog) with CEvalHelpers { abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps: Int) extends Evaluator(ctx, prog) with CEvalHelpers {
...@@ -20,17 +18,18 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps ...@@ -20,17 +18,18 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps
type GC <: GlobalContext type GC <: GlobalContext
def initRC(mappings: Map[Identifier, Expr]): RC def initRC(mappings: Map[Identifier, Expr]): RC
def initGC(model: solvers.Model): GC def initGC(model: solvers.Model, check: Boolean): GC
case class EvalError(msg : String) extends Exception case class EvalError(msg : String) extends Exception
case class RuntimeError(msg : String) extends Exception case class RuntimeError(msg : String) extends Exception
case class QuantificationError(msg: String) extends Exception
// Used by leon-web, please do not delete // Used by leon-web, please do not delete
var lastGC: Option[GC] = None var lastGC: Option[GC] = None
def eval(ex: Expr, model: Model) = { def eval(ex: Expr, model: Model) = {
try { try {
lastGC = Some(initGC(model)) lastGC = Some(initGC(model, check = true))
ctx.timers.evaluators.recursive.runtime.start() ctx.timers.evaluators.recursive.runtime.start()
EvaluationResults.Successful(e(ex)(initRC(model.toMap), lastGC.get)) EvaluationResults.Successful(e(ex)(initRC(model.toMap), lastGC.get))
} catch { } catch {
...@@ -47,6 +46,30 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps ...@@ -47,6 +46,30 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps
} }
} }
def check(ex: Expr, model: Model): CheckResult = {
assert(ex.getType == BooleanType, "Can't check non-boolean expression " + ex.asString)
try {
lastGC = Some(initGC(model, check = true))
ctx.timers.evaluators.recursive.runtime.start()
val res = e(ex)(initRC(model.toMap), lastGC.get)
if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess
else EvaluationResults.CheckValidityFailure
} catch {
case so: StackOverflowError =>
EvaluationResults.CheckRuntimeFailure("Stack overflow")
case e @ EvalError(msg) =>
EvaluationResults.CheckRuntimeFailure(msg)
case e @ RuntimeError(msg) =>
EvaluationResults.CheckRuntimeFailure(msg)
case jre: java.lang.RuntimeException =>
EvaluationResults.CheckRuntimeFailure(jre.getMessage)
case qe @ QuantificationError(msg) =>
EvaluationResults.CheckQuantificationFailure(msg)
} finally {
ctx.timers.evaluators.recursive.runtime.stop()
}
}
protected def e(expr: Expr)(implicit rctx: RC, gctx: GC): Value protected def e(expr: Expr)(implicit rctx: RC, gctx: GC): Value
def typeErrorMsg(tree : Expr, expected : TypeTree) : String = s"Type error : expected ${expected.asString}, found ${tree.asString}." def typeErrorMsg(tree : Expr, expected : TypeTree) : String = s"Type error : expected ${expected.asString}, found ${tree.asString}."
...@@ -55,60 +78,62 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps ...@@ -55,60 +78,62 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps
private[evaluators] trait CEvalHelpers { private[evaluators] trait CEvalHelpers {
this: ContextualEvaluator => this: ContextualEvaluator =>
def forallInstantiations(gctx:GC, fargs: Seq[ValDef], conj: Expr) = {
val henkinModel: HenkinModel = gctx.model match {
case hm: HenkinModel => hm
case _ => throw EvalError("Can't evaluate foralls without henkin model")
}
val vars = variablesOf(conj)
val args = fargs.map(_.id).filter(vars)
val quantified = args.toSet
val matcherQuorums = extractQuorums(conj, quantified) /* This is an effort to generalize forall to non-det. solvers
def forallInstantiations(gctx:GC, fargs: Seq[ValDef], conj: Expr) = {
matcherQuorums.flatMap { quorum => val henkinModel: HenkinModel = gctx.model match {
var mappings: Seq[(Identifier, Int, Int)] = Seq.empty case hm: HenkinModel => hm
var constraints: Seq[(Expr, Int, Int)] = Seq.empty case _ => throw EvalError("Can't evaluate foralls without henkin model")
}
for (((expr, args), qidx) <- quorum.zipWithIndex) { val vars = variablesOf(conj)
val (qmappings, qconstraints) = args.zipWithIndex.partition { val args = fargs.map(_.id).filter(vars)
case (Variable(id), aidx) => quantified(id) val quantified = args.toSet
case _ => false
}
mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) val matcherQuorums = extractQuorums(conj, quantified)
constraints ++= qconstraints.map(p => (p._1, qidx, p._2))
}
var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty matcherQuorums.flatMap { quorum =>
val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { var mappings: Seq[(Identifier, Int, Int)] = Seq.empty
val base :: others = es.toList.map(p => (p._2, p._3)) var constraints: Seq[(Expr, Int, Int)] = Seq.empty
equalities ++= others.map(p => base -> p)
(id -> base)
}
val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { for (((expr, args), qidx) <- quorum.zipWithIndex) {
case (acc, (expr, _)) => acc.flatMap(s => henkinModel.domain(expr).map(d => s :+ d)) val (qmappings, qconstraints) = args.zipWithIndex.partition {
case (Variable(id), aidx) => quantified(id)
case _ => false
} }
argSets.map { args => mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2))
val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { constraints ++= qconstraints.map(p => (p._1, qidx, p._2))
case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } }
}.toMap
var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty
val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield {
val base :: others = es.toList.map(p => (p._2, p._3))
equalities ++= others.map(p => base -> p)
(id -> base)
}
val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) {
case (acc, (expr, _)) => acc.flatMap(s => henkinModel.domain(expr).map(d => s :+ d))
}
argSets.map { args =>
val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap {
case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e }
}.toMap
val map = mapping.map { case (id, key) => id -> argMap(key) }
val enabler = andJoin(constraints.map {
case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx))
} ++ equalities.map {
case (k1, k2) => Equals(argMap(k1), argMap(k2))
})
(enabler, map)
}
}*/
val map = mapping.map { case (id, key) => id -> argMap(key) }
val enabler = andJoin(constraints.map {
case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx))
} ++ equalities.map {
case (k1, k2) => Equals(argMap(k1), argMap(k2))
})
(enabler, map)
}
}
}
} }
\ No newline at end of file
...@@ -8,4 +8,4 @@ import purescala.Definitions.Program ...@@ -8,4 +8,4 @@ import purescala.Definitions.Program
class DefaultEvaluator(ctx: LeonContext, prog: Program) class DefaultEvaluator(ctx: LeonContext, prog: Program)
extends RecursiveEvaluator(ctx, prog, 5000) extends RecursiveEvaluator(ctx, prog, 5000)
with HasDefaultGlobalContext with HasDefaultGlobalContext
with HasDefaultRecContext with HasDefaultRecContext
\ No newline at end of file
...@@ -17,7 +17,6 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) ...@@ -17,7 +17,6 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams)
type RC = DualRecContext type RC = DualRecContext
def initRC(mappings: Map[Identifier, Expr]): RC = DualRecContext(mappings) def initRC(mappings: Map[Identifier, Expr]): RC = DualRecContext(mappings)
implicit val debugSection = utils.DebugSectionEvaluation implicit val debugSection = utils.DebugSectionEvaluation
var monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) var monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations)
......
...@@ -15,4 +15,21 @@ object EvaluationResults { ...@@ -15,4 +15,21 @@ object EvaluationResults {
/** Represents an evaluation that failed (in the evaluator). */ /** Represents an evaluation that failed (in the evaluator). */
case class EvaluatorError(message : String) extends Result(None) case class EvaluatorError(message : String) extends Result(None)
/** Results of checking proposition evaluation.
* Useful for verification of model validity in presence of quantifiers. */
sealed abstract class CheckResult(val success: Boolean)
/** Successful proposition evaluation (model |= expr) */
case object CheckSuccess extends CheckResult(true)
/** Check failed with `model |= !expr` */
case object CheckValidityFailure extends CheckResult(false)
/** Check failed due to evaluation or runtime errors.
* @see [[RuntimeError]] and [[EvaluatorError]] */
case class CheckRuntimeFailure(msg: String) extends CheckResult(false)
/** Check failed due to inconsistence of model with quantified propositions. */
case class CheckQuantificationFailure(msg: String) extends CheckResult(false)
} }
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
package leon package leon
package evaluators package evaluators
import leon.purescala.Types.TypeTree
import purescala.Common._ import purescala.Common._
import purescala.Definitions._ import purescala.Definitions._
import purescala.Expressions._ import purescala.Expressions._
...@@ -19,6 +18,7 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends ...@@ -19,6 +18,7 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends
type Value type Value
type EvaluationResult = EvaluationResults.Result[Value] type EvaluationResult = EvaluationResults.Result[Value]
type CheckResult = EvaluationResults.CheckResult
/** Evaluates an expression, using [[Model.mapping]] as a valuation function for the free variables. */ /** Evaluates an expression, using [[Model.mapping]] as a valuation function for the free variables. */
def eval(expr: Expr, model: Model) : EvaluationResult def eval(expr: Expr, model: Model) : EvaluationResult
...@@ -31,6 +31,9 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends ...@@ -31,6 +31,9 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends
/** Evaluates a ground expression. */ /** Evaluates a ground expression. */
final def eval(expr: Expr) : EvaluationResult = eval(expr, Model.empty) final def eval(expr: Expr) : EvaluationResult = eval(expr, Model.empty)
/** Checks that `model |= expr` and that quantifications are all valid */
def check(expr: Expr, model: Model) : CheckResult
/** Compiles an expression into a function, where the arguments are the free variables in the expression. /** Compiles an expression into a function, where the arguments are the free variables in the expression.
* `argorder` specifies in which order the arguments should be passed. * `argorder` specifies in which order the arguments should be passed.
* The default implementation uses the evaluation function each time, but evaluators are free * The default implementation uses the evaluation function each time, but evaluators are free
......
...@@ -4,9 +4,11 @@ package leon ...@@ -4,9 +4,11 @@ package leon
package evaluators package evaluators
import purescala.Common.Identifier import purescala.Common.Identifier
import purescala.Expressions.Expr import leon.purescala.Expressions.{Lambda, Expr}
import solvers.Model import solvers.Model
import scala.collection.mutable.{Map => MutableMap}
trait RecContext[RC <: RecContext[RC]] { trait RecContext[RC <: RecContext[RC]] {
def mappings: Map[Identifier, Expr] def mappings: Map[Identifier, Expr]
...@@ -25,8 +27,10 @@ case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext ...@@ -25,8 +27,10 @@ case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext
def newVars(news: Map[Identifier, Expr]) = copy(news) def newVars(news: Map[Identifier, Expr]) = copy(news)
} }
class GlobalContext(val model: Model, val maxSteps: Int) { class GlobalContext(val model: Model, val maxSteps: Int, val check: Boolean) {
var stepsLeft = maxSteps var stepsLeft = maxSteps
val lambdas: MutableMap[Lambda, Lambda] = MutableMap.empty
} }
trait HasDefaultRecContext extends ContextualEvaluator { trait HasDefaultRecContext extends ContextualEvaluator {
...@@ -35,6 +39,6 @@ trait HasDefaultRecContext extends ContextualEvaluator { ...@@ -35,6 +39,6 @@ trait HasDefaultRecContext extends ContextualEvaluator {
} }
trait HasDefaultGlobalContext extends ContextualEvaluator { trait HasDefaultGlobalContext extends ContextualEvaluator {
def initGC(model: solvers.Model) = new GlobalContext(model, this.maxSteps) def initGC(model: solvers.Model, check: Boolean) = new GlobalContext(model, this.maxSteps, check)
type GC = GlobalContext type GC = GlobalContext
} }
\ No newline at end of file
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
package leon package leon
package evaluators package evaluators
import leon.purescala.Quantification._
import purescala.Constructors._ import purescala.Constructors._
import purescala.ExprOps._ import purescala.ExprOps._
import purescala.Expressions.Pattern import purescala.Expressions.Pattern
...@@ -12,7 +13,9 @@ import purescala.Types._ ...@@ -12,7 +13,9 @@ import purescala.Types._
import purescala.Common._ import purescala.Common._
import purescala.Expressions._ import purescala.Expressions._
import purescala.Definitions._ import purescala.Definitions._
import solvers.SolverFactory import leon.solvers.{HenkinModel, Model, SolverFactory}
import scala.collection.mutable.{Map => MutableMap}
abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int) abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int)
extends ContextualEvaluator(ctx, prog, maxSteps) extends ContextualEvaluator(ctx, prog, maxSteps)
...@@ -42,11 +45,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int ...@@ -42,11 +45,12 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
val newArgs = args.map(e) val newArgs = args.map(e)
val mapping = l.paramSubst(newArgs) val mapping = l.paramSubst(newArgs)
e(body)(rctx.withNewVars(mapping), gctx) e(body)(rctx.withNewVars(mapping), gctx)
case PartialLambda(mapping, _) => case PartialLambda(mapping, dflt, _) =>
mapping.find { case (pargs, res) => mapping.find { case (pargs, res) =>
(args zip pargs).forall(p => e(Equals(p._1, p._2)) == BooleanLiteral(true)) (args zip pargs).forall(p => e(Equals(p._1, p._2)) == BooleanLiteral(true))
}.map(_._2).getOrElse { }.map(_._2).orElse(dflt).getOrElse {
throw EvalError("Cannot apply partial lambda outside of domain") throw EvalError("Cannot apply partial lambda outside of domain : " +
args.map(e(_).asString(ctx)).mkString("(", ", ", ")"))
} }
case f => case f =>
throw EvalError("Cannot apply non-lambda function " + f.asString) throw EvalError("Cannot apply non-lambda function " + f.asString)
...@@ -180,7 +184,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int ...@@ -180,7 +184,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
(lv,rv) match { (lv,rv) match {
case (FiniteSet(el1, _),FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteSet(el1, _),FiniteSet(el2, _)) => BooleanLiteral(el1 == el2)
case (FiniteMap(el1, _, _),FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) case (FiniteMap(el1, _, _),FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet)
case (PartialLambda(m1, _), PartialLambda(m2, _)) => BooleanLiteral(m1.toSet == m2.toSet) case (PartialLambda(m1, d1, _), PartialLambda(m2, d2, _)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2)
case _ => BooleanLiteral(lv == rv) case _ => BooleanLiteral(lv == rv)
} }
...@@ -458,20 +462,19 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int ...@@ -458,20 +462,19 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
FiniteSet(els.map(e), base) FiniteSet(els.map(e), base)
case l @ Lambda(_, _) => case l @ Lambda(_, _) =>
val (nl, structSubst) = normalizeStructure(l) val (nl, structSubst) = normalizeStructure(matchToIfThenElse(l))
val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap
replaceFromIDs(mapping, nl) val newLambda = replaceFromIDs(mapping, nl).asInstanceOf[Lambda]
if (!gctx.lambdas.isDefinedAt(newLambda)) {
gctx.lambdas += (newLambda -> nl.asInstanceOf[Lambda])
}
newLambda
case PartialLambda(mapping, tpe) => case PartialLambda(mapping, dflt, tpe) =>
PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), tpe) PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), dflt.map(e), tpe)
case f @ Forall(fargs, TopLevelAnds(conjuncts)) => case Forall(fargs, body) =>
e(andJoin(for (conj <- conjuncts) yield { evalForall(fargs.map(_.id).toSet, body)
val instantiations = forallInstantiations(gctx, fargs, conj)
e(andJoin(instantiations.map { case (enabler, mapping) =>
e(Implies(enabler, conj))(rctx.withNewVars(mapping), gctx)
}))
}))
case ArrayLength(a) => case ArrayLength(a) =>
val FiniteArray(_, _, IntLiteral(length)) = e(a) val FiniteArray(_, _, IntLiteral(length)) = e(a)
...@@ -678,6 +681,140 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int ...@@ -678,6 +681,140 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int
} }
protected def evalForall(quants: Set[Identifier], body: Expr, check: Boolean = true)(implicit rctx: RC, gctx: GC): Expr = {
val henkinModel: HenkinModel = gctx.model match {
case hm: HenkinModel => hm
case _ => throw EvalError("Can't evaluate foralls without henkin model")
}
val TopLevelAnds(conjuncts) = body
e(andJoin(conjuncts.flatMap { conj =>
val vars = variablesOf(conj)
val quantified = quants.filter(vars)
extractQuorums(conj, quantified).flatMap { case (qrm, others) =>
val quorum = qrm.toList
if (quorum.exists { case (TopLevelAnds(paths), _, _) =>
val p = andJoin(paths.filter(path => (variablesOf(path) & quantified).isEmpty))
e(p) == BooleanLiteral(false)
}) List(BooleanLiteral(true)) else {
var mappings: Seq[(Identifier, Int, Int)] = Seq.empty
var constraints: Seq[(Expr, Int, Int)] = Seq.empty
var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty
for (((_, expr, args), qidx) <- quorum.zipWithIndex) {
val (qmappings, qconstraints) = args.zipWithIndex.partition {
case (Variable(id),aidx) => quantified(id)
case _ => false
}
mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2))
constraints ++= qconstraints.map(p => (p._1, qidx, p._2))
}
val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield {
val base :: others = es.toList.map(p => (p._2, p._3))
equalities ++= others.map(p => base -> p)
(id -> base)
}
def domain(expr: Expr): Set[Seq[Expr]] = henkinModel.domain(e(expr) match {
case l: Lambda => gctx.lambdas.getOrElse(l, l)
case ev => ev
})
val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) {
case (acc, (_, expr, _)) => acc.flatMap(s => domain(expr).map(d => s :+ d))
}
argSets.map { args =>
val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap {
case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e }
}.toMap
val map = mapping.map { case (id, key) => id -> argMap(key) }
val enabler = andJoin(constraints.map {
case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx))
} ++ equalities.map {
case (k1, k2) => Equals(argMap(k1), argMap(k2))
})
val ctx = rctx.withNewVars(map)
if (e(enabler)(ctx, gctx) == BooleanLiteral(true)) {
if (gctx.check) {
for ((b,caller,args) <- others if e(b)(ctx, gctx) == BooleanLiteral(true)) {
val evArgs = args.map(arg => e(arg)(ctx, gctx))
if (!domain(caller)(evArgs))
throw QuantificationError("Unhandled transitive implication in " + replaceFromIDs(map, conj))
}
}
e(conj)(ctx, gctx)
} else {
BooleanLiteral(true)
}
}
}
}
})) match {
case res @ BooleanLiteral(true) if check =>
if (gctx.check) {
checkForall(quants, body) match {
case status: ForallInvalid =>
throw QuantificationError("Invalid forall: " + status.getMessage)
case _ =>
// make sure the body doesn't contain matches or lets as these introduce new locals
val cleanBody = expandLets(matchToIfThenElse(body))
val calls = new CollectorWithPaths[(Expr, Seq[Expr], Seq[Expr])] {
def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Seq[Expr], Seq[Expr])] = e match {
case QuantificationMatcher(IsTyped(caller, _: FunctionType), args) => Some((caller, args, path))
case _ => None
}
override def rec(e: Expr, path: Seq[Expr]): Expr = e match {
case l : Lambda => l
case _ => super.rec(e, path)
}
}.traverse(cleanBody)
for ((caller, appArgs, paths) <- calls) {
val path = andJoin(paths.filter(expr => (variablesOf(expr) & quants).isEmpty))
if (e(path) == BooleanLiteral(true)) e(caller) match {
case _: PartialLambda => // OK
case l: Lambda =>
val nl @ Lambda(args, body) = gctx.lambdas.getOrElse(l, l)
val lambdaQuantified = (appArgs zip args).collect {
case (Variable(id), vd) if quants(id) => vd.id
}.toSet
if (lambdaQuantified.nonEmpty) {
checkForall(lambdaQuantified, body) match {
case lambdaStatus: ForallInvalid =>
throw QuantificationError("Invalid forall: " + lambdaStatus.getMessage)
case _ => // do nothing
}
val axiom = Equals(Application(nl, args.map(_.toVariable)), nl.body)
if (evalForall(args.map(_.id).toSet, axiom, check = false) == BooleanLiteral(false)) {
throw QuantificationError("Unaxiomatic lambda " + l)
}
}
case f =>
throw EvalError("Cannot apply non-lambda function " + f.asString)
}
}
}
}
res
// `res == false` means the quantification is valid since there effectivelly must
// exist an input for which the proposition doesn't hold
case res => res
}
}
} }
...@@ -37,7 +37,8 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) ...@@ -37,7 +37,8 @@ class StreamEvaluator(ctx: LeonContext, prog: Program)
case l @ Lambda(params, body) => case l @ Lambda(params, body) =>
val mapping = l.paramSubst(newArgs) val mapping = l.paramSubst(newArgs)
e(body)(rctx.withNewVars(mapping), gctx).distinct e(body)(rctx.withNewVars(mapping), gctx).distinct
case PartialLambda(mapping, _) => case PartialLambda(mapping, _, _) =>
// FIXME
mapping.collectFirst { mapping.collectFirst {
case (pargs, res) if (newArgs zip pargs).forall { case (f, r) => f == r } => case (pargs, res) if (newArgs zip pargs).forall { case (f, r) => f == r } =>
res res
...@@ -134,7 +135,8 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) ...@@ -134,7 +135,8 @@ class StreamEvaluator(ctx: LeonContext, prog: Program)
).toMap ).toMap
Stream(replaceFromIDs(mapping, nl)) Stream(replaceFromIDs(mapping, nl))
case PartialLambda(mapping, tpe) => // FIXME
case PartialLambda(mapping, tpe, df) =>
def solveOne(pair: (Seq[Expr], Expr)) = { def solveOne(pair: (Seq[Expr], Expr)) = {
val (args, res) = pair val (args, res) = pair
for { for {
...@@ -142,11 +144,11 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) ...@@ -142,11 +144,11 @@ class StreamEvaluator(ctx: LeonContext, prog: Program)
r <- e(res) r <- e(res)
} yield as -> r } yield as -> r
} }
cartesianProduct(mapping map solveOne) map (PartialLambda(_, tpe)) cartesianProduct(mapping map solveOne) map (PartialLambda(_, tpe, df)) // FIXME!!!
case f @ Forall(fargs, TopLevelAnds(conjuncts)) => case f @ Forall(fargs, TopLevelAnds(conjuncts)) =>
Stream() // FIXME
def solveOne(conj: Expr) = { /*def solveOne(conj: Expr) = {
val instantiations = forallInstantiations(gctx, fargs, conj) val instantiations = forallInstantiations(gctx, fargs, conj)
for { for {
es <- cartesianProduct(instantiations.map { case (enabler, mapping) => es <- cartesianProduct(instantiations.map { case (enabler, mapping) =>
...@@ -159,7 +161,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) ...@@ -159,7 +161,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program)
for { for {
conj <- cartesianProduct(conjuncts map solveOne) conj <- cartesianProduct(conjuncts map solveOne)
res <- e(andJoin(conj)) res <- e(andJoin(conj))
} yield res } yield res*/
case p : Passes => case p : Passes =>
e(p.asConstraint) e(p.asConstraint)
...@@ -344,7 +346,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) ...@@ -344,7 +346,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program)
(lv, rv) match { (lv, rv) match {
case (FiniteSet(el1, _), FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteSet(el1, _), FiniteSet(el2, _)) => BooleanLiteral(el1 == el2)
case (FiniteMap(el1, _, _), FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) case (FiniteMap(el1, _, _), FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet)
case (PartialLambda(m1, _), PartialLambda(m2, _)) => BooleanLiteral(m1.toSet == m2.toSet) case (PartialLambda(m1, _, d1), PartialLambda(m2, _, d2)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2)
case _ => BooleanLiteral(lv == rv) case _ => BooleanLiteral(lv == rv)
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment