diff --git a/.larabot.conf b/.larabot.conf index 6a037e0addc616364cb0b049075beb1f7c2d18c7..4f84df1cfdc62750edb7122030615723a91dbca0 100644 --- a/.larabot.conf +++ b/.larabot.conf @@ -1,8 +1,8 @@ commands = [ - "sbt -batch test" - "sbt -batch integration:test" - "sbt -batch regression:test" - "sbt -batch genc:test" + "sbt -batch -Dparallel=10 test" + "sbt -batch -Dparallel=10 integration:test" + "sbt -batch -Dparallel=10 regression:test" + "sbt -batch -Dparallel=10 genc:test" ] trusted = [ @@ -17,3 +17,9 @@ trusted = [ "samarion" "vkuncak" ] + +notify { + master = [ + "leon-dev@googlegroups.com" + ] +} diff --git a/build.sbt b/build.sbt index 0e3e147a10e395fb58a0ebb3f981b9919deb56b9..6507f84a78b8bb3528d58121fe16e40b88f9ce09 100644 --- a/build.sbt +++ b/build.sbt @@ -56,6 +56,20 @@ clean := { } } +lazy val nParallel = { + val p = System.getProperty("parallel") + if (p ne null) { + try { + p.toInt + } catch { + case nfe: NumberFormatException => + 1 + } + } else { + 1 + } +} + lazy val script = taskKey[Unit]("Generate the leon Bash script") script := { @@ -107,14 +121,17 @@ sourcesInBase in Compile := false Keys.fork in run := true + lazy val testSettings = Seq( //Keys.fork := true, - logBuffered := true, - parallelExecution := true + logBuffered := (nParallel > 1), + parallelExecution := (nParallel > 1) //testForkedParallel := true, //javaOptions ++= Seq("-Xss64M", "-Xmx4G") ) +concurrentRestrictions in Global += Tags.limit(Tags.Test, nParallel) + // Unit Tests testOptions in Test := Seq(Tests.Argument("-oDF"), Tests.Filter(_ startsWith "leon.unit.")) @@ -138,7 +155,6 @@ lazy val IsabelleTest = config("isabelle") extend(Test) testOptions in IsabelleTest := Seq(Tests.Argument("-oDF"), Tests.Filter(_ startsWith "leon.isabelle.")) parallelExecution in IsabelleTest := false - fork in IsabelleTest := true // GenC Tests diff --git a/library/lang/Rational.scala b/library/lang/Rational.scala index f4a10215b61032a73f20107c5aeddbac841e66bc..1bcb679a0d4abf6b168117b2af819fa41e027315 100644 --- a/library/lang/Rational.scala +++ b/library/lang/Rational.scala @@ -9,72 +9,65 @@ import scala.language.implicitConversions @library case class Rational(numerator: BigInt, denominator: BigInt) { + require(this.isRational) + def +(that: Rational): Rational = { - require(this.isRational && that.isRational) Rational(this.numerator*that.denominator + that.numerator*this.denominator, this.denominator*that.denominator) } ensuring(res => res.isRational) def -(that: Rational): Rational = { - require(this.isRational && that.isRational) Rational(this.numerator*that.denominator - that.numerator*this.denominator, this.denominator*that.denominator) } ensuring(res => res.isRational) def unary_- : Rational = { - require(this.isRational) Rational(-this.numerator, this.denominator) } ensuring(res => res.isRational) def *(that: Rational): Rational = { - require(this.isRational && that.isRational) Rational(this.numerator*that.numerator, this.denominator*that.denominator) } ensuring(res => res.isRational) def /(that: Rational): Rational = { - require(this.isRational && that.isRational && that.nonZero) + require(that.nonZero) val newNumerator = this.numerator*that.denominator val newDenominator = this.denominator*that.numerator normalize(newNumerator, newDenominator) } ensuring(res => res.isRational) def reciprocal: Rational = { - require(this.isRational && this.nonZero) + require(this.nonZero) normalize(this.denominator, this.numerator) } ensuring(res => res.isRational) def ~(that: Rational): Boolean = { - require(this.isRational && that.isRational) this.numerator*that.denominator == that.numerator*this.denominator } def <(that: Rational): Boolean = { - require(this.isRational && that.isRational) this.numerator*that.denominator < that.numerator*this.denominator } def <=(that: Rational): Boolean = { - require(this.isRational && that.isRational) this.numerator*that.denominator <= that.numerator*this.denominator } def >(that: Rational): Boolean = { - require(this.isRational && that.isRational) this.numerator*that.denominator > that.numerator*this.denominator } def >=(that: Rational): Boolean = { - require(this.isRational && that.isRational) this.numerator*that.denominator >= that.numerator*this.denominator } def nonZero: Boolean = { - require(this.isRational) numerator != 0 } - def isRational: Boolean = denominator > 0 + private def isRational: Boolean = denominator > 0 private def normalize(num: BigInt, den: BigInt): Rational = { + require(den != 0) if(den < 0) Rational(-num, -den) else @@ -88,4 +81,5 @@ object Rational { implicit def bigIntToRat(n: BigInt): Rational = Rational(n, 1) def apply(n: BigInt): Rational = Rational(n, 1) + } diff --git a/library/lang/package.scala b/library/lang/package.scala index b19ec529bdb3975956de00a7da8e8ad39bcf34e4..ab713a588e5451ba8ab61533d6037bd2bd572de8 100644 --- a/library/lang/package.scala +++ b/library/lang/package.scala @@ -20,16 +20,7 @@ package object lang { if (underlying) that else true } } - - implicit class SpecsDecorations[A](val underlying: A) { - @inline - def computes(target: A) = { - underlying - } ensuring { - res => res == target - } - } - + @ignore def forall[A](p: A => Boolean): Boolean = sys.error("Can't execute quantified proposition") @ignore def forall[A,B](p: (A,B) => Boolean): Boolean = sys.error("Can't execute quantified proposition") @ignore def forall[A,B,C](p: (A,B,C) => Boolean): Boolean = sys.error("Can't execute quantified proposition") @@ -56,6 +47,27 @@ package object lang { def passes(tests : A => B ) : Boolean = try { tests(in) == out } catch { case _ : MatchError => true } } + + @ignore + def byExample[A, B](in: A, out: B): Boolean = { + sys.error("Can't execute by example proposition") + } + + implicit class SpecsDecorations[A](val underlying: A) { + @ignore + def computes(target: A) = { + underlying + } ensuring { + res => res == target + } + + @ignore // Programming by example: ???[String] ask input + def ask[I](input : I) = { + underlying + } ensuring { + (res: A) => byExample(input, res) + } + } @ignore object BigInt { diff --git a/src/main/java/leon/codegen/runtime/FiniteLambda.java b/src/main/java/leon/codegen/runtime/FiniteLambda.java new file mode 100644 index 0000000000000000000000000000000000000000..fcdb340191923fe3f408484333171c242a946196 --- /dev/null +++ b/src/main/java/leon/codegen/runtime/FiniteLambda.java @@ -0,0 +1,44 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.codegen.runtime; + +import java.util.HashMap; + +public final class FiniteLambda extends Lambda { + public final HashMap<Tuple, Object> mapping = new HashMap<Tuple, Object>(); + public final Object dflt; + + public FiniteLambda(Object dflt) { + super(); + this.dflt = dflt; + } + + public void add(Tuple key, Object value) { + mapping.put(key, value); + } + + @Override + public Object apply(Object[] args) throws LeonCodeGenRuntimeException { + Tuple tuple = new Tuple(args); + if (mapping.containsKey(tuple)) { + return mapping.get(tuple); + } else { + return dflt; + } + } + + @Override + public boolean equals(Object that) { + if (that != null && (that instanceof FiniteLambda)) { + FiniteLambda l = (FiniteLambda) that; + return dflt.equals(l.dflt) && mapping.equals(l.mapping); + } else { + return false; + } + } + + @Override + public int hashCode() { + return 63 + 11 * mapping.hashCode() + (dflt == null ? 0 : dflt.hashCode()); + } +} diff --git a/src/main/java/leon/codegen/runtime/Forall.java b/src/main/java/leon/codegen/runtime/Forall.java deleted file mode 100644 index f6877b604bcd069af6c2ce20d78a17d82d434dbe..0000000000000000000000000000000000000000 --- a/src/main/java/leon/codegen/runtime/Forall.java +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon.codegen.runtime; - -import java.util.HashMap; - -public abstract class Forall { - private static final HashMap<Tuple, Boolean> cache = new HashMap<Tuple, Boolean>(); - - protected final LeonCodeGenRuntimeHenkinMonitor monitor; - protected final Tuple closures; - protected final boolean check; - - public Forall(LeonCodeGenRuntimeMonitor monitor, Tuple closures) throws LeonCodeGenEvaluationException { - if (!(monitor instanceof LeonCodeGenRuntimeHenkinMonitor)) - throw new LeonCodeGenEvaluationException("Can't evaluate foralls without domain"); - - this.monitor = (LeonCodeGenRuntimeHenkinMonitor) monitor; - this.closures = closures; - this.check = (boolean) closures.get(closures.getArity() - 1); - } - - public boolean check() { - Tuple key = new Tuple(new Object[] { getClass(), monitor, closures }); // check is in the closures - if (cache.containsKey(key)) { - return cache.get(key); - } else { - boolean res = checkForall(); - cache.put(key, res); - return res; - } - } - - public abstract boolean checkForall(); -} diff --git a/src/main/java/leon/codegen/runtime/Lambda.java b/src/main/java/leon/codegen/runtime/Lambda.java index af255726311655efaeddea545c5e6e44afc15b8e..a6abbef37edbe8f87f480a21a6200e32a9e0206b 100644 --- a/src/main/java/leon/codegen/runtime/Lambda.java +++ b/src/main/java/leon/codegen/runtime/Lambda.java @@ -4,6 +4,4 @@ package leon.codegen.runtime; public abstract class Lambda { public abstract Object apply(Object[] args) throws LeonCodeGenRuntimeException; - public abstract void checkForall(boolean[] quantified); - public abstract void checkAxiom(); } diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java b/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java deleted file mode 100644 index f172316a2548a52c6b294f70101a15ebbb8ce98a..0000000000000000000000000000000000000000 --- a/src/main/java/leon/codegen/runtime/LeonCodeGenQuantificationException.java +++ /dev/null @@ -1,14 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon.codegen.runtime; - -/** Such exceptions are thrown when the evaluator encounters a forall - * expression whose shape is not supported in Leon. */ -public class LeonCodeGenQuantificationException extends Exception { - - private static final long serialVersionUID = -1824885321497473916L; - - public LeonCodeGenQuantificationException(String msg) { - super(msg); - } -} diff --git a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java b/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java deleted file mode 100644 index 597beec44b6a1a1719909e00ecb7d7916f0c7c03..0000000000000000000000000000000000000000 --- a/src/main/java/leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor.java +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon.codegen.runtime; - -import java.util.List; -import java.util.LinkedList; -import java.util.HashMap; - -public class LeonCodeGenRuntimeHenkinMonitor extends LeonCodeGenRuntimeMonitor { - private final HashMap<Integer, List<Tuple>> tpes = new HashMap<Integer, List<Tuple>>(); - private final HashMap<Class<?>, List<Tuple>> lambdas = new HashMap<Class<?>, List<Tuple>>(); - public final boolean checkForalls; - - public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations, boolean checkForalls) { - super(maxInvocations); - this.checkForalls = checkForalls; - } - - public LeonCodeGenRuntimeHenkinMonitor(int maxInvocations) { - this(maxInvocations, false); - } - - public void add(int type, Tuple input) { - if (!tpes.containsKey(type)) tpes.put(type, new LinkedList<Tuple>()); - tpes.get(type).add(input); - } - - public void add(Class<?> clazz, Tuple input) { - if (!lambdas.containsKey(clazz)) lambdas.put(clazz, new LinkedList<Tuple>()); - lambdas.get(clazz).add(input); - } - - public List<Tuple> domain(Object obj, int type) { - List<Tuple> domain = new LinkedList<Tuple>(); - if (obj instanceof PartialLambda) { - PartialLambda l = (PartialLambda) obj; - for (Tuple key : l.mapping.keySet()) { - domain.add(key); - } - } else if (obj instanceof Lambda) { - List<Tuple> lambdaDomain = lambdas.get(obj.getClass()); - if (lambdaDomain != null) domain.addAll(lambdaDomain); - } - - List<Tuple> tpeDomain = tpes.get(type); - if (tpeDomain != null) domain.addAll(tpeDomain); - - return domain; - } -} diff --git a/src/main/java/leon/codegen/runtime/PartialLambda.java b/src/main/java/leon/codegen/runtime/PartialLambda.java deleted file mode 100644 index b04036db5e9f81d1eaf7fa2c9a047bfef45a4df8..0000000000000000000000000000000000000000 --- a/src/main/java/leon/codegen/runtime/PartialLambda.java +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon.codegen.runtime; - -import java.util.HashMap; - -public final class PartialLambda extends Lambda { - final HashMap<Tuple, Object> mapping = new HashMap<Tuple, Object>(); - private final Object dflt; - - public PartialLambda() { - this(null); - } - - public PartialLambda(Object dflt) { - super(); - this.dflt = dflt; - } - - public void add(Tuple key, Object value) { - mapping.put(key, value); - } - - @Override - public Object apply(Object[] args) throws LeonCodeGenRuntimeException { - Tuple tuple = new Tuple(args); - if (mapping.containsKey(tuple)) { - return mapping.get(tuple); - } else if (dflt != null) { - return dflt; - } else { - throw new LeonCodeGenRuntimeException("Partial function apply on undefined arguments " + tuple); - } - } - - @Override - public boolean equals(Object that) { - if (that != null && (that instanceof PartialLambda)) { - PartialLambda l = (PartialLambda) that; - return ((dflt != null && dflt.equals(l.dflt)) || (dflt == null && l.dflt == null)) && mapping.equals(l.mapping); - } else { - return false; - } - } - - @Override - public int hashCode() { - return 63 + 11 * mapping.hashCode() + (dflt == null ? 0 : dflt.hashCode()); - } - - @Override - public void checkForall(boolean[] quantified) {} - - @Override - public void checkAxiom() {} -} diff --git a/src/main/scala/leon/LeonOption.scala b/src/main/scala/leon/LeonOption.scala index a0a9c9d92ee78397cce95dd0a52042fa8bc63330..f079399f9995e409b6b72881ef8f81d92c273556 100644 --- a/src/main/scala/leon/LeonOption.scala +++ b/src/main/scala/leon/LeonOption.scala @@ -25,8 +25,10 @@ abstract class LeonOptionDef[+A] { try { parser(s) } catch { case _ : IllegalArgumentException => - reporter.error(s"Invalid option usage: $usageDesc") - Main.displayHelp(reporter, error = true) + reporter.fatalError( + s"Invalid option usage: --$name\n" + + "Try 'leon --help' for more information." + ) } } diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 8adc16bfc19e9febdb0eb3be678db28989d649fb..509a04d38788c922d4fcc3ca9bf5ed4355e2be5c 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -34,8 +34,9 @@ object Main { } // Add whatever you need here. - lazy val allComponents: Set[LeonComponent] = allPhases.toSet ++ Set( - solvers.z3.FairZ3Component, MainComponent, SharedOptions, solvers.smtlib.SMTLIBCVC4Component, solvers.isabelle.Component) + lazy val allComponents : Set[LeonComponent] = allPhases.toSet ++ Set( + solvers.combinators.UnrollingProcedure, MainComponent, SharedOptions, solvers.smtlib.SMTLIBCVC4Component, solvers.isabelle.Component + ) /* * This object holds the options that determine the selected pipeline of Leon. @@ -45,18 +46,18 @@ object Main { val name = "main" val description = "Selection of Leon functionality. Default: verify" - val optEval = LeonStringOptionDef("eval", "Evaluate ground functions through code generation or evaluation (default: evaluation)", "default", "[code|default]") - val optTermination = LeonFlagOptionDef("termination", "Check program termination. Can be used along --verify", false) - val optRepair = LeonFlagOptionDef("repair", "Repair selected functions", false) - val optSynthesis = LeonFlagOptionDef("synthesis", "Partial synthesis of choose() constructs", false) - val optIsabelle = LeonFlagOptionDef("isabelle", "Run Isabelle verification", false) - val optNoop = LeonFlagOptionDef("noop", "No operation performed, just output program", false) - val optVerify = LeonFlagOptionDef("verify", "Verify function contracts", false) - val optHelp = LeonFlagOptionDef("help", "Show help message", false) - val optInstrument = LeonFlagOptionDef("instrument", "Instrument the code for inferring time/depth/stack bounds", false) - val optInferInv = LeonFlagOptionDef("inferInv", "Infer invariants from (instrumented) the code", false) + val optEval = LeonStringOptionDef("eval", "Evaluate ground functions through code generation or evaluation (default: evaluation)", "default", "[codegen|default]") + val optTermination = LeonFlagOptionDef("termination", "Check program termination. Can be used along --verify", false) + val optRepair = LeonFlagOptionDef("repair", "Repair selected functions", false) + val optSynthesis = LeonFlagOptionDef("synthesis", "Partial synthesis of choose() constructs", false) + val optIsabelle = LeonFlagOptionDef("isabelle", "Run Isabelle verification", false) + val optNoop = LeonFlagOptionDef("noop", "No operation performed, just output program", false) + val optVerify = LeonFlagOptionDef("verify", "Verify function contracts", false) + val optHelp = LeonFlagOptionDef("help", "Show help message", false) + val optInstrument = LeonFlagOptionDef("instrument", "Instrument the code for inferring time/depth/stack bounds", false) + val optInferInv = LeonFlagOptionDef("inferInv", "Infer invariants from (instrumented) the code", false) val optLazyEval = LeonFlagOptionDef("lazy", "Handles programs that may use the lazy construct", false) - val optGenc = LeonFlagOptionDef("genc", "Generate C code", false) + val optGenc = LeonFlagOptionDef("genc", "Generate C code", false) override val definedOptions: Set[LeonOptionDef[Any]] = Set(optTermination, optRepair, optSynthesis, optIsabelle, optNoop, optHelp, optEval, optVerify, optInstrument, optInferInv, optLazyEval, optGenc) @@ -115,9 +116,11 @@ object Main { s"Malformed option $opt. Options should have the form --name or --name=value") } // Find respective LeonOptionDef, or report an unknown option - val df = allOptions.find(_.name == name).getOrElse { - initReporter.error(s"Unknown option: $name") - displayHelp(initReporter, error = true) + val df = allOptions.find(_. name == name).getOrElse{ + initReporter.fatalError( + s"Unknown option: $name\n" + + "Try 'leon --help' for more information." + ) } df.parse(value)(initReporter) } diff --git a/src/main/scala/leon/SharedOptions.scala b/src/main/scala/leon/SharedOptions.scala index 839dda206b4a702ad5011049a38c76d7dcd21478..b68e64c83a6fd42e445dd4b5fa3b3bd42a935de4 100644 --- a/src/main/scala/leon/SharedOptions.scala +++ b/src/main/scala/leon/SharedOptions.scala @@ -5,12 +5,11 @@ package leon import leon.utils.{DebugSections, DebugSection} import OptionParsers._ -/* - * This object contains options that are shared among different modules of Leon. - * - * Options that determine the pipeline of Leon are not stored here, - * but in MainComponent in Main.scala. - */ +/** This object contains options that are shared among different modules of Leon. + * + * Options that determine the pipeline of Leon are not stored here, + * but in [[Main.MainComponent]] instead. + */ object SharedOptions extends LeonComponent { val name = "sharedOptions" @@ -45,7 +44,7 @@ object SharedOptions extends LeonComponent { val name = "debug" val description = { val sects = DebugSections.all.toSeq.map(_.name).sorted - val (first, second) = sects.splitAt(sects.length/2) + val (first, second) = sects.splitAt(sects.length/2 + 1) "Enable detailed messages per component.\nAvailable:\n" + " " + first.mkString(", ") + ",\n" + " " + second.mkString(", ") @@ -61,8 +60,6 @@ object SharedOptions extends LeonComponent { Set(rs) case None => throw new IllegalArgumentException - //initReporter.error("Section "+s+" not found, available: "+DebugSections.all.map(_.name).mkString(", ")) - //Set() } } } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 86405da600c7ec006c41da3c4c33433354775b66..14b8f5cdf5fb00237282ac231ed881b9179e053d 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -8,6 +8,7 @@ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps._ import purescala.Constructors._ import purescala.Extractors._ import purescala.Quantification._ @@ -27,9 +28,10 @@ trait CodeGeneration { * isStatic signifies if the current method is static (a function, in Leon terms) */ class Locals private[codegen] ( - vars : Map[Identifier, Int], - args : Map[Identifier, Int], - fields : Map[Identifier, (String,String,String)] + vars : Map[Identifier, Int], + args : Map[Identifier, Int], + fields : Map[Identifier, (String,String,String)], + val tps : Seq[TypeParameter] ) { /** Fetches the offset of a local variable/ parameter from its identifier */ def varToLocal(v: Identifier): Option[Int] = vars.get(v) @@ -39,21 +41,24 @@ trait CodeGeneration { def varToField(v: Identifier): Option[(String,String,String)] = fields.get(v) /** Adds some extra variables to the mapping */ - def withVars(newVars: Map[Identifier, Int]) = new Locals(vars ++ newVars, args, fields) + def withVars(newVars: Map[Identifier, Int]) = new Locals(vars ++ newVars, args, fields, tps) /** Adds an extra variable to the mapping */ - def withVar(nv: (Identifier, Int)) = new Locals(vars + nv, args, fields) + def withVar(nv: (Identifier, Int)) = new Locals(vars + nv, args, fields, tps) - def withArgs(newArgs: Map[Identifier, Int]) = new Locals(vars, args ++ newArgs, fields) + def withArgs(newArgs: Map[Identifier, Int]) = new Locals(vars, args ++ newArgs, fields, tps) - def withFields(newFields: Map[Identifier,(String,String,String)]) = new Locals(vars, args, fields ++ newFields) + def withFields(newFields: Map[Identifier,(String,String,String)]) = new Locals(vars, args, fields ++ newFields, tps) - override def toString = "Locals("+vars + ", " + args + ", " + fields + ")" + def withTypes(newTps: Seq[TypeParameter]) = new Locals(vars, args, fields, tps ++ newTps) + + override def toString = "Locals("+vars + ", " + args + ", " + fields + ", " + tps + ")" } - object NoLocals extends Locals(Map.empty, Map.empty, Map.empty) + object NoLocals extends Locals(Map.empty, Map.empty, Map.empty, Seq.empty) lazy val monitorID = FreshIdentifier("__$monitor") + lazy val tpsID = FreshIdentifier("__$tps") private[codegen] val ObjectClass = "java/lang/Object" private[codegen] val BoxedIntClass = "java/lang/Integer" @@ -73,17 +78,15 @@ trait CodeGeneration { private[codegen] val RationalClass = "leon/codegen/runtime/Rational" private[codegen] val CaseClassClass = "leon/codegen/runtime/CaseClass" private[codegen] val LambdaClass = "leon/codegen/runtime/Lambda" - private[codegen] val ForallClass = "leon/codegen/runtime/Forall" - private[codegen] val PartialLambdaClass = "leon/codegen/runtime/PartialLambda" + private[codegen] val FiniteLambdaClass = "leon/codegen/runtime/FiniteLambda" private[codegen] val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" - private[codegen] val InvalidForallClass = "leon/codegen/runtime/LeonCodeGenQuantificationException" private[codegen] val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException" private[codegen] val BadQuantificationClass = "leon/codegen/runtime/LeonCodeGenQuantificationException" private[codegen] val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" private[codegen] val ChooseEntryPointClass = "leon/codegen/runtime/ChooseEntryPoint" private[codegen] val GenericValuesClass = "leon/codegen/runtime/GenericValues" - private[codegen] val MonitorClass = "leon/codegen/runtime/LeonCodeGenRuntimeMonitor" - private[codegen] val HenkinClass = "leon/codegen/runtime/LeonCodeGenRuntimeHenkinMonitor" + private[codegen] val MonitorClass = "leon/codegen/runtime/Monitor" + private[codegen] val NoMonitorClass = "leon/codegen/runtime/NoMonitor" private[codegen] val StrOpsClass = "leon/codegen/runtime/StrOps" def idToSafeJVMName(id: Identifier) = { @@ -165,13 +168,8 @@ trait CodeGeneration { val cf = classes(owner) val (_,mn,_) = leonFunDefToJVMInfo(funDef).get - val paramsTypes = funDef.params.map(a => typeToJVM(a.getType)) - - val realParams = if (requireMonitor) { - ("L" + MonitorClass + ";") +: paramsTypes - } else { - paramsTypes - } + val tpeParam = if (funDef.tparams.isEmpty) Seq() else Seq("[I") + val realParams = ("L" + MonitorClass + ";") +: (tpeParam ++ funDef.params.map(a => typeToJVM(a.getType))) val m = cf.addMethod( typeToJVM(funDef.returnType), @@ -193,9 +191,11 @@ trait CodeGeneration { // An offset we introduce to the parameters: // 1 if this is a method, so we need "this" in position 0 of the stack - // 1 if we are monitoring - val idParams = (if (requireMonitor) Seq(monitorID) else Seq.empty) ++ funDef.paramIds - val newMapping = idParams.zipWithIndex.toMap.mapValues(_ + (if (!isStatic) 1 else 0)) + val receiverOffset = if (isStatic) 0 else 1 + val paramIds = Seq(monitorID) ++ + (if (funDef.tparams.nonEmpty) Seq(tpsID) else Seq.empty) ++ + funDef.paramIds + val newMapping = paramIds.zipWithIndex.toMap.mapValues(_ + receiverOffset) val body = if (params.checkContracts) { funDef.fullBody @@ -203,11 +203,11 @@ trait CodeGeneration { funDef.body.getOrElse(throw CompilationException("Can't compile a FunDef without body: "+funDef.id.name)) } - val locals = NoLocals.withVars(newMapping) + val locals = NoLocals.withVars(newMapping).withTypes(funDef.tparams.map(_.tp)) if (params.recordInvocations) { load(monitorID, ch)(locals) - ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") + ch << InvokeVirtual(MonitorClass, "onInvocation", "()V") } mkExpr(body, ch)(locals) @@ -226,16 +226,17 @@ trait CodeGeneration { private[codegen] val lambdaToClass = scala.collection.mutable.Map.empty[Lambda, String] private[codegen] val classToLambda = scala.collection.mutable.Map.empty[String, Lambda] - protected def compileLambda(l: Lambda): (String, Seq[(Identifier, String)], String) = { + protected def compileLambda(l: Lambda): (String, Seq[(Identifier, String)], Seq[TypeParameter], String) = { val (normalized, structSubst) = purescala.ExprOps.normalizeStructure(matchToIfThenElse(l)) val reverseSubst = structSubst.map(p => p._2 -> p._1) val nl = normalized.asInstanceOf[Lambda] - val closureIDs = purescala.ExprOps.variablesOf(nl).toSeq.sortBy(_.uniqueName) - val closuresWithoutMonitor = closureIDs.map(id => id -> typeToJVM(id.getType)) - val closures = if (requireMonitor) { - (monitorID -> s"L$MonitorClass;") +: closuresWithoutMonitor - } else closuresWithoutMonitor + val tparams: Seq[TypeParameter] = typeParamsOf(nl).toSeq.sortBy(_.id.uniqueName) + + val closedVars = purescala.ExprOps.variablesOf(nl).toSeq.sortBy(_.uniqueName) + val closuresWithoutMonitor = closedVars.map(id => id -> typeToJVM(id.getType)) + val closures = (monitorID -> s"L$MonitorClass;") +: + ((if (tparams.nonEmpty) Seq(tpsID -> "[I") else Seq.empty) ++ closuresWithoutMonitor) val afName = lambdaToClass.getOrElse(nl, { val afName = "Leon$CodeGen$Lambda$" + lambdaCounter.nextGlobal @@ -283,7 +284,7 @@ trait CodeGeneration { val argMapping = nl.args.map(_.id).zipWithIndex.toMap val closureMapping = closures.map { case (id, jvmt) => id -> (afName, id.uniqueName, jvmt) }.toMap - val newLocals = NoLocals.withArgs(argMapping).withFields(closureMapping) + val newLocals = NoLocals.withArgs(argMapping).withFields(closureMapping).withTypes(tparams) locally { val apm = cf.addMethod(s"L$ObjectClass;", "apply", s"[L$ObjectClass;") @@ -377,396 +378,14 @@ trait CodeGeneration { hch.freeze } - locally { - val vmh = cf.addMethod("V", "checkForall", "[Z") - vmh.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val vch = vmh.codeHandler - - vch << ALoad(1) // load argument array - def rec(args: Seq[Identifier], idx: Int, quantified: Set[Identifier]): Unit = args match { - case x :: xs => - val notQuantLabel = vch.getFreshLabel("notQuant") - vch << DUP << Ldc(idx) << BALOAD << IfEq(notQuantLabel) - rec(xs, idx + 1, quantified + x) - vch << Label(notQuantLabel) - rec(xs, idx + 1, quantified) - - case Nil => - if (quantified.nonEmpty) { - checkQuantified(quantified, nl.body, vch)(newLocals) - vch << ALoad(0) << InvokeVirtual(LambdaClass, "checkAxiom", "()V") - } - vch << POP << RETURN - } - - if (requireQuantification) { - rec(nl.args.map(_.id), 0, Set.empty) - } else { - vch << POP << RETURN - } - - vch.freeze - } - - locally { - val vmh = cf.addMethod("V", "checkAxiom") - vmh.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val vch = vmh.codeHandler - - if (requireQuantification) { - val thisVar = FreshIdentifier("this", l.getType) - val axiom = Equals(Application(Variable(thisVar), nl.args.map(_.toVariable)), nl.body) - val axiomLocals = NoLocals.withFields(closureMapping).withVar(thisVar -> 0) - - mkForall(nl.args.map(_.id).toSet, axiom, vch, check = false)(axiomLocals) - - val skip = vch.getFreshLabel("skip") - vch << IfNe(skip) - vch << New(InvalidForallClass) << DUP - vch << Ldc("Unaxiomatic lambda " + l) - vch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") - vch << ATHROW - vch << Label(skip) - } - - vch << RETURN - vch.freeze - } - loader.register(cf) afName }) (afName, closures.map { case p @ (id, jvmt) => - if (id == monitorID) p else (reverseSubst(id) -> jvmt) - }, "(" + closures.map(_._2).mkString("") + ")V") - } - - private def checkQuantified(quantified: Set[Identifier], body: Expr, ch: CodeHandler)(implicit locals: Locals): Unit = { - val skipCheck = ch.getFreshLabel("skipCheck") - - load(monitorID, ch) - ch << CheckCast(HenkinClass) << GetField(HenkinClass, "checkForalls", "Z") - ch << IfEq(skipCheck) - - checkForall(quantified, body)(ctx) match { - case status: ForallInvalid => - ch << New(InvalidForallClass) << DUP - ch << Ldc("Invalid forall: " + status.getMessage) - ch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") - ch << ATHROW - - case ForallValid => - // expand match case expressions and lets so that caller can be compiled given - // the current locals (lets and matches introduce new locals) - val cleanBody = purescala.ExprOps.expandLets(purescala.ExprOps.matchToIfThenElse(body)) - - val calls = new CollectorWithPaths[(Expr, Seq[Expr], Seq[Expr])] { - def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Seq[Expr], Seq[Expr])] = e match { - case QuantificationMatcher(IsTyped(caller, _: FunctionType), args) => Some((caller, args, path)) - case _ => None - } - - override def rec(e: Expr, path: Seq[Expr]): Expr = e match { - case l : Lambda => l - case _ => super.rec(e, path) - } - }.traverse(cleanBody) - - for ((caller, args, paths) <- calls) { - if ((variablesOf(caller) & quantified).isEmpty) { - val enabler = andJoin(paths.filter(expr => (variablesOf(expr) & quantified).isEmpty)) - val skipCall = ch.getFreshLabel("skipCall") - mkExpr(enabler, ch) - ch << IfEq(skipCall) - - mkExpr(caller, ch) - ch << Ldc(args.size) << NewArray.primitive("T_BOOLEAN") - for ((arg, idx) <- args.zipWithIndex) { - ch << DUP << Ldc(idx) << Ldc(arg match { - case Variable(id) if quantified(id) => 1 - case _ => 0 - }) << BASTORE - } - - ch << InvokeVirtual(LambdaClass, "checkForall", "([Z)V") - - ch << Label(skipCall) - } - } - } - - ch << Label(skipCheck) - } - - private val typeIdCache = scala.collection.mutable.Map.empty[TypeTree, Int] - private[codegen] def typeId(tpe: TypeTree): Int = typeIdCache.get(tpe) match { - case Some(id) => id - case None => - val id = typeIdCache.size - typeIdCache += tpe -> id - id - } - - private[codegen] val forallToClass = scala.collection.mutable.Map.empty[Expr, String] - - private def mkForall(quants: Set[Identifier], body: Expr, ch: CodeHandler, check: Boolean = true)(implicit locals: Locals): Unit = { - val (afName, closures, consSig) = compileForall(quants, body) - ch << New(afName) << DUP - load(monitorID, ch) - mkTuple(closures.map(_.toVariable) :+ BooleanLiteral(check), ch) - ch << InvokeSpecial(afName, constructorName, consSig) - ch << InvokeVirtual(ForallClass, "check", "()Z") - } - - private def compileForall(quants: Set[Identifier], body: Expr): (String, Seq[Identifier], String) = { - val (nl, structSubst) = purescala.ExprOps.normalizeStructure(matchToIfThenElse(body)) - val reverseSubst = structSubst.map(p => p._2 -> p._1) - val nquants = quants.flatMap(structSubst.get) - - val closures = (purescala.ExprOps.variablesOf(nl) -- nquants).toSeq.sortBy(_.uniqueName) - - val afName = forallToClass.getOrElse(nl, { - val afName = "Leon$CodeGen$Forall$" + forallCounter.nextGlobal - forallToClass += nl -> afName - - val cf = new ClassFile(afName, Some(ForallClass)) - - cf.setFlags(( - CLASS_ACC_SUPER | - CLASS_ACC_PUBLIC | - CLASS_ACC_FINAL - ).asInstanceOf[U2]) - - locally { - val cch = cf.addConstructor(s"L$MonitorClass;", s"L$TupleClass;").codeHandler - - cch << ALoad(0) << ALoad(1) << ALoad(2) - cch << InvokeSpecial(ForallClass, constructorName, s"(L$MonitorClass;L$TupleClass;)V") - cch << RETURN - cch.freeze - } - - locally { - val cfm = cf.addMethod("Z", "checkForall") - - cfm.setFlags(( - METHOD_ACC_PUBLIC | - METHOD_ACC_FINAL - ).asInstanceOf[U2]) - - val cfch = cfm.codeHandler - - cfch << ALoad(0) << GetField(ForallClass, "closures", s"L$TupleClass;") - - val closureVars = (for ((id, idx) <- closures.zipWithIndex) yield { - val slot = cfch.getFreshVar - cfch << DUP << Ldc(idx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") - mkUnbox(id.getType, cfch) - cfch << (id.getType match { - case ValueType() => IStore(slot) - case _ => AStore(slot) - }) - id -> slot - }).toMap - - cfch << POP - - val monitorSlot = cfch.getFreshVar - cfch << ALoad(0) << GetField(ForallClass, "monitor", s"L$HenkinClass;") - cfch << AStore(monitorSlot) - - implicit val locals = NoLocals.withVars(closureVars).withVar(monitorID -> monitorSlot) - - val skipCheck = cfch.getFreshLabel("skipCheck") - cfch << ALoad(0) << GetField(ForallClass, "check", "Z") - cfch << IfEq(skipCheck) - checkQuantified(nquants, nl, cfch) - cfch << Label(skipCheck) - - val TopLevelAnds(conjuncts) = nl - val endLabel = cfch.getFreshLabel("forallEnd") - - for (conj <- conjuncts) { - val vars = purescala.ExprOps.variablesOf(conj) - val quantified = nquants.filter(vars) - - val matchQuorums = extractQuorums(conj, quantified) - - var allSlots: List[Int] = Nil - var freeSlots: List[Int] = Nil - def getSlot(): Int = freeSlots match { - case x :: xs => - freeSlots = xs - x - case Nil => - val slot = cfch.getFreshVar - allSlots = allSlots :+ slot - slot - } - - for ((qrm, others) <- matchQuorums) { - val quorum = qrm.toList - - def rec(mis: List[(Expr, Expr, Seq[Expr], Int)], locs: Map[Identifier, Int], pointers: Map[(Int, Int), Identifier]): Unit = mis match { - case (TopLevelAnds(paths), expr, args, qidx) :: rest => - load(monitorID, cfch) - cfch << CheckCast(HenkinClass) - - mkExpr(expr, cfch) - cfch << Ldc(typeId(expr.getType)) - cfch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") - cfch << InvokeInterface(JavaListClass, "iterator", s"()L$JavaIteratorClass;") - - val loop = cfch.getFreshLabel("loop") - val out = cfch.getFreshLabel("out") - cfch << Label(loop) - // it - cfch << DUP - // it, it - cfch << InvokeInterface(JavaIteratorClass, "hasNext", "()Z") - // it, hasNext - cfch << IfEq(out) << DUP - // it, it - cfch << InvokeInterface(JavaIteratorClass, "next", s"()L$ObjectClass;") - // it, elem - cfch << CheckCast(TupleClass) - - val (newLoc, newPtr) = (for ((arg, aidx) <- args.zipWithIndex) yield { - val id = FreshIdentifier("q", arg.getType, true) - val slot = getSlot() - - cfch << DUP << Ldc(aidx) << InvokeVirtual(TupleClass, "get", s"(I)L$ObjectClass;") - mkUnbox(arg.getType, cfch) - cfch << (typeToJVM(arg.getType) match { - case "I" | "Z" => IStore(slot) - case _ => AStore(slot) - }) - - (id -> slot, (qidx -> aidx) -> id) - }).unzip - - cfch << POP - // it - - rec(rest, locs ++ newLoc, pointers ++ newPtr) - - cfch << Goto(loop) - cfch << Label(out) << POP - - case Nil => - val okLabel = cfch.getFreshLabel("assignmentOk") - - var mappings: Seq[(Identifier, Int, Int)] = Seq.empty - var constraints: Seq[(Expr, Int, Int)] = Seq.empty - var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty - - for ((q @ (_, expr, args), qidx) <- quorum.zipWithIndex) { - val (qmappings, qconstraints) = args.zipWithIndex.partition { - case (Variable(id), aidx) => quantified(id) - case _ => false - } - - mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) - constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) - } - - val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { - val base :: others = es.toList.map(p => (p._2, p._3)) - equalities ++= others.map(p => base -> p) - (id -> base) - } - - val enabler = andJoin(constraints.map { - case (e, qidx, aidx) => Equals(e, pointers(qidx -> aidx).toVariable) - } ++ equalities.map { - case (k1, k2) => Equals(pointers(k1).toVariable, pointers(k2).toVariable) - }) - - val varsMap = quantified.map(id => id -> locs(pointers(mapping(id)))).toMap - val varLocals = locals.withVars(varsMap) - - mkExpr(enabler, cfch)(varLocals.withVars(locs)) - cfch << IfEq(okLabel) - - val checkOk = cfch.getFreshLabel("checkOk") - load(monitorID, cfch) - cfch << GetField(HenkinClass, "checkForalls", "Z") - cfch << IfEq(checkOk) - - var nextLabel: Option[String] = None - for ((b,caller,args) <- others) { - nextLabel.foreach(label => cfch << Label(label)) - nextLabel = Some(cfch.getFreshLabel("next")) - - mkExpr(b, cfch)(varLocals) - cfch << IfEq(nextLabel.get) - - load(monitorID, cfch) - cfch << CheckCast(HenkinClass) - mkExpr(caller, cfch)(varLocals) - cfch << Ldc(typeId(caller.getType)) - cfch << InvokeVirtual(HenkinClass, "domain", s"(L$ObjectClass;I)L$JavaListClass;") - mkTuple(args, cfch)(varLocals) - cfch << InvokeInterface(JavaListClass, "contains", s"(L$ObjectClass;)Z") - cfch << IfNe(nextLabel.get) - - cfch << New(InvalidForallClass) << DUP - cfch << Ldc("Unhandled transitive implication in " + conj) - cfch << InvokeSpecial(InvalidForallClass, constructorName, "(Ljava/lang/String;)V") - cfch << ATHROW - } - nextLabel.foreach(label => cfch << Label(label)) - - cfch << Label(checkOk) - mkExpr(conj, cfch)(varLocals) - cfch << IfNe(okLabel) - - // -- Forall is false! -- - // POP all the iterators... - for (_ <- List.range(0, quorum.size)) cfch << POP - - // ... and return false - cfch << Ldc(0) << Goto(endLabel) - cfch << Label(okLabel) - } - - val skipQuorum = cfch.getFreshLabel("skipQuorum") - for ((TopLevelAnds(paths), _, _) <- quorum) { - val p = andJoin(paths.filter(path => (variablesOf(path) & quantified).isEmpty)) - mkExpr(p, cfch) - cfch << IfEq(skipQuorum) - } - - val mis = quorum.zipWithIndex.map { case ((p, e, as), idx) => (p, e, as, idx) } - rec(mis, Map.empty, Map.empty) - freeSlots = allSlots - - cfch << Label(skipQuorum) - } - } - - cfch << Ldc(1) << Label(endLabel) - cfch << IRETURN - - cfch.freeze - } - - loader.register(cf) - - afName - }) - - (afName, closures.map(reverseSubst), s"(L$MonitorClass;L$TupleClass;)V") + if (id == monitorID || id == tpsID) p else (reverseSubst(id) -> jvmt) + }, tparams, "(" + closures.map(_._2).mkString("") + ")V") } // also makes tuples with 0/1 args @@ -783,6 +402,31 @@ trait CodeGeneration { ch << InvokeSpecial(TupleClass, constructorName, s"([L$ObjectClass;)V") } + private def loadTypes(tps: Seq[TypeTree], ch: CodeHandler)(implicit locals: Locals): Unit = { + if (tps.nonEmpty) { + ch << Ldc(tps.size) + ch << NewArray.primitive("T_INT") + for ((tpe,idx) <- tps.zipWithIndex) { + ch << DUP << Ldc(idx) << Ldc(registerType(tpe)) << IASTORE + } + + if (locals.tps.nonEmpty) { + load(monitorID, ch) + ch << SWAP + + ch << Ldc(locals.tps.size) + ch << NewArray.primitive("T_INT") + for ((tpe,idx) <- locals.tps.zipWithIndex) { + ch << DUP << Ldc(idx) << Ldc(registerType(tpe)) << IASTORE + } + + ch << SWAP + load(tpsID, ch) + ch << InvokeVirtual(MonitorClass, "typeParams", s"([I[I[I)[I") + } + } + } + private[codegen] def mkExpr(e: Expr, ch: CodeHandler, canDelegateToMkBranch: Boolean = true)(implicit locals: Locals) { e match { case Variable(id) => @@ -839,9 +483,7 @@ trait CodeGeneration { throw CompilationException("Unknown class : " + cct.id) } ch << New(ccName) << DUP - if (requireMonitor) { - load(monitorID, ch) - } + load(monitorID, ch) for((a, vd) <- as zip cct.classDef.fields) { vd.getType match { @@ -968,11 +610,6 @@ trait CodeGeneration { throw CompilationException("Unknown method : " + tfd.id) } - if (requireMonitor) { - load(monitorID, ch) - ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") - } - // Get static field ch << GetStatic(className, fieldName, typeToJVM(tfd.fd.returnType)) @@ -1024,10 +661,9 @@ trait CodeGeneration { ch << POP << POP // list, it, cons, cons, elem, list - if (requireMonitor) { - load(monitorID, ch) - ch << DUP_X2 << POP - } + load(monitorID, ch) + ch << DUP_X2 << POP + ch << InvokeSpecial(consName, constructorName, ccApplySig) // list, it, newList ch << DUP_X2 << POP << SWAP << POP @@ -1039,15 +675,41 @@ trait CodeGeneration { ch << POP // list + case FunctionInvocation(tfd, as) if abstractFunDefs(tfd.fd.id) => + val id = registerAbstractFD(tfd.fd) + + load(monitorID, ch) + + ch << Ldc(id) + if (tfd.fd.tparams.nonEmpty) { + loadTypes(tfd.tps, ch) + } else { + ch << Ldc(0) << NewArray.primitive("T_INT") + } + + ch << Ldc(as.size) + ch << NewArray(ObjectClass) + + for ((e, i) <- as.zipWithIndex) { + ch << DUP + ch << Ldc(i) + mkExpr(e, ch) + mkBox(e.getType, ch) + ch << AASTORE + } + + ch << InvokeVirtual(MonitorClass, "onAbstractInvocation", s"(I[I[L$ObjectClass;)L$ObjectClass;") + + mkUnbox(tfd.returnType, ch) + // Static lazy fields/ functions case fi @ FunctionInvocation(tfd, as) => val (cn, mn, ms) = leonFunDefToJVMInfo(tfd.fd).getOrElse { throw CompilationException("Unknown method : " + tfd.id) } - if (requireMonitor) { - load(monitorID, ch) - } + load(monitorID, ch) + loadTypes(tfd.tps, ch) for((a, vd) <- as zip tfd.fd.params) { vd.getType match { @@ -1072,10 +734,6 @@ trait CodeGeneration { throw CompilationException("Unknown method : " + tfd.id) } - if (requireMonitor) { - load(monitorID, ch) - ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") - } // Load receiver mkExpr(rec,ch) @@ -1097,11 +755,10 @@ trait CodeGeneration { } // Receiver of the method call - mkExpr(rec,ch) + mkExpr(rec, ch) - if (requireMonitor) { - load(monitorID, ch) - } + load(monitorID, ch) + loadTypes(tfd.tps, ch) for((a, vd) <- as zip tfd.fd.params) { vd.getType match { @@ -1133,39 +790,31 @@ trait CodeGeneration { ch << InvokeVirtual(LambdaClass, "apply", s"([L$ObjectClass;)L$ObjectClass;") mkUnbox(app.getType, ch) - case p @ PartialLambda(mapping, optDflt, _) => - ch << New(PartialLambdaClass) << DUP - optDflt match { - case Some(dflt) => - mkBoxedExpr(dflt, ch) - ch << InvokeSpecial(PartialLambdaClass, constructorName, s"(L$ObjectClass;)V") - case None => - ch << InvokeSpecial(PartialLambdaClass, constructorName, "()V") - } + case p @ FiniteLambda(mapping, dflt, _) => + ch << New(FiniteLambdaClass) << DUP + mkBoxedExpr(dflt, ch) + ch << InvokeSpecial(FiniteLambdaClass, constructorName, s"(L$ObjectClass;)V") for ((es,v) <- mapping) { ch << DUP mkTuple(es, ch) mkBoxedExpr(v, ch) - ch << InvokeVirtual(PartialLambdaClass, "add", s"(L$TupleClass;L$ObjectClass;)V") + ch << InvokeVirtual(FiniteLambdaClass, "add", s"(L$TupleClass;L$ObjectClass;)V") } case l @ Lambda(args, body) => - val (afName, closures, consSig) = compileLambda(l) + val (afName, closures, tparams, consSig) = compileLambda(l) ch << New(afName) << DUP for ((id,jvmt) <- closures) { - if (id == monitorID) { - load(monitorID, ch) + if (id == tpsID) { + loadTypes(tparams, ch) } else { mkExpr(Variable(id), ch) } } ch << InvokeSpecial(afName, constructorName, consSig) - case f @ Forall(args, body) => - mkForall(args.map(_.id).toSet, body, ch) - // String processing => case StringConcat(l, r) => mkExpr(l, ch) @@ -1407,11 +1056,43 @@ trait CodeGeneration { ch << InvokeSpecial(ErrorClass, constructorName, "(Ljava/lang/String;)V") ch << ATHROW + case forall @ Forall(fargs, body) => + val id = registerForall(forall, locals.tps) + val args = variablesOf(forall).toSeq.sortBy(_.uniqueName) + + load(monitorID, ch) + ch << Ldc(id) + if (locals.tps.nonEmpty) { + load(tpsID, ch) + } else { + ch << Ldc(0) << NewArray.primitive("T_INT") + } + + ch << Ldc(args.size) + ch << NewArray(ObjectClass) + + for ((id, i) <- args.zipWithIndex) { + ch << DUP + ch << Ldc(i) + mkExpr(Variable(id), ch) + mkBox(id.getType, ch) + ch << AASTORE + } + + ch << InvokeVirtual(MonitorClass, "onForallInvocation", s"(I[I[L$ObjectClass;)Z") + case choose: Choose => val prob = synthesis.Problem.fromSpec(choose.pred) - val id = runtime.ChooseEntryPoint.register(prob, this) + val id = registerProblem(prob, locals.tps) + + load(monitorID, ch) ch << Ldc(id) + if (locals.tps.nonEmpty) { + load(tpsID, ch) + } else { + ch << Ldc(0) << NewArray.primitive("T_INT") + } ch << Ldc(prob.as.size) ch << NewArray(ObjectClass) @@ -1424,7 +1105,7 @@ trait CodeGeneration { ch << AASTORE } - ch << InvokeStatic(ChooseEntryPointClass, "invoke", s"(I[L$ObjectClass;)L$ObjectClass;") + ch << InvokeVirtual(MonitorClass, "onChooseInvocation", s"(I[I[L$ObjectClass;)L$ObjectClass;") mkUnbox(choose.getType, ch) @@ -1731,9 +1412,7 @@ trait CodeGeneration { // accessor method locally { - val parameters = if (requireMonitor) { - Seq(monitorID -> s"L$MonitorClass;") - } else Seq() + val parameters = Seq(monitorID -> s"L$MonitorClass;") val paramMapping = parameters.map(_._1).zipWithIndex.toMap.mapValues(_ + (if (isStatic) 0 else 1)) val newLocs = NoLocals.withVars(paramMapping) @@ -1749,11 +1428,6 @@ trait CodeGeneration { val body = lzy.body.getOrElse(throw CompilationException("Lazy field without body?")) val initLabel = ch.getFreshLabel("isInitialized") - if (requireMonitor) { - load(monitorID, ch)(newLocs) - ch << InvokeVirtual(MonitorClass, "onInvoke", "()V") - } - if (isStatic) { ch << GetStatic(cName, underlyingName, underlyingType) } else { @@ -1890,9 +1564,7 @@ trait CodeGeneration { // definition of the constructor locally { - val constrParams = if (requireMonitor) { - Seq(monitorID -> s"L$MonitorClass;") - } else Seq() + val constrParams = Seq(monitorID -> s"L$MonitorClass;") val newLocs = NoLocals.withVars { constrParams.map(_._1).zipWithIndex.toMap.mapValues(_ + 1) @@ -1909,8 +1581,8 @@ trait CodeGeneration { case Some(parent) => val pName = defToJVMName(parent.classDef) // Load monitor object - if (requireMonitor) cch << ALoad(1) - val constrSig = if (requireMonitor) "(L" + MonitorClass + ";)V" else "()V" + cch << ALoad(1) + val constrSig = "(L" + MonitorClass + ";)V" cch << InvokeSpecial(pName, constructorName, constrSig) case None => @@ -1985,9 +1657,7 @@ trait CodeGeneration { // Case class parameters val fieldsTypes = ccd.fields.map { vd => (vd.id, typeToJVM(vd.getType)) } - val constructorArgs = if (requireMonitor) { - (monitorID -> s"L$MonitorClass;") +: fieldsTypes - } else fieldsTypes + val constructorArgs = (monitorID -> s"L$MonitorClass;") +: fieldsTypes val newLocs = NoLocals.withFields(constructorArgs.map { case (id, jvmt) => (id, (cName, id.name, jvmt)) @@ -2013,62 +1683,54 @@ trait CodeGeneration { } // definition of the constructor - if(!params.doInstrument && !requireMonitor && ccd.fields.isEmpty && !ccd.methods.exists(_.canBeField)) { - cf.addDefaultConstructor - } else { - for((id, jvmt) <- constructorArgs) { - val fh = cf.addField(jvmt, id.name) - fh.setFlags(( - FIELD_ACC_PUBLIC | - FIELD_ACC_FINAL - ).asInstanceOf[U2]) - } - - if (params.doInstrument) { - val fh = cf.addField("I", instrumentedField) - fh.setFlags(FIELD_ACC_PUBLIC) - } + for((id, jvmt) <- constructorArgs) { + val fh = cf.addField(jvmt, id.name) + fh.setFlags(( + FIELD_ACC_PUBLIC | + FIELD_ACC_FINAL + ).asInstanceOf[U2]) + } - val cch = cf.addConstructor(constructorArgs.map(_._2) : _*).codeHandler + if (params.doInstrument) { + val fh = cf.addField("I", instrumentedField) + fh.setFlags(FIELD_ACC_PUBLIC) + } - if (params.doInstrument) { - cch << ALoad(0) - cch << Ldc(0) - cch << PutField(cName, instrumentedField, "I") - } + val cch = cf.addConstructor(constructorArgs.map(_._2) : _*).codeHandler - var c = 1 - for((id, jvmt) <- constructorArgs) { - cch << ALoad(0) - cch << (jvmt match { - case "I" | "Z" => ILoad(c) - case _ => ALoad(c) - }) - cch << PutField(cName, id.name, jvmt) - c += 1 - } + if (params.doInstrument) { + cch << ALoad(0) + cch << Ldc(0) + cch << PutField(cName, instrumentedField, "I") + } - // Call parent constructor AFTER initializing case class parameters - if (ccd.parent.isDefined) { - cch << ALoad(0) - if (requireMonitor) { - cch << ALoad(1) - cch << InvokeSpecial(pName.get, constructorName, s"(L$MonitorClass;)V") - } else { - cch << InvokeSpecial(pName.get, constructorName, "()V") - } - } else { - // Call constructor of java.lang.Object - cch << ALoad(0) - cch << InvokeSpecial(ObjectClass, constructorName, "()V") - } + var c = 1 + for((id, jvmt) <- constructorArgs) { + cch << ALoad(0) + cch << (jvmt match { + case "I" | "Z" => ILoad(c) + case _ => ALoad(c) + }) + cch << PutField(cName, id.name, jvmt) + c += 1 + } - // Now initialize fields - for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, isStatic = false)(newLocs) } - for (field <- strictFields) { initStrictField(cch, cName , field, isStatic = false)(newLocs) } - cch << RETURN - cch.freeze + // Call parent constructor AFTER initializing case class parameters + if (ccd.parent.isDefined) { + cch << ALoad(0) + cch << ALoad(1) + cch << InvokeSpecial(pName.get, constructorName, s"(L$MonitorClass;)V") + } else { + // Call constructor of java.lang.Object + cch << ALoad(0) + cch << InvokeSpecial(ObjectClass, constructorName, "()V") } + + // Now initialize fields + for (lzy <- lazyFields) { initLazyField(cch, cName, lzy, isStatic = false)(newLocs) } + for (field <- strictFields) { initStrictField(cch, cName , field, isStatic = false)(newLocs) } + cch << RETURN + cch.freeze } locally { diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 26f976c095489f22cbc97dd9ee810d1d71ef9f99..9cef8cd0452fa6596f1ebc2e86ed5d6f4d394c1f 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -8,11 +8,11 @@ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps.typeParamsOf import purescala.Extractors._ import purescala.Constructors._ -import codegen.runtime.LeonCodeGenRuntimeMonitor -import codegen.runtime.LeonCodeGenRuntimeHenkinMonitor import utils.UniqueCounter +import runtime.{Monitor, StdMonitor} import cafebabe._ import cafebabe.AbstractByteCodes._ @@ -24,22 +24,62 @@ import scala.collection.JavaConverters._ import java.lang.reflect.Constructor +import synthesis.Problem class CompilationUnit(val ctx: LeonContext, val program: Program, val params: CodeGenParams = CodeGenParams.default) extends CodeGeneration { + protected[codegen] val requireQuantification = program.definedFunctions.exists { fd => exists { case _: Forall => true case _ => false } (fd.fullBody) } - protected[codegen] val requireMonitor = params.requireMonitor || requireQuantification - val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader) var classes = Map[Definition, ClassFile]() + var defToModuleOrClass = Map[Definition, Definition]() + val abstractFunDefs = program.definedFunctions.filter(_.body.isEmpty).map(_.id).toSet + + val runtimeCounter = new UniqueCounter[Unit] + + var runtimeTypeToIdMap = Map[TypeTree, Int]() + var runtimeIdToTypeMap = Map[Int, TypeTree]() + def registerType(tpe: TypeTree): Int = runtimeTypeToIdMap.get(tpe) match { + case Some(id) => id + case None => + val id = runtimeCounter.nextGlobal + runtimeTypeToIdMap += tpe -> id + runtimeIdToTypeMap += id -> tpe + id + } + + var runtimeProblemMap = Map[Int, (Seq[TypeParameter], Problem)]() + + def registerProblem(p: Problem, tps: Seq[TypeParameter]): Int = { + val id = runtimeCounter.nextGlobal + runtimeProblemMap += id -> (tps, p) + id + } + + var runtimeForallMap = Map[Int, (Seq[TypeParameter], Forall)]() + + def registerForall(f: Forall, tps: Seq[TypeParameter]): Int = { + val id = runtimeCounter.nextGlobal + runtimeForallMap += id -> (tps, f) + id + } + + var runtimeAbstractMap = Map[Int, FunDef]() + + def registerAbstractFD(fd: FunDef): Int = { + val id = runtimeCounter.nextGlobal + runtimeAbstractMap += id -> fd + id + } + def defineClass(df: Definition) { val cName = defToJVMName(df) @@ -65,8 +105,7 @@ class CompilationUnit(val ctx: LeonContext, def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { classes.get(cd) match { case Some(cf) => - val monitorType = if (requireMonitor) "L"+MonitorClass+";" else "" - val sig = "(" + monitorType + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" + val sig = "(L"+MonitorClass+";" + cd.fields.map(f => typeToJVM(f.getType)).mkString("") + ")V" Some((cf.className, sig)) case _ => None } @@ -84,9 +123,9 @@ class CompilationUnit(val ctx: LeonContext, */ def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = { funDefInfo.get(fd).orElse { - val monitorType = if (requireMonitor) "L"+MonitorClass+";" else "" - - val sig = "(" + monitorType + fd.params.map(a => typeToJVM(a.getType)).mkString("") + ")" + typeToJVM(fd.returnType) + val sig = "(L"+MonitorClass+";" + + (if (fd.tparams.nonEmpty) "[I" else "") + + fd.params.map(a => typeToJVM(a.getType)).mkString("") + ")" + typeToJVM(fd.returnType) defToModuleOrClass.get(fd).flatMap(m => classes.get(m)) match { case Some(cf) => @@ -127,30 +166,14 @@ class CompilationUnit(val ctx: LeonContext, conss.last } - def modelToJVM(model: solvers.Model, maxInvocations: Int, check: Boolean): LeonCodeGenRuntimeMonitor = model match { - case hModel: solvers.HenkinModel => - val lhm = new LeonCodeGenRuntimeHenkinMonitor(maxInvocations, check) - for ((lambda, domain) <- hModel.doms.lambdas) { - val (afName, _, _) = compileLambda(lambda) - val lc = loader.loadClass(afName) - - for (args <- domain) { - // note here that it doesn't matter that `lhm` doesn't yet have its domains - // filled since all values in `args` should be grounded - val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] - lhm.add(lc, inputJvm) - } - } + def getMonitor(model: solvers.Model, maxInvocations: Int): Monitor = { + val bodies = model.toSeq.filter { case (id, v) => abstractFunDefs(id) }.toMap + val domains = model match { + case hm: solvers.PartialModel => Some(hm.domains) + case _ => None + } - for ((tpe, domain) <- hModel.doms.tpes; args <- domain) { - val tpeId = typeId(tpe) - // same remark as above about valueToJVM(_)(lhm) - val inputJvm = tupleConstructor.newInstance(args.map(valueToJVM(_)(lhm)).toArray).asInstanceOf[leon.codegen.runtime.Tuple] - lhm.add(tpeId, inputJvm) - } - lhm - case _ => - new LeonCodeGenRuntimeMonitor(maxInvocations) + new StdMonitor(this, maxInvocations, bodies, domains) } /** Translates Leon values (not generic expressions) to JVM compatible objects. @@ -159,7 +182,7 @@ class CompilationUnit(val ctx: LeonContext, * This means it is safe to return AnyRef (as opposed to primitive types), because * reflection needs this anyway. */ - def valueToJVM(e: Expr)(implicit monitor: LeonCodeGenRuntimeMonitor): AnyRef = e match { + def valueToJVM(e: Expr)(implicit monitor: Monitor): AnyRef = e match { case IntLiteral(v) => new java.lang.Integer(v) @@ -190,8 +213,8 @@ class CompilationUnit(val ctx: LeonContext, case CaseClass(cct, args) => caseClassConstructor(cct.classDef) match { case Some(cons) => - val realArgs = if (requireMonitor) monitor +: args.map(valueToJVM) else args.map(valueToJVM) - cons.newInstance(realArgs.toArray : _*).asInstanceOf[AnyRef] + val jvmArgs = monitor +: args.map(valueToJVM) + cons.newInstance(jvmArgs.toArray : _*).asInstanceOf[AnyRef] case None => ctx.reporter.fatalError("Case class constructor not found?!?") } @@ -215,12 +238,8 @@ class CompilationUnit(val ctx: LeonContext, } m - case f @ PartialLambda(mapping, dflt, _) => - val l = if (dflt.isDefined) { - new leon.codegen.runtime.PartialLambda(dflt.get) - } else { - new leon.codegen.runtime.PartialLambda() - } + case f @ FiniteLambda(mapping, dflt, _) => + val l = new leon.codegen.runtime.FiniteLambda(valueToJVM(dflt)) for ((ks,v) <- mapping) { // Force tuple even with 1/0 elems. @@ -230,6 +249,22 @@ class CompilationUnit(val ctx: LeonContext, } l + case l @ Lambda(args, body) => + val (afName, closures, tparams, consSig) = compileLambda(l) + val args = closures.map { case (id, _) => + if (id == monitorID) monitor + else if (id == tpsID) typeParamsOf(l).toSeq.sortBy(_.id.uniqueName).map(registerType).toArray + else throw CompilationException(s"Unexpected closure $id in Lambda compilation") + } + + val lc = loader.loadClass(afName) + val conss = lc.getConstructors.sortBy(_.getParameterTypes.length) + println(conss) + assert(conss.nonEmpty) + val lambdaConstructor = conss.last + println(args.toArray) + lambdaConstructor.newInstance(args.toArray : _*).asInstanceOf[AnyRef] + case f @ IsTyped(FiniteArray(elems, default, IntLiteral(length)), ArrayType(underlying)) => if (length < 0) { throw LeonFatalError( @@ -271,10 +306,6 @@ class CompilationUnit(val ctx: LeonContext, case _ => throw CompilationException(s"Unexpected expression $e in valueToJVM") - - // Just slightly overkill... - //case _ => - // compileExpression(e, Seq()).evalToJVM(Seq(),monitor) } /** Translates JVM objects back to Leon values of the appropriate type */ @@ -340,6 +371,15 @@ class CompilationUnit(val ctx: LeonContext, }.toMap FiniteMap(pairs, from, to) + case (lambda: runtime.FiniteLambda, ft @ FunctionType(from, to)) => + val mapping = lambda.mapping.asScala.map { entry => + val k = jvmToValue(entry._1, tupleTypeWrap(from)) + val v = jvmToValue(entry._2, to) + unwrapTuple(k, from.size) -> v + } + val dflt = jvmToValue(lambda.dflt, to) + FiniteLambda(mapping.toSeq, dflt, ft) + case (lambda: runtime.Lambda, _: FunctionType) => val cls = lambda.getClass @@ -390,11 +430,7 @@ class CompilationUnit(val ctx: LeonContext, val argsTypes = args.map(a => typeToJVM(a.getType)) - val realArgs = if (requireMonitor) { - ("L" + MonitorClass + ";") +: argsTypes - } else { - argsTypes - } + val realArgs = ("L" + MonitorClass + ";") +: argsTypes val m = cf.addMethod( typeToJVM(e.getType), @@ -410,11 +446,7 @@ class CompilationUnit(val ctx: LeonContext, val ch = m.codeHandler - val newMapping = if (requireMonitor) { - args.zipWithIndex.toMap.mapValues(_ + 1) + (monitorID -> 0) - } else { - args.zipWithIndex.toMap - } + val newMapping = Map(monitorID -> 0) ++ args.zipWithIndex.toMap.mapValues(_ + 1) mkExpr(e, ch)(NoLocals.withVars(newMapping)) @@ -480,10 +512,10 @@ class CompilationUnit(val ctx: LeonContext, * method invocations here :( */ val locals = NoLocals.withVar(monitorID -> ch.getFreshVar) - ch << New(MonitorClass) << DUP - ch << Ldc(Int.MaxValue) // Allow "infinite" method calls - ch << InvokeSpecial(MonitorClass, cafebabe.Defaults.constructorName, "(I)V") + ch << New(NoMonitorClass) << DUP + ch << InvokeSpecial(NoMonitorClass, cafebabe.Defaults.constructorName, "()V") ch << AStore(locals.varToLocal(monitorID).get) // position 0 + for (lzy <- lazyFields) { initLazyField(ch, cName, lzy, isStatic = true)(locals) } for (field <- strictFields) { initStrictField(ch, cName , field, isStatic = true)(locals) } ch << RETURN diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index f9fca911564c61ad984fa97c3f2ac0da7fc021b4..6467a2068cf56227f26d01df6bec2287d515d48d 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -8,7 +8,7 @@ import purescala.Expressions._ import cafebabe._ -import runtime.{LeonCodeGenRuntimeMonitor => LM} +import runtime.Monitor import java.lang.reflect.InvocationTargetException @@ -21,29 +21,21 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, private val params = unit.params - def argsToJVM(args: Seq[Expr], monitor: LM): Seq[AnyRef] = { + def argsToJVM(args: Seq[Expr], monitor: Monitor): Seq[AnyRef] = { args.map(unit.valueToJVM(_)(monitor)) } - def evalToJVM(args: Seq[AnyRef], monitor: LM): AnyRef = { + def evalToJVM(args: Seq[AnyRef], monitor: Monitor): AnyRef = { assert(args.size == argsDecl.size) - val realArgs = if (unit.requireMonitor) { - monitor +: args - } else { - args - } + val allArgs = monitor +: args - if (realArgs.isEmpty) { - meth.invoke(null) - } else { - meth.invoke(null, realArgs.toArray : _*) - } + meth.invoke(null, allArgs.toArray : _*) } // This may throw an exception. We unwrap it if needed. // We also need to reattach a type in some cases (sets, maps). - def evalFromJVM(args: Seq[AnyRef], monitor: LM) : Expr = { + def evalFromJVM(args: Seq[AnyRef], monitor: Monitor) : Expr = { try { unit.jvmToValue(evalToJVM(args, monitor), exprType) } catch { @@ -51,9 +43,10 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression: Expr, } } - def eval(model: solvers.Model, check: Boolean = false) : Expr = { + def eval(model: solvers.Model) : Expr = { try { - val monitor = unit.modelToJVM(model, params.maxFunctionInvocations, check) + val monitor = unit.getMonitor(model, params.maxFunctionInvocations) + evalFromJVM(argsToJVM(argsDecl.map(model), monitor), monitor) } catch { case ite : InvocationTargetException => throw ite.getCause diff --git a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala deleted file mode 100644 index 84968a169370ff3c415bf9689f48c86577eed44e..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package codegen.runtime - -import utils._ -import purescala.Expressions._ -import purescala.ExprOps.valuateWithModel -import purescala.Constructors._ -import solvers.SolverFactory - -import java.util.WeakHashMap -import java.lang.ref.WeakReference -import scala.collection.mutable.{HashMap => MutableMap} -import scala.concurrent.duration._ - -import codegen.CompilationUnit - -import synthesis._ - -object ChooseEntryPoint { - implicit val debugSection = DebugSectionSynthesis - - private case class ChooseId(id: Int) { } - - private[this] val context = new WeakHashMap[ChooseId, (WeakReference[CompilationUnit], Problem)]() - private[this] val cache = new WeakHashMap[ChooseId, MutableMap[Seq[AnyRef], java.lang.Object]]() - - private[this] val ids = new WeakHashMap[CompilationUnit, MutableMap[Problem, ChooseId]]() - - private val intCounter = new UniqueCounter[Unit] - intCounter.nextGlobal // Start with 1 - - private def getUniqueId(unit: CompilationUnit, p: Problem): ChooseId = synchronized { - if (!ids.containsKey(unit)) { - ids.put(unit, new MutableMap()) - } - - if (ids.get(unit) contains p) { - ids.get(unit)(p) - } else { - val cid = new ChooseId(intCounter.nextGlobal) - ids.get(unit) += p -> cid - cid - } - } - - def register(p: Problem, unit: CompilationUnit): Int = { - val cid = getUniqueId(unit, p) - - context.put(cid, new WeakReference(unit) -> p) - - cid.id - } - - def invoke(i: Int, inputs: Array[AnyRef]): java.lang.Object = { - val id = ChooseId(i) - val (ur, p) = context.get(id) - val unit = ur.get - - val program = unit.program - val ctx = unit.ctx - - ctx.reporter.debug("Executing choose (codegen)!") - val is = inputs.toSeq - - if (!cache.containsKey(id)) { - cache.put(id, new MutableMap()) - } - - val chCache = cache.get(id) - - if (chCache contains is) { - chCache(is) - } else { - val tStart = System.currentTimeMillis - - val solverf = SolverFactory.default(ctx, program).withTimeout(10.second) - val solver = solverf.getNewSolver() - - val inputsMap = (p.as zip inputs).map { - case (id, v) => - Equals(Variable(id), unit.jvmToValue(v, id.getType)) - } - - solver.assertCnstr(andJoin(Seq(p.pc, p.phi) ++ inputsMap)) - - try { - solver.check match { - case Some(true) => - val model = solver.getModel - - val valModel = valuateWithModel(model) _ - - val res = p.xs.map(valModel) - val leonRes = tupleWrap(res) - - val total = System.currentTimeMillis-tStart - - ctx.reporter.debug("Synthesis took "+total+"ms") - ctx.reporter.debug("Finished synthesis with "+leonRes.asString(ctx)) - - val obj = unit.valueToJVM(leonRes)(new LeonCodeGenRuntimeMonitor(unit.params.maxFunctionInvocations)) - chCache += is -> obj - obj - case Some(false) => - throw new LeonCodeGenRuntimeException("Constraint is UNSAT") - case _ => - throw new LeonCodeGenRuntimeException("Timeout exceeded") - } - } finally { - solver.free() - solverf.shutdown() - } - } - } -} diff --git a/src/main/scala/leon/codegen/runtime/Monitor.scala b/src/main/scala/leon/codegen/runtime/Monitor.scala new file mode 100644 index 0000000000000000000000000000000000000000..0861ce3bc406711e5d4d460c5e5cacdb7eb09cd0 --- /dev/null +++ b/src/main/scala/leon/codegen/runtime/Monitor.scala @@ -0,0 +1,252 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package codegen.runtime + +import utils._ +import purescala.Expressions._ +import purescala.Constructors._ +import purescala.Definitions._ +import purescala.Common._ +import purescala.Types._ +import purescala.TypeOps._ +import purescala.ExprOps.{valuateWithModel, replaceFromIDs, variablesOf} +import purescala.Quantification.{extractQuorums, Domains} + +import codegen.CompilationUnit + +import scala.collection.immutable.{Map => ScalaMap} +import scala.collection.mutable.{HashMap => MutableMap, Set => MutableSet} +import scala.concurrent.duration._ + +import solvers.SolverFactory +import solvers.combinators.UnrollingProcedure + +import synthesis._ + +abstract class Monitor { + def onInvocation(): Unit + + def typeParams(params: Array[Int], tps: Array[Int], newTps: Array[Int]): Array[Int] + + def onAbstractInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): AnyRef + + def onChooseInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): AnyRef + + def onForallInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): Boolean +} + +class NoMonitor extends Monitor { + def onInvocation(): Unit = {} + + def typeParams(params: Array[Int], tps: Array[Int], newTps: Array[Int]): Array[Int] = { + throw new LeonCodeGenEvaluationException("No monitor available.") + } + + def onAbstractInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): AnyRef = { + throw new LeonCodeGenEvaluationException("No monitor available.") + } + + def onChooseInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): AnyRef = { + throw new LeonCodeGenEvaluationException("No monitor available.") + } + + def onForallInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): Boolean = { + throw new LeonCodeGenEvaluationException("No monitor available.") + } +} + +class StdMonitor(unit: CompilationUnit, invocationsMax: Int, bodies: ScalaMap[Identifier, Expr], domains: Option[Domains] = None) extends Monitor { + + private[this] var invocations = 0 + + def onInvocation(): Unit = { + if (invocationsMax >= 0) { + if (invocations < invocationsMax) { + invocations += 1; + } else { + throw new LeonCodeGenEvaluationException("Maximum number of invocations reached ("+invocationsMax+")."); + } + } + } + + def typeParams(params: Array[Int], tps: Array[Int], newTps: Array[Int]): Array[Int] = { + val tparams = params.toSeq.map(unit.runtimeIdToTypeMap(_).asInstanceOf[TypeParameter]) + val static = tps.toSeq.map(unit.runtimeIdToTypeMap(_)) + val newTypes = newTps.toSeq.map(unit.runtimeIdToTypeMap(_)) + val tpMap = (tparams.map(TypeParameterDef(_)) zip newTypes).toMap + static.map(tpe => unit.registerType(instantiateType(tpe, tpMap))).toArray + } + + def onAbstractInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): AnyRef = { + val fd = unit.runtimeAbstractMap(id) + + // TODO: extract types too! + + bodies.get(fd.id) match { + case Some(expr) => + throw new LeonCodeGenRuntimeException("Found body!") + + case None => + throw new LeonCodeGenRuntimeException("Did not find body!") + } + } + + private[this] val chooseCache = new MutableMap[(Int, Seq[AnyRef]), AnyRef]() + + def onChooseInvocation(id: Int, tps: Array[Int], inputs: Array[AnyRef]): AnyRef = { + implicit val debugSection = DebugSectionSynthesis + + val (tparams, p) = unit.runtimeProblemMap(id) + + val program = unit.program + val ctx = unit.ctx + + ctx.reporter.debug("Executing choose (codegen)!") + val is = inputs.toSeq + + if (chooseCache contains ((id, is))) { + chooseCache((id, is)) + } else { + val tStart = System.currentTimeMillis + + val solverf = SolverFactory.default(ctx, program).withTimeout(10.second) + val solver = solverf.getNewSolver() + + val newTypes = tps.toSeq.map(unit.runtimeIdToTypeMap(_)) + val tpMap = (tparams.map(TypeParameterDef(_)) zip newTypes).toMap + + val newXs = p.xs.map { id => + val newTpe = instantiateType(id.getType, tpMap) + if (id.getType == newTpe) id else FreshIdentifier(id.name, newTpe, true) + } + + val newAs = p.as.map { id => + val newTpe = instantiateType(id.getType, tpMap) + if (id.getType == newTpe) id else FreshIdentifier(id.name, newTpe, true) + } + + val inputsMap = (newAs zip inputs).map { + case (id, v) => Equals(Variable(id), unit.jvmToValue(v, id.getType)) + } + + val expr = instantiateType(and(p.pc, p.phi), tpMap, (p.as zip newAs).toMap ++ (p.xs zip newXs)) + solver.assertCnstr(andJoin(expr +: inputsMap)) + + try { + solver.check match { + case Some(true) => + val model = solver.getModel + + val valModel = valuateWithModel(model) _ + + val res = newXs.map(valModel) + val leonRes = tupleWrap(res) + + val total = System.currentTimeMillis-tStart + + ctx.reporter.debug("Synthesis took "+total+"ms") + ctx.reporter.debug("Finished synthesis with "+leonRes.asString(ctx)) + + val obj = unit.valueToJVM(leonRes)(this) + chooseCache += (id, is) -> obj + obj + case Some(false) => + throw new LeonCodeGenRuntimeException("Constraint is UNSAT") + case _ => + throw new LeonCodeGenRuntimeException("Timeout exceeded") + } + } finally { + solver.free() + solverf.shutdown() + } + } + } + + private[this] val forallCache = new MutableMap[(Int, Seq[AnyRef]), Boolean]() + + def onForallInvocation(id: Int, tps: Array[Int], args: Array[AnyRef]): Boolean = { + implicit val debugSection = DebugSectionVerification + + val (tparams, f) = unit.runtimeForallMap(id) + + val program = unit.program + val ctx = unit.ctx.copy(options = unit.ctx.options.map { + case LeonOption(optDef, value) if optDef == UnrollingProcedure.optFeelingLucky => + LeonOption(optDef)(false) + case opt => opt + }) + + ctx.reporter.debug("Executing forall (codegen)!") + val argsSeq = args.toSeq + + if (forallCache contains ((id, argsSeq))) { + forallCache((id, argsSeq)) + } else { + val tStart = System.currentTimeMillis + + val solverf = SolverFactory.default(ctx, program).withTimeout(1.second) + val solver = solverf.getNewSolver() + + val newTypes = tps.toSeq.map(unit.runtimeIdToTypeMap(_)) + val tpMap = (tparams.map(TypeParameterDef(_)) zip newTypes).toMap + + val vars = variablesOf(f).toSeq.sortBy(_.uniqueName) + val newVars = vars.map(id => FreshIdentifier(id.name, instantiateType(id.getType, tpMap), true)) + + val Forall(fargs, body) = instantiateType(f, tpMap, (vars zip newVars).toMap) + val mapping = (newVars zip argsSeq).map(p => p._1 -> unit.jvmToValue(p._2, p._1.getType)).toMap + val cnstr = Not(replaceFromIDs(mapping, body)) + solver.assertCnstr(cnstr) + + if (domains.isDefined) { + val dom = domains.get + val quantifiers = fargs.map(_.id).toSet + val quorums = extractQuorums(body, quantifiers) + + val domainCnstr = orJoin(quorums.map { quorum => + val quantifierDomains = quorum.flatMap { case (path, caller, args) => + val domain = caller match { + case Variable(id) => dom.get(mapping(id)) + case _ => ctx.reporter.fatalError("Unexpected quantifier matcher: " + caller) + } + + args.zipWithIndex.flatMap { + case (Variable(id),idx) if quantifiers(id) => + Some(id -> domain.map(cargs => path -> cargs(idx))) + case _ => None + } + } + + val domainMap = quantifierDomains.groupBy(_._1).mapValues(_.map(_._2).flatten) + andJoin(domainMap.toSeq.map { case (id, dom) => + orJoin(dom.toSeq.map { case (path, value) => and(path, Equals(Variable(id), value)) }) + }) + }) + + solver.assertCnstr(domainCnstr) + } + + try { + solver.check match { + case Some(negRes) => + val res = !negRes + val total = System.currentTimeMillis-tStart + + ctx.reporter.debug("Verification took "+total+"ms") + ctx.reporter.debug("Finished forall evaluation with: "+res) + + forallCache += (id, argsSeq) -> res + res + + case _ => + throw new LeonCodeGenRuntimeException("Timeout exceeded") + } + } finally { + solver.free() + solverf.shutdown() + } + } + } +} + diff --git a/src/main/scala/leon/datagen/GrammarDataGen.scala b/src/main/scala/leon/datagen/GrammarDataGen.scala index cd86c707ddb893918e512f6d7101cc4cc92b6405..04541e78a6e63d1fa8670f4d4aeb12dd9c4417a2 100644 --- a/src/main/scala/leon/datagen/GrammarDataGen.scala +++ b/src/main/scala/leon/datagen/GrammarDataGen.scala @@ -4,14 +4,17 @@ package leon package datagen import purescala.Expressions._ -import purescala.Types.TypeTree +import purescala.Types._ import purescala.Common._ import purescala.Constructors._ import purescala.Extractors._ +import purescala.ExprOps._ import evaluators._ import bonsai.enumerators._ import grammars._ +import utils.UniqueCounter +import utils.SeqUtils.cartesianProduct /** Utility functions to generate values of a given type. * In fact, it could be used to generate *terms* of a given type, @@ -19,9 +22,40 @@ import grammars._ class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar[TypeTree] = ValueGrammar) extends DataGenerator { implicit val ctx = evaluator.context + // Assume e contains generic values with index 0. + // Return a series of expressions with all normalized combinations of generic values. + private def expandGenerics(e: Expr): Seq[Expr] = { + val c = new UniqueCounter[TypeParameter] + val withUniqueCounters: Expr = postMap { + case GenericValue(t, _) => + Some(GenericValue(t, c.next(t))) + case _ => None + }(e) + + val indices = c.current + + val (tps, substInt) = (for { + tp <- indices.keySet.toSeq + } yield tp -> (for { + from <- 0 to indices(tp) + to <- 0 to from + } yield (from, to))).unzip + + val combos = cartesianProduct(substInt) + + val substitutions = combos map { subst => + tps.zip(subst).map { case (tp, (from, to)) => + (GenericValue(tp, from): Expr) -> (GenericValue(tp, to): Expr) + }.toMap + } + + substitutions map (replace(_, withUniqueCounters)) + + } + def generate(tpe: TypeTree): Iterator[Expr] = { - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](grammar.getProductions) - enum.iterator(tpe) + val enum = new MemoizedEnumerator[TypeTree, Expr, ProductionRule[TypeTree, Expr]](grammar.getProductions) + enum.iterator(tpe).flatMap(expandGenerics) } def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]] = { @@ -51,4 +85,8 @@ class GrammarDataGen(evaluator: Evaluator, grammar: ExpressionGrammar[TypeTree] } } + def generateMapping(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int) = { + generateFor(ins, satisfying, maxValid, maxEnumerated) map (ins zip _) + } + } diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index b82dc6a0f256fec3c3cf733addad76f899456f4a..1975400500b741a8fb59fe953b72faf08bf8968e 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -13,7 +13,7 @@ import purescala.Constructors._ import codegen.CompilationUnit import codegen.CodeGenParams -import codegen.runtime.LeonCodeGenRuntimeMonitor +import codegen.runtime.StdMonitor import vanuatoo.{Pattern => VPattern, _} import evaluators._ @@ -131,7 +131,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { Constructor[Expr, TypeTree](subs, ft, { s => val grouped = s.grouped(from.size + 1).toSeq val mapping = grouped.init.map { case args :+ res => (args -> res) } - PartialLambda(mapping, Some(grouped.last.last), ft) + FiniteLambda(mapping, grouped.last.last, ft) }, ft.asString(ctx) + "@" + size) } constructors += ft -> cs @@ -262,7 +262,8 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { Some((args : Expr) => { try { - val monitor = new LeonCodeGenRuntimeMonitor(unit.params.maxFunctionInvocations) + val monitor = new StdMonitor(unit, unit.params.maxFunctionInvocations, Map()) + val jvmArgs = ce.argsToJVM(Seq(args), monitor) val result = ce.evalFromJVM(jvmArgs, monitor) diff --git a/src/main/scala/leon/evaluators/AbstractEvaluator.scala b/src/main/scala/leon/evaluators/AbstractEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..f86f5389a79751524b9b61e60dd8b7b804df8ee1 --- /dev/null +++ b/src/main/scala/leon/evaluators/AbstractEvaluator.scala @@ -0,0 +1,97 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package evaluators + +import purescala.Extractors.Operator +import purescala.Constructors._ +import purescala.Expressions._ +import purescala.Types._ +import purescala.Definitions.{TypedFunDef, Program} +import purescala.DefOps +import purescala.TypeOps +import purescala.ExprOps +import purescala.Expressions.Expr +import leon.utils.DebugSectionSynthesis + +/** The evaluation returns a pair (e, t), + * where e is the expression evaluated as much as possible, and t is the way the expression has been evaluated. + * Caution: If and Match statement require the condition to be non-abstract. */ +class AbstractEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvaluator(ctx, prog, 50000) with HasDefaultGlobalContext with HasDefaultRecContext { + lazy val scalaEv = new ScalacEvaluator(underlying, ctx, prog) + + /** Evaluates resuts which can be evaluated directly + * For example, concatenation of two string literals */ + val underlying = new DefaultEvaluator(ctx, prog) + underlying.setEvaluationFailOnChoose(true) + override type Value = (Expr, Expr) + + override val description: String = "Evaluates string programs but keeps the formula which generated the string" + override val name: String = "String Tracing evaluator" + + protected def e(expr: Expr)(implicit rctx: RC, gctx: GC): (Expr, Expr) = expr match { + case Variable(id) => + rctx.mappings.get(id) match { + case Some(v) if v != expr => + e(v) + case _ => + (expr, expr) + } + + case e if ExprOps.isValue(e) => + (e, e) + + case IfExpr(cond, thenn, elze) => + val first = underlying.e(cond) + first match { + case BooleanLiteral(true) => + ctx.reporter.ifDebug(printer => printer(thenn))(DebugSectionSynthesis) + e(thenn) + case BooleanLiteral(false) => e(elze) + case _ => throw EvalError(typeErrorMsg(first, BooleanType)) + } + + case MatchExpr(scrut, cases) => + val (escrut, tscrut) = e(scrut) + val rscrut = escrut + cases.toStream.map(c => underlying.matchesCase(rscrut, c)).find(_.nonEmpty) match { + case Some(Some((c, mappings))) => + e(c.rhs)(rctx.withNewVars(mappings), gctx) + case _ => + throw RuntimeError("MatchError: "+rscrut.asString+" did not match any of the cases :" + cases) + } + + case FunctionInvocation(tfd, args) => + if (gctx.stepsLeft < 0) { + throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")") + } + gctx.stepsLeft -= 1 + val evArgs = args map e + val evArgsValues = evArgs.map(_._1) + val evArgsOrigin = evArgs.map(_._2) + + // build a mapping for the function... + val frame = rctx.withNewVars(tfd.paramSubst(evArgsValues)) + + val callResult = if ((evArgsValues forall ExprOps.isValue) && tfd.fd.annotations("extern") && ctx.classDir.isDefined) { + (scalaEv.call(tfd, evArgsValues), functionInvocation(tfd.fd, evArgsOrigin)) + } else { + if((!tfd.hasBody && !rctx.mappings.isDefinedAt(tfd.id)) || tfd.body.exists(b => ExprOps.exists(e => e.isInstanceOf[Choose])(b))) { + (functionInvocation(tfd.fd, evArgsValues), functionInvocation(tfd.fd, evArgsOrigin)) + } else { + val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) + e(body)(frame, gctx) + } + } + callResult + case Operator(es, builder) => + val (ees, ts) = es.map(e).unzip + if(ees forall ExprOps.isValue) { + (underlying.e(builder(ees)), builder(ts)) + } else { + (builder(ees), builder(ts)) + } + } + + +} diff --git a/src/main/scala/leon/evaluators/AngelicEvaluator.scala b/src/main/scala/leon/evaluators/AngelicEvaluator.scala index 99d704f67c7485ba8e00727c4edcc5d30644b4cc..57f233cf44caa63dd587e0d792eb33ec62b89d55 100644 --- a/src/main/scala/leon/evaluators/AngelicEvaluator.scala +++ b/src/main/scala/leon/evaluators/AngelicEvaluator.scala @@ -22,9 +22,6 @@ class AngelicEvaluator(underlying: NDEvaluator) case other@(RuntimeError(_) | EvaluatorError(_)) => other.asInstanceOf[Result[Nothing]] } - - /** Checks that `model |= expr` and that quantifications are all valid */ - def check(expr: Expr, model: Model): CheckResult = underlying.check(expr, model) } class DemonicEvaluator(underlying: NDEvaluator) @@ -42,7 +39,4 @@ class DemonicEvaluator(underlying: NDEvaluator) case other@(RuntimeError(_) | EvaluatorError(_)) => other.asInstanceOf[Result[Nothing]] } - - /** Checks that `model |= expr` and that quantifications are all valid */ - def check(expr: Expr, model: Model): CheckResult = underlying.check(expr, model) -} \ No newline at end of file +} diff --git a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala index 533ba695ca27f478a03ac7b6cf53d46885e11309..59e9be41863b313202f3821e137419e8872cbc01 100644 --- a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala +++ b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala @@ -13,7 +13,6 @@ import codegen.CodeGenParams import leon.codegen.runtime.LeonCodeGenRuntimeException import leon.codegen.runtime.LeonCodeGenEvaluationException -import leon.codegen.runtime.LeonCodeGenQuantificationException class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Evaluator(ctx, unit.program) with DeterministicEvaluator { @@ -38,40 +37,6 @@ class CodeGenEvaluator(ctx: LeonContext, val unit : CompilationUnit) extends Eva } } - def check(expression: Expr, model: solvers.Model) : CheckResult = { - compileExpr(expression, model.toSeq.map(_._1)).map { ce => - ctx.timers.evaluators.codegen.runtime.start() - try { - val res = ce.eval(model, check = true) - if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess - else EvaluationResults.CheckValidityFailure - } catch { - case e : ArithmeticException => - EvaluationResults.CheckRuntimeFailure(e.getMessage) - - case e : ArrayIndexOutOfBoundsException => - EvaluationResults.CheckRuntimeFailure(e.getMessage) - - case e : LeonCodeGenRuntimeException => - EvaluationResults.CheckRuntimeFailure(e.getMessage) - - case e : LeonCodeGenEvaluationException => - EvaluationResults.CheckRuntimeFailure(e.getMessage) - - case e : java.lang.ExceptionInInitializerError => - EvaluationResults.CheckRuntimeFailure(e.getException.getMessage) - - case so : java.lang.StackOverflowError => - EvaluationResults.CheckRuntimeFailure("Stack overflow") - - case e : LeonCodeGenQuantificationException => - EvaluationResults.CheckQuantificationFailure(e.getMessage) - } finally { - ctx.timers.evaluators.codegen.runtime.stop() - } - }.getOrElse(EvaluationResults.CheckRuntimeFailure("Couldn't compile expression.")) - } - def eval(expression: Expr, model: solvers.Model) : EvaluationResult = { compile(expression, model.toSeq.map(_._1)).map { e => ctx.timers.evaluators.codegen.runtime.start() diff --git a/src/main/scala/leon/evaluators/ContextualEvaluator.scala b/src/main/scala/leon/evaluators/ContextualEvaluator.scala index 0fc33102a04716816fc3b2a83faa1384b37da1fd..51205fc5c8221c439801786b3c9ee3cc7b2e3287 100644 --- a/src/main/scala/leon/evaluators/ContextualEvaluator.scala +++ b/src/main/scala/leon/evaluators/ContextualEvaluator.scala @@ -8,7 +8,7 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.Types._ -import solvers.{HenkinModel, Model} +import solvers.{PartialModel, Model} abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps: Int) extends Evaluator(ctx, prog) with CEvalHelpers { @@ -20,7 +20,9 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps def initRC(mappings: Map[Identifier, Expr]): RC def initGC(model: solvers.Model, check: Boolean): GC - case class EvalError(msg : String) extends Exception + case class EvalError(msg : String) extends Exception { + override def getMessage = msg + Option(super.getMessage).map("\n" + _).getOrElse("") + } case class RuntimeError(msg : String) extends Exception case class QuantificationError(msg: String) extends Exception @@ -46,33 +48,9 @@ abstract class ContextualEvaluator(ctx: LeonContext, prog: Program, val maxSteps } } - def check(ex: Expr, model: Model): CheckResult = { - assert(ex.getType == BooleanType, "Can't check non-boolean expression " + ex.asString) - try { - lastGC = Some(initGC(model, check = true)) - ctx.timers.evaluators.recursive.runtime.start() - val res = e(ex)(initRC(model.toMap), lastGC.get) - if (res == BooleanLiteral(true)) EvaluationResults.CheckSuccess - else EvaluationResults.CheckValidityFailure - } catch { - case so: StackOverflowError => - EvaluationResults.CheckRuntimeFailure("Stack overflow") - case e @ EvalError(msg) => - EvaluationResults.CheckRuntimeFailure(msg) - case e @ RuntimeError(msg) => - EvaluationResults.CheckRuntimeFailure(msg) - case jre: java.lang.RuntimeException => - EvaluationResults.CheckRuntimeFailure(jre.getMessage) - case qe @ QuantificationError(msg) => - EvaluationResults.CheckQuantificationFailure(msg) - } finally { - ctx.timers.evaluators.recursive.runtime.stop() - } - } - protected def e(expr: Expr)(implicit rctx: RC, gctx: GC): Value - def typeErrorMsg(tree : Expr, expected : TypeTree) : String = s"Type error : expected ${expected.asString}, found ${tree.asString}." + def typeErrorMsg(tree : Expr, expected : TypeTree) : String = s"Type error : expected ${expected.asString}, found ${tree.asString} of type ${tree.getType}." } @@ -82,8 +60,8 @@ private[evaluators] trait CEvalHelpers { /* This is an effort to generalize forall to non-det. solvers def forallInstantiations(gctx:GC, fargs: Seq[ValDef], conj: Expr) = { - val henkinModel: HenkinModel = gctx.model match { - case hm: HenkinModel => hm + val henkinModel: PartialModel = gctx.model match { + case hm: PartialModel => hm case _ => throw EvalError("Can't evaluate foralls without henkin model") } @@ -136,4 +114,4 @@ private[evaluators] trait CEvalHelpers { -} \ No newline at end of file +} diff --git a/src/main/scala/leon/evaluators/DualEvaluator.scala b/src/main/scala/leon/evaluators/DualEvaluator.scala index 4c405c8b6f216ee9b839101d8bff5574035b05f4..7976dbfc207679ee92f81b86fd4f00d582b2e9f0 100644 --- a/src/main/scala/leon/evaluators/DualEvaluator.scala +++ b/src/main/scala/leon/evaluators/DualEvaluator.scala @@ -9,20 +9,20 @@ import purescala.Definitions._ import purescala.Types._ import codegen._ +import codegen.runtime.{StdMonitor, Monitor} class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) extends RecursiveEvaluator(ctx, prog, params.maxFunctionInvocations) - with HasDefaultGlobalContext -{ + with HasDefaultGlobalContext { type RC = DualRecContext def initRC(mappings: Map[Identifier, Expr]): RC = DualRecContext(mappings) implicit val debugSection = utils.DebugSectionEvaluation - var monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) - val unit = new CompilationUnit(ctx, prog, params) + var monitor: Monitor = new StdMonitor(unit, params.maxFunctionInvocations, Map()) + val isCompiled = prog.definedFunctions.toSet case class DualRecContext(mappings: Map[Identifier, Expr], needJVMRef: Boolean = false) extends RecContext[DualRecContext] { @@ -37,7 +37,9 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) val (className, methodName, _) = unit.leonFunDefToJVMInfo(tfd.fd).get - val allArgs = if (params.requireMonitor) monitor +: args else args + val allArgs = Seq(monitor) ++ + (if (tfd.fd.tparams.nonEmpty) Seq(tfd.tps.map(unit.registerType(_)).toArray) else Seq()) ++ + args ctx.reporter.debug(s"Calling $className.$methodName(${args.mkString(",")})") @@ -124,9 +126,8 @@ class DualEvaluator(ctx: LeonContext, prog: Program, params: CodeGenParams) } } - override def eval(ex: Expr, model: solvers.Model) = { - monitor = new runtime.LeonCodeGenRuntimeMonitor(params.maxFunctionInvocations) + monitor = unit.getMonitor(model, params.maxFunctionInvocations) super.eval(ex, model) } diff --git a/src/main/scala/leon/evaluators/EvaluationResults.scala b/src/main/scala/leon/evaluators/EvaluationResults.scala index 18f7a0c92d448f98c8f6a271d91e021a649e3b9c..b9cbecdde4cdd31b2b29dcb82fdd23c50aaf9a24 100644 --- a/src/main/scala/leon/evaluators/EvaluationResults.scala +++ b/src/main/scala/leon/evaluators/EvaluationResults.scala @@ -15,21 +15,4 @@ object EvaluationResults { /** Represents an evaluation that failed (in the evaluator). */ case class EvaluatorError(message : String) extends Result(None) - - /** Results of checking proposition evaluation. - * Useful for verification of model validity in presence of quantifiers. */ - sealed abstract class CheckResult(val success: Boolean) - - /** Successful proposition evaluation (model |= expr) */ - case object CheckSuccess extends CheckResult(true) - - /** Check failed with `model |= !expr` */ - case object CheckValidityFailure extends CheckResult(false) - - /** Check failed due to evaluation or runtime errors. - * @see [[RuntimeError]] and [[EvaluatorError]] */ - case class CheckRuntimeFailure(msg: String) extends CheckResult(false) - - /** Check failed due to inconsistence of model with quantified propositions. */ - case class CheckQuantificationFailure(msg: String) extends CheckResult(false) } diff --git a/src/main/scala/leon/evaluators/Evaluator.scala b/src/main/scala/leon/evaluators/Evaluator.scala index ff0f35f1241547d66f81f0b341fca508276b40ea..400409577ac7aebf490d8a961ac44c67036989c7 100644 --- a/src/main/scala/leon/evaluators/Evaluator.scala +++ b/src/main/scala/leon/evaluators/Evaluator.scala @@ -18,7 +18,6 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends type Value type EvaluationResult = EvaluationResults.Result[Value] - type CheckResult = EvaluationResults.CheckResult /** Evaluates an expression, using [[Model.mapping]] as a valuation function for the free variables. */ def eval(expr: Expr, model: Model) : EvaluationResult @@ -31,9 +30,6 @@ abstract class Evaluator(val context: LeonContext, val program: Program) extends /** Evaluates a ground expression. */ final def eval(expr: Expr) : EvaluationResult = eval(expr, Model.empty) - /** Checks that `model |= expr` and that quantifications are all valid */ - def check(expr: Expr, model: Model) : CheckResult - /** Compiles an expression into a function, where the arguments are the free variables in the expression. * `argorder` specifies in which order the arguments should be passed. * The default implementation uses the evaluation function each time, but evaluators are free diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 75c691edf1f7e5a4ca056bf8c8448bff9b5ead77..689980682c86d03bb715b50e0dd2ff6ce530c331 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -3,19 +3,21 @@ package leon package evaluators -import leon.purescala.Quantification._ +import purescala.Quantification._ import purescala.Constructors._ import purescala.ExprOps._ import purescala.Expressions.Pattern import purescala.Extractors._ -import purescala.TypeOps._ +import purescala.TypeOps.isSubtypeOf import purescala.Types._ import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ -import leon.solvers.{HenkinModel, Model, SolverFactory} +import purescala.DefOps +import solvers.{PartialModel, Model, SolverFactory} +import solvers.combinators.UnrollingProcedure import scala.collection.mutable.{Map => MutableMap} -import leon.purescala.DefOps +import scala.concurrent.duration._ import org.apache.commons.lang3.StringEscapeUtils abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int) @@ -28,6 +30,11 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int lazy val scalaEv = new ScalacEvaluator(this, ctx, prog) protected var clpCache = Map[(Choose, Seq[Expr]), Expr]() + protected var frlCache = Map[(Forall, Seq[Expr]), Expr]() + + private var evaluationFailsOnChoose = false + /** Sets the flag if when encountering a Choose, it should fail instead of solving it. */ + def setEvaluationFailOnChoose(b: Boolean) = { this.evaluationFailsOnChoose = b; this } protected[evaluators] def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match { case Variable(id) => @@ -37,7 +44,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case Some(v) => v case None => - throw EvalError("No value for identifier " + id.asString + " in mapping.") + throw EvalError("No value for identifier " + id.asString + " in mapping " + rctx.mappings) } case Application(caller, args) => @@ -46,13 +53,10 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val newArgs = args.map(e) val mapping = l.paramSubst(newArgs) e(body)(rctx.withNewVars(mapping), gctx) - case PartialLambda(mapping, dflt, _) => + case FiniteLambda(mapping, dflt, _) => mapping.find { case (pargs, res) => (args zip pargs).forall(p => e(Equals(p._1, p._2)) == BooleanLiteral(true)) - }.map(_._2).orElse(dflt).getOrElse { - throw EvalError("Cannot apply partial lambda outside of domain : " + - args.map(e(_).asString(ctx)).mkString("(", ", ", ")")) - } + }.map(_._2).getOrElse(dflt) case f => throw EvalError("Cannot apply non-lambda function " + f.asString) } @@ -73,16 +77,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int e(IfExpr(Not(cond), Error(expr.getType, oerr.getOrElse("Assertion failed @"+expr.getPos)), body)) case en@Ensuring(body, post) => - if ( exists{ - case Hole(_,_) => true - case WithOracle(_,_) => true - case _ => false - }(en)) { - import synthesis.ConversionPhase.convert - e(convert(en, ctx)) - } else { - e(en.toAssert) - } + e(en.toAssert) case Error(tpe, desc) => throw RuntimeError("Error reached in evaluation: " + desc) @@ -192,7 +187,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (lv,rv) match { case (FiniteSet(el1, _),FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteMap(el1, _, _),FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) - case (PartialLambda(m1, d1, _), PartialLambda(m2, d2, _)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2) + case (FiniteLambda(m1, d1, _), FiniteLambda(m2, d2, _)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2) case _ => BooleanLiteral(lv == rv) } @@ -506,19 +501,95 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int FiniteSet(els.map(e), base) case l @ Lambda(_, _) => - val (nl, structSubst) = normalizeStructure(matchToIfThenElse(l)) - val mapping = variablesOf(l).map(id => structSubst(id) -> e(Variable(id))).toMap - val newLambda = replaceFromIDs(mapping, nl).asInstanceOf[Lambda] - if (!gctx.lambdas.isDefinedAt(newLambda)) { - gctx.lambdas += (newLambda -> nl.asInstanceOf[Lambda]) + val mapping = variablesOf(l).map(id => id -> e(Variable(id))).toMap + val newLambda = replaceFromIDs(mapping, l).asInstanceOf[Lambda] + val (normalized, _) = normalizeStructure(matchToIfThenElse(newLambda)) + val nl = normalized.asInstanceOf[Lambda] + if (!gctx.lambdas.isDefinedAt(nl)) { + val (norm, _) = normalizeStructure(matchToIfThenElse(l)) + gctx.lambdas += (nl -> norm.asInstanceOf[Lambda]) } - newLambda + nl + + case FiniteLambda(mapping, dflt, tpe) => + FiniteLambda(mapping.map(p => p._1.map(e) -> e(p._2)), e(dflt), tpe) + + case f @ Forall(fargs, body) => + + implicit val debugSection = utils.DebugSectionVerification + + ctx.reporter.debug("Executing forall!") + + val mapping = variablesOf(f).map(id => id -> rctx.mappings(id)).toMap + val context = mapping.toSeq.sortBy(_._1.uniqueName).map(_._2) + + frlCache.getOrElse((f, context), { + val tStart = System.currentTimeMillis + + val newCtx = ctx.copy(options = ctx.options.map { + case LeonOption(optDef, value) if optDef == UnrollingProcedure.optFeelingLucky => + LeonOption(optDef)(false) + case opt => opt + }) + + val solverf = SolverFactory.getFromSettings(newCtx, program).withTimeout(1.second) + val solver = solverf.getNewSolver() + + try { + val cnstr = Not(replaceFromIDs(mapping, body)) + solver.assertCnstr(cnstr) + + gctx.model match { + case pm: PartialModel => + val quantifiers = fargs.map(_.id).toSet + val quorums = extractQuorums(body, quantifiers) + + val domainCnstr = orJoin(quorums.map { quorum => + val quantifierDomains = quorum.flatMap { case (path, caller, args) => + val matcher = e(expr) match { + case l: Lambda => gctx.lambdas.getOrElse(l, l) + case ev => ev + } + + val domain = pm.domains.get(matcher) + args.zipWithIndex.flatMap { + case (Variable(id),idx) if quantifiers(id) => + Some(id -> domain.map(cargs => path -> cargs(idx))) + case _ => None + } + } - case PartialLambda(mapping, dflt, tpe) => - PartialLambda(mapping.map(p => p._1.map(e) -> e(p._2)), dflt.map(e), tpe) + val domainMap = quantifierDomains.groupBy(_._1).mapValues(_.map(_._2).flatten) + andJoin(domainMap.toSeq.map { case (id, dom) => + orJoin(dom.toSeq.map { case (path, value) => and(path, Equals(Variable(id), value)) }) + }) + }) - case Forall(fargs, body) => - evalForall(fargs.map(_.id).toSet, body) + solver.assertCnstr(domainCnstr) + + case _ => + } + + solver.check match { + case Some(negRes) => + val total = System.currentTimeMillis-tStart + val res = BooleanLiteral(!negRes) + ctx.reporter.debug("Verification took "+total+"ms") + ctx.reporter.debug("Finished forall evaluation with: "+res) + + frlCache += (f, context) -> res + res + case _ => + throw RuntimeError("Timeout exceeded") + } + } catch { + case e: Throwable => + throw EvalError("Runtime verification of forall failed: "+e.getMessage) + } finally { + solverf.reclaim(solver) + solverf.shutdown() + } + }) case ArrayLength(a) => val FiniteArray(_, _, IntLiteral(length)) = e(a) @@ -560,6 +631,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case (l,r) => throw EvalError(typeErrorMsg(l, MapType(r.getType, g.getType))) } + case u @ MapUnion(m1,m2) => (e(m1), e(m2)) match { case (f1@FiniteMap(ss1, _, _), FiniteMap(ss2, _, _)) => val newSs = ss1 ++ ss2 @@ -568,6 +640,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case (l, r) => throw EvalError(typeErrorMsg(l, m1.getType)) } + case i @ MapIsDefinedAt(m,k) => (e(m), e(k)) match { case (FiniteMap(ss, _, _), e) => BooleanLiteral(ss.contains(e)) case (l, r) => throw EvalError(typeErrorMsg(l, m.getType)) @@ -577,6 +650,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int e(p.asConstraint) case choose: Choose => + if(evaluationFailsOnChoose) { + throw EvalError("Evaluator set to not solve choose constructs") + } implicit val debugSection = utils.DebugSectionSynthesis @@ -592,7 +668,6 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int val solverf = SolverFactory.getFromSettings(ctx, program) val solver = solverf.getNewSolver() - try { val eqs = p.as.map { case id => @@ -637,7 +712,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case Some(Some((c, mappings))) => e(c.rhs)(rctx.withNewVars(mappings), gctx) case _ => - throw RuntimeError("MatchError: "+rscrut.asString+" did not match any of the cases") + throw RuntimeError("MatchError: "+rscrut.asString+" did not match any of the cases:\n"+cases) } case gl: GenericValue => gl @@ -645,9 +720,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int case l : Literal[_] => l case other => - context.reporter.error(other.getPos, "Error: don't know how to handle " + other.asString + " in Evaluator ("+other.getClass+").") - println("RecursiveEvaluator error:" + other.asString) - throw EvalError("Unhandled case in Evaluator : " + other.asString) + throw EvalError("Unhandled case in Evaluator : [" + other.getClass + "] " + other.asString) } def matchesCase(scrut: Expr, caze: MatchCase)(implicit rctx: RC, gctx: GC): Option[(MatchCase, Map[Identifier, Expr])] = { @@ -723,142 +796,5 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int } yield (caze, r) } } - - - protected def evalForall(quants: Set[Identifier], body: Expr, check: Boolean = true)(implicit rctx: RC, gctx: GC): Expr = { - val henkinModel: HenkinModel = gctx.model match { - case hm: HenkinModel => hm - case _ => throw EvalError("Can't evaluate foralls without henkin model") -} - - val TopLevelAnds(conjuncts) = body - e(andJoin(conjuncts.flatMap { conj => - val vars = variablesOf(conj) - val quantified = quants.filter(vars) - - extractQuorums(conj, quantified).flatMap { case (qrm, others) => - val quorum = qrm.toList - - if (quorum.exists { case (TopLevelAnds(paths), _, _) => - val p = andJoin(paths.filter(path => (variablesOf(path) & quantified).isEmpty)) - e(p) == BooleanLiteral(false) - }) List(BooleanLiteral(true)) else { - - var mappings: Seq[(Identifier, Int, Int)] = Seq.empty - var constraints: Seq[(Expr, Int, Int)] = Seq.empty - var equalities: Seq[((Int, Int), (Int, Int))] = Seq.empty - - for (((_, expr, args), qidx) <- quorum.zipWithIndex) { - val (qmappings, qconstraints) = args.zipWithIndex.partition { - case (Variable(id),aidx) => quantified(id) - case _ => false - } - - mappings ++= qmappings.map(p => (p._1.asInstanceOf[Variable].id, qidx, p._2)) - constraints ++= qconstraints.map(p => (p._1, qidx, p._2)) - } - - val mapping = for ((id, es) <- mappings.groupBy(_._1)) yield { - val base :: others = es.toList.map(p => (p._2, p._3)) - equalities ++= others.map(p => base -> p) - (id -> base) - } - - def domain(expr: Expr): Set[Seq[Expr]] = henkinModel.domain(e(expr) match { - case l: Lambda => gctx.lambdas.getOrElse(l, l) - case ev => ev - }) - - val argSets = quorum.foldLeft[List[Seq[Seq[Expr]]]](List(Seq.empty)) { - case (acc, (_, expr, _)) => acc.flatMap(s => domain(expr).map(d => s :+ d)) - } - - argSets.map { args => - val argMap: Map[(Int, Int), Expr] = args.zipWithIndex.flatMap { - case (a, qidx) => a.zipWithIndex.map { case (e, aidx) => (qidx, aidx) -> e } - }.toMap - - val map = mapping.map { case (id, key) => id -> argMap(key) } - val enabler = andJoin(constraints.map { - case (e, qidx, aidx) => Equals(e, argMap(qidx -> aidx)) - } ++ equalities.map { - case (k1, k2) => Equals(argMap(k1), argMap(k2)) - }) - - val ctx = rctx.withNewVars(map) - if (e(enabler)(ctx, gctx) == BooleanLiteral(true)) { - if (gctx.check) { - for ((b,caller,args) <- others if e(b)(ctx, gctx) == BooleanLiteral(true)) { - val evArgs = args.map(arg => e(arg)(ctx, gctx)) - if (!domain(caller)(evArgs)) - throw QuantificationError("Unhandled transitive implication in " + replaceFromIDs(map, conj)) - } - } - - e(conj)(ctx, gctx) - } else { - BooleanLiteral(true) - } - } - } - } - })) match { - case res @ BooleanLiteral(true) if check => - if (gctx.check) { - checkForall(quants, body) match { - case status: ForallInvalid => - throw QuantificationError("Invalid forall: " + status.getMessage) - case _ => - // make sure the body doesn't contain matches or lets as these introduce new locals - val cleanBody = expandLets(matchToIfThenElse(body)) - val calls = new CollectorWithPaths[(Expr, Seq[Expr], Seq[Expr])] { - def collect(e: Expr, path: Seq[Expr]): Option[(Expr, Seq[Expr], Seq[Expr])] = e match { - case QuantificationMatcher(IsTyped(caller, _: FunctionType), args) => Some((caller, args, path)) - case _ => None - } - - override def rec(e: Expr, path: Seq[Expr]): Expr = e match { - case l : Lambda => l - case _ => super.rec(e, path) - } - }.traverse(cleanBody) - - for ((caller, appArgs, paths) <- calls) { - val path = andJoin(paths.filter(expr => (variablesOf(expr) & quants).isEmpty)) - if (e(path) == BooleanLiteral(true)) e(caller) match { - case _: PartialLambda => // OK - case l: Lambda => - val nl @ Lambda(args, body) = gctx.lambdas.getOrElse(l, l) - val lambdaQuantified = (appArgs zip args).collect { - case (Variable(id), vd) if quants(id) => vd.id - }.toSet - - if (lambdaQuantified.nonEmpty) { - checkForall(lambdaQuantified, body) match { - case lambdaStatus: ForallInvalid => - throw QuantificationError("Invalid forall: " + lambdaStatus.getMessage) - case _ => // do nothing - } - - val axiom = Equals(Application(nl, args.map(_.toVariable)), nl.body) - if (evalForall(args.map(_.id).toSet, axiom, check = false) == BooleanLiteral(false)) { - throw QuantificationError("Unaxiomatic lambda " + l) - } - } - case f => - throw EvalError("Cannot apply non-lambda function " + f.asString) - } - } - } - } - - res - - // `res == false` means the quantification is valid since there effectivelly must - // exist an input for which the proposition doesn't hold - case res => res - } - } - } diff --git a/src/main/scala/leon/evaluators/StreamEvaluator.scala b/src/main/scala/leon/evaluators/StreamEvaluator.scala index 9cc6dd132036ffdde4a5ef70e2be87e6276f9f3c..3fff6e523a945b23761a232094c9327ba276e201 100644 --- a/src/main/scala/leon/evaluators/StreamEvaluator.scala +++ b/src/main/scala/leon/evaluators/StreamEvaluator.scala @@ -6,15 +6,19 @@ package evaluators import purescala.Constructors._ import purescala.ExprOps._ import purescala.Extractors._ -import purescala.TypeOps._ +import purescala.TypeOps.{leastUpperBound, isSubtypeOf} import purescala.Types._ import purescala.Common.Identifier import purescala.Definitions.{TypedFunDef, Program} import purescala.Expressions._ +import purescala.Quantification._ -import leon.solvers.SolverFactory +import leon.solvers.{SolverFactory, PartialModel} +import leon.solvers.combinators.UnrollingProcedure import leon.utils.StreamUtils._ +import scala.concurrent.duration._ + class StreamEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvaluator(ctx, prog, 50000) with NDEvaluator @@ -37,12 +41,12 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) case l @ Lambda(params, body) => val mapping = l.paramSubst(newArgs) e(body)(rctx.withNewVars(mapping), gctx).distinct - case PartialLambda(mapping, _, _) => + case FiniteLambda(mapping, dflt, _) => // FIXME - mapping.collectFirst { + Stream(mapping.collectFirst { case (pargs, res) if (newArgs zip pargs).forall { case (f, r) => f == r } => res - }.toStream + }.getOrElse(dflt)) case _ => Stream() } @@ -58,16 +62,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) e(IfExpr(Not(cond), Error(expr.getType, oerr.getOrElse("Assertion failed @"+expr.getPos)), body)) case en@Ensuring(body, post) => - if ( exists{ - case Hole(_,_) => true - case WithOracle(_,_) => true - case _ => false - }(en)) { - import synthesis.ConversionPhase.convert - e(convert(en, ctx)) - } else { - e(en.toAssert) - } + e(en.toAssert) case Error(tpe, desc) => Stream() @@ -115,6 +110,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) case Or(args) if args.isEmpty => Stream(BooleanLiteral(false)) + case Or(args) => e(args.head).distinct.flatMap { case BooleanLiteral(true) => Stream(BooleanLiteral(true)) @@ -126,42 +122,93 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) e(Or(Not(lhs), rhs)) case l @ Lambda(_, _) => - val (nl, structSubst) = normalizeStructure(l) val mapping = variablesOf(l).map(id => - structSubst(id) -> (e(Variable(id)) match { + id -> (e(Variable(id)) match { case Stream(v) => v case _ => return Stream() }) ).toMap - Stream(replaceFromIDs(mapping, nl)) - - // FIXME - case PartialLambda(mapping, tpe, df) => - def solveOne(pair: (Seq[Expr], Expr)) = { - val (args, res) = pair - for { - as <- cartesianProduct(args map e) - r <- e(res) - } yield as -> r - } - cartesianProduct(mapping map solveOne) map (PartialLambda(_, tpe, df)) // FIXME!!! - - case f @ Forall(fargs, TopLevelAnds(conjuncts)) => - Stream() // FIXME - /*def solveOne(conj: Expr) = { - val instantiations = forallInstantiations(gctx, fargs, conj) - for { - es <- cartesianProduct(instantiations.map { case (enabler, mapping) => - e(Implies(enabler, conj))(rctx.withNewVars(mapping), gctx) - }) - res <- e(andJoin(es)) - } yield res - } + Stream(replaceFromIDs(mapping, l)) - for { - conj <- cartesianProduct(conjuncts map solveOne) - res <- e(andJoin(conj)) - } yield res*/ + case fl @ FiniteLambda(mapping, dflt, tpe) => + // finite lambda should always be ground! + Stream(fl) + + case f @ Forall(fargs, body) => + + // TODO add memoization + implicit val debugSection = utils.DebugSectionVerification + + ctx.reporter.debug("Executing forall!") + + val mapping = variablesOf(f).map(id => id -> rctx.mappings(id)).toMap + val context = mapping.toSeq.sortBy(_._1.uniqueName).map(_._2) + + val tStart = System.currentTimeMillis + + val newCtx = ctx.copy(options = ctx.options.map { + case LeonOption(optDef, value) if optDef == UnrollingProcedure.optFeelingLucky => + LeonOption(optDef)(false) + case opt => opt + }) + + val solverf = SolverFactory.getFromSettings(newCtx, program).withTimeout(1.second) + val solver = solverf.getNewSolver() + + try { + val cnstr = Not(replaceFromIDs(mapping, body)) + solver.assertCnstr(cnstr) + + gctx.model match { + case pm: PartialModel => + val quantifiers = fargs.map(_.id).toSet + val quorums = extractQuorums(body, quantifiers) + + val domainCnstr = orJoin(quorums.map { quorum => + val quantifierDomains = quorum.flatMap { case (path, caller, args) => + val optMatcher = e(expr) match { + case Stream(l: Lambda) => Some(gctx.lambdas.getOrElse(l, l)) + case Stream(ev) => Some(ev) + case _ => None + } + + optMatcher.toSeq.flatMap { matcher => + val domain = pm.domains.get(matcher) + args.zipWithIndex.flatMap { + case (Variable(id),idx) if quantifiers(id) => + Some(id -> domain.map(cargs => path -> cargs(idx))) + case _ => None + } + } + } + + val domainMap = quantifierDomains.groupBy(_._1).mapValues(_.map(_._2).flatten) + andJoin(domainMap.toSeq.map { case (id, dom) => + orJoin(dom.toSeq.map { case (path, value) => and(path, Equals(Variable(id), value)) }) + }) + }) + + solver.assertCnstr(domainCnstr) + + case _ => + } + + solver.check match { + case Some(negRes) => + val total = System.currentTimeMillis-tStart + val res = BooleanLiteral(!negRes) + ctx.reporter.debug("Verification took "+total+"ms") + ctx.reporter.debug("Finished forall evaluation with: "+res) + Stream(res) + case _ => + Stream() + } + } catch { + case e: Throwable => Stream() + } finally { + solverf.reclaim(solver) + solverf.shutdown() + } case p : Passes => e(p.asConstraint) @@ -177,13 +224,18 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) val tStart = System.currentTimeMillis - val solverf = SolverFactory.getFromSettings(ctx, program) + val newCtx = ctx.copy(options = ctx.options.map { + case LeonOption(optDef, value) if optDef == UnrollingProcedure.optFeelingLucky => + LeonOption(optDef)(false) + case opt => opt + }) + + val solverf = SolverFactory.getFromSettings(newCtx, program) val solver = solverf.getNewSolver() try { val eqs = p.as.map { - case id => - Equals(Variable(id), rctx.mappings(id)) + case id => Equals(Variable(id), rctx.mappings(id)) } val cnstr = andJoin(eqs ::: p.pc :: p.phi :: Nil) @@ -346,7 +398,7 @@ class StreamEvaluator(ctx: LeonContext, prog: Program) (lv, rv) match { case (FiniteSet(el1, _), FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteMap(el1, _, _), FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) - case (PartialLambda(m1, _, d1), PartialLambda(m2, _, d2)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2) + case (FiniteLambda(m1, d1, _), FiniteLambda(m2, d2, _)) => BooleanLiteral(m1.toSet == m2.toSet && d1 == d2) case _ => BooleanLiteral(lv == rv) } diff --git a/src/main/scala/leon/evaluators/StringTracingEvaluator.scala b/src/main/scala/leon/evaluators/StringTracingEvaluator.scala deleted file mode 100644 index 62cde913fae123ab0fcd8c3d019953af9efb5942..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/evaluators/StringTracingEvaluator.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package evaluators - -import purescala.Extractors.Operator -import purescala.Expressions._ -import purescala.Types._ -import purescala.Definitions.{TypedFunDef, Program} -import purescala.DefOps -import purescala.Expressions.Expr -import leon.utils.DebugSectionSynthesis -import org.apache.commons.lang3.StringEscapeUtils - -class StringTracingEvaluator(ctx: LeonContext, prog: Program) extends ContextualEvaluator(ctx, prog, 50000) with HasDefaultGlobalContext with HasDefaultRecContext { - - val underlying = new DefaultEvaluator(ctx, prog) { - override protected[evaluators] def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match { - - case FunctionInvocation(TypedFunDef(fd, Nil), Seq(input)) if fd == prog.library.escape.get => - e(input) match { - case StringLiteral(s) => - StringLiteral(StringEscapeUtils.escapeJava(s)) - case _ => throw EvalError(typeErrorMsg(input, StringType)) - } - - case FunctionInvocation(tfd, args) => - if (gctx.stepsLeft < 0) { - throw RuntimeError("Exceeded number of allocated methods calls ("+gctx.maxSteps+")") - } - gctx.stepsLeft -= 1 - - val evArgs = args map e - - // build a mapping for the function... - val frame = rctx.withNewVars(tfd.paramSubst(evArgs)) - - val callResult = if (tfd.fd.annotations("extern") && ctx.classDir.isDefined) { - scalaEv.call(tfd, evArgs) - } else { - if(!tfd.hasBody && !rctx.mappings.isDefinedAt(tfd.id)) { - throw EvalError("Evaluation of function with unknown implementation.") - } - - val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) - e(body)(frame, gctx) - } - - callResult - - case Variable(id) => - rctx.mappings.get(id) match { - case Some(v) if v != expr => - e(v) - case Some(v) => - v - case None => - expr - } - case StringConcat(s1, s2) => - val es1 = e(s1) - val es2 = e(s2) - (es1, es2) match { - case (StringLiteral(_), StringLiteral(_)) => - (super.e(StringConcat(es1, es2))) - case _ => - StringConcat(es1, es2) - } - case expr => - super.e(expr) - } - } - override type Value = (Expr, Expr) - - override val description: String = "Evaluates string programs but keeps the formula which generated the string" - override val name: String = "String Tracing evaluator" - - protected def e(expr: Expr)(implicit rctx: RC, gctx: GC): (Expr, Expr) = expr match { - case Variable(id) => - rctx.mappings.get(id) match { - case Some(v) if v != expr => - e(v) - case Some(v) => - (v, expr) - case None => - (expr, expr) - } - - case StringConcat(s1, s2) => - val (es1, t1) = e(s1) - val (es2, t2) = e(s2) - (es1, es2) match { - case (StringLiteral(_), StringLiteral(_)) => - (underlying.e(StringConcat(es1, es2)), StringConcat(t1, t2)) - case _ => - (StringConcat(es1, es2), StringConcat(t1, t2)) - } - case StringLength(s1) => - val (es1, t1) = e(s1) - es1 match { - case StringLiteral(_) => - (underlying.e(StringLength(es1)), StringLength(t1)) - case _ => - (StringLength(es1), StringLength(t1)) - } - - case expr@StringLiteral(s) => - (expr, expr) - - case IfExpr(cond, thenn, elze) => - val first = underlying.e(cond) - first match { - case BooleanLiteral(true) => - ctx.reporter.ifDebug(printer => printer(thenn))(DebugSectionSynthesis) - e(thenn) - case BooleanLiteral(false) => e(elze) - case _ => throw EvalError(typeErrorMsg(first, BooleanType)) - } - - case Operator(es, builder) => - val (ees, ts) = es.map(e).unzip - (underlying.e(builder(ees)), builder(ts)) - - } - - -} diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 67e5c4502ad1810b323616916e3edfab9c05a8ea..f8509be9c4d678b10a52f7736a406bb078b6bfbc 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -210,6 +210,37 @@ trait ASTExtractors { case _ => None } } + + + /** Matches the `A computes B` expression at the end of any expression A, and returns (A, B).*/ + object ExComputesExpression { + def unapply(tree: Apply) : Option[(Tree, Tree)] = tree match { + case Apply(Select( + Apply(TypeApply(ExSelected("leon", "lang", "package", "SpecsDecorations"), List(_)), realExpr :: Nil), + ExNamed("computes")), expected::Nil) + => Some((realExpr, expected)) + case _ => None + } + } + + /** Matches the `O ask I` expression at the end of any expression O, and returns (I, O).*/ + object ExAskExpression { + def unapply(tree: Apply) : Option[(Tree, Tree)] = tree match { + case Apply(TypeApply(Select( + Apply(TypeApply(ExSelected("leon", "lang", "package", "SpecsDecorations"), List(_)), output :: Nil), + ExNamed("ask")), List(_)), input::Nil) + => Some((input, output)) + case _ => None + } + } + + object ExByExampleExpression { + def unapply(tree: Apply) : Option[(Tree, Tree)] = tree match { + case Apply(TypeApply(ExSelected("leon", "lang", "package", "byExample"), List(_, _)), input :: res_output :: Nil) + => Some((input, res_output)) + case _ => None + } + } /** Extracts the `(input, output) passes { case In => Out ...}` and returns (input, output, list of case classes) */ object ExPasses { diff --git a/src/main/scala/leon/frontends/scalac/ClassgenPhase.scala b/src/main/scala/leon/frontends/scalac/ClassgenPhase.scala index b2650740483613b02d3cd208b913febdb15bca5f..3ac856763e13222f606e5e645d2f06a31e7fa60c 100644 --- a/src/main/scala/leon/frontends/scalac/ClassgenPhase.scala +++ b/src/main/scala/leon/frontends/scalac/ClassgenPhase.scala @@ -31,7 +31,7 @@ object ClassgenPhase extends LeonPhase[List[String], List[String]] { _.getLocation.getPath }.orElse( for { // We are in Eclipse. Look in Eclipse plugins to find scala lib - eclipseHome <- Option(System.getenv("ECLIPSE_HOME")) + eclipseHome <- Option(System.getenv("ECLIPSE_HOME")) pluginsHome = eclipseHome + "/plugins" plugins <- scala.util.Try(new File(pluginsHome).listFiles().map{ _.getAbsolutePath }).toOption path <- plugins.find{ _ contains "scala-library"} @@ -40,7 +40,7 @@ object ClassgenPhase extends LeonPhase[List[String], List[String]] { "make sure to set the ECLIPSE_HOME environment variable to your Eclipse installation home directory" )) - val tempOut = Files.createTempDirectory(new File("tmp/").toPath, "classes").toFile + val tempOut = Files.createTempDirectory("classes").toFile settings.classpath.value = scalaLib settings.usejavacp.value = false diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 3e93664008654edef6fa057c56451db0a108bc6b..0451b4c08603fb1b10ae91e2e38cc8902cd1649d 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -22,7 +22,7 @@ import Common._ import Extractors._ import Constructors._ import ExprOps._ -import TypeOps._ +import TypeOps.{leastUpperBound, typesCompatible, typeParamsOf, canBeSubtypeOf} import xlang.Expressions.{Block => LeonBlock, _} import xlang.ExprOps._ @@ -137,10 +137,6 @@ trait CodeExtraction extends ASTExtractors { private var currentFunDef: FunDef = null - //This is a bit misleading, if an expr is not mapped then it has no owner, if it is mapped to None it means - //that it can have any owner - private var owners: Map[Identifier, Option[FunDef]] = Map() - // This one never fails, on error, it returns Untyped def leonType(tpt: Type)(implicit dctx: DefContext, pos: Position): LeonType = { try { @@ -317,19 +313,45 @@ trait CodeExtraction extends ASTExtractors { } private def fillLeonUnit(u: ScalaUnit): Unit = { + def extractClassMembers(sym: Symbol, tpl: Template): Unit = { + for (t <- tpl.body if !t.isEmpty) { + extractFunOrMethodBody(Some(sym), t) + } + + classToInvariants.get(sym).foreach { bodies => + val fd = new FunDef(invId, Seq.empty, Seq.empty, BooleanType) + fd.addFlag(IsADTInvariant) + + val cd = classesToClasses(sym) + cd.registerMethod(fd) + cd.addFlag(IsADTInvariant) + val ctparams = sym.tpe match { + case TypeRef(_, _, tps) => + extractTypeParams(tps).map(_._1) + case _ => + Nil + } + + val tparamsMap = (ctparams zip cd.tparams.map(_.tp)).toMap + val dctx = DefContext(tparamsMap) + + val body = andJoin(bodies.toSeq.filter(_ != EmptyTree).map { + body => flattenBlocks(extractTreeOrNoTree(body)(dctx)) + }) + + fd.fullBody = body + } + } + for (t <- u.defs) t match { case t if isIgnored(t.symbol) => // ignore case ExAbstractClass(_, sym, tpl) => - for (t <- tpl.body if !t.isEmpty) { - extractFunOrMethodBody(Some(sym), t) - } + extractClassMembers(sym, tpl) case ExCaseClass(_, sym, _, tpl) => - for (t <- tpl.body if !t.isEmpty) { - extractFunOrMethodBody(Some(sym), t) - } + extractClassMembers(sym, tpl) case ExObjectDef(n, templ) => for (t <- templ.body if !t.isEmpty) t match { @@ -338,14 +360,10 @@ trait CodeExtraction extends ASTExtractors { None case ExAbstractClass(_, sym, tpl) => - for (t <- tpl.body if !t.isEmpty) { - extractFunOrMethodBody(Some(sym), t) - } + extractClassMembers(sym, tpl) case ExCaseClass(_, sym, _, tpl) => - for (t <- tpl.body if !t.isEmpty) { - extractFunOrMethodBody(Some(sym), t) - } + extractClassMembers(sym, tpl) case t => extractFunOrMethodBody(None, t) @@ -446,6 +464,7 @@ trait CodeExtraction extends ASTExtractors { private var isMethod = Set[Symbol]() private var methodToClass = Map[FunDef, LeonClassDef]() + private var classToInvariants = Map[Symbol, Set[Tree]]() /** * For the function in $defs with name $owner, find its parameter with index $index, @@ -547,8 +566,53 @@ trait CodeExtraction extends ASTExtractors { if (tpe != id.getType) println(tpe, id.getType) LeonValDef(id.setPos(t.pos)).setPos(t.pos) } + //println(s"Fields of $sym") ccd.setFields(fields) + + // checks whether this type definition could lead to an infinite type + def computeChains(tpe: LeonType): Map[TypeParameterDef, Set[LeonClassDef]] = { + var seen: Set[LeonClassDef] = Set.empty + var chains: Map[TypeParameterDef, Set[LeonClassDef]] = Map.empty + + def rec(tpe: LeonType): Set[LeonClassDef] = tpe match { + case ct: ClassType => + val root = ct.classDef.root + if (!seen(ct.classDef.root)) { + seen += ct.classDef.root + for (cct <- ct.root.knownCCDescendants; + (tp, tpe) <- cct.classDef.tparams zip cct.tps) { + val relevant = rec(tpe) + chains += tp -> (chains.getOrElse(tp, Set.empty) ++ relevant) + for (cd <- relevant; vd <- cd.fields) { + rec(vd.getType) + } + } + } + Set(root) + + case Types.NAryType(tpes, _) => + tpes.flatMap(rec).toSet + } + + rec(tpe) + chains + } + + val chains = computeChains(ccd.typed) + + def check(tp: TypeParameterDef, seen: Set[LeonClassDef]): Unit = chains.get(tp) match { + case Some(classDefs) => + if ((seen intersect classDefs).nonEmpty) { + outOfSubsetError(sym.pos, "Infinite types are not allowed") + } else { + for (cd <- classDefs; tp <- cd.tparams) check(tp, seen + cd) + } + case None => + } + + for (tp <- ccd.tparams) check(tp, Set.empty) + case _ => } @@ -572,6 +636,9 @@ trait CodeExtraction extends ASTExtractors { cd.registerMethod(fd) + case ExRequiredExpression(body) => + classToInvariants += sym -> (classToInvariants.getOrElse(sym, Set.empty) + body) + // Default values for parameters case t@ ExDefaultValueFunction(fsym, _, _, _, owner, index, _) => isMethod += fsym @@ -625,6 +692,8 @@ trait CodeExtraction extends ASTExtractors { } } + private val invId = FreshIdentifier("inv", BooleanType) + private var isLazy = Set[LeonValDef]() private var defsToDefs = Map[Symbol, FunDef]() @@ -639,7 +708,6 @@ trait CodeExtraction extends ASTExtractors { val ptpe = leonType(sym.tpe)(nctx, sym.pos) val tpe = if (sym.isByNameParam) FunctionType(Seq(), ptpe) else ptpe val newID = FreshIdentifier(sym.name.toString, tpe).setPos(sym.pos) - owners += (newID -> None) val vd = LeonValDef(newID).setPos(sym.pos) if (sym.isByNameParam) { @@ -798,21 +866,7 @@ trait CodeExtraction extends ASTExtractors { }} else body0 val finalBody = try { - flattenBlocks(extractTreeOrNoTree(body)(fctx)) match { - case e if e.getType.isInstanceOf[ArrayType] => - getOwner(e) match { - case Some(Some(fd)) if fd == funDef => - e - - case None => - e - - case _ => - outOfSubsetError(body, "Function cannot return an array that is not locally defined") - } - case e => - e - } + flattenBlocks(extractTreeOrNoTree(body)(fctx)) } catch { case e: ImpureCodeEncounteredException => e.emit() @@ -827,6 +881,10 @@ trait CodeExtraction extends ASTExtractors { NoTree(funDef.returnType) } + if (fctx.isExtern && !exists(_.isInstanceOf[NoTree])(finalBody)) { + reporter.warning(finalBody.getPos, "External function could be extracted as Leon tree: "+finalBody) + } + funDef.fullBody = finalBody // Post-extraction sanity checks @@ -955,11 +1013,7 @@ trait CodeExtraction extends ASTExtractors { private def extractTreeOrNoTree(tr: Tree)(implicit dctx: DefContext): LeonExpr = { try { - val res = extractTree(tr) - if (dctx.isExtern) { - reporter.warning(res.getPos, "External function could be extracted as Leon tree") - } - res + extractTree(tr) } catch { case e: ImpureCodeEncounteredException => if (dctx.isExtern) { @@ -1012,6 +1066,30 @@ trait CodeExtraction extends ASTExtractors { Ensuring(b, post) + case t @ ExComputesExpression(body, expected) => + val b = extractTreeOrNoTree(body).setPos(body.pos) + val expected_expr = extractTreeOrNoTree(expected).setPos(expected.pos) + + val resId = FreshIdentifier("res", b.getType).setPos(current.pos) + val post = Lambda(Seq(LeonValDef(resId)), Equals(Variable(resId), expected_expr)).setPos(current.pos) + + Ensuring(b, post) + + case t @ ExByExampleExpression(input, output) => + val input_expr = extractTreeOrNoTree(input).setPos(input.pos) + val output_expr = extractTreeOrNoTree(output).setPos(output.pos) + Passes(input_expr, output_expr, MatchCase(WildcardPattern(None), Some(BooleanLiteral(false)), NoTree(output_expr.getType))::Nil) + + case t @ ExAskExpression(input, output) => + val input_expr = extractTreeOrNoTree(input).setPos(input.pos) + val output_expr = extractTreeOrNoTree(output).setPos(output.pos) + + val resId = FreshIdentifier("res", output_expr.getType).setPos(current.pos) + val post = Lambda(Seq(LeonValDef(resId)), + Passes(input_expr, Variable(resId), MatchCase(WildcardPattern(None), Some(BooleanLiteral(false)), NoTree(output_expr.getType))::Nil)).setPos(current.pos) + + Ensuring(output_expr, post) + case ExAssertExpression(contract, oerr) => val const = extractTree(contract) val b = rest.map(extractTreeOrNoTree).getOrElse(UnitLiteral()) @@ -1090,15 +1168,6 @@ trait CodeExtraction extends ASTExtractors { val newID = FreshIdentifier(vs.name.toString, binderTpe) val valTree = extractTree(bdy) - if(valTree.getType.isInstanceOf[ArrayType]) { - getOwner(valTree) match { - case None => - owners += (newID -> Some(currentFunDef)) - case _ => - outOfSubsetError(tr, "Cannot alias array") - } - } - val restTree = rest match { case Some(rst) => val nctx = dctx.withNewVar(vs -> (() => Variable(newID))) @@ -1138,7 +1207,7 @@ trait CodeExtraction extends ASTExtractors { case _ => (Nil, restTree) } - LetDef(funDefWithBody +: other_fds, block) + letDef(funDefWithBody +: other_fds, block) // FIXME case ExDefaultValueFunction @@ -1151,15 +1220,6 @@ trait CodeExtraction extends ASTExtractors { val newID = FreshIdentifier(vs.name.toString, binderTpe) val valTree = extractTree(bdy) - if(valTree.getType.isInstanceOf[ArrayType]) { - getOwner(valTree) match { - case None => - owners += (newID -> Some(currentFunDef)) - case Some(_) => - outOfSubsetError(tr, "Cannot alias array") - } - } - val restTree = rest match { case Some(rst) => { val nv = vs -> (() => Variable(newID)) @@ -1178,9 +1238,6 @@ trait CodeExtraction extends ASTExtractors { case Some(fun) => val Variable(id) = fun() val rhsTree = extractTree(rhs) - if(rhsTree.getType.isInstanceOf[ArrayType] && getOwner(rhsTree).isDefined) { - outOfSubsetError(tr, "Cannot alias array") - } Assignment(id, rhsTree) case None => @@ -1223,18 +1280,6 @@ trait CodeExtraction extends ASTExtractors { outOfSubsetError(tr, "Array update only works on variables") } - getOwner(lhsRec) match { - // case Some(Some(fd)) if fd != currentFunDef => - // outOfSubsetError(tr, "cannot update an array that is not defined locally") - - // case Some(None) => - // outOfSubsetError(tr, "cannot update an array that is not defined locally") - - case Some(_) => - - case None => sys.error("This array: " + lhsRec + " should have had an owner") - } - val indexRec = extractTree(index) val newValueRec = extractTree(newValue) ArrayUpdate(lhsRec, indexRec, newValueRec) @@ -1309,7 +1354,6 @@ trait CodeExtraction extends ASTExtractors { val aTpe = extractType(tpt) val oTpe = oracleType(ops.pos, aTpe) val newID = FreshIdentifier(sym.name.toString, oTpe) - owners += (newID -> None) newID } @@ -1331,7 +1375,6 @@ trait CodeExtraction extends ASTExtractors { val vds = args map { vd => val aTpe = extractType(vd.tpt) val newID = FreshIdentifier(vd.symbol.name.toString, aTpe) - owners += (newID -> None) LeonValDef(newID) } @@ -1347,7 +1390,6 @@ trait CodeExtraction extends ASTExtractors { val vds = args map { case (tpt, sym) => val aTpe = extractType(tpt) val newID = FreshIdentifier(sym.name.toString, aTpe) - owners += (newID -> None) LeonValDef(newID) } @@ -1908,34 +1950,6 @@ trait CodeExtraction extends ASTExtractors { } } - private def getReturnedExpr(expr: LeonExpr): Seq[LeonExpr] = expr match { - case Let(_, _, rest) => getReturnedExpr(rest) - case LetVar(_, _, rest) => getReturnedExpr(rest) - case LeonBlock(_, rest) => getReturnedExpr(rest) - case IfExpr(_, thenn, elze) => getReturnedExpr(thenn) ++ getReturnedExpr(elze) - case MatchExpr(_, cses) => cses.flatMap{ cse => getReturnedExpr(cse.rhs) } - case _ => Seq(expr) - } - - def getOwner(exprs: Seq[LeonExpr]): Option[Option[FunDef]] = { - val exprOwners: Seq[Option[Option[FunDef]]] = exprs.map { - case Variable(id) => - owners.get(id) - case _ => - None - } - - if(exprOwners.contains(None)) - None - else if(exprOwners.contains(Some(None))) - Some(None) - else if(exprOwners.exists(o1 => exprOwners.exists(o2 => o1 != o2))) - Some(None) - else - exprOwners.head - } - - def getOwner(expr: LeonExpr): Option[Option[FunDef]] = getOwner(getReturnedExpr(expr)) } def containsLetDef(expr: LeonExpr): Boolean = { diff --git a/src/main/scala/leon/grammars/BaseGrammar.scala b/src/main/scala/leon/grammars/BaseGrammar.scala index f11f937498051eb47c2c522a0faa2a1499545175..6e0a2ee5e6842255aac5755c9ee27005e5360eb5 100644 --- a/src/main/scala/leon/grammars/BaseGrammar.scala +++ b/src/main/scala/leon/grammars/BaseGrammar.scala @@ -7,56 +7,65 @@ import purescala.Types._ import purescala.Expressions._ import purescala.Constructors._ +/** The basic grammar for Leon expressions. + * Generates the most obvious expressions for a given type, + * without regard of context (variables in scope, current function etc.) + * Also does some trivial simplifications. + */ case object BaseGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => List( - terminal(BooleanLiteral(true)), - terminal(BooleanLiteral(false)), - nonTerminal(List(BooleanType), { case Seq(a) => not(a) }), - nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }), - nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), + terminal(BooleanLiteral(false), Tags.BooleanC), + terminal(BooleanLiteral(true), Tags.BooleanC), + nonTerminal(List(BooleanType), { case Seq(a) => not(a) }, Tags.Not), + nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }, Tags.And), + nonTerminal(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }, Tags.Or ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), nonTerminal(List(Int32Type, Int32Type), { case Seq(a, b) => LessEquals(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), nonTerminal(List(IntegerType, IntegerType), { case Seq(a, b) => LessEquals(a, b) }) ) case Int32Type => List( - terminal(IntLiteral(0)), - terminal(IntLiteral(1)), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => plus(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => minus(a, b) }), - nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => times(a, b) }) + terminal(IntLiteral(0), Tags.Zero), + terminal(IntLiteral(1), Tags.One ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => plus(a, b) }, Tags.Plus ), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => minus(a, b) }, Tags.Minus), + nonTerminal(List(Int32Type, Int32Type), { case Seq(a,b) => times(a, b) }, Tags.Times) ) case IntegerType => List( - terminal(InfiniteIntegerLiteral(0)), - terminal(InfiniteIntegerLiteral(1)), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => plus(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => minus(a, b) }), - nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => times(a, b) }) + terminal(InfiniteIntegerLiteral(0), Tags.Zero), + terminal(InfiniteIntegerLiteral(1), Tags.One ), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => plus(a, b) }, Tags.Plus ), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => minus(a, b) }, Tags.Minus), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => times(a, b) }, Tags.Times), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => Modulo(a, b) }, Tags.Mod), + nonTerminal(List(IntegerType, IntegerType), { case Seq(a,b) => Division(a, b) }, Tags.Div) ) case TupleType(stps) => List( - nonTerminal(stps, { sub => Tuple(sub) }) + nonTerminal(stps, Tuple, Tags.Constructor(isTerminal = false)) ) case cct: CaseClassType => List( - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct) ) ) case act: AbstractClassType => act.knownCCDescendants.map { cct => - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)} ) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct) ) } case st @ SetType(base) => List( - nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }), + terminal(FiniteSet(Set(), base), Tags.Constant), + nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)), nonTerminal(List(st, st), { case Seq(a, b) => SetUnion(a, b) }), nonTerminal(List(st, st), { case Seq(a, b) => SetIntersection(a, b) }), nonTerminal(List(st, st), { case Seq(a, b) => SetDifference(a, b) }) @@ -64,7 +73,7 @@ case object BaseGrammar extends ExpressionGrammar[TypeTree] { case UnitType => List( - terminal(UnitLiteral()) + terminal(UnitLiteral(), Tags.Constant) ) case _ => diff --git a/src/main/scala/leon/grammars/Constants.scala b/src/main/scala/leon/grammars/Constants.scala new file mode 100644 index 0000000000000000000000000000000000000000..81c55346052668e0d82b05ee240867eb1e5c468c --- /dev/null +++ b/src/main/scala/leon/grammars/Constants.scala @@ -0,0 +1,33 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import purescala.Expressions._ +import purescala.Types.TypeTree +import purescala.ExprOps.collect +import purescala.Extractors.IsTyped + +/** Generates constants found in an [[leon.purescala.Expressions.Expr]]. + * Some constants that are generated by other grammars (like 0, 1) will be excluded + */ +case class Constants(e: Expr) extends ExpressionGrammar[TypeTree] { + + private val excluded: Set[Expr] = Set( + InfiniteIntegerLiteral(1), + InfiniteIntegerLiteral(0), + IntLiteral(1), + IntLiteral(0), + BooleanLiteral(true), + BooleanLiteral(false) + ) + + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { + val literals = collect[Expr]{ + case IsTyped(l:Literal[_], `t`) => Set(l) + case _ => Set() + }(e) + + (literals -- excluded map (terminal(_, Tags.Constant))).toSeq + } +} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/DepthBoundedGrammar.scala b/src/main/scala/leon/grammars/DepthBoundedGrammar.scala deleted file mode 100644 index fc999be644bf2c4a7a20a73403cf7b1001bb9b68..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/DepthBoundedGrammar.scala +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -case class DepthBoundedGrammar[T](g: ExpressionGrammar[NonTerminal[T]], bound: Int) extends ExpressionGrammar[NonTerminal[T]] { - def computeProductions(l: NonTerminal[T])(implicit ctx: LeonContext): Seq[Gen] = g.computeProductions(l).flatMap { - case gen => - if (l.depth == Some(bound) && gen.subTrees.nonEmpty) { - Nil - } else if (l.depth.exists(_ > bound)) { - Nil - } else { - List ( - nonTerminal(gen.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), gen.builder) - ) - } - } -} diff --git a/src/main/scala/leon/grammars/Empty.scala b/src/main/scala/leon/grammars/Empty.scala index 70ebddc98f21fc872aef8635fe36de7e9ba9bbce..737f9cdf389454f403a6581e13eec7fafa383f34 100644 --- a/src/main/scala/leon/grammars/Empty.scala +++ b/src/main/scala/leon/grammars/Empty.scala @@ -5,6 +5,7 @@ package grammars import purescala.Types.Typed +/** The empty expression grammar */ case class Empty[T <: Typed]() extends ExpressionGrammar[T] { - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = Nil + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = Nil } diff --git a/src/main/scala/leon/grammars/EqualityGrammar.scala b/src/main/scala/leon/grammars/EqualityGrammar.scala index e9463a771204d877a4d748c373b6d198e2c2591b..a2f9c41360ada03334ace63eca3ca46f9f6d5ff7 100644 --- a/src/main/scala/leon/grammars/EqualityGrammar.scala +++ b/src/main/scala/leon/grammars/EqualityGrammar.scala @@ -6,13 +6,15 @@ package grammars import purescala.Types._ import purescala.Constructors._ -import bonsai._ - +/** A grammar of equalities + * + * @param types The set of types for which equalities will be generated + */ case class EqualityGrammar(types: Set[TypeTree]) extends ExpressionGrammar[TypeTree] { - override def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => types.toList map { tp => - nonTerminal(List(tp, tp), { case Seq(a, b) => equality(a, b) }) + nonTerminal(List(tp, tp), { case Seq(a, b) => equality(a, b) }, Tags.Equals) } case _ => Nil diff --git a/src/main/scala/leon/grammars/ExpressionGrammar.scala b/src/main/scala/leon/grammars/ExpressionGrammar.scala index ac394ab840bddf0d498080a04e447ce66de07caa..3179312b7f65444eb3e8c39357fd449e13339c8f 100644 --- a/src/main/scala/leon/grammars/ExpressionGrammar.scala +++ b/src/main/scala/leon/grammars/ExpressionGrammar.scala @@ -6,23 +6,37 @@ package grammars import purescala.Expressions._ import purescala.Types._ import purescala.Common._ +import transformers.Union +import utils.Timer import scala.collection.mutable.{HashMap => MutableMap} +/** Represents a context-free grammar of expressions + * + * @tparam T The type of nonterminal symbols for this grammar + */ abstract class ExpressionGrammar[T <: Typed] { - type Gen = Generator[T, Expr] - private[this] val cache = new MutableMap[T, Seq[Gen]]() + type Prod = ProductionRule[T, Expr] - def terminal(builder: => Expr) = { - Generator[T, Expr](Nil, { (subs: Seq[Expr]) => builder }) + private[this] val cache = new MutableMap[T, Seq[Prod]]() + + /** Generates a [[ProductionRule]] without nonterminal symbols */ + def terminal(builder: => Expr, tag: Tags.Tag = Tags.Top, cost: Int = 1) = { + ProductionRule[T, Expr](Nil, { (subs: Seq[Expr]) => builder }, tag, cost) } - def nonTerminal(subs: Seq[T], builder: (Seq[Expr] => Expr)): Generator[T, Expr] = { - Generator[T, Expr](subs, builder) + /** Generates a [[ProductionRule]] with nonterminal symbols */ + def nonTerminal(subs: Seq[T], builder: (Seq[Expr] => Expr), tag: Tags.Tag = Tags.Top, cost: Int = 1): ProductionRule[T, Expr] = { + ProductionRule[T, Expr](subs, builder, tag, cost) } - def getProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = { + /** The list of production rules for this grammar for a given nonterminal. + * This is the cached version of [[getProductions]] which clients should use. + * + * @param t The nonterminal for which production rules will be generated + */ + def getProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = { cache.getOrElse(t, { val res = computeProductions(t) cache += t -> res @@ -30,9 +44,13 @@ abstract class ExpressionGrammar[T <: Typed] { }) } - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] + /** The list of production rules for this grammar for a given nonterminal + * + * @param t The nonterminal for which production rules will be generated + */ + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] - def filter(f: Gen => Boolean) = { + def filter(f: Prod => Boolean) = { new ExpressionGrammar[T] { def computeProductions(t: T)(implicit ctx: LeonContext) = ExpressionGrammar.this.computeProductions(t).filter(f) } @@ -44,14 +62,19 @@ abstract class ExpressionGrammar[T <: Typed] { final def printProductions(printer: String => Unit)(implicit ctx: LeonContext) { - for ((t, gs) <- cache; g <- gs) { - val subs = g.subTrees.map { t => - FreshIdentifier(Console.BOLD+t.asString+Console.RESET, t.getType).toVariable - } + for ((t, gs) <- cache) { + val lhs = f"${Console.BOLD}${t.asString}%50s${Console.RESET} ::=" + if (gs.isEmpty) { + printer(s"$lhs ε") + } else for (g <- gs) { + val subs = g.subTrees.map { t => + FreshIdentifier(Console.BOLD + t.asString + Console.RESET, t.getType).toVariable + } - val gen = g.builder(subs).asString + val gen = g.builder(subs).asString - printer(f"${Console.BOLD}${t.asString}%30s${Console.RESET} ::= $gen") + printer(s"$lhs $gen") + } } } } diff --git a/src/main/scala/leon/grammars/FunctionCalls.scala b/src/main/scala/leon/grammars/FunctionCalls.scala index 14f92393934c18804bdb130e9c1617b915a347bd..1233fb1931a83b5ca674019be0c85144339dd19f 100644 --- a/src/main/scala/leon/grammars/FunctionCalls.scala +++ b/src/main/scala/leon/grammars/FunctionCalls.scala @@ -10,8 +10,14 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Expressions._ +/** Generates non-recursive function calls + * + * @param currentFunction The currend function for which no calls will be generated + * @param types The candidate real type parameters for [[currentFunction]] + * @param exclude An additional set of functions for which no calls will be generated + */ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[TypeTree], exclude: Set[FunDef]) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { def getCandidates(fd: FunDef): Seq[TypedFunDef] = { // Prevents recursive calls @@ -73,7 +79,7 @@ case class FunctionCalls(prog: Program, currentFunction: FunDef, types: Seq[Type val funcs = visibleFunDefsFromMain(prog).toSeq.sortBy(_.id).flatMap(getCandidates).filterNot(filter) funcs.map{ tfd => - nonTerminal(tfd.params.map(_.getType), { sub => FunctionInvocation(tfd, sub) }) + nonTerminal(tfd.params.map(_.getType), FunctionInvocation(tfd, _), Tags.tagOf(tfd.fd, isSafe = false)) } } } diff --git a/src/main/scala/leon/grammars/Generator.scala b/src/main/scala/leon/grammars/Generator.scala deleted file mode 100644 index 18d132e2c25ea222324dc05809220f12d0fb7100..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/Generator.scala +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -import bonsai.{Generator => Gen} - -object GrammarTag extends Enumeration { - val Top = Value -} -import GrammarTag._ - -class Generator[T, R](subTrees: Seq[T], builder: Seq[R] => R, tag: Value) extends Gen[T,R](subTrees, builder) -object Generator { - def apply[T, R](subTrees: Seq[T], builder: Seq[R] => R, tag: Value = Top) = new Generator(subTrees, builder, tag) -} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/Grammars.scala b/src/main/scala/leon/grammars/Grammars.scala index 23b1dd5a14cfeb82dd4555832e777597615b337e..06aba3d5f5343cc7e2807854f0b4665bfa1a602c 100644 --- a/src/main/scala/leon/grammars/Grammars.scala +++ b/src/main/scala/leon/grammars/Grammars.scala @@ -7,6 +7,7 @@ import purescala.Expressions._ import purescala.Definitions._ import purescala.Types._ import purescala.TypeOps._ +import transformers.OneOf import synthesis.{SynthesisContext, Problem} @@ -16,6 +17,7 @@ object Grammars { BaseGrammar || EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ inputs.map { _.getType }) || OneOf(inputs) || + Constants(currentFunction.fullBody) || FunctionCalls(prog, currentFunction, inputs.map(_.getType), exclude) || SafeRecursiveCalls(prog, ws, pc) } @@ -28,3 +30,4 @@ object Grammars { g.filter(g => g.subTrees.forall(t => typeDepth(t.getType) <= b)) } } + diff --git a/src/main/scala/leon/grammars/NonTerminal.scala b/src/main/scala/leon/grammars/NonTerminal.scala index 7492ffac5c17df326084f846857c6ac3bebe1775..600189ffa06378841f6bf3285f7f1bd7bb6116f5 100644 --- a/src/main/scala/leon/grammars/NonTerminal.scala +++ b/src/main/scala/leon/grammars/NonTerminal.scala @@ -5,7 +5,14 @@ package grammars import purescala.Types._ -case class NonTerminal[T](t: TypeTree, l: T, depth: Option[Int] = None) extends Typed { +/** A basic non-terminal symbol of a grammar. + * + * @param t The type of which expressions will be generated + * @param l A label that characterizes this [[NonTerminal]] + * @param depth The optional depth within the syntax tree where this [[NonTerminal]] is. + * @tparam L The type of label for this NonTerminal. + */ +case class NonTerminal[L](t: TypeTree, l: L, depth: Option[Int] = None) extends Typed { val getType = t override def asString(implicit ctx: LeonContext) = t.asString+"#"+l+depth.map(d => "@"+d).getOrElse("") diff --git a/src/main/scala/leon/grammars/ProductionRule.scala b/src/main/scala/leon/grammars/ProductionRule.scala new file mode 100644 index 0000000000000000000000000000000000000000..fc493a7d9d17557a26a8e54ff4615d39ba922190 --- /dev/null +++ b/src/main/scala/leon/grammars/ProductionRule.scala @@ -0,0 +1,18 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import bonsai.Generator + +/** Represents a production rule of a non-terminal symbol of an [[ExpressionGrammar]]. + * + * @param subTrees The nonterminals that are used in the right-hand side of this [[ProductionRule]] + * (and will generate deeper syntax trees). + * @param builder A function that builds the syntax tree that this [[ProductionRule]] represents from nested trees. + * @param tag Gives information about the nature of this production rule. + * @tparam T The type of nonterminal symbols of the grammar + * @tparam R The type of syntax trees of the grammar + */ +case class ProductionRule[T, R](override val subTrees: Seq[T], override val builder: Seq[R] => R, tag: Tags.Tag, cost: Int = 1) + extends Generator[T,R](subTrees, builder) diff --git a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala index 1bbcb0523158ac95713f5a0d4a16f0f35e14edf4..f3234176a8c17378a7a5f027f38cbd42069ae7d6 100644 --- a/src/main/scala/leon/grammars/SafeRecursiveCalls.scala +++ b/src/main/scala/leon/grammars/SafeRecursiveCalls.scala @@ -9,15 +9,25 @@ import purescala.ExprOps._ import purescala.Expressions._ import synthesis.utils.Helpers._ +/** Generates recursive calls that will not trivially result in non-termination. + * + * @param ws An expression that contains the known set [[synthesis.Witnesses.Terminating]] expressions + * @param pc The path condition for the generated [[Expr]] by this grammar + */ case class SafeRecursiveCalls(prog: Program, ws: Expr, pc: Expr) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { val calls = terminatingCalls(prog, t, ws, pc) calls.map { - case (e, free) => + case (fi, free) => val freeSeq = free.toSeq - nonTerminal(freeSeq.map(_.getType), { sub => replaceFromIDs(freeSeq.zip(sub).toMap, e) }) + nonTerminal( + freeSeq.map(_.getType), + { sub => replaceFromIDs(freeSeq.zip(sub).toMap, fi) }, + Tags.tagOf(fi.tfd.fd, isSafe = true), + 2 + ) } } } diff --git a/src/main/scala/leon/grammars/SimilarTo.scala b/src/main/scala/leon/grammars/SimilarTo.scala index 77e912792965d860fc934eb016370c8f2b57fd8f..3a7708e9a77960ffbfde98d478d2ca7c73c713d0 100644 --- a/src/main/scala/leon/grammars/SimilarTo.scala +++ b/src/main/scala/leon/grammars/SimilarTo.scala @@ -3,21 +3,24 @@ package leon package grammars +import transformers._ import purescala.Types._ import purescala.TypeOps._ import purescala.Extractors._ import purescala.Definitions._ -import purescala.ExprOps._ -import purescala.DefOps._ import purescala.Expressions._ import synthesis._ +/** A grammar that generates expressions by inserting small variations in [[e]] + * @param e The [[Expr]] to which small variations will be inserted + * @param terminals A set of [[Expr]]s that may be inserted into [[e]] as small variations + */ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisContext, p: Problem) extends ExpressionGrammar[NonTerminal[String]] { val excludeFCalls = sctx.settings.functionsToIgnore - val normalGrammar = DepthBoundedGrammar(EmbeddedGrammar( + val normalGrammar: ExpressionGrammar[NonTerminal[String]] = DepthBoundedGrammar(EmbeddedGrammar( BaseGrammar || EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ terminals.map { _.getType }) || OneOf(terminals.toSeq :+ e) || @@ -37,9 +40,9 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } } - private[this] var similarCache: Option[Map[L, Seq[Gen]]] = None + private[this] var similarCache: Option[Map[L, Seq[Prod]]] = None - def computeProductions(t: L)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: L)(implicit ctx: LeonContext): Seq[Prod] = { t match { case NonTerminal(_, "B", _) => normalGrammar.computeProductions(t) case _ => @@ -54,7 +57,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } } - def computeSimilar(e : Expr)(implicit ctx: LeonContext): Seq[(L, Gen)] = { + def computeSimilar(e : Expr)(implicit ctx: LeonContext): Seq[(L, Prod)] = { def getLabel(t: TypeTree) = { val tpe = bestRealType(t) @@ -67,9 +70,9 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte case _ => false } - def rec(e: Expr, gl: L): Seq[(L, Gen)] = { + def rec(e: Expr, gl: L): Seq[(L, Prod)] = { - def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Gen)] = { + def gens(e: Expr, gl: L, subs: Seq[Expr], builder: (Seq[Expr] => Expr)): Seq[(L, Prod)] = { val subGls = subs.map { s => getLabel(s.getType) } // All the subproductions for sub gl @@ -81,8 +84,8 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } val swaps = if (subs.size > 1 && !isCommutative(e)) { - (for (i <- 0 until subs.size; - j <- i+1 until subs.size) yield { + (for (i <- subs.indices; + j <- i+1 until subs.size) yield { if (subs(i).getType == subs(j).getType) { val swapSubs = subs.updated(i, subs(j)).updated(j, subs(i)) @@ -98,18 +101,18 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte allSubs ++ injectG ++ swaps } - def cegis(gl: L): Seq[(L, Gen)] = { + def cegis(gl: L): Seq[(L, Prod)] = { normalGrammar.getProductions(gl).map(gl -> _) } - def int32Variations(gl: L, e : Expr): Seq[(L, Gen)] = { + def int32Variations(gl: L, e : Expr): Seq[(L, Prod)] = { Seq( gl -> terminal(BVMinus(e, IntLiteral(1))), gl -> terminal(BVPlus (e, IntLiteral(1))) ) } - def intVariations(gl: L, e : Expr): Seq[(L, Gen)] = { + def intVariations(gl: L, e : Expr): Seq[(L, Prod)] = { Seq( gl -> terminal(Minus(e, InfiniteIntegerLiteral(1))), gl -> terminal(Plus (e, InfiniteIntegerLiteral(1))) @@ -118,7 +121,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte // Find neighbor case classes that are compatible with the arguments: // Turns And(e1, e2) into Or(e1, e2)... - def ccVariations(gl: L, cc: CaseClass): Seq[(L, Gen)] = { + def ccVariations(gl: L, cc: CaseClass): Seq[(L, Prod)] = { val CaseClass(cct, args) = cc val neighbors = cct.root.knownCCDescendants diff Seq(cct) @@ -129,7 +132,7 @@ case class SimilarTo(e: Expr, terminals: Set[Expr] = Set(), sctx: SynthesisConte } val funFilter = (fd: FunDef) => fd.isSynthetic || (excludeFCalls contains fd) - val subs: Seq[(L, Gen)] = e match { + val subs: Seq[(L, Prod)] = e match { case _: Terminal | _: Let | _: LetDef | _: MatchExpr => gens(e, gl, Nil, { _ => e }) ++ cegis(gl) diff --git a/src/main/scala/leon/grammars/SizeBoundedGrammar.scala b/src/main/scala/leon/grammars/SizeBoundedGrammar.scala deleted file mode 100644 index 1b25e30f61aa74598feb255366fe10a153bc9e30..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/grammars/SizeBoundedGrammar.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package grammars - -import purescala.Types._ -import leon.utils.SeqUtils.sumTo - -case class SizedLabel[T <: Typed](underlying: T, size: Int) extends Typed { - val getType = underlying.getType - - override def asString(implicit ctx: LeonContext) = underlying.asString+"|"+size+"|" -} - -case class SizeBoundedGrammar[T <: Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[SizedLabel[T]] { - def computeProductions(sl: SizedLabel[T])(implicit ctx: LeonContext): Seq[Gen] = { - if (sl.size <= 0) { - Nil - } else if (sl.size == 1) { - g.getProductions(sl.underlying).filter(_.subTrees.isEmpty).map { gen => - terminal(gen.builder(Seq())) - } - } else { - g.getProductions(sl.underlying).filter(_.subTrees.nonEmpty).flatMap { gen => - val sizes = sumTo(sl.size-1, gen.subTrees.size) - - for (ss <- sizes) yield { - val subSizedLabels = (gen.subTrees zip ss) map (s => SizedLabel(s._1, s._2)) - - nonTerminal(subSizedLabels, gen.builder) - } - } - } - } -} diff --git a/src/main/scala/leon/grammars/Tags.scala b/src/main/scala/leon/grammars/Tags.scala new file mode 100644 index 0000000000000000000000000000000000000000..4a6b6fca491b8db6f74622edd9298ec5cd6053b0 --- /dev/null +++ b/src/main/scala/leon/grammars/Tags.scala @@ -0,0 +1,65 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars + +import purescala.Types.CaseClassType +import purescala.Definitions.FunDef + +object Tags { + /** A class for tags that tag a [[ProductionRule]] with the kind of expression in generates. */ + abstract class Tag + case object Top extends Tag // Tag for the top-level of the grammar (default) + case object Zero extends Tag // Tag for 0 + case object One extends Tag // Tag for 1 + case object BooleanC extends Tag // Tag for boolean constants + case object Constant extends Tag // Tag for other constants + case object And extends Tag // Tags for boolean operations + case object Or extends Tag + case object Not extends Tag + case object Plus extends Tag // Tags for arithmetic operations + case object Minus extends Tag + case object Times extends Tag + case object Mod extends Tag + case object Div extends Tag + case object Variable extends Tag // Tag for variables + case object Equals extends Tag // Tag for equality + /** Constructors like Tuple, CaseClass... + * + * @param isTerminal If true, this constructor represents a terminal symbol + * (in practice, case class with 0 fields) + */ + case class Constructor(isTerminal: Boolean) extends Tag + /** Tag for function calls + * + * @param isMethod Whether the function called is a method + * @param isSafe Whether this constructor represents a safe function call. + * We need this because this call implicitly contains a variable, + * so we want to allow constants in all arguments. + */ + case class FunCall(isMethod: Boolean, isSafe: Boolean) extends Tag + + /** The set of tags that represent constants */ + val isConst: Set[Tag] = Set(Zero, One, Constant, BooleanC, Constructor(true)) + + /** The set of tags that represent commutative operations */ + val isCommut: Set[Tag] = Set(Plus, Times, Equals) + + /** The set of tags which have trivial results for equal arguments */ + val symmetricTrivial = Set(Minus, And, Or, Equals, Div, Mod) + + /** Tags which allow constants in all their operands + * + * In reality, the current version never allows that: it is only allowed in safe function calls + * which by construction contain a hidden reference to a variable. + * TODO: Experiment with different conditions, e.g. are constants allowed in + * top-level/ general function calls/ constructors/...? + */ + def allConstArgsAllowed(t: Tag) = t match { + case FunCall(_, true) => true + case _ => false + } + + def tagOf(cct: CaseClassType) = Constructor(cct.fields.isEmpty) + def tagOf(fd: FunDef, isSafe: Boolean) = FunCall(fd.methodOwner.isDefined, isSafe) +} \ No newline at end of file diff --git a/src/main/scala/leon/grammars/ValueGrammar.scala b/src/main/scala/leon/grammars/ValueGrammar.scala index 98850c8df4adcf3e776970c176cf37c251823917..d3c42201728f4b03d9518b3db503cff9189dcc8b 100644 --- a/src/main/scala/leon/grammars/ValueGrammar.scala +++ b/src/main/scala/leon/grammars/ValueGrammar.scala @@ -6,62 +6,64 @@ package grammars import purescala.Types._ import purescala.Expressions._ +/** A grammar of values (ground terms) */ case object ValueGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { case BooleanType => List( - terminal(BooleanLiteral(true)), - terminal(BooleanLiteral(false)) + terminal(BooleanLiteral(true), Tags.One), + terminal(BooleanLiteral(false), Tags.Zero) ) case Int32Type => List( - terminal(IntLiteral(0)), - terminal(IntLiteral(1)), - terminal(IntLiteral(5)) + terminal(IntLiteral(0), Tags.Zero), + terminal(IntLiteral(1), Tags.One), + terminal(IntLiteral(5), Tags.Constant) ) case IntegerType => List( - terminal(InfiniteIntegerLiteral(0)), - terminal(InfiniteIntegerLiteral(1)), - terminal(InfiniteIntegerLiteral(5)) + terminal(InfiniteIntegerLiteral(0), Tags.Zero), + terminal(InfiniteIntegerLiteral(1), Tags.One), + terminal(InfiniteIntegerLiteral(5), Tags.Constant) ) case StringType => List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("foo")), - terminal(StringLiteral("bar")) + terminal(StringLiteral(""), Tags.Constant), + terminal(StringLiteral("a"), Tags.Constant), + terminal(StringLiteral("foo"), Tags.Constant), + terminal(StringLiteral("bar"), Tags.Constant) ) case tp: TypeParameter => - for (ind <- (1 to 3).toList) yield { - terminal(GenericValue(tp, ind)) - } + List( + terminal(GenericValue(tp, 0)) + ) case TupleType(stps) => List( - nonTerminal(stps, { sub => Tuple(sub) }) + nonTerminal(stps, Tuple, Tags.Constructor(stps.isEmpty)) ) case cct: CaseClassType => List( - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)}) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct)) ) case act: AbstractClassType => act.knownCCDescendants.map { cct => - nonTerminal(cct.fields.map(_.getType), { case rs => CaseClass(cct, rs)}) + nonTerminal(cct.fields.map(_.getType), CaseClass(cct, _), Tags.tagOf(cct)) } case st @ SetType(base) => List( - nonTerminal(List(base), { case elems => FiniteSet(elems.toSet, base) }), - nonTerminal(List(base, base), { case elems => FiniteSet(elems.toSet, base) }) + terminal(FiniteSet(Set(), base), Tags.Constant), + nonTerminal(List(base), { elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)), + nonTerminal(List(base, base), { elems => FiniteSet(elems.toSet, base) }, Tags.Constructor(isTerminal = false)) ) case UnitType => List( - terminal(UnitLiteral()) + terminal(UnitLiteral(), Tags.Constant) ) case _ => diff --git a/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala b/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..02e045497a28db205d4a33a300ac0b742510920a --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/DepthBoundedGrammar.scala @@ -0,0 +1,21 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +/** Limits a grammar to a specific expression depth */ +case class DepthBoundedGrammar[L](g: ExpressionGrammar[NonTerminal[L]], bound: Int) extends ExpressionGrammar[NonTerminal[L]] { + def computeProductions(l: NonTerminal[L])(implicit ctx: LeonContext): Seq[Prod] = g.computeProductions(l).flatMap { + case gen => + if (l.depth == Some(bound) && gen.isNonTerminal) { + Nil + } else if (l.depth.exists(_ > bound)) { + Nil + } else { + List ( + nonTerminal(gen.subTrees.map(sl => sl.copy(depth = l.depth.map(_+1).orElse(Some(1)))), gen.builder) + ) + } + } +} diff --git a/src/main/scala/leon/grammars/EmbeddedGrammar.scala b/src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala similarity index 74% rename from src/main/scala/leon/grammars/EmbeddedGrammar.scala rename to src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala index 8dcbc6ec10f9aa42895e5f876cdd4d72479de229..d989a8804b32f62697b7f31e498e61393a12c35b 100644 --- a/src/main/scala/leon/grammars/EmbeddedGrammar.scala +++ b/src/main/scala/leon/grammars/transformers/EmbeddedGrammar.scala @@ -2,10 +2,9 @@ package leon package grammars +package transformers -import purescala.Types._ -import purescala.Expressions._ -import purescala.Constructors._ +import leon.purescala.Types.Typed /** * Embed a grammar Li->Expr within a grammar Lo->Expr @@ -13,9 +12,9 @@ import purescala.Constructors._ * We rely on a bijection between Li and Lo labels */ case class EmbeddedGrammar[Ti <: Typed, To <: Typed](innerGrammar: ExpressionGrammar[Ti], iToo: Ti => To, oToi: To => Ti) extends ExpressionGrammar[To] { - def computeProductions(t: To)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: To)(implicit ctx: LeonContext): Seq[Prod] = { innerGrammar.computeProductions(oToi(t)).map { innerGen => - nonTerminal(innerGen.subTrees.map(iToo), innerGen.builder) + nonTerminal(innerGen.subTrees.map(iToo), innerGen.builder, innerGen.tag) } } } diff --git a/src/main/scala/leon/grammars/OneOf.scala b/src/main/scala/leon/grammars/transformers/OneOf.scala similarity index 56% rename from src/main/scala/leon/grammars/OneOf.scala rename to src/main/scala/leon/grammars/transformers/OneOf.scala index 0e10c096151c1fdf83d3c7e7f10c4a4a6518215b..5c57c6a1a48179e2d813398aa651022df6cae35a 100644 --- a/src/main/scala/leon/grammars/OneOf.scala +++ b/src/main/scala/leon/grammars/transformers/OneOf.scala @@ -2,14 +2,15 @@ package leon package grammars +package transformers -import purescala.Types._ -import purescala.Expressions._ -import purescala.TypeOps._ -import purescala.Constructors._ +import purescala.Expressions.Expr +import purescala.Types.TypeTree +import purescala.TypeOps.isSubtypeOf +/** Generates one production rule for each expression in a sequence that has compatible type */ case class OneOf(inputs: Seq[Expr]) extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = { + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = { inputs.collect { case i if isSubtypeOf(i.getType, t) => terminal(i) diff --git a/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala b/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..1b605359fdf18d02f08d105a9cccc58757b99262 --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/SizeBoundedGrammar.scala @@ -0,0 +1,59 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +import purescala.Types.Typed +import utils.SeqUtils._ + +/** Adds information about size to a nonterminal symbol */ +case class SizedNonTerm[T <: Typed](underlying: T, size: Int) extends Typed { + val getType = underlying.getType + + override def asString(implicit ctx: LeonContext) = underlying.asString+"|"+size+"|" +} + +/** Limits a grammar by producing expressions of size bounded by the [[SizedNonTerm.size]] of a given [[SizedNonTerm]]. + * + * In case of commutative operations, the grammar will produce trees skewed to the right + * (i.e. the right subtree will always be larger). Notice we do not lose generality in case of + * commutative operations. + */ +case class SizeBoundedGrammar[T <: Typed](g: ExpressionGrammar[T], optimizeCommut: Boolean) extends ExpressionGrammar[SizedNonTerm[T]] { + def computeProductions(sl: SizedNonTerm[T])(implicit ctx: LeonContext): Seq[Prod] = { + if (sl.size <= 0) { + Nil + } else if (sl.size == 1) { + g.getProductions(sl.underlying).filter(_.isTerminal).map { gen => + terminal(gen.builder(Seq()), gen.tag) + } + } else { + g.getProductions(sl.underlying).filter(_.isNonTerminal).flatMap { gen => + + // Ad-hoc equality that does not take into account position etc.of TaggedNonTerminal's + // TODO: Ugly and hacky + def characteristic(t: T): Typed = t match { + case TaggedNonTerm(underlying, _, _, _) => + underlying + case other => + other + } + + // Optimization: When we have a commutative operation and all the labels are the same, + // we can skew the expression to always be right-heavy + val sizes = if(optimizeCommut && Tags.isCommut(gen.tag) && gen.subTrees.map(characteristic).toSet.size == 1) { + sumToOrdered(sl.size-gen.cost, gen.arity) + } else { + sumTo(sl.size-gen.cost, gen.arity) + } + + for (ss <- sizes) yield { + val subSizedLabels = (gen.subTrees zip ss) map (s => SizedNonTerm(s._1, s._2)) + + nonTerminal(subSizedLabels, gen.builder, gen.tag) + } + } + } + } +} diff --git a/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala b/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala new file mode 100644 index 0000000000000000000000000000000000000000..43ce13e850ed1b52460ef1a74d7b039adacbd519 --- /dev/null +++ b/src/main/scala/leon/grammars/transformers/TaggedGrammar.scala @@ -0,0 +1,111 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package grammars +package transformers + +import leon.purescala.Types.Typed +import Tags._ + +/** Adds to a nonterminal information about about the tag of its parent's [[leon.grammars.ProductionRule.tag]] + * and additional information. + * + * @param underlying The underlying nonterminal + * @param tag The tag of the parent of this nonterminal + * @param pos The index of this nonterminal in its father's production rule + * @param isConst Whether this nonterminal is obliged to generate/not generate constants. + * + */ +case class TaggedNonTerm[T <: Typed](underlying: T, tag: Tag, pos: Int, isConst: Option[Boolean]) extends Typed { + val getType = underlying.getType + + private val cString = isConst match { + case Some(true) => "↓" + case Some(false) => "↑" + case None => "○" + } + + /** [[isConst]] is printed as follows: ↓ for constants only, ↑ for nonconstants only, + * ○ for anything allowed. + */ + override def asString(implicit ctx: LeonContext): String = s"$underlying%$tag@$pos$cString" +} + +/** Constraints a grammar to reduce redundancy by utilizing information provided by the [[TaggedNonTerm]]. + * + * 1) In case of associative operations, right associativity is enforced. + * 2) Does not generate + * - neutral and absorbing elements (incl. boolean equality) + * - nested negations + * 3) Excludes method calls on nullary case objects, e.g. Nil().size + * 4) Enforces that no constant trees are generated (and recursively for each subtree) + * + * @param g The underlying untagged grammar + */ +case class TaggedGrammar[T <: Typed](g: ExpressionGrammar[T]) extends ExpressionGrammar[TaggedNonTerm[T]] { + + private def exclude(tag: Tag, pos: Int): Set[Tag] = (tag, pos) match { + case (Top, _) => Set() + case (And, 0) => Set(And, BooleanC) + case (And, 1) => Set(BooleanC) + case (Or, 0) => Set(Or, BooleanC) + case (Or, 1) => Set(BooleanC) + case (Plus, 0) => Set(Plus, Zero, One) + case (Plus, 1) => Set(Zero) + case (Minus, 1) => Set(Zero) + case (Not, _) => Set(Not, BooleanC) + case (Times, 0) => Set(Times, Zero, One) + case (Times, 1) => Set(Zero, One) + case (Equals,_) => Set(Not, BooleanC) + case (Div | Mod, 0 | 1) => Set(Zero, One) + case (FunCall(true, _), 0) => Set(Constructor(true)) // Don't allow Nil().size etc. + case _ => Set() + } + + def computeProductions(t: TaggedNonTerm[T])(implicit ctx: LeonContext): Seq[Prod] = { + + // Point (4) for this level + val constFilter: g.Prod => Boolean = t.isConst match { + case Some(b) => + innerGen => isConst(innerGen.tag) == b + case None => + _ => true + } + + g.computeProductions(t.underlying) + // Include only constants iff constants are forced, only non-constants iff they are forced + .filter(constFilter) + // Points (1), (2). (3) + .filterNot { innerGen => exclude(t.tag, t.pos)(innerGen.tag) } + .flatMap { innerGen => + + def nt(isConst: Int => Option[Boolean]) = nonTerminal( + innerGen.subTrees.zipWithIndex.map { + case (t, pos) => TaggedNonTerm(t, innerGen.tag, pos, isConst(pos)) + }, + innerGen.builder, + innerGen.tag + ) + + def powerSet[A](t: Set[A]): Set[Set[A]] = { + @scala.annotation.tailrec + def pwr(t: Set[A], ps: Set[Set[A]]): Set[Set[A]] = + if (t.isEmpty) ps + else pwr(t.tail, ps ++ (ps map (_ + t.head))) + + pwr(t, Set(Set.empty[A])) + } + + // Allow constants everywhere if this is allowed, otherwise demand at least 1 variable. + // Aka. tag subTrees correctly so point (4) is enforced in the lower level + // (also, make sure we treat terminals correctly). + if (innerGen.isTerminal || allConstArgsAllowed(innerGen.tag)) { + Seq(nt(_ => None)) + } else { + val indices = innerGen.subTrees.indices.toSet + (powerSet(indices) - indices) map (indices => nt(x => Some(indices(x)))) + } + } + } + +} diff --git a/src/main/scala/leon/grammars/Or.scala b/src/main/scala/leon/grammars/transformers/Union.scala similarity index 73% rename from src/main/scala/leon/grammars/Or.scala rename to src/main/scala/leon/grammars/transformers/Union.scala index e691a245984eaeb11277b9278505b49cf623fed3..471625ac3c22c22456f49f366ed26e5195b5f4ab 100644 --- a/src/main/scala/leon/grammars/Or.scala +++ b/src/main/scala/leon/grammars/transformers/Union.scala @@ -2,8 +2,9 @@ package leon package grammars +package transformers -import purescala.Types._ +import purescala.Types.Typed case class Union[T <: Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGrammar[T] { val subGrammars: Seq[ExpressionGrammar[T]] = gs.flatMap { @@ -11,6 +12,6 @@ case class Union[T <: Typed](gs: Seq[ExpressionGrammar[T]]) extends ExpressionGr case g => Seq(g) } - def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Gen] = + def computeProductions(t: T)(implicit ctx: LeonContext): Seq[Prod] = subGrammars.flatMap(_.getProductions(t)) } diff --git a/src/main/scala/leon/invariant/engine/RefinementEngine.scala b/src/main/scala/leon/invariant/engine/RefinementEngine.scala index f60a86af0228c30db578c2d288cb3b5c4de125f9..9502cefad49fb3a44242194e880caeaca74b6dab 100644 --- a/src/main/scala/leon/invariant/engine/RefinementEngine.scala +++ b/src/main/scala/leon/invariant/engine/RefinementEngine.scala @@ -5,7 +5,7 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ -import purescala.TypeOps._ +import purescala.TypeOps.instantiateType import purescala.Extractors._ import purescala.Types._ import java.io._ diff --git a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala index d38bdce1f781e9f438ff2bb273b0e2911e0c6ba7..1398270d0a39ffad5d0e7cfa0a8017b679b5dff6 100644 --- a/src/main/scala/leon/invariant/engine/SpecInstantiator.scala +++ b/src/main/scala/leon/invariant/engine/SpecInstantiator.scala @@ -5,6 +5,7 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.ExprOps._ +import leon.purescala.TypeOps.instantiateType import purescala.Extractors._ import purescala.Types._ import java.io._ @@ -108,7 +109,6 @@ class SpecInstantiator(ctx: InferenceContext, program: Program, ctrTracker: Cons resetUntempCalls(formula.fd, newUntemplatedCalls ++ calls) } - import leon.purescala.TypeOps._ def specForCall(call: Call): Option[Expr] = { val argmap = formalToActual(call) val tfd = call.fi.tfd diff --git a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala index f1fedec33f8502ba17e6364c75941bded8fbd21a..8a5347ad279139b808d72bc967ce2dcfe3d76a15 100644 --- a/src/main/scala/leon/laziness/LazinessEliminationPhase.scala +++ b/src/main/scala/leon/laziness/LazinessEliminationPhase.scala @@ -13,7 +13,6 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Extractors._ import purescala.Types._ -import purescala.TypeOps._ import leon.invariant.util.TypeUtil._ import leon.invariant.util.LetTupleSimplification._ import leon.verification.VerificationPhase diff --git a/src/main/scala/leon/laziness/LazyClosureConverter.scala b/src/main/scala/leon/laziness/LazyClosureConverter.scala index 00d2bbaa88cb979f8f610600c0f4fcc12b60f04d..84d6344176d44f50ebe5df6e528340b8e22f6fec 100644 --- a/src/main/scala/leon/laziness/LazyClosureConverter.scala +++ b/src/main/scala/leon/laziness/LazyClosureConverter.scala @@ -27,7 +27,7 @@ import leon.TransformationPhase import LazinessUtil._ import ProgramUtil._ import PredicateUtil._ -import purescala.TypeOps._ +import purescala.TypeOps.bestRealType /** * (a) add state to every function in the program @@ -780,7 +780,7 @@ class LazyClosureConverter(p: Program, ctx: LeonContext, transformCaseClasses assignBodiesToFunctions assignContractsForEvals - addDefs( + ProgramUtil.addDefs( copyProgram(p, (defs: Seq[Definition]) => defs.flatMap { case fd: FunDef if funMap.contains(fd) => diff --git a/src/main/scala/leon/laziness/LazyClosureFactory.scala b/src/main/scala/leon/laziness/LazyClosureFactory.scala index a36a6d8431de2e350abe8b0a49c971a1d146ba01..faab5ef4dc21771530c319d80e53a55108f262af 100644 --- a/src/main/scala/leon/laziness/LazyClosureFactory.scala +++ b/src/main/scala/leon/laziness/LazyClosureFactory.scala @@ -13,7 +13,6 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Extractors._ import purescala.Types._ -import purescala.TypeOps._ import leon.invariant.util.TypeUtil._ import leon.invariant.util.LetTupleSimplification._ import java.io.File diff --git a/src/main/scala/leon/laziness/LazyExpressionLifter.scala b/src/main/scala/leon/laziness/LazyExpressionLifter.scala index 17cb6e858cb38f765576f7e4dd5e48a5df59b422..acc5598e3a3f96abf1cb27673aeb4cecb61b7708 100644 --- a/src/main/scala/leon/laziness/LazyExpressionLifter.scala +++ b/src/main/scala/leon/laziness/LazyExpressionLifter.scala @@ -178,7 +178,7 @@ object LazyExpressionLifter { case d => d }) val progWithClasses = - if (createUniqueIds) addDefs(progWithFuns, fvClasses, anchorDef.get) + if (createUniqueIds) ProgramUtil.addDefs(progWithFuns, fvClasses, anchorDef.get) else progWithFuns if (!newfuns.isEmpty) { val modToNewDefs = newfuns.values.groupBy(_._2).map { case (k, v) => (k, v.map(_._1)) }.toMap diff --git a/src/main/scala/leon/laziness/LazyVerificationPhase.scala b/src/main/scala/leon/laziness/LazyVerificationPhase.scala index d9b4c18e1f55de015e7009819946bfc8cbc700b7..242e07338f1832fe9c5d898b1c6c64a4a857f4cf 100644 --- a/src/main/scala/leon/laziness/LazyVerificationPhase.scala +++ b/src/main/scala/leon/laziness/LazyVerificationPhase.scala @@ -13,7 +13,6 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Extractors._ import purescala.Types._ -import purescala.TypeOps._ import leon.invariant.util.TypeUtil._ import leon.invariant.util.LetTupleSimplification._ import leon.verification.VerificationPhase @@ -118,7 +117,7 @@ object LazyVerificationPhase { if (debugInferProgram) prettyPrintProgramToFile(inferctx.inferProgram, checkCtx, "-inferProg", true) - val results = (new InferenceEngine(inferctx)).analyseProgram(inferctx.inferProgram, + val results = (new InferenceEngine(inferctx)).analyseProgram(inferctx.inferProgram, funsToCheck.map(InstUtil.userFunctionName), vcSolver, None) new InferenceReport(results.map { case (fd, ic) => (fd -> List[VC](ic)) }, inferctx.inferProgram)(inferctx) } else { diff --git a/src/main/scala/leon/purescala/CheckADTFieldsTypes.scala b/src/main/scala/leon/purescala/CheckADTFieldsTypes.scala deleted file mode 100644 index d4583e55311cfaf843983a3f8af70ac46f7b3675..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/purescala/CheckADTFieldsTypes.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package purescala - -import Definitions._ -import TypeOps._ - -object CheckADTFieldsTypes extends UnitPhase[Program] { - - val name = "ADT Fields" - val description = "Check that fields of ADTs are hierarchy roots" - - def apply(ctx: LeonContext, program: Program) = { - program.definedClasses.foreach { - case ccd: CaseClassDef => - for(vd <- ccd.fields) { - val tpe = vd.getType - if (bestRealType(tpe) != tpe) { - ctx.reporter.warning(ccd.getPos, "Definition of "+ccd.id.asString(ctx)+" has a field of a sub-type ("+vd.asString(ctx)+"): " + - "this type is not supported as-is by solvers and will be up-cast. " + - "This may cause issues such as crashes.") - } - } - case _ => - } - } - -} diff --git a/src/main/scala/leon/purescala/Common.scala b/src/main/scala/leon/purescala/Common.scala index 63ec4a7d1d38245ba7b835524b6c5cd2aec4efed..8b9ecba6ce5d18b7e431500cf7ce86f4e2dd6d10 100644 --- a/src/main/scala/leon/purescala/Common.scala +++ b/src/main/scala/leon/purescala/Common.scala @@ -70,6 +70,10 @@ object Common { def toVariable: Variable = Variable(this) def freshen: Identifier = FreshIdentifier(name, tpe, alwaysShowUniqueID).copiedFrom(this) + + def duplicate(name: String = name, tpe: TypeTree = tpe, alwaysShowUniqueID: Boolean = alwaysShowUniqueID) = { + FreshIdentifier(name, tpe, alwaysShowUniqueID) + } override def compare(that: Identifier): Int = { val ord = implicitly[Ordering[(String, Int, Int)]] diff --git a/src/main/scala/leon/purescala/Constructors.scala b/src/main/scala/leon/purescala/Constructors.scala index e3ce7d1e01780ce02ba424396dabd3670f426966..95218cb2857a97c6834ac039f20c426a0736be8f 100644 --- a/src/main/scala/leon/purescala/Constructors.scala +++ b/src/main/scala/leon/purescala/Constructors.scala @@ -131,7 +131,7 @@ object Constructors { */ def caseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier): Expr = { caseClass match { - case CaseClass(ct, fields) if ct.classDef == classType.classDef => + case CaseClass(ct, fields) if ct.classDef == classType.classDef && !ct.classDef.hasInvariant => fields(ct.classDef.selectorID2Index(selector)) case _ => CaseClassSelector(classType, caseClass, selector) @@ -317,6 +317,13 @@ object Constructors { } /** $encodingof simplified `fn(realArgs)` (function application). + * Transforms + * {{{ ((x: A, y: B) => g(x, y))(c, d) }}} + * into + * {{{val x0 = c + * val y0 = d + * g(x0, y0)}}} + * and further simplifies it. * @see [[purescala.Expressions.Lambda Lambda]] * @see [[purescala.Expressions.Application Application]] */ @@ -338,6 +345,7 @@ object Constructors { val (ids, bds) = defs.unzip letTuple(ids, tupleWrap(bds), replaceFromIDs(subst, body)) + case _ => Application(fn, realArgs) } diff --git a/src/main/scala/leon/purescala/DefOps.scala b/src/main/scala/leon/purescala/DefOps.scala index 81ebbdbec38e1484baf106eac9652d62297ae493..ea4bff382d1d20b52ba27823cdf92dbaf62d851c 100644 --- a/src/main/scala/leon/purescala/DefOps.scala +++ b/src/main/scala/leon/purescala/DefOps.scala @@ -4,7 +4,10 @@ package leon.purescala import Definitions._ import Expressions._ +import Common.Identifier import ExprOps.{preMap, functionCallsOf} +import leon.purescala.Types.AbstractClassType +import leon.purescala.Types._ object DefOps { @@ -274,78 +277,393 @@ object DefOps { case _ => None } - + + /** Clones the given program by replacing some functions by other functions. + * + * @param p The original program + * @param fdMapF Given f, returns Some(g) if f should be replaced by g, and None if f should be kept. + * @param fiMapF Given a previous function invocation and its new function definition, returns the expression to use. + * By default it is the function invocation using the new function definition. + * @return the new program with a map from the old functions to the new functions */ def replaceFunDefs(p: Program)(fdMapF: FunDef => Option[FunDef], - fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { - - var fdMapCache = Map[FunDef, Option[FunDef]]() + fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) + : (Program, Map[FunDef, FunDef])= { + + var fdMapFCache = Map[FunDef, Option[FunDef]]() // Original fdMapF cache + var fdMapCache = Map[FunDef, Option[FunDef]]() // Final replacement. + def fdMapFCached(fd: FunDef): Option[FunDef] = { + fdMapFCache.get(fd) match { + case Some(e) => e + case None => + val new_fd = fdMapF(fd) + fdMapFCache += fd -> new_fd + new_fd + } + } + + def duplicateParents(fd: FunDef): Unit = { + fdMapCache.get(fd) match { + case None => + fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate())) + for(fp <- p.callGraph.callers(fd)) { + duplicateParents(fp) + } + case _ => + } + } + def fdMap(fd: FunDef): FunDef = { - if (!(fdMapCache contains fd)) { - fdMapCache += fd -> fdMapF(fd) + fdMapCache.get(fd) match { + case Some(Some(e)) => e + case Some(None) => fd + case None => + if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) { + duplicateParents(fd) + } else { // Verify that for all + fdMapCache += fd -> None + } + fdMapCache(fd).getOrElse(fd) } - - fdMapCache(fd).getOrElse(fd) } - val newP = p.copy(units = for (u <- p.units) yield { u.copy( defs = u.defs.map { case m : ModuleDef => m.copy(defs = for (df <- m.defs) yield { df match { - case f : FunDef => - val newF = fdMap(f) - newF.fullBody = replaceFunCalls(newF.fullBody, fdMap, fiMapF) - newF - case d => - d + case f : FunDef => fdMap(f) + case d => d } }) case d => d } ) }) - + + for(fd <- newP.definedFunctions) { + if(ExprOps.exists{ + case FunctionInvocation(TypedFunDef(fd, targs), fargs) => fdMapCache contains fd + case MatchExpr(_, cases) => cases.exists(c => PatternOps.exists{ + case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => fdMapCache contains fd + case _ => false + }(c.pattern)) + case _ => false + }(fd.fullBody)) { + fd.fullBody = replaceFunCalls(fd.fullBody, fdMap, fiMapF) + } + } (newP, fdMapCache.collect{ case (ofd, Some(nfd)) => ofd -> nfd }) } - def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap) = { + def replaceFunCalls(e: Expr, fdMapF: FunDef => FunDef, fiMapF: (FunctionInvocation, FunDef) => Option[Expr] = defaultFiMap): Expr = { preMap { + case MatchExpr(scrut, cases) => + Some(MatchExpr(scrut, cases.map(matchcase => matchcase match { + case MatchCase(pattern, guard, rhs) => MatchCase(replaceFunCalls(pattern, fdMapF), guard, rhs) + }))) case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => fiMapF(fi, fdMapF(fd)).map(_.setPos(fi)) case _ => None }(e) } + + def replaceFunCalls(p: Pattern, fdMapF: FunDef => FunDef): Pattern = PatternOps.preMap{ + case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId, TypedFunDef(fdMapF(fd), tps), subp)) + case _ => None + }(p) + + private def defaultCdMap(cc: CaseClass, ccd: CaseClassType): Option[Expr] = (cc, ccd) match { + case (CaseClass(old, args), newCcd) if old.classDef != newCcd => + Some(CaseClass(newCcd, args)) + case _ => + None + } + + /** Clones the given program by replacing some classes by other classes. + * + * @param p The original program + * @param cdMapF Given c returns Some(d) where d can take an abstract parent and return a class e if c should be replaced by e, and None if c should be kept. + * @param ciMapF Given a previous case class invocation and its new case class definition, returns the expression to use. + * By default it is the case class construction using the new case class definition. + * @return the new program with a map from the old case classes to the new case classes, with maps concerning identifiers and function definitions. */ + def replaceCaseClassDefs(p: Program)(_cdMapF: CaseClassDef => Option[Option[AbstractClassType] => CaseClassDef], + ciMapF: (CaseClass, CaseClassType) => Option[Expr] = defaultCdMap) + : (Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = { + var cdMapFCache = Map[CaseClassDef, Option[Option[AbstractClassType] => CaseClassDef]]() + var cdMapCache = Map[ClassDef, Option[ClassDef]]() + var idMapCache = Map[Identifier, Identifier]() + var fdMapFCache = Map[FunDef, Option[FunDef]]() + var fdMapCache = Map[FunDef, Option[FunDef]]() + def cdMapF(cd: ClassDef): Option[Option[AbstractClassType] => CaseClassDef] = { + cd match { + case ccd: CaseClassDef => + cdMapFCache.getOrElse(ccd, { + val new_cd_potential = _cdMapF(ccd) + cdMapFCache += ccd -> new_cd_potential + new_cd_potential + }) + case acd: AbstractClassDef => None + } + } + def tpMap[T <: TypeTree](tt: T): T = TypeOps.postMap{ + case AbstractClassType(asd, targs) => Some(AbstractClassType(cdMap(asd).asInstanceOf[AbstractClassDef], targs)) + case CaseClassType(ccd, targs) => Some(CaseClassType(cdMap(ccd).asInstanceOf[CaseClassDef], targs)) + case e => None + }(tt).asInstanceOf[T] + + def duplicateClassDef(cd: ClassDef): ClassDef = { + cdMapCache.get(cd) match { + case Some(new_cd) => + new_cd.get // None would have meant that this class would never be duplicated, which is not possible. + case None => + val parent = cd.parent.map(duplicateAbstractClassType) + val new_cd = cdMapF(cd).map(f => f(parent)).getOrElse{ + cd match { + case acd:AbstractClassDef => acd.duplicate(parent = parent) + case ccd:CaseClassDef => + ccd.duplicate(parent = parent, fields = ccd.fieldsIds.map(id => ValDef(idMap(id)))) // Should not cycle since fields have to be abstract. + } + } + cdMapCache += cd -> Some(new_cd) + new_cd + } + } + + def duplicateAbstractClassType(act: AbstractClassType): AbstractClassType = { + TypeOps.postMap{ + case AbstractClassType(acd, tps) => Some(AbstractClassType(duplicateClassDef(acd).asInstanceOf[AbstractClassDef], tps)) + case CaseClassType(ccd, tps) => Some(CaseClassType(duplicateClassDef(ccd).asInstanceOf[CaseClassDef], tps)) + case _ => None + }(act).asInstanceOf[AbstractClassType] + } + + // If at least one descendants or known case class needs conversion, then all the hierarchy will be converted. + // If something extends List[A] and A is modified, then the first something should be modified. + def dependencies(s: ClassDef): Set[ClassDef] = { + leon.utils.fixpoint((s: Set[ClassDef]) => s ++ s.flatMap(_.knownDescendants) ++ s.flatMap(_.parent.toList.flatMap(p => TypeOps.collect[ClassDef]{ + case AbstractClassType(acd, _) => Set(acd:ClassDef) ++ acd.knownDescendants + case CaseClassType(ccd, _) => Set(ccd:ClassDef) + case _ => Set() + }(p))))(Set(s)) + } + + def cdMap(cd: ClassDef): ClassDef = { + cdMapCache.get(cd) match { + case Some(Some(new_cd)) => new_cd + case Some(None) => cd + case None => + if(cdMapF(cd).isDefined || dependencies(cd).exists(cd => cdMapF(cd).isDefined)) { // Needs replacement in any case. + duplicateClassDef(cd) + } else { + cdMapCache += cd -> None + } + cdMapCache(cd).getOrElse(cd) + } + } + def idMap(id: Identifier): Identifier = { + if (!(idMapCache contains id)) { + val new_id = id.duplicate(tpe = tpMap(id.getType)) + idMapCache += id -> new_id + } + idMapCache(id) + } + + def idHasToChange(id: Identifier): Boolean = { + typeHasToChange(id.getType) + } + + def typeHasToChange(tp: TypeTree): Boolean = { + TypeOps.exists{ + case AbstractClassType(acd, _) => cdMap(acd) != acd + case CaseClassType(ccd, _) => cdMap(ccd) != ccd + case _ => false + }(tp) + } + + def patternHasToChange(p: Pattern): Boolean = { + PatternOps.exists { + case CaseClassPattern(optId, cct, sub) => optId.exists(idHasToChange) || typeHasToChange(cct) + case InstanceOfPattern(optId, cct) => optId.exists(idHasToChange) || typeHasToChange(cct) + case Extractors.Pattern(optId, subp, builder) => optId.exists(idHasToChange) + case e => false + } (p) + } + + def exprHasToChange(e: Expr): Boolean = { + ExprOps.exists{ + case Let(id, expr, body) => idHasToChange(id) + case Variable(id) => idHasToChange(id) + case ci @ CaseClass(cct, args) => typeHasToChange(cct) + case CaseClassSelector(cct, expr, identifier) => typeHasToChange(cct) || idHasToChange(identifier) + case IsInstanceOf(e, cct) => typeHasToChange(cct) + case AsInstanceOf(e, cct) => typeHasToChange(cct) + case MatchExpr(scrut, cases) => + cases.exists{ + case MatchCase(pattern, optGuard, rhs) => + patternHasToChange(pattern) + } + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + tps.exists(typeHasToChange) + case _ => + false + }(e) + } + + def funDefHasToChange(fd: FunDef): Boolean = { + exprHasToChange(fd.fullBody) || fd.params.exists(vid => typeHasToChange(vid.id.getType)) || typeHasToChange(fd.returnType) + } + + def funHasToChange(fd: FunDef): Boolean = { + funDefHasToChange(fd) || p.callGraph.transitiveCallees(fd).exists(fd => + fdMapFCache.get(fd) match { + case Some(Some(_)) => true + case Some(None) => false + case None => funDefHasToChange(fd) + }) + } + + def fdMapFCached(fd: FunDef): Option[FunDef] = { + fdMapFCache.get(fd) match { + case Some(e) => e + case None => + val new_fd = if(funHasToChange(fd)) { + Some(fd.duplicate(params = fd.params.map(vd => ValDef(idMap(vd.id))), returnType = tpMap(fd.returnType))) + } else { + None + } + fdMapFCache += fd -> new_fd + new_fd + } + } + + def duplicateParents(fd: FunDef): Unit = { + fdMapCache.get(fd) match { + case None => + fdMapCache += fd -> fdMapFCached(fd).orElse(Some(fd.duplicate())) + for(fp <- p.callGraph.callers(fd)) { + duplicateParents(fp) + } + case _ => + } + } + + def fdMap(fd: FunDef): FunDef = { + fdMapCache.get(fd) match { + case Some(Some(e)) => e + case Some(None) => fd + case None => + if(fdMapFCached(fd).isDefined || p.callGraph.transitiveCallees(fd).exists(fd => fdMapFCached(fd).isDefined)) { + duplicateParents(fd) + } else { + fdMapCache += fd -> None + } + fdMapCache(fd).getOrElse(fd) + } + } + + val newP = p.copy(units = for (u <- p.units) yield { + u.copy( + defs = u.defs.map { + case m : ModuleDef => + m.copy(defs = for (df <- m.defs) yield { + df match { + case cd : ClassDef => cdMap(cd) + case fd : FunDef => fdMap(fd) + case d => d + } + }) + case d => d + } + ) + }) + def replaceClassDefUse(e: Pattern): Pattern = PatternOps.postMap{ + case CaseClassPattern(optId, cct, sub) => Some(CaseClassPattern(optId.map(idMap), tpMap[CaseClassType](cct), sub)) + case InstanceOfPattern(optId, cct) => Some(InstanceOfPattern(optId.map(idMap), tpMap[ClassType](cct))) + case UnapplyPattern(optId, TypedFunDef(fd, tps), subp) => Some(UnapplyPattern(optId.map(idMap), TypedFunDef(fdMap(fd), tps.map(tpMap)), subp)) + case Extractors.Pattern(Some(id), subp, builder) => Some(builder(Some(idMap(id)), subp)) + case e => None + }(e) + + def replaceClassDefsUse(e: Expr): Expr = { + ExprOps.postMap { + case Let(id, expr, body) => Some(Let(idMap(id), expr, body)) + case Lambda(vd, body) => Some(Lambda(vd.map(vd => ValDef(idMap(vd.id))), body)) + case Variable(id) => Some(Variable(idMap(id))) + case ci @ CaseClass(ct, args) => + ciMapF(ci, tpMap(ct)).map(_.setPos(ci)) + case CaseClassSelector(cct, expr, identifier) => + val new_cct = tpMap(cct) + val selection = (if(new_cct != cct || new_cct.classDef.fieldsIds != cct.classDef.fieldsIds) idMap(identifier) else identifier) + Some(CaseClassSelector(new_cct, expr, selection)) + case IsInstanceOf(e, ct) => Some(IsInstanceOf(e, tpMap(ct))) + case AsInstanceOf(e, ct) => Some(AsInstanceOf(e, tpMap(ct))) + case MatchExpr(scrut, cases) => + Some(MatchExpr(scrut, cases.map{ + case MatchCase(pattern, optGuard, rhs) => + MatchCase(replaceClassDefUse(pattern), optGuard, rhs) + })) + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + defaultFiMap(fi, fdMap(fd)).map(_.setPos(fi)) + case _ => + None + }(e) + } + + for(fd <- newP.definedFunctions) { + if(fdMapCache.getOrElse(fd, None).isDefined) { + fd.fullBody = replaceClassDefsUse(fd.fullBody) + } + } + (newP, + cdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd}, + idMapCache, + fdMapCache.collect{case (cd, Some(new_cd)) => cd -> new_cd }) + } + + - def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = { + def addDefs(p: Program, cds: Traversable[Definition], after: Definition): Program = { var found = false val res = p.copy(units = for (u <- p.units) yield { u.copy( - defs = u.defs.map { + defs = u.defs.flatMap { case m: ModuleDef => val newdefs = for (df <- m.defs) yield { df match { case `after` => found = true - after +: fds.toSeq - case d => - Seq(d) + after +: cds.toSeq + case d => Seq(d) } } - m.copy(defs = newdefs.flatten) - case d => d + Seq(m.copy(defs = newdefs.flatten)) + case `after` => + found = true + after +: cds.toSeq + case d => Seq(d) } ) }) + if (!found) { - println("addFunDefs could not find anchor function!") + println(s"addDefs could not find anchor definition! Not found: $after") + p.definedFunctions.filter(f => f.id.name == after.id.name).map(fd => fd.id.name + " : " + fd) match { + case Nil => + case e => println("Did you mean " + e) + } + println(Thread.currentThread().getStackTrace.map(_.toString).take(10).mkString("\n")) } + res } + def addFunDefs(p: Program, fds: Traversable[FunDef], after: FunDef): Program = addDefs(p, fds, after) + + def addClassDefs(p: Program, fds: Traversable[ClassDef], after: ClassDef): Program = addDefs(p, fds, after) + // @Note: This function does not filter functions in classdefs def filterFunDefs(p: Program, fdF: FunDef => Boolean): Program = { p.copy(units = p.units.map { u => diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index 8fb753d23d35364059115f7032bef3310c492d82..733eaf124dbb6dd8499c81347f7fc864c23ca453 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -182,7 +182,6 @@ object Definitions { lazy val singleCaseClasses : Seq[CaseClassDef] = defs.collect { case c @ CaseClassDef(_, _, None, _) => c } - } // A class that represents flags that annotate a FunDef with different attributes @@ -217,7 +216,8 @@ object Definitions { case object IsSynthetic extends FunctionFlag // Is inlined case object IsInlined extends FunctionFlag - + // Is an ADT invariant method + case object IsADTInvariant extends FunctionFlag with ClassFlag /** Useful because case classes and classes are somewhat unified in some * patterns (of pattern-matching, that is) */ @@ -268,6 +268,15 @@ object Definitions { def flags = _flags + private var _invariant: Option[FunDef] = None + + def invariant = _invariant + def hasInvariant = flags contains IsADTInvariant + def setInvariant(fd: FunDef): Unit = { + addFlag(IsADTInvariant) + _invariant = Some(fd) + } + def annotations: Set[String] = extAnnotations.keySet def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap @@ -289,6 +298,23 @@ object Definitions { ccd } + def isInductive: Boolean = { + def induct(tpe: TypeTree, seen: Set[ClassDef]): Boolean = tpe match { + case ct: ClassType => + val root = ct.classDef.root + seen(root) || ct.fields.forall(vd => induct(vd.getType, seen + root)) + case TupleType(tpes) => + tpes.forall(tpe => induct(tpe, seen)) + case _ => true + } + + if (this == root && !this.isAbstract) false + else if (this != root) root.isInductive + else knownCCDescendants.forall { ccd => + ccd.fields.forall(vd => induct(vd.getType, Set(root))) + } + } + val isAbstract: Boolean val isCaseObject: Boolean @@ -315,6 +341,20 @@ object Definitions { AbstractClassType(this, tps) } def typed: AbstractClassType = typed(tparams.map(_.tp)) + + /** Duplication of this [[CaseClassDef]]. + * @note This will not add known case class children + */ + def duplicate( + id: Identifier = this.id.freshen, + tparams: Seq[TypeParameterDef] = this.tparams, + parent: Option[AbstractClassType] = this.parent + ): AbstractClassDef = { + val acd = new AbstractClassDef(id, tparams, parent) + acd.addFlags(this.flags) + parent.map(_.classDef.ancestors.map(_.registerChild(acd))) + acd.copiedFrom(this) + } } /** Case classes/objects. */ @@ -351,6 +391,24 @@ object Definitions { CaseClassType(this, tps) } def typed: CaseClassType = typed(tparams.map(_.tp)) + + /** Duplication of this [[CaseClassDef]]. + * @note This will not replace recursive case class def calls in [[arguments]] nor the parent abstract class types + */ + def duplicate( + id: Identifier = this.id.freshen, + tparams: Seq[TypeParameterDef] = this.tparams, + fields: Seq[ValDef] = this.fields, + parent: Option[AbstractClassType] = this.parent, + isCaseObject: Boolean = this.isCaseObject + ): CaseClassDef = { + val cd = new CaseClassDef(id, tparams, parent, isCaseObject) + cd.setFields(fields) + cd.addFlags(this.flags) + cd.copiedFrom(this) + parent.map(_.classDef.ancestors.map(_.registerChild(cd))) + cd + } } /** Function/method definition. @@ -442,6 +500,7 @@ object Definitions { def canBeField = canBeLazyField || canBeStrictField def isRealFunction = !canBeField def isSynthetic = flags contains IsSynthetic + def isInvariant = flags contains IsADTInvariant def methodOwner = flags collectFirst { case IsMethod(cd) => cd } /* Wrapping in TypedFunDef */ diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 49d945ec6217a42578da7fef97bf572252cb6c26..ca52d4592842da60629423e9d0ba3ff52ea2a73c 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -19,285 +19,19 @@ import solvers._ * * The generic operations lets you apply operations on a whole tree * expression. You can look at: - * - [[ExprOps.fold foldRight]] - * - [[ExprOps.preTraversal preTraversal]] - * - [[ExprOps.postTraversal postTraversal]] - * - [[ExprOps.preMap preMap]] - * - [[ExprOps.postMap postMap]] - * - [[ExprOps.genericTransform genericTransform]] + * - [[SubTreeOps.fold foldRight]] + * - [[SubTreeOps.preTraversal preTraversal]] + * - [[SubTreeOps.postTraversal postTraversal]] + * - [[SubTreeOps.preMap preMap]] + * - [[SubTreeOps.postMap postMap]] + * - [[SubTreeOps.genericTransform genericTransform]] * * These operations usually take a higher order function that gets applied to the * expression tree in some strategy. They provide an expressive way to build complex * operations on Leon expressions. * */ -object ExprOps { - - /* ======== - * Core API - * ======== - * - * All these functions should be stable, tested, and used everywhere. Modify - * with care. - */ - - - /** Does a right tree fold - * - * A right tree fold applies the input function to the subnodes first (from left - * to right), and combine the results along with the current node value. - * - * @param f a function that takes the current node and the seq - * of results form the subtrees. - * @param e The Expr on which to apply the fold. - * @return The expression after applying `f` on all subtrees. - * @note the computation is lazy, hence you should not rely on side-effects of `f` - */ - def fold[T](f: (Expr, Seq[T]) => T)(e: Expr): T = { - val rec = fold(f) _ - val Operator(es, _) = e - - //Usages of views makes the computation lazy. (which is useful for - //contains-like operations) - f(e, es.view.map(rec)) - } - - /** Pre-traversal of the tree. - * - * Invokes the input function on every node '''before''' visiting - * children. Traverse children from left to right subtrees. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) - * }}} - * will yield, in order: - * {{{ - * f(Add(a, Minus(b, c))); f(a); f(Minus(b, c)); f(b); f(c) - * }}} - * - * @param f a function to apply on each node of the expression - * @param e the expression to traverse - */ - def preTraversal(f: Expr => Unit)(e: Expr): Unit = { - val rec = preTraversal(f) _ - val Operator(es, _) = e - f(e) - es.foreach(rec) - } - - /** Post-traversal of the tree. - * - * Invokes the input function on every node '''after''' visiting - * children. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) - * }}} - * will yield, in order: - * {{{ - * f(a), f(b), f(c), f(Minus(b, c)), f(Add(a, Minus(b, c))) - * }}} - * - * @param f a function to apply on each node of the expression - * @param e the expression to traverse - */ - def postTraversal(f: Expr => Unit)(e: Expr): Unit = { - val rec = postTraversal(f) _ - val Operator(es, _) = e - es.foreach(rec) - f(e) - } - - /** Pre-transformation of the tree. - * - * Takes a partial function of replacements and substitute - * '''before''' recursing down the trees. - * - * Supports two modes : - * - * - If applyRec is false (default), will only substitute once on each level. - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, d) // And not Add(a, f) because it only substitute once for each level. - * }}} - * - * - If applyRec is true, it will substitute multiple times on each level: - * - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, f) - * }}} - * - * @note The mode with applyRec true can diverge if f is not well formed - */ - def preMap(f: Expr => Option[Expr], applyRec : Boolean = false)(e: Expr): Expr = { - val rec = preMap(f, applyRec) _ - - val newV = if (applyRec) { - // Apply f as long as it returns Some() - fixpoint { e : Expr => f(e) getOrElse e } (e) - } else { - f(e) getOrElse e - } - - val Operator(es, builder) = newV - val newEs = es.map(rec) - - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(newV) - } else { - newV - } - } - - /** Post-transformation of the tree. - * - * Takes a partial function of replacements. - * Substitutes '''after''' recursing down the trees. - * - * Supports two modes : - * - * - If applyRec is false (default), will only substitute once on each level. - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e - * }}} - * will yield: - * {{{ - * Add(a, Minus(e, c)) - * }}} - * - * - If applyRec is true, it will substitute multiple times on each level: - * e.g. - * {{{ - * Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f - * }}} - * will yield: - * {{{ - * Add(a, f) - * }}} - * - * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent) - */ - def postMap(f: Expr => Option[Expr], applyRec : Boolean = false)(e: Expr): Expr = { - val rec = postMap(f, applyRec) _ - - val Operator(es, builder) = e - val newEs = es.map(rec) - val newV = { - if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { - builder(newEs).copiedFrom(e) - } else { - e - } - } - - if (applyRec) { - // Apply f as long as it returns Some() - fixpoint { e : Expr => f(e) getOrElse e } (newV) - } else { - f(newV) getOrElse newV - } - - } - - - /** Applies functions and combines results in a generic way - * - * Start with an initial value, and apply functions to nodes before - * and after the recursion in the children. Combine the results of - * all children and apply a final function on the resulting node. - * - * @param pre a function applied on a node before doing a recursion in the children - * @param post a function applied to the node built from the recursive application to - all children - * @param combiner a function to combine the resulting values from all children with - the current node - * @param init the initial value - * @param expr the expression on which to apply the transform - * - * @see [[simpleTransform]] - * @see [[simplePreTransform]] - * @see [[simplePostTransform]] - */ - def genericTransform[C](pre: (Expr, C) => (Expr, C), - post: (Expr, C) => (Expr, C), - combiner: (Expr, Seq[C]) => C)(init: C)(expr: Expr) = { - - def rec(eIn: Expr, cIn: C): (Expr, C) = { - - val (expr, ctx) = pre(eIn, cIn) - val Operator(es, builder) = expr - val (newExpr, newC) = { - val (nes, cs) = es.map{ rec(_, ctx)}.unzip - val newE = builder(nes).copiedFrom(expr) - - (newE, combiner(newE, cs)) - } - - post(newExpr, newC) - } - - rec(expr, init) - } - - /* - * ============= - * Auxiliary API - * ============= - * - * Convenient methods using the Core API. - */ - - /** Checks if the predicate holds in some sub-expression */ - def exists(matcher: Expr => Boolean)(e: Expr): Boolean = { - fold[Boolean]({ (e, subs) => matcher(e) || subs.contains(true) } )(e) - } - - /** Collects a set of objects from all sub-expressions */ - def collect[T](matcher: Expr => Set[T])(e: Expr): Set[T] = { - fold[Set[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) - } - - def collectPreorder[T](matcher: Expr => Seq[T])(e: Expr): Seq[T] = { - fold[Seq[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) - } - - /** Returns a set of all sub-expressions matching the predicate */ - def filter(matcher: Expr => Boolean)(e: Expr): Set[Expr] = { - collect[Expr] { e => Set(e) filter matcher }(e) - } - - /** Counts how many times the predicate holds in sub-expressions */ - def count(matcher: Expr => Int)(e: Expr): Int = { - fold[Int]({ (e, subs) => matcher(e) + subs.sum } )(e) - } - - /** Replaces bottom-up sub-expressions by looking up for them in a map */ - def replace(substs: Map[Expr,Expr], expr: Expr) : Expr = { - postMap(substs.lift)(expr) - } - - /** Replaces bottom-up sub-expressions by looking up for them in the provided order */ - def replaceSeq(substs: Seq[(Expr, Expr)], expr: Expr): Expr = { - var res = expr - for (s <- substs) { - res = replace(Map(s), res) - } - res - } - +object ExprOps extends { val Deconstructor = Operator } with SubTreeOps[Expr] { /** Replaces bottom-up sub-identifiers by looking up for them in a map */ def replaceFromIDs(substs: Map[Identifier, Expr], expr: Expr) : Expr = { postMap({ @@ -332,7 +66,7 @@ object ExprOps { Lambda(args, rec(binders ++ args.map(_.id), bd)) case Forall(args, bd) => Forall(args, rec(binders ++ args.map(_.id), bd)) - case Operator(subs, builder) => + case Deconstructor(subs, builder) => builder(subs map (rec(binders, _))) }).copiedFrom(e) @@ -341,7 +75,7 @@ object ExprOps { /** Returns the set of free variables in an expression */ def variablesOf(expr: Expr): Set[Identifier] = { - import leon.xlang.Expressions.LetVar + import leon.xlang.Expressions._ fold[Set[Identifier]] { case (e, subs) => val subvs = subs.flatten.toSet @@ -375,6 +109,13 @@ object ExprOps { case _ => Set() }(expr) } + + def nestedFunDefsOf(expr: Expr): Set[FunDef] = { + collect[FunDef] { + case LetDef(fds, _) => fds.toSet + case _ => Set() + }(expr) + } /** Returns functions in directly nested LetDefs */ def directlyNestedFunDefs(e: Expr): Set[FunDef] = { @@ -442,7 +183,7 @@ object ExprOps { case l @ Let(i,e,b) => val newID = FreshIdentifier(i.name, i.getType, alwaysShowUniqueID = true).copiedFrom(i) - Some(Let(newID, e, replace(Map(Variable(i) -> Variable(newID)), b))) + Some(Let(newID, e, replaceFromIDs(Map(i -> Variable(newID)), b))) case _ => None }(expr) @@ -597,7 +338,7 @@ object ExprOps { def simplerLet(t: Expr) : Option[Expr] = t match { case letExpr @ Let(i, t: Terminal, b) if isDeterministic(b) => - Some(replace(Map(Variable(i) -> t), b)) + Some(replaceFromIDs(Map(i -> t), b)) case letExpr @ Let(i,e,b) if isDeterministic(b) => { val occurrences = count { @@ -608,7 +349,7 @@ object ExprOps { if(occurrences == 0) { Some(b) } else if(occurrences == 1) { - Some(replace(Map(Variable(i) -> e), b)) + Some(replaceFromIDs(Map(i -> e), b)) } else { None } @@ -619,7 +360,7 @@ object ExprOps { val (remIds, remExprs) = (ids zip exprs).filter { case (id, value: Terminal) => - newBody = replace(Map(Variable(id) -> value), newBody) + newBody = replaceFromIDs(Map(id -> value), newBody) //we replace, so we drop old false case (id, value) => @@ -695,7 +436,7 @@ object ExprOps { case i @ IfExpr(t1,t2,t3) => IfExpr(rec(t1, s),rec(t2, s),rec(t3, s)) case m @ MatchExpr(scrut, cses) => matchExpr(rec(scrut, s), cses.map(inCase(_, s))).setPos(m) case p @ Passes(in, out, cses) => Passes(rec(in, s), rec(out,s), cses.map(inCase(_, s))).setPos(p) - case n @ Operator(args, recons) => { + case n @ Deconstructor(args, recons) => { var change = false val rargs = args.map(a => { val ra = rec(a, s) @@ -1143,7 +884,7 @@ object ExprOps { GenericValue(tp, 0) case ft @ FunctionType(from, to) => - PartialLambda(Seq.empty, Some(simplestValue(to)), ft) + FiniteLambda(Seq.empty, simplestValue(to), ft) case _ => throw LeonFatalError("I can't choose simplest value for type " + tpe) } @@ -1204,7 +945,7 @@ object ExprOps { def transform(expr: Expr): Option[Expr] = expr match { case IfExpr(c, t, e) => None - case nop@Operator(ts, op) => { + case nop@Deconstructor(ts, op) => { val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } if(iteIndex == -1) None else { val (beforeIte, startIte) = ts.splitAt(iteIndex) @@ -1355,7 +1096,7 @@ object ExprOps { formulaSize(rhs) + og.map(formulaSize).getOrElse(0) + patternSize(p) }.sum - case Operator(es, _) => + case Deconstructor(es, _) => es.map(formulaSize).sum+1 } @@ -1449,6 +1190,12 @@ object ExprOps { case Minus(Plus(e1, e2), Plus(e3, e4)) if e1 == e4 && e2 == e3 => InfiniteIntegerLiteral(0) case Minus(Plus(e1, e2), Plus(Plus(e3, e4), e5)) if e1 == e4 && e2 == e3 => UMinus(e5) + case StringConcat(StringLiteral(""), a) => a + case StringConcat(a, StringLiteral("")) => a + case StringConcat(StringLiteral(a), StringLiteral(b)) => StringLiteral(a+b) + case StringConcat(StringLiteral(a), StringConcat(StringLiteral(b), c)) => StringConcat(StringLiteral(a+b), c) + case StringConcat(StringConcat(c, StringLiteral(a)), StringLiteral(b)) => StringConcat(c, StringLiteral(a+b)) + case StringConcat(a, StringConcat(b, c)) => StringConcat(StringConcat(a, b), c) //default case e => e }).copiedFrom(expr) @@ -1537,6 +1284,179 @@ object ExprOps { case _ => false } + + /** Checks whether two expressions can be homomorphic and returns the corresponding mapping */ + def canBeHomomorphic(t1: Expr, t2: Expr): Option[Map[Identifier, Identifier]] = { + val freeT1Variables = ExprOps.variablesOf(t1) + val freeT2Variables = ExprOps.variablesOf(t2) + + def mergeContexts(a: Option[Map[Identifier, Identifier]], b: =>Option[Map[Identifier, Identifier]]) = a match { + case Some(m) => + b match { + case Some(n) if (m.keySet & n.keySet) forall (key => m(key) == n(key)) => + Some(m ++ n) + case _ =>None + } + case _ => None + } + object Same { + def unapply(tt: (Expr, Expr)): Option[(Expr, Expr)] = { + if (tt._1.getClass == tt._2.getClass) { + Some(tt) + } else { + None + } + } + } + implicit class AugmentedContext(c: Option[Map[Identifier, Identifier]]) { + def &&(other: => Option[Map[Identifier, Identifier]]) = mergeContexts(c, other) + def --(other: Seq[Identifier]) = + c.map(_ -- other) + } + implicit class AugmentedBooleant(c: Boolean) { + def &&(other: => Option[Map[Identifier, Identifier]]) = if(c) other else None + } + implicit class AugmentedSeq[T](c: Seq[T]) { + def mergeall(p: T => Option[Map[Identifier, Identifier]]) = + (Option(Map[Identifier, Identifier]()) /: c) { + case (s, c) => s && p(c) + } + } + + + def idHomo(i1: Identifier, i2: Identifier): Option[Map[Identifier, Identifier]] = { + if(!(freeT1Variables(i1) || freeT2Variables(i2)) || i1 == i2) Some(Map(i1 -> i2)) else None + } + + def fdHomo(fd1: FunDef, fd2: FunDef): Option[Map[Identifier, Identifier]] = { + if(fd1.params.size == fd2.params.size) { + val newMap = Map(( + (fd1.id -> fd2.id) +: + (fd1.paramIds zip fd2.paramIds)): _*) + Option(newMap) && isHomo(fd1.fullBody, fd2.fullBody) + } else None + } + + def isHomo(t1: Expr, t2: Expr): Option[Map[Identifier, Identifier]] = { + def casesMatch(cs1 : Seq[MatchCase], cs2 : Seq[MatchCase]) : Option[Map[Identifier, Identifier]] = { + def patternHomo(p1: Pattern, p2: Pattern): (Boolean, Map[Identifier, Identifier]) = (p1, p2) match { + case (InstanceOfPattern(ob1, cd1), InstanceOfPattern(ob2, cd2)) => + (ob1.size == ob2.size && cd1 == cd2, Map((ob1 zip ob2).toSeq : _*)) + + case (WildcardPattern(ob1), WildcardPattern(ob2)) => + (ob1.size == ob2.size, Map((ob1 zip ob2).toSeq : _*)) + + case (CaseClassPattern(ob1, ccd1, subs1), CaseClassPattern(ob2, ccd2, subs2)) => + val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) + + if (ob1.size == ob2.size && ccd1 == ccd2 && subs1.size == subs2.size) { + (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { + case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) + } + } else { + (false, Map()) + } + + case (UnapplyPattern(ob1, fd1, subs1), UnapplyPattern(ob2, fd2, subs2)) => + val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) + + if (ob1.size == ob2.size && fd1 == fd2 && subs1.size == subs2.size) { + (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { + case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) + } + } else { + (false, Map()) + } + + case (TuplePattern(ob1, subs1), TuplePattern(ob2, subs2)) => + val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) + + if (ob1.size == ob2.size && subs1.size == subs2.size) { + (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { + case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) + } + } else { + (false, Map()) + } + + case (LiteralPattern(ob1, lit1), LiteralPattern(ob2,lit2)) => + (ob1.size == ob2.size && lit1 == lit2, (ob1 zip ob2).toMap) + + case _ => + (false, Map()) + } + + (cs1 zip cs2).mergeall { + case (MatchCase(p1, g1, e1), MatchCase(p2, g2, e2)) => + val (h, nm) = patternHomo(p1, p2) + val g: Option[Map[Identifier, Identifier]] = (g1, g2) match { + case (Some(g1), Some(g2)) => Some(nm) && isHomo(g1,g2) + case (None, None) => Some(Map()) + case _ => None + } + val e = Some(nm) && isHomo(e1, e2) + + h && g && e + } + + } + + import synthesis.Witnesses.Terminating + + val res: Option[Map[Identifier, Identifier]] = (t1, t2) match { + case (Variable(i1), Variable(i2)) => + idHomo(i1, i2) + + case (Let(id1, v1, e1), Let(id2, v2, e2)) => + isHomo(v1, v2) && + isHomo(e1, e2) && Some(Map(id1 -> id2)) + + case (LetDef(fds1, e1), LetDef(fds2, e2)) => + fds1.size == fds2.size && + { + val zipped = fds1.zip(fds2) + (zipped mergeall (fds => fdHomo(fds._1, fds._2))) && Some(zipped.map(fds => fds._1.id -> fds._2.id).toMap) && + isHomo(e1, e2) + } + + case (MatchExpr(s1, cs1), MatchExpr(s2, cs2)) => + cs1.size == cs2.size && casesMatch(cs1,cs2) && isHomo(s1, s2) + + case (Passes(in1, out1, cs1), Passes(in2, out2, cs2)) => + (cs1.size == cs2.size && casesMatch(cs1,cs2)) && isHomo(in1,in2) && isHomo(out1,out2) + + case (FunctionInvocation(tfd1, args1), FunctionInvocation(tfd2, args2)) => + idHomo(tfd1.fd.id, tfd2.fd.id) && tfd1.tps.zip(tfd2.tps).mergeall{ case (t1, t2) => if(t1 == t2) Option(Map()) else None} && + (args1 zip args2).mergeall{ case (a1, a2) => isHomo(a1, a2) } + + case (Terminating(tfd1, args1), Terminating(tfd2, args2)) => + idHomo(tfd1.fd.id, tfd2.fd.id) && tfd1.tps.zip(tfd2.tps).mergeall{ case (t1, t2) => if(t1 == t2) Option(Map()) else None} && + (args1 zip args2).mergeall{ case (a1, a2) => isHomo(a1, a2) } + + case (Lambda(defs, body), Lambda(defs2, body2)) => + // We remove variables introduced by lambdas. + (isHomo(body, body2) && + (defs zip defs2).mergeall{ case (ValDef(a1), ValDef(a2)) => Option(Map(a1 -> a2)) } + ) -- (defs.map(_.id)) + + case (v1, v2) if isValue(v1) && isValue(v2) => + v1 == v2 && Some(Map[Identifier, Identifier]()) + + case Same(Operator(es1, _), Operator(es2, _)) => + (es1.size == es2.size) && + (es1 zip es2).mergeall{ case (e1, e2) => isHomo(e1, e2) } + + case _ => + None + } + + res + } + + isHomo(t1,t2) + + + } // ensuring (res => res.isEmpty || isHomomorphic(t1, t2)(res.get)) /** Checks whether two trees are homomoprhic modulo an identifier map. * @@ -1667,9 +1587,10 @@ object ExprOps { fdHomo(tfd1.fd, tfd2.fd) && (args1 zip args2).forall{ case (a1, a2) => isHomo(a1, a2) } - // TODO: Seems a lot is missing, like Literals + case (v1, v2) if isValue(v1) && isValue(v2) => + v1 == v2 - case Same(Operator(es1, _), Operator(es2, _)) => + case Same(Deconstructor(es1, _), Deconstructor(es2, _)) => (es1.size == es2.size) && (es1 zip es2).forall{ case (e1, e2) => isHomo(e1, e2) } @@ -2008,7 +1929,7 @@ object ExprOps { f(e, initParent) - val Operator(es, _) = e + val Deconstructor(es, _) = e es foreach rec } @@ -2071,8 +1992,8 @@ object ExprOps { Let(i, e, apply(b, args)) case LetTuple(is, es, b) => letTuple(is, es, apply(b, args)) - case l@Lambda(params, body) => - l.withParamSubst(args, body) + //case l @ Lambda(params, body) => + // l.withParamSubst(args, body) case _ => Application(expr, args) } @@ -2094,14 +2015,14 @@ object ExprOps { case Application(caller, args) => val newArgs = args.map(rec(_, true)) val newCaller = rec(caller, false) - extract(application(newCaller, newArgs), build) + extract(Application(newCaller, newArgs), build) case FunctionInvocation(fd, args) => val newArgs = args.map(rec(_, true)) extract(FunctionInvocation(fd, newArgs), build) case l @ Lambda(args, body) => val newBody = rec(body, true) extract(Lambda(args, newBody), build) - case Operator(es, recons) => recons(es.map(rec(_, build))) + case Deconstructor(es, recons) => recons(es.map(rec(_, build))) } rec(lift(expr), true) @@ -2129,7 +2050,7 @@ object ExprOps { fds ++= nfds - Some(LetDef(nfds.map(_._2), b)) + Some(letDef(nfds.map(_._2), b)) case FunctionInvocation(tfd, args) => if (fds contains tfd.fd) { @@ -2239,4 +2160,42 @@ object ExprOps { fun } + /** Returns true if expr is a value of type t */ + def isValueOfType(e: Expr, t: TypeTree): Boolean = { + (e, t) match { + case (StringLiteral(_), StringType) => true + case (IntLiteral(_), Int32Type) => true + case (InfiniteIntegerLiteral(_), IntegerType) => true + case (CharLiteral(_), CharType) => true + case (FractionalLiteral(_, _), RealType) => true + case (BooleanLiteral(_), BooleanType) => true + case (UnitLiteral(), UnitType) => true + case (GenericValue(t, _), tp) => t == tp + case (Tuple(elems), TupleType(bases)) => + elems zip bases forall (eb => isValueOfType(eb._1, eb._2)) + case (FiniteSet(elems, tbase), SetType(base)) => + tbase == base && + (elems forall isValue) + case (FiniteMap(elems, tk, tv), MapType(from, to)) => + tk == from && tv == to && + (elems forall (kv => isValueOfType(kv._1, from) && isValueOfType(kv._2, to) )) + case (NonemptyArray(elems, defaultValues), ArrayType(base)) => + elems.values forall (x => isValueOfType(x, base)) + case (EmptyArray(tpe), ArrayType(base)) => + tpe == base + case (CaseClass(ct, args), ct2@AbstractClassType(classDef, tps)) => + TypeOps.isSubtypeOf(ct, ct2) && + ((args zip ct.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2))) + case (CaseClass(ct, args), ct2@CaseClassType(classDef, tps)) => + ct == ct2 && + ((args zip ct.fieldsTypes) forall (argstyped => isValueOfType(argstyped._1, argstyped._2))) + case (Lambda(valdefs, body), FunctionType(ins, out)) => + (valdefs zip ins forall (vdin => vdin._1.getType == vdin._2)) && + body.getType == out + case _ => false + } + } + + /** Returns true if expr is a value. Stronger than isGround */ + val isValue = (e: Expr) => isValueOfType(e, e.getType) } diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index 2a714abaf6143eac9bb1d86cb762d87fec150176..dba3903654592da7000912b43bd1a5b6a5d55ed5 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -76,10 +76,6 @@ object Expressions { val getType = tpe } - case class Old(id: Identifier) extends Expr with Terminal { - val getType = id.getType - } - /** Precondition of an [[Expressions.Expr]]. Corresponds to the Leon keyword *require* * * @param pred The precondition formula inside ``require(...)`` @@ -165,7 +161,7 @@ object Expressions { * @param body The body of the expression after the function */ case class LetDef(fds: Seq[FunDef], body: Expr) extends Expr { - assert(fds.nonEmpty) + require(fds.nonEmpty) val getType = body.getType } @@ -231,7 +227,7 @@ object Expressions { } } - case class PartialLambda(mapping: Seq[(Seq[Expr], Expr)], default: Option[Expr], tpe: FunctionType) extends Expr { + case class FiniteLambda(mapping: Seq[(Seq[Expr], Expr)], default: Expr, tpe: FunctionType) extends Expr { val getType = tpe } @@ -275,7 +271,7 @@ object Expressions { * @param rhs The expression to the right of `=>` * @see [[Expressions.MatchExpr]] */ - case class MatchCase(pattern : Pattern, optGuard : Option[Expr], rhs: Expr) extends Tree { + case class MatchCase(pattern: Pattern, optGuard: Option[Expr], rhs: Expr) extends Tree { def expressions: Seq[Expr] = optGuard.toList :+ rhs } @@ -364,6 +360,23 @@ object Expressions { someValue.id ) } + + // Extracts without taking care of the binder. (contrary to Extractos.Pattern) + object PatternExtractor extends SubTreeOps.Extractor[Pattern] { + def unapply(e: Pattern): Option[(Seq[Pattern], (Seq[Pattern]) => Pattern)] = e match { + case (_: InstanceOfPattern) | (_: WildcardPattern) | (_: LiteralPattern[_]) => + Some(Seq(), es => e) + case CaseClassPattern(binder, ct, subpatterns) => + Some(subpatterns, es => CaseClassPattern(binder, ct, es)) + case TuplePattern(binder, subpatterns) => + Some(subpatterns, es => TuplePattern(binder, es)) + case UnapplyPattern(binder, unapplyFun, subpatterns) => + Some(subpatterns, es => UnapplyPattern(binder, unapplyFun, es)) + case _ => None + } + } + + object PatternOps extends { val Deconstructor = PatternExtractor } with SubTreeOps[Pattern] /** Symbolic I/O examples as a match/case. * $encodingof `out == (in match { cases; case _ => out })` @@ -579,7 +592,10 @@ object Expressions { /** $encodingof `lhs.subString(start, end)` for strings */ case class SubString(expr: Expr, start: Expr, end: Expr) extends Expr { val getType = { - if (expr.getType == StringType && (start == IntegerType || start == Int32Type) && (end == IntegerType || end == Int32Type)) StringType + val ext = expr.getType + val st = start.getType + val et = end.getType + if (ext == StringType && (st == IntegerType || st == Int32Type) && (et == IntegerType || et == Int32Type)) StringType else Untyped } } @@ -771,7 +787,7 @@ object Expressions { * * [[exprs]] should always contain at least 2 elements. * If you are not sure about this requirement, you should use - * [[purescala.Constructors#tupleWrap purescala's constructor tupleWrap]] + * [[leon.purescala.Constructors#tupleWrap purescala's constructor tupleWrap]] * * @param exprs The expressions in the tuple */ @@ -784,7 +800,7 @@ object Expressions { * * Index is 1-based, first element of tuple is 1. * If you are not sure that [[tuple]] is indeed of a TupleType, - * you should use [[purescala.Constructors$.tupleSelect(t:leon\.purescala\.Expressions\.Expr,index:Int,isTuple:Boolean):leon\.purescala\.Expressions\.Expr* purescala's constructor tupleSelect]] + * you should use [[leon.purescala.Constructors.tupleSelect(t:leon\.purescala\.Expressions\.Expr,index:Int,isTuple:Boolean):leon\.purescala\.Expressions\.Expr* purescala's constructor tupleSelect]] */ case class TupleSelect(tuple: Expr, index: Int) extends Expr { require(index >= 1) diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index e2581dd8cdb33e3d025e8206d72dd32fc4ca59f7..49e6afd3aeea07ad6ec48813d8ced497374b79c3 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -7,12 +7,11 @@ import Expressions._ import Common._ import Types._ import Constructors._ -import ExprOps._ -import Definitions.Program +import Definitions.{Program, AbstractClassDef, CaseClassDef} object Extractors { - object Operator { + object Operator extends SubTreeOps.Extractor[Expr] { def unapply(expr: Expr): Option[(Seq[Expr], (Seq[Expr]) => Expr)] = expr match { /* Unary operators */ case Not(t) => @@ -53,7 +52,7 @@ object Extractors { Some((Seq(a), (es: Seq[Expr]) => ArrayLength(es.head))) case Lambda(args, body) => Some((Seq(body), (es: Seq[Expr]) => Lambda(args, es.head))) - case PartialLambda(mapping, dflt, tpe) => + case FiniteLambda(mapping, dflt, tpe) => val sze = tpe.from.size + 1 val subArgs = mapping.flatMap { case (args, v) => args :+ v } val builder = (as: Seq[Expr]) => { @@ -64,10 +63,9 @@ object Extractors { case Seq() => Seq.empty case _ => sys.error("unexpected number of key/value expressions") } - val (nas, nd) = if (dflt.isDefined) (as.init, Some(as.last)) else (as, None) - PartialLambda(rec(nas), nd, tpe) + FiniteLambda(rec(as.init), as.last, tpe) } - Some((subArgs ++ dflt, builder)) + Some((subArgs :+ dflt, builder)) case Forall(args, body) => Some((Seq(body), (es: Seq[Expr]) => Forall(args, es.head))) @@ -147,7 +145,7 @@ object Extractors { Some(Seq(t1, t2), (es: Seq[Expr]) => SetUnion(es(0), es(1))) case SetDifference(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => SetDifference(es(0), es(1))) - case mg@MapApply(t1, t2) => + case mg @ MapApply(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => MapApply(es(0), es(1))) case MapUnion(t1, t2) => Some(Seq(t1, t2), (es: Seq[Expr]) => MapUnion(es(0), es(1))) @@ -167,9 +165,9 @@ object Extractors { Some(Seq(const, body), (es: Seq[Expr]) => Assert(es(0), oerr, es(1))) /* Other operators */ - case fi@FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _))) - case mi@MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, as => MethodInvocation(as.head, cd, tfd, as.tail))) - case fa@Application(caller, args) => Some(caller +: args, as => application(as.head, as.tail)) + case fi @ FunctionInvocation(fd, args) => Some((args, FunctionInvocation(fd, _))) + case mi @ MethodInvocation(rec, cd, tfd, args) => Some((rec +: args, as => MethodInvocation(as.head, cd, tfd, as.tail))) + case fa @ Application(caller, args) => Some(caller +: args, as => Application(as.head, as.tail)) case CaseClass(cd, args) => Some((args, CaseClass(cd, _))) case And(args) => Some((args, and)) case Or(args) => Some((args, or)) @@ -199,7 +197,7 @@ object Extractors { val l = as.length nonemptyArray(as.take(l - 2), Some((as(l - 2), as(l - 1)))) })) - case na@NonemptyArray(elems, None) => + case na @ NonemptyArray(elems, None) => val ArrayType(tpe) = na.getType val (indexes, elsOrdered) = elems.toSeq.unzip @@ -250,6 +248,8 @@ object Extractors { None } } + + // Extractors for types are available at Types.NAryType trait Extractable { def extract: Option[(Seq[Expr], Seq[Expr] => Expr)] @@ -367,7 +367,7 @@ object Extractors { def unapply(me : MatchExpr) : Option[(Pattern, Expr, Expr)] = { Option(me) collect { - case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, variablesOf(scrut)) => + case MatchExpr(scrut, List(SimpleCase(pattern, body))) if !aliased(pattern.binders, ExprOps.variablesOf(scrut)) => ( pattern, scrut, body ) } } diff --git a/src/main/scala/leon/purescala/MethodLifting.scala b/src/main/scala/leon/purescala/MethodLifting.scala index 2014739eaa3fba0c84ce10685087262e7e227872..0e9e11d171a15e2b0b5fb77d0b3f0e65d44be38b 100644 --- a/src/main/scala/leon/purescala/MethodLifting.scala +++ b/src/main/scala/leon/purescala/MethodLifting.scala @@ -251,6 +251,10 @@ object MethodLifting extends TransformationPhase { ) } + if (cd.methods.exists(md => md.id == fd.id && md.isInvariant)) { + cd.setInvariant(nfd) + } + mdToFds += fd -> nfd fdsOf += cd.id.name -> (fdsOf.getOrElse(cd.id.name, Set()) + nfd) } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 039bef507339294874ad59e7e38dfe335b8f6a2f..20502a13640e1ab3fd5820a9bcfc38de46ece0f5 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -81,15 +81,12 @@ class PrettyPrinter(opts: PrinterOptions, } p"$name" - case Old(id) => - p"old($id)" - case Variable(id) => p"$id" case Let(b,d,e) => - p"""|val $b = $d - |$e""" + p"""|val $b = $d + |$e""" case LetDef(a::q,body) => p"""|$a @@ -117,11 +114,17 @@ class PrettyPrinter(opts: PrinterOptions, |}""" case p@Passes(in, out, tests) => - optP { - p"""|($in, $out) passes { - | ${nary(tests, "\n")} - |}""" + tests match { + case Seq(MatchCase(_, Some(BooleanLiteral(false)), NoTree(_))) => + p"""|byExample($in, $out)""" + case _ => + optP { + p"""|($in, $out) passes { + | ${nary(tests, "\n")} + |}""" + } } + case c @ WithOracle(vars, pred) => p"""|withOracle { (${typed(vars)}) => @@ -273,19 +276,15 @@ class PrettyPrinter(opts: PrinterOptions, case Lambda(args, body) => optP { p"($args) => $body" } - case PartialLambda(mapping, dflt, _) => + case FiniteLambda(mapping, dflt, _) => optP { def pm(p: (Seq[Expr], Expr)): PrinterHelpers.Printable = (pctx: PrinterContext) => p"${purescala.Constructors.tupleWrap(p._1)} => ${p._2}"(pctx) if (mapping.isEmpty) { - p"{}" + p"{ * => ${dflt} }" } else { - p"{ ${nary(mapping map pm)} }" - } - - if (dflt.isDefined) { - p" getOrElse ${dflt.get}" + p"{ ${nary(mapping map pm)}, * => ${dflt} }" } } diff --git a/src/main/scala/leon/purescala/Quantification.scala b/src/main/scala/leon/purescala/Quantification.scala index bb5115baab042a1bd524fdb2f73deeabefa59243..21608b0facdb2d2318320a0e58974d23f56dc491 100644 --- a/src/main/scala/leon/purescala/Quantification.scala +++ b/src/main/scala/leon/purescala/Quantification.scala @@ -49,7 +49,7 @@ object Quantification { res.filter(ms => ms.forall(m => reverseMap(m) subsetOf ms)) } - def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[(Set[(Expr, Expr, Seq[Expr])], Set[(Expr, Expr, Seq[Expr])])] = { + def extractQuorums(expr: Expr, quantified: Set[Identifier]): Seq[Set[(Expr, Expr, Seq[Expr])]] = { object QMatcher { def unapply(e: Expr): Option[(Expr, Seq[Expr])] = e match { case QuantificationMatcher(expr, args) => @@ -65,49 +65,19 @@ object Quantification { val allMatchers = CollectorWithPaths { case QMatcher(expr, args) => expr -> args }.traverse(expr) val matchers = allMatchers.map { case ((caller, args), path) => (path, caller, args) }.toSet - val quorums = extractQuorums(matchers, quantified, + extractQuorums(matchers, quantified, (p: (Expr, Expr, Seq[Expr])) => p._3.collect { case QMatcher(e, a) => (p._1, e, a) }.toSet, (p: (Expr, Expr, Seq[Expr])) => p._3.collect { case Variable(id) if quantified(id) => id }.toSet) - - quorums.map(quorum => quorum -> matchers.filter(m => !quorum(m))) - } - - def extractModel( - asMap: Map[Identifier, Expr], - funDomains: Map[Identifier, Set[Seq[Expr]]], - typeDomains: Map[TypeTree, Set[Seq[Expr]]], - evaluator: DeterministicEvaluator - ): Map[Identifier, Expr] = asMap.map { case (id, expr) => - id -> (funDomains.get(id) match { - case Some(domain) => - PartialLambda(domain.toSeq.map { es => - val optEv = evaluator.eval(Application(expr, es)).result - es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + Application(expr, es))) - }, None, id.getType.asInstanceOf[FunctionType]) - - case None => postMap { - case p @ PartialLambda(mapping, dflt, tpe) => - Some(PartialLambda(typeDomains.get(tpe) match { - case Some(domain) => domain.toSeq.map { es => - val optEv = evaluator.eval(Application(p, es)).result - es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + Application(p, es))) - } - case _ => Seq.empty - }, None, tpe)) - case _ => None - } (expr) - }) } - object HenkinDomains { - def empty = new HenkinDomains(Map.empty, Map.empty) + object Domains { + def empty = new Domains(Map.empty, Map.empty) } - class HenkinDomains (val lambdas: Map[Lambda, Set[Seq[Expr]]], val tpes: Map[TypeTree, Set[Seq[Expr]]]) { + class Domains (val lambdas: Map[Lambda, Set[Seq[Expr]]], val tpes: Map[TypeTree, Set[Seq[Expr]]]) { def get(e: Expr): Set[Seq[Expr]] = { val specialized: Set[Seq[Expr]] = e match { - case PartialLambda(_, Some(dflt), _) => scala.sys.error("No domain for non-partial lambdas") - case PartialLambda(mapping, _, _) => mapping.map(_._1).toSet + case FiniteLambda(mapping, _, _) => mapping.map(_._1).toSet case l: Lambda => lambdas.getOrElse(l, Set.empty) case _ => Set.empty } @@ -117,7 +87,7 @@ object Quantification { object QuantificationMatcher { private def flatApplication(expr: Expr): Option[(Expr, Seq[Expr])] = expr match { - case Application(fi: FunctionInvocation, args) => Some((fi, args)) + case Application(fi: FunctionInvocation, args) => None case Application(caller: Application, args) => flatApplication(caller) match { case Some((c, prevArgs)) => Some((c, prevArgs ++ args)) case None => None diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 67cc994649e6b454ee8389fbfdd7674da4aebb6a..11f0c187e144873c386a01702326056e636e1225 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -9,14 +9,12 @@ import Common._ import Expressions._ import Types._ import Definitions._ -import org.apache.commons.lang3.StringEscapeUtils -/** This pretty-printer only print valid scala syntax */ +/** This pretty-printer only prints valid scala syntax */ class ScalaPrinter(opts: PrinterOptions, opgm: Option[Program], sb: StringBuffer = new StringBuffer) extends PrettyPrinter(opts, opgm, sb) { - private val dbquote = "\"" override def pp(tree: Tree)(implicit ctx: PrinterContext): Unit = { tree match { diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index f0ff379ffd6ec0419d631777c623098b21838a57..e06055dc4d9dfce8b24d2d8ddb698ebbbc781079 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -3,18 +3,24 @@ package leon package purescala +import collection.mutable.ListBuffer import Common._ import Definitions._ import Expressions._ import Extractors._ +import Constructors.letDef class ScopeSimplifier extends Transformer { case class Scope(inScope: Set[Identifier] = Set(), oldToNew: Map[Identifier, Identifier] = Map(), funDefs: Map[FunDef, FunDef] = Map()) { - + def register(oldNew: (Identifier, Identifier)): Scope = { val newId = oldNew._2 copy(inScope = inScope + newId, oldToNew = oldToNew + oldNew) } + + def register(oldNews: Seq[(Identifier, Identifier)]): Scope = { + (this /: oldNews){ case (oldScope, oldNew) => oldScope.register(oldNew) } + } def registerFunDef(oldNew: (FunDef, FunDef)): Scope = { copy(funDefs = funDefs + oldNew) @@ -44,22 +50,23 @@ class ScopeSimplifier extends Transformer { } val fds_mapping = for((fd, newId) <- fds_newIds) yield { + val localScopeToRegister = ListBuffer[(Identifier, Identifier)]() // We record the mapping of these variables only for the function. val newArgs = for(ValDef(id) <- fd.params) yield { - val newArg = genId(id, newScope) - newScope = newScope.register(id -> newArg) + val newArg = genId(id, newScope.register(localScopeToRegister)) + localScopeToRegister += (id -> newArg) // This renaming happens only inside the function. ValDef(newArg) } val newFd = fd.duplicate(id = newId, params = newArgs) newScope = newScope.registerFunDef(fd -> newFd) - (newFd, fd) + (newFd, localScopeToRegister, fd) } - for((newFd, fd) <- fds_mapping) { - newFd.fullBody = rec(fd.fullBody, newScope) + for((newFd, localScopeToRegister, fd) <- fds_mapping) { + newFd.fullBody = rec(fd.fullBody, newScope.register(localScopeToRegister)) } - LetDef(fds_mapping.map(_._1), rec(body, newScope)) + letDef(fds_mapping.map(_._1), rec(body, newScope)) case MatchExpr(scrut, cases) => val rs = rec(scrut, scope) diff --git a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala index aa4c204d00187ab211a7164a45ae41d8d77d8f8f..d1cc5f86c6659b2fca51d42b3ebbb0feab20d084 100644 --- a/src/main/scala/leon/purescala/SelfPrettyPrinter.scala +++ b/src/main/scala/leon/purescala/SelfPrettyPrinter.scala @@ -1,58 +1,126 @@ package leon.purescala -import leon.evaluators.StringTracingEvaluator import leon.purescala +import leon.solvers.{ Model, SolverFactory } +import leon.LeonContext +import leon.evaluators +import leon.utils.StreamUtils +import leon.purescala.Quantification._ +import leon.utils.DebugSectionSynthesis +import leon.utils.DebugSectionVerification import purescala.Definitions.Program -import leon.evaluators.StringTracingEvaluator import purescala.Expressions._ import purescala.Types.StringType -import leon.utils.DebugSectionSynthesis -import leon.utils.DebugSectionVerification -import leon.purescala.Quantification._ import purescala.Constructors._ import purescala.ExprOps._ -import purescala.Expressions.{Pattern, Expr} +import purescala.Expressions._ +import purescala.Expressions.{Choose } import purescala.Extractors._ import purescala.TypeOps._ import purescala.Types._ import purescala.Common._ -import purescala.Expressions._ import purescala.Definitions._ -import leon.solvers.{ HenkinModel, Model, SolverFactory } -import leon.LeonContext -import leon.evaluators +import scala.collection.mutable.ListBuffer +import leon.evaluators.DefaultEvaluator + +object SelfPrettyPrinter { + def prettyPrintersForType(inputType: TypeTree)(implicit ctx: LeonContext, program: Program): Stream[Lambda] = { + (new SelfPrettyPrinter).prettyPrintersForType(inputType) + } + def print(v: Expr, orElse: =>String, excluded: Set[FunDef] = Set())(implicit ctx: LeonContext, program: Program): String = { + (new SelfPrettyPrinter).print(v, orElse, excluded) + } +} /** This pretty-printer uses functions defined in Leon itself. * If not pretty printing function is defined, return the default value instead * @param The list of functions which should be excluded from pretty-printing (to avoid rendering counter-examples of toString methods using the method itself) * @return a user defined string for the given typed expression. */ -object SelfPrettyPrinter { - def print(v: Expr, orElse: =>String, excluded: FunDef => Boolean = Set())(implicit ctx: LeonContext, program: Program): String = { - (program.definedFunctions find { - case fd if !excluded(fd) => - fd.returnType == StringType && fd.params.length == 1 && TypeOps.isSubtypeOf(v.getType, fd.params.head.getType) && fd.id.name.toLowerCase().endsWith("tostring") && - program.callGraph.transitiveCallees(fd).forall { fde => - !purescala.ExprOps.exists( _.isInstanceOf[Choose])(fde.fullBody) +class SelfPrettyPrinter { + implicit val section = leon.utils.DebugSectionEvaluation + private var allowedFunctions = Set[FunDef]() + private var excluded = Set[FunDef]() + /** Functions whose name does not need to end with `tostring` or which can be abstract, i.e. which may contain a choose construct.*/ + def allowFunction(fd: FunDef) = { allowedFunctions += fd; this } + + def excludeFunctions(fds: Set[FunDef]) = { excluded ++= fds; this } + def excludeFunction(fd: FunDef) = { excluded += fd; this } + + /** Returns a list of possible lambdas that can transform the input type to a String. + * At this point, it does not consider yet the inputType. Only [[prettyPrinterFromCandidate]] will consider it. */ + def prettyPrintersForType(inputType: TypeTree/*, existingPp: Map[TypeTree, List[Lambda]] = Map()*/)(implicit ctx: LeonContext, program: Program): Stream[Lambda] = { + program.definedFunctions.toStream flatMap { + fd => + val isCandidate = fd.returnType == StringType && + fd.params.length >= 1 && + !excluded(fd) && + (allowedFunctions(fd) || ( + fd.id.name.toLowerCase().endsWith("tostring"))) + if(isCandidate) { + prettyPrinterFromCandidate(fd, inputType) + } else Stream.Empty + } + } + + + def prettyPrinterFromCandidate(fd: FunDef, inputType: TypeTree)(implicit ctx: LeonContext, program: Program): Stream[Lambda] = { + TypeOps.canBeSubtypeOf(inputType, fd.tparams.map(_.tp), fd.params.head.getType) match { + case Some(genericTypeMap) => + val defGenericTypeMap = genericTypeMap.map{ case (k, v) => (Definitions.TypeParameterDef(k), v) } + def gatherPrettyPrinters(funIds: List[Identifier], acc: ListBuffer[Stream[Lambda]] = ListBuffer()): Option[Stream[List[Lambda]]] = funIds match { + case Nil => Some(StreamUtils.cartesianProduct(acc.toList)) + case funId::tail => // For each function, find an expression which could be provided if it exists. + funId.getType match { + case FunctionType(Seq(in), StringType) => // Should have one argument. + val candidates = prettyPrintersForType(in) + gatherPrettyPrinters(tail, acc += candidates) + case _ => None + } + } + val funIds = fd.params.tail.map(x => TypeOps.instantiateType(x.id, defGenericTypeMap)).toList + gatherPrettyPrinters(funIds) match { + case Some(l) => for(lambdas <- l) yield { + val x = FreshIdentifier("x", inputType) // verify the type + Lambda(Seq(ValDef(x)), functionInvocation(fd, Variable(x)::lambdas)) + } + case _ => Stream.empty } + case None => Stream.empty + } + } + + + /** Actually prints the expression with as alternative the given orElse */ + def print(v: Expr, orElse: =>String, excluded: Set[FunDef] = Set())(implicit ctx: LeonContext, program: Program): String = { + this.excluded = excluded + val s = prettyPrintersForType(v.getType) // TODO: Included the variable excluded if necessary. + s.take(100).find(l => l match { // Limit the number of pretty-printers. + case Lambda(_, FunctionInvocation(TypedFunDef(fd, _), _)) => + (program.callGraph.transitiveCallees(fd) + fd).forall { fde => + !ExprOps.exists( _.isInstanceOf[Choose])(fde.fullBody) + } + case _ => false }) match { - case Some(fd) => - //println("Found fd: " + fd.id.name) - val ste = new StringTracingEvaluator(ctx, program) + case None => orElse + case Some(l) => + ctx.reporter.debug("Executing pretty printer for type " + v.getType + " : " + l + " on " + v) + val ste = new DefaultEvaluator(ctx, program) try { - val result = ste.eval(FunctionInvocation(fd.typed, Seq(v))) - - result.result match { - case Some((StringLiteral(res), _)) if res != "" => - res - case _ => - orElse - } + val toEvaluate = application(l, Seq(v)) + val result = ste.eval(toEvaluate) + + result.result match { + case Some(StringLiteral(res)) if res != "" => + res + case res => + ctx.reporter.debug("not a string literal " + res) + orElse + } } catch { case e: evaluators.ContextualEvaluator#EvalError => + ctx.reporter.debug("Error " + e.msg) orElse } - case None => - orElse } } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/purescala/SubTreeOps.scala b/src/main/scala/leon/purescala/SubTreeOps.scala new file mode 100644 index 0000000000000000000000000000000000000000..140bd5edc2ff5f316a7afb0d90442df12b78ced8 --- /dev/null +++ b/src/main/scala/leon/purescala/SubTreeOps.scala @@ -0,0 +1,327 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package purescala + +import Expressions.Expr +import Types.TypeTree +import Common._ +import utils._ + +object SubTreeOps { + trait Extractor[SubTree <: Tree] { + def unapply(e: SubTree): Option[(Seq[SubTree], (Seq[SubTree]) => SubTree)] + } +} +trait SubTreeOps[SubTree <: Tree] { + val Deconstructor: SubTreeOps.Extractor[SubTree] + + /* ======== + * Core API + * ======== + * + * All these functions should be stable, tested, and used everywhere. Modify + * with care. + */ + + /** Does a right tree fold + * + * A right tree fold applies the input function to the subnodes first (from left + * to right), and combine the results along with the current node value. + * + * @param f a function that takes the current node and the seq + * of results form the subtrees. + * @param e The value on which to apply the fold. + * @return The expression after applying `f` on all subtrees. + * @note the computation is lazy, hence you should not rely on side-effects of `f` + */ + def fold[T](f: (SubTree, Seq[T]) => T)(e: SubTree): T = { + val rec = fold(f) _ + val Deconstructor(es, _) = e + + //Usages of views makes the computation lazy. (which is useful for + //contains-like operations) + f(e, es.view.map(rec)) + } + + + /** Pre-traversal of the tree. + * + * Invokes the input function on every node '''before''' visiting + * children. Traverse children from left to right subtrees. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) + * }}} + * will yield, in order: + * {{{ + * f(Add(a, Minus(b, c))); f(a); f(Minus(b, c)); f(b); f(c) + * }}} + * + * @param f a function to apply on each node of the expression + * @param e the expression to traverse + */ + def preTraversal(f: SubTree => Unit)(e: SubTree): Unit = { + val rec = preTraversal(f) _ + val Deconstructor(es, _) = e + f(e) + es.foreach(rec) + } + + /** Post-traversal of the tree. + * + * Invokes the input function on every node '''after''' visiting + * children. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) + * }}} + * will yield, in order: + * {{{ + * f(a), f(b), f(c), f(Minus(b, c)), f(Add(a, Minus(b, c))) + * }}} + * + * @param f a function to apply on each node of the expression + * @param e the expression to traverse + */ + def postTraversal(f: SubTree => Unit)(e: SubTree): Unit = { + val rec = postTraversal(f) _ + val Deconstructor(es, _) = e + es.foreach(rec) + f(e) + } + + /** Pre-transformation of the tree. + * + * Takes a partial function of replacements and substitute + * '''before''' recursing down the trees. + * + * Supports two modes : + * + * - If applyRec is false (default), will only substitute once on each level. + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, d) // And not Add(a, f) because it only substitute once for each level. + * }}} + * + * - If applyRec is true, it will substitute multiple times on each level: + * + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, f) + * }}} + * + * @note The mode with applyRec true can diverge if f is not well formed + */ + def preMap(f: SubTree => Option[SubTree], applyRec : Boolean = false)(e: SubTree): SubTree = { + val rec = preMap(f, applyRec) _ + + val newV = if (applyRec) { + // Apply f as long as it returns Some() + fixpoint { e : SubTree => f(e) getOrElse e } (e) + } else { + f(e) getOrElse e + } + + val Deconstructor(es, builder) = newV + val newEs = es.map(rec) + + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(newV) + } else { + newV + } + } + + + /** Post-transformation of the tree. + * + * Takes a partial function of replacements. + * Substitutes '''after''' recursing down the trees. + * + * Supports two modes : + * + * - If applyRec is false (default), will only substitute once on each level. + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(b,c) -> z, Minus(e,c) -> d, b -> e + * }}} + * will yield: + * {{{ + * Add(a, Minus(e, c)) + * }}} + * + * - If applyRec is true, it will substitute multiple times on each level: + * e.g. + * {{{ + * Add(a, Minus(b, c)) with replacements: Minus(e,c) -> d, b -> e, d -> f + * }}} + * will yield: + * {{{ + * Add(a, f) + * }}} + * + * @note The mode with applyRec true can diverge if f is not well formed (i.e. not convergent) + */ + def postMap(f: SubTree => Option[SubTree], applyRec : Boolean = false)(e: SubTree): SubTree = { + val rec = postMap(f, applyRec) _ + + val Deconstructor(es, builder) = e + val newEs = es.map(rec) + val newV = { + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + } + + if (applyRec) { + // Apply f as long as it returns Some() + fixpoint { e : SubTree => f(e) getOrElse e } (newV) + } else { + f(newV) getOrElse newV + } + + } + + + /** Applies functions and combines results in a generic way + * + * Start with an initial value, and apply functions to nodes before + * and after the recursion in the children. Combine the results of + * all children and apply a final function on the resulting node. + * + * @param pre a function applied on a node before doing a recursion in the children + * @param post a function applied to the node built from the recursive application to + all children + * @param combiner a function to combine the resulting values from all children with + the current node + * @param init the initial value + * @param expr the expression on which to apply the transform + * + * @see [[simpleTransform]] + * @see [[simplePreTransform]] + * @see [[simplePostTransform]] + */ + def genericTransform[C](pre: (SubTree, C) => (SubTree, C), + post: (SubTree, C) => (SubTree, C), + combiner: (SubTree, Seq[C]) => C)(init: C)(expr: SubTree) = { + + def rec(eIn: SubTree, cIn: C): (SubTree, C) = { + + val (expr, ctx) = pre(eIn, cIn) + val Deconstructor(es, builder) = expr + val (newExpr, newC) = { + val (nes, cs) = es.map{ rec(_, ctx)}.unzip + val newE = builder(nes).copiedFrom(expr) + + (newE, combiner(newE, cs)) + } + + post(newExpr, newC) + } + + rec(expr, init) + } + + /** Pre-transformation of the tree, with a context value from "top-down". + * + * Takes a partial function of replacements. + * Substitutes '''before''' recursing down the trees. The function returns + * an option of the new value, as well as the new context to be used for + * the recursion in its children. The context is "lost" when going back up, + * so changes made by one node will not be see by its siblings. + */ + def preMapWithContext[C](f: (SubTree, C) => (Option[SubTree], C), applyRec: Boolean = false) + (e: SubTree, c: C): SubTree = { + + def rec(expr: SubTree, context: C): SubTree = { + + val (newV, newCtx) = { + if(applyRec) { + var ctx = context + val finalV = fixpoint{ e: SubTree => { + val res = f(e, ctx) + ctx = res._2 + res._1.getOrElse(e) + }} (expr) + (finalV, ctx) + } else { + val res = f(expr, context) + (res._1.getOrElse(expr), res._2) + } + } + + val Deconstructor(es, builder) = newV + val newEs = es.map(e => rec(e, newCtx)) + + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(newV) + } else { + newV + } + + } + + rec(e, c) + } + + /* + * ============= + * Auxiliary API + * ============= + * + * Convenient methods using the Core API. + */ + + /** Checks if the predicate holds in some sub-expression */ + def exists(matcher: SubTree => Boolean)(e: SubTree): Boolean = { + fold[Boolean]({ (e, subs) => matcher(e) || subs.contains(true) } )(e) + } + + /** Collects a set of objects from all sub-expressions */ + def collect[T](matcher: SubTree => Set[T])(e: SubTree): Set[T] = { + fold[Set[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) + } + + def collectPreorder[T](matcher: SubTree => Seq[T])(e: SubTree): Seq[T] = { + fold[Seq[T]]({ (e, subs) => matcher(e) ++ subs.flatten } )(e) + } + + /** Returns a set of all sub-expressions matching the predicate */ + def filter(matcher: SubTree => Boolean)(e: SubTree): Set[SubTree] = { + collect[SubTree] { e => Set(e) filter matcher }(e) + } + + /** Counts how many times the predicate holds in sub-expressions */ + def count(matcher: SubTree => Int)(e: SubTree): Int = { + fold[Int]({ (e, subs) => matcher(e) + subs.sum } )(e) + } + + /** Replaces bottom-up sub-expressions by looking up for them in a map */ + def replace(substs: Map[SubTree,SubTree], expr: SubTree) : SubTree = { + postMap(substs.lift)(expr) + } + + /** Replaces bottom-up sub-expressions by looking up for them in the provided order */ + def replaceSeq(substs: Seq[(SubTree, SubTree)], expr: SubTree): SubTree = { + var res = expr + for (s <- substs) { + res = replace(Map(s), res) + } + res + } + +} \ No newline at end of file diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index db655365c24a7304831e818849238bba0a849de9..fedac58ec483953a58c6cb9576ba0340da30b720 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -11,16 +11,22 @@ import Extractors._ import Constructors._ import ExprOps.preMap -object TypeOps { +object TypeOps extends { val Deconstructor = NAryType } with SubTreeOps[TypeTree] { def typeDepth(t: TypeTree): Int = t match { - case NAryType(tps, builder) => 1+ (0 +: (tps map typeDepth)).max + case NAryType(tps, builder) => 1 + (0 +: (tps map typeDepth)).max } - def typeParamsOf(t: TypeTree): Set[TypeParameter] = t match { - case tp: TypeParameter => Set(tp) - case _ => - val NAryType(subs, _) = t - subs.flatMap(typeParamsOf).toSet + def typeParamsOf(t: TypeTree): Set[TypeParameter] = { + collect[TypeParameter]({ + case tp: TypeParameter => Set(tp) + case _ => Set.empty + })(t) + } + + def typeParamsOf(expr: Expr): Set[TypeParameter] = { + var tparams: Set[TypeParameter] = Set.empty + ExprOps.preTraversal(e => typeParamsOf(e.getType))(expr) + tparams } def canBeSubtypeOf( @@ -127,9 +133,43 @@ object TypeOps { if (args.forall(_.isDefined)) Some(TupleType(args.map(_.get))) else None case (FunctionType(from1, to1), FunctionType(from2, to2)) => - // TODO: make functions contravariant to arg. types - if (from1 == from2) { - leastUpperBound(to1, to2) map { FunctionType(from1, _) } + val args = (from1 zip from2).map(p => greatestLowerBound(p._1, p._2)) + if (args.forall(_.isDefined)) { + leastUpperBound(to1, to2) map { FunctionType(args.map(_.get), _) } + } else { + None + } + + case (o1, o2) if o1 == o2 => Some(o1) + case _ => None + } + + def greatestLowerBound(t1: TypeTree, t2: TypeTree): Option[TypeTree] = (t1,t2) match { + case (c1: ClassType, c2: ClassType) => + + def computeChains(ct: ClassType): Set[ClassType] = ct.parent match { + case Some(pct) => + computeChains(pct) + ct + case None => + Set(ct) + } + + if (computeChains(c1)(c2)) { + Some(c2) + } else if (computeChains(c2)(c1)) { + Some(c1) + } else { + None + } + + case (TupleType(args1), TupleType(args2)) => + val args = (args1 zip args2).map(p => greatestLowerBound(p._1, p._2)) + if (args.forall(_.isDefined)) Some(TupleType(args.map(_.get))) else None + + case (FunctionType(from1, to1), FunctionType(from2, to2)) => + val args = (from1 zip from2).map(p => leastUpperBound(p._1, p._2)) + if (args.forall(_.isDefined)) { + greatestLowerBound(to1, to2).map { FunctionType(args.map(_.get), _) } } else { None } @@ -185,7 +225,7 @@ object TypeOps { id } } - + def instantiateType(id: Identifier, tps: Map[TypeParameterDef, TypeTree]): Identifier = { freshId(id, typeParamSubst(tps map { case (tpd, tp) => tpd.tp -> tp })(id.getType)) } @@ -313,7 +353,7 @@ object TypeOps { val returnType = tpeSub(fd.returnType) val params = fd.params map (vd => vd.copy(id = freshId(vd.id, tpeSub(vd.getType)))) val newFd = fd.duplicate(id, tparams, params, returnType) - val subCalls = preMap { + val subCalls = ExprOps.preMap { case fi @ FunctionInvocation(tfd, args) if tfd.fd == fd => Some(FunctionInvocation(newFd.typed(tfd.tps), args).copiedFrom(fi)) case _ => @@ -335,7 +375,7 @@ object TypeOps { } val newBd = srec(subCalls(bd)).copiedFrom(bd) - LetDef(newFds, newBd).copiedFrom(l) + letDef(newFds, newBd).copiedFrom(l) case l @ Lambda(args, body) => val newArgs = args.map { arg => @@ -381,6 +421,10 @@ object TypeOps { case m @ FiniteMap(elems, from, to) => FiniteMap(elems.map{ case (k, v) => (srec(k), srec(v)) }, tpeSub(from), tpeSub(to)).copiedFrom(m) + case f @ FiniteLambda(mapping, dflt, FunctionType(from, to)) => + FiniteLambda(mapping.map { case (ks, v) => ks.map(srec) -> srec(v) }, srec(dflt), + FunctionType(from.map(tpeSub), tpeSub(to))).copiedFrom(f) + case v @ Variable(id) if idsMap contains id => Variable(idsMap(id)).copiedFrom(v) diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index 3a0a85bb24045df18ab65b0afa488a34c8921315..6f2518d549be6b1d439ff9cb08594558d02449d5 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -114,6 +114,8 @@ object Types { } } + def invariant = classDef.invariant.map(_.typed(tps)) + def knownDescendants = classDef.knownDescendants.map( _.typed(tps) ) def knownCCDescendants: Seq[CaseClassType] = classDef.knownCCDescendants.map( _.typed(tps) ) @@ -128,12 +130,12 @@ object Types { case t => throw LeonFatalError("Unexpected translated parent type: "+t) } } - } + case class AbstractClassType(classDef: AbstractClassDef, tps: Seq[TypeTree]) extends ClassType case class CaseClassType(classDef: CaseClassDef, tps: Seq[TypeTree]) extends ClassType - object NAryType { + object NAryType extends SubTreeOps.Extractor[TypeTree] { def unapply(t: TypeTree): Option[(Seq[TypeTree], Seq[TypeTree] => TypeTree)] = t match { case CaseClassType(ccd, ts) => Some((ts, ts => CaseClassType(ccd, ts))) case AbstractClassType(acd, ts) => Some((ts, ts => AbstractClassType(acd, ts))) @@ -142,6 +144,7 @@ object Types { case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1)))) case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head))) + /* n-ary operators */ case t => Some(Nil, _ => t) } } diff --git a/src/main/scala/leon/repair/Repairman.scala b/src/main/scala/leon/repair/Repairman.scala index 37f187679897bbec908ffeaf18f5385198f915c9..9dcd782a3323138e5bfe4c68822fba614e2065e7 100644 --- a/src/main/scala/leon/repair/Repairman.scala +++ b/src/main/scala/leon/repair/Repairman.scala @@ -3,6 +3,7 @@ package leon package repair +import leon.datagen.GrammarDataGen import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ @@ -25,7 +26,6 @@ import synthesis.Witnesses._ import synthesis.graph.{dotGenIds, DotGenerator} import rules._ -import grammars._ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeoutMs: Option[Long], repairTimeoutMs: Option[Long]) { implicit val ctx = ctx0 @@ -155,7 +155,7 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou }(DebugSectionReport) if (synth.settings.generateDerivationTrees) { - val dot = new DotGenerator(search.g) + val dot = new DotGenerator(search) dot.writeFile("derivation"+ dotGenIds.nextGlobal + ".dot") } @@ -236,29 +236,10 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou def discoverTests(): ExamplesBank = { - import bonsai.enumerators._ - val maxEnumerated = 1000 val maxValid = 400 val evaluator = new CodeGenEvaluator(ctx, program, CodeGenParams.default) - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](ValueGrammar.getProductions) - - val inputs = enum.iterator(tupleTypeWrap(fd.params map { _.getType})).map(unwrapTuple(_, fd.params.size)) - - val filtering: Seq[Expr] => Boolean = fd.precondition match { - case None => - _ => true - case Some(pre) => - val argIds = fd.paramIds - evaluator.compile(pre, argIds) match { - case Some(evalFun) => - val sat = EvaluationResults.Successful(BooleanLiteral(true)); - { (es: Seq[Expr]) => evalFun(new solvers.Model((argIds zip es).toMap)) == sat } - case None => - { _ => false } - } - } val inputsToExample: Seq[Expr] => Example = { ins => evaluator.eval(functionInvocation(fd, ins)) match { @@ -269,10 +250,10 @@ class Repairman(ctx0: LeonContext, initProgram: Program, fd: FunDef, verifTimeou } } - val generatedTests = inputs - .take(maxEnumerated) - .filter(filtering) - .take(maxValid) + val dataGen = new GrammarDataGen(evaluator) + + val generatedTests = dataGen + .generateFor(fd.paramIds, fd.precOrTrue, maxValid, maxEnumerated) .map(inputsToExample) .toList diff --git a/src/main/scala/leon/solvers/Model.scala b/src/main/scala/leon/solvers/Model.scala index 07bdee913f21605fbc41f660af608c492e5ee1b5..57f93655f5311655ffd1d479762dac839e750517 100644 --- a/src/main/scala/leon/solvers/Model.scala +++ b/src/main/scala/leon/solvers/Model.scala @@ -5,6 +5,7 @@ package solvers import purescala.Expressions._ import purescala.Common.Identifier +import purescala.Quantification.Domains import purescala.ExprOps._ trait AbstractModel[+This <: Model with AbstractModel[This]] @@ -40,8 +41,9 @@ trait AbstractModel[+This <: Model with AbstractModel[This]] "Model()" } else { (for ((k,v) <- mapping.toSeq.sortBy(_._1)) yield { - f" ${k.asString}%-20s -> ${v.asString}" - }).mkString("Model(\n", ",\n", ")") + val valuePadded = v.asString.replaceAll("\n", "\n"+(" "*26)) + f" ${k.asString}%-20s -> ${valuePadded}" + }).mkString("Model(\n", ",\n", "\n)") } } } @@ -68,6 +70,7 @@ class Model(protected val mapping: Map[Identifier, Expr]) def isDefinedAt(id: Identifier): Boolean = mapping.isDefinedAt(id) def get(id: Identifier): Option[Expr] = mapping.get(id) def getOrElse[E >: Expr](id: Identifier, e: E): E = get(id).getOrElse(e) + def ids = mapping.keys def apply(id: Identifier): Expr = get(id).getOrElse { throw new IllegalArgumentException } } @@ -78,3 +81,21 @@ object Model { class ModelBuilder extends AbstractModelBuilder[Model] { def result = new Model(mapBuilder.result) } + +class PartialModel(mapping: Map[Identifier, Expr], val domains: Domains) + extends Model(mapping) + with AbstractModel[PartialModel] { + + override def newBuilder = new PartialModelBuilder(domains) +} + +object PartialModel { + def empty = new PartialModel(Map.empty, Domains.empty) +} + +class PartialModelBuilder(domains: Domains) + extends ModelBuilder + with AbstractModelBuilder[PartialModel] { + + override def result = new PartialModel(mapBuilder.result, domains) +} diff --git a/src/main/scala/leon/solvers/QuantificationSolver.scala b/src/main/scala/leon/solvers/QuantificationSolver.scala deleted file mode 100644 index fa11ab6613bd65b196cce87ee062c3c56f0b95f9..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/QuantificationSolver.scala +++ /dev/null @@ -1,35 +0,0 @@ -package leon -package solvers - -import purescala.Common._ -import purescala.Expressions._ -import purescala.Quantification._ -import purescala.Definitions._ -import purescala.Types._ - -class HenkinModel(mapping: Map[Identifier, Expr], val doms: HenkinDomains) - extends Model(mapping) - with AbstractModel[HenkinModel] { - override def newBuilder = new HenkinModelBuilder(doms) - - def domain(expr: Expr) = doms.get(expr) -} - -object HenkinModel { - def empty = new HenkinModel(Map.empty, HenkinDomains.empty) -} - -class HenkinModelBuilder(domains: HenkinDomains) - extends ModelBuilder - with AbstractModelBuilder[HenkinModel] { - override def result = new HenkinModel(mapBuilder.result, domains) -} - -trait QuantificationSolver { - val program: Program - def getModel: HenkinModel - - protected lazy val requireQuantification = program.definedFunctions.exists { fd => - purescala.ExprOps.exists { case _: Forall => true case _ => false } (fd.fullBody) - } -} diff --git a/src/main/scala/leon/solvers/SolverFactory.scala b/src/main/scala/leon/solvers/SolverFactory.scala index 67d28f877019a3e4741df6d268c68fc2d86d17ee..4f8b1f50ae346887c945823ff429761be1a34a78 100644 --- a/src/main/scala/leon/solvers/SolverFactory.scala +++ b/src/main/scala/leon/solvers/SolverFactory.scala @@ -31,13 +31,13 @@ object SolverFactory { } val definedSolvers = Map( - "fairz3" -> "Native Z3 with z3-templates for unfolding (default)", + "fairz3" -> "Native Z3 with z3-templates for unrolling (default)", "smt-cvc4" -> "CVC4 through SMT-LIB", "smt-z3" -> "Z3 through SMT-LIB", "smt-z3-q" -> "Z3 through SMT-LIB, with quantified encoding", "smt-cvc4-proof" -> "CVC4 through SMT-LIB, in-solver inductive reasoning, for proofs only", "smt-cvc4-cex" -> "CVC4 through SMT-LIB, in-solver finite-model-finding, for counter-examples only", - "unrollz3" -> "Native Z3 with leon-templates for unfolding", + "unrollz3" -> "Native Z3 with leon-templates for unrolling", "ground" -> "Only solves ground verification conditions by evaluating them", "enum" -> "Enumeration-based counter-example-finder", "isabelle" -> "Isabelle2015 through libisabelle with various automated tactics" @@ -79,10 +79,12 @@ object SolverFactory { def getFromName(ctx: LeonContext, program: Program)(name: String): SolverFactory[TimeoutSolver] = name match { case "fairz3" => - SolverFactory(() => new FairZ3Solver(ctx, program) with TimeoutSolver) + // Previously: new FairZ3Solver(ctx, program) with TimeoutSolver + SolverFactory(() => new Z3StringFairZ3Solver(ctx, program) with TimeoutSolver) case "unrollz3" => - SolverFactory(() => new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) + // Previously: new UnrollingSolver(ctx, program, new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver + SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new UninterpretedZ3Solver(ctx, program)) with TimeoutSolver) case "enum" => SolverFactory(() => new EnumerationSolver(ctx, program) with TimeoutSolver) @@ -91,10 +93,12 @@ object SolverFactory { SolverFactory(() => new GroundSolver(ctx, program) with TimeoutSolver) case "smt-z3" => - SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) + // Previously: new UnrollingSolver(ctx, program, new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver + SolverFactory(() => new Z3StringUnrollingSolver(ctx, program, (program: Program) => new SMTLIBZ3Solver(ctx, program)) with TimeoutSolver) case "smt-z3-q" => - SolverFactory(() => new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) + // Previously: new SMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver + SolverFactory(() => new Z3StringSMTLIBZ3QuantifiedSolver(ctx, program) with TimeoutSolver) case "smt-cvc4" => SolverFactory(() => new UnrollingSolver(ctx, program, new SMTLIBCVC4Solver(ctx, program)) with TimeoutSolver) diff --git a/src/main/scala/leon/solvers/SolverUnsupportedError.scala b/src/main/scala/leon/solvers/SolverUnsupportedError.scala index 5d519160d7aed9fce7a42584c8d53806e53e265a..2efc8ea39b0da8494b2cd1309b3dcf9c2ca9cec3 100644 --- a/src/main/scala/leon/solvers/SolverUnsupportedError.scala +++ b/src/main/scala/leon/solvers/SolverUnsupportedError.scala @@ -7,7 +7,7 @@ import purescala.Common.Tree object SolverUnsupportedError { def msg(t: Tree, s: Solver, reason: Option[String]) = { - s" is unsupported by solver ${s.name}" + reason.map(":\n " + _ ).getOrElse("") + s"(of ${t.getClass}) is unsupported by solver ${s.name}" + reason.map(":\n " + _ ).getOrElse("") } } diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 1235d69c6d16a2726971cac386732fdcea501b95..95a8cdd67722660f666a9256a0020fab9d93b5d2 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -4,359 +4,685 @@ package leon package solvers package combinators +import purescala.Printable import purescala.Common._ import purescala.Definitions._ import purescala.Quantification._ import purescala.Constructors._ +import purescala.Extractors._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps.bestRealType import utils._ -import z3.FairZ3Component.{optFeelingLucky, optUseCodeGen, optAssumePre, optNoChecks, optUnfoldFactor} import templates._ import evaluators._ import Template._ -class UnrollingSolver(val context: LeonContext, val program: Program, underlying: Solver) - extends Solver - with NaiveAssumptionSolver - with EvaluatingSolver - with QuantificationSolver { +trait UnrollingProcedure extends LeonComponent { + val name = "Unroll-P" + val description = "Leon Unrolling Procedure" - val feelingLucky = context.findOptionOrDefault(optFeelingLucky) - val useCodeGen = context.findOptionOrDefault(optUseCodeGen) - val assumePreHolds = context.findOptionOrDefault(optAssumePre) - val disableChecks = context.findOptionOrDefault(optNoChecks) - val unfoldFactor = context.findOptionOrDefault(optUnfoldFactor) + val optUnrollFactor = LeonLongOptionDef("unrollfactor", "Number of unfoldings to perform in each unfold step", default = 1, "<PosInt>") + val optFeelingLucky = LeonFlagOptionDef("feelinglucky", "Use evaluator to find counter-examples early", false) + val optCheckModels = LeonFlagOptionDef("checkmodels", "Double-check counter-examples with evaluator", false) + val optUnrollCores = LeonFlagOptionDef("unrollcores", "Use unsat-cores to drive unfolding while remaining fair", false) + val optUseCodeGen = LeonFlagOptionDef("codegen", "Use compiled evaluator instead of interpreter", false) + val optAssumePre = LeonFlagOptionDef("assumepre", "Assume precondition holds (pre && f(x) = body) when unfolding", false) + val optPartialModels = LeonFlagOptionDef("partialmodels", "Extract domains for quantifiers and bounded first-class functions", false) - protected var lastCheckResult : (Boolean, Option[Boolean], Option[HenkinModel]) = (false, None, None) + override val definedOptions: Set[LeonOptionDef[Any]] = + Set(optCheckModels, optFeelingLucky, optUseCodeGen, optUnrollCores, optAssumePre, optUnrollFactor, optPartialModels) +} - private val freeVars = new IncrementalSet[Identifier]() - private val constraints = new IncrementalSeq[Expr]() +object UnrollingProcedure extends UnrollingProcedure - protected var interrupted : Boolean = false +trait AbstractUnrollingSolver[T] + extends UnrollingProcedure + with Solver + with EvaluatingSolver { - val reporter = context.reporter + val unfoldFactor = context.findOptionOrDefault(optUnrollFactor) + val feelingLucky = context.findOptionOrDefault(optFeelingLucky) + val checkModels = context.findOptionOrDefault(optCheckModels) + val useCodeGen = context.findOptionOrDefault(optUseCodeGen) + val unrollUnsatCores = context.findOptionOrDefault(optUnrollCores) + val assumePreHolds = context.findOptionOrDefault(optAssumePre) + val partialModels = context.findOptionOrDefault(optPartialModels) - def name = "U:"+underlying.name + protected var foundDefinitiveAnswer = false + protected var definitiveAnswer : Option[Boolean] = None + protected var definitiveModel : Model = Model.empty + protected var definitiveCore : Set[Expr] = Set.empty - def free() { - underlying.free() + def check: Option[Boolean] = { + genericCheck(Set.empty) } - val templateGenerator = new TemplateGenerator(new TemplateEncoder[Expr] { - def encodeId(id: Identifier): Expr= { - Variable(id.freshen) - } + def getModel: Model = if (foundDefinitiveAnswer && definitiveAnswer.getOrElse(false)) { + definitiveModel + } else { + Model.empty + } - def encodeExpr(bindings: Map[Identifier, Expr])(e: Expr): Expr = { - replaceFromIDs(bindings, e) - } + def getUnsatCore: Set[Expr] = if (foundDefinitiveAnswer && !definitiveAnswer.getOrElse(true)) { + definitiveCore + } else { + Set.empty + } - def substitute(substMap: Map[Expr, Expr]): Expr => Expr = { - (e: Expr) => replace(substMap, e) - } + private val freeVars = new IncrementalMap[Identifier, T]() + private val constraints = new IncrementalSeq[Expr]() - def mkNot(e: Expr) = not(e) - def mkOr(es: Expr*) = orJoin(es) - def mkAnd(es: Expr*) = andJoin(es) - def mkEquals(l: Expr, r: Expr) = Equals(l, r) - def mkImplies(l: Expr, r: Expr) = implies(l, r) - }, assumePreHolds) + protected var interrupted : Boolean = false - val unrollingBank = new UnrollingBank(context, templateGenerator) + protected val reporter = context.reporter - val solver = underlying + lazy val templateGenerator = new TemplateGenerator(templateEncoder, assumePreHolds) + lazy val unrollingBank = new UnrollingBank(context, templateGenerator) - def assertCnstr(expression: Expr) { - constraints += expression + def push(): Unit = { + unrollingBank.push() + constraints.push() + freeVars.push() + } - val freeIds = variablesOf(expression) + def pop(): Unit = { + unrollingBank.pop() + constraints.pop() + freeVars.pop() + } - freeVars ++= freeIds + override def reset() = { + foundDefinitiveAnswer = false + interrupted = false - val newVars = freeIds.map(_.toVariable: Expr) + unrollingBank.reset() + constraints.reset() + freeVars.reset() + } - val bindings = newVars.zip(newVars).toMap + override def interrupt(): Unit = { + interrupted = true + } - val newClauses = unrollingBank.getClauses(expression, bindings) + override def recoverInterrupt(): Unit = { + interrupted = false + } + + def assertCnstr(expression: Expr, bindings: Map[Identifier, T]): Unit = { + constraints += expression + freeVars ++= bindings + val newClauses = unrollingBank.getClauses(expression, bindings.map { case (k, v) => Variable(k) -> v }) for (cl <- newClauses) { - solver.assertCnstr(cl) + solverAssert(cl) } } - override def dbg(msg: => Any) = underlying.dbg(msg) - - def push() { - unrollingBank.push() - solver.push() - freeVars.push() - constraints.push() + def foundAnswer(res: Option[Boolean], model: Model = Model.empty, core: Set[Expr] = Set.empty) = { + foundDefinitiveAnswer = true + definitiveAnswer = res + definitiveModel = model + definitiveCore = core } - def pop() { - unrollingBank.pop() - solver.pop() - freeVars.pop() - constraints.pop() + implicit val printable: T => Printable + val templateEncoder: TemplateEncoder[T] + + def solverAssert(cnstr: T): Unit + + /** We define solverCheckAssumptions in CPS in order for solvers that don't + * support this feature to be able to use the provided [[solverCheck]] CPS + * construction. + */ + def solverCheckAssumptions[R](assumptions: Seq[T])(block: Option[Boolean] => R): R = + solverCheck(assumptions)(block) + + def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = + genericCheck(assumptions) + + /** Provides CPS solver.check call. CPS is necessary in order for calls that + * depend on solver.getModel to be able to access the model BEFORE the call + * to solver.pop() is issued. + * + * The underlying solver therefore performs the following sequence of + * solver calls: + * {{{ + * solver.push() + * for (cls <- clauses) solver.assertCnstr(cls) + * val res = solver.check + * block(res) + * solver.pop() + * }}} + * + * This ordering guarantees that [[block]] can safely call solver.getModel. + * + * This sequence of calls can also be used to mimic solver.checkAssumptions() + * for solvers that don't support the construct natively. + */ + def solverCheck[R](clauses: Seq[T])(block: Option[Boolean] => R): R + + def solverUnsatCore: Option[Seq[T]] + + trait ModelWrapper { + def get(id: Identifier): Option[Expr] + def eval(elem: T, tpe: TypeTree): Option[Expr] + + private[AbstractUnrollingSolver] def extract(b: T, m: Matcher[T]): Option[Seq[Expr]] = { + val QuantificationTypeMatcher(fromTypes, _) = m.tpe + val optEnabler = eval(b, BooleanType) + optEnabler.filter(_ == BooleanLiteral(true)).flatMap { _ => + val optArgs = (m.args zip fromTypes).map { case (arg, tpe) => eval(arg.encoded, tpe) } + if (optArgs.forall(_.isDefined)) { + Some(optArgs.map(_.get)) + } else { + None + } + } + } } - def check: Option[Boolean] = { - genericCheck(Set()) - } + def solverGetModel: ModelWrapper - def hasFoundAnswer = lastCheckResult._1 + private def emit(silenceErrors: Boolean)(msg: String) = + if (silenceErrors) reporter.debug(msg) else reporter.warning(msg) - private def extractModel(model: Model): HenkinModel = { - val allVars = freeVars.toSet + private def extractModel(wrapper: ModelWrapper): Model = + new Model(freeVars.toMap.map(p => p._1 -> wrapper.get(p._1).getOrElse(simplestValue(p._1.getType)))) - def extract(b: Expr, m: Matcher[Expr]): Set[Seq[Expr]] = { - val QuantificationTypeMatcher(fromTypes, _) = m.tpe - val optEnabler = evaluator.eval(b, model).result + private def validateModel(model: Model, assumptions: Seq[Expr], silenceErrors: Boolean): Boolean = { + val expr = andJoin(assumptions ++ constraints) - if (optEnabler == Some(BooleanLiteral(true))) { - val optArgs = m.args.map(arg => evaluator.eval(arg.encoded, model).result) - if (optArgs.forall(_.isDefined)) { - Set(optArgs.map(_.get)) - } else { - Set.empty - } - } else { - Set.empty - } + evaluator.eval(expr, model) match { + case EvaluationResults.Successful(BooleanLiteral(true)) => + reporter.debug("- Model validated.") + true + + case EvaluationResults.Successful(_) => + reporter.debug("- Invalid model.") + false + + case EvaluationResults.RuntimeError(msg) => + emit(silenceErrors)("- Model leads to runtime error: " + msg) + false + + case EvaluationResults.EvaluatorError(msg) => + emit(silenceErrors)("- Model leads to evaluation error: " + msg) + false } + } + + private def getPartialModel: PartialModel = { + val wrapped = solverGetModel - val (typeInsts, partialInsts, lambdaInsts) = templateGenerator.manager.instantiations + val typeInsts = templateGenerator.manager.typeInstantiations + val partialInsts = templateGenerator.manager.partialInstantiations + val lambdaInsts = templateGenerator.manager.lambdaInstantiations val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInsts.map { - case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => wrapped.extract(b, m) }.toSet } - val funDomains: Map[Identifier, Set[Seq[Expr]]] = partialInsts.map { - case (Variable(id), domain) => id -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + val funDomains: Map[Identifier, Set[Seq[Expr]]] = freeVars.toMap.map { case (id, idT) => + id -> partialInsts.get(idT).toSeq.flatten.flatMap { case (b, m) => wrapped.extract(b, m) }.toSet } val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInsts.map { - case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + case (l, domain) => l -> domain.flatMap { case (b, m) => wrapped.extract(b, m) }.toSet } - val asDMap = purescala.Quantification.extractModel(model.toMap, funDomains, typeDomains, evaluator) - val domains = new HenkinDomains(lambdaDomains, typeDomains) - new HenkinModel(asDMap, domains) - } + val model = new Model(freeVars.toMap.map { case (id, _) => + val value = wrapped.get(id).getOrElse(simplestValue(id.getType)) + id -> (funDomains.get(id) match { + case Some(domain) => + val dflt = value match { + case FiniteLambda(_, dflt, _) => dflt + case Lambda(_, IfExpr(_, _, dflt)) => dflt + case _ => scala.sys.error("Can't extract default from " + value) + } - def foundAnswer(res: Option[Boolean], model: Option[HenkinModel] = None) = { - lastCheckResult = (true, res, model) + FiniteLambda(domain.toSeq.map { es => + val optEv = evaluator.eval(application(value, es)).result + es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + application(value, es))) + }, dflt, id.getType.asInstanceOf[FunctionType]) + + case None => postMap { + case p @ FiniteLambda(mapping, dflt, tpe) => + Some(FiniteLambda(typeDomains.get(tpe) match { + case Some(domain) => domain.toSeq.map { es => + val optEv = evaluator.eval(application(value, es)).result + es -> optEv.getOrElse(scala.sys.error("Unexpectedly failed to evaluate " + application(value, es))) + } + case _ => Seq.empty + }, dflt, tpe)) + case _ => None + } (value) + }) + }) + + val domains = new Domains(lambdaDomains, typeDomains) + new PartialModel(model.toMap, domains) } - def validatedModel(silenceErrors: Boolean = false): (Boolean, HenkinModel) = { - val lastModel = solver.getModel - val clauses = templateGenerator.manager.checkClauses - val optModel = if (clauses.isEmpty) Some(lastModel) else { - solver.push() - for (clause <- clauses) { - solver.assertCnstr(clause) + private def getTotalModel: Model = { + val wrapped = solverGetModel + + def checkForalls(quantified: Set[Identifier], body: Expr): Option[String] = { + val matchers = collect[(Expr, Seq[Expr])] { + case QuantificationMatcher(e, args) => Set(e -> args) + case _ => Set.empty + } (body) + + if (matchers.isEmpty) + return Some("No matchers found.") + + val matcherToQuants = matchers.foldLeft(Map.empty[Expr, Set[Identifier]]) { + case (acc, (m, args)) => acc + (m -> (acc.getOrElse(m, Set.empty) ++ args.flatMap { + case Variable(id) if quantified(id) => Set(id) + case _ => Set.empty[Identifier] + })) } - reporter.debug(" - Verifying model transitivity") - val solverModel = solver.check match { - case Some(true) => - Some(solver.getModel) + val bijectiveMappings = matcherToQuants.filter(_._2.nonEmpty).groupBy(_._2) + if (bijectiveMappings.size > 1) + return Some("Non-bijective mapping for symbol " + bijectiveMappings.head._2.head._1.asString) - case Some(false) => - val msg = "- Transitivity independence not guaranteed for model" - if (silenceErrors) { - reporter.debug(msg) - } else { - reporter.warning(msg) - } - None + def quantifiedArg(e: Expr): Boolean = e match { + case Variable(id) => quantified(id) + case QuantificationMatcher(_, args) => args.forall(quantifiedArg) + case _ => false + } - case None => - val msg = "- Unknown for transitivity independence!?" - if (silenceErrors) { - reporter.debug(msg) - } else { - reporter.warning(msg) - } - None + postTraversal(m => m match { + case QuantificationMatcher(_, args) => + val qArgs = args.filter(quantifiedArg) + + if (qArgs.nonEmpty && qArgs.size < args.size) + return Some("Mixed ground and quantified arguments in " + m.asString) + + case Operator(es, _) if es.collect { case Variable(id) if quantified(id) => id }.nonEmpty => + return Some("Invalid operation on quantifiers " + m.asString) + + case (_: Equals) | (_: And) | (_: Or) | (_: Implies) => // OK + + case Operator(es, _) if (es.flatMap(variablesOf).toSet & quantified).nonEmpty => + return Some("Unandled implications from operation " + m.asString) + + case _ => + }) (body) + + body match { + case Variable(id) if quantified(id) => + Some("Unexpected free quantifier " + id.asString) + case _ => None } + } + + val issues: Iterable[(Seq[Identifier], Expr, String)] = for { + q <- templateGenerator.manager.quantifications.view + if wrapped.eval(q.holds, BooleanType) == Some(BooleanLiteral(true)) + msg <- checkForalls(q.quantifiers.map(_._1).toSet, q.body) + } yield (q.quantifiers.map(_._1), q.body, msg) - solver.pop() - solverModel + if (issues.nonEmpty) { + val (quantifiers, body, msg) = issues.head + reporter.warning("Model soundness not guaranteed for \u2200" + + quantifiers.map(_.asString).mkString(",") + ". " + body.asString+" :\n => " + msg) } - optModel match { - case None => - (false, extractModel(lastModel)) + val typeInsts = templateGenerator.manager.typeInstantiations + val partialInsts = templateGenerator.manager.partialInstantiations - case Some(m) => - val model = extractModel(m) + def extractCond(params: Seq[Identifier], args: Seq[(T, Expr)], structure: Map[T, Identifier]): Seq[Expr] = (params, args) match { + case (id +: rparams, (v, arg) +: rargs) => + if (templateGenerator.manager.isQuantifier(v)) { + structure.get(v) match { + case Some(pid) => Equals(Variable(id), Variable(pid)) +: extractCond(rparams, rargs, structure) + case None => extractCond(rparams, rargs, structure + (v -> id)) + } + } else { + Equals(Variable(id), arg) +: extractCond(rparams, rargs, structure) + } + case _ => Seq.empty + } - val expr = andJoin(constraints.toSeq) - val fullModel = model set freeVars.toSet + new Model(freeVars.toMap.map { case (id, idT) => + val value = wrapped.get(id).getOrElse(simplestValue(id.getType)) + id -> (id.getType match { + case FunctionType(from, to) => + val params = from.map(tpe => FreshIdentifier("x", tpe, true)) + val domain = partialInsts.get(idT).orElse(typeInsts.get(bestRealType(id.getType))).toSeq.flatten + val conditionals = domain.flatMap { case (b, m) => + wrapped.extract(b, m).map { args => + val result = evaluator.eval(application(value, args)).result.getOrElse { + scala.sys.error("Unexpectedly failed to evaluate " + application(value, args)) + } - (evaluator.check(expr, fullModel) match { - case EvaluationResults.CheckSuccess => - reporter.debug("- Model validated.") - true + val cond = if (m.args.exists(arg => templateGenerator.manager.isQuantifier(arg.encoded))) { + extractCond(params, m.args.map(_.encoded) zip args, Map.empty) + } else { + (params zip args).map(p => Equals(Variable(p._1), p._2)) + } - case EvaluationResults.CheckValidityFailure => - reporter.debug("- Invalid model.") - false + cond -> result + } + } - case EvaluationResults.CheckRuntimeFailure(msg) => - if (silenceErrors) { - reporter.debug("- Model leads to evaluation error: " + msg) - } else { - reporter.warning("- Model leads to evaluation error: " + msg) + val filteredConds = conditionals + .foldLeft(Map.empty[Seq[Expr], Expr]) { case (mapping, (conds, result)) => + if (mapping.isDefinedAt(conds)) mapping else mapping + (conds -> result) } - false - case EvaluationResults.CheckQuantificationFailure(msg) => - if (silenceErrors) { - reporter.debug("- Model leads to quantification error: " + msg) - } else { - reporter.warning("- Model leads to quantification error: " + msg) + if (filteredConds.isEmpty) { + // TODO: warning?? + value + } else { + val rest :+ ((_, dflt)) = filteredConds.toSeq.sortBy(_._1.size) + val body = rest.foldLeft(dflt) { case (elze, (conds, res)) => + if (conds.isEmpty) elze else IfExpr(andJoin(conds), res, elze) } - false - }, fullModel) - } + + Lambda(params.map(ValDef(_)), body) + } + + case _ => value + }) + }) } def genericCheck(assumptions: Set[Expr]): Option[Boolean] = { - lastCheckResult = (false, None, None) + foundDefinitiveAnswer = false + + val encoder = templateGenerator.encoder.encodeExpr(freeVars.toMap) _ + val assumptionsSeq : Seq[Expr] = assumptions.toSeq + val encodedAssumptions : Seq[T] = assumptionsSeq.map(encoder) + val encodedToAssumptions : Map[T, Expr] = (encodedAssumptions zip assumptionsSeq).toMap + + def encodedCoreToCore(core: Seq[T]): Set[Expr] = { + core.flatMap(ast => encodedToAssumptions.get(ast) match { + case Some(n @ Not(Variable(_))) => Some(n) + case Some(v @ Variable(_)) => Some(v) + case _ => None + }).toSet + } - while(!hasFoundAnswer && !interrupted) { + while(!foundDefinitiveAnswer && !interrupted) { reporter.debug(" - Running search...") + var quantify = false - solver.push() - solver.assertCnstr(andJoin((assumptions ++ unrollingBank.satisfactionAssumptions).toSeq)) - val res = solver.check + def check[R](clauses: Seq[T])(block: Option[Boolean] => R) = + if (partialModels) solverCheckAssumptions(clauses)(block) else solverCheck(clauses)(block) - reporter.debug(" - Finished search with blocked literals") + val timer = context.timers.solvers.check.start() + check(encodedAssumptions.toSeq ++ unrollingBank.satisfactionAssumptions) { res => + timer.stop() - res match { - case None => - solver.pop() + reporter.debug(" - Finished search with blocked literals") - reporter.ifDebug { debug => - reporter.debug("Solver returned unknown!?") - } - foundAnswer(None) + res match { + case None => + foundAnswer(None) - case Some(true) => // SAT - val (valid, model) = if (!this.disableChecks && requireQuantification) { - validatedModel(silenceErrors = false) - } else { - true -> extractModel(solver.getModel) - } + case Some(true) => // SAT + val (stop, model) = if (interrupted) { + (true, Model.empty) + } else if (partialModels) { + (true, getPartialModel) + } else { + val clauses = templateGenerator.manager.checkClauses + if (clauses.isEmpty) { + (true, extractModel(solverGetModel)) + } else { + reporter.debug(" - Verifying model transitivity") + + val timer = context.timers.solvers.check.start() + solverCheck(clauses) { res => + timer.stop() + + reporter.debug(" - Finished transitivity check") + + res match { + case Some(true) => + (true, getTotalModel) + + case Some(false) => + reporter.debug(" - Transitivity not guaranteed for model") + (false, Model.empty) + + case None => + reporter.warning(" - Unknown for transitivity!?") + (false, Model.empty) + } + } + } + } - solver.pop() - if (valid) { - foundAnswer(Some(true), Some(model)) - } else { - reporter.error("Something went wrong. The model should have been valid, yet we got this : ") - reporter.error(model.asString(context)) - foundAnswer(None, Some(model)) - } + if (!interrupted) { + if (!stop) { + if (!unrollingBank.canInstantiate) { + reporter.error("Something went wrong. The model is not transitive yet we can't instantiate!?") + reporter.error(model.asString(context)) + foundAnswer(None, model) + } else { + quantify = true + } + } else { + val valid = !checkModels || validateModel(model, assumptionsSeq, silenceErrors = false) + + if (valid) { + foundAnswer(Some(true), model) + } else { + reporter.error("Something went wrong. The model should have been valid, yet we got this : ") + reporter.error(model.asString(context)) + foundAnswer(None, model) + } + } + } - case Some(false) if !unrollingBank.canUnroll => - solver.pop() - foundAnswer(Some(false)) + if (interrupted) { + foundAnswer(None) + } - case Some(false) => - //debug("UNSAT BECAUSE: "+solver.getUnsatCore.mkString("\n AND \n")) - //debug("UNSAT BECAUSE: "+core.mkString(" AND ")) - solver.pop() + case Some(false) if !unrollingBank.canUnroll => + solverUnsatCore match { + case Some(core) => + val exprCore = encodedCoreToCore(core) + foundAnswer(Some(false), core = exprCore) + case None => + foundAnswer(Some(false)) + } - if (!interrupted) { - if (feelingLucky) { - reporter.debug(" - Running search without blocked literals (w/ lucky test)") - } else { - reporter.debug(" - Running search without blocked literals (w/o lucky test)") + case Some(false) => + if (unrollUnsatCores) { + solverUnsatCore match { + case Some(core) => + unrollingBank.decreaseAllGenerations() + + for (c <- core) templateGenerator.encoder.extractNot(c) match { + case Some(b) => unrollingBank.promoteBlocker(b) + case None => reporter.fatalError("Unexpected blocker polarity for unsat core unrolling: " + c) + } + case None => + reporter.fatalError("Can't unroll unsat core for incompatible solver " + name) + } } + } + } - solver.push() - solver.assertCnstr(andJoin(assumptions.toSeq ++ unrollingBank.refutationAssumptions)) - val res2 = solver.check + if (!quantify && !foundDefinitiveAnswer && !interrupted) { + if (feelingLucky) { + reporter.debug(" - Running search without blocked literals (w/ lucky test)") + } else { + reporter.debug(" - Running search without blocked literals (w/o lucky test)") + } - res2 match { - case Some(false) => - //reporter.debug("UNSAT WITHOUT Blockers") - foundAnswer(Some(false)) + val timer = context.timers.solvers.check.start() + solverCheckAssumptions(encodedAssumptions.toSeq ++ unrollingBank.refutationAssumptions) { res2 => + timer.stop() - case Some(true) => - if (feelingLucky && !interrupted) { - // we might have been lucky :D - val (valid, model) = validatedModel(silenceErrors = true) - if (valid) foundAnswer(Some(true), Some(model)) - } + reporter.debug(" - Finished search without blocked literals") - case None => - foundAnswer(None) - } - solver.pop() - } + res2 match { + case Some(false) => + solverUnsatCore match { + case Some(core) => + val exprCore = encodedCoreToCore(core) + foundAnswer(Some(false), core = exprCore) + case None => + foundAnswer(Some(false)) + } - if(interrupted) { - foundAnswer(None) + case Some(true) => + if (this.feelingLucky && !interrupted) { + // we might have been lucky :D + val model = extractModel(solverGetModel) + val valid = validateModel(model, assumptionsSeq, silenceErrors = true) + if (valid) foundAnswer(Some(true), model) + } + + case None => + foundAnswer(None) } + } + } - if(!hasFoundAnswer) { - reporter.debug("- We need to keep going.") + if (!foundDefinitiveAnswer && !interrupted) { + reporter.debug("- We need to keep going") - // unfolling `unfoldFactor` times - for (i <- 1 to unfoldFactor.toInt) { - val toRelease = unrollingBank.getBlockersToUnlock + reporter.debug(" - more instantiations") + val newClauses = unrollingBank.instantiateQuantifiers(force = quantify) - reporter.debug(" - more unrollings") + for (cls <- newClauses) { + solverAssert(cls) + } - val newClauses = unrollingBank.unrollBehind(toRelease) + reporter.debug(" - finished instantiating") - for (ncl <- newClauses) { - solver.assertCnstr(ncl) - } - } + // unfolling `unfoldFactor` times + for (i <- 1 to unfoldFactor.toInt) { + val toRelease = unrollingBank.getBlockersToUnlock - reporter.debug(" - finished unrolling") + reporter.debug(" - more unrollings") + + val newClauses = unrollingBank.unrollBehind(toRelease) + + for (ncl <- newClauses) { + solverAssert(ncl) } + } + + reporter.debug(" - finished unrolling") } } - if(interrupted) { + if (interrupted) { None } else { - lastCheckResult._2 + definitiveAnswer } } +} + +class UnrollingSolver(val context: LeonContext, val program: Program, underlying: Solver) + extends AbstractUnrollingSolver[Expr] { + + override val name = "U:"+underlying.name - def getModel: HenkinModel = { - lastCheckResult match { - case (true, Some(true), Some(m)) => - m.filter(freeVars.toSet) - case _ => - HenkinModel.empty + def free() { + underlying.free() + } + + val printable = (e: Expr) => e + + val templateEncoder = new TemplateEncoder[Expr] { + def encodeId(id: Identifier): Expr= { + Variable(id.freshen) + } + + def encodeExpr(bindings: Map[Identifier, Expr])(e: Expr): Expr = { + replaceFromIDs(bindings, e) + } + + def substitute(substMap: Map[Expr, Expr]): Expr => Expr = { + (e: Expr) => replace(substMap, e) + } + + def mkNot(e: Expr) = not(e) + def mkOr(es: Expr*) = orJoin(es) + def mkAnd(es: Expr*) = andJoin(es) + def mkEquals(l: Expr, r: Expr) = Equals(l, r) + def mkImplies(l: Expr, r: Expr) = implies(l, r) + + def extractNot(e: Expr): Option[Expr] = e match { + case Not(b) => Some(b) + case _ => None } } - override def reset() = { + val solver = underlying + + def assertCnstr(expression: Expr): Unit = { + assertCnstr(expression, variablesOf(expression).map(id => id -> id.toVariable).toMap) + } + + def solverAssert(cnstr: Expr): Unit = { + solver.assertCnstr(cnstr) + } + + def solverCheck[R](clauses: Seq[Expr])(block: Option[Boolean] => R): R = { + solver.push() + for (cls <- clauses) solver.assertCnstr(cls) + val res = solver.check + val r = block(res) + solver.pop() + r + } + + def solverUnsatCore = None + + def solverGetModel: ModelWrapper = new ModelWrapper { + val model = solver.getModel + def get(id: Identifier): Option[Expr] = model.get(id) + def eval(elem: Expr, tpe: TypeTree): Option[Expr] = evaluator.eval(elem, model).result + override def toString = model.toMap.mkString("\n") + } + + override def dbg(msg: => Any) = underlying.dbg(msg) + + override def push(): Unit = { + super.push() + solver.push() + } + + override def pop(): Unit = { + super.pop() + solver.pop() + } + + override def foundAnswer(res: Option[Boolean], model: Model = Model.empty, core: Set[Expr] = Set.empty) = { + super.foundAnswer(res, model, core) + + if (!interrupted && res == None && model == None) { + reporter.ifDebug { debug => + debug("Unknown result!?") + } + } + } + + override def reset(): Unit = { underlying.reset() - lastCheckResult = (false, None, None) - freeVars.reset() - constraints.reset() - interrupted = false + super.reset() } override def interrupt(): Unit = { - interrupted = true + super.interrupt() solver.interrupt() } override def recoverInterrupt(): Unit = { solver.recoverInterrupt() - interrupted = false + super.recoverInterrupt() } } diff --git a/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala new file mode 100644 index 0000000000000000000000000000000000000000..b94233f285e1fe63c486dd2711c63e115ef1dd7b --- /dev/null +++ b/src/main/scala/leon/solvers/combinators/Z3StringCapableSolver.scala @@ -0,0 +1,271 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package combinators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Quantification._ +import purescala.Constructors._ +import purescala.Expressions._ +import purescala.ExprOps._ +import purescala.Types._ +import purescala.DefOps +import purescala.TypeOps +import purescala.Extractors._ +import utils._ +import templates._ +import evaluators._ +import Template._ +import leon.solvers.z3.Z3StringConversion +import leon.utils.Bijection +import leon.solvers.z3.StringEcoSystem + +object Z3StringCapableSolver { + def thatShouldBeConverted(t: TypeTree): Boolean = TypeOps.exists{ _== StringType }(t) + def thatShouldBeConverted(e: Expr): Boolean = exists(e => thatShouldBeConverted(e.getType))(e) + def thatShouldBeConverted(id: Identifier): Boolean = thatShouldBeConverted(id.getType) + def thatShouldBeConverted(vd: ValDef): Boolean = thatShouldBeConverted(vd.id) + def thatShouldBeConverted(fd: FunDef): Boolean = { + (fd.body exists thatShouldBeConverted)|| (fd.paramIds exists thatShouldBeConverted) + } + def thatShouldBeConverted(cd: ClassDef): Boolean = cd match { + case ccd:CaseClassDef => ccd.fields.exists(thatShouldBeConverted) + case _ => false + } + def thatShouldBeConverted(p: Program): Boolean = { + (p.definedFunctions exists thatShouldBeConverted) || + (p.definedClasses exists thatShouldBeConverted) + } + + def convert(p: Program): (Program, Option[Z3StringConversion]) = { + val converter = new Z3StringConversion(p) + import converter.Forward._ + var hasStrings = false + val program_with_strings = converter.getProgram + val (program_with_correct_classes, cdMap, idMap, fdMap) = if(program_with_strings.definedClasses.exists{ case c: CaseClassDef => c.fieldsIds.exists(id => TypeOps.exists{ _ == StringType}(id.getType)) case _ => false}) { + val res:(Program, Map[ClassDef, ClassDef], Map[Identifier, Identifier], Map[FunDef, FunDef]) = DefOps.replaceCaseClassDefs(program_with_strings)((cd: ClassDef) => { + cd match { + case acd:AbstractClassDef => None + case ccd:CaseClassDef => + if(ccd.fieldsIds.exists(id => TypeOps.exists(StringType == _)(id.getType))) { + Some((parent: Option[AbstractClassType]) => ccd.duplicate(convertId(ccd.id), ccd.tparams, ccd.fieldsIds.map(id => ValDef(convertId(id))), parent, ccd.isCaseObject)) + } else None + } + }) + converter.mappedVariables.clear() // We will compose them later, they have been stored in idMap + res + } else { + (program_with_strings, Map[ClassDef, ClassDef](), Map[Identifier, Identifier](), Map[FunDef, FunDef]()) + } + val fdMapInverse = fdMap.map(kv => kv._2 -> kv._1).toMap + val idMapInverse = idMap.map(kv => kv._2 -> kv._1).toMap + var globalFdMap = Map[FunDef, (Map[Identifier, Identifier], FunDef)]() + val (new_program, _) = DefOps.replaceFunDefs(program_with_correct_classes)((fd: FunDef) => { + globalFdMap.get(fd).map(_._2).orElse( + if(thatShouldBeConverted(fd)) { + val idMap = fd.params.zip(fd.params).map(origvd_vd => origvd_vd._1.id -> convertId(origvd_vd._2.id)).toMap + val newFdId = convertId(fd.id) + val newFd = fd.duplicate(newFdId, + fd.tparams, + fd.params.map(vd => ValDef(idMap(vd.id))), + convertType(fd.returnType)) + globalFdMap += fd -> ((idMap, newFd)) + hasStrings = hasStrings || (program_with_strings.library.escape.get != fd) + Some(newFd) + } else None + ) + }) + if(!hasStrings) { + (p, None) + } else { + converter.globalFdMap ++= globalFdMap.view.map(kv => (kv._1, kv._2._2)) + for((fd, (idMap, newFd)) <- globalFdMap) { + implicit val idVarMap = idMap.mapValues(id => Variable(id)) + newFd.fullBody = convertExpr(newFd.fullBody) + } + converter.mappedVariables.composeA(id => idMapInverse.getOrElse(id, id)) + converter.globalFdMap.composeA(fd => fdMapInverse.getOrElse(fd, fd)) + converter.globalClassMap ++= cdMap + (new_program, Some(converter)) + } + } +} + +abstract class Z3StringCapableSolver[+TUnderlying <: Solver]( + val context: LeonContext, + val program: Program, + val underlyingConstructor: (Program, Option[Z3StringConversion]) => TUnderlying) extends Solver { + + protected val (new_program, optConverter) = Z3StringCapableSolver.convert(program) + var someConverter = optConverter + + val underlying = underlyingConstructor(new_program, someConverter) + var solverInvokedWithStrings = false + + def getModel: leon.solvers.Model = { + val model = underlying.getModel + someConverter match { + case None => model + case Some(converter) => + val ids = model.ids.toSeq + val exprs = ids.map(model.apply) + import converter.Backward._ + val original_ids = ids.map(convertId) + val original_exprs = exprs.map{ case e => convertExpr(e)(Map()) } + + model match { + case hm: PartialModel => + val new_domain = new Domains( + hm.domains.lambdas.map(kv => + (convertExpr(kv._1)(Map()).asInstanceOf[Lambda], + kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap, + hm.domains.tpes.map(kv => + (convertType(kv._1), + kv._2.map(e => e.map(e => convertExpr(e)(Map()))))).toMap + ) + + new PartialModel(original_ids.zip(original_exprs).toMap, new_domain) + case _ => + new Model(original_ids.zip(original_exprs).toMap) + } + } + } + + // Members declared in leon.utils.Interruptible + def interrupt(): Unit = underlying.interrupt() + def recoverInterrupt(): Unit = underlying.recoverInterrupt() + + // Converts expression on the fly if needed, creating a string converter if needed. + def convertExprOnTheFly(expression: Expr, withConverter: Z3StringConversion => Expr): Expr = { + someConverter match { + case None => + if(solverInvokedWithStrings || exists(e => TypeOps.exists(StringType == _)(e.getType))(expression)) { // On the fly conversion + solverInvokedWithStrings = true + val c = new Z3StringConversion(program) + someConverter = Some(c) + withConverter(c) + } else expression + case Some(converter) => + withConverter(converter) + } + } + + // Members declared in leon.solvers.Solver + def assertCnstr(expression: Expr): Unit = { + someConverter.map{converter => + import converter.Forward._ + val newExpression = convertExpr(expression)(Map()) + underlying.assertCnstr(newExpression) + }.getOrElse{ + underlying.assertCnstr(convertExprOnTheFly(expression, _.Forward.convertExpr(expression)(Map()))) + } + } + def getUnsatCore: Set[Expr] = { + someConverter.map{converter => + import converter.Backward._ + underlying.getUnsatCore map (e => convertExpr(e)(Map())) + }.getOrElse(underlying.getUnsatCore) + } + + def check: Option[Boolean] = underlying.check + def free(): Unit = underlying.free() + def pop(): Unit = underlying.pop() + def push(): Unit = underlying.push() + def reset(): Unit = underlying.reset() + def name: String = underlying.name +} + +import z3._ + +trait Z3StringAbstractZ3Solver[TUnderlying <: Solver] extends AbstractZ3Solver { self: Z3StringCapableSolver[TUnderlying] => +} + +trait Z3StringNaiveAssumptionSolver[TUnderlying <: Solver] extends NaiveAssumptionSolver { self: Z3StringCapableSolver[TUnderlying] => +} + +trait Z3StringEvaluatingSolver[TUnderlying <: EvaluatingSolver] extends EvaluatingSolver{ self: Z3StringCapableSolver[TUnderlying] => + // Members declared in leon.solvers.EvaluatingSolver + val useCodeGen: Boolean = underlying.useCodeGen +} + +class ConvertibleCodeGenEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) + extends CodeGenEvaluator(context, originalProgram) { + + override def compile(expression: Expr, args: Seq[Identifier]) : Option[solvers.Model=>EvaluationResult] = { + import converter._ + super.compile(Backward.convertExpr(expression)(Map()), args.map(Backward.convertId)) + .map(evaluator => (m: Model) => Forward.convertResult(evaluator(Backward.convertModel(m))) + ) + } +} + +class ConvertibleDefaultEvaluator(context: LeonContext, originalProgram: Program, val converter: Z3StringConversion) + extends DefaultEvaluator(context, originalProgram) { + + override def eval(ex: Expr, model: Model): EvaluationResults.Result[Expr] = { + import converter._ + Forward.convertResult(super.eval(Backward.convertExpr(ex)(Map()), Backward.convertModel(model))) + } +} + +class FairZ3SolverWithBackwardEvaluator(context: LeonContext, program: Program, + originalProgram: Program, someConverter: Option[Z3StringConversion]) extends FairZ3Solver(context, program) { + override lazy val evaluator: DeterministicEvaluator = { // We evaluate expressions using the original evaluator + someConverter match { + case Some(converter) => + if (useCodeGen) { + new ConvertibleCodeGenEvaluator(context, originalProgram, converter) + } else { + new ConvertibleDefaultEvaluator(context, originalProgram, converter) + } + case None => + if (useCodeGen) { + new CodeGenEvaluator(context, program) + } else { + new DefaultEvaluator(context, program) + } + } + } +} + +class Z3StringFairZ3Solver(context: LeonContext, program: Program) + extends Z3StringCapableSolver(context, program, + (prgm: Program, someConverter: Option[Z3StringConversion]) => + new FairZ3SolverWithBackwardEvaluator(context, prgm, program, someConverter)) + with Z3StringEvaluatingSolver[FairZ3Solver] { + + // Members declared in leon.solvers.z3.AbstractZ3Solver + protected[leon] val z3cfg: _root_.z3.scala.Z3Config = underlying.z3cfg + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + someConverter match { + case None => underlying.checkAssumptions(assumptions.map(e => convertExprOnTheFly(e, _.Forward.convertExpr(e)(Map())))) + case Some(converter) => + underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + } + } +} + +class Z3StringUnrollingSolver(context: LeonContext, program: Program, underlyingSolverConstructor: Program => Solver) + extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => + new UnrollingSolver(context, program, underlyingSolverConstructor(program))) + with Z3StringNaiveAssumptionSolver[UnrollingSolver] + with Z3StringEvaluatingSolver[UnrollingSolver] { + + override def getUnsatCore = super[Z3StringNaiveAssumptionSolver].getUnsatCore +} + +class Z3StringSMTLIBZ3QuantifiedSolver(context: LeonContext, program: Program) + extends Z3StringCapableSolver(context, program, (program: Program, converter: Option[Z3StringConversion]) => + new smtlib.SMTLIBZ3QuantifiedSolver(context, program)) { + + override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { + someConverter match { + case None => underlying.checkAssumptions(assumptions) + case Some(converter) => + underlying.checkAssumptions(assumptions map (e => converter.Forward.convertExpr(e)(Map()))) + } + } +} + diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala index 87cc849b41d2507d7e0ca53a351b2f605d387b46..03cf3aef32fc59c33283c70cebfa2f918e549e76 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBCVC4Target.scala @@ -63,7 +63,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { RawArrayValue(k, Map(), fromSMT(elem, v)) case ft @ FunctionType(from, to) => - PartialLambda(Seq.empty, Some(fromSMT(elem, to)), ft) + FiniteLambda(Seq.empty, fromSMT(elem, to), ft) case MapType(k, v) => FiniteMap(Map(), k, v) @@ -75,7 +75,7 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { RawArrayValue(k, Map(), fromSMT(elem, v)) case ft @ FunctionType(from, to) => - PartialLambda(Seq.empty, Some(fromSMT(elem, to)), ft) + FiniteLambda(Seq.empty, fromSMT(elem, to), ft) case MapType(k, v) => FiniteMap(Map(), k, v) @@ -88,9 +88,9 @@ trait SMTLIBCVC4Target extends SMTLIBTarget { RawArrayValue(k, elems + (fromSMT(key, k) -> fromSMT(elem, v)), base) case FunctionType(from, v) => - val PartialLambda(mapping, dflt, ft) = fromSMT(arr, otpe) + val FiniteLambda(mapping, dflt, ft) = fromSMT(arr, otpe) val args = unwrapTuple(fromSMT(key, tupleTypeWrap(from)), from.size) - PartialLambda(mapping :+ (args -> fromSMT(elem, v)), dflt, ft) + FiniteLambda(mapping :+ (args -> fromSMT(elem, v)), dflt, ft) case MapType(k, v) => val FiniteMap(elems, k, v) = fromSMT(arr, otpe) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index 47017bcf1471770f1b1e9fb574a81bed34f4c515..4b838694956205c17432abc229012eb19f98b143 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -20,7 +20,7 @@ import _root_.smtlib.printer.{ RecursivePrinter => SMTPrinter } import _root_.smtlib.parser.Commands.{ Constructor => SMTConstructor, FunDef => SMTFunDef, - Assert => _, + Assert => SMTAssert, _ } import _root_.smtlib.parser.Terms.{ @@ -104,6 +104,7 @@ trait SMTLIBTarget extends Interruptible { interpreter.eval(cmd) match { case err @ ErrorResponse(msg) if !hasError && !interrupted && !rawOut => reporter.warning(s"Unexpected error from $targetName solver: $msg") + //println(Thread.currentThread().getStackTrace.map(_.toString).take(10).mkString("\n")) // Store that there was an error. Now all following check() // invocations will return None addError() @@ -205,7 +206,7 @@ trait SMTLIBTarget extends Interruptible { case ft @ FunctionType(from, to) => val elems = r.elems.toSeq.map { case (k, v) => unwrapTuple(k, from.size) -> v } - PartialLambda(elems, Some(r.default), ft) + FiniteLambda(elems, r.default, ft) case MapType(from, to) => // We expect a RawArrayValue with keys in from and values in Option[to], @@ -532,7 +533,11 @@ trait SMTLIBTarget extends Interruptible { case gv @ GenericValue(tpe, n) => genericValues.cachedB(gv) { - declareVariable(FreshIdentifier("gv" + n, tpe)) + val v = declareVariable(FreshIdentifier("gv" + n, tpe)) + for ((ogv, ov) <- genericValues.aToB if ogv.getType == tpe) { + emit(SMTAssert(Core.Not(Core.Equals(v, ov)))) + } + v } /** diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index 2d0ac830bcd2773e65e0a3cf5f6b57c753fdd8ca..f9f8e386c628c3550983afece5a984197e3e07ba 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -8,15 +8,15 @@ import purescala.Common._ import purescala.Expressions._ import purescala.Constructors._ import purescala.Types._ -import purescala.Definitions._ + import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} import _root_.smtlib.interpreters.Z3Interpreter import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} import _root_.smtlib.theories.ArraysEx -import leon.solvers.z3.Z3StringConversion -trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { +trait SMTLIBZ3Target extends SMTLIBTarget { + def targetName = "z3" def interpreterOps(ctx: LeonContext) = { @@ -40,11 +40,11 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { override protected def declareSort(t: TypeTree): Sort = { val tpe = normalizeType(t) sorts.cachedB(tpe) { - convertType(tpe) match { + tpe match { case SetType(base) => super.declareSort(BooleanType) declareSetSort(base) - case t => + case _ => super.declareSort(t) } } @@ -69,13 +69,9 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { Sort(SMTIdentifier(setSort.get), Seq(declareSort(of))) } - override protected def fromSMT(t: Term, expected_otpe: Option[TypeTree] = None) + override protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { - val otpe = expected_otpe match { - case Some(StringType) => Some(listchar) - case _ => expected_otpe - } - val res = (t, otpe) match { + (t, otpe) match { case (SimpleSymbol(s), Some(tp: TypeParameter)) => val n = s.name.split("!").toList.last GenericValue(tp, n.toInt) @@ -100,16 +96,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { case _ => super.fromSMT(t, otpe) } - expected_otpe match { - case Some(StringType) => - StringLiteral(convertToString(res)(program)) - case _ => res - } - } - - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = toSMT(e) - def targetApplication(tfd: TypedFunDef, args: Seq[Term])(implicit bindings: Map[Identifier, Term]): Term = { - FunctionApplication(declareFunction(tfd), args) } override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { @@ -146,7 +132,6 @@ trait SMTLIBZ3Target extends SMTLIBTarget with Z3StringConversion[Term] { case SetIntersection(l, r) => ArrayMap(SSymbol("and"), toSMT(l), toSMT(r)) - case StringConverted(result) => result case _ => super.toSMT(e) } diff --git a/src/main/scala/leon/solvers/templates/DatatypeManager.scala b/src/main/scala/leon/solvers/templates/DatatypeManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..dcfa67e8343f88fb53dfa0a18212e0a2aa383016 --- /dev/null +++ b/src/main/scala/leon/solvers/templates/DatatypeManager.scala @@ -0,0 +1,219 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon +package solvers +package templates + +import purescala.Common._ +import purescala.Definitions._ +import purescala.Expressions._ +import purescala.Constructors._ +import purescala.Extractors._ +import purescala.ExprOps._ +import purescala.Types._ +import purescala.TypeOps.bestRealType + +import utils._ +import utils.SeqUtils._ +import Instantiation._ +import Template._ + +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +case class FreshFunction(expr: Expr) extends Expr with Extractable { + val getType = BooleanType + val extract = Some(Seq(expr), (exprs: Seq[Expr]) => FreshFunction(exprs.head)) +} + +object DatatypeTemplate { + + def apply[T]( + encoder: TemplateEncoder[T], + manager: DatatypeManager[T], + tpe: TypeTree + ) : DatatypeTemplate[T] = { + val id = FreshIdentifier("x", tpe, true) + val expr = manager.typeUnroller(Variable(id)) + + val pathVar = FreshIdentifier("b", BooleanType, true) + + var condVars = Map[Identifier, T]() + var condTree = Map[Identifier, Set[Identifier]](pathVar -> Set.empty).withDefaultValue(Set.empty) + def storeCond(pathVar: Identifier, id: Identifier): Unit = { + condVars += id -> encoder.encodeId(id) + condTree += pathVar -> (condTree(pathVar) + id) + } + + var guardedExprs = Map[Identifier, Seq[Expr]]() + def storeGuarded(pathVar: Identifier, expr: Expr): Unit = { + val prev = guardedExprs.getOrElse(pathVar, Nil) + guardedExprs += pathVar -> (expr +: prev) + } + + def requireDecomposition(e: Expr): Boolean = exists { + case _: FunctionInvocation => true + case _ => false + } (e) + + def rec(pathVar: Identifier, expr: Expr): Unit = expr match { + case i @ IfExpr(cond, thenn, elze) if requireDecomposition(i) => + val newBool1: Identifier = FreshIdentifier("b", BooleanType, true) + val newBool2: Identifier = FreshIdentifier("b", BooleanType, true) + + storeCond(pathVar, newBool1) + storeCond(pathVar, newBool2) + + storeGuarded(pathVar, or(Variable(newBool1), Variable(newBool2))) + storeGuarded(pathVar, or(not(Variable(newBool1)), not(Variable(newBool2)))) + storeGuarded(pathVar, Equals(Variable(newBool1), cond)) + + rec(newBool1, thenn) + rec(newBool2, elze) + + case a @ And(es) if requireDecomposition(a) => + val partitions = groupWhile(es)(!requireDecomposition(_)) + for (e <- partitions.map(andJoin)) rec(pathVar, e) + + case _ => + storeGuarded(pathVar, expr) + } + + rec(pathVar, expr) + + val (idT, pathVarT) = (encoder.encodeId(id), encoder.encodeId(pathVar)) + val (clauses, blockers, _, functions, _, _) = Template.encode(encoder, + pathVar -> pathVarT, Seq(id -> idT), condVars, Map.empty, guardedExprs, Seq.empty, Seq.empty) + + new DatatypeTemplate[T](encoder, manager, + pathVar -> pathVarT, idT, condVars, condTree, clauses, blockers, functions) + } +} + +class DatatypeTemplate[T] private ( + val encoder: TemplateEncoder[T], + val manager: DatatypeManager[T], + val pathVar: (Identifier, T), + val argument: T, + val condVars: Map[Identifier, T], + val condTree: Map[Identifier, Set[Identifier]], + val clauses: Seq[T], + val blockers: Map[T, Set[TemplateCallInfo[T]]], + val functions: Set[(T, FunctionType, T)]) extends Template[T] { + + val args = Seq(argument) + val exprVars = Map.empty[Identifier, T] + val applications = Map.empty[T, Set[App[T]]] + val lambdas = Seq.empty[LambdaTemplate[T]] + val matchers = Map.empty[T, Set[Matcher[T]]] + val quantifications = Seq.empty[QuantificationTemplate[T]] + + def instantiate(blocker: T, arg: T): Instantiation[T] = instantiate(blocker, Seq(Left(arg))) +} + +class DatatypeManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(encoder) { + + private val classCache: MutableMap[ClassType, FunDef] = MutableMap.empty + + private def classTypeUnroller(ct: ClassType): FunDef = classCache.get(ct) match { + case Some(fd) => fd + case None => + val param = ValDef(FreshIdentifier("x", ct)) + val fd = new FunDef(FreshIdentifier("unroll"+ct.classDef.id), Seq.empty, Seq(param), BooleanType) + classCache += ct -> fd + + val matchers = ct match { + case (act: AbstractClassType) => act.knownCCDescendants + case (cct: CaseClassType) => Seq(cct) + } + + fd.body = Some(MatchExpr(param.toVariable, matchers.map { cct => + val pattern = CaseClassPattern(None, cct, cct.fields.map(vd => WildcardPattern(Some(vd.id)))) + val expr = andJoin(cct.fields.map(vd => typeUnroller(Variable(vd.id)))) + MatchCase(pattern, None, expr) + })) + + fd + } + + private val requireChecking: MutableSet[ClassType] = MutableSet.empty + private val requireCache: MutableMap[TypeTree, Boolean] = MutableMap.empty + + private def requireTypeUnrolling(tpe: TypeTree): Boolean = requireCache.get(tpe) match { + case Some(res) => res + case None => + val res = tpe match { + case ft: FunctionType => true + case ct: CaseClassType if ct.classDef.hasParent => true + case ct: ClassType if requireChecking(ct.root) => false + case ct: ClassType => + requireChecking += ct.root + val classTypes = ct.root +: (ct.root match { + case (act: AbstractClassType) => act.knownDescendants + case (cct: CaseClassType) => Seq.empty + }) + + classTypes.exists(ct => ct.classDef.hasInvariant || ct.fieldsTypes.exists(requireTypeUnrolling)) + + case BooleanType | UnitType | CharType | IntegerType | + RealType | Int32Type | StringType | (_: TypeParameter) => false + + case NAryType(tpes, _) => tpes.exists(requireTypeUnrolling) + } + + requireCache += tpe -> res + res + } + + def typeUnroller(expr: Expr): Expr = expr.getType match { + case tpe if !requireTypeUnrolling(tpe) => + BooleanLiteral(true) + + case ct: ClassType => + val inv = if (ct.classDef.root.hasInvariant) { + Seq(FunctionInvocation(ct.root.invariant.get, Seq(expr))) + } else { + Seq.empty + } + + val subtype = if (ct != ct.root) { + Seq(IsInstanceOf(expr, ct)) + } else { + Seq.empty + } + + val induct = if (!ct.classDef.isInductive) { + val matchers = ct.root match { + case (act: AbstractClassType) => act.knownCCDescendants + case (cct: CaseClassType) => Seq(cct) + } + + MatchExpr(expr, matchers.map { cct => + val pattern = CaseClassPattern(None, cct, cct.fields.map(vd => WildcardPattern(Some(vd.id)))) + val expr = andJoin(cct.fields.map(vd => typeUnroller(Variable(vd.id)))) + MatchCase(pattern, None, expr) + }) + } else { + val fd = classTypeUnroller(ct.root) + FunctionInvocation(fd.typed, Seq(expr)) + } + + andJoin(inv ++ subtype :+ induct) + + case TupleType(tpes) => + andJoin(tpes.zipWithIndex.map(p => typeUnroller(TupleSelect(expr, p._2 + 1)))) + + case FunctionType(_, _) => + FreshFunction(expr) + + case _ => scala.sys.error("TODO") + } + + private val typeCache: MutableMap[TypeTree, DatatypeTemplate[T]] = MutableMap.empty + + protected def typeTemplate(tpe: TypeTree): DatatypeTemplate[T] = typeCache.getOrElse(tpe, { + val template = DatatypeTemplate(encoder, this, tpe) + typeCache += tpe -> template + template + }) +} + diff --git a/src/main/scala/leon/solvers/templates/LambdaManager.scala b/src/main/scala/leon/solvers/templates/LambdaManager.scala index 036654c25ab64839efa5af4cb0d49a6566ffcf20..a63fdd7a6e06d7ea9cc0b9d8b7ac1f0b89219bab 100644 --- a/src/main/scala/leon/solvers/templates/LambdaManager.scala +++ b/src/main/scala/leon/solvers/templates/LambdaManager.scala @@ -5,16 +5,22 @@ package solvers package templates import purescala.Common._ +import purescala.Definitions._ import purescala.Expressions._ +import purescala.Constructors._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps.bestRealType import utils._ +import utils.SeqUtils._ import Instantiation._ import Template._ -case class App[T](caller: T, tpe: FunctionType, args: Seq[Arg[T]]) { +import scala.collection.mutable.{Map => MutableMap, Set => MutableSet} + +case class App[T](caller: T, tpe: FunctionType, args: Seq[Arg[T]], encoded: T) { override def toString = "(" + caller + " : " + tpe + ")" + args.map(_.encoded).mkString("(", ",", ")") } @@ -39,10 +45,12 @@ object LambdaTemplate { val id = ids._2 val tpe = ids._1.getType.asInstanceOf[FunctionType] - val (clauses, blockers, applications, matchers, templateString) = - Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, + val (clauses, blockers, applications, functions, matchers, templateString) = + Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, quantifications, substMap = baseSubstMap + ids, optApp = Some(id -> tpe)) + assert(functions.isEmpty, "Only synthetic type explorers should introduce functions!") + val lambdaString : () => String = () => { "Template for lambda " + ids._1 + ": " + lambda + " is :\n" + templateString() } @@ -63,11 +71,11 @@ object LambdaTemplate { clauses, blockers, applications, - quantifications, - matchers, lambdas, + matchers, + quantifications, keyDeps, - key, + key -> structSubst, lambdaString ) } @@ -75,7 +83,7 @@ object LambdaTemplate { trait KeyedTemplate[T, E <: Expr] { val dependencies: Map[Identifier, T] - val structuralKey: E + val structure: E lazy val key: (E, Seq[T]) = { def rec(e: Expr): Seq[Identifier] = e match { @@ -91,7 +99,7 @@ trait KeyedTemplate[T, E <: Expr] { case _ => Seq.empty } - structuralKey -> rec(structuralKey).distinct.map(dependencies) + structure -> rec(structure).map(dependencies) } } @@ -107,15 +115,17 @@ class LambdaTemplate[T] private ( val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], - val quantifications: Seq[QuantificationTemplate[T]], - val matchers: Map[T, Set[Matcher[T]]], val lambdas: Seq[LambdaTemplate[T]], + val matchers: Map[T, Set[Matcher[T]]], + val quantifications: Seq[QuantificationTemplate[T]], val dependencies: Map[Identifier, T], - val structuralKey: Lambda, + val struct: (Lambda, Map[Identifier, Identifier]), stringRepr: () => String) extends Template[T] with KeyedTemplate[T, Lambda] { val args = arguments.map(_._2) - val tpe = ids._1.getType.asInstanceOf[FunctionType] + val tpe = bestRealType(ids._1.getType).asInstanceOf[FunctionType] + val functions: Set[(T, FunctionType, T)] = Set.empty + val (structure, structSubst) = struct def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): LambdaTemplate[T] = { val newStart = substituter(start) @@ -135,14 +145,14 @@ class LambdaTemplate[T] private ( )) } - val newQuantifications = quantifications.map(_.substitute(substituter, matcherSubst)) + val newLambdas = lambdas.map(_.substitute(substituter, matcherSubst)) val newMatchers = matchers.map { case (b, ms) => val bp = if (b == start) newStart else b bp -> ms.map(_.substitute(substituter, matcherSubst)) } - val newLambdas = lambdas.map(_.substitute(substituter, matcherSubst)) + val newQuantifications = quantifications.map(_.substitute(substituter, matcherSubst)) val newDependencies = dependencies.map(p => p._1 -> substituter(p._2)) @@ -158,11 +168,11 @@ class LambdaTemplate[T] private ( newClauses, newBlockers, newApplications, - newQuantifications, - newMatchers, newLambdas, + newMatchers, + newQuantifications, newDependencies, - structuralKey, + struct, stringRepr ) } @@ -172,8 +182,8 @@ class LambdaTemplate[T] private ( new LambdaTemplate[T]( ids._1 -> idT, encoder, manager, pathVar, arguments, condVars, exprVars, condTree, clauses map substituter, // make sure the body-defining clause is inlined! - blockers, applications, quantifications, matchers, lambdas, - dependencies, structuralKey, stringRepr + blockers, applications, lambdas, matchers, quantifications, + dependencies, struct, stringRepr ) } @@ -185,27 +195,90 @@ class LambdaTemplate[T] private ( } } -class LambdaManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(encoder) { +class LambdaManager[T](encoder: TemplateEncoder[T]) extends DatatypeManager(encoder) { private[templates] lazy val trueT = encoder.encodeExpr(Map.empty)(BooleanLiteral(true)) protected[templates] val byID = new IncrementalMap[T, LambdaTemplate[T]] protected val byType = new IncrementalMap[FunctionType, Map[(Expr, Seq[T]), LambdaTemplate[T]]].withDefaultValue(Map.empty) protected val applications = new IncrementalMap[FunctionType, Set[(T, App[T])]].withDefaultValue(Set.empty) - protected val freeLambdas = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) + protected val knownFree = new IncrementalMap[FunctionType, Set[T]].withDefaultValue(Set.empty) + protected val maybeFree = new IncrementalMap[FunctionType, Set[(T, T)]].withDefaultValue(Set.empty) + protected val freeBlockers = new IncrementalMap[FunctionType, Set[(T, T)]].withDefaultValue(Set.empty) private val instantiated = new IncrementalSet[(T, App[T])] override protected def incrementals: List[IncrementalState] = - super.incrementals ++ List(byID, byType, applications, freeLambdas, instantiated) + super.incrementals ++ List(byID, byType, applications, knownFree, maybeFree, freeBlockers, instantiated) + + def registerFunction(b: T, tpe: FunctionType, f: T): Instantiation[T] = { + val ft = bestRealType(tpe).asInstanceOf[FunctionType] + val bs = fixpoint((bs: Set[T]) => bs ++ bs.flatMap(blockerParents))(Set(b)) + + val (known, neqClauses) = if ((bs intersect typeEnablers).nonEmpty) { + maybeFree += ft -> (maybeFree(ft) + (b -> f)) + (false, byType(ft).values.toSeq.map { t => + encoder.mkImplies(b, encoder.mkNot(encoder.mkEquals(t.ids._2, f))) + }) + } else { + knownFree += ft -> (knownFree(ft) + f) + (true, byType(ft).values.toSeq.map(t => encoder.mkNot(encoder.mkEquals(t.ids._2, f)))) + } - def registerFree(lambdas: Seq[(Identifier, T)]): Unit = { - for ((id, idT) <- lambdas) id.getType match { - case ft: FunctionType => - freeLambdas += ft -> (freeLambdas(ft) + idT) - case _ => + val extClauses = freeBlockers(tpe).map { case (oldB, freeF) => + val equals = encoder.mkEquals(f, freeF) + val nextB = encoder.encodeId(FreshIdentifier("b_or", BooleanType, true)) + val extension = encoder.mkOr(if (known) equals else encoder.mkAnd(b, equals), nextB) + encoder.mkEquals(oldB, extension) + } + + val instantiation = Instantiation.empty[T] withClauses (neqClauses ++ extClauses) + + applications(tpe).foldLeft(instantiation) { + case (instantiation, (app @ (_, App(caller, _, args, _)))) => + val equals = encoder.mkAnd(b, encoder.mkEquals(f, caller)) + instantiation withApp (app -> TemplateAppInfo(f, equals, args)) } } + def assumptions: Seq[T] = freeBlockers.toSeq.flatMap(_._2.map(p => encoder.mkNot(p._1))) + + private val typeBlockers = new IncrementalMap[T, T]() + private val typeEnablers: MutableSet[T] = MutableSet.empty + + private def typeUnroller(blocker: T, app: App[T]): Instantiation[T] = typeBlockers.get(app.encoded) match { + case Some(typeBlocker) => + implies(blocker, typeBlocker) + (Seq(encoder.mkImplies(blocker, typeBlocker)), Map.empty, Map.empty) + + case None => + val App(caller, tpe @ FunctionType(_, to), args, value) = app + val typeBlocker = encoder.encodeId(FreshIdentifier("t", BooleanType)) + typeBlockers += value -> typeBlocker + implies(blocker, typeBlocker) + + val template = typeTemplate(to) + val instantiation = template.instantiate(typeBlocker, value) + + val (b, extClauses) = if (knownFree(tpe) contains caller) { + (blocker, Seq.empty) + } else { + val firstB = encoder.encodeId(FreshIdentifier("b_free", BooleanType, true)) + implies(firstB, typeBlocker) + typeEnablers += firstB + + val nextB = encoder.encodeId(FreshIdentifier("b_or", BooleanType, true)) + freeBlockers += tpe -> (freeBlockers(tpe) + (nextB -> caller)) + + val clause = encoder.mkEquals(firstB, encoder.mkAnd(blocker, encoder.mkOr( + knownFree(tpe).map(idT => encoder.mkEquals(caller, idT)).toSeq ++ + maybeFree(tpe).map { case (b, idT) => encoder.mkAnd(b, encoder.mkEquals(caller, idT)) } :+ + nextB : _*))) + (firstB, Seq(clause)) + } + + instantiation withClauses extClauses withClause encoder.mkImplies(b, typeBlocker) + } + def instantiateLambda(template: LambdaTemplate[T]): (T, Instantiation[T]) = { byType(template.tpe).get(template.key) match { case Some(template) => @@ -215,51 +288,60 @@ class LambdaManager[T](encoder: TemplateEncoder[T]) extends TemplateManager(enco val idT = encoder.encodeId(template.ids._1) val newTemplate = template.withId(idT) - var clauses : Clauses[T] = equalityClauses(newTemplate) - var appBlockers : AppBlockers[T] = Map.empty.withDefaultValue(Set.empty) - // make sure the new lambda isn't equal to any free lambda var - clauses ++= freeLambdas(newTemplate.tpe).map(pIdT => encoder.mkNot(encoder.mkEquals(idT, pIdT))) + val instantiation = Instantiation.empty[T] withClauses ( + equalityClauses(newTemplate) ++ + knownFree(newTemplate.tpe).map(f => encoder.mkNot(encoder.mkEquals(idT, f))).toSeq ++ + maybeFree(newTemplate.tpe).map { p => + encoder.mkImplies(p._1, encoder.mkNot(encoder.mkEquals(idT, p._2))) + }) byID += idT -> newTemplate byType += newTemplate.tpe -> (byType(newTemplate.tpe) + (newTemplate.key -> newTemplate)) - for (blockedApp @ (_, App(caller, tpe, args)) <- applications(newTemplate.tpe)) { - val equals = encoder.mkEquals(idT, caller) - appBlockers += (blockedApp -> (appBlockers(blockedApp) + TemplateAppInfo(newTemplate, equals, args))) + val inst = applications(newTemplate.tpe).foldLeft(instantiation) { + case (instantiation, app @ (_, App(caller, _, args, _))) => + val equals = encoder.mkEquals(idT, caller) + instantiation withApp (app -> TemplateAppInfo(newTemplate, equals, args)) } - (idT, (clauses, Map.empty, appBlockers)) + (idT, inst) } } def instantiateApp(blocker: T, app: App[T]): Instantiation[T] = { - val App(caller, tpe, args) = app - val instantiation = Instantiation.empty[T] + val App(caller, tpe @ FunctionType(_, to), args, encoded) = app - if (freeLambdas(tpe).contains(caller)) instantiation else { - val key = blocker -> app + val instantiation: Instantiation[T] = if (byID contains caller) { + Instantiation.empty + } else { + typeUnroller(blocker, app) + } - if (instantiated(key)) instantiation else { - instantiated += key + val key = blocker -> app + if (instantiated(key)) { + instantiation + } else { + instantiated += key - if (byID contains caller) { - instantiation withApp (key -> TemplateAppInfo(byID(caller), trueT, args)) - } else { + if (knownFree(tpe) contains caller) { + instantiation withApp (key -> TemplateAppInfo(caller, trueT, args)) + } else if (byID contains caller) { + instantiation withApp (key -> TemplateAppInfo(byID(caller), trueT, args)) + } else { - // make sure that even if byType(tpe) is empty, app is recorded in blockers - // so that UnrollingBank will generate the initial block! - val init = instantiation withApps Map(key -> Set.empty) - val inst = byType(tpe).values.foldLeft(init) { - case (instantiation, template) => - val equals = encoder.mkEquals(template.ids._2, caller) - instantiation withApp (key -> TemplateAppInfo(template, equals, args)) - } + // make sure that even if byType(tpe) is empty, app is recorded in blockers + // so that UnrollingBank will generate the initial block! + val init = instantiation withApps Map(key -> Set.empty) + val inst = byType(tpe).values.foldLeft(init) { + case (instantiation, template) => + val equals = encoder.mkEquals(template.ids._2, caller) + instantiation withApp (key -> TemplateAppInfo(template, equals, args)) + } - applications += tpe -> (applications(tpe) + key) + applications += tpe -> (applications(tpe) + key) - inst - } + inst } } } diff --git a/src/main/scala/leon/solvers/templates/QuantificationManager.scala b/src/main/scala/leon/solvers/templates/QuantificationManager.scala index bc31267c09d8023390ac541b0cb99ecece6691ea..85a2add2721a8b03cb3399bdbc73d796d74853e4 100644 --- a/src/main/scala/leon/solvers/templates/QuantificationManager.scala +++ b/src/main/scala/leon/solvers/templates/QuantificationManager.scala @@ -11,6 +11,7 @@ import purescala.Constructors._ import purescala.Expressions._ import purescala.ExprOps._ import purescala.Types._ +import purescala.TypeOps._ import purescala.Quantification.{QuantificationTypeMatcher => QTM} import Instantiation._ @@ -54,8 +55,9 @@ class QuantificationTemplate[T]( val matchers: Map[T, Set[Matcher[T]]], val lambdas: Seq[LambdaTemplate[T]], val dependencies: Map[Identifier, T], - val structuralKey: Forall) extends KeyedTemplate[T, Forall] { + val struct: (Forall, Map[Identifier, Identifier])) extends KeyedTemplate[T, Forall] { + val structure = struct._1 lazy val start = pathVar._2 def substitute(substituter: T => T, matcherSubst: Map[T, Matcher[T]]): QuantificationTemplate[T] = { @@ -87,7 +89,7 @@ class QuantificationTemplate[T]( }, lambdas.map(_.substitute(substituter, matcherSubst)), dependencies.map { case (id, value) => id -> substituter(value) }, - structuralKey + struct ) } } @@ -116,8 +118,8 @@ object QuantificationTemplate { val insts: (Identifier, T) = inst -> encoder.encodeId(inst) val guards: (Identifier, T) = guard -> encoder.encodeId(guard) - val (clauses, blockers, applications, matchers, _) = - Template.encode(encoder, pathVar, quantifiers, condVars, exprVars, guardedExprs, lambdas, + val (clauses, blockers, applications, functions, matchers, _) = + Template.encode(encoder, pathVar, quantifiers, condVars, exprVars, guardedExprs, lambdas, Seq.empty, substMap = baseSubstMap + q2s + insts + guards + qs) val (structuralQuant, structSubst) = normalizeStructure(proposition) @@ -126,23 +128,27 @@ object QuantificationTemplate { new QuantificationTemplate[T](quantificationManager, pathVar, qs, q2s, insts, guards._2, quantifiers, condVars, exprVars, condTree, - clauses, blockers, applications, matchers, lambdas, keyDeps, key) + clauses, blockers, applications, matchers, lambdas, keyDeps, key -> structSubst) } } class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManager[T](encoder) { - private val quantifications = new IncrementalSeq[MatcherQuantification] + private[solvers] val quantifications = new IncrementalSeq[MatcherQuantification] + private val instCtx = new InstantiationContext - private val handled = new ContextMap - private val ignored = new ContextMap + private val ignoredMatchers = new IncrementalSeq[(Int, Set[T], Matcher[T])] + private val ignoredSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Int, Set[T], Map[T,Arg[T]])]] + private val handledSubsts = new IncrementalMap[MatcherQuantification, MutableSet[(Set[T], Map[T,Arg[T]])]] - private val known = new IncrementalSet[T] - private val lambdaAxioms = new IncrementalSet[(LambdaTemplate[T], Seq[(Identifier, T)])] + private val lambdaAxioms = new IncrementalSet[((Expr, Seq[T]), Seq[(Identifier, T)])] private val templates = new IncrementalMap[(Expr, Seq[T]), T] override protected def incrementals: List[IncrementalState] = - List(quantifications, instCtx, handled, ignored, known, lambdaAxioms, templates) ++ super.incrementals + List(quantifications, instCtx, ignoredMatchers, ignoredSubsts, + handledSubsts, lambdaAxioms, templates) ++ super.incrementals + + private var currentGen = 0 private sealed abstract class MatcherKey(val tpe: TypeTree) private case class CallerKey(caller: T, tt: TypeTree) extends MatcherKey(tt) @@ -150,8 +156,8 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage private case class TypeKey(tt: TypeTree) extends MatcherKey(tt) private def matcherKey(caller: T, tpe: TypeTree): MatcherKey = tpe match { - case _: FunctionType if known(caller) => CallerKey(caller, tpe) - case _: FunctionType if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structuralKey, tpe) + case ft: FunctionType if knownFree(ft)(caller) => CallerKey(caller, tpe) + case _: FunctionType if byID.isDefinedAt(caller) => LambdaKey(byID(caller).structure, tpe) case _ => TypeKey(tpe) } @@ -159,55 +165,56 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage private def correspond(qm: Matcher[T], m: Matcher[T]): Boolean = correspond(qm, m.caller, m.tpe) - @inline - private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = - matcherKey(qm.caller, qm.tpe) == matcherKey(caller, tpe) + private def correspond(qm: Matcher[T], caller: T, tpe: TypeTree): Boolean = { + val qkey = matcherKey(qm.caller, qm.tpe) + val key = matcherKey(caller, tpe) + qkey == key || (qkey.tpe == key.tpe && (qkey.isInstanceOf[TypeKey] || key.isInstanceOf[TypeKey])) + } private val uniformQuantMap: MutableMap[TypeTree, Seq[T]] = MutableMap.empty private val uniformQuantSet: MutableSet[T] = MutableSet.empty def isQuantifier(idT: T): Boolean = uniformQuantSet(idT) - private def uniformSubst(qs: Seq[(Identifier, T)]): Map[T, T] = { - qs.groupBy(_._1.getType).flatMap { case (tpe, qst) => + def uniformQuants(ids: Seq[Identifier]): Seq[T] = { + val mapping = ids.groupBy(id => bestRealType(id.getType)).flatMap { case (tpe, idst) => val prev = uniformQuantMap.get(tpe) match { case Some(seq) => seq case None => Seq.empty } - if (prev.size >= qst.size) { - qst.map(_._2) zip prev.take(qst.size) + if (prev.size >= idst.size) { + idst zip prev.take(idst.size) } else { - val (handled, newQs) = qst.splitAt(prev.size) - val uQs = newQs.map(p => p._2 -> encoder.encodeId(p._1)) + val (handled, newIds) = idst.splitAt(prev.size) + val uIds = newIds.map(id => id -> encoder.encodeId(id)) - uniformQuantMap(tpe) = prev ++ uQs.map(_._2) - uniformQuantSet ++= uQs.map(_._2) + uniformQuantMap(tpe) = prev ++ uIds.map(_._2) + uniformQuantSet ++= uIds.map(_._2) - (handled.map(_._2) zip prev) ++ uQs + (handled zip prev) ++ uIds } }.toMap + + ids.map(mapping) } - def assumptions: Seq[T] = quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq + private def uniformSubst(qs: Seq[(Identifier, T)]): Map[T, T] = { + (qs.map(_._2) zip uniformQuants(qs.map(_._1))).toMap + } - def instantiations: (Map[TypeTree, Matchers], Map[T, Matchers], Map[Lambda, Matchers]) = { - var typeInsts: Map[TypeTree, Matchers] = Map.empty - var partialInsts: Map[T, Matchers] = Map.empty - var lambdaInsts: Map[Lambda, Matchers] = Map.empty + override def assumptions: Seq[T] = super.assumptions ++ + quantifications.collect { case q: Quantification => q.currentQ2Var }.toSeq - val instantiations = handled.instantiations ++ instCtx.map.instantiations - for ((key, matchers) <- instantiations) key match { - case TypeKey(tpe) => typeInsts += tpe -> matchers - case CallerKey(caller, _) => partialInsts += caller -> matchers - case LambdaKey(lambda, _) => lambdaInsts += lambda -> matchers - } + def typeInstantiations: Map[TypeTree, Matchers] = instCtx.map.instantiations.collect { + case (TypeKey(tpe), matchers) => tpe -> matchers + } - (typeInsts, partialInsts, lambdaInsts) + def lambdaInstantiations: Map[Lambda, Matchers] = instCtx.map.instantiations.collect { + case (LambdaKey(lambda, tpe), matchers) => lambda -> (matchers ++ instCtx.map.get(TypeKey(tpe)).toMatchers) } - override def registerFree(ids: Seq[(Identifier, T)]): Unit = { - super.registerFree(ids) - known ++= ids.map(_._2) + def partialInstantiations: Map[T, Matchers] = instCtx.map.instantiations.collect { + case (CallerKey(caller, tpe), matchers) => caller -> (matchers ++ instCtx.map.get(TypeKey(tpe)).toMatchers) } private def matcherDepth(m: Matcher[T]): Int = 1 + (0 +: m.args.map { @@ -229,7 +236,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } def +(p: (Set[T], Matcher[T])): Context = if (apply(p)) this else { - val prev = ctx.getOrElse(p._2, Seq.empty) + val prev = ctx.getOrElse(p._2, Set.empty) val newSet = prev.filterNot(set => p._1.subsetOf(set)).toSet + p._1 new Context(ctx + (p._2 -> newSet)) } @@ -282,7 +289,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage def get(key: MatcherKey): Context = key match { case TypeKey(tpe) => tpeMap.getOrElse(tpe, new Context) - case key => funMap.getOrElse(key, new Context) + case key => funMap.getOrElse(key, new Context) ++ tpeMap.getOrElse(key.tpe, new Context) } def instantiations: Map[MatcherKey, Matchers] = @@ -339,9 +346,9 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } } - private trait MatcherQuantification { + private[solvers] trait MatcherQuantification { val pathVar: (Identifier, T) - val quantified: Set[T] + val quantifiers: Seq[(Identifier, T)] val matchers: Set[Matcher[T]] val allMatchers: Map[T, Set[Matcher[T]]] val condVars: Map[Identifier, T] @@ -352,6 +359,10 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val applications: Map[T, Set[App[T]]] val lambdas: Seq[LambdaTemplate[T]] + val holds: T + val body: Expr + + lazy val quantified: Set[T] = quantifiers.map(_._2).toSet lazy val start = pathVar._2 private lazy val depth = matchers.map(matcherDepth).max @@ -411,23 +422,26 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage private def extractSubst(mapping: Set[(Set[T], Matcher[T], Matcher[T])]): (Set[T], Map[T,Arg[T]], Boolean) = { var constraints: Set[T] = Set.empty var eqConstraints: Set[(T, T)] = Set.empty - var matcherEqs: List[(T, T)] = Nil var subst: Map[T, Arg[T]] = Map.empty + var matcherEqs: Set[(T, T)] = Set.empty + def strictnessCnstr(qarg: Arg[T], arg: Arg[T]): Unit = (qarg, arg) match { + case (Right(qam), Right(am)) => (qam.args zip am.args).foreach(p => strictnessCnstr(p._1, p._2)) + case _ => matcherEqs += qarg.encoded -> arg.encoded + } + for { (bs, qm @ Matcher(qcaller, _, qargs, _), m @ Matcher(caller, _, args, _)) <- mapping _ = constraints ++= bs - _ = matcherEqs :+= qm.encoded -> m.encoded (qarg, arg) <- (qargs zip args) + _ = strictnessCnstr(qarg, arg) } qarg match { case Left(quant) if subst.isDefinedAt(quant) => eqConstraints += (quant -> arg.encoded) case Left(quant) if quantified(quant) => subst += quant -> arg case Right(qam) => - val argVal = arg.encoded - eqConstraints += (qam.encoded -> argVal) - matcherEqs :+= qam.encoded -> argVal + eqConstraints += (qam.encoded -> arg.encoded) } val substituter = encoder.substitute(subst.mapValues(_.encoded)) @@ -445,43 +459,68 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage for (mapping <- mappings(bs, matcher)) { val (enablers, subst, isStrict) = extractSubst(mapping) - val (enabler, optEnabler) = freshBlocker(enablers) - - val baseSubst = subst ++ instanceSubst(enablers).mapValues(Left(_)) - val (substMap, inst) = Template.substitution(encoder, QuantificationManager.this, - condVars, exprVars, condTree, Seq.empty, lambdas, baseSubst, pathVar._1, enabler) - if (!skip(substMap)) { - if (optEnabler.isDefined) { - instantiation = instantiation withClause encoder.mkEquals(enabler, optEnabler.get) + if (!skip(subst)) { + if (!isStrict) { + ignoreSubst(enablers, subst) + } else { + instantiation ++= instantiateSubst(enablers, subst, strict = true) } + } + } - instantiation ++= inst - instantiation ++= Template.instantiate(encoder, QuantificationManager.this, - clauses, blockers, applications, Seq.empty, Map.empty[T, Set[Matcher[T]]], lambdas, substMap) + instantiation + } - val msubst = substMap.collect { case (c, Right(m)) => c -> m } - val substituter = encoder.substitute(substMap.mapValues(_.encoded)) + def instantiateSubst(enablers: Set[T], subst: Map[T, Arg[T]], strict: Boolean = false): Instantiation[T] = { + if (handledSubsts(this)(enablers -> subst)) { + Instantiation.empty[T] + } else { + handledSubsts(this) += enablers -> subst - for ((b,ms) <- allMatchers; m <- ms) { - val sb = enablers ++ (if (b == start) Set.empty else Set(substituter(b))) - val sm = m.substitute(substituter, msubst) + var instantiation = Instantiation.empty[T] + val (enabler, optEnabler) = freshBlocker(enablers) + if (optEnabler.isDefined) { + instantiation = instantiation withClause encoder.mkEquals(enabler, optEnabler.get) + } - if (matchers(m)) { - handled += sb -> sm - } else if (transMatchers(m) && isStrict) { - instantiation ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*) - } else { - ignored += sb -> sm - } + val baseSubst = subst ++ instanceSubst(enabler).mapValues(Left(_)) + val (substMap, inst) = Template.substitution[T](encoder, QuantificationManager.this, + condVars, exprVars, condTree, Seq.empty, lambdas, Set.empty, baseSubst, pathVar._1, enabler) + instantiation ++= inst + + val msubst = substMap.collect { case (c, Right(m)) => c -> m } + val substituter = encoder.substitute(substMap.mapValues(_.encoded)) + instantiation ++= Template.instantiate(encoder, QuantificationManager.this, + clauses, blockers, applications, Map.empty, substMap) + + for ((b,ms) <- allMatchers; m <- ms) { + val sb = enablers ++ (if (b == start) Set.empty else Set(substituter(b))) + val sm = m.substitute(substituter, msubst) + + if (strict && (matchers(m) || transMatchers(m))) { + instantiation ++= instCtx.instantiate(sb, sm)(quantifications.toSeq : _*) + } else if (!matchers(m)) { + ignoredMatchers += ((currentGen + 3, sb, sm)) } } + + instantiation } + } - instantiation + def ignoreSubst(enablers: Set[T], subst: Map[T, Arg[T]]): Unit = { + val msubst = subst.collect { case (c, Right(m)) => c -> m } + val substituter = encoder.substitute(subst.mapValues(_.encoded)) + val nextGen = (if (matchers.forall { m => + val sm = m.substitute(substituter, msubst) + instCtx(enablers -> sm) + }) currentGen + 3 else currentGen + 3) + + ignoredSubsts(this) += ((nextGen, enablers, subst)) } - protected def instanceSubst(enablers: Set[T]): Map[T, T] + protected def instanceSubst(enabler: T): Map[T, T] protected def skip(subst: Map[T, Arg[T]]): Boolean = false } @@ -492,7 +531,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val q2s: (Identifier, T), val insts: (Identifier, T), val guardVar: T, - val quantified: Set[T], + val quantifiers: Seq[(Identifier, T)], val matchers: Set[Matcher[T]], val allMatchers: Map[T, Set[Matcher[T]]], val condVars: Map[Identifier, T], @@ -501,14 +540,21 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], - val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification { + val lambdas: Seq[LambdaTemplate[T]], + val template: QuantificationTemplate[T]) extends MatcherQuantification { var currentQ2Var: T = qs._2 + val holds = qs._2 + val body = { + val quantified = quantifiers.map(_._1).toSet + val mapping = template.struct._2.map(p => p._2 -> p._1.toVariable) + replaceFromIDs(mapping, template.structure.body) + } - protected def instanceSubst(enablers: Set[T]): Map[T, T] = { + protected def instanceSubst(enabler: T): Map[T, T] = { val nextQ2Var = encoder.encodeId(q2s._1) - val subst = Map(qs._2 -> currentQ2Var, guardVar -> encodeEnablers(enablers), + val subst = Map(qs._2 -> currentQ2Var, guardVar -> enabler, q2s._2 -> nextQ2Var, insts._2 -> encoder.encodeId(insts._1)) currentQ2Var = nextQ2Var @@ -517,19 +563,26 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage } private lazy val blockerId = FreshIdentifier("blocker", BooleanType, true) - private lazy val blockerCache: MutableMap[T, T] = MutableMap.empty + private lazy val enablersToBlocker: MutableMap[Set[T], T] = MutableMap.empty + private lazy val blockerToEnablers: MutableMap[T, Set[T]] = MutableMap.empty private def freshBlocker(enablers: Set[T]): (T, Option[T]) = enablers.toSeq match { case Seq(b) if isBlocker(b) => (b, None) case _ => - val enabler = encodeEnablers(enablers) - blockerCache.get(enabler) match { + val last = enablersToBlocker.get(enablers).orElse { + val initialEnablers = enablers.flatMap(e => blockerToEnablers.getOrElse(e, Set(e))) + enablersToBlocker.get(initialEnablers) + } + + last match { case Some(b) => (b, None) case None => val nb = encoder.encodeId(blockerId) - blockerCache += enabler -> nb + enablersToBlocker += enablers -> nb + blockerToEnablers += nb -> enablers for (b <- enablers if isBlocker(b)) implies(b, nb) blocker(nb) - (nb, Some(enabler)) + + (nb, Some(encodeEnablers(enablers))) } } @@ -537,7 +590,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val pathVar: (Identifier, T), val blocker: T, val guardVar: T, - val quantified: Set[T], + val quantifiers: Seq[(Identifier, T)], val matchers: Set[Matcher[T]], val allMatchers: Map[T, Set[Matcher[T]]], val condVars: Map[Identifier, T], @@ -546,13 +599,19 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], - val lambdas: Seq[LambdaTemplate[T]]) extends MatcherQuantification { + val lambdas: Seq[LambdaTemplate[T]], + val template: LambdaTemplate[T]) extends MatcherQuantification { - protected def instanceSubst(enablers: Set[T]): Map[T, T] = { - // no need to add an equality clause here since it is already contained in the Axiom clauses - val (newBlocker, optEnabler) = freshBlocker(enablers) - val guardT = if (optEnabler.isDefined) encoder.mkAnd(start, optEnabler.get) else start - Map(guardVar -> guardT, blocker -> newBlocker) + val holds = start + + val body = { + val quantified = quantifiers.map(_._1).toSet + val mapping = template.structSubst.map(p => p._2 -> p._1.toVariable) + replaceFromIDs(mapping, template.structure) + } + + protected def instanceSubst(enabler: T): Map[T, T] = { + Map(guardVar -> start, blocker -> enabler) } override protected def skip(subst: Map[T, Arg[T]]): Boolean = { @@ -588,15 +647,90 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage (m: Matcher[T]) => m.args.collect { case Left(a) if quantified(a) => a }.toSet) } + private def instantiateConstants(quantifiers: Seq[(Identifier, T)], matchers: Set[Matcher[T]]): Instantiation[T] = { + val quantifierSubst = uniformSubst(quantifiers) + val substituter = encoder.substitute(quantifierSubst) + var instantiation: Instantiation[T] = Instantiation.empty + + for { + m <- matchers + sm = m.substitute(substituter, Map.empty) + if !instCtx.corresponding(sm).exists(_._2.args == sm.args) + } { + instantiation ++= instCtx.instantiate(Set.empty, m)(quantifications.toSeq : _*) + instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) + } + + def unifyMatchers(matchers: Seq[Matcher[T]]): Unit = matchers match { + case sm +: others => + for (pm <- others if correspond(pm, sm)) { + val encodedArgs = (sm.args zip pm.args).map(p => p._1.encoded -> p._2.encoded) + val mismatches = encodedArgs.zipWithIndex.collect { + case ((sa, pa), idx) if isQuantifier(sa) && isQuantifier(pa) && sa != pa => (idx, (pa, sa)) + }.toMap + + def extractChains(indexes: Seq[Int], partials: Seq[Seq[Int]]): Seq[Seq[Int]] = indexes match { + case idx +: xs => + val (p1, p2) = mismatches(idx) + val newPartials = Seq(idx) +: partials.map { seq => + if (mismatches(seq.head)._1 == p2) idx +: seq + else if (mismatches(seq.last)._2 == p1) seq :+ idx + else seq + } + + val (closed, remaining) = newPartials.partition { seq => + mismatches(seq.head)._1 == mismatches(seq.last)._2 + } + closed ++ extractChains(xs, partials ++ remaining) + + case _ => Seq.empty + } + + val chains = extractChains(mismatches.keys.toSeq, Seq.empty) + val positions = chains.foldLeft(Map.empty[Int, Int]) { (mapping, seq) => + val res = seq.min + mapping ++ seq.map(i => i -> res) + } + + def extractArgs(args: Seq[Arg[T]]): Seq[Arg[T]] = + (0 until args.size).map(i => args(positions.getOrElse(i, i))) + + instantiation ++= instCtx.instantiate(Set.empty, sm.copy(args = extractArgs(sm.args)))(quantifications.toSeq : _*) + instantiation ++= instCtx.instantiate(Set.empty, pm.copy(args = extractArgs(pm.args)))(quantifications.toSeq : _*) + } + + unifyMatchers(others) + + case _ => + } + + val substMatchers = matchers.map(_.substitute(substituter, Map.empty)) + unifyMatchers(substMatchers.toSeq) + + instantiation + } + def instantiateAxiom(template: LambdaTemplate[T], substMap: Map[T, Arg[T]]): Instantiation[T] = { - val quantifiers = template.arguments flatMap { - case (id, idT) => substMap(idT).left.toOption.map(id -> _) - } filter (p => isQuantifier(p._2)) + def quantifiedMatcher(m: Matcher[T]): Boolean = m.args.exists(a => a match { + case Left(v) => isQuantifier(v) + case Right(m) => quantifiedMatcher(m) + }) + + val quantified = template.arguments flatMap { + case (id, idT) => substMap(idT) match { + case Left(v) if isQuantifier(v) => Some(id) + case Right(m) if quantifiedMatcher(m) => Some(id) + case _ => None + } + } + + val quantifiers = quantified zip uniformQuants(quantified) + val key = template.key -> quantifiers - if (quantifiers.isEmpty || lambdaAxioms(template -> quantifiers)) { + if (quantifiers.isEmpty || lambdaAxioms(key)) { Instantiation.empty[T] } else { - lambdaAxioms += template -> quantifiers + lambdaAxioms += key val blockerT = encoder.encodeId(blockerId) val guard = FreshIdentifier("guard", BooleanType, true) @@ -616,80 +750,51 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage val appT = encoder.encodeExpr((template.arguments.map(_._1) zip encArgs.map(_.encoded)).toMap + template.ids)(app) val selfMatcher = Matcher(template.ids._2, template.tpe, encArgs, appT) + val instMatchers = allMatchers + (template.start -> (allMatchers.getOrElse(template.start, Set.empty) + selfMatcher)) + val enablingClause = encoder.mkImplies(guardT, blockerT) - instantiateAxiom( - template.pathVar._1 -> substituter(template.start), - blockerT, - guardT, - quantifiers, - qMatchers, - allMatchers + (template.start -> (allMatchers.getOrElse(template.start, Set.empty) + selfMatcher)), - template.condVars map { case (id, idT) => id -> substituter(idT) }, - template.exprVars map { case (id, idT) => id -> substituter(idT) }, - template.condTree, - (template.clauses map substituter) :+ enablingClause, - template.blockers map { case (b, fis) => - substituter(b) -> fis.map(fi => fi.copy( - args = fi.args.map(_.substitute(substituter, msubst)) - )) - }, - template.applications map { case (b, apps) => - substituter(b) -> apps.map(app => app.copy( - caller = substituter(app.caller), - args = app.args.map(_.substitute(substituter, msubst)) - )) - }, - template.lambdas map (_.substitute(substituter, msubst)) - ) - } - } + val condVars = template.condVars map { case (id, idT) => id -> substituter(idT) } + val exprVars = template.exprVars map { case (id, idT) => id -> substituter(idT) } + val clauses = (template.clauses map substituter) :+ enablingClause + val blockers = template.blockers map { case (b, fis) => + substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(_.substitute(substituter, msubst)))) + } - def instantiateAxiom( - pathVar: (Identifier, T), - blocker: T, - guardVar: T, - quantifiers: Seq[(Identifier, T)], - matchers: Set[Matcher[T]], - allMatchers: Map[T, Set[Matcher[T]]], - condVars: Map[Identifier, T], - exprVars: Map[Identifier, T], - condTree: Map[Identifier, Set[Identifier]], - clauses: Seq[T], - blockers: Map[T, Set[TemplateCallInfo[T]]], - applications: Map[T, Set[App[T]]], - lambdas: Seq[LambdaTemplate[T]] - ): Instantiation[T] = { - val quantified = quantifiers.map(_._2).toSet - val matchQuorums = extractQuorums(quantified, matchers, lambdas) + val applications = template.applications map { case (b, apps) => + substituter(b) -> apps.map(app => app.copy( + caller = substituter(app.caller), + args = app.args.map(_.substitute(substituter, msubst)) + )) + } - var instantiation = Instantiation.empty[T] + val lambdas = template.lambdas map (_.substitute(substituter, msubst)) + + val quantified = quantifiers.map(_._2).toSet + val matchQuorums = extractQuorums(quantified, qMatchers, lambdas) - for (matchers <- matchQuorums) { - val axiom = new LambdaAxiom(pathVar, blocker, guardVar, quantified, - matchers, allMatchers, condVars, exprVars, condTree, - clauses, blockers, applications, lambdas - ) + var instantiation = Instantiation.empty[T] + + for (matchers <- matchQuorums) { + val axiom = new LambdaAxiom(template.pathVar._1 -> substituter(template.start), + blockerT, guardT, quantifiers, matchers, instMatchers, condVars, exprVars, template.condTree, + clauses, blockers, applications, lambdas, template) - quantifications += axiom + quantifications += axiom + handledSubsts += axiom -> MutableSet.empty + ignoredSubsts += axiom -> MutableSet.empty - val newCtx = new InstantiationContext() - for ((b,m) <- instCtx.instantiated) { - instantiation ++= newCtx.instantiate(b, m)(axiom) + val newCtx = new InstantiationContext() + for ((b,m) <- instCtx.instantiated) { + instantiation ++= newCtx.instantiate(b, m)(axiom) + } + instCtx.merge(newCtx) } - instCtx.merge(newCtx) - } - val quantifierSubst = uniformSubst(quantifiers) - val substituter = encoder.substitute(quantifierSubst) + instantiation ++= instantiateConstants(quantifiers, qMatchers) - for { - m <- matchers - sm = m.substitute(substituter, Map.empty) - if !instCtx.corresponding(sm).exists(_._2.args == sm.args) - } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) - - instantiation + instantiation + } } def instantiateQuantification(template: QuantificationTemplate[T]): (T, Instantiation[T]) = { @@ -712,13 +817,14 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage template.pathVar, template.qs._1 -> newQ, template.q2s, template.insts, template.guardVar, - quantified, matchers, template.matchers, + template.quantifiers, matchers, template.matchers, template.condVars, template.exprVars, template.condTree, template.clauses map substituter, // one clause depends on 'q' (and therefore 'newQ') - template.blockers, template.applications, template.lambdas - ) + template.blockers, template.applications, template.lambdas, template) quantifications += quantification + handledSubsts += quantification -> MutableSet.empty + ignoredSubsts += quantification -> MutableSet.empty val newCtx = new InstantiationContext() for ((b,m) <- instCtx.instantiated) { @@ -737,14 +843,7 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage encoder.mkImplies(template.start, encoder.mkEquals(qT, newQs)) } - val quantifierSubst = uniformSubst(template.quantifiers) - val substituter = encoder.substitute(quantifierSubst) - - for { - (_, ms) <- template.matchers; m <- ms - sm = m.substitute(substituter, Map.empty) - if !instCtx.corresponding(sm).exists(_._2.args == sm.args) - } instantiation ++= instCtx.instantiate(Set.empty, sm)(quantifications.toSeq : _*) + instantiation ++= instantiateConstants(template.quantifiers, template.matchers.flatMap(_._2).toSet) templates += template.key -> qT (qT, instantiation) @@ -755,61 +854,128 @@ class QuantificationManager[T](encoder: TemplateEncoder[T]) extends LambdaManage instCtx.instantiate(Set(blocker), matcher)(quantifications.toSeq : _*) } - private type SetDef = (T, (Identifier, T), (Identifier, T), Seq[T], T, T, T) - private val setConstructors: MutableMap[TypeTree, SetDef] = MutableMap.empty + def hasIgnored: Boolean = ignoredSubsts.nonEmpty || ignoredMatchers.nonEmpty + + def instantiateIgnored(force: Boolean = false): Instantiation[T] = { + currentGen = if (!force) currentGen + 1 else { + val gens = ignoredSubsts.toSeq.flatMap(_._2).map(_._1) ++ ignoredMatchers.toSeq.map(_._1) + if (gens.isEmpty) currentGen else gens.min + } + + var instantiation = Instantiation.empty[T] + + val matchersToRelease = ignoredMatchers.toList.flatMap { case e @ (gen, b, m) => + if (gen == currentGen) { + ignoredMatchers -= e + Some(b -> m) + } else { + None + } + } + + for ((bs,m) <- matchersToRelease) { + instCtx.instantiate(bs, m)(quantifications.toSeq : _*) + } + + val substsToRelease = quantifications.toList.flatMap { q => + val qsubsts = ignoredSubsts(q) + qsubsts.toList.flatMap { case e @ (gen, enablers, subst) => + if (gen == currentGen) { + qsubsts -= e + Some((q, enablers, subst)) + } else { + None + } + } + } + + for ((q, enablers, subst) <- substsToRelease) { + instantiation ++= q.instantiateSubst(enablers, subst, strict = false) + } + + instantiation + } def checkClauses: Seq[T] = { val clauses = new scala.collection.mutable.ListBuffer[T] + //val keySets = scala.collection.mutable.Map.empty[MatcherKey, T] + val keyClause = MutableMap.empty[MatcherKey, (Seq[T], T)] - for ((key, ctx) <- ignored.instantiations) { - val insts = instCtx.map.get(key).toMatchers - + for ((_, bs, m) <- ignoredMatchers) { + val key = matcherKey(m.caller, m.tpe) val QTM(argTypes, _) = key.tpe - val tupleType = tupleTypeWrap(argTypes) - - val (guardT, (setPrev, setPrevT), (setNext, setNextT), elems, containsT, emptyT, setT) = - setConstructors.getOrElse(tupleType, { - val guard = FreshIdentifier("guard", BooleanType) - val setPrev = FreshIdentifier("prevSet", SetType(tupleType)) - val setNext = FreshIdentifier("nextSet", SetType(tupleType)) - val elems = argTypes.map(tpe => FreshIdentifier("elem", tpe)) - - val elemExpr = tupleWrap(elems.map(_.toVariable)) - val contextExpr = And( - Implies(Variable(guard), Equals(Variable(setNext), - SetUnion(Variable(setPrev), FiniteSet(Set(elemExpr), tupleType)))), - Implies(Not(Variable(guard)), Equals(Variable(setNext), Variable(setPrev)))) - - val guardP = guard -> encoder.encodeId(guard) - val setPrevP = setPrev -> encoder.encodeId(setPrev) - val setNextP = setNext -> encoder.encodeId(setNext) - val elemsP = elems.map(e => e -> encoder.encodeId(e)) - - val containsT = encoder.encodeExpr(elemsP.toMap + setPrevP)(ElementOfSet(elemExpr, setPrevP._1.toVariable)) - val emptyT = encoder.encodeExpr(Map.empty)(FiniteSet(Set.empty, tupleType)) - val contextT = encoder.encodeExpr(Map(guardP, setPrevP, setNextP) ++ elemsP)(contextExpr) - - val setDef = (guardP._2, setPrevP, setNextP, elemsP.map(_._2), containsT, emptyT, contextT) - setConstructors += key.tpe -> setDef - setDef - }) - - var prev = emptyT - for ((b, m) <- insts.toSeq) { - val next = encoder.encodeId(setNext) - val argsMap = (elems zip m.args).map { case (idT, arg) => idT -> arg.encoded } - val substMap = Map(guardT -> b, setPrevT -> prev, setNextT -> next) ++ argsMap - prev = next - clauses += encoder.substitute(substMap)(setT) + + val (values, clause) = keyClause.getOrElse(key, { + val insts = instCtx.map.get(key).toMatchers + + val guard = FreshIdentifier("guard", BooleanType) + val elems = argTypes.map(tpe => FreshIdentifier("elem", tpe)) + val values = argTypes.map(tpe => FreshIdentifier("value", tpe)) + val expr = andJoin(Variable(guard) +: (elems zip values).map(p => Equals(Variable(p._1), Variable(p._2)))) + + val guardP = guard -> encoder.encodeId(guard) + val elemsP = elems.map(e => e -> encoder.encodeId(e)) + val valuesP = values.map(v => v -> encoder.encodeId(v)) + val exprT = encoder.encodeExpr(elemsP.toMap ++ valuesP + guardP)(expr) + + val disjuncts = insts.toSeq.map { case (b, im) => + val bp = if (m.caller != im.caller) encoder.mkAnd(encoder.mkEquals(m.caller, im.caller), b) else b + val subst = (elemsP.map(_._2) zip im.args.map(_.encoded)).toMap + (guardP._2 -> bp) + encoder.substitute(subst)(exprT) + } + + val res = (valuesP.map(_._2), encoder.mkOr(disjuncts : _*)) + keyClause += key -> res + res + }) + + val b = encodeEnablers(bs) + val substMap = (values zip m.args.map(_.encoded)).toMap + clauses += encoder.substitute(substMap)(encoder.mkImplies(b, clause)) + } + + for (q <- quantifications) { + val guard = FreshIdentifier("guard", BooleanType) + val elems = q.quantifiers.map(_._1) + val values = elems.map(id => id.freshen) + val expr = andJoin(Variable(guard) +: (elems zip values).map(p => Equals(Variable(p._1), Variable(p._2)))) + + val guardP = guard -> encoder.encodeId(guard) + val elemsP = elems.map(e => e -> encoder.encodeId(e)) + val valuesP = values.map(v => v -> encoder.encodeId(v)) + val exprT = encoder.encodeExpr(elemsP.toMap ++ valuesP + guardP)(expr) + + val disjunction = handledSubsts(q) match { + case set if set.isEmpty => encoder.encodeExpr(Map.empty)(BooleanLiteral(false)) + case set => encoder.mkOr(set.toSeq.map { case (enablers, subst) => + val b = if (enablers.isEmpty) trueT else encoder.mkAnd(enablers.toSeq : _*) + val substMap = (elemsP.map(_._2) zip q.quantifiers.map(p => subst(p._2).encoded)).toMap + (guardP._2 -> b) + encoder.substitute(substMap)(exprT) + } : _*) } - val setMap = Map(setPrevT -> prev) - for ((b, m) <- ctx.toSeq) { - val substMap = setMap ++ (elems zip m.args).map(p => p._1 -> p._2.encoded) - clauses += encoder.substitute(substMap)(encoder.mkImplies(b, containsT)) + for ((_, enablers, subst) <- ignoredSubsts(q)) { + val b = if (enablers.isEmpty) trueT else encoder.mkAnd(enablers.toSeq : _*) + val substMap = (valuesP.map(_._2) zip q.quantifiers.map(p => subst(p._2).encoded)).toMap + clauses += encoder.substitute(substMap)(encoder.mkImplies(b, disjunction)) } } + for ((key, ctx) <- instCtx.map.instantiations) { + val QTM(argTypes, _) = key.tpe + + for { + (tpe,idx) <- argTypes.zipWithIndex + quants <- uniformQuantMap.get(tpe) if quants.nonEmpty + (b, m) <- ctx + arg = m.args(idx).encoded if !isQuantifier(arg) + } clauses += encoder.mkAnd(quants.map(q => encoder.mkNot(encoder.mkEquals(q, arg))) : _*) + } + + for ((tpe, base +: rest) <- uniformQuantMap; q <- rest) { + clauses += encoder.mkEquals(base, q) + } + clauses.toSeq } } diff --git a/src/main/scala/leon/solvers/templates/TemplateEncoder.scala b/src/main/scala/leon/solvers/templates/TemplateEncoder.scala index f3eb2ad28f3bab2e72770f388e6ecbe92524dc54..16d7b3cdc7f695eb7524895fd880de4f23a1083c 100644 --- a/src/main/scala/leon/solvers/templates/TemplateEncoder.scala +++ b/src/main/scala/leon/solvers/templates/TemplateEncoder.scala @@ -18,4 +18,6 @@ trait TemplateEncoder[T] { def mkAnd(ts: T*): T def mkEquals(l: T, r: T): T def mkImplies(l: T, r: T): T + + def extractNot(v: T): Option[T] } diff --git a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala index aa0e5dad6506be945919cb22361cce649f40c077..4556f7dd6507f9a8c8a745b0323dbbdb9cf0032b 100644 --- a/src/main/scala/leon/solvers/templates/TemplateGenerator.scala +++ b/src/main/scala/leon/solvers/templates/TemplateGenerator.scala @@ -9,11 +9,12 @@ import purescala.Expressions._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -import purescala.TypeOps._ +import purescala.TypeOps.bestRealType import purescala.Definitions._ import purescala.Constructors._ import purescala.Quantification._ +import utils.SeqUtils._ import Instantiation._ class TemplateGenerator[T](val encoder: TemplateEncoder[T], @@ -32,6 +33,16 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], private def emptyClauses: Clauses = (Map.empty, Map.empty, Map.empty, Map.empty, Seq.empty, Seq.empty) + private implicit class ClausesWrapper(clauses: Clauses) { + def ++(that: Clauses): Clauses = { + val (thisConds, thisExprs, thisTree, thisGuarded, thisLambdas, thisQuants) = clauses + val (thatConds, thatExprs, thatTree, thatGuarded, thatLambdas, thatQuants) = that + + (thisConds ++ thatConds, thisExprs ++ thatExprs, thisTree merge thatTree, + thisGuarded merge thatGuarded, thisLambdas ++ thatLambdas, thisQuants ++ thatQuants) + } + } + val manager = new QuantificationManager[T](encoder) def mkTemplate(body: Expr): FunctionTemplate[T] = { @@ -39,8 +50,10 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], return cacheExpr(body) } - val fakeFunDef = new FunDef(FreshIdentifier("fake", alwaysShowUniqueID = true), Nil, variablesOf(body).toSeq.map(ValDef(_)), body.getType) + val arguments = variablesOf(body).toSeq.map(ValDef(_)) + val fakeFunDef = new FunDef(FreshIdentifier("fake", alwaysShowUniqueID = true), Nil, arguments, body.getType) + fakeFunDef.precondition = Some(andJoin(arguments.map(vd => manager.typeUnroller(vd.toVariable)))) fakeFunDef.body = Some(body) val res = mkTemplate(fakeFunDef.typed, false) @@ -54,7 +67,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], } // The precondition if it exists. - val prec : Option[Expr] = tfd.precondition.map(p => matchToIfThenElse(p)) + val prec : Option[Expr] = tfd.precondition.map(p => simplifyHOFunctions(matchToIfThenElse(p))) val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b)) val lambdaBody : Option[Expr] = newBody.map(b => simplifyHOFunctions(b)) @@ -63,20 +76,18 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val lambdaArguments: Seq[Identifier] = lambdaBody.map(lambdaArgs).toSeq.flatten val invocation : Expr = FunctionInvocation(tfd, funDefArgs.map(_.toVariable)) - val invocationEqualsBody : Option[Expr] = lambdaBody match { + val invocationEqualsBody : Seq[Expr] = lambdaBody match { case Some(body) if isRealFunDef => - val b : Expr = And( - liftedEquals(invocation, body, lambdaArguments), - Equals(invocation, body)) + val bs = liftedEquals(invocation, body, lambdaArguments) :+ Equals(invocation, body) - Some(if(prec.isDefined) { - Implies(prec.get, b) + if(prec.isDefined) { + bs.map(Implies(prec.get, _)) } else { - b - }) + bs + } case _ => - None + Seq.empty } val start : Identifier = FreshIdentifier("start", BooleanType, true) @@ -87,16 +98,17 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], val substMap : Map[Identifier, T] = arguments.toMap + pathVar + val (bodyConds, bodyExprs, bodyTree, bodyGuarded, bodyLambdas, bodyQuantifications) = if (isRealFunDef) { - invocationEqualsBody.map(expr => mkClauses(start, expr, substMap)).getOrElse(emptyClauses) + invocationEqualsBody.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap)) } else { - mkClauses(start, lambdaBody.get, substMap) + (prec.toSeq :+ lambdaBody.get).foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(start, cls, substMap)) } // Now the postcondition. val (condVars, exprVars, condTree, guardedExprs, lambdas, quantifications) = tfd.postcondition match { case Some(post) => - val newPost : Expr = application(matchToIfThenElse(post), Seq(invocation)) + val newPost : Expr = simplifyHOFunctions(application(matchToIfThenElse(post), Seq(invocation))) val postHolds : Expr = if(tfd.hasPrecondition) { @@ -128,7 +140,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], case _ => Seq.empty } - private def liftedEquals(invocation: Expr, body: Expr, args: Seq[Identifier], inlineFirst: Boolean = false): Expr = { + private def liftedEquals(invocation: Expr, body: Expr, args: Seq[Identifier], inlineFirst: Boolean = false): Seq[Expr] = { def rec(i: Expr, b: Expr, args: Seq[Identifier], inline: Boolean): Seq[Expr] = i.getType match { case FunctionType(from, to) => val (currArgs, nextArgs) = args.splitAt(from.size) @@ -141,7 +153,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], Seq.empty } - andJoin(rec(invocation, body, args, inlineFirst)) + rec(invocation, body, args, inlineFirst) } private def minimalFlattening(inits: Set[Identifier], conj: Expr): (Set[Identifier], Expr) = { @@ -206,11 +218,14 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], // Represents clauses of the form: // id => expr && ... && expr var guardedExprs = Map[Identifier, Seq[Expr]]() - def storeGuarded(guardVar : Identifier, expr : Expr) : Unit = { - assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean") + def storeGuarded(guardVar: Identifier, expr: Expr) : Unit = { + assert(expr.getType == BooleanType, expr.asString(Program.empty)(LeonContext.empty) + " is not of type Boolean." + ( + purescala.ExprOps.fold[String]{ (e, se) => + s"$e is of type ${e.getType}" + se.map(child => "\n " + "\n".r.replaceAllIn(child, "\n ")).mkString + }(expr) + )) val prev = guardedExprs.getOrElse(guardVar, Nil) - guardedExprs += guardVar -> (expr +: prev) } @@ -237,25 +252,6 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], }(e) } - def groupWhile[T](es: Seq[T])(p: T => Boolean): Seq[Seq[T]] = { - var res: Seq[Seq[T]] = Nil - - var c = es - while (!c.isEmpty) { - val (span, rest) = c.span(p) - - if (span.isEmpty) { - res :+= Seq(rest.head) - c = rest.tail - } else { - res :+= span - c = rest - } - } - - res - } - def rec(pathVar: Identifier, expr: Expr): Expr = { expr match { case a @ Assert(cond, err, body) => @@ -314,8 +310,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], storeExpr(newExpr) def recAnd(pathVar: Identifier, partitions: Seq[Expr]): Unit = partitions match { - case x :: Nil if !requireDecomposition(x) => - storeGuarded(pathVar, Equals(Variable(newExpr), x)) + case x :: Nil => + storeGuarded(pathVar, Equals(Variable(newExpr), rec(pathVar, x))) case x :: xs => val newBool : Identifier = FreshIdentifier("b", BooleanType, true) @@ -327,8 +323,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], recAnd(newBool, xs) - case Nil => - storeGuarded(pathVar, Variable(newExpr)) + case Nil => scala.sys.error("Should never happen!") } recAnd(pathVar, seq) @@ -344,8 +339,8 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], storeExpr(newExpr) def recOr(pathVar: Identifier, partitions: Seq[Expr]): Unit = partitions match { - case x :: Nil if !requireDecomposition(x) => - storeGuarded(pathVar, Equals(Variable(newExpr), x)) + case x :: Nil => + storeGuarded(pathVar, Equals(Variable(newExpr), rec(pathVar, x))) case x :: xs => val newBool : Identifier = FreshIdentifier("b", BooleanType, true) @@ -357,8 +352,7 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], recOr(newBool, xs) - case Nil => - storeGuarded(pathVar, Not(Variable(newExpr))) + case Nil => scala.sys.error("Should never happen!") } recOr(pathVar, seq) @@ -394,7 +388,6 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], } case c @ Choose(Lambda(params, cond)) => - val cs = params.map(_.id.freshen.toVariable) for (c <- cs) { @@ -407,16 +400,25 @@ class TemplateGenerator[T](val encoder: TemplateEncoder[T], tupleWrap(cs) + case FiniteLambda(mapping, dflt, FunctionType(from, to)) => + val args = from.map(tpe => FreshIdentifier("x", tpe)) + val body = mapping.toSeq.foldLeft(dflt) { case (elze, (exprs, res)) => + IfExpr(andJoin((args zip exprs).map(p => Equals(Variable(p._1), p._2))), res, elze) + } + + rec(pathVar, Lambda(args.map(ValDef(_)), body)) + case l @ Lambda(args, body) => val idArgs : Seq[Identifier] = lambdaArgs(l) val trArgs : Seq[T] = idArgs.map(id => substMap.getOrElse(id, encoder.encodeId(id))) val lid = FreshIdentifier("lambda", bestRealType(l.getType), true) - val clause = liftedEquals(Variable(lid), l, idArgs, inlineFirst = true) + val clauses = liftedEquals(Variable(lid), l, idArgs, inlineFirst = true) val localSubst: Map[Identifier, T] = substMap ++ condVars ++ exprVars ++ lambdaVars val clauseSubst: Map[Identifier, T] = localSubst ++ (idArgs zip trArgs) - val (lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaTemplates, lambdaQuants) = mkClauses(pathVar, clause, clauseSubst) + val (lambdaConds, lambdaExprs, lambdaTree, lambdaGuarded, lambdaTemplates, lambdaQuants) = + clauses.foldLeft(emptyClauses)((clsSet, cls) => clsSet ++ mkClauses(pathVar, cls, clauseSubst)) val ids: (Identifier, T) = lid -> storeLambda(lid) val dependencies: Map[Identifier, T] = variablesOf(l).map(id => id -> localSubst(id)).toMap diff --git a/src/main/scala/leon/solvers/templates/TemplateInfo.scala b/src/main/scala/leon/solvers/templates/TemplateInfo.scala index 27df9b25d13412c106f2bf30d6c75b1266b19d93..dfdd664cbdf1f110690d05e959bc1a14cc615327 100644 --- a/src/main/scala/leon/solvers/templates/TemplateInfo.scala +++ b/src/main/scala/leon/solvers/templates/TemplateInfo.scala @@ -16,11 +16,24 @@ case class TemplateCallInfo[T](tfd: TypedFunDef, args: Seq[Arg[T]]) { } } -case class TemplateAppInfo[T](template: LambdaTemplate[T], equals: T, args: Seq[Arg[T]]) { +case class TemplateAppInfo[T](template: Either[LambdaTemplate[T], T], equals: T, args: Seq[Arg[T]]) { override def toString = { - template.ids._2 + "|" + equals + args.map { + val caller = template match { + case Left(tmpl) => tmpl.ids._2 + case Right(c) => c + } + + caller + "|" + equals + args.map { case Right(m) => m.toString case Left(v) => v.toString }.mkString("(", ",", ")") } } + +object TemplateAppInfo { + def apply[T](template: LambdaTemplate[T], equals: T, args: Seq[Arg[T]]): TemplateAppInfo[T] = + TemplateAppInfo(Left(template), equals, args) + + def apply[T](caller: T, equals: T, args: Seq[Arg[T]]): TemplateAppInfo[T] = + TemplateAppInfo(Right(caller), equals, args) +} diff --git a/src/main/scala/leon/solvers/templates/TemplateManager.scala b/src/main/scala/leon/solvers/templates/TemplateManager.scala index 2b75f08f0480cf272515bb8d8393e01e29d4dbf1..c81abc384934aea21fb1d5d183925e7c6d013fb7 100644 --- a/src/main/scala/leon/solvers/templates/TemplateManager.scala +++ b/src/main/scala/leon/solvers/templates/TemplateManager.scala @@ -8,19 +8,20 @@ import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ import purescala.Quantification._ +import purescala.Constructors._ import purescala.Extractors._ import purescala.ExprOps._ import purescala.Types._ -import purescala.TypeOps._ +import purescala.TypeOps.bestRealType import utils._ import scala.collection.generic.CanBuildFrom object Instantiation { - type Clauses[T] = Seq[T] - type CallBlockers[T] = Map[T, Set[TemplateCallInfo[T]]] - type AppBlockers[T] = Map[(T, App[T]), Set[TemplateAppInfo[T]]] + type Clauses[T] = Seq[T] + type CallBlockers[T] = Map[T, Set[TemplateCallInfo[T]]] + type AppBlockers[T] = Map[(T, App[T]), Set[TemplateAppInfo[T]]] type Instantiation[T] = (Clauses[T], CallBlockers[T], AppBlockers[T]) def empty[T] = (Seq.empty[T], Map.empty[T, Set[TemplateCallInfo[T]]], Map.empty[(T, App[T]), Set[TemplateAppInfo[T]]]) @@ -33,7 +34,7 @@ object Instantiation { implicit class MapSeqWrapper[A,B](map: Map[A,Seq[B]]) { def merge(that: Map[A,Seq[B]]): Map[A,Seq[B]] = (map.keys ++ that.keys).map { k => - k -> (map.getOrElse(k, Seq.empty) ++ that.getOrElse(k, Seq.empty)) + k -> (map.getOrElse(k, Seq.empty) ++ that.getOrElse(k, Seq.empty)).distinct }.toMap } @@ -61,32 +62,36 @@ import Template.Arg trait Template[T] { self => val encoder : TemplateEncoder[T] - val manager : QuantificationManager[T] + val manager : TemplateManager[T] + + val pathVar : (Identifier, T) + val args : Seq[T] - val pathVar: (Identifier, T) - val args : Seq[T] val condVars : Map[Identifier, T] val exprVars : Map[Identifier, T] val condTree : Map[Identifier, Set[Identifier]] - val clauses : Seq[T] - val blockers : Map[T, Set[TemplateCallInfo[T]]] + + val clauses : Seq[T] + val blockers : Map[T, Set[TemplateCallInfo[T]]] val applications : Map[T, Set[App[T]]] + val functions : Set[(T, FunctionType, T)] + val lambdas : Seq[LambdaTemplate[T]] + val quantifications : Seq[QuantificationTemplate[T]] - val matchers : Map[T, Set[Matcher[T]]] - val lambdas : Seq[LambdaTemplate[T]] + val matchers : Map[T, Set[Matcher[T]]] lazy val start = pathVar._2 def instantiate(aVar: T, args: Seq[Arg[T]]): Instantiation[T] = { val (substMap, instantiation) = Template.substitution(encoder, manager, - condVars, exprVars, condTree, quantifications, lambdas, + condVars, exprVars, condTree, quantifications, lambdas, functions, (this.args zip args).toMap + (start -> Left(aVar)), pathVar._1, aVar) instantiation ++ instantiate(substMap) } protected def instantiate(substMap: Map[T, Arg[T]]): Instantiation[T] = { - Template.instantiate(encoder, manager, - clauses, blockers, applications, quantifications, matchers, lambdas, substMap) + Template.instantiate(encoder, manager, clauses, + blockers, applications, matchers, substMap) } override def toString : String = "Instantiated template" @@ -127,6 +132,9 @@ object Template { Matcher(encodeExpr(caller), bestRealType(caller.getType), arguments.map(arg => Left(encodeExpr(arg))), encodeExpr(app)) } + type Apps[T] = Map[T, Set[App[T]]] + type Functions[T] = Set[(T, FunctionType, T)] + def encode[T]( encoder: TemplateEncoder[T], pathVar: (Identifier, T), @@ -135,34 +143,60 @@ object Template { exprVars: Map[Identifier, T], guardedExprs: Map[Identifier, Seq[Expr]], lambdas: Seq[LambdaTemplate[T]], + quantifications: Seq[QuantificationTemplate[T]], substMap: Map[Identifier, T] = Map.empty[Identifier, T], optCall: Option[TypedFunDef] = None, optApp: Option[(T, FunctionType)] = None - ) : (Seq[T], Map[T, Set[TemplateCallInfo[T]]], Map[T, Set[App[T]]], Map[T, Set[Matcher[T]]], () => String) = { + ) : (Clauses[T], CallBlockers[T], Apps[T], Functions[T], Map[T, Set[Matcher[T]]], () => String) = { val idToTrId : Map[Identifier, T] = - condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) + condVars ++ exprVars + pathVar ++ arguments ++ substMap ++ lambdas.map(_.ids) ++ quantifications.map(_.qs) val encodeExpr : Expr => T = encoder.encodeExpr(idToTrId) - val clauses : Seq[T] = (for ((b,es) <- guardedExprs; e <- es) yield { - encodeExpr(Implies(Variable(b), e)) - }).toSeq + val (clauses, cleanGuarded, functions) = { + var functions: Set[(T, FunctionType, T)] = Set.empty + var clauses: Seq[T] = Seq.empty + + val cleanGuarded = guardedExprs.map { + case (b, es) => b -> es.map { e => + def clean(expr: Expr): Expr = postMap { + case FreshFunction(f) => Some(BooleanLiteral(true)) + case _ => None + } (expr) + + val withPaths = CollectorWithPaths { case FreshFunction(f) => f }.traverse(e) + functions ++= withPaths.map { case (f, TopLevelAnds(paths)) => + val tpe = bestRealType(f.getType).asInstanceOf[FunctionType] + val path = andJoin(paths.map(clean)) + (encodeExpr(and(Variable(b), path)), tpe, encodeExpr(f)) + } + + val cleanExpr = clean(e) + clauses :+= encodeExpr(Implies(Variable(b), cleanExpr)) + cleanExpr + } + } + + (clauses, cleanGuarded, functions) + } val optIdCall = optCall.map(tfd => TemplateCallInfo[T](tfd, arguments.map(p => Left(p._2)))) val optIdApp = optApp.map { case (idT, tpe) => - App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2))) + val id = FreshIdentifier("x", tpe, true) + val encoded = encoder.encodeExpr(Map(id -> idT) ++ arguments)(Application(Variable(id), arguments.map(_._1.toVariable))) + App(idT, bestRealType(tpe).asInstanceOf[FunctionType], arguments.map(p => Left(p._2)), encoded) } lazy val invocMatcher = optCall.filter(_.returnType.isInstanceOf[FunctionType]) .map(tfd => invocationMatcher(encodeExpr)(tfd, arguments.map(_._1.toVariable))) val (blockers, applications, matchers) = { - var blockers : Map[Identifier, Set[TemplateCallInfo[T]]] = Map.empty - var applications : Map[Identifier, Set[App[T]]] = Map.empty - var matchers : Map[Identifier, Set[Matcher[T]]] = Map.empty + var blockers : Map[Identifier, Set[TemplateCallInfo[T]]] = Map.empty + var applications : Map[Identifier, Set[App[T]]] = Map.empty + var matchers : Map[Identifier, Set[Matcher[T]]] = Map.empty - for ((b,es) <- guardedExprs) { + for ((b,es) <- cleanGuarded) { var funInfos : Set[TemplateCallInfo[T]] = Set.empty var appInfos : Set[App[T]] = Set.empty var matchInfos : Set[Matcher[T]] = Set.empty @@ -192,23 +226,22 @@ object Template { funInfos ++= firstOrderCallsOf(e).map(p => TemplateCallInfo(p._1, p._2.map(encodeArg))) appInfos ++= firstOrderAppsOf(e).map { case (c, args) => - App(encodeExpr(c), bestRealType(c.getType).asInstanceOf[FunctionType], args.map(encodeArg)) + val tpe = bestRealType(c.getType).asInstanceOf[FunctionType] + App(encodeExpr(c), tpe, args.map(encodeArg), encodeExpr(Application(c, args))) } matchInfos ++= exprToMatcher.values } - val calls = funInfos -- optIdCall + val calls = funInfos.filter(i => Some(i) != optIdCall) if (calls.nonEmpty) blockers += b -> calls - val apps = appInfos -- optIdApp + val apps = appInfos.filter(i => Some(i) != optIdApp) if (apps.nonEmpty) applications += b -> apps - val matchs = matchInfos.filter { case m @ Matcher(mc, mtpe, margs, _) => - !optIdApp.exists { case App(ac, atpe, aargs) => - mc == ac && mtpe == atpe && margs == aargs - } - } ++ (if (funInfos.exists(info => Some(info) == optIdCall)) invocMatcher else None) + val matchs = (matchInfos.filter { case m @ Matcher(_, _, _, menc) => + !optIdApp.exists { case App(_, _, _, aenc) => menc == aenc } + } ++ (if (funInfos.exists(info => Some(info) == optIdCall)) invocMatcher else None)) if (matchs.nonEmpty) matchers += b -> matchs } @@ -224,8 +257,8 @@ object Template { " * Activating boolean : " + pathVar._1 + "\n" + " * Control booleans : " + condVars.keys.mkString(", ") + "\n" + " * Expression vars : " + exprVars.keys.mkString(", ") + "\n" + - " * Clauses : " + (if (guardedExprs.isEmpty) "\n" else { - "\n " + (for ((b,es) <- guardedExprs; e <- es) yield (b + " ==> " + e)).mkString("\n ") + "\n" + " * Clauses : " + (if (cleanGuarded.isEmpty) "\n" else { + "\n " + (for ((b,es) <- cleanGuarded; e <- es) yield (b + " ==> " + e)).mkString("\n ") + "\n" }) + " * Invocation-blocks :" + (if (blockers.isEmpty) "\n" else { "\n " + blockers.map(p => p._1 + " ==> " + p._2).mkString("\n ") + "\n" @@ -241,17 +274,18 @@ object Template { }.mkString("\n") } - (clauses, encodedBlockers, encodedApps, encodedMatchers, stringRepr) + (clauses, encodedBlockers, encodedApps, functions, encodedMatchers, stringRepr) } def substitution[T]( encoder: TemplateEncoder[T], - manager: QuantificationManager[T], + manager: TemplateManager[T], condVars: Map[Identifier, T], exprVars: Map[Identifier, T], condTree: Map[Identifier, Set[Identifier]], quantifications: Seq[QuantificationTemplate[T]], lambdas: Seq[LambdaTemplate[T]], + functions: Set[(T, FunctionType, T)], baseSubst: Map[T, Arg[T]], pathVar: Identifier, aVar: T @@ -259,40 +293,56 @@ object Template { val freshSubst = exprVars.map { case (id, idT) => idT -> encoder.encodeId(id) } ++ manager.freshConds(pathVar -> aVar, condVars, condTree) val matcherSubst = baseSubst.collect { case (c, Right(m)) => c -> m } - var subst = freshSubst.mapValues(Left(_)) ++ baseSubst - // /!\ CAREFUL /!\ - // We have to be wary while computing the lambda subst map since lambdas can - // depend on each other. However, these dependencies cannot be cyclic so it - // suffices to make sure the traversal order is correct. + var subst = freshSubst.mapValues(Left(_)) ++ baseSubst var instantiation : Instantiation[T] = Instantiation.empty - var seen : Set[LambdaTemplate[T]] = Set.empty - - val lambdaKeys = lambdas.map(lambda => lambda.ids._1 -> lambda).toMap - def extractSubst(lambda: LambdaTemplate[T]): Unit = { - for { - dep <- lambda.dependencies.flatMap(p => lambdaKeys.get(p._1)) - if !seen(dep) - } extractSubst(dep) - - if (!seen(lambda)) { - val substMap = subst.mapValues(_.encoded) - val substLambda = lambda.substitute(encoder.substitute(substMap), matcherSubst) - val (idT, inst) = manager.instantiateLambda(substLambda) - instantiation ++= inst - subst += lambda.ids._2 -> Left(idT) - seen += lambda - } + + manager match { + case lmanager: LambdaManager[T] => + val funSubstituter = encoder.substitute(subst.mapValues(_.encoded)) + for ((b,tpe,f) <- functions) { + instantiation ++= lmanager.registerFunction(funSubstituter(b), tpe, funSubstituter(f)) + } + + // /!\ CAREFUL /!\ + // We have to be wary while computing the lambda subst map since lambdas can + // depend on each other. However, these dependencies cannot be cyclic so it + // suffices to make sure the traversal order is correct. + var seen : Set[LambdaTemplate[T]] = Set.empty + + val lambdaKeys = lambdas.map(lambda => lambda.ids._1 -> lambda).toMap + def extractSubst(lambda: LambdaTemplate[T]): Unit = { + for { + dep <- lambda.dependencies.flatMap(p => lambdaKeys.get(p._1)) + if !seen(dep) + } extractSubst(dep) + + if (!seen(lambda)) { + val substMap = subst.mapValues(_.encoded) + val substLambda = lambda.substitute(encoder.substitute(substMap), matcherSubst) + val (idT, inst) = lmanager.instantiateLambda(substLambda) + instantiation ++= inst + subst += lambda.ids._2 -> Left(idT) + seen += lambda + } + } + + for (l <- lambdas) extractSubst(l) + + case _ => } - for (l <- lambdas) extractSubst(l) + manager match { + case qmanager: QuantificationManager[T] => + for (q <- quantifications) { + val substMap = subst.mapValues(_.encoded) + val substQuant = q.substitute(encoder.substitute(substMap), matcherSubst) + val (qT, inst) = qmanager.instantiateQuantification(substQuant) + instantiation ++= inst + subst += q.qs._2 -> Left(qT) + } - for (q <- quantifications) { - val substMap = subst.mapValues(_.encoded) - val substQuant = q.substitute(encoder.substitute(substMap), matcherSubst) - val (qT, inst) = manager.instantiateQuantification(substQuant) - instantiation ++= inst - subst += q.qs._2 -> Left(qT) + case _ => } (subst, instantiation) @@ -300,13 +350,11 @@ object Template { def instantiate[T]( encoder: TemplateEncoder[T], - manager: QuantificationManager[T], + manager: TemplateManager[T], clauses: Seq[T], blockers: Map[T, Set[TemplateCallInfo[T]]], applications: Map[T, Set[App[T]]], - quantifications: Seq[QuantificationTemplate[T]], matchers: Map[T, Set[Matcher[T]]], - lambdas: Seq[LambdaTemplate[T]], substMap: Map[T, Arg[T]] ): Instantiation[T] = { @@ -314,20 +362,31 @@ object Template { val msubst = substMap.collect { case (c, Right(m)) => c -> m } val newClauses = clauses.map(substituter) + val newBlockers = blockers.map { case (b,fis) => substituter(b) -> fis.map(fi => fi.copy(args = fi.args.map(_.substitute(substituter, msubst)))) } var instantiation: Instantiation[T] = (newClauses, newBlockers, Map.empty) - for ((b,apps) <- applications; bp = substituter(b); app <- apps) { - val newApp = app.copy(caller = substituter(app.caller), args = app.args.map(_.substitute(substituter, msubst))) - instantiation ++= manager.instantiateApp(bp, newApp) + manager match { + case lmanager: LambdaManager[T] => + for ((b,apps) <- applications; bp = substituter(b); app <- apps) { + val newApp = app.copy(caller = substituter(app.caller), args = app.args.map(_.substitute(substituter, msubst))) + instantiation ++= lmanager.instantiateApp(bp, newApp) + } + + case _ => } - for ((b, matchs) <- matchers; bp = substituter(b); m <- matchs) { - val newMatcher = m.substitute(substituter, msubst) - instantiation ++= manager.instantiateMatcher(bp, newMatcher) + manager match { + case qmanager: QuantificationManager[T] => + for ((b, matchs) <- matchers; bp = substituter(b); m <- matchs) { + val newMatcher = m.substitute(substituter, msubst) + instantiation ++= qmanager.instantiateMatcher(bp, newMatcher) + } + + case _ => } instantiation @@ -339,7 +398,7 @@ object FunctionTemplate { def apply[T]( tfd: TypedFunDef, encoder: TemplateEncoder[T], - manager: QuantificationManager[T], + manager: TemplateManager[T], pathVar: (Identifier, T), arguments: Seq[(Identifier, T)], condVars: Map[Identifier, T], @@ -351,9 +410,8 @@ object FunctionTemplate { isRealFunDef: Boolean ) : FunctionTemplate[T] = { - val (clauses, blockers, applications, matchers, templateString) = - Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, - substMap = quantifications.map(q => q.qs).toMap, + val (clauses, blockers, applications, functions, matchers, templateString) = + Template.encode(encoder, pathVar, arguments, condVars, exprVars, guardedExprs, lambdas, quantifications, optCall = Some(tfd)) val funString : () => String = () => { @@ -374,9 +432,10 @@ object FunctionTemplate { clauses, blockers, applications, - quantifications, - matchers, + functions, lambdas, + matchers, + quantifications, isRealFunDef, funString ) @@ -386,7 +445,7 @@ object FunctionTemplate { class FunctionTemplate[T] private( val tfd: TypedFunDef, val encoder: TemplateEncoder[T], - val manager: QuantificationManager[T], + val manager: TemplateManager[T], val pathVar: (Identifier, T), val args: Seq[T], val condVars: Map[Identifier, T], @@ -395,19 +454,15 @@ class FunctionTemplate[T] private( val clauses: Seq[T], val blockers: Map[T, Set[TemplateCallInfo[T]]], val applications: Map[T, Set[App[T]]], - val quantifications: Seq[QuantificationTemplate[T]], - val matchers: Map[T, Set[Matcher[T]]], + val functions: Set[(T, FunctionType, T)], val lambdas: Seq[LambdaTemplate[T]], + val matchers: Map[T, Set[Matcher[T]]], + val quantifications: Seq[QuantificationTemplate[T]], isRealFunDef: Boolean, stringRepr: () => String) extends Template[T] { private lazy val str : String = stringRepr() override def toString : String = str - - override def instantiate(aVar: T, args: Seq[Arg[T]]): Instantiation[T] = { - if (!isRealFunDef) manager.registerFree(tfd.params.map(_.id) zip args.map(_.left.get)) - super.instantiate(aVar, args) - } } class TemplateManager[T](protected[templates] val encoder: TemplateEncoder[T]) extends IncrementalState { @@ -436,7 +491,8 @@ class TemplateManager[T](protected[templates] val encoder: TemplateEncoder[T]) e def blocker(b: T): Unit = condImplies += (b -> Set.empty) def isBlocker(b: T): Boolean = condImplies.isDefinedAt(b) || condImplied.isDefinedAt(b) - + def blockerParents(b: T): Set[T] = condImplied(b) + def implies(b1: T, b2: T): Unit = implies(b1, Set(b2)) def implies(b1: T, b2s: Set[T]): Unit = { val fb2s = b2s.filter(_ != b1) diff --git a/src/main/scala/leon/solvers/templates/UnrollingBank.scala b/src/main/scala/leon/solvers/templates/UnrollingBank.scala index 4543262d74204dd9f77d3eeea8882bf543956a90..cb7fe90d280160d03a5b9fbc190a7851b3cd3c1b 100644 --- a/src/main/scala/leon/solvers/templates/UnrollingBank.scala +++ b/src/main/scala/leon/solvers/templates/UnrollingBank.scala @@ -28,7 +28,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat private val appInfos = new IncrementalMap[(T, App[T]), (Int, Int, T, T, Set[TemplateAppInfo[T]])]() private val appBlockers = new IncrementalMap[(T, App[T]), T]() private val blockerToApps = new IncrementalMap[T, (T, App[T])]() - private val functionVars = new IncrementalMap[TypeTree, Set[T]]() def push() { callInfos.push() @@ -37,7 +36,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat appInfos.push() appBlockers.push() blockerToApps.push() - functionVars.push() } def pop() { @@ -47,7 +45,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat appInfos.pop() appBlockers.pop() blockerToApps.pop() - functionVars.pop() } def clear() { @@ -57,7 +54,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat appInfos.clear() appBlockers.clear() blockerToApps.clear() - functionVars.clear() } def reset() { @@ -67,7 +63,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat appInfos.reset() appBlockers.reset() blockerToApps.clear() - functionVars.reset() } def dumpBlockers() = { @@ -91,6 +86,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def refutationAssumptions = manager.assumptions def canUnroll = callInfos.nonEmpty || appInfos.nonEmpty + def canInstantiate = manager.hasIgnored def currentBlockers = callInfos.map(_._2._3).toSeq ++ appInfos.map(_._2._4).toSeq @@ -141,14 +137,13 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } private def freshAppBlocks(apps: Traversable[(T, App[T])]) : Seq[T] = { - apps.filter(!appBlockers.isDefinedAt(_)).toSeq.map { case app @ (blocker, App(caller, tpe, _)) => + apps.filter(!appBlockers.isDefinedAt(_)).toSeq.map { + case app @ (blocker, App(caller, tpe, _, _)) => + val firstB = encoder.encodeId(FreshIdentifier("b_lambda", BooleanType, true)) + val clause = encoder.mkImplies(encoder.mkNot(firstB), encoder.mkNot(blocker)) - val firstB = encoder.encodeId(FreshIdentifier("b_lambda", BooleanType, true)) - val freeEq = functionVars.getOrElse(tpe, Set()).toSeq.map(t => encoder.mkEquals(t, caller)) - val clause = encoder.mkImplies(encoder.mkNot(encoder.mkOr((freeEq :+ firstB) : _*)), encoder.mkNot(blocker)) - - appBlockers += app -> firstB - clause + appBlockers += app -> firstB + clause } } @@ -173,10 +168,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat val trArgs = template.tfd.params.map(vd => Left(bindings(Variable(vd.id)))) - for (vd <- template.tfd.params if vd.getType.isInstanceOf[FunctionType]) { - functionVars += vd.getType -> (functionVars.getOrElse(vd.getType, Set()) + bindings(vd.toVariable)) - } - // ...now this template defines clauses that are all guarded // by that activating boolean. If that activating boolean is // undefined (or false) these clauses have no effect... @@ -218,9 +209,9 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat def promoteBlocker(b: T) = { if (callInfos contains b) { - val (_, origGen, ast, fis) = callInfos(b) + val (_, origGen, notB, fis) = callInfos(b) - callInfos += b -> (1, origGen, ast, fis) + callInfos += b -> (1, origGen, notB, fis) } if (blockerToApps contains b) { @@ -231,6 +222,31 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat } } + def instantiateQuantifiers(force: Boolean = false): Seq[T] = { + val (newExprs, callBlocks, appBlocks) = manager.instantiateIgnored(force) + val blockExprs = freshAppBlocks(appBlocks.keys) + val gens = (callInfos.values.map(_._1) ++ appInfos.values.map(_._1)) + val gen = if (gens.nonEmpty) gens.min else 0 + + for ((b, newInfos) <- callBlocks) { + registerCallBlocker(nextGeneration(gen), b, newInfos) + } + + for ((newApp, newInfos) <- appBlocks) { + registerAppBlocker(nextGeneration(gen), newApp, newInfos) + } + + val clauses = newExprs ++ blockExprs + if (clauses.nonEmpty) { + reporter.debug("Instantiating ignored quantifiers ("+clauses.size+")") + for (cl <- clauses) { + reporter.debug(" . "+cl) + } + } + + clauses + } + def unrollBehind(ids: Seq[T]): Seq[T] = { assert(ids.forall(id => (callInfos contains id) || (blockerToApps contains id))) @@ -254,7 +270,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat newClauses :+= extension } - var fastAppInfos : Map[(T, App[T]), (Int, Set[TemplateAppInfo[T]])] = Map.empty + var fastAppInfos : Seq[((T, App[T]), (Int, Set[TemplateAppInfo[T]]))] = Seq.empty for ((id, (gen, _, _, infos)) <- newCallInfos; info @ TemplateCallInfo(tfd, args) <- infos) { var newCls = Seq[T]() @@ -268,7 +284,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat // we need to define this defBlocker and link it to definition val defBlocker = encoder.encodeId(FreshIdentifier("d", BooleanType)) defBlockers += info -> defBlocker - manager.implies(id, defBlocker) val template = templateGenerator.mkTemplate(tfd) //reporter.debug(template) @@ -302,6 +317,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat // We connect it to the defBlocker: blocker => defBlocker if (defBlocker != id) { newCls :+= encoder.mkImplies(id, defBlocker) + manager.implies(id, defBlocker) } reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") @@ -312,7 +328,9 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat newClauses ++= newCls } - for ((app @ (b, _), (gen, infos)) <- thisAppInfos ++ fastAppInfos; info @ TemplateAppInfo(template, equals, args) <- infos) { + for ((app @ (b, _), (gen, infos)) <- thisAppInfos ++ fastAppInfos; + info @ TemplateAppInfo(tmpl, equals, args) <- infos; + template <- tmpl.left) { var newCls = Seq.empty[T] val lambdaBlocker = lambdaBlockers.get(info) match { @@ -321,7 +339,6 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat case None => val lambdaBlocker = encoder.encodeId(FreshIdentifier("d", BooleanType)) lambdaBlockers += info -> lambdaBlocker - manager.implies(b, lambdaBlocker) val (newExprs, callBlocks, appBlocks) = template.instantiate(lambdaBlocker, args) val blockExprs = freshAppBlocks(appBlocks.keys) @@ -341,6 +358,7 @@ class UnrollingBank[T <% Printable](ctx: LeonContext, templateGenerator: Templat val enabler = if (equals == manager.trueT) b else encoder.mkAnd(equals, b) newCls :+= encoder.mkImplies(enabler, lambdaBlocker) + manager.implies(b, lambdaBlocker) reporter.debug("Unrolling behind "+info+" ("+newCls.size+")") for (cl <- newCls) { diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index ac1e8855a4a53ed5ef2f66c34d4b015d9a26ba3f..7ae02d83e0a6245c7345ef7349211e6c246221ac 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -30,7 +30,7 @@ trait AbstractZ3Solver extends Solver { val library = program.library - protected[z3] val reporter : Reporter = context.reporter + protected val reporter : Reporter = context.reporter context.interruptManager.registerForInterrupts(this) @@ -51,7 +51,7 @@ trait AbstractZ3Solver extends Solver { protected[leon] val z3cfg : Z3Config protected[leon] var z3 : Z3Context = null - override def free() { + override def free(): Unit = { freed = true if (z3 ne null) { z3.delete() @@ -59,18 +59,13 @@ trait AbstractZ3Solver extends Solver { } } - protected[z3] var interrupted = false - - override def interrupt() { - interrupted = true + override def interrupt(): Unit = { if(z3 ne null) { z3.interrupt() } } - override def recoverInterrupt() { - interrupted = false - } + override def recoverInterrupt(): Unit = () def functionDefToDecl(tfd: TypedFunDef): Z3FuncDecl = { functions.cachedB(tfd) { @@ -262,324 +257,313 @@ trait AbstractZ3Solver extends Solver { } case other => - throw SolverUnsupportedError(other, this) + unsupported(other) } - - protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier, Z3AST] = Map.empty): Z3AST = { - implicit var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { + var z3Vars: Map[Identifier,Z3AST] = if(initialMap.nonEmpty) { initialMap } else { // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of // variable has to remain in a map etc. variables.aToB.collect{ case (Variable(id), p2) => id -> p2 } } - new Z3StringConversion[Z3AST] { - def getProgram = AbstractZ3Solver.this.program - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, Z3AST]): Z3AST = { - rec(e) - } - def targetApplication(tfd: TypedFunDef, args: Seq[Z3AST])(implicit bindings: Map[Identifier, Z3AST]): Z3AST = { - z3.mkApp(functionDefToDecl(tfd), args: _*) + + def rec(ex: Expr): Z3AST = ex match { + + // TODO: Leave that as a specialization? + case LetTuple(ids, e, b) => { + z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => + val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) + entry } - def rec(ex: Expr): Z3AST = ex match { - - // TODO: Leave that as a specialization? - case LetTuple(ids, e, b) => { - z3Vars = z3Vars ++ ids.zipWithIndex.map { case (id, ix) => - val entry = id -> rec(tupleSelect(e, ix + 1, ids.size)) - entry - } - val rb = rec(b) - z3Vars = z3Vars -- ids - rb - } - - case p @ Passes(_, _, _) => - rec(p.asConstraint) - - case me @ MatchExpr(s, cs) => - rec(matchToIfThenElse(me)) - - case Let(i, e, b) => { - val re = rec(e) - z3Vars = z3Vars + (i -> re) - val rb = rec(b) - z3Vars = z3Vars - i - rb - } - - case Waypoint(_, e, _) => rec(e) - case a @ Assert(cond, err, body) => - rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) - - case e @ Error(tpe, _) => { - val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) - // Might introduce dupplicates (e), but no worries here - variables += (e -> newAST) - newAST - } - case v @ Variable(id) => z3Vars.get(id) match { - case Some(ast) => + val rb = rec(b) + z3Vars = z3Vars -- ids + rb + } + + case p @ Passes(_, _, _) => + rec(p.asConstraint) + + case me @ MatchExpr(s, cs) => + rec(matchToIfThenElse(me)) + + case Let(i, e, b) => { + val re = rec(e) + z3Vars = z3Vars + (i -> re) + val rb = rec(b) + z3Vars = z3Vars - i + rb + } + + case Waypoint(_, e, _) => rec(e) + case a @ Assert(cond, err, body) => + rec(IfExpr(cond, body, Error(a.getType, err.getOrElse("Assertion failed")).setPos(a.getPos)).setPos(a.getPos)) + + case e @ Error(tpe, _) => { + val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) + // Might introduce dupplicates (e), but no worries here + variables += (e -> newAST) + newAST + } + case v @ Variable(id) => z3Vars.get(id) match { + case Some(ast) => + ast + case None => { + variables.getB(v) match { + case Some(ast) => ast - case None => { - variables.getB(v) match { - case Some(ast) => - ast - - case None => - val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) - z3Vars = z3Vars + (id -> newAST) - variables += (v -> newAST) - newAST - } - } - } - - case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) - case And(exs) => z3.mkAnd(exs.map(rec): _*) - case Or(exs) => z3.mkOr(exs.map(rec): _*) - case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) - case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) - case Not(e) => z3.mkNot(rec(e)) - case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) - case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) - case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) - case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) - case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() - case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) - case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) - case Minus(l, r) => z3.mkSub(rec(l), rec(r)) - case Times(l, r) => z3.mkMul(rec(l), rec(r)) - case Division(l, r) => { - val rl = rec(l) - val rr = rec(r) - z3.mkITE( - z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), - z3.mkDiv(rl, rr), - z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) - ) - } - case Remainder(l, r) => { - val q = rec(Division(l, r)) - z3.mkSub(rec(l), z3.mkMul(rec(r), q)) - } - case Modulo(l, r) => { - z3.mkMod(rec(l), rec(r)) - } - case UMinus(e) => z3.mkUnaryMinus(rec(e)) - - case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) - case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) - case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) - case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) - case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) - - case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) - case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) - case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) - case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) - case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) - case BVUMinus(e) => z3.mkBVNeg(rec(e)) - case BVNot(e) => z3.mkBVNot(rec(e)) - case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) - case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) - case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) - case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) - case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) - case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) - case LessThan(l, r) => l.getType match { - case IntegerType => z3.mkLT(rec(l), rec(r)) - case RealType => z3.mkLT(rec(l), rec(r)) - case Int32Type => z3.mkBVSlt(rec(l), rec(r)) - case CharType => z3.mkBVSlt(rec(l), rec(r)) - } - case LessEquals(l, r) => l.getType match { - case IntegerType => z3.mkLE(rec(l), rec(r)) - case RealType => z3.mkLE(rec(l), rec(r)) - case Int32Type => z3.mkBVSle(rec(l), rec(r)) - case CharType => z3.mkBVSle(rec(l), rec(r)) - //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") - } - case GreaterThan(l, r) => l.getType match { - case IntegerType => z3.mkGT(rec(l), rec(r)) - case RealType => z3.mkGT(rec(l), rec(r)) - case Int32Type => z3.mkBVSgt(rec(l), rec(r)) - case CharType => z3.mkBVSgt(rec(l), rec(r)) - } - case GreaterEquals(l, r) => l.getType match { - case IntegerType => z3.mkGE(rec(l), rec(r)) - case RealType => z3.mkGE(rec(l), rec(r)) - case Int32Type => z3.mkBVSge(rec(l), rec(r)) - case CharType => z3.mkBVSge(rec(l), rec(r)) - } - - case StringConverted(result) => - result - - case u : UnitLiteral => - val tpe = normalizeType(u.getType) - typeToSort(tpe) - val constructor = constructors.toB(tpe) - constructor() - - case t @ Tuple(es) => - val tpe = normalizeType(t.getType) - typeToSort(tpe) - val constructor = constructors.toB(tpe) - constructor(es.map(rec): _*) - - case ts @ TupleSelect(t, i) => - val tpe = normalizeType(t.getType) - typeToSort(tpe) - val selector = selectors.toB((tpe, i-1)) - selector(rec(t)) - - case c @ CaseClass(ct, args) => - typeToSort(ct) // Making sure the sort is defined - val constructor = constructors.toB(ct) - constructor(args.map(rec): _*) - - case c @ CaseClassSelector(cct, cc, sel) => - typeToSort(cct) // Making sure the sort is defined - val selector = selectors.toB(cct, c.selectorIndex) - selector(rec(cc)) - - case AsInstanceOf(expr, ct) => - rec(expr) - - case IsInstanceOf(e, act: AbstractClassType) => - act.knownCCDescendants match { - case Seq(cct) => - rec(IsInstanceOf(e, cct)) - case more => - val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) - rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) - } - - case IsInstanceOf(e, cct: CaseClassType) => - typeToSort(cct) // Making sure the sort is defined - val tester = testers.toB(cct) - tester(rec(e)) - - case al @ ArraySelect(a, i) => - val tpe = normalizeType(a.getType) - - val sa = rec(a) - val content = selectors.toB((tpe, 1))(sa) - - z3.mkSelect(content, rec(i)) - - case al @ ArrayUpdated(a, i, e) => - val tpe = normalizeType(a.getType) - - val sa = rec(a) - val ssize = selectors.toB((tpe, 0))(sa) - val scontent = selectors.toB((tpe, 1))(sa) - - val newcontent = z3.mkStore(scontent, rec(i), rec(e)) - - val constructor = constructors.toB(tpe) - - constructor(ssize, newcontent) - - case al @ ArrayLength(a) => - val tpe = normalizeType(a.getType) - val sa = rec(a) - selectors.toB((tpe, 0))(sa) - - case arr @ FiniteArray(elems, oDefault, length) => - val at @ ArrayType(base) = normalizeType(arr.getType) - typeToSort(at) - - val default = oDefault.getOrElse(simplestValue(base)) - - val ar = rec(RawArrayValue(Int32Type, elems.map { - case (i, e) => IntLiteral(i) -> e - }, default)) - - constructors.toB(at)(rec(length), ar) - - case f @ FunctionInvocation(tfd, args) => - z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) - - case fa @ Application(caller, args) => - val ft @ FunctionType(froms, to) = normalizeType(caller.getType) - val funDecl = lambdas.cachedB(ft) { - val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) - val returnSort = typeToSort(to) - - val name = FreshIdentifier("dynLambda").uniqueName - z3.mkFreshFuncDecl(name, sortSeq, returnSort) - } - z3.mkApp(funDecl, (caller +: args).map(rec): _*) - - case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) - case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) - case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) - case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) - case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) - case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) - - case RawArrayValue(keyTpe, elems, default) => - val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) - - elems.foldLeft(ar) { - case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) - } - - /** - * ===== Map operations ===== - */ - case m @ FiniteMap(elems, from, to) => - val MapType(_, t) = normalizeType(m.getType) - - rec(RawArrayValue(from, elems.map{ - case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) - }.toMap, CaseClass(library.noneType(t), Seq()))) - - case MapApply(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) - typeToSort(mt) - - val el = z3.mkSelect(rec(m), rec(k)) - - // Really ?!? We don't check that it is actually != None? - selectors.toB(library.someType(t), 0)(el) - - case MapIsDefinedAt(m, k) => - val mt @ MapType(_, t) = normalizeType(m.getType) - typeToSort(mt) - - val el = z3.mkSelect(rec(m), rec(k)) - - testers.toB(library.someType(t))(el) - - case MapUnion(m1, FiniteMap(elems, _, _)) => - val mt @ MapType(_, t) = normalizeType(m1.getType) - typeToSort(mt) - - elems.foldLeft(rec(m1)) { case (m, (k,v)) => - z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) - } - - - case gv @ GenericValue(tp, id) => - z3.mkApp(genericValueToDecl(gv)) - - case other => - unsupported(other) + + case None => + val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) + z3Vars = z3Vars + (id -> newAST) + variables += (v -> newAST) + newAST } - }.rec(expr) + } + } + + case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) + case And(exs) => z3.mkAnd(exs.map(rec): _*) + case Or(exs) => z3.mkOr(exs.map(rec): _*) + case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) + case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) + case Not(e) => z3.mkNot(rec(e)) + case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) + case InfiniteIntegerLiteral(v) => z3.mkNumeral(v.toString, typeToSort(IntegerType)) + case FractionalLiteral(n, d) => z3.mkNumeral(s"$n / $d", typeToSort(RealType)) + case CharLiteral(c) => z3.mkInt(c, typeToSort(CharType)) + case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() + case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) + case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) + case Minus(l, r) => z3.mkSub(rec(l), rec(r)) + case Times(l, r) => z3.mkMul(rec(l), rec(r)) + case Division(l, r) => { + val rl = rec(l) + val rr = rec(r) + z3.mkITE( + z3.mkGE(rl, z3.mkNumeral("0", typeToSort(IntegerType))), + z3.mkDiv(rl, rr), + z3.mkUnaryMinus(z3.mkDiv(z3.mkUnaryMinus(rl), rr)) + ) + } + case Remainder(l, r) => { + val q = rec(Division(l, r)) + z3.mkSub(rec(l), z3.mkMul(rec(r), q)) + } + case Modulo(l, r) => { + z3.mkMod(rec(l), rec(r)) + } + case UMinus(e) => z3.mkUnaryMinus(rec(e)) + + case RealPlus(l, r) => z3.mkAdd(rec(l), rec(r)) + case RealMinus(l, r) => z3.mkSub(rec(l), rec(r)) + case RealTimes(l, r) => z3.mkMul(rec(l), rec(r)) + case RealDivision(l, r) => z3.mkDiv(rec(l), rec(r)) + case RealUMinus(e) => z3.mkUnaryMinus(rec(e)) + + case BVPlus(l, r) => z3.mkBVAdd(rec(l), rec(r)) + case BVMinus(l, r) => z3.mkBVSub(rec(l), rec(r)) + case BVTimes(l, r) => z3.mkBVMul(rec(l), rec(r)) + case BVDivision(l, r) => z3.mkBVSdiv(rec(l), rec(r)) + case BVRemainder(l, r) => z3.mkBVSrem(rec(l), rec(r)) + case BVUMinus(e) => z3.mkBVNeg(rec(e)) + case BVNot(e) => z3.mkBVNot(rec(e)) + case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) + case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) + case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) + case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) + case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) + case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) + case LessThan(l, r) => l.getType match { + case IntegerType => z3.mkLT(rec(l), rec(r)) + case RealType => z3.mkLT(rec(l), rec(r)) + case Int32Type => z3.mkBVSlt(rec(l), rec(r)) + case CharType => z3.mkBVSlt(rec(l), rec(r)) + } + case LessEquals(l, r) => l.getType match { + case IntegerType => z3.mkLE(rec(l), rec(r)) + case RealType => z3.mkLE(rec(l), rec(r)) + case Int32Type => z3.mkBVSle(rec(l), rec(r)) + case CharType => z3.mkBVSle(rec(l), rec(r)) + //case _ => throw new IllegalStateException(s"l: $l, Left type: ${l.getType} Expr: $ex") + } + case GreaterThan(l, r) => l.getType match { + case IntegerType => z3.mkGT(rec(l), rec(r)) + case RealType => z3.mkGT(rec(l), rec(r)) + case Int32Type => z3.mkBVSgt(rec(l), rec(r)) + case CharType => z3.mkBVSgt(rec(l), rec(r)) + } + case GreaterEquals(l, r) => l.getType match { + case IntegerType => z3.mkGE(rec(l), rec(r)) + case RealType => z3.mkGE(rec(l), rec(r)) + case Int32Type => z3.mkBVSge(rec(l), rec(r)) + case CharType => z3.mkBVSge(rec(l), rec(r)) + } + + case u : UnitLiteral => + val tpe = normalizeType(u.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor() + + case t @ Tuple(es) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val constructor = constructors.toB(tpe) + constructor(es.map(rec): _*) + + case ts @ TupleSelect(t, i) => + val tpe = normalizeType(t.getType) + typeToSort(tpe) + val selector = selectors.toB((tpe, i-1)) + selector(rec(t)) + + case c @ CaseClass(ct, args) => + typeToSort(ct) // Making sure the sort is defined + val constructor = constructors.toB(ct) + constructor(args.map(rec): _*) + + case c @ CaseClassSelector(cct, cc, sel) => + typeToSort(cct) // Making sure the sort is defined + val selector = selectors.toB(cct, c.selectorIndex) + selector(rec(cc)) + + case AsInstanceOf(expr, ct) => + rec(expr) + + case IsInstanceOf(e, act: AbstractClassType) => + act.knownCCDescendants match { + case Seq(cct) => + rec(IsInstanceOf(e, cct)) + case more => + val i = FreshIdentifier("e", act, alwaysShowUniqueID = true) + rec(Let(i, e, orJoin(more map(IsInstanceOf(Variable(i), _))))) + } + + case IsInstanceOf(e, cct: CaseClassType) => + typeToSort(cct) // Making sure the sort is defined + val tester = testers.toB(cct) + tester(rec(e)) + + case al @ ArraySelect(a, i) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val content = selectors.toB((tpe, 1))(sa) + + z3.mkSelect(content, rec(i)) + + case al @ ArrayUpdated(a, i, e) => + val tpe = normalizeType(a.getType) + + val sa = rec(a) + val ssize = selectors.toB((tpe, 0))(sa) + val scontent = selectors.toB((tpe, 1))(sa) + + val newcontent = z3.mkStore(scontent, rec(i), rec(e)) + + val constructor = constructors.toB(tpe) + + constructor(ssize, newcontent) + + case al @ ArrayLength(a) => + val tpe = normalizeType(a.getType) + val sa = rec(a) + selectors.toB((tpe, 0))(sa) + + case arr @ FiniteArray(elems, oDefault, length) => + val at @ ArrayType(base) = normalizeType(arr.getType) + typeToSort(at) + + val default = oDefault.getOrElse(simplestValue(base)) + + val ar = rec(RawArrayValue(Int32Type, elems.map { + case (i, e) => IntLiteral(i) -> e + }, default)) + + constructors.toB(at)(rec(length), ar) + + case f @ FunctionInvocation(tfd, args) => + z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) + + case fa @ Application(caller, args) => + val ft @ FunctionType(froms, to) = normalizeType(caller.getType) + val funDecl = lambdas.cachedB(ft) { + val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) + val returnSort = typeToSort(to) + + val name = FreshIdentifier("dynLambda").uniqueName + z3.mkFreshFuncDecl(name, sortSeq, returnSort) + } + z3.mkApp(funDecl, (caller +: args).map(rec): _*) + + case ElementOfSet(e, s) => z3.mkSetMember(rec(e), rec(s)) + case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) + case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) + case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) + case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) + case f @ FiniteSet(elems, base) => elems.foldLeft(z3.mkEmptySet(typeToSort(base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) + + case RawArrayValue(keyTpe, elems, default) => + val ar = z3.mkConstArray(typeToSort(keyTpe), rec(default)) + + elems.foldLeft(ar) { + case (array, (k, v)) => z3.mkStore(array, rec(k), rec(v)) + } + + /** + * ===== Map operations ===== + */ + case m @ FiniteMap(elems, from, to) => + val MapType(_, t) = normalizeType(m.getType) + + rec(RawArrayValue(from, elems.map{ + case (k, v) => (k, CaseClass(library.someType(t), Seq(v))) + }, CaseClass(library.noneType(t), Seq()))) + + case MapApply(m, k) => + val mt @ MapType(_, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + // Really ?!? We don't check that it is actually != None? + selectors.toB(library.someType(t), 0)(el) + + case MapIsDefinedAt(m, k) => + val mt @ MapType(_, t) = normalizeType(m.getType) + typeToSort(mt) + + val el = z3.mkSelect(rec(m), rec(k)) + + testers.toB(library.someType(t))(el) + + case MapUnion(m1, FiniteMap(elems, _, _)) => + val mt @ MapType(_, t) = normalizeType(m1.getType) + typeToSort(mt) + + elems.foldLeft(rec(m1)) { case (m, (k,v)) => + z3.mkStore(m, rec(k), rec(CaseClass(library.someType(t), Seq(v)))) + } + + + case gv @ GenericValue(tp, id) => + z3.mkApp(genericValueToDecl(gv)) + + case other => + unsupported(other) + } + + rec(expr) } protected[leon] def fromZ3Formula(model: Z3Model, tree: Z3AST, tpe: TypeTree): Expr = { - def rec(t: Z3AST, expected_tpe: TypeTree): Expr = { + + def rec(t: Z3AST, tpe: TypeTree): Expr = { val kind = z3.getASTKind(t) - val tpe = Z3StringTypeConversion.convert(expected_tpe)(program) - val res = kind match { + kind match { case Z3NumeralIntAST(Some(v)) => val leading = t.toString.substring(0, 2 min t.toString.length) if(leading == "#x") { @@ -634,7 +618,7 @@ trait AbstractZ3Solver extends Solver { case Z3AppAST(decl, args) => val argsSize = args.size - if(argsSize == 0 && (variables containsB t)) { + if (argsSize == 0 && (variables containsB t)) { variables.toA(t) } else if(functions containsB decl) { val tfd = functions.toA(decl) @@ -694,13 +678,13 @@ trait AbstractZ3Solver extends Solver { case None => simplestValue(ft) case Some((_, mapping, elseValue)) => val leonElseValue = rec(elseValue, tt) - PartialLambda(mapping.flatMap { case (z3Args, z3Result) => + FiniteLambda(mapping.flatMap { case (z3Args, z3Result) => if (t == z3Args.head) { List((z3Args.tail zip fts).map(p => rec(p._1, p._2)) -> rec(z3Result, tt)) } else { Nil } - }, Some(leonElseValue), ft) + }, leonElseValue, ft) } } @@ -769,11 +753,6 @@ trait AbstractZ3Solver extends Solver { } case _ => unsound(t, "unexpected AST") } - expected_tpe match { - case StringType => - StringLiteral(Z3StringTypeConversion.convertToString(res)(program)) - case _ => res - } } rec(tree, normalizeType(tpe)) @@ -790,11 +769,10 @@ trait AbstractZ3Solver extends Solver { } def idToFreshZ3Id(id: Identifier): Z3AST = { - val correctType = Z3StringTypeConversion.convert(id.getType)(program) - z3.mkFreshConst(id.uniqueName, typeToSort(correctType)) + z3.mkFreshConst(id.uniqueName, typeToSort(id.getType)) } - def reset() = { + def reset(): Unit = { throw new CantResetException(this) } diff --git a/src/main/scala/leon/solvers/z3/FairZ3Component.scala b/src/main/scala/leon/solvers/z3/FairZ3Component.scala deleted file mode 100644 index 70ebd260f931a6a428a84afad9640accf193887d..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/solvers/z3/FairZ3Component.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -package leon -package solvers.z3 - -trait FairZ3Component extends LeonComponent { - val name = "Z3-f" - val description = "Fair Z3 Solver" - - val optEvalGround = LeonFlagOptionDef("evalground", "Use evaluator on functions applied to ground arguments", false) - val optCheckModels = LeonFlagOptionDef("checkmodels", "Double-check counter-examples with evaluator", false) - val optFeelingLucky = LeonFlagOptionDef("feelinglucky","Use evaluator to find counter-examples early", false) - val optUseCodeGen = LeonFlagOptionDef("codegen", "Use compiled evaluator instead of interpreter", false) - val optUnrollCores = LeonFlagOptionDef("unrollcores", "Use unsat-cores to drive unrolling while remaining fair", false) - val optAssumePre = LeonFlagOptionDef("assumepre", "Assume precondition holds (pre && f(x) = body) when unfolding", false) - val optNoChecks = LeonFlagOptionDef("nochecks", "Disable counter-example check in presence of foralls" , false) - val optUnfoldFactor = LeonLongOptionDef("unfoldFactor", "Number of unfoldings to perform in each unfold step", default = 1, "<PosInt>") - - override val definedOptions: Set[LeonOptionDef[Any]] = - Set(optEvalGround, optCheckModels, optFeelingLucky, optUseCodeGen, optUnrollCores, optAssumePre, optUnfoldFactor) -} - -object FairZ3Component extends FairZ3Component diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index c975fd2d37e7563137d9b1fd4cbe0f2cbb7892f4..5c12fd3863da241f09a667133731a9b31182f10f 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -17,6 +17,7 @@ import purescala.ExprOps._ import purescala.Types._ import solvers.templates._ +import solvers.combinators._ import Template._ import evaluators._ @@ -25,32 +26,19 @@ import termination._ class FairZ3Solver(val context: LeonContext, val program: Program) extends AbstractZ3Solver - with Z3ModelReconstruction - with FairZ3Component - with EvaluatingSolver - with QuantificationSolver { + with AbstractUnrollingSolver[Z3AST] { enclosing => - val feelingLucky = context.findOptionOrDefault(optFeelingLucky) - val checkModels = context.findOptionOrDefault(optCheckModels) - val useCodeGen = context.findOptionOrDefault(optUseCodeGen) - val evalGroundApps = context.findOptionOrDefault(optEvalGround) - val unrollUnsatCores = context.findOptionOrDefault(optUnrollCores) - val assumePreHolds = context.findOptionOrDefault(optAssumePre) - val disableChecks = context.findOptionOrDefault(optNoChecks) - - assert(!checkModels || !disableChecks, "Options \"checkmodels\" and \"nochecks\" are mutually exclusive") - protected val errors = new IncrementalBijection[Unit, Boolean]() protected def hasError = errors.getB(()) contains true protected def addError() = errors += () -> true - protected[z3] def getEvaluator : Evaluator = evaluator + override val name = "Z3-f" + override val description = "Fair Z3 Solver" - private val terminator : TerminationChecker = new SimpleTerminationChecker(context, program) - - protected[z3] def getTerminator : TerminationChecker = terminator + override protected val reporter = context.reporter + override def reset(): Unit = super[AbstractZ3Solver].reset() // FIXME: Dirty hack to bypass z3lib bug. Assumes context is the same over all instances of FairZ3Solver protected[leon] val z3cfg = context.synchronized { new Z3Config( @@ -60,53 +48,81 @@ class FairZ3Solver(val context: LeonContext, val program: Program) )} toggleWarningMessages(true) - private def extractModel(model: Z3Model, ids: Set[Identifier]): HenkinModel = { - def extract(b: Z3AST, m: Matcher[Z3AST]): Set[Seq[Expr]] = { - val QuantificationTypeMatcher(fromTypes, _) = m.tpe - val optEnabler = model.evalAs[Boolean](b) + def solverCheck[R](clauses: Seq[Z3AST])(block: Option[Boolean] => R): R = { + solver.push() + for (cls <- clauses) solver.assertCnstr(cls) + val res = solver.check + val r = block(res) + solver.pop() + r + } - if (optEnabler == Some(true)) { - val optArgs = (m.args zip fromTypes).map { - p => softFromZ3Formula(model, model.eval(p._1.encoded, true).get, p._2) - } + override def solverCheckAssumptions[R](assumptions: Seq[Z3AST])(block: Option[Boolean] => R): R = { + solver.push() // FIXME: remove when z3 bug is fixed + val res = solver.checkAssumptions(assumptions : _*) + solver.pop() // FIXME: remove when z3 bug is fixed + block(res) + } - if (optArgs.forall(_.isDefined)) { - Set(optArgs.map(_.get)) - } else { - Set.empty - } - } else { - Set.empty + def solverGetModel: ModelWrapper = new ModelWrapper { + val model = solver.getModel + + /* + val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap + val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { + if (functions containsB p._1) { + val tfd = functions.toA(p._1) + if (!tfd.hasImplementation) { + val (cses, default) = p._2 + val ite = cses.foldLeft(fromZ3Formula(model, default, tfd.returnType))((expr, q) => IfExpr( + andJoin(q._1.zip(tfd.params).map(a12 => Equals(fromZ3Formula(model, a12._1, a12._2.getType), Variable(a12._2.id)))), + fromZ3Formula(model, q._2, tfd.returnType), + expr)) + Seq((tfd.id, ite)) + } else Seq() + } else Seq() + }) + + val constantFunctionsAsMap: Map[Identifier, Expr] = model.getModelConstantInterpretations.flatMap(p => { + if(functions containsB p._1) { + val tfd = functions.toA(p._1) + if(!tfd.hasImplementation) { + Seq((tfd.id, fromZ3Formula(model, p._2, tfd.returnType))) + } else Seq() + } else Seq() + }).toMap + + val leonModel = extractModel(model, freeVars.toSet) + val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) + */ + + def get(id: Identifier): Option[Expr] = variables.getB(id.toVariable).flatMap { + z3ID => eval(z3ID, id.getType) match { + case Some(Variable(id)) => None + case e => e } } - val (typeInsts, partialInsts, lambdaInsts) = templateGenerator.manager.instantiations - - val typeDomains: Map[TypeTree, Set[Seq[Expr]]] = typeInsts.map { - case (tpe, domain) => tpe -> domain.flatMap { case (b, m) => extract(b, m) }.toSet - } - - val funDomains: Map[Identifier, Set[Seq[Expr]]] = partialInsts.flatMap { - case (c, domain) => variables.getA(c).collect { - case Variable(id) => id -> domain.flatMap { case (b, m) => extract(b, m) }.toSet + def eval(elem: Z3AST, tpe: TypeTree): Option[Expr] = tpe match { + case BooleanType => model.evalAs[Boolean](elem).map(BooleanLiteral) + case Int32Type => model.evalAs[Int](elem).map(IntLiteral).orElse { + model.eval(elem).flatMap(t => softFromZ3Formula(model, t, Int32Type)) + } + case IntegerType => model.evalAs[Int](elem).map(InfiniteIntegerLiteral(_)) + case other => model.eval(elem) match { + case None => None + case Some(t) => softFromZ3Formula(model, t, other) } } - val lambdaDomains: Map[Lambda, Set[Seq[Expr]]] = lambdaInsts.map { - case (l, domain) => l -> domain.flatMap { case (b, m) => extract(b, m) }.toSet - } - - val asMap = modelToMap(model, ids) - val asDMap = purescala.Quantification.extractModel(asMap, funDomains, typeDomains, evaluator) - val domains = new HenkinDomains(lambdaDomains, typeDomains) - new HenkinModel(asDMap, domains) + override def toString = model.toString } - implicit val z3Printable = (z3: Z3AST) => new Printable { + val printable = (z3: Z3AST) => new Printable { def asString(implicit ctx: LeonContext) = z3.toString } - val templateGenerator = new TemplateGenerator(new TemplateEncoder[Z3AST] { + val templateEncoder = new TemplateEncoder[Z3AST] { def encodeId(id: Identifier): Z3AST = { idToFreshZ3Id(id) } @@ -127,31 +143,33 @@ class FairZ3Solver(val context: LeonContext, val program: Program) def mkAnd(es: Z3AST*) = z3.mkAnd(es : _*) def mkEquals(l: Z3AST, r: Z3AST) = z3.mkEq(l, r) def mkImplies(l: Z3AST, r: Z3AST) = z3.mkImplies(l, r) - }, assumePreHolds) + def extractNot(l: Z3AST): Option[Z3AST] = z3.getASTKind(l) match { + case Z3AppAST(decl, args) => z3.getDeclKind(decl) match { + case Z3DeclKind.OpNot => Some(args.head) + case Z3DeclKind.OpUninterpreted => None + } + case ast => None + } + } initZ3() val solver = z3.mkSolver() - private val freeVars = new IncrementalSet[Identifier]() - private val constraints = new IncrementalSeq[Expr]() - - val tr = implicitly[Z3AST => Printable] - - val unrollingBank = new UnrollingBank(context, templateGenerator) - private val incrementals: List[IncrementalState] = List( - errors, freeVars, constraints, functions, generics, lambdas, sorts, variables, - constructors, selectors, testers, unrollingBank + errors, functions, generics, lambdas, sorts, variables, + constructors, selectors, testers ) - def push() { + override def push(): Unit = { + super.push() solver.push() incrementals.foreach(_.push()) } - def pop() { + override def pop(): Unit = { + super.pop() solver.pop(1) incrementals.foreach(_.pop()) } @@ -160,7 +178,7 @@ class FairZ3Solver(val context: LeonContext, val program: Program) if (hasError) { None } else { - fairCheck(Set()) + super.check } } @@ -168,330 +186,43 @@ class FairZ3Solver(val context: LeonContext, val program: Program) if (hasError) { None } else { - fairCheck(assumptions) + super.checkAssumptions(assumptions) } } - var foundDefinitiveAnswer = false - var definitiveAnswer : Option[Boolean] = None - var definitiveModel : HenkinModel = HenkinModel.empty - var definitiveCore : Set[Expr] = Set.empty - - def assertCnstr(expression: Expr) { + def assertCnstr(expression: Expr): Unit = { try { - val newFreeVars = variablesOf(expression) - freeVars ++= newFreeVars + val bindings = variablesOf(expression).map(id => id -> variables.cachedB(Variable(id)) { + templateGenerator.encoder.encodeId(id) + }).toMap - // We make sure all free variables are registered as variables - freeVars.foreach { v => - variables.cachedB(Variable(v)) { - templateGenerator.encoder.encodeId(v) - } - } - - constraints += expression - - val newClauses = unrollingBank.getClauses(expression, variables.aToB) - - for (cl <- newClauses) { - solver.assertCnstr(cl) - } + assertCnstr(expression, bindings) } catch { case _: Unsupported => addError() } } - def getModel = { - definitiveModel + def solverAssert(cnstr: Z3AST): Unit = { + solver.assertCnstr(cnstr) } - def getUnsatCore = { - definitiveCore - } - - def fairCheck(assumptions: Set[Expr]): Option[Boolean] = { - foundDefinitiveAnswer = false + def solverUnsatCore = Some(solver.getUnsatCore) - def entireFormula = andJoin(assumptions.toSeq ++ constraints.toSeq) + override def foundAnswer(res: Option[Boolean], model: Model = Model.empty, core: Set[Expr] = Set.empty) = { + super.foundAnswer(res, model, core) - def foundAnswer(answer: Option[Boolean], model: HenkinModel = HenkinModel.empty, core: Set[Expr] = Set.empty) : Unit = { - foundDefinitiveAnswer = true - definitiveAnswer = answer - definitiveModel = model - definitiveCore = core - } - - // these are the optional sequence of assumption literals - val assumptionsAsZ3: Seq[Z3AST] = assumptions.map(toZ3Formula(_)).toSeq - val assumptionsAsZ3Set: Set[Z3AST] = assumptionsAsZ3.toSet - - def z3CoreToCore(core: Seq[Z3AST]): Set[Expr] = { - core.filter(assumptionsAsZ3Set).map(ast => fromZ3Formula(null, ast, BooleanType) match { - case n @ Not(Variable(_)) => n - case v @ Variable(_) => v - case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") - }).toSet - } - - def validatedModel(silenceErrors: Boolean) : (Boolean, HenkinModel) = { - if (interrupted) { - (false, HenkinModel.empty) - } else { - val lastModel = solver.getModel - val clauses = templateGenerator.manager.checkClauses - val optModel = if (clauses.isEmpty) Some(lastModel) else { - solver.push() - for (clause <- clauses) { - solver.assertCnstr(clause) - } - - reporter.debug(" - Enforcing model transitivity") - val timer = context.timers.solvers.z3.check.start() - solver.push() // FIXME: remove when z3 bug is fixed - val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.satisfactionAssumptions) :_*) - solver.pop() // FIXME: remove when z3 bug is fixed - timer.stop() - - val solverModel = res match { - case Some(true) => - Some(solver.getModel) - - case Some(false) => - val msg = "- Transitivity independence not guaranteed for model" - if (silenceErrors) { - reporter.debug(msg) - } else { - reporter.warning(msg) - } - None - - case None => - val msg = "- Unknown for transitivity independence!?" - if (silenceErrors) { - reporter.debug(msg) - } else { - reporter.warning(msg) - } - None - } - - solver.pop() - solverModel - } - - val model = optModel getOrElse lastModel - - val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap - val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { - if (functions containsB p._1) { - val tfd = functions.toA(p._1) - if (!tfd.hasImplementation) { - val (cses, default) = p._2 - val ite = cses.foldLeft(fromZ3Formula(model, default, tfd.returnType))((expr, q) => IfExpr( - andJoin( - q._1.zip(tfd.params).map(a12 => Equals(fromZ3Formula(model, a12._1, a12._2.getType), Variable(a12._2.id))) - ), - fromZ3Formula(model, q._2, tfd.returnType), - expr)) - Seq((tfd.id, ite)) - } else Seq() - } else Seq() - }) - - val constantFunctionsAsMap: Map[Identifier, Expr] = model.getModelConstantInterpretations.flatMap(p => { - if(functions containsB p._1) { - val tfd = functions.toA(p._1) - if(!tfd.hasImplementation) { - Seq((tfd.id, fromZ3Formula(model, p._2, tfd.returnType))) - } else Seq() - } else Seq() - }).toMap - - val leonModel = extractModel(model, freeVars.toSet) - val fullModel = leonModel ++ (functionsAsMap ++ constantFunctionsAsMap) - - if (!optModel.isDefined) { - (false, leonModel) - } else { - (evaluator.check(entireFormula, fullModel) match { - case EvaluationResults.CheckSuccess => - reporter.debug("- Model validated.") - true - - case EvaluationResults.CheckValidityFailure => - reporter.debug("- Invalid model.") - false - - case EvaluationResults.CheckRuntimeFailure(msg) => - if (silenceErrors) { - reporter.debug("- Model leads to evaluation error: " + msg) - } else { - reporter.warning("- Model leads to evaluation error: " + msg) - } - false - - case EvaluationResults.CheckQuantificationFailure(msg) => - if (silenceErrors) { - reporter.debug("- Model leads to quantification error: " + msg) - } else { - reporter.warning("- Model leads to quantification error: " + msg) - } - false - }, leonModel) + if (!interrupted && res == None && model == None) { + reporter.ifDebug { debug => + if (solver.getReasonUnknown != "canceled") { + debug("Z3 returned unknown: " + solver.getReasonUnknown) } } } + } - while(!foundDefinitiveAnswer && !interrupted) { - - //val blockingSetAsZ3 : Seq[Z3AST] = blockingSet.toSeq.map(toZ3Formula(_).get) - // println("Blocking set : " + blockingSet) - - reporter.debug(" - Running Z3 search...") - - //reporter.debug("Searching in:\n"+solver.getAssertions.toSeq.mkString("\nAND\n")) - //reporter.debug("Unroll. Assumptions:\n"+unrollingBank.z3CurrentZ3Blockers.mkString(" && ")) - //reporter.debug("Userland Assumptions:\n"+assumptionsAsZ3.mkString(" && ")) - - val timer = context.timers.solvers.z3.check.start() - solver.push() // FIXME: remove when z3 bug is fixed - val res = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.satisfactionAssumptions) :_*) - solver.pop() // FIXME: remove when z3 bug is fixed - timer.stop() - - reporter.debug(" - Finished search with blocked literals") - - lazy val allVars: Set[Identifier] = freeVars.toSet - - res match { - case None => - reporter.ifDebug { debug => - if (solver.getReasonUnknown != "canceled") { - debug("Z3 returned unknown: " + solver.getReasonUnknown) - } - } - foundAnswer(None) - - case Some(true) => // SAT - val (valid, model) = if (!this.disableChecks && (this.checkModels || requireQuantification)) { - validatedModel(false) - } else { - true -> extractModel(solver.getModel, allVars) - } - - if (valid) { - foundAnswer(Some(true), model) - } else { - reporter.error("Something went wrong. The model should have been valid, yet we got this : ") - reporter.error(model.asString(context)) - foundAnswer(None, model) - } - - case Some(false) if !unrollingBank.canUnroll => - - val core = z3CoreToCore(solver.getUnsatCore()) - - foundAnswer(Some(false), core = core) - - // This branch is both for with and without unsat cores. The - // distinction is made inside. - case Some(false) => - - def coreElemToBlocker(c: Z3AST): (Z3AST, Boolean) = { - z3.getASTKind(c) match { - case Z3AppAST(decl, args) => - z3.getDeclKind(decl) match { - case Z3DeclKind.OpNot => - (args.head, true) - case Z3DeclKind.OpUninterpreted => - (c, false) - } - - case ast => - (c, false) - } - } - - if (unrollUnsatCores) { - unrollingBank.decreaseAllGenerations() - - for (c <- solver.getUnsatCore()) { - val (z3ast, pol) = coreElemToBlocker(c) - assert(pol) - - unrollingBank.promoteBlocker(z3ast) - } - - } - - //debug("UNSAT BECAUSE: "+solver.getUnsatCore.mkString("\n AND \n")) - //debug("UNSAT BECAUSE: "+core.mkString(" AND ")) - - if (!interrupted) { - if (this.feelingLucky) { - // we need the model to perform the additional test - reporter.debug(" - Running search without blocked literals (w/ lucky test)") - } else { - reporter.debug(" - Running search without blocked literals (w/o lucky test)") - } - - val timer = context.timers.solvers.z3.check.start() - solver.push() // FIXME: remove when z3 bug is fixed - val res2 = solver.checkAssumptions((assumptionsAsZ3 ++ unrollingBank.refutationAssumptions) : _*) - solver.pop() // FIXME: remove when z3 bug is fixed - timer.stop() - - reporter.debug(" - Finished search without blocked literals") - - res2 match { - case Some(false) => - //reporter.debug("UNSAT WITHOUT Blockers") - foundAnswer(Some(false), core = z3CoreToCore(solver.getUnsatCore)) - case Some(true) => - //reporter.debug("SAT WITHOUT Blockers") - if (this.feelingLucky && !interrupted) { - // we might have been lucky :D - val (wereWeLucky, cleanModel) = validatedModel(true) - - if(wereWeLucky) { - foundAnswer(Some(true), cleanModel) - } - } - - case None => - foundAnswer(None) - } - } - - if(interrupted) { - foundAnswer(None) - } - - if(!foundDefinitiveAnswer) { - reporter.debug("- We need to keep going.") - - val toRelease = unrollingBank.getBlockersToUnlock - - reporter.debug(" - more unrollings") - - val newClauses = unrollingBank.unrollBehind(toRelease) - - for(ncl <- newClauses) { - solver.assertCnstr(ncl) - } - - //readLine() - - reporter.debug(" - finished unrolling") - } - } - } - - if(interrupted) { - None - } else { - definitiveAnswer - } + override def interrupt(): Unit = { + super[AbstractZ3Solver].interrupt() + super[AbstractUnrollingSolver].interrupt() } } diff --git a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala index 3daf1ad4964ad73e8c4d9701ae4e65d0f4170897..b644f687af3950d03bb062f0ef0f030c3691f6b4 100644 --- a/src/main/scala/leon/solvers/z3/Z3StringConversion.scala +++ b/src/main/scala/leon/solvers/z3/Z3StringConversion.scala @@ -7,88 +7,371 @@ import purescala.Expressions._ import purescala.Constructors._ import purescala.Types._ import purescala.Definitions._ -import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} -import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} -import _root_.smtlib.interpreters.Z3Interpreter -import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} -import _root_.smtlib.theories.ArraysEx import leon.utils.Bijection +import leon.purescala.DefOps +import leon.purescala.TypeOps +import leon.purescala.Extractors.Operator +import leon.evaluators.EvaluationResults -object Z3StringTypeConversion { - def convert(t: TypeTree)(implicit p: Program) = new Z3StringTypeConversion { def getProgram = p }.convertType(t) - def convertToString(e: Expr)(implicit p: Program) = new Z3StringTypeConversion{ def getProgram = p }.convertToString(e) -} - -trait Z3StringTypeConversion { - val stringBijection = new Bijection[String, Expr]() +object StringEcoSystem { + private def withIdentifier[T](name: String, tpe: TypeTree = Untyped)(f: Identifier => T): T = { + val id = FreshIdentifier(name, tpe) + f(id) + } + private def withIdentifiers[T](name: String, tpe: TypeTree, name2: String, tpe2: TypeTree = Untyped)(f: (Identifier, Identifier) => T): T = { + withIdentifier(name, tpe)(id => withIdentifier(name2, tpe2)(id2 => f(id, id2))) + } + + val StringList = AbstractClassDef(FreshIdentifier("StringList"), Seq(), None) + val StringListTyped = StringList.typed + val StringCons = withIdentifiers("head", CharType, "tail", StringListTyped){ (head, tail) => + val d = CaseClassDef(FreshIdentifier("StringCons"), Seq(), Some(StringListTyped), false) + d.setFields(Seq(ValDef(head), ValDef(tail))) + d + } + StringList.registerChild(StringCons) + val StringConsTyped = StringCons.typed + val StringNil = CaseClassDef(FreshIdentifier("StringNil"), Seq(), Some(StringListTyped), false) + val StringNilTyped = StringNil.typed + StringList.registerChild(StringNil) - lazy val conschar = program.lookupCaseClass("leon.collection.Cons") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find Cons in Z3 solver") + val StringSize = withIdentifiers("l", StringListTyped, "StringSize"){ (lengthArg, id) => + val fd = new FunDef(id, Seq(), Seq(ValDef(lengthArg)), IntegerType) + fd.body = Some(withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(lengthArg), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, InfiniteIntegerLiteral(BigInt(0))), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + Plus(InfiniteIntegerLiteral(BigInt(1)), FunctionInvocation(fd.typed, Seq(Variable(t))))) + )) + }) + fd } - lazy val nilchar = program.lookupCaseClass("leon.collection.Nil") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find Nil in Z3 solver") + val StringListConcat = withIdentifiers("x", StringListTyped, "y", StringListTyped) { (x, y) => + val fd = new FunDef(FreshIdentifier("StringListConcat"), Seq(), Seq(ValDef(x), ValDef(y)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped){ (h, t) => + MatchExpr(Variable(x), Seq( + MatchCase(CaseClassPattern(None, StringNilTyped, Seq()), None, Variable(y)), + MatchCase(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), None, + CaseClass(StringConsTyped, Seq(Variable(h), FunctionInvocation(fd.typed, Seq(Variable(t), Variable(y))))) + ))) + } + ) + fd } - lazy val listchar = program.lookupAbstractClass("leon.collection.List") match { - case Some(cc) => cc.typed(Seq(CharType)) - case _ => throw new Exception("Could not find List in Z3 solver") + + val StringTake = withIdentifiers("tt", StringListTyped, "it", StringListTyped) { (tt, it) => + val fd = new FunDef(FreshIdentifier("StringTake"), Seq(), Seq(ValDef(tt), ValDef(it)), StringListTyped) + fd.body = Some{ + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(tt), Variable(it))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringNilTyped, Seq()), + CaseClass(StringConsTyped, Seq(Variable(h), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))))) + )))) + } + } + } + fd } - def lookupFunDef(s: String): FunDef = program.lookupFunDef(s) match { - case Some(fd) => fd - case _ => throw new Exception("Could not find function "+s+" in program") + + val StringDrop = withIdentifiers("td", StringListTyped, "id", IntegerType) { (td, id) => + val fd = new FunDef(FreshIdentifier("StringDrop"), Seq(), Seq(ValDef(td), ValDef(id)), StringListTyped) + fd.body = Some( + withIdentifiers("h", CharType, "t", StringListTyped) { (h, t) => + withIdentifier("i", IntegerType){ i => + MatchExpr(Tuple(Seq(Variable(td), Variable(id))), Seq( + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringNilTyped, Seq()), WildcardPattern(None))), None, + InfiniteIntegerLiteral(BigInt(0))), + MatchCase(TuplePattern(None, Seq(CaseClassPattern(None, StringConsTyped, Seq(WildcardPattern(Some(h)), WildcardPattern(Some(t)))), WildcardPattern(Some(i)))), None, + IfExpr(LessThan(Variable(i), InfiniteIntegerLiteral(BigInt(0))), + CaseClass(StringConsTyped, Seq(Variable(h), Variable(t))), + FunctionInvocation(fd.typed, Seq(Variable(t), Minus(Variable(i), InfiniteIntegerLiteral(BigInt(1))))) + )))) + }} + ) + fd } - lazy val list_size = lookupFunDef("leon.collection.List.size").typed(Seq(CharType)) - lazy val list_++ = lookupFunDef("leon.collection.List.++").typed(Seq(CharType)) - lazy val list_take = lookupFunDef("leon.collection.List.take").typed(Seq(CharType)) - lazy val list_drop = lookupFunDef("leon.collection.List.drop").typed(Seq(CharType)) - lazy val list_slice = lookupFunDef("leon.collection.List.slice").typed(Seq(CharType)) - private lazy val program = getProgram + val StringSlice = withIdentifier("s", StringListTyped) { s => withIdentifiers("from", IntegerType, "to", IntegerType) { (from, to) => + val fd = new FunDef(FreshIdentifier("StringSlice"), Seq(), Seq(ValDef(s), ValDef(from), ValDef(to)), StringListTyped) + fd.body = Some( + FunctionInvocation(StringTake.typed, + Seq(FunctionInvocation(StringDrop.typed, Seq(Variable(s), Variable(from))), + Minus(Variable(to), Variable(from))))) + fd + } } - def getProgram: Program + val classDefs = Seq(StringList, StringCons, StringNil) + val funDefs = Seq(StringSize, StringListConcat, StringTake, StringDrop, StringSlice) +} + +class Z3StringConversion(val p: Program) extends Z3StringConverters { + import StringEcoSystem._ + def getProgram = program_with_string_methods - def convertType(t: TypeTree): TypeTree = t match { - case StringType => listchar - case _ => t + lazy val program_with_string_methods = { + val p2 = DefOps.addClassDefs(p, StringEcoSystem.classDefs, p.library.Nil.get) + DefOps.addFunDefs(p2, StringEcoSystem.funDefs, p2.library.escape.get) } - def convertToString(e: Expr)(implicit p: Program): String = +} + +trait Z3StringConverters { + import StringEcoSystem._ + val mappedVariables = new Bijection[Identifier, Identifier]() + + val globalClassMap = new Bijection[ClassDef, ClassDef]() // To be added manually + + val globalFdMap = new Bijection[FunDef, FunDef]() + + val stringBijection = new Bijection[String, Expr]() + + def convertToString(e: Expr): String = stringBijection.cachedA(e) { e match { case CaseClass(_, Seq(CharLiteral(c), l)) => c + convertToString(l) case CaseClass(_, Seq()) => "" } } - def convertFromString(v: String) = + def convertFromString(v: String): Expr = stringBijection.cachedB(v) { - v.toList.foldRight(CaseClass(nilchar, Seq())){ - case (char, l) => CaseClass(conschar, Seq(CharLiteral(char), l)) + v.toList.foldRight(CaseClass(StringNilTyped, Seq())){ + case (char, l) => CaseClass(StringConsTyped, Seq(CharLiteral(char), l)) } } -} - -trait Z3StringConversion[TargetType] extends Z3StringTypeConversion { - def convertToTarget(e: Expr)(implicit bindings: Map[Identifier, TargetType]): TargetType - def targetApplication(fd: TypedFunDef, args: Seq[TargetType])(implicit bindings: Map[Identifier, TargetType]): TargetType - object StringConverted { - def unapply(e: Expr)(implicit bindings: Map[Identifier, TargetType]): Option[TargetType] = e match { + trait BidirectionalConverters { + def convertFunDef(fd: FunDef): FunDef + def hasIdConversion(id: Identifier): Boolean + def convertId(id: Identifier): Identifier + def convertClassDef(d: ClassDef): ClassDef + def isTypeToConvert(tpe: TypeTree): Boolean + def convertType(tpe: TypeTree): TypeTree + def convertPattern(pattern: Pattern): Pattern + def convertExpr(expr: Expr)(implicit bindings: Map[Identifier, Expr]): Expr + object TypeConverted { + def unapply(t: TypeTree): Option[TypeTree] = Some(t match { + case cct@CaseClassType(ccd, args) => CaseClassType(convertClassDef(ccd).asInstanceOf[CaseClassDef], args) + case act@AbstractClassType(acd, args) => AbstractClassType(convertClassDef(acd).asInstanceOf[AbstractClassDef], args) + case NAryType(es, builder) => + builder(es map convertType) + }) + } + object PatternConverted { + def unapply(e: Pattern): Option[Pattern] = Some(e match { + case InstanceOfPattern(binder, ct) => + InstanceOfPattern(binder.map(convertId), convertType(ct).asInstanceOf[ClassType]) + case WildcardPattern(binder) => + WildcardPattern(binder.map(convertId)) + case CaseClassPattern(binder, ct, subpatterns) => + CaseClassPattern(binder.map(convertId), convertType(ct).asInstanceOf[CaseClassType], subpatterns map convertPattern) + case TuplePattern(binder, subpatterns) => + TuplePattern(binder.map(convertId), subpatterns map convertPattern) + case UnapplyPattern(binder, TypedFunDef(fd, tpes), subpatterns) => + UnapplyPattern(binder.map(convertId), TypedFunDef(convertFunDef(fd), tpes map convertType), subpatterns map convertPattern) + case PatternExtractor(es, builder) => + builder(es map convertPattern) + }) + } + + object ExprConverted { + def unapply(e: Expr)(implicit bindings: Map[Identifier, Expr]): Option[Expr] = Some(e match { + case Variable(id) if bindings contains id => bindings(id).copiedFrom(e) + case Variable(id) if hasIdConversion(id) => Variable(convertId(id)).copiedFrom(e) + case Variable(id) => e + case pl @ FiniteLambda(mappings, default, tpe) => + FiniteLambda( + mappings.map(kv => (kv._1.map(argtpe => convertExpr(argtpe)), + convertExpr(kv._2))), + convertExpr(default), convertType(tpe).asInstanceOf[FunctionType]) + case Lambda(args, body) => + val new_bindings = scala.collection.mutable.ListBuffer[(Identifier, Identifier)]() + val new_args = for(arg <- args) yield { + val in = arg.getType + val new_id = convertId(arg.id) + if(new_id ne arg.id) { + new_bindings += (arg.id -> new_id) + ValDef(new_id) + } else arg + } + val res = Lambda(new_args, convertExpr(body)(bindings ++ new_bindings.map(t => (t._1, Variable(t._2))))).copiedFrom(e) + res + case Let(a, expr, body) if isTypeToConvert(a.getType) => + val new_a = convertId(a) + val new_bindings = bindings + (a -> Variable(new_a)) + val expr2 = convertExpr(expr)(new_bindings) + val body2 = convertExpr(body)(new_bindings) + Let(new_a, expr2, body2).copiedFrom(e) + case CaseClass(CaseClassType(ccd, tpes), args) => + CaseClass(CaseClassType(ccd, tpes map convertType), args map convertExpr).copiedFrom(e) + case CaseClassSelector(CaseClassType(ccd, tpes), caseClass, selector) => + CaseClassSelector(CaseClassType(ccd, tpes map convertType), convertExpr(caseClass), selector).copiedFrom(e) + case MethodInvocation(rec: Expr, cd: ClassDef, TypedFunDef(fd, tpes), args: Seq[Expr]) => + MethodInvocation(convertExpr(rec), cd, TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) + case FunctionInvocation(TypedFunDef(fd, tpes), args) => + FunctionInvocation(TypedFunDef(convertFunDef(fd), tpes map convertType), args map convertExpr).copiedFrom(e) + case This(ct: ClassType) => + This(convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case IsInstanceOf(expr, ct) => + IsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case AsInstanceOf(expr, ct) => + AsInstanceOf(convertExpr(expr), convertType(ct).asInstanceOf[ClassType]).copiedFrom(e) + case Tuple(args) => + Tuple(for(arg <- args) yield convertExpr(arg)).copiedFrom(e) + case MatchExpr(scrutinee, cases) => + MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { + MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) + }) + case Operator(es, builder) => + val rec = convertExpr _ + val newEs = es.map(rec) + if ((newEs zip es).exists { case (bef, aft) => aft ne bef }) { + builder(newEs).copiedFrom(e) + } else { + e + } + case e => e + }) + } + + def convertModel(model: Model): Model = { + new Model(model.ids.map{i => + val id = convertId(i) + id -> convertExpr(model(i))(Map()) + }.toMap) + } + + def convertResult(result: EvaluationResults.Result[Expr]) = { + result match { + case EvaluationResults.Successful(e) => EvaluationResults.Successful(convertExpr(e)(Map())) + case result => result + } + } + } + + object Forward extends BidirectionalConverters { + /* The conversion between functions should already have taken place */ + def convertFunDef(fd: FunDef): FunDef = { + globalFdMap.getBorElse(fd, fd) + } + /* The conversion between classdefs should already have taken place */ + def convertClassDef(cd: ClassDef): ClassDef = { + globalClassMap.getBorElse(cd, cd) + } + def hasIdConversion(id: Identifier): Boolean = { + mappedVariables.containsA(id) + } + def convertId(id: Identifier): Identifier = { + mappedVariables.getB(id) match { + case Some(idB) => idB + case None => + if(isTypeToConvert(id.getType)) { + val new_id = FreshIdentifier(id.name, convertType(id.getType)) + mappedVariables += (id -> new_id) + new_id + } else id + } + } + def isTypeToConvert(tpe: TypeTree): Boolean = + TypeOps.exists(StringType == _)(tpe) + def convertType(tpe: TypeTree): TypeTree = tpe match { + case StringType => StringList.typed + case TypeConverted(t) => t + } + def convertPattern(e: Pattern): Pattern = e match { + case LiteralPattern(binder, StringLiteral(s)) => + s.foldRight(CaseClassPattern(None, StringNilTyped, Seq())) { + case (elem, pattern) => + CaseClassPattern(None, StringConsTyped, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) + } + case PatternConverted(e) => e + } + + /** Method which can use recursively StringConverted in its body in unapply positions */ + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = e match { + case Variable(id) if isTypeToConvert(id.getType) => Variable(convertId(id)).copiedFrom(e) case StringLiteral(v) => - // No string support for z3 at this moment. val stringEncoding = convertFromString(v) - Some(convertToTarget(stringEncoding)) + convertExpr(stringEncoding).copiedFrom(e) case StringLength(a) => - Some(targetApplication(list_size, Seq(convertToTarget(a)))) + FunctionInvocation(StringSize.typed, Seq(convertExpr(a))).copiedFrom(e) case StringConcat(a, b) => - Some(targetApplication(list_++, Seq(convertToTarget(a), convertToTarget(b)))) + FunctionInvocation(StringListConcat.typed, Seq(convertExpr(a), convertExpr(b))).copiedFrom(e) case SubString(a, start, Plus(start2, length)) if start == start2 => - Some(targetApplication(list_take, - Seq(targetApplication(list_drop, Seq(convertToTarget(a), convertToTarget(start))), convertToTarget(length)))) + FunctionInvocation(StringTake.typed, + Seq(FunctionInvocation(StringDrop.typed, Seq(convertExpr(a), convertExpr(start))), convertExpr(length))).copiedFrom(e) case SubString(a, start, end) => - Some(targetApplication(list_slice, Seq(convertToTarget(a), convertToTarget(start), convertToTarget(end)))) - case _ => None + FunctionInvocation(StringSlice.typed, Seq(convertExpr(a), convertExpr(start), convertExpr(end))).copiedFrom(e) + case MatchExpr(scrutinee, cases) => + MatchExpr(convertExpr(scrutinee), for(MatchCase(pattern, guard, rhs) <- cases) yield { + MatchCase(convertPattern(pattern), guard.map(convertExpr), convertExpr(rhs)) + }) + case ExprConverted(e) => e } - - def apply(t: TypeTree): TypeTree = convertType(t) } -} \ No newline at end of file + + object Backward extends BidirectionalConverters { + def convertFunDef(fd: FunDef): FunDef = { + globalFdMap.getAorElse(fd, fd) + } + /* The conversion between classdefs should already have taken place */ + def convertClassDef(cd: ClassDef): ClassDef = { + globalClassMap.getAorElse(cd, cd) + } + def hasIdConversion(id: Identifier): Boolean = { + mappedVariables.containsB(id) + } + def convertId(id: Identifier): Identifier = { + mappedVariables.getA(id) match { + case Some(idA) => idA + case None => + if(isTypeToConvert(id.getType)) { + val old_type = convertType(id.getType) + val old_id = FreshIdentifier(id.name, old_type) + mappedVariables += (old_id -> id) + old_id + } else id + } + } + def convertIdToMapping(id: Identifier): (Identifier, Variable) = { + id -> Variable(convertId(id)) + } + def isTypeToConvert(tpe: TypeTree): Boolean = + TypeOps.exists(t => TypeOps.isSubtypeOf(t, StringListTyped))(tpe) + def convertType(tpe: TypeTree): TypeTree = tpe match { + case StringList | StringCons | StringNil => StringType + case TypeConverted(t) => t + } + def convertPattern(e: Pattern): Pattern = e match { + case CaseClassPattern(b, StringNilTyped, Seq()) => + LiteralPattern(b.map(convertId), StringLiteral("")) + case CaseClassPattern(b, StringConsTyped, Seq(LiteralPattern(_, CharLiteral(elem)), subpattern)) => + convertPattern(subpattern) match { + case LiteralPattern(_, StringLiteral(s)) + => LiteralPattern(b.map(convertId), StringLiteral(elem + s)) + case e => LiteralPattern(None, StringLiteral("Failed to parse pattern back as string:" + e)) + } + case PatternConverted(e) => e + } + + def convertExpr(e: Expr)(implicit bindings: Map[Identifier, Expr]): Expr = + e match { + case cc@CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, StringListTyped)=> + StringLiteral(convertToString(cc)) + case FunctionInvocation(StringSize, Seq(a)) => + StringLength(convertExpr(a)).copiedFrom(e) + case FunctionInvocation(StringListConcat, Seq(a, b)) => + StringConcat(convertExpr(a), convertExpr(b)).copiedFrom(e) + case FunctionInvocation(StringTake, + Seq(FunctionInvocation(StringDrop, Seq(a, start)), length)) => + val rstart = convertExpr(start) + SubString(convertExpr(a), rstart, plus(rstart, convertExpr(length))).copiedFrom(e) + case ExprConverted(e) => e + } + } +} diff --git a/src/main/scala/leon/synthesis/ConversionPhase.scala b/src/main/scala/leon/synthesis/ConversionPhase.scala index 13f1788fffec11ca03139210d52815699d386324..ce9f1525f5e012800400c7b9c0776914d989b43b 100644 --- a/src/main/scala/leon/synthesis/ConversionPhase.scala +++ b/src/main/scala/leon/synthesis/ConversionPhase.scala @@ -60,7 +60,7 @@ object ConversionPhase extends UnitPhase[Program] { * post(res) * } * - * 3) Completes abstract definitions: + * 3) Completes abstract definitions (IF NOT EXTERN): * * def foo(a: T) = { * require(..a..) @@ -92,14 +92,14 @@ object ConversionPhase extends UnitPhase[Program] { * (in practice, there will be no pre-and postcondition) */ - def convert(e : Expr, ctx : LeonContext) : Expr = { + def convert(e : Expr, ctx : LeonContext, isExtern: Boolean) : Expr = { val (pre, body, post) = breakDownSpecs(e) // Ensure that holes are not found in pre and/or post conditions (pre ++ post).foreach { preTraversal{ case h : Hole => - ctx.reporter.error(s"Holes are not supported in pre- or postconditions. @ ${h.getPos}") + ctx.reporter.error(s"Holes like $h are not supported in pre- or postconditions. @ ${h.getPos}") case wo: WithOracle => ctx.reporter.error(s"WithOracle expressions are not supported in pre- or postconditions: ${wo.asString(ctx)} @ ${wo.getPos}") case _ => @@ -183,8 +183,12 @@ object ConversionPhase extends UnitPhase[Program] { } case None => - val newPost = post getOrElse Lambda(Seq(ValDef(FreshIdentifier("res", e.getType))), BooleanLiteral(true)) - withPrecondition(Choose(newPost), pre) + if (isExtern) { + e + } else { + val newPost = post getOrElse Lambda(Seq(ValDef(FreshIdentifier("res", e.getType))), BooleanLiteral(true)) + withPrecondition(Choose(newPost), pre) + } } // extract spec from chooses at the top-level @@ -202,7 +206,7 @@ object ConversionPhase extends UnitPhase[Program] { def apply(ctx: LeonContext, pgm: Program): Unit = { // TODO: remove side-effects for (fd <- pgm.definedFunctions) { - fd.fullBody = convert(fd.fullBody,ctx) + fd.fullBody = convert(fd.fullBody, ctx, fd.annotations("extern")) } } diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index 5077f9467dab2edd3abf8fdf7da01d5eae59ff5e..fef98f7d79920b9edcca0e0375a84f253b2f79c6 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -6,13 +6,10 @@ package synthesis import purescala.Expressions._ import purescala.Definitions._ import purescala.ExprOps._ -import purescala.Types.TypeTree import purescala.Common._ import purescala.Constructors._ -import purescala.Extractors._ import evaluators._ import grammars._ -import bonsai.enumerators._ import codegen._ import datagen._ import solvers._ @@ -25,6 +22,13 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { implicit val ctx = ctx0 val reporter = ctx.reporter + + private var keepAbstractExamples = false + /** If true, will not evaluate examples to check them. */ + def setKeepAbstractExamples(b: Boolean) = { this.keepAbstractExamples = b; this } + /** Sets if evalution of the result of tests should stop on choose statements. + * Useful for programming by Example */ + def setEvaluationFailOnChoose(b: Boolean) = { evaluator.setEvaluationFailOnChoose(b); this } def extractFromFunDef(fd: FunDef, partition: Boolean): ExamplesBank = fd.postcondition match { case Some(Lambda(Seq(ValDef(id)), post)) => @@ -41,6 +45,8 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { Some(InOutExample(fd.params.map(p => t(p.id)), Seq(t(id)))) } else if ((ids & insIds) == insIds) { Some(InExample(fd.params.map(p => t(p.id)))) + } else if((ids & outsIds) == outsIds) { // Examples provided on a part of the inputs. + Some(InOutExample(fd.params.map(p => t.getOrElse(p.id, Variable(p.id))), Seq(t(id)))) } else { None } @@ -67,27 +73,31 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { case None => ExamplesBank(Nil, Nil) } - - // Extract examples from the passes found in expression + + /** Extract examples from the passes found in expression */ def extractFromProblem(p: Problem): ExamplesBank = { val testClusters = extractTestsOf(and(p.pc, p.phi)) // Finally, we keep complete tests covering all as++xs val allIds = (p.as ++ p.xs).toSet val insIds = p.as.toSet - + val outsIds = p.xs.toSet + val examples = testClusters.toSeq.flatMap { t => val ids = t.keySet if ((ids & allIds) == allIds) { Some(InOutExample(p.as.map(t), p.xs.map(t))) } else if ((ids & insIds) == insIds) { Some(InExample(p.as.map(t))) + } else if((ids & outsIds) == outsIds) { // Examples provided on a part of the inputs. + Some(InOutExample(p.as.map(p => t.getOrElse(p, Variable(p))), p.xs.map(t))) } else { None } } def isValidExample(ex: Example): Boolean = { + if(this.keepAbstractExamples) return true // TODO: Abstract interpretation here ? val (mapping, cond) = ex match { case io: InOutExample => (Map((p.as zip io.ins) ++ (p.xs zip io.outs): _*), And(p.pc, p.phi)) @@ -110,13 +120,14 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { val datagen = new GrammarDataGen(evaluator, ValueGrammar) val solverDataGen = new SolverDataGen(ctx, program, (ctx, pgm) => SolverFactory(() => new FairZ3Solver(ctx, pgm))) - val generatedExamples = datagen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample(_)) + val generatedExamples = datagen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample) - val solverExamples = solverDataGen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample(_)) + val solverExamples = solverDataGen.generateFor(ids, pc, maxValid, maxEnumerated).map(InExample) ExamplesBank(generatedExamples.toSeq ++ solverExamples.toList, Nil) } + /** Extracts all passes constructs from the given postcondition, merges them if needed */ private def extractTestsOf(e: Expr): Set[Map[Identifier, Expr]] = { val allTests = collect[Map[Identifier, Expr]] { case Passes(ins, outs, cases) => @@ -133,14 +144,15 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { case _ => test } } - - // Check whether we can extract all ids from example - val results = exs.collect { case e if infos.forall(_._2.isDefinedAt(e)) => - infos.map{ case (id, f) => id -> f(e) }.toMap + try { + // Check whether we can extract all ids from example + val results = exs.collect { case e if this.keepAbstractExamples || infos.forall(_._2.isDefinedAt(e)) => + infos.map{ case (id, f) => id -> f(e) }.toMap + } + results.toSet + } catch { + case e: IDExtractionException => Set() } - - results.toSet - case _ => Set() }(e) @@ -149,7 +161,8 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { consolidateTests(allTests) } - + /** Processes ((in, out) passes { + * cs[=>Case pattExpr if guard => outR]*/ private def caseToExamples(in: Expr, out: Expr, cs: MatchCase, examplesPerCase: Int = 5): Seq[(Expr,Expr)] = { def doSubstitute(subs : Seq[(Identifier, Expr)], e : Expr) = @@ -180,34 +193,30 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { case (a, b, c) => None }) getOrElse { + // If the input contains free variables, it does not provide concrete examples. // We will instantiate them according to a simple grammar to get them. - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree, Expr]](ValueGrammar.getProductions) - val values = enum.iterator(tupleTypeWrap(freeVars.map { _.getType })) - val instantiations = values.map { - v => freeVars.zip(unwrapTuple(v, freeVars.size)).toMap - } + if(this.keepAbstractExamples) { + cs.optGuard match { + case Some(BooleanLiteral(false)) => + Seq() + case None => + Seq((pattExpr, cs.rhs)) + case Some(pred) => + Seq((Require(pred, pattExpr), cs.rhs)) + } + } else { + val dataGen = new GrammarDataGen(evaluator) - def filterGuard(e: Expr, mapping: Map[Identifier, Expr]): Boolean = cs.optGuard match { - case Some(guard) => - // in -> e should be enough. We shouldn't find any subexpressions of in. - evaluator.eval(replace(Map(in -> e), guard), mapping) match { - case EvaluationResults.Successful(BooleanLiteral(true)) => true - case _ => false - } + val theGuard = replace(Map(in -> pattExpr), cs.optGuard.getOrElse(BooleanLiteral(true))) - case None => - true + dataGen.generateFor(freeVars, theGuard, examplesPerCase, 1000).toSeq map { vals => + val inst = freeVars.zip(vals).toMap + val inR = replaceFromIDs(inst, pattExpr) + val outR = replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs)) + (inR, outR) + } } - - if(cs.optGuard == Some(BooleanLiteral(false))) { - Nil - } else (for { - inst <- instantiations.toSeq - inR = replaceFromIDs(inst, pattExpr) - outR = replaceFromIDs(inst, doSubstitute(ieMap, cs.rhs)) - if filterGuard(inR, inst) - } yield (inR, outR)).take(examplesPerCase) } } } @@ -249,6 +258,8 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { } consolidated } + + case class IDExtractionException(msg: String) extends Exception(msg) /** Extract ids in ins/outs args, and compute corresponding extractors for values map * @@ -268,13 +279,13 @@ class ExamplesFinder(ctx0: LeonContext, program: Program) { case Tuple(vs) => vs.map(extractIds).zipWithIndex.flatMap{ case (ids, i) => ids.map{ case (id, e) => - (id, andThen({ case Tuple(vs) => vs(i) }, e)) + (id, andThen({ case Tuple(vs) => vs(i) case e => throw new IDExtractionException("Expected Tuple, got " + e) }, e)) } } case CaseClass(cct, args) => args.map(extractIds).zipWithIndex.flatMap { case (ids, i) => ids.map{ case (id, e) => - (id, andThen({ case CaseClass(cct2, vs) if cct2 == cct => vs(i) } ,e)) + (id, andThen({ case CaseClass(cct2, vs) if cct2 == cct => vs(i) case e => throw new IDExtractionException("Expected Case class of type " + cct + ", got " + e) } ,e)) } } diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index cd27e272d53e9669f4c2d1f2c8e07356819b4827..3a86ca64a79238aec4e59b197e001fd4c24660b4 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -35,8 +35,10 @@ abstract class PreprocessingRule(name: String) extends Rule(name) { /** Contains the list of all available rules for synthesis */ object Rules { + + def all: List[Rule] = all(false) /** Returns the list of all available rules for synthesis */ - def all = List[Rule]( + def all(naiveGrammar: Boolean): List[Rule] = List[Rule]( StringRender, Unification.DecompTrivialClash, Unification.OccursCheck, // probably useless @@ -54,8 +56,8 @@ object Rules { OptimisticGround, EqualitySplit, InequalitySplit, - CEGIS, - TEGIS, + if(naiveGrammar) NaiveCEGIS else CEGIS, + //TEGIS, //BottomUpTEGIS, rules.Assert, DetupleOutput, diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index dfe90d2f775e8ed3736020cb1e827c75fe602ad9..affa41df46292ae73a6b9a87150283ec8c0ee997 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -31,7 +31,8 @@ class Solution(val pre: Expr, val defs: Set[FunDef], val term: Expr, val isTrust } def toExpr = { - letDef(defs.toList, guardedTerm) + if(defs.isEmpty) guardedTerm else + LetDef(defs.toList, guardedTerm) } // Projects a solution (ignore several output variables) diff --git a/src/main/scala/leon/synthesis/SourceInfo.scala b/src/main/scala/leon/synthesis/SourceInfo.scala index 4bb10d38c9ffc7a7667d165b84e4f65c1edc9e0c..8ab07929d78479656f18ce1fd652cfa7ef870e17 100644 --- a/src/main/scala/leon/synthesis/SourceInfo.scala +++ b/src/main/scala/leon/synthesis/SourceInfo.scala @@ -45,6 +45,10 @@ object SourceInfo { ci } + if (results.isEmpty) { + ctx.reporter.warning("No 'choose' found. Maybe the functions you chose do not exist?") + } + results.sortBy(_.source.getPos) } diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index b9ba6df1f688e01edf631e995ba2f8623bcfc5fe..ac4d30614d8269ce78a92a95e232855eb40d9fbd 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -3,13 +3,11 @@ package leon package synthesis -import purescala.ExprOps._ - +import purescala.ExprOps.replace import purescala.ScalaPrinter -import leon.utils._ import purescala.Definitions.{Program, FunDef} -import leon.utils.ASCIIHelpers +import leon.utils._ import graph._ object SynthesisPhase extends TransformationPhase { @@ -21,11 +19,13 @@ object SynthesisPhase extends TransformationPhase { val optDerivTrees = LeonFlagOptionDef( "derivtrees", "Generate derivation trees", false) // CEGIS options - val optCEGISOptTimeout = LeonFlagOptionDef( "cegis:opttimeout", "Consider a time-out of CE-search as untrusted solution", true) - val optCEGISVanuatoo = LeonFlagOptionDef( "cegis:vanuatoo", "Generate inputs using new korat-style generator", false) + val optCEGISOptTimeout = LeonFlagOptionDef("cegis:opttimeout", "Consider a time-out of CE-search as untrusted solution", true ) + val optCEGISVanuatoo = LeonFlagOptionDef("cegis:vanuatoo", "Generate inputs using new korat-style generator", false) + val optCEGISNaiveGrammar = LeonFlagOptionDef("cegis:naive", "Use the old naive grammar for CEGIS", false) + val optCEGISMaxSize = LeonLongOptionDef("cegis:maxsize", "Maximum size of expressions synthesized by CEGIS", 5L, "N") override val definedOptions : Set[LeonOptionDef[Any]] = - Set(optManual, optCostModel, optDerivTrees, optCEGISOptTimeout, optCEGISVanuatoo) + Set(optManual, optCostModel, optDerivTrees, optCEGISOptTimeout, optCEGISVanuatoo, optCEGISNaiveGrammar, optCEGISMaxSize) def processOptions(ctx: LeonContext): SynthesisSettings = { val ms = ctx.findOption(optManual) @@ -53,11 +53,13 @@ object SynthesisPhase extends TransformationPhase { timeoutMs = timeout map { _ * 1000 }, generateDerivationTrees = ctx.findOptionOrDefault(optDerivTrees), costModel = costModel, - rules = Rules.all ++ (if(ms.isDefined) Seq(rules.AsChoose, rules.SygusCVC4) else Seq()), + rules = Rules.all(ctx.findOptionOrDefault(optCEGISNaiveGrammar)) ++ + (if(ms.isDefined) Seq(rules.AsChoose, rules.SygusCVC4) else Seq()), manualSearch = ms, functions = ctx.findOption(SharedOptions.optFunctions) map { _.toSet }, - cegisUseOptTimeout = ctx.findOption(optCEGISOptTimeout), - cegisUseVanuatoo = ctx.findOption(optCEGISVanuatoo) + cegisUseOptTimeout = ctx.findOptionOrDefault(optCEGISOptTimeout), + cegisUseVanuatoo = ctx.findOptionOrDefault(optCEGISVanuatoo), + cegisMaxSize = ctx.findOptionOrDefault(optCEGISMaxSize).toInt ) } @@ -80,7 +82,7 @@ object SynthesisPhase extends TransformationPhase { try { if (options.generateDerivationTrees) { - val dot = new DotGenerator(search.g) + val dot = new DotGenerator(search) dot.writeFile("derivation"+dotGenIds.nextGlobal+".dot") } diff --git a/src/main/scala/leon/synthesis/SynthesisSettings.scala b/src/main/scala/leon/synthesis/SynthesisSettings.scala index 5202818e18765ebf4086ef41d1685967a14940d0..61dc24ece71081c0f02f5bdcb38d9d9eeb0fee14 100644 --- a/src/main/scala/leon/synthesis/SynthesisSettings.scala +++ b/src/main/scala/leon/synthesis/SynthesisSettings.scala @@ -16,7 +16,8 @@ case class SynthesisSettings( functionsToIgnore: Set[FunDef] = Set(), // Cegis related options - cegisUseOptTimeout: Option[Boolean] = None, - cegisUseVanuatoo: Option[Boolean] = None + cegisUseOptTimeout: Boolean = true, + cegisUseVanuatoo : Boolean = false, + cegisMaxSize: Int = 5 ) diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index bafed6ec2bab51539bfc0547563bbad2aeea873e..efd1ad13e0538f855487e94b7b1a35d7d893627f 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -70,21 +70,19 @@ class Synthesizer(val context : LeonContext, // Print out report for synthesis, if necessary reporter.ifDebug { printer => - import java.io.FileWriter import java.text.SimpleDateFormat import java.util.Date val categoryName = ci.fd.getPos.file.toString.split("/").dropRight(1).lastOption.getOrElse("?") val benchName = categoryName+"."+ci.fd.id.name - var time = lastTime/1000.0; + val time = lastTime/1000.0 val defs = visibleDefsFrom(ci.fd)(program).collect { case cd: ClassDef => 1 + cd.fields.size case fd: FunDef => 1 + fd.params.size + formulaSize(fd.fullBody) } - val psize = defs.sum; - + val psize = defs.sum val (size, calls, proof) = result.headOption match { case Some((sol, trusted)) => diff --git a/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala b/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala index 6e9dc237667e286cdbd9a82875a986dfbf6b8aba..e5071b9b4ecbdcde1d1b5a339b486c7820d20bfc 100644 --- a/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala +++ b/src/main/scala/leon/synthesis/disambiguation/ExamplesAdder.scala @@ -3,20 +3,46 @@ package leon package synthesis package disambiguation -import leon.LeonContext -import leon.purescala.Expressions._ +import purescala.Types.FunctionType import purescala.Common.FreshIdentifier import purescala.Constructors.{ and, tupleWrap } import purescala.Definitions.{ FunDef, Program, ValDef } -import purescala.ExprOps.expressionToPattern -import purescala.Expressions.{ BooleanLiteral, Equals, Expr, Lambda, MatchCase, Passes, Variable, WildcardPattern } +import purescala.ExprOps import purescala.Extractors.TopLevelAnds -import leon.purescala.Expressions._ +import purescala.Expressions._ /** * @author Mikael */ +object ExamplesAdder { + def replaceGenericValuesByVariable(e: Expr): (Expr, Map[Expr, Expr]) = { + var assignment = Map[Expr, Expr]() + var extension = 'a' + var id = "" + (ExprOps.postMap({ expr => expr match { + case g@GenericValue(tpe, index) => + val newIdentifier = FreshIdentifier(tpe.id.name.take(1).toLowerCase() + tpe.id.name.drop(1) + extension + id, tpe.id.getType) + if(extension != 'z' && extension != 'Z') + extension = (extension.toInt + 1).toChar + else if(extension == 'z') // No more than 52 generic variables in practice? + extension = 'A' + else { + if(id == "") id = "1" else id = (id.toInt + 1).toString + } + + val newVar = Variable(newIdentifier) + assignment += g -> newVar + Some(newVar) + case _ => None + } })(e), assignment) + } +} + class ExamplesAdder(ctx0: LeonContext, program: Program) { + import ExamplesAdder._ + var _removeFunctionParameters = false + + def setRemoveFunctionParameters(b: Boolean) = { _removeFunctionParameters = b; this } /** Accepts the nth alternative of a question (0 being the current one) */ def acceptQuestion[T <: Expr](fd: FunDef, q: Question[T], alternativeIndex: Int): Unit = { @@ -25,9 +51,12 @@ class ExamplesAdder(ctx0: LeonContext, program: Program) { addToFunDef(fd, Seq((newIn, newOut))) } + private def filterCases(cases: Seq[MatchCase]) = cases.filter(c => c.optGuard != Some(BooleanLiteral(false))) + /** Adds the given input/output examples to the function definitions */ - def addToFunDef(fd: FunDef, examples: Seq[(Expr, Expr)]) = { - val inputVariables = tupleWrap(fd.params.map(p => Variable(p.id): Expr)) + def addToFunDef(fd: FunDef, examples: Seq[(Expr, Expr)]): Unit = { + val params = if(_removeFunctionParameters) fd.params.filter(x => !x.getType.isInstanceOf[FunctionType]) else fd.params + val inputVariables = tupleWrap(params.map(p => Variable(p.id): Expr)) val newCases = examples.map{ case (in, out) => exampleToCase(in, out) } fd.postcondition match { case Some(Lambda(Seq(ValDef(id)), post)) => @@ -44,7 +73,7 @@ class ExamplesAdder(ctx0: LeonContext, program: Program) { } else { val newPasses = exprs(i) match { case Passes(in, out, cases) => - Passes(in, out, (cases ++ newCases).distinct ) + Passes(in, out, (filterCases(cases) ++ newCases).distinct ) case _ => ??? } val newPost = and(exprs.updated(i, newPasses) : _*) @@ -68,12 +97,12 @@ class ExamplesAdder(ctx0: LeonContext, program: Program) { } private def exampleToCase(in: Expr, out: Expr): MatchCase = { - val (inPattern, inGuard) = expressionToPattern(in) - if(inGuard != BooleanLiteral(true)) { + val (inPattern, inGuard) = ExprOps.expressionToPattern(in) + if(inGuard == BooleanLiteral(true)) { + MatchCase(inPattern, None, out) + } else /*if (in == in_raw) { } *else*/ { val id = FreshIdentifier("out", in.getType, true) MatchCase(WildcardPattern(Some(id)), Some(Equals(Variable(id), in)), out) - } else { - MatchCase(inPattern, None, out) } } } diff --git a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala index bfa6a62120af6171d001b6a026630734eb6c10fd..81f98f86432dc54ffb446f473fee3a1afcf358ca 100644 --- a/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala +++ b/src/main/scala/leon/synthesis/disambiguation/QuestionBuilder.scala @@ -1,21 +1,20 @@ package leon package synthesis.disambiguation +import datagen.GrammarDataGen import synthesis.Solution import evaluators.DefaultEvaluator import purescala.Expressions._ import purescala.ExprOps -import purescala.Constructors._ -import purescala.Extractors._ import purescala.Types.{StringType, TypeTree} import purescala.Common.Identifier import purescala.Definitions.Program import purescala.DefOps -import grammars.ValueGrammar -import bonsai.enumerators.MemoizedEnumerator +import grammars._ import solvers.ModelBuilder import scala.collection.mutable.ListBuffer -import grammars._ +import evaluators.AbstractEvaluator +import scala.annotation.tailrec object QuestionBuilder { /** Sort methods for questions. You can build your own */ @@ -69,15 +68,15 @@ object QuestionBuilder { /** Specific enumeration of strings, which can be used with the QuestionBuilder#setValueEnumerator method */ object SpecialStringValueGrammar extends ExpressionGrammar[TypeTree] { - def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Gen] = t match { - case StringType => - List( - terminal(StringLiteral("")), - terminal(StringLiteral("a")), - terminal(StringLiteral("\"'\n\t")), - terminal(StringLiteral("Lara 2007")) - ) - case _ => ValueGrammar.computeProductions(t) + def computeProductions(t: TypeTree)(implicit ctx: LeonContext): Seq[Prod] = t match { + case StringType => + List( + terminal(StringLiteral("")), + terminal(StringLiteral("a")), + terminal(StringLiteral("\"'\n\t")), + terminal(StringLiteral("Lara 2007")) + ) + case _ => ValueGrammar.computeProductions(t) } } } @@ -92,11 +91,9 @@ object QuestionBuilder { * * @tparam T A subtype of Expr that will be the type used in the Question[T] results. * @param input The identifier of the unique function's input. Must be typed or the type should be defined by setArgumentType - * @param ruleApplication The set of solutions for the body of f * @param filter A function filtering which outputs should be considered for comparison. - * It takes as input the sequence of outputs already considered for comparison, and the new output. - * It should return Some(result) if the result can be shown, and None else. - * @return An ordered + * It takes as input the sequence of outputs already considered for comparison, and the new output. + * It should return Some(result) if the result can be shown, and None else. * */ class QuestionBuilder[T <: Expr]( @@ -129,35 +126,72 @@ class QuestionBuilder[T <: Expr]( private def run(s: Solution, elems: Seq[(Identifier, Expr)]): Option[Expr] = { val newProgram = DefOps.addFunDefs(p, s.defs, p.definedFunctions.head) - val e = new DefaultEvaluator(c, newProgram) + val e = new AbstractEvaluator(c, newProgram) val model = new ModelBuilder model ++= elems val modelResult = model.result() - e.eval(s.term, modelResult).result + for{x <- e.eval(s.term, modelResult).result + res = x._1 + simp = ExprOps.simplifyArithmetic(res)} + yield simp + } + + /** Make all generic values unique. + * Duplicate generic values are not suitable for disambiguating questions since they remove an order. */ + def makeGenericValuesUnique(a: Expr): Expr = { + var genVals = Set[Expr with Terminal]() + def freshenValue(g: Expr with Terminal): Option[Expr with Terminal] = g match { + case g: GenericValue => Some(GenericValue(g.tp, g.id + 1)) + case StringLiteral(s) => + val i = s.lastIndexWhere { c => c < '0' || c > '9' } + val prefix = s.take(i+1) + val suffix = s.drop(i+1) + Some(StringLiteral(prefix + (if(suffix == "") "0" else (suffix.toInt + 1).toString))) + case InfiniteIntegerLiteral(i) => Some(InfiniteIntegerLiteral(i+1)) + case IntLiteral(i) => if(i == Integer.MAX_VALUE) None else Some(IntLiteral(i+1)) + case CharLiteral(c) => if(c == Char.MaxValue) None else Some(CharLiteral((c+1).toChar)) + case otherLiteral => None + } + @tailrec @inline def freshValue(g: Expr with Terminal): Expr with Terminal = { + if(genVals contains g) + freshenValue(g) match { + case None => g + case Some(v) => freshValue(v) + } + else { + genVals += g + g + } + } + ExprOps.postMap{ e => e match { + case g:Expr with Terminal => + Some(freshValue(g)) + case _ => None + }}(a) } /** Returns a list of input/output questions to ask to the user. */ def result(): List[Question[T]] = { if(solutions.isEmpty) return Nil - - val enum = new MemoizedEnumerator[TypeTree, Expr, Generator[TypeTree,Expr]](value_enumerator.getProductions) - val values = enum.iterator(tupleTypeWrap(_argTypes)) - val instantiations = values.map { - v => input.zip(unwrapTuple(v, input.size)) - } - - val enumerated_inputs = instantiations.take(expressionsToTake).toList - + + val datagen = new GrammarDataGen(new DefaultEvaluator(c, p), value_enumerator) + val enumerated_inputs = datagen.generateMapping(input, BooleanLiteral(true), expressionsToTake, expressionsToTake) + .map(inputs => + inputs.map(id_expr => + (id_expr._1, makeGenericValuesUnique(id_expr._2)))).toList + val solution = solutions.head val alternatives = solutions.drop(1).take(solutionsToTake).toList val questions = ListBuffer[Question[T]]() - for{possible_input <- enumerated_inputs - current_output_nonfiltered <- run(solution, possible_input) - current_output <- filter(Seq(), current_output_nonfiltered)} { + for { + possibleInput <- enumerated_inputs + currentOutputNonFiltered <- run(solution, possibleInput) + currentOutput <- filter(Seq(), currentOutputNonFiltered) + } { - val alternative_outputs = ((ListBuffer[T](current_output) /: alternatives) { (prev, alternative) => - run(alternative, possible_input) match { - case Some(alternative_output) if alternative_output != current_output => + val alternative_outputs = (ListBuffer[T](currentOutput) /: alternatives) { (prev, alternative) => + run(alternative, possibleInput) match { + case Some(alternative_output) if alternative_output != currentOutput => filter(prev, alternative_output) match { case Some(alternative_output_filtered) => prev += alternative_output_filtered @@ -165,11 +199,11 @@ class QuestionBuilder[T <: Expr]( } case _ => prev } - }).drop(1).toList.distinct - if(alternative_outputs.nonEmpty || keepEmptyAlternativeQuestions(current_output)) { - questions += Question(possible_input.map(_._2), current_output, alternative_outputs.sortWith((e,f) => _alternativeSortMethod.compare(e, f) <= 0)) + }.drop(1).toList.distinct + if(alternative_outputs.nonEmpty || keepEmptyAlternativeQuestions(currentOutput)) { + questions += Question(possibleInput.map(_._2), currentOutput, alternative_outputs.sortWith((e,f) => _alternativeSortMethod.compare(e, f) <= 0)) } } questions.toList.sortBy(_questionSorMethod(_)) } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/synthesis/graph/DotGenerator.scala b/src/main/scala/leon/synthesis/graph/DotGenerator.scala index 7da38716116f51d89e751a8aa12d709be776e17c..78ef7b371487a6711d3508b9712f7806e9c551e0 100644 --- a/src/main/scala/leon/synthesis/graph/DotGenerator.scala +++ b/src/main/scala/leon/synthesis/graph/DotGenerator.scala @@ -6,7 +6,11 @@ import leon.utils.UniqueCounter import java.io.{File, FileWriter, BufferedWriter} -class DotGenerator(g: Graph) { +class DotGenerator(search: Search) { + + implicit val ctx = search.ctx + + val g = search.g private val idCounter = new UniqueCounter[Unit] idCounter.nextGlobal // Start with 1 @@ -80,12 +84,14 @@ class DotGenerator(g: Graph) { } def nodeDesc(n: Node): String = n match { - case an: AndNode => an.ri.toString - case on: OrNode => on.p.toString + case an: AndNode => an.ri.asString + case on: OrNode => on.p.asString } def drawNode(res: StringBuffer, name: String, n: Node) { + val index = n.parent.map(_.descendants.indexOf(n) + " ").getOrElse("") + def escapeHTML(str: String) = str.replaceAll("&", "&").replaceAll("<", "<").replaceAll(">", ">") val color = if (n.isSolved) { @@ -109,10 +115,10 @@ class DotGenerator(g: Graph) { res append "<TR><TD BORDER=\"0\">"+escapeHTML(n.cost.asString)+"</TD></TR>" } - res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(nodeDesc(n)))+"</TD></TR>" + res append "<TR><TD BORDER=\"1\" BGCOLOR=\""+color+"\">"+escapeHTML(limit(index + nodeDesc(n)))+"</TD></TR>" if (n.isSolved) { - res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(limit(n.generateSolutions().head.toString))+"</TD></TR>" + res append "<TR><TD BGCOLOR=\""+color+"\">"+escapeHTML(limit(n.generateSolutions().head.asString))+"</TD></TR>" } res append "</TABLE>>, shape = \"none\" ];\n" @@ -126,4 +132,4 @@ class DotGenerator(g: Graph) { } } -object dotGenIds extends UniqueCounter[Unit] \ No newline at end of file +object dotGenIds extends UniqueCounter[Unit] diff --git a/src/main/scala/leon/synthesis/graph/Search.scala b/src/main/scala/leon/synthesis/graph/Search.scala index 98554a5ae492972e0b7b3915979d9af829d81555..c630e315d9777110b5dcde7adc42cf6172161af3 100644 --- a/src/main/scala/leon/synthesis/graph/Search.scala +++ b/src/main/scala/leon/synthesis/graph/Search.scala @@ -10,7 +10,7 @@ import scala.collection.mutable.ArrayBuffer import leon.utils.Interruptible import java.util.concurrent.atomic.AtomicBoolean -abstract class Search(ctx: LeonContext, ci: SourceInfo, p: Problem, costModel: CostModel) extends Interruptible { +abstract class Search(val ctx: LeonContext, ci: SourceInfo, p: Problem, costModel: CostModel) extends Interruptible { val g = new Graph(costModel, p) def findNodeToExpandFrom(n: Node): Option[Node] diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index df2c44193412a55af004dfa7695901044a4b5b53..d3dc6347280a45642935e6ea3c314246a3cb6958 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -65,7 +65,7 @@ case object ADTSplit extends Rule("ADT Split.") { case Some((id, act, cases)) => val oas = p.as.filter(_ != id) - val subInfo = for(ccd <- cases) yield { + val subInfo0 = for(ccd <- cases) yield { val cct = CaseClassType(ccd, act.tps) val args = cct.fields.map { vd => FreshIdentifier(vd.id.name, vd.getType, true) }.toList @@ -89,6 +89,10 @@ case object ADTSplit extends Rule("ADT Split.") { (cct, subProblem, subPattern) } + val subInfo = subInfo0.sortBy{ case (cct, _, _) => + cct.fieldsTypes.count { t => t == act } + } + val onSuccess: List[Solution] => Option[Solution] = { case sols => diff --git a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala index 2f3869af16b71f9635e36d27774f55a7cee7140c..4c12f58224427c1d74654638e28965d746f93d54 100644 --- a/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala +++ b/src/main/scala/leon/synthesis/rules/BottomUpTegis.scala @@ -14,7 +14,6 @@ import codegen.CodeGenParams import grammars._ import bonsai.enumerators._ -import bonsai.{Generator => Gen} case object BottomUpTEGIS extends BottomUpTEGISLike[TypeTree]("BU TEGIS") { def getGrammar(sctx: SynthesisContext, p: Problem) = { @@ -51,13 +50,13 @@ abstract class BottomUpTEGISLike[T <: Typed](name: String) extends Rule(name) { val nTests = tests.size - var compiled = Map[Generator[T, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() + var compiled = Map[ProductionRule[T, Expr], Vector[Vector[Expr]] => Option[Vector[Expr]]]() /** * Compile Generators to functions from Expr to Expr. The compiled * generators will be passed to the enumerator */ - def compile(gen: Generator[T, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { + def compile(gen: ProductionRule[T, Expr]): Vector[Vector[Expr]] => Option[Vector[Expr]] = { compiled.getOrElse(gen, { val executor = if (gen.subTrees.isEmpty) { @@ -108,7 +107,7 @@ abstract class BottomUpTEGISLike[T <: Typed](name: String) extends Rule(name) { val targetType = tupleTypeWrap(p.xs.map(_.getType)) val wrappedTests = tests.map { case (is, os) => (is, tupleWrap(os))} - val enum = new BottomUpEnumerator[T, Expr, Expr, Generator[T, Expr]]( + val enum = new BottomUpEnumerator[T, Expr, Expr, ProductionRule[T, Expr]]( grammar.getProductions, wrappedTests, { (vecs, gen) => diff --git a/src/main/scala/leon/synthesis/rules/CEGIS.scala b/src/main/scala/leon/synthesis/rules/CEGIS.scala index 1fcf01d52088ea9d4d25d184a673ef8335a8d260..b0de64ed05458d22cc113170dc850e2c1e2f6a3b 100644 --- a/src/main/scala/leon/synthesis/rules/CEGIS.scala +++ b/src/main/scala/leon/synthesis/rules/CEGIS.scala @@ -4,16 +4,31 @@ package leon package synthesis package rules -import purescala.Types._ - import grammars._ -import utils._ +import grammars.transformers._ +import purescala.Types.TypeTree -case object CEGIS extends CEGISLike[TypeTree]("CEGIS") { +/** Basic implementation of CEGIS that uses a naive grammar */ +case object NaiveCEGIS extends CEGISLike[TypeTree]("Naive CEGIS") { def getParams(sctx: SynthesisContext, p: Problem) = { CegisParams( grammar = Grammars.typeDepthBound(Grammars.default(sctx, p), 2), // This limits type depth - rootLabel = {(tpe: TypeTree) => tpe } + rootLabel = {(tpe: TypeTree) => tpe }, + optimizations = false + ) + } +} + +/** More advanced implementation of CEGIS that uses a less permissive grammar + * and some optimizations + */ +case object CEGIS extends CEGISLike[TaggedNonTerm[TypeTree]]("CEGIS") { + def getParams(sctx: SynthesisContext, p: Problem) = { + val base = NaiveCEGIS.getParams(sctx,p).grammar + CegisParams( + grammar = TaggedGrammar(base), + rootLabel = TaggedNonTerm(_, Tags.Top, 0, None), + optimizations = true ) } } diff --git a/src/main/scala/leon/synthesis/rules/CEGISLike.scala b/src/main/scala/leon/synthesis/rules/CEGISLike.scala index d577f7f9fe1f260f4af9ca4d3cb20ca868e37fcc..291e485d70b80b580095a484e02448166de9e18c 100644 --- a/src/main/scala/leon/synthesis/rules/CEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/CEGISLike.scala @@ -4,10 +4,6 @@ package leon package synthesis package rules -import leon.utils.SeqUtils -import solvers._ -import grammars._ - import purescala.Expressions._ import purescala.Common._ import purescala.Definitions._ @@ -16,44 +12,59 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Constructors._ -import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} +import solvers._ +import grammars._ +import grammars.transformers._ +import leon.utils._ import evaluators._ import datagen._ import codegen.CodeGenParams +import scala.collection.mutable.{HashMap=>MutableMap, ArrayBuffer} + abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case class CegisParams( grammar: ExpressionGrammar[T], rootLabel: TypeTree => T, - maxUnfoldings: Int = 5 + optimizations: Boolean, + maxSize: Option[Int] = None ) def getParams(sctx: SynthesisContext, p: Problem): CegisParams def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { + val exSolverTo = 2000L val cexSolverTo = 2000L - // Track non-deterministic programs up to 10'000 programs, or give up + // Track non-deterministic programs up to 100'000 programs, or give up val nProgramsLimit = 100000 val sctx = hctx.sctx val ctx = sctx.context + val timers = ctx.timers.synthesis.cegis + // CEGIS Flags to activate or deactivate features - val useOptTimeout = sctx.settings.cegisUseOptTimeout.getOrElse(true) - val useVanuatoo = sctx.settings.cegisUseVanuatoo.getOrElse(false) + val useOptTimeout = sctx.settings.cegisUseOptTimeout + val useVanuatoo = sctx.settings.cegisUseVanuatoo // Limits the number of programs CEGIS will specifically validate individually val validateUpTo = 3 + val passingRatio = 10 val interruptManager = sctx.context.interruptManager val params = getParams(sctx, p) - if (params.maxUnfoldings == 0) { + // If this CEGISLike forces a maxSize, take it, otherwise find it in the settings + val maxSize = params.maxSize.getOrElse(sctx.settings.cegisMaxSize) + + ctx.reporter.debug(s"This is $name. Settings: optimizations = ${params.optimizations}, maxSize = $maxSize, vanuatoo=$useVanuatoo, optTimeout=$useOptTimeout") + + if (maxSize == 0) { return Nil } @@ -61,13 +72,13 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { private var termSize = 0 - val grammar = SizeBoundedGrammar(params.grammar) + val grammar = SizeBoundedGrammar(params.grammar, params.optimizations) - def rootLabel = SizedLabel(params.rootLabel(tupleTypeWrap(p.xs.map(_.getType))), termSize) + def rootLabel = SizedNonTerm(params.rootLabel(tupleTypeWrap(p.xs.map(_.getType))), termSize) - var nAltsCache = Map[SizedLabel[T], Int]() + var nAltsCache = Map[SizedNonTerm[T], Int]() - def countAlternatives(l: SizedLabel[T]): Int = { + def countAlternatives(l: SizedNonTerm[T]): Int = { if (!(nAltsCache contains l)) { val count = grammar.getProductions(l).map { gen => gen.subTrees.map(countAlternatives).product @@ -91,18 +102,18 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { * b3 => c6 == H(c4, c5) * * c1 -> Seq( - * (b1, F(c2, c3), Set(c2, c3)) - * (b2, G(c4, c5), Set(c4, c5)) + * (b1, F(_, _), Seq(c2, c3)) + * (b2, G(_, _), Seq(c4, c5)) * ) * c6 -> Seq( - * (b3, H(c7, c8), Set(c7, c8)) + * (b3, H(_, _), Seq(c7, c8)) * ) */ private var cTree: Map[Identifier, Seq[(Identifier, Seq[Expr] => Expr, Seq[Identifier])]] = Map() // C identifiers corresponding to p.xs - private var rootC: Identifier = _ + private var rootC: Identifier = _ private var bs: Set[Identifier] = Set() @@ -110,19 +121,19 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { class CGenerator { - private var buffers = Map[SizedLabel[T], Stream[Identifier]]() + private var buffers = Map[SizedNonTerm[T], Stream[Identifier]]() - private var slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + private var slots = Map[SizedNonTerm[T], Int]().withDefaultValue(0) - private def streamOf(t: SizedLabel[T]): Stream[Identifier] = Stream.continually( + private def streamOf(t: SizedNonTerm[T]): Stream[Identifier] = Stream.continually( FreshIdentifier(t.asString, t.getType, true) ) def rewind(): Unit = { - slots = Map[SizedLabel[T], Int]().withDefaultValue(0) + slots = Map[SizedNonTerm[T], Int]().withDefaultValue(0) } - def getNext(t: SizedLabel[T]) = { + def getNext(t: SizedNonTerm[T]) = { if (!(buffers contains t)) { buffers += t -> streamOf(t) } @@ -140,13 +151,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { def updateCTree(): Unit = { + ctx.timers.synthesis.cegis.updateCTree.start() def freshB() = { val id = FreshIdentifier("B", BooleanType, true) bs += id id } - def defineCTreeFor(l: SizedLabel[T], c: Identifier): Unit = { + def defineCTreeFor(l: SizedNonTerm[T], c: Identifier): Unit = { if (!(cTree contains c)) { val cGen = new CGenerator() @@ -182,11 +194,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { sctx.reporter.ifDebug { printer => printer("Grammar so far:") grammar.printProductions(printer) + printer("") } bsOrdered = bs.toSeq.sorted + excludedPrograms = ArrayBuffer() setCExpr(computeCExpr()) + ctx.timers.synthesis.cegis.updateCTree.stop() } /** @@ -233,9 +248,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { cache(c) } - SeqUtils.cartesianProduct(seqs).map { ls => - ls.foldLeft(Set[Identifier]())(_ ++ _) - } + SeqUtils.cartesianProduct(seqs).map(_.flatten.toSet) } allProgramsFor(Seq(rootC)) @@ -287,7 +300,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case (e, alt) => IfExpr(alt._1.toVariable, exprOf(alt), e) } } else { - Error(c.getType, "Impossibru") + Error(c.getType, s"Empty production rule: $c") } cToFd(c).fullBody = body @@ -325,11 +338,10 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { solFd.fullBody = Ensuring( FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), - Lambda(p.xs.map(ValDef(_)), p.phi) + Lambda(p.xs.map(ValDef), p.phi) ) - - phiFd.body = Some( + phiFd.body = Some( letTuple(p.xs, FunctionInvocation(cTreeFd.typed, p.as.map(_.toVariable)), p.phi) @@ -373,46 +385,56 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { private val innerPhi = outerExprToInnerExpr(p.phi) private var programCTree: Program = _ - private var tester: (Example, Set[Identifier]) => EvaluationResults.Result[Expr] = _ + + private var evaluator: DefaultEvaluator = _ private def setCExpr(cTreeInfo: (Expr, Seq[FunDef])): Unit = { val (cTree, newFds) = cTreeInfo cTreeFd.body = Some(cTree) programCTree = addFunDefs(innerProgram, newFds, cTreeFd) + evaluator = new DefaultEvaluator(sctx.context, programCTree) //println("-- "*30) //println(programCTree.asString) //println(".. "*30) + } - //val evaluator = new DualEvaluator(sctx.context, programCTree, CodeGenParams.default) - val evaluator = new DefaultEvaluator(sctx.context, programCTree) - - tester = - { (ex: Example, bValues: Set[Identifier]) => - // TODO: Test output value as well - val envMap = bs.map(b => b -> BooleanLiteral(bValues(b))).toMap - - ex match { - case InExample(ins) => - val fi = FunctionInvocation(phiFd.typed, ins) - evaluator.eval(fi, envMap) + def testForProgram(bValues: Set[Identifier])(ex: Example): Boolean = { - case InOutExample(ins, outs) => - val fi = FunctionInvocation(cTreeFd.typed, ins) - val eq = equality(fi, tupleWrap(outs)) - evaluator.eval(eq, envMap) - } - } - } + val origImpl = cTreeFd.fullBody + val outerSol = getExpr(bValues) + val innerSol = outerExprToInnerExpr(outerSol) + val cnstr = letTuple(p.xs, innerSol, innerPhi) + cTreeFd.fullBody = innerSol + + timers.testForProgram.start() + val res = ex match { + case InExample(ins) => + evaluator.eval(cnstr, p.as.zip(ins).toMap) + + case InOutExample(ins, outs) => + val eq = equality(innerSol, tupleWrap(outs)) + evaluator.eval(eq, p.as.zip(ins).toMap) + } + timers.testForProgram.stop() + cTreeFd.fullBody = origImpl - def testForProgram(bValues: Set[Identifier])(ex: Example): Boolean = { - tester(ex, bValues) match { + res match { case EvaluationResults.Successful(res) => res == BooleanLiteral(true) case EvaluationResults.RuntimeError(err) => + /*if (err.contains("Empty production rule")) { + println(programCTree.asString) + println(bValues) + println(ex) + println(this.getExpr(bValues)) + (new Throwable).printStackTrace() + println(err) + println() + }*/ sctx.reporter.debug("RE testing CE: "+err) false @@ -420,18 +442,18 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { sctx.reporter.debug("Error testing CE: "+err) false } - } - + } // Returns the outer expression corresponding to a B-valuation def getExpr(bValues: Set[Identifier]): Expr = { + def getCValue(c: Identifier): Expr = { cTree(c).find(i => bValues(i._1)).map { case (b, builder, cs) => builder(cs.map(getCValue)) }.getOrElse { - simplestValue(c.getType) + Error(c.getType, "Impossible assignment of bs") } } @@ -445,60 +467,70 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { def validatePrograms(bss: Set[Set[Identifier]]): Either[Stream[Solution], Seq[Seq[Expr]]] = { val origImpl = cTreeFd.fullBody - val cexs = for (bs <- bss.toSeq) yield { + var cexs = Seq[Seq[Expr]]() + + for (bs <- bss.toSeq) { val outerSol = getExpr(bs) val innerSol = outerExprToInnerExpr(outerSol) - + //println(s"Testing $outerSol") cTreeFd.fullBody = innerSol val cnstr = and(innerPc, letTuple(p.xs, innerSol, Not(innerPhi))) - //println("Solving for: "+cnstr.asString) + val eval = new DefaultEvaluator(ctx, innerProgram) - val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) - val solver = solverf.getNewSolver() - try { - solver.assertCnstr(cnstr) - solver.check match { - case Some(true) => - excludeProgram(bs, true) - val model = solver.getModel - //println("Found counter example: ") - //for ((s, v) <- model) { - // println(" "+s.asString+" -> "+v.asString) - //} - - //val evaluator = new DefaultEvaluator(ctx, prog) - //println(evaluator.eval(cnstr, model)) - - Some(p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))) - - case Some(false) => - // UNSAT, valid program - return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) + if (cexs exists (cex => eval.eval(cnstr, p.as.zip(cex).toMap).result == Some(BooleanLiteral(true)))) { + //println(s"Program $outerSol fails!") + excludeProgram(bs, true) + cTreeFd.fullBody = origImpl + } else { + //println("Solving for: "+cnstr.asString) + + val solverf = SolverFactory.getFromSettings(ctx, innerProgram).withTimeout(cexSolverTo) + val solver = solverf.getNewSolver() + try { + solver.assertCnstr(cnstr) + solver.check match { + case Some(true) => + excludeProgram(bs, true) + val model = solver.getModel + //println("Found counter example: ") + //for ((s, v) <- model) { + // println(" "+s.asString+" -> "+v.asString) + //} + + //val evaluator = new DefaultEvaluator(ctx, prog) + //println(evaluator.eval(cnstr, model)) + //println(s"Program $outerSol fails with cex ${p.as.map(a => model.getOrElse(a, simplestValue(a.getType)))}") + cexs +:= p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) + + case Some(false) => + // UNSAT, valid program + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, true))) - case None => - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - // Optimistic valid solution - return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) - } else { - None - } + case None => + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") + // Optimistic valid solution + return Left(Stream(Solution(BooleanLiteral(true), Set(), outerSol, false))) + } + } + } finally { + solverf.reclaim(solver) + solverf.shutdown() + cTreeFd.fullBody = origImpl } - } finally { - solverf.reclaim(solver) - solverf.shutdown() - cTreeFd.fullBody = origImpl } } - Right(cexs.flatten) + Right(cexs) } var excludedPrograms = ArrayBuffer[Set[Identifier]]() + def allProgramsClosed = allProgramsCount() <= excludedPrograms.size + // Explicitly remove program computed by bValues from the search space // // If the bValues comes from models, we make sure the bValues we exclude @@ -542,9 +574,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { //println(" --- Constraints ---") //println(" - "+toFind.asString) try { - //TODO: WHAT THE F IS THIS? - //val bsOrNotBs = andJoin(bsOrdered.map(b => if (bs(b)) b.toVariable else Not(b.toVariable))) - //solver.assertCnstr(bsOrNotBs) solver.assertCnstr(toFind) for ((c, alts) <- cTree) { @@ -660,9 +689,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ndProgram.init() var unfolding = 1 - val maxUnfoldings = params.maxUnfoldings - - sctx.reporter.debug(s"maxUnfoldings=$maxUnfoldings") var baseExampleInputs: ArrayBuffer[Example] = new ArrayBuffer[Example]() @@ -670,7 +696,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { ndProgram.grammar.printProductions(printer) } - // We populate the list of examples with a predefined one + // We populate the list of examples with a defined one sctx.reporter.debug("Acquiring initial list of examples") baseExampleInputs ++= p.eb.examples @@ -708,7 +734,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } - /** * We generate tests for discarding potential programs */ @@ -738,8 +763,6 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { try { do { - var skipCESearch = false - // Unfold formula ndProgram.unfold() @@ -748,6 +771,7 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { val nInitial = prunedPrograms.size sctx.reporter.debug("#Programs: "+nInitial) + //sctx.reporter.ifDebug{ printer => // val limit = 100 @@ -764,34 +788,33 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { // We further filter the set of working programs to remove those that fail on known examples if (hasInputExamples) { + timers.filter.start() for (bs <- prunedPrograms if !interruptManager.isInterrupted) { - var valid = true val examples = allInputExamples() - while(valid && examples.hasNext) { - val e = examples.next() - if (!ndProgram.testForProgram(bs)(e)) { - failedTestsStats(e) += 1 - sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}") - wrongPrograms += bs - prunedPrograms -= bs - - valid = false - } + examples.find(e => !ndProgram.testForProgram(bs)(e)).foreach { e => + failedTestsStats(e) += 1 + sctx.reporter.debug(f" Program: ${ndProgram.getExpr(bs).asString}%-80s failed on: ${e.asString}") + wrongPrograms += bs + prunedPrograms -= bs } if (wrongPrograms.size+1 % 1000 == 0) { sctx.reporter.debug("..."+wrongPrograms.size) } } + timers.filter.stop() } val nPassing = prunedPrograms.size - sctx.reporter.debug("#Programs passing tests: "+nPassing) + val nTotal = ndProgram.allProgramsCount() + //println(s"Iotal: $nTotal, passing: $nPassing") + + sctx.reporter.debug(s"#Programs passing tests: $nPassing out of $nTotal") sctx.reporter.ifDebug{ printer => - for (p <- prunedPrograms.take(10)) { + for (p <- prunedPrograms.take(100)) { printer(" - "+ndProgram.getExpr(p).asString) } - if(nPassing > 10) { + if(nPassing > 100) { printer(" - ...") } } @@ -805,94 +828,86 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { } } + // We can skip CE search if - we have excluded all programs or - we do so with validatePrograms + var skipCESearch = nPassing == 0 || interruptManager.isInterrupted || { + // If the number of pruned programs is very small, or by far smaller than the number of total programs, + // we hypothesize it will be easier to just validate them individually. + // Otherwise, we validate a small number of programs just in case we are lucky FIXME is this last clause useful? + val (programsToValidate, otherPrograms) = if (nTotal / nPassing > passingRatio || nPassing < 10) { + (prunedPrograms, Nil) + } else { + prunedPrograms.splitAt(validateUpTo) + } - if (nPassing == 0 || interruptManager.isInterrupted) { - // No test passed, we can skip solver and unfold again, if possible - skipCESearch = true - } else { - var doFilter = true - - if (validateUpTo > 0) { - // Validate the first N programs individualy - ndProgram.validatePrograms(prunedPrograms.take(validateUpTo)) match { - case Left(sols) if sols.nonEmpty => - doFilter = false - result = Some(RuleClosed(sols)) - case Right(cexs) => - baseExampleInputs ++= cexs.map(InExample) - - if (nPassing <= validateUpTo) { - // All programs failed verification, we filter everything out and unfold - doFilter = false - skipCESearch = true + ndProgram.validatePrograms(programsToValidate) match { + case Left(sols) if sols.nonEmpty => + // Found solution! Exit CEGIS + result = Some(RuleClosed(sols)) + true + case Right(cexs) => + // Found some counterexamples + val newCexs = cexs.map(InExample) + baseExampleInputs ++= newCexs + // Retest whether the newly found C-E invalidates some programs + for (p <- otherPrograms if !interruptManager.isInterrupted) { + // Exclude any programs that fail at least one new cex + newCexs.find { cex => !ndProgram.testForProgram(p)(cex) }.foreach { cex => + failedTestsStats(cex) += 1 + ndProgram.excludeProgram(p, true) } - } + } + // If we excluded all programs, we can skip CE search + programsToValidate.size >= nPassing } + } - if (doFilter) { - sctx.reporter.debug("Excluding "+wrongPrograms.size+" programs") - wrongPrograms.foreach { - ndProgram.excludeProgram(_, true) - } + if (!skipCESearch) { + sctx.reporter.debug("Excluding "+wrongPrograms.size+" programs") + wrongPrograms.foreach { + ndProgram.excludeProgram(_, true) } } // CEGIS Loop at a given unfolding level - while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted) { + while (result.isEmpty && !skipCESearch && !interruptManager.isInterrupted && !ndProgram.allProgramsClosed) { + timers.loop.start() ndProgram.solveForTentativeProgram() match { case Some(Some(bs)) => - // Should we validate this program with Z3? - - val validateWithZ3 = if (hasInputExamples) { - - if (allInputExamples().forall(ndProgram.testForProgram(bs))) { - // All valid inputs also work with this, we need to - // make sure by validating this candidate with z3 - true - } else { - println("testing failed ?!") - // One valid input failed with this candidate, we can skip + // No inputs to test or all valid inputs also work with this. + // We need to make sure by validating this candidate with z3 + sctx.reporter.debug("Found tentative model, need to validate!") + ndProgram.solveForCounterExample(bs) match { + case Some(Some(inputsCE)) => + sctx.reporter.debug("Found counter-example:" + inputsCE) + val ce = InExample(inputsCE) + // Found counter example! Exclude this program + baseExampleInputs += ce ndProgram.excludeProgram(bs, false) - false - } - } else { - // No inputs or capability to test, we need to ask Z3 - true - } - sctx.reporter.debug("Found tentative model (Validate="+validateWithZ3+")!") - - if (validateWithZ3) { - ndProgram.solveForCounterExample(bs) match { - case Some(Some(inputsCE)) => - sctx.reporter.debug("Found counter-example:"+inputsCE) - val ce = InExample(inputsCE) - // Found counter example! - baseExampleInputs += ce - - // Retest whether the newly found C-E invalidates all programs - if (prunedPrograms.forall(p => !ndProgram.testForProgram(p)(ce))) { - skipCESearch = true - } else { - ndProgram.excludeProgram(bs, false) - } - - case Some(None) => - // Found no counter example! Program is a valid solution + + // Retest whether the newly found C-E invalidates some programs + prunedPrograms.foreach { p => + if (!ndProgram.testForProgram(p)(ce)) ndProgram.excludeProgram(p, true) + } + + case Some(None) => + // Found no counter example! Program is a valid solution + val expr = ndProgram.getExpr(bs) + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) + + case None => + // We are not sure + sctx.reporter.debug("Unknown") + if (useOptTimeout) { + // Interpret timeout in CE search as "the candidate is valid" + sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") val expr = ndProgram.getExpr(bs) - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr))) - - case None => - // We are not sure - sctx.reporter.debug("Unknown") - if (useOptTimeout) { - // Interpret timeout in CE search as "the candidate is valid" - sctx.reporter.info("CEGIS could not prove the validity of the resulting expression") - val expr = ndProgram.getExpr(bs) - result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) - } else { - result = Some(RuleFailed()) - } - } + result = Some(RuleClosed(Solution(BooleanLiteral(true), Set(), expr, isTrusted = false))) + } else { + // Ok, we failed to validate, exclude this program + ndProgram.excludeProgram(bs, false) + // TODO: Make CEGIS fail early when it fails on 1 program? + // result = Some(RuleFailed()) + } } case Some(None) => @@ -901,11 +916,14 @@ abstract class CEGISLike[T <: Typed](name: String) extends Rule(name) { case None => result = Some(RuleFailed()) } + + timers.loop.stop() } unfolding += 1 - } while(unfolding <= maxUnfoldings && result.isEmpty && !interruptManager.isInterrupted) + } while(unfolding <= maxSize && result.isEmpty && !interruptManager.isInterrupted) + if (interruptManager.isInterrupted) interruptManager.recoverInterrupt() result.getOrElse(RuleFailed()) } catch { diff --git a/src/main/scala/leon/synthesis/rules/CEGLESS.scala b/src/main/scala/leon/synthesis/rules/CEGLESS.scala index c12edac075bc8525d395d5f792ef4579c0d109f1..36cc7f9e65dae8af9d8c17d4db936dc4400c0ece 100644 --- a/src/main/scala/leon/synthesis/rules/CEGLESS.scala +++ b/src/main/scala/leon/synthesis/rules/CEGLESS.scala @@ -4,10 +4,10 @@ package leon package synthesis package rules +import leon.grammars.transformers.Union import purescala.ExprOps._ import purescala.Types._ import purescala.Extractors._ -import utils._ import grammars._ import Witnesses._ @@ -24,7 +24,7 @@ case object CEGLESS extends CEGISLike[NonTerminal[String]]("CEGLESS") { val inputs = p.as.map(_.toVariable) sctx.reporter.ifDebug { printer => - printer("Guides available:") + printer("Guides available:") for (g <- guides) { printer(" - "+g.asString(ctx)) } @@ -35,7 +35,8 @@ case object CEGLESS extends CEGISLike[NonTerminal[String]]("CEGLESS") { CegisParams( grammar = guidedGrammar, rootLabel = { (tpe: TypeTree) => NonTerminal(tpe, "G0") }, - maxUnfoldings = (0 +: guides.map(depth(_) + 1)).max + optimizations = false, + maxSize = Some((0 +: guides.map(depth(_) + 1)).max) ) } } diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala index 2ae2b1d5d0292a6ed725055e61b1b4af4100a63c..d3b4c823dd7110763316d121407bcf94820c5826 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala @@ -83,7 +83,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { } } - var eb = p.qeb.mapIns { info => + val eb = p.qeb.mapIns { info => List(info.flatMap { case (id, v) => ebMapInfo.get(id) match { case Some(m) => @@ -103,7 +103,8 @@ case object DetupleInput extends NormalizingRule("Detuple In") { case CaseClass(ct, args) => val (cts, es) = args.zip(ct.fields).map { case (CaseClassSelector(ct, e, id), field) if field.id == id => (ct, e) - case _ => return e + case _ => + return e }.unzip if (cts.distinct.size == 1 && es.distinct.size == 1) { @@ -126,7 +127,7 @@ case object DetupleInput extends NormalizingRule("Detuple In") { val sub = Problem(newAs, subWs, subPc, subProblem, p.xs, eb) - val s = {substAll(reverseMap, _:Expr)} andThen { simplePostTransform(recompose) } + val s = (substAll(reverseMap, _:Expr)) andThen simplePostTransform(recompose) Some(decomp(List(sub), forwardMap(s), s"Detuple ${reverseMap.keySet.mkString(", ")}")) } else { diff --git a/src/main/scala/leon/synthesis/rules/StringRender.scala b/src/main/scala/leon/synthesis/rules/StringRender.scala index f03c54b560e81428851c83b86b9430d4e706e20f..715d54385e07a86f642fa454b78bcf72f020b528 100644 --- a/src/main/scala/leon/synthesis/rules/StringRender.scala +++ b/src/main/scala/leon/synthesis/rules/StringRender.scala @@ -6,36 +6,33 @@ package rules import scala.annotation.tailrec import scala.collection.mutable.ListBuffer - import bonsai.enumerators.MemoizedEnumerator -import leon.evaluators.DefaultEvaluator -import leon.evaluators.StringTracingEvaluator -import leon.synthesis.programsets.DirectProgramSet -import leon.synthesis.programsets.JoinProgramSet -import leon.purescala.Common.FreshIdentifier -import leon.purescala.Common.Identifier -import leon.purescala.DefOps -import leon.purescala.Definitions.FunDef -import leon.purescala.Definitions.FunDef -import leon.purescala.Definitions.ValDef -import leon.purescala.ExprOps -import leon.solvers.Model -import leon.solvers.ModelBuilder -import leon.solvers.string.StringSolver -import leon.utils.DebugSectionSynthesis +import evaluators.DefaultEvaluator +import evaluators.AbstractEvaluator +import purescala.Definitions.{FunDef, ValDef, Program, TypedFunDef, CaseClassDef, AbstractClassDef} +import purescala.Common._ +import purescala.Types._ import purescala.Constructors._ -import purescala.Definitions._ -import purescala.ExprOps._ import purescala.Expressions._ import purescala.Extractors._ import purescala.TypeOps -import purescala.Types._ +import purescala.DefOps +import purescala.ExprOps +import purescala.SelfPrettyPrinter +import solvers.Model +import solvers.ModelBuilder +import solvers.string.StringSolver +import synthesis.programsets.DirectProgramSet +import synthesis.programsets.JoinProgramSet +import leon.utils.DebugSectionSynthesis + + /** A template generator for a given type tree. * Extend this class using a concrete type tree, * Then use the apply method to get a hole which can be a placeholder for holes in the template. - * Each call to the ``.instantiate` method of the subsequent Template will provide different instances at each position of the hole. + * Each call to the `.instantiate` method of the subsequent Template will provide different instances at each position of the hole. */ abstract class TypedTemplateGenerator(t: TypeTree) { import StringRender.WithIds @@ -74,60 +71,84 @@ case object StringRender extends Rule("StringRender") { var EDIT_ME = "_edit_me_" - var _defaultTypeToString: Option[Map[TypeTree, FunDef]] = None - - def defaultMapTypeToString()(implicit hctx: SearchContext): Map[TypeTree, FunDef] = { - _defaultTypeToString.getOrElse{ - // Updates the cache with the functions converting standard types to string. - val res = (hctx.program.library.StrOps.toSeq.flatMap { StrOps => - StrOps.defs.collect{ case d: FunDef if d.params.length == 1 && d.returnType == StringType => d.params.head.getType -> d } - }).toMap - _defaultTypeToString = Some(res) - res - } - } - - /** Returns a toString function converter if it has been defined. */ - class WithFunDefConverter(implicit hctx: SearchContext) { - def unapply(tpe: TypeTree): Option[FunDef] = { - _defaultTypeToString.flatMap(_.get(tpe)) - } - } + var enforceDefaultStringMethodsIfAvailable = true + var enforceSelfStringMethodsIfAvailable = false val booleanTemplate = (a: Expr) => StringTemplateGenerator(Hole => IfExpr(a, Hole, Hole)) - /** Returns a seq of expressions such as `x + y + "1" + y + "2" + z` associated to an expected result string `"1, 2"`. - * We use these equations so that we can find the values of the constants x, y, z and so on. - * This uses a custom evaluator which does not concatenate string but reminds the calculation. - */ - def createProblems(inlineFunc: Seq[FunDef], inlineExpr: Expr, examples: ExamplesBank): Seq[(Expr, String)] = ??? + import StringSolver.{StringFormToken, StringForm, Problem => SProblem, Equation, Assignment} - /** For each solution to the problem such as `x + "1" + y + j + "2" + z = 1, 2`, outputs all possible assignments if they exist. */ - def solveProblems(problems: Seq[(Expr, String)]): Seq[Map[Identifier, String]] = ??? + /** Augment the left-hand-side to have possible function calls, such as x + "const" + customToString(_) ... + * Function calls will be eliminated when converting to a valid problem. + */ + sealed abstract class AugmentedStringFormToken + case class RegularStringFormToken(e: StringFormToken) extends AugmentedStringFormToken + case class OtherStringFormToken(e: Expr) extends AugmentedStringFormToken + type AugmentedStringForm = List[AugmentedStringFormToken] - import StringSolver.{StringFormToken, StringForm, Problem => SProblem, Equation, Assignment} + /** Augments the right-hand-side to have possible function calls, such as "const" + customToString(_) ... + * Function calls will be eliminated when converting to a valid problem. + */ + sealed abstract class AugmentedStringChunkRHS + case class RegularStringChunk(e: String) extends AugmentedStringChunkRHS + case class OtherStringChunk(e: Expr) extends AugmentedStringChunkRHS + type AugmentedStringLiteral = List[AugmentedStringChunkRHS] /** Converts an expression to a stringForm, suitable for StringSolver */ - def toStringForm(e: Expr, acc: List[StringFormToken] = Nil)(implicit hctx: SearchContext): Option[StringForm] = e match { + def toStringForm(e: Expr, acc: List[AugmentedStringFormToken] = Nil)(implicit hctx: SearchContext): Option[AugmentedStringForm] = e match { case StringLiteral(s) => - Some(Left(s)::acc) - case Variable(id) => Some(Right(id)::acc) + Some(RegularStringFormToken(Left(s))::acc) + case Variable(id) => Some(RegularStringFormToken(Right(id))::acc) case StringConcat(lhs, rhs) => toStringForm(rhs, acc).flatMap(toStringForm(lhs, _)) + case e:Application => Some(OtherStringFormToken(e)::acc) + case e:FunctionInvocation => Some(OtherStringFormToken(e)::acc) case _ => None } /** Returns the string associated to the expression if it is computable */ - def toStringLiteral(e: Expr): Option[String] = e match { - case StringLiteral(s) => Some(s) - case StringConcat(lhs, rhs) => toStringLiteral(lhs).flatMap(k => toStringLiteral(rhs).map(l => k + l)) + def toStringLiteral(e: Expr): Option[AugmentedStringLiteral] = e match { + case StringLiteral(s) => Some(List(RegularStringChunk(s))) + case StringConcat(lhs, rhs) => + toStringLiteral(lhs).flatMap(k => toStringLiteral(rhs).map(l => (k.init, k.last, l) match { + case (kinit, RegularStringChunk(s), RegularStringChunk(sp)::ltail) => + kinit ++ (RegularStringChunk(s + sp)::ltail) + case _ => k ++ l + })) + case e: Application => Some(List(OtherStringChunk(e))) + case e: FunctionInvocation => Some(List(OtherStringChunk(e))) case _ => None } + /** Converts an equality AugmentedStringForm == AugmentedStringLiteral to a list of equations + * For that, splits both strings on function applications. If they yield the same value, we can split, else it fails. */ + def toEquations(lhs: AugmentedStringForm, rhs: AugmentedStringLiteral): Option[List[Equation]] = { + def rec(lhs: AugmentedStringForm, rhs: AugmentedStringLiteral, + accEqs: ListBuffer[Equation], accLeft: ListBuffer[StringFormToken], accRight: StringBuffer): Option[List[Equation]] = (lhs, rhs) match { + case (Nil, Nil) => + (accLeft.toList, accRight.toString) match { + case (Nil, "") => Some(accEqs.toList) + case (lhs, rhs) => Some((accEqs += ((lhs, rhs))).toList) + } + case (OtherStringFormToken(e)::lhstail, OtherStringChunk(f)::rhstail) => + if(ExprOps.canBeHomomorphic(e, f).nonEmpty) { + rec(lhstail, rhstail, accEqs += ((accLeft.toList, accRight.toString)), ListBuffer[StringFormToken](), new StringBuffer) + } else None + case (OtherStringFormToken(e)::lhstail, Nil) => + None + case (Nil, OtherStringChunk(f)::rhstail) => + None + case (lhs, RegularStringChunk(s)::rhstail) => + rec(lhs, rhstail, accEqs, accLeft, accRight append s) + case (RegularStringFormToken(e)::lhstail, rhs) => + rec(lhstail, rhs, accEqs, accLeft += e, accRight) + } + rec(lhs, rhs, ListBuffer[Equation](), ListBuffer[StringFormToken](), new StringBuffer) + } + /** Returns a stream of assignments compatible with input/output examples for the given template */ def findAssignments(p: Program, inputs: Seq[Identifier], examples: ExamplesBank, template: Expr)(implicit hctx: SearchContext): Stream[Map[Identifier, String]] = { - //new Evaluator() - val e = new StringTracingEvaluator(hctx.context, p) + val e = new AbstractEvaluator(hctx.context, p) @tailrec def gatherEquations(s: List[InOutExample], acc: ListBuffer[Equation] = ListBuffer()): Option[SProblem] = s match { case Nil => Some(acc.toList) @@ -139,34 +160,39 @@ case object StringRender extends Rule("StringRender") { val evalResult = e.eval(template, modelResult) evalResult.result match { case None => - hctx.reporter.debug("Eval = None : ["+template+"] in ["+inputs.zip(in)+"]") + hctx.reporter.info("Eval = None : ["+template+"] in ["+inputs.zip(in)+"]") + hctx.reporter.info(evalResult) None case Some((sfExpr, abstractSfExpr)) => //ctx.reporter.debug("Eval = ["+sfExpr+"] (from "+abstractSfExpr+")") val sf = toStringForm(sfExpr) val rhs = toStringLiteral(rhExpr.head) - if(sf.isEmpty || rhs.isEmpty) { - hctx.reporter.ifDebug(printer => printer("sf empty or rhs empty ["+sfExpr+"] => ["+sf+"] in ["+rhs+"]")) - None - } else gatherEquations(q, acc += ((sf.get, rhs.get))) + (sf, rhs) match { + case (Some(sfget), Some(rhsget)) => + toEquations(sfget, rhsget) match { + case Some(equations) => + gatherEquations(q, acc ++= equations) + case None => + hctx.reporter.info("Could not extract equations from ["+sfget+"] == ["+rhsget+"]\n coming from ... == " + rhExpr) + None + } + case _ => + hctx.reporter.info("sf empty or rhs empty ["+sfExpr+"] => ["+sf+"] in ["+rhs+"]") + None + } } } else { - hctx.reporter.ifDebug(printer => printer("RHS.length != 1 : ["+rhExpr+"]")) + hctx.reporter.info("RHS.length != 1 : ["+rhExpr+"]") None } } gatherEquations((examples.valids ++ examples.invalids).collect{ case io:InOutExample => io }.toList) match { - case Some(problem) => - hctx.reporter.debug("Problem: ["+StringSolver.renderProblem(problem)+"]") - val res = StringSolver.solve(problem) - hctx.reporter.debug("Solution found:"+res.nonEmpty) - res - case None => - hctx.reporter.ifDebug(printer => printer("No problem found")) - Stream.empty + case Some(problem) => StringSolver.solve(problem) + case None => Stream.empty } } + /** With a given (template, fundefs, consts) will find consts so that (expr, funs) passes all the examples */ def findSolutions(examples: ExamplesBank, template: Stream[WithIds[Expr]], funDefs: Seq[(FunDef, Stream[WithIds[Expr]])])(implicit hctx: SearchContext, p: Problem): RuleApplication = { // Fun is a stream of many function applications. val funs= JoinProgramSet.direct(funDefs.map(fbody => fbody._2.map((fbody._1, _))).map(d => DirectProgramSet(d))) @@ -176,7 +202,7 @@ case object StringRender extends Rule("StringRender") { def computeSolutions(funDefsBodies: Seq[(FunDef, WithIds[Expr])], template: WithIds[Expr]): Stream[Assignment] = { val funDefs = for((funDef, body) <- funDefsBodies) yield { funDef.body = Some(body._1); funDef } val newProgram = DefOps.addFunDefs(hctx.program, funDefs, hctx.sctx.functionContext) - findAssignments(newProgram, p.as, examples, template._1) + findAssignments(newProgram, p.as.filter{ x => !x.getType.isInstanceOf[FunctionType] }, examples, template._1) } val tagged_solutions = @@ -219,7 +245,7 @@ case object StringRender extends Rule("StringRender") { var transformMap = Map[FunDef, FunDef]() def mapExpr(body: Expr): Expr = { ExprOps.preMap((e: Expr) => e match { - case FunctionInvocation(TypedFunDef(fd, _), args) if fd != program.library.escape.get => Some(FunctionInvocation(getMapping(fd).typed, args)) + case FunctionInvocation(TypedFunDef(fd, _), args) if fds(fd) => Some(functionInvocation(getMapping(fd), args)) case e => None })(body) } @@ -241,10 +267,6 @@ case object StringRender extends Rule("StringRender") { case class DependentType(caseClassParent: Option[TypeTree], inputName: String, typeToConvert: TypeTree) - object StringSynthesisContext { - def empty(implicit hctx: SearchContext) = new StringSynthesisContext(None, new StringSynthesisResult(Map(), Set())) - } - type MapFunctions = Map[DependentType, (FunDef, Stream[WithIds[Expr]])] /** Result of the current synthesis process */ @@ -266,17 +288,39 @@ case object StringRender extends Rule("StringRender") { s0 } } + type StringConverters = Map[TypeTree, List[Expr => Expr]] + + /** Companion object to create a StringSynthesisContext */ + object StringSynthesisContext { + def empty( + abstractStringConverters: StringConverters, + originalInputs: Set[Expr], + provided_functions: Seq[Identifier])(implicit hctx: SearchContext) = + new StringSynthesisContext(None, new StringSynthesisResult(Map(), Set()), + abstractStringConverters, + originalInputs, + provided_functions) + } /** Context for the current synthesis process */ class StringSynthesisContext( val currentCaseClassParent: Option[TypeTree], - val result: StringSynthesisResult + val result: StringSynthesisResult, + val abstractStringConverters: StringConverters, + val originalInputs: Set[Expr], + val provided_functions: Seq[Identifier] )(implicit hctx: SearchContext) { def add(d: DependentType, f: FunDef, s: Stream[WithIds[Expr]]): StringSynthesisContext = { - new StringSynthesisContext(currentCaseClassParent, result.add(d, f, s)) + new StringSynthesisContext(currentCaseClassParent, result.add(d, f, s), + abstractStringConverters, + originalInputs, + provided_functions) } def copy(currentCaseClassParent: Option[TypeTree]=currentCaseClassParent, result: StringSynthesisResult = result): StringSynthesisContext = - new StringSynthesisContext(currentCaseClassParent, result) + new StringSynthesisContext(currentCaseClassParent, result, + abstractStringConverters, + originalInputs, + provided_functions) def freshFunName(s: String) = result.freshFunName(s) } @@ -296,7 +340,8 @@ case object StringRender extends Rule("StringRender") { val funName = funName3(0).toLower + funName3.substring(1) val funId = FreshIdentifier(ctx.freshFunName(funName), alwaysShowUniqueID = true) val argId= FreshIdentifier(tpe.typeToConvert.asString(hctx.context).toLowerCase()(0).toString, tpe.typeToConvert) - val fd = new FunDef(funId, Nil, ValDef(argId) :: Nil, StringType) // Empty function. + val tparams = hctx.sctx.functionContext.tparams + val fd = new FunDef(funId, tparams, ValDef(argId) :: ctx.provided_functions.map(ValDef(_)).toList, StringType) // Empty function. fd } @@ -372,54 +417,89 @@ case object StringRender extends Rule("StringRender") { val dependentType = DependentType(ctx.currentCaseClassParent, input.asString(hctx.program)(hctx.context), input.getType) ctx.result.adtToString.get(dependentType) match { case Some(fd) => - gatherInputs(ctx, q, result += Stream((functionInvocation(fd._1, Seq(input)), Nil))) + gatherInputs(ctx, q, result += Stream((functionInvocation(fd._1, input::ctx.provided_functions.toList.map(Variable)), Nil))) case None => // No function can render the current type. + // We should not rely on calling the original function on the first line of the body of the function itself. + val exprs1s = (new SelfPrettyPrinter) + .allowFunction(hctx.sctx.functionContext) + .excludeFunction(hctx.sctx.functionContext) + .prettyPrintersForType(input.getType)(hctx.context, hctx.program) + .map(l => (application(l, Seq(input)), List[Identifier]())) // Use already pre-defined pretty printers. + val exprs1 = exprs1s.toList.sortBy{ case (Lambda(_, FunctionInvocation(fd, _)), _) if fd == hctx.sctx.functionContext => 0 case _ => 1} + val exprs2 = ctx.abstractStringConverters.getOrElse(input.getType, Nil).map(f => (f(input), List[Identifier]())) + val defaultConverters: Stream[WithIds[Expr]] = exprs1.toStream #::: exprs2.toStream + val recursiveConverters: Stream[WithIds[Expr]] = + (new SelfPrettyPrinter) + .prettyPrinterFromCandidate(hctx.sctx.functionContext, input.getType)(hctx.context, hctx.program) + .map(l => (application(l, Seq(input)), List[Identifier]())) + + def mergeResults(templateConverters: =>Stream[WithIds[Expr]]): Stream[WithIds[Expr]] = { + if(defaultConverters.isEmpty) templateConverters + else if(enforceDefaultStringMethodsIfAvailable) { + if(enforceSelfStringMethodsIfAvailable) + recursiveConverters #::: defaultConverters + else { + defaultConverters #::: recursiveConverters + } + } + else recursiveConverters #::: defaultConverters #::: templateConverters + } + input.getType match { case StringType => gatherInputs(ctx, q, result += - (Stream((input, Nil), - (FunctionInvocation( - hctx.program.library.escape.get.typed, - Seq(input)): Expr, Nil)))) + mergeResults(Stream((input, Nil), + (functionInvocation( + hctx.program.library.escape.get, List(input)): Expr, Nil)))) case BooleanType => val (bTemplate, vs) = booleanTemplate(input).instantiateWithVars - gatherInputs(ctx, q, result += Stream((BooleanToString(input), Nil), (bTemplate, vs))) + gatherInputs(ctx, q, result += mergeResults(Stream((BooleanToString(input), Nil), (bTemplate, vs)))) case WithStringconverter(converter) => // Base case - gatherInputs(ctx, q, result += Stream((converter(input), Nil))) + gatherInputs(ctx, q, result += mergeResults(Stream((converter(input), Nil)))) case t: ClassType => - // Create the empty function body and updates the assignments parts. - val fd = createEmptyFunDef(ctx, dependentType) - val ctx2 = preUpdateFunDefBody(dependentType, fd, ctx) // Inserts the FunDef in the assignments so that it can already be used. - t.root match { - case act@AbstractClassType(acd@AbstractClassDef(id, tparams, parent), tps) => - // Create a complete FunDef body with pattern matching - - val allKnownDescendantsAreCCAndHaveZeroArgs = act.knownCCDescendants.forall { x => x match { - case CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) => ccd.fields.isEmpty - case _ => false - }} - - //TODO: Test other templates not only with Wilcard patterns, but more cases options for non-recursive classes (e.g. Option, Boolean, Finite parameterless case classes.) - val (ctx3, cases) = ((ctx2, ListBuffer[Stream[WithIds[MatchCase]]]()) /: act.knownCCDescendants) { - case ((ctx22, acc), cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2)) => - val (newCases, result) = extractCaseVariants(cct, ctx22) - val ctx23 = ctx22.copy(result = result) - (ctx23, acc += newCases) - case ((adtToString, acc), e) => hctx.reporter.fatalError("Could not handle this class definition for string rendering " + e) - } - - val allMatchExprsEnd = JoinProgramSet(cases.map(DirectProgramSet(_)), mergeMatchCases(fd)).programs // General pattern match expressions - val allMatchExprs = if(allKnownDescendantsAreCCAndHaveZeroArgs) { - Stream(constantPatternMatching(fd, act)) ++ allMatchExprsEnd - } else allMatchExprsEnd - gatherInputs(ctx3.add(dependentType, fd, allMatchExprs), q, result += Stream((functionInvocation(fd, Seq(input)), Nil))) - case cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) => - val (newCases, result3) = extractCaseVariants(cct, ctx2) - val allMatchExprs = newCases.map(acase => mergeMatchCases(fd)(Seq(acase))) - gatherInputs(ctx2.copy(result = result3).add(dependentType, fd, allMatchExprs), q, result += Stream((functionInvocation(fd, Seq(input)), Nil))) + if(enforceDefaultStringMethodsIfAvailable && !defaultConverters.isEmpty) { + gatherInputs(ctx, q, result += defaultConverters) + } else { + // Create the empty function body and updates the assignments parts. + val fd = createEmptyFunDef(ctx, dependentType) + val ctx2 = preUpdateFunDefBody(dependentType, fd, ctx) // Inserts the FunDef in the assignments so that it can already be used. + t.root match { + case act@AbstractClassType(acd@AbstractClassDef(id, tparams, parent), tps) => + // Create a complete FunDef body with pattern matching + + val allKnownDescendantsAreCCAndHaveZeroArgs = act.knownCCDescendants.forall { x => x match { + case CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) => ccd.fields.isEmpty + case _ => false + }} + + //TODO: Test other templates not only with Wilcard patterns, but more cases options for non-recursive classes (e.g. Option, Boolean, Finite parameterless case classes.) + val (ctx3, cases) = ((ctx2, ListBuffer[Stream[WithIds[MatchCase]]]()) /: act.knownCCDescendants) { + case ((ctx22, acc), cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2)) => + val (newCases, result) = extractCaseVariants(cct, ctx22) + val ctx23 = ctx22.copy(result = result) + (ctx23, acc += newCases) + case ((adtToString, acc), e) => hctx.reporter.fatalError("Could not handle this class definition for string rendering " + e) + } + + val allMatchExprsEnd = JoinProgramSet(cases.map(DirectProgramSet(_)), mergeMatchCases(fd)).programs // General pattern match expressions + val allMatchExprs = if(allKnownDescendantsAreCCAndHaveZeroArgs) { + Stream(constantPatternMatching(fd, act)) ++ allMatchExprsEnd + } else allMatchExprsEnd + gatherInputs(ctx3.add(dependentType, fd, allMatchExprs), q, + result += Stream((functionInvocation(fd, input::ctx.provided_functions.toList.map(Variable)), Nil))) + case cct@CaseClassType(ccd@CaseClassDef(id, tparams, parent, isCaseObject), tparams2) => + val (newCases, result3) = extractCaseVariants(cct, ctx2) + val allMatchExprs = newCases.map(acase => mergeMatchCases(fd)(Seq(acase))) + gatherInputs(ctx2.copy(result = result3).add(dependentType, fd, allMatchExprs), q, + result += Stream((functionInvocation(fd, input::ctx.provided_functions.toList.map(Variable)), Nil))) + } } case TypeParameter(t) => - hctx.reporter.fatalError("Could not handle type parameter for string rendering " + t) + if(defaultConverters.isEmpty) { + hctx.reporter.fatalError("Could not handle type parameter for string rendering " + t) + } else { + gatherInputs(ctx, q, result += mergeResults(Stream.empty)) + } case tpe => hctx.reporter.fatalError("Could not handle class type for string rendering " + tpe) } @@ -455,7 +535,7 @@ case object StringRender extends Rule("StringRender") { } template } - (templates.flatten, ctx2.result) // TODO: Flatten or interleave? + (templates.flatten, ctx2.result) } def instantiateOn(implicit hctx: SearchContext, p: Problem): Traversable[RuleInstantiation] = { @@ -463,15 +543,32 @@ case object StringRender extends Rule("StringRender") { p.xs match { case List(IsTyped(v, StringType)) => val description = "Creates a standard string conversion function" - - val defaultToStringFunctions = defaultMapTypeToString() - + val examplesFinder = new ExamplesFinder(hctx.context, hctx.program) + .setKeepAbstractExamples(true) + .setEvaluationFailOnChoose(true) val examples = examplesFinder.extractFromProblem(p) + val abstractStringConverters: StringConverters = + (p.as.flatMap { case x => x.getType match { + case FunctionType(Seq(aType), StringType) => List((aType, (arg: Expr) => application(Variable(x), Seq(arg)))) + case _ => Nil + }}).groupBy(_._1).mapValues(_.map(_._2)) + + val (inputVariables, functionVariables) = p.as.partition ( x => x.getType match { + case f: FunctionType => false + case _ => true + }) + val ruleInstantiations = ListBuffer[RuleInstantiation]() + val originalInputs = inputVariables.map(Variable) ruleInstantiations += RuleInstantiation("String conversion") { - val (expr, synthesisResult) = createFunDefsTemplates(StringSynthesisContext.empty, p.as.map(Variable)) + val (expr, synthesisResult) = createFunDefsTemplates( + StringSynthesisContext.empty( + abstractStringConverters, + originalInputs.toSet, + functionVariables + ), originalInputs) val funDefs = synthesisResult.adtToString /*val toDebug: String = (("\nInferred functions:" /: funDefs)( (t, s) => diff --git a/src/main/scala/leon/synthesis/rules/TEGIS.scala b/src/main/scala/leon/synthesis/rules/TEGIS.scala index d7ec34617ee7dc50745c3b6839511e2c00a6037e..3d496d0597e1947af0eb83504be5af449d7854f1 100644 --- a/src/main/scala/leon/synthesis/rules/TEGIS.scala +++ b/src/main/scala/leon/synthesis/rules/TEGIS.scala @@ -6,7 +6,6 @@ package rules import purescala.Types._ import grammars._ -import utils._ case object TEGIS extends TEGISLike[TypeTree]("TEGIS") { def getParams(sctx: SynthesisContext, p: Problem) = { diff --git a/src/main/scala/leon/synthesis/rules/TEGISLike.scala b/src/main/scala/leon/synthesis/rules/TEGISLike.scala index 91084ae4f6d69d055c36f0ce2c75bc4b41bfa763..93e97de6f1ad97c40def5b77c0d79fbb60282633 100644 --- a/src/main/scala/leon/synthesis/rules/TEGISLike.scala +++ b/src/main/scala/leon/synthesis/rules/TEGISLike.scala @@ -12,6 +12,7 @@ import datagen._ import evaluators._ import codegen.CodeGenParams import grammars._ +import leon.utils.GrowableIterable import scala.collection.mutable.{HashMap => MutableMap} @@ -40,7 +41,7 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val nTests = if (p.pc == BooleanLiteral(true)) 50 else 20 - val useVanuatoo = sctx.settings.cegisUseVanuatoo.getOrElse(false) + val useVanuatoo = sctx.settings.cegisUseVanuatoo val inputGenerator: Iterator[Seq[Expr]] = if (useVanuatoo) { new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, nTests, 3000) @@ -53,8 +54,6 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val failedTestsStats = new MutableMap[Seq[Expr], Int]().withDefaultValue(0) - def hasInputExamples = gi.nonEmpty - var n = 1 def allInputExamples() = { if (n == 10 || n == 50 || n % 500 == 0) { @@ -64,14 +63,12 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { gi.iterator } - var tests = p.eb.valids.map(_.ins).distinct - if (gi.nonEmpty) { - val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000) - val evaluator = new DualEvaluator(sctx.context, sctx.program, evalParams) + val evalParams = CodeGenParams.default.copy(maxFunctionInvocations = 2000) + val evaluator = new DualEvaluator(sctx.context, sctx.program, evalParams) - val enum = new MemoizedEnumerator[T, Expr, Generator[T, Expr]](grammar.getProductions) + val enum = new MemoizedEnumerator[T, Expr, ProductionRule[T, Expr]](grammar.getProductions) val targetType = tupleTypeWrap(p.xs.map(_.getType)) @@ -80,7 +77,6 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { val allExprs = enum.iterator(params.rootLabel(targetType)) var candidate: Option[Expr] = None - var n = 1 def findNext(): Option[Expr] = { candidate = None @@ -111,14 +107,9 @@ abstract class TEGISLike[T <: Typed](name: String) extends Rule(name) { candidate } - def toStream: Stream[Solution] = { - findNext() match { - case Some(e) => - Stream.cons(Solution(BooleanLiteral(true), Set(), e, isTrusted = false), toStream) - case None => - Stream.empty - } - } + val toStream = Stream.continually(findNext()).takeWhile(_.nonEmpty).map( e => + Solution(BooleanLiteral(true), Set(), e.get, isTrusted = false) + ) RuleClosed(toStream) } else { diff --git a/src/main/scala/leon/synthesis/utils/Helpers.scala b/src/main/scala/leon/synthesis/utils/Helpers.scala index acd285a4570f93ee9dd85ba3dd29a7e4b120c25a..4bfedc4acbe59440ac7f3382c8187ae201775f02 100644 --- a/src/main/scala/leon/synthesis/utils/Helpers.scala +++ b/src/main/scala/leon/synthesis/utils/Helpers.scala @@ -34,7 +34,18 @@ object Helpers { } } - def terminatingCalls(prog: Program, tpe: TypeTree, ws: Expr, pc: Expr): List[(Expr, Set[Identifier])] = { + /** Given an initial set of function calls provided by a list of [[Terminating]], + * returns function calls that will hopefully be safe to call recursively from within this initial function calls. + * + * For each returned call, one argument is substituted by a "smaller" one, while the rest are left as holes. + * + * @param prog The current program + * @param tpe The expected type for the returned function calls + * @param ws Helper predicates that contain [[Terminating]]s with the initial calls + * @param pc The path condition + * @return A list of pairs of (safe function call, holes), where holes stand for the rest of the arguments of the function. + */ + def terminatingCalls(prog: Program, tpe: TypeTree, ws: Expr, pc: Expr): List[(FunctionInvocation, Set[Identifier])] = { val TopLevelAnds(wss) = ws val TopLevelAnds(clauses) = pc diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index f59618e17ac4eec3881dbdfc30c2ad8133f5ae54..50de827c17168bacd3db23049ed76237f73f132b 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -32,7 +32,7 @@ trait StructuralSize { )) absFun.typed } - + def size(expr: Expr) : Expr = { def funDef(ct: ClassType, cases: ClassType => Seq[MatchCase]): FunDef = { // we want to reuse generic size functions for sub-types diff --git a/src/main/scala/leon/utils/Bijection.scala b/src/main/scala/leon/utils/Bijection.scala index 57a62b665c797b10fab2d099fabd3a722f6e7d27..380799d25e1f73ddbbb57d7989706fd03e5f1821 100644 --- a/src/main/scala/leon/utils/Bijection.scala +++ b/src/main/scala/leon/utils/Bijection.scala @@ -2,17 +2,29 @@ package leon.utils -class Bijection[A, B] { +object Bijection { + def apply[A, B](a: Iterable[(A, B)]): Bijection[A, B] = new Bijection[A, B] ++= a + def apply[A, B](a: (A, B)*): Bijection[A, B] = apply(a.toSeq) +} + +class Bijection[A, B] extends Iterable[(A, B)] { protected var a2b = Map[A, B]() protected var b2a = Map[B, A]() + + def iterator = a2b.iterator def +=(a: A, b: B): Unit = { a2b += a -> b b2a += b -> a } - def +=(t: (A,B)): Unit = { - this += (t._1, t._2) + def +=(t: (A,B)): this.type = { + +=(t._1, t._2) + this + } + + def ++=(t: Iterable[(A, B)]) = { + (this /: t){ case (b, elem) => b += elem } } def clear(): Unit = { @@ -22,6 +34,9 @@ class Bijection[A, B] { def getA(b: B) = b2a.get(b) def getB(a: A) = a2b.get(a) + + def getAorElse(b: B, orElse: =>A) = b2a.getOrElse(b, orElse) + def getBorElse(a: A, orElse: =>B) = a2b.getOrElse(a, orElse) def toA(b: B) = getA(b).get def toB(a: A) = getB(a).get @@ -50,4 +65,11 @@ class Bijection[A, B] { def aSet = a2b.keySet def bSet = b2a.keySet + + def composeA[C](c: A => C): Bijection[C, B] = { + new Bijection[C, B] ++= this.a2b.map(kv => c(kv._1) -> kv._2) + } + def composeB[C](c: B => C): Bijection[A, C] = { + new Bijection[A, C] ++= this.a2b.map(kv => kv._1 -> c(kv._2)) + } } diff --git a/src/main/scala/leon/utils/GrowableIterable.scala b/src/main/scala/leon/utils/GrowableIterable.scala index d05a9f06576a9e3748ba0ba5fdd33656cc9ac457..0b32fe6261b3bd41cf6bb8ad11fcc6161b47d44b 100644 --- a/src/main/scala/leon/utils/GrowableIterable.scala +++ b/src/main/scala/leon/utils/GrowableIterable.scala @@ -1,4 +1,4 @@ -package leon +package leon.utils import scala.collection.mutable.ArrayBuffer diff --git a/src/main/scala/leon/utils/IncrementalSeq.scala b/src/main/scala/leon/utils/IncrementalSeq.scala index 4ec9290b5eb5c2672b0f4fae44760081ca14ba80..fbf042868415378d4af4877ee8766f1303632373 100644 --- a/src/main/scala/leon/utils/IncrementalSeq.scala +++ b/src/main/scala/leon/utils/IncrementalSeq.scala @@ -13,6 +13,7 @@ class IncrementalSeq[A] extends IncrementalState with Builder[A, IncrementalSeq[A]] { private[this] val stack = new Stack[ArrayBuffer[A]]() + stack.push(new ArrayBuffer()) def clear() : Unit = { stack.clear() @@ -20,11 +21,11 @@ class IncrementalSeq[A] extends IncrementalState def reset(): Unit = { clear() - push() + stack.push(new ArrayBuffer()) } def push(): Unit = { - stack.push(new ArrayBuffer()) + stack.push(stack.head.clone) } def pop(): Unit = { @@ -33,9 +34,8 @@ class IncrementalSeq[A] extends IncrementalState def iterator = stack.flatten.iterator def +=(e: A) = { stack.head += e; this } + def -=(e: A) = { stack.head -= e; this } override def newBuilder = new scala.collection.mutable.ArrayBuffer() def result = this - - push() } diff --git a/src/main/scala/leon/utils/InliningPhase.scala b/src/main/scala/leon/utils/InliningPhase.scala index 8053a8dc1e9956c9ef264c247d41f8d96ffbb934..d69bbb9c32d6cd9f0d32281ec7b5af628aac5a15 100644 --- a/src/main/scala/leon/utils/InliningPhase.scala +++ b/src/main/scala/leon/utils/InliningPhase.scala @@ -3,12 +3,13 @@ package leon.utils import leon._ +import purescala.Common._ import purescala.Definitions._ import purescala.Expressions._ -import purescala.TypeOps._ +import purescala.TypeOps.instantiateType import purescala.ExprOps._ import purescala.DefOps._ -import purescala.Constructors.caseClassSelector +import purescala.Constructors.{caseClassSelector, application} object InliningPhase extends TransformationPhase { @@ -25,26 +26,20 @@ object InliningPhase extends TransformationPhase { def doInline(fd: FunDef) = fd.flags(IsInlined) && !doNotInline(fd) - def simplifyImplicitClass(e: Expr) = e match { - case CaseClassSelector(cct, cc: CaseClass, id) => - Some(caseClassSelector(cct, cc, id)) + for (fd <- p.definedFunctions) { + fd.fullBody = preMap ({ + case FunctionInvocation(tfd, args) if doInline(tfd.fd) => + Some(replaceFromIDs((tfd.params.map(_.id) zip args).toMap, tfd.fullBody)) - case _ => - None - } + case CaseClassSelector(cct, cc: CaseClass, id) => + Some(caseClassSelector(cct, cc, id)) - def simplify(e: Expr) = { - fixpoint(postMap(simplifyImplicitClass))(e) - } + case Application(caller: Lambda, args) => + Some(application(caller, args)) - for (fd <- p.definedFunctions) { - fd.fullBody = simplify(preMap ({ - case FunctionInvocation(TypedFunDef(fd, tps), args) if doInline(fd) => - val newBody = instantiateType(fd.fullBody, (fd.tparams zip tps).toMap, Map()) - Some(replaceFromIDs(fd.params.map(_.id).zip(args).toMap, newBody)) case _ => None - }, applyRec = true)(fd.fullBody)) + }, applyRec = true)(fd.fullBody) } filterFunDefs(p, fd => !doInline(fd)) diff --git a/src/main/scala/leon/utils/PreprocessingPhase.scala b/src/main/scala/leon/utils/PreprocessingPhase.scala index 06f35dc3a47df0a48836e6137eaae96218745d39..72103a4f70098527f4f80dba71c47a4b46c5d2d1 100644 --- a/src/main/scala/leon/utils/PreprocessingPhase.scala +++ b/src/main/scala/leon/utils/PreprocessingPhase.scala @@ -36,7 +36,6 @@ class PreprocessingPhase(desugarXLang: Boolean = false, genc: Boolean = false) e MethodLifting andThen TypingPhase andThen synthesis.ConversionPhase andThen - CheckADTFieldsTypes andThen InliningPhase val pipeX = if (!genc && desugarXLang) { diff --git a/src/main/scala/leon/utils/SeqUtils.scala b/src/main/scala/leon/utils/SeqUtils.scala index 002f2ebedc8a6dfb265fbf101c2185b3bfa17ce1..ada7499120353737d34bc79e2c9cc312d6702580 100644 --- a/src/main/scala/leon/utils/SeqUtils.scala +++ b/src/main/scala/leon/utils/SeqUtils.scala @@ -8,7 +8,6 @@ import scala.collection.mutable.ArrayBuffer object SeqUtils { type Tuple[T] = Seq[T] - def cartesianProduct[T](seqs: Tuple[Seq[T]]): Seq[Tuple[T]] = { val sizes = seqs.map(_.size) val max = sizes.product @@ -34,7 +33,10 @@ object SeqUtils { } def sumTo(sum: Int, arity: Int): Seq[Seq[Int]] = { - if (arity == 1) { + require(arity >= 1) + if (sum < arity) { + Nil + } else if (arity == 1) { Seq(Seq(sum)) } else { (1 until sum).flatMap{ n => @@ -42,6 +44,39 @@ object SeqUtils { } } } + + def sumToOrdered(sum: Int, arity: Int): Seq[Seq[Int]] = { + def rec(sum: Int, arity: Int): Seq[Seq[Int]] = { + require(arity > 0) + if (sum < 0) Nil + else if (arity == 1) Seq(Seq(sum)) + else for { + n <- 0 to sum / arity + rest <- rec(sum - arity * n, arity - 1) + } yield n +: rest.map(n + _) + } + + rec(sum, arity) filterNot (_.head == 0) + } + + def groupWhile[T](es: Seq[T])(p: T => Boolean): Seq[Seq[T]] = { + var res: Seq[Seq[T]] = Nil + + var c = es + while (!c.isEmpty) { + val (span, rest) = c.span(p) + + if (span.isEmpty) { + res :+= Seq(rest.head) + c = rest.tail + } else { + res :+= span + c = rest + } + } + + res + } } class CartesianView[+A](views: Seq[SeqView[A, Seq[A]]]) extends SeqView[Seq[A], Seq[Seq[A]]] { @@ -86,4 +121,4 @@ class CartesianView[+A](views: Seq[SeqView[A, Seq[A]]]) extends SeqView[Seq[A], ret } } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/utils/UniqueCounter.scala b/src/main/scala/leon/utils/UniqueCounter.scala index 06a6c0bb4b1badd63df38c3285c5fd8514d249fb..7c7862747271a67d899b9a590bc2d9c5fbb7de40 100644 --- a/src/main/scala/leon/utils/UniqueCounter.scala +++ b/src/main/scala/leon/utils/UniqueCounter.scala @@ -17,4 +17,5 @@ class UniqueCounter[K] { globalId } + def current = nameIds } diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index f4f603393728dd5a7b748c990486533b1cd18db6..45fa8bea46c71643c68a5a69f5f18e9318c4c449 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -125,7 +125,7 @@ object UnitElimination extends TransformationPhase { } } - LetDef(newFds, rest) + letDef(newFds, rest) } case ite@IfExpr(cond, tExpr, eExpr) => diff --git a/src/main/scala/leon/verification/InjectAsserts.scala b/src/main/scala/leon/verification/InjectAsserts.scala index 4e126827bd6cf352692c43e8433857b8894615d4..aa88b39dec1e45377aa77f87ded6a0535abac782 100644 --- a/src/main/scala/leon/verification/InjectAsserts.scala +++ b/src/main/scala/leon/verification/InjectAsserts.scala @@ -8,7 +8,6 @@ import Expressions._ import ExprOps._ import Definitions._ import Constructors._ -import xlang.Expressions._ object InjectAsserts extends SimpleLeonPhase[Program, Program] { @@ -72,6 +71,12 @@ object InjectAsserts extends SimpleLeonPhase[Program, Program] { e ).setPos(e)) + case e @ CaseClass(cct, args) if cct.root.classDef.hasInvariant => + Some(assertion(FunctionInvocation(cct.root.invariant.get, Seq(e)), + Some("ADT invariant"), + e + ).setPos(e)) + case _ => None }) diff --git a/src/main/scala/leon/verification/TraceInductionTactic.scala b/src/main/scala/leon/verification/TraceInductionTactic.scala index ef0941c494c9dc6f1362d471ffbc05db9c7265df..3177f70e8c2e1545b251acb7497cf3ad8d94cd26 100644 --- a/src/main/scala/leon/verification/TraceInductionTactic.scala +++ b/src/main/scala/leon/verification/TraceInductionTactic.scala @@ -10,7 +10,7 @@ import purescala.ExprOps._ import purescala.DefOps._ import purescala.Common._ import purescala.Types._ -import purescala.TypeOps._ +import purescala.TypeOps.instantiateType import purescala.Extractors._ import invariant.util.PredicateUtil._ import leon.utils._ @@ -21,13 +21,13 @@ import leon.utils._ */ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { val description: String = "A tactic that performs induction over the recursions of a function." - + val cg = vctx.program.callGraph - val defaultTactic = new DefaultTactic(vctx) - val deepInduct = true // a flag for enabling deep induction pattern discovery + val defaultTactic = new DefaultTactic(vctx) + val deepInduct = true // a flag for enabling deep induction pattern discovery def generatePostconditions(function: FunDef): Seq[VC] = { - assert(!cg.isRecursive(function) && function.body.isDefined) + assert(!cg.isRecursive(function) && function.body.isDefined) val inductFunname = function.extAnnotations("traceInduct") match { case Seq(Some(arg: String)) => Some(arg) case a => None @@ -35,18 +35,18 @@ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { // print debug info if(inductFunname.isDefined) ctx.reporter.debug("Extracting induction pattern from: "+inductFunname.get)(DebugSectionVerification) - + // helper function def selfRecs(fd: FunDef): Set[FunctionInvocation] = { if(fd.body.isDefined){ collect{ - case fi@FunctionInvocation(tfd, _) if tfd.fd == fd => + case fi@FunctionInvocation(tfd, _) if tfd.fd == fd => Set(fi) case _ => Set.empty[FunctionInvocation] }(fd.body.get) } else Set() } - + if (function.hasPostcondition) { // construct post(body) val prop = application(function.postcondition.get, Seq(function.body.get)) @@ -55,7 +55,7 @@ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { var funInv: Option[FunctionInvocation] = None preTraversal { case _ if funInv.isDefined => - // do nothing + // do nothing case fi @ FunctionInvocation(tfd, args) if cg.isRecursive(tfd.fd) // function is recursive => val argCheck = @@ -69,17 +69,17 @@ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { case FunctionInvocation(_, recArgs) => rest.forall { case (_, i) => calleeParams(i) == recArgs(i) } } - val paramArgs = args.filter(paramVars.contains) + val paramArgs = args.filter(paramVars.contains) paramArgs.toSet.size == paramArgs.size && // paramArgs are unique ? restInv } else { - args.forall(paramVars.contains) && // all arguments are parameters + args.forall(paramVars.contains) && // all arguments are parameters args.toSet.size == args.size // all arguments are unique } if (argCheck) { if (inductFunname.isDefined) { - if (inductFunname.get == tfd.fd.id.name) - funInv = Some(fi) + if (inductFunname.get == tfd.fd.id.name) + funInv = Some(fi) } else { funInv = Some(fi) } @@ -87,7 +87,7 @@ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { case _ => }(prop) funInv match { - case None => + case None => ctx.reporter.warning("Cannot discover induction pattern! Falling back to normal tactic.") defaultTactic.generatePostconditions(function) case Some(finv) => @@ -96,39 +96,39 @@ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { function.params, BooleanType) tactFun.precondition = function.precondition // the body of tactFun is a conjunction of induction pattern of finv, and the property - val callee = finv.tfd.fd + val callee = finv.tfd.fd val paramIndex = paramVars.zipWithIndex.toMap val framePositions = finv.args.zipWithIndex.collect { // note: the index here is w.r.t calleeArgs case (v: Variable, i) if paramVars.contains(v) => (v, i) }.toMap val footprint = paramVars.filterNot(framePositions.keySet.contains) val indexedFootprint = footprint.map { a => paramIndex(a) -> a }.toMap // index here is w.r.t params - + // the returned expression will have boolean value def inductPattern(e: Expr): Expr = { - e match { + e match { case IfExpr(c, th, el) => createAnd(Seq(inductPattern(c), IfExpr(c, inductPattern(th), inductPattern(el)))) - + case MatchExpr(scr, cases) => val scrpat = inductPattern(scr) val casePats = cases.map{ case MatchCase(pat, optGuard, rhs) => - val guardPat = optGuard.toSeq.map(inductPattern _) + val guardPat = optGuard.toSeq.map(inductPattern _) (guardPat, MatchCase(pat, optGuard, inductPattern(rhs))) } val pats = scrpat +: casePats.flatMap(_._1) :+ MatchExpr(scr, casePats.map(_._2)) createAnd(pats) - - case Let(i, v, b) => + + case Let(i, v, b) => createAnd(Seq(inductPattern(v), Let(i, v, inductPattern(b)))) - + case FunctionInvocation(tfd, args) => val argPattern = createAnd(args.map(inductPattern)) if (tfd.fd == callee) { // self recursive call ? - // create a tactFun invocation to mimic the recursion pattern - val indexedArgs = framePositions.map { + // create a tactFun invocation to mimic the recursion pattern + val indexedArgs = framePositions.map { case (f, i) => paramIndex(f) -> args(i) }.toMap ++ indexedFootprint val recArgs = (0 until indexedArgs.size).map(indexedArgs) @@ -137,7 +137,7 @@ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { } else { argPattern } - + case Operator(args, op) => // conjoin all the expressions and return them createAnd(args.map(inductPattern)) @@ -154,7 +154,7 @@ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { // postcondition is `holds` val resid = FreshIdentifier("holds", BooleanType) tactFun.postcondition = Some(Lambda(Seq(ValDef(resid)), resid.toVariable)) - + // print debug info if needed ctx.reporter.debug("Autogenerated tactic fun: "+tactFun)(DebugSectionVerification) @@ -170,9 +170,9 @@ class TraceInductionTactic(vctx: VerificationContext) extends Tactic(vctx) { } else Seq() } - def generatePreconditions(function: FunDef): Seq[VC] = + def generatePreconditions(function: FunDef): Seq[VC] = defaultTactic.generatePreconditions(function) - - def generateCorrectnessConditions(function: FunDef): Seq[VC] = + + def generateCorrectnessConditions(function: FunDef): Seq[VC] = defaultTactic.generateCorrectnessConditions(function) } diff --git a/src/main/scala/leon/verification/VerificationCondition.scala b/src/main/scala/leon/verification/VerificationCondition.scala index c19f75220ba31e4fd3142b36d09f0fbaedde2f5d..299307a87e6e1e4bd2470c377966419bc8c90f15 100644 --- a/src/main/scala/leon/verification/VerificationCondition.scala +++ b/src/main/scala/leon/verification/VerificationCondition.scala @@ -7,7 +7,6 @@ import leon.purescala.Definitions._ import leon.purescala.Types._ import leon.purescala.PrettyPrinter import leon.utils.Positioned -import leon.evaluators.StringTracingEvaluator import leon.solvers._ import leon.LeonContext import leon.purescala.SelfPrettyPrinter diff --git a/src/main/scala/leon/verification/VerificationReport.scala b/src/main/scala/leon/verification/VerificationReport.scala index d695f4ca0877382eeb83838a3cf08267ee935cb6..2875bf3ed074756c5089896c327b487cc029fe34 100644 --- a/src/main/scala/leon/verification/VerificationReport.scala +++ b/src/main/scala/leon/verification/VerificationReport.scala @@ -3,7 +3,6 @@ package leon package verification -import evaluators.StringTracingEvaluator import utils.DebugSectionSynthesis import utils.DebugSectionVerification import leon.purescala @@ -22,7 +21,7 @@ import purescala.Common._ import purescala.Expressions._ import purescala.Definitions._ import purescala.SelfPrettyPrinter -import leon.solvers.{ HenkinModel, Model, SolverFactory } +import leon.solvers.{ PartialModel, Model, SolverFactory } case class VerificationReport(program: Program, results: Map[VC, Option[VCResult]]) { val vrs: Seq[(VC, VCResult)] = results.toSeq.sortBy { case (vc, _) => (vc.fd.id.name, vc.kind.toString) }.map { diff --git a/src/main/scala/leon/xlang/AntiAliasingPhase.scala b/src/main/scala/leon/xlang/AntiAliasingPhase.scala new file mode 100644 index 0000000000000000000000000000000000000000..7eb391088ea64c276aa0bbf7836e4a9511aa978c --- /dev/null +++ b/src/main/scala/leon/xlang/AntiAliasingPhase.scala @@ -0,0 +1,383 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ +package leon.xlang + +import leon.TransformationPhase +import leon.LeonContext +import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.Expressions._ +import leon.purescala.ExprOps._ +import leon.purescala.DefOps._ +import leon.purescala.Types._ +import leon.purescala.Constructors._ +import leon.purescala.Extractors._ +import leon.xlang.Expressions._ + +object AntiAliasingPhase extends TransformationPhase { + + val name = "Anti-Aliasing" + val description = "Make aliasing explicit" + + override def apply(ctx: LeonContext, pgm: Program): Program = { + val fds = allFunDefs(pgm) + fds.foreach(fd => checkAliasing(fd)(ctx)) + + var updatedFunctions: Map[FunDef, FunDef] = Map() + + val effects = effectsAnalysis(pgm) + + //for each fun def, all the vars the the body captures. Only + //mutable types. + val varsInScope: Map[FunDef, Set[Identifier]] = (for { + fd <- fds + } yield { + val allFreeVars = fd.body.map(bd => variablesOf(bd)).getOrElse(Set()) + val freeVars = allFreeVars -- fd.params.map(_.id) + val mutableFreeVars = freeVars.filter(id => id.getType.isInstanceOf[ArrayType]) + (fd, mutableFreeVars) + }).toMap + + /* + * The first pass will introduce all new function definitions, + * so that in the next pass we can update function invocations + */ + for { + fd <- fds + } { + updatedFunctions += (fd -> updateFunDef(fd, effects)(ctx)) + } + + for { + fd <- fds + } { + updateBody(fd, effects, updatedFunctions, varsInScope)(ctx) + } + + val res = replaceFunDefs(pgm)(fd => updatedFunctions.get(fd), (fi, fd) => None) + //println(res._1) + res._1 + } + + /* + * Create a new FunDef for a given FunDef in the program. + * Adapt the signature to express its effects. In case the + * function has no effect, this will still introduce a fresh + * FunDef as the body might have to be updated anyway. + */ + private def updateFunDef(fd: FunDef, effects: Effects)(ctx: LeonContext): FunDef = { + + val ownEffects = effects(fd) + val aliasedParams: Seq[Identifier] = fd.params.zipWithIndex.flatMap{ + case (vd, i) if ownEffects.contains(i) => Some(vd) + case _ => None + }.map(_.id) + + fd.body.foreach(body => getReturnedExpr(body).foreach{ + case v@Variable(id) if aliasedParams.contains(id) => + ctx.reporter.fatalError(v.getPos, "Cannot return a shared reference to a mutable object") + case _ => () + }) + //val allBodies: Set[Expr] = + // fd.body.toSet.flatMap((bd: Expr) => nestedFunDefsOf(bd).flatMap(_.body)) ++ fd.body + //allBodies.foreach(body => getReturnedExpr(body).foreach{ + // case v@Variable(id) if aliasedParams.contains(id) => + // ctx.reporter.fatalError(v.getPos, "Cannot return a shared reference to a mutable object: "k+ v) + // case _ => () + //}) + + val newReturnType: TypeTree = if(aliasedParams.isEmpty) fd.returnType else { + tupleTypeWrap(fd.returnType +: aliasedParams.map(_.getType)) + } + val newFunDef = new FunDef(fd.id.freshen, fd.tparams, fd.params, newReturnType) + newFunDef.addFlags(fd.flags) + newFunDef.setPos(fd) + newFunDef + } + + private def updateBody(fd: FunDef, effects: Effects, updatedFunDefs: Map[FunDef, FunDef], varsInScope: Map[FunDef, Set[Identifier]]) + (ctx: LeonContext): Unit = { + + val ownEffects = effects(fd) + val aliasedParams: Seq[Identifier] = fd.params.zipWithIndex.flatMap{ + case (vd, i) if ownEffects.contains(i) => Some(vd) + case _ => None + }.map(_.id) + + val newFunDef = updatedFunDefs(fd) + + if(aliasedParams.isEmpty) { + val newBody = fd.body.map(body => { + makeSideEffectsExplicit(body, Seq(), effects, updatedFunDefs, varsInScope)(ctx) + }) + newFunDef.body = newBody + newFunDef.precondition = fd.precondition + newFunDef.postcondition = fd.postcondition + } else { + val freshLocalVars: Seq[Identifier] = aliasedParams.map(v => v.freshen) + val rewritingMap: Map[Identifier, Identifier] = aliasedParams.zip(freshLocalVars).toMap + + val newBody = fd.body.map(body => { + + val freshBody = replaceFromIDs(rewritingMap.map(p => (p._1, p._2.toVariable)), body) + val explicitBody = makeSideEffectsExplicit(freshBody, freshLocalVars, effects, updatedFunDefs, varsInScope)(ctx) + + //WARNING: only works if side effects in Tuples are extracted from left to right, + // in the ImperativeTransformation phase. + val finalBody: Expr = Tuple(explicitBody +: freshLocalVars.map(_.toVariable)) + + freshLocalVars.zip(aliasedParams).foldLeft(finalBody)((bd, vp) => { + LetVar(vp._1, Variable(vp._2), bd) + }) + + }) + + val newPostcondition = fd.postcondition.map(post => { + val Lambda(Seq(res), postBody) = post + val newRes = ValDef(FreshIdentifier(res.id.name, newFunDef.returnType)) + val newBody = + replace( + aliasedParams.zipWithIndex.map{ case (id, i) => + (id.toVariable, TupleSelect(newRes.toVariable, i+2)): (Expr, Expr)}.toMap ++ + aliasedParams.map(id => + (Old(id), id.toVariable): (Expr, Expr)).toMap + + (res.toVariable -> TupleSelect(newRes.toVariable, 1)), + postBody) + Lambda(Seq(newRes), newBody).setPos(post) + }) + + newFunDef.body = newBody + newFunDef.precondition = fd.precondition + newFunDef.postcondition = newPostcondition + } + } + + //We turn all local val of mutable objects into vars and explicit side effects + //using assignments. We also make sure that no aliasing is being done. + private def makeSideEffectsExplicit + (body: Expr, aliasedParams: Seq[Identifier], effects: Effects, updatedFunDefs: Map[FunDef, FunDef], varsInScope: Map[FunDef, Set[Identifier]]) + (ctx: LeonContext): Expr = { + preMapWithContext[Set[Identifier]]((expr, bindings) => expr match { + + case up@ArrayUpdate(a, i, v) => { + val ra@Variable(id) = a + if(bindings.contains(id)) + (Some(Assignment(id, ArrayUpdated(ra, i, v).setPos(up)).setPos(up)), bindings) + else + (None, bindings) + } + + case l@Let(id, IsTyped(v, ArrayType(_)), b) => { + val varDecl = LetVar(id, v, b).setPos(l) + (Some(varDecl), bindings + id) + } + + case l@LetVar(id, IsTyped(v, ArrayType(_)), b) => { + (None, bindings + id) + } + + //we need to replace local fundef by the new updated fun defs. + case l@LetDef(fds, body) => { + //this might be traversed several time in case of doubly nested fundef, + //so we need to ignore the second times by checking if updatedFunDefs + //contains a mapping or not + val nfds = fds.map(fd => updatedFunDefs.get(fd).getOrElse(fd)) + (Some(LetDef(nfds, body)), bindings) + } + + case fi@FunctionInvocation(fd, args) => { + + val vis: Set[Identifier] = varsInScope.get(fd.fd).getOrElse(Set()) + args.find({ + case Variable(id) => vis.contains(id) + case _ => false + }).foreach(aliasedArg => + ctx.reporter.fatalError(aliasedArg.getPos, "Illegal passing of aliased parameter: " + aliasedArg)) + + updatedFunDefs.get(fd.fd) match { + case None => (None, bindings) + case Some(nfd) => { + val nfi = FunctionInvocation(nfd.typed(fd.tps), args).setPos(fi) + val fiEffects = effects.getOrElse(fd.fd, Set()) + if(fiEffects.nonEmpty) { + val modifiedArgs: Seq[Variable] = + args.zipWithIndex.filter{ case (arg, i) => fiEffects.contains(i) } + .map(_._1.asInstanceOf[Variable]) + + val duplicatedParams = modifiedArgs.diff(modifiedArgs.distinct).distinct + if(duplicatedParams.nonEmpty) + ctx.reporter.fatalError(fi.getPos, "Illegal passing of aliased parameter: " + duplicatedParams.head) + + val freshRes = FreshIdentifier("res", nfd.returnType) + + val extractResults = Block( + modifiedArgs.zipWithIndex.map(p => Assignment(p._1.id, TupleSelect(freshRes.toVariable, p._2 + 2))), + TupleSelect(freshRes.toVariable, 1)) + + + val newExpr = Let(freshRes, nfi, extractResults) + (Some(newExpr), bindings) + } else { + (Some(nfi), bindings) + } + } + } + } + + case _ => (None, bindings) + + })(body, aliasedParams.toSet) + } + + //TODO: in the future, any object with vars could be aliased and so + // we will need a general property + + private type Effects = Map[FunDef, Set[Int]] + + /* + * compute effects for each function in the program, including any nested + * functions (LetDef). + */ + private def effectsAnalysis(pgm: Program): Effects = { + + //currently computed effects (incremental) + var effects: Map[FunDef, Set[Int]] = Map() + //missing dependencies for a function for its effects to be complete + var missingEffects: Map[FunDef, Set[FunctionInvocation]] = Map() + + def effectsFullyComputed(fd: FunDef): Boolean = !missingEffects.isDefinedAt(fd) + + for { + fd <- allFunDefs(pgm) + } { + fd.body match { + case None => + effects += (fd -> Set()) + case Some(body) => { + val mutableParams = fd.params.filter(vd => vd.getType match { + case ArrayType(_) => true + case _ => false + }) + val mutatedParams = mutableParams.filter(vd => exists { + case ArrayUpdate(Variable(a), _, _) => a == vd.id + case _ => false + }(body)) + val mutatedParamsIndices = fd.params.zipWithIndex.flatMap{ + case (vd, i) if mutatedParams.contains(vd) => Some(i) + case _ => None + }.toSet + effects = effects + (fd -> mutatedParamsIndices) + + val missingCalls: Set[FunctionInvocation] = functionCallsOf(body).filterNot(fi => fi.tfd.fd == fd) + if(missingCalls.nonEmpty) + missingEffects += (fd -> missingCalls) + } + } + } + + def rec(): Unit = { + val previousMissingEffects = missingEffects + + for{ (fd, calls) <- missingEffects } { + var newMissingCalls: Set[FunctionInvocation] = calls + for(fi <- calls) { + val mutatedArgs = invocEffects(fi) + val mutatedFunParams: Set[Int] = fd.params.zipWithIndex.flatMap{ + case (vd, i) if mutatedArgs.contains(vd.id) => Some(i) + case _ => None + }.toSet + effects += (fd -> (effects(fd) ++ mutatedFunParams)) + + if(effectsFullyComputed(fi.tfd.fd)) { + newMissingCalls -= fi + } + } + if(newMissingCalls.isEmpty) + missingEffects = missingEffects - fd + else + missingEffects += (fd -> newMissingCalls) + } + + if(missingEffects != previousMissingEffects) { + rec() + } + } + + def invocEffects(fi: FunctionInvocation): Set[Identifier] = { + //TODO: the require should be fine once we consider nested functions as well + //require(effects.isDefinedAt(fi.tfd.fd) + val mutatedParams: Set[Int] = effects.get(fi.tfd.fd).getOrElse(Set()) + fi.args.zipWithIndex.flatMap{ + case (Variable(id), i) if mutatedParams.contains(i) => Some(id) + case _ => None + }.toSet + } + + rec() + effects + } + + + def checkAliasing(fd: FunDef)(ctx: LeonContext): Unit = { + def checkReturnValue(body: Expr, bindings: Set[Identifier]): Unit = { + getReturnedExpr(body).foreach{ + case IsTyped(v@Variable(id), ArrayType(_)) if bindings.contains(id) => + ctx.reporter.fatalError(v.getPos, "Cannot return a shared reference to a mutable object: " + v) + case _ => () + } + } + + fd.body.foreach(bd => { + val params = fd.params.map(_.id).toSet + checkReturnValue(bd, params) + preMapWithContext[Set[Identifier]]((expr, bindings) => expr match { + case l@Let(id, IsTyped(v, ArrayType(_)), b) => { + v match { + case FiniteArray(_, _, _) => () + case FunctionInvocation(_, _) => () + case ArrayUpdated(_, _, _) => () + case _ => ctx.reporter.fatalError(l.getPos, "Cannot alias array: " + l) + } + (None, bindings + id) + } + case l@LetVar(id, IsTyped(v, ArrayType(_)), b) => { + v match { + case FiniteArray(_, _, _) => () + case FunctionInvocation(_, _) => () + case ArrayUpdated(_, _, _) => () + case _ => ctx.reporter.fatalError(l.getPos, "Cannot alias array: " + l) + } + (None, bindings + id) + } + case l@LetDef(fds, body) => { + fds.foreach(fd => fd.body.foreach(bd => checkReturnValue(bd, bindings))) + (None, bindings) + } + + case _ => (None, bindings) + })(bd, params) + }) + } + + /* + * A bit hacky, but not sure of the best way to do something like that + * currently. + */ + private def getReturnedExpr(expr: Expr): Seq[Expr] = expr match { + case Let(_, _, rest) => getReturnedExpr(rest) + case LetVar(_, _, rest) => getReturnedExpr(rest) + case Block(_, rest) => getReturnedExpr(rest) + case IfExpr(_, thenn, elze) => getReturnedExpr(thenn) ++ getReturnedExpr(elze) + case MatchExpr(_, cses) => cses.flatMap{ cse => getReturnedExpr(cse.rhs) } + case e => Seq(expr) + } + + + /* + * returns all fun def in the program, including local definitions inside + * other functions (LetDef). + */ + private def allFunDefs(pgm: Program): Seq[FunDef] = + pgm.definedFunctions.flatMap(fd => + fd.body.toSet.flatMap((bd: Expr) => + nestedFunDefsOf(bd)) + fd) +} diff --git a/src/main/scala/leon/xlang/Expressions.scala b/src/main/scala/leon/xlang/Expressions.scala index d627e0d284f4933cd6ed7ecefbe7dd13f4e658f8..98214ee640bd95227c0b759113ab74d2c9555d94 100644 --- a/src/main/scala/leon/xlang/Expressions.scala +++ b/src/main/scala/leon/xlang/Expressions.scala @@ -15,6 +15,14 @@ object Expressions { trait XLangExpr extends Expr + case class Old(id: Identifier) extends XLangExpr with Terminal with PrettyPrintable { + val getType = id.getType + + def printWith(implicit pctx: PrinterContext): Unit = { + p"old($id)" + } + } + case class Block(exprs: Seq[Expr], last: Expr) extends XLangExpr with Extractable with PrettyPrintable { def extract: Option[(Seq[Expr], (Seq[Expr])=>Expr)] = { Some((exprs :+ last, exprs => Block(exprs.init, exprs.last))) diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index 45bb36770cddca417ba51582cd7824fc09152199..6b7f7cc6ee3c00827289313ed60e111b6ec3a640 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -9,7 +9,7 @@ import leon.purescala.Expressions._ import leon.purescala.Extractors._ import leon.purescala.Constructors._ import leon.purescala.ExprOps._ -import leon.purescala.TypeOps._ +import leon.purescala.TypeOps.leastUpperBound import leon.purescala.Types._ import leon.xlang.Expressions._ @@ -67,7 +67,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] { val (tRes, tScope, tFun) = toFunction(tExpr) val (eRes, eScope, eFun) = toFunction(eExpr) - val iteRType = leastUpperBound(tRes.getType, eRes.getType).get + val iteRType = leastUpperBound(tRes.getType, eRes.getType).getOrElse(Untyped) val modifiedVars: Seq[Identifier] = (tFun.keys ++ eFun.keys).toSet.intersect(varsInScope).toSeq val resId = FreshIdentifier("res", iteRType) @@ -218,7 +218,7 @@ object ImperativeCodeElimination extends UnitPhase[Program] { case LetDef(fds, b) => if(fds.size > 1) { - //TODO: no support for true mutually recursion + //TODO: no support for true mutual recursion toFunction(LetDef(Seq(fds.head), LetDef(fds.tail, b))) } else { diff --git a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala index 3a7f8be381cfa5fdb87dc6870cf35141f8d8cf33..59dd3217714f964be5f9ff6a1428617c6e54e17a 100644 --- a/src/main/scala/leon/xlang/XLangDesugaringPhase.scala +++ b/src/main/scala/leon/xlang/XLangDesugaringPhase.scala @@ -12,7 +12,8 @@ object XLangDesugaringPhase extends LeonPhase[Program, Program] { override def run(ctx: LeonContext, pgm: Program): (LeonContext, Program) = { val phases = - ArrayTransformation andThen + //ArrayTransformation andThen + AntiAliasingPhase andThen EpsilonElimination andThen ImperativeCodeElimination diff --git a/src/test/resources/regression/frontends/error/xlang/Array2.scala b/src/test/resources/regression/frontends/error/xlang/Array2.scala deleted file mode 100644 index b1b370395d7e0b648e0b88875b3678eaf4668eb5..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array2.scala +++ /dev/null @@ -1,11 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -object Array2 { - - def foo(): Int = { - val a = Array.fill(5)(5) - val b = a - b(3) - } - -} diff --git a/src/test/resources/regression/frontends/error/xlang/Array3.scala b/src/test/resources/regression/frontends/error/xlang/Array3.scala deleted file mode 100644 index 14a8512015102bd235a9efe41782ba7d5a46fd44..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array3.scala +++ /dev/null @@ -1,14 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -object Array3 { - - def foo(): Int = { - val a = Array.fill(5)(5) - if(a.length > 2) - a(1) = 2 - else - 0 - 0 - } - -} diff --git a/src/test/resources/regression/frontends/error/xlang/Array4.scala b/src/test/resources/regression/frontends/error/xlang/Array4.scala deleted file mode 100644 index e41535d6d267986ba7764a6f457970c6ab33b733..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array4.scala +++ /dev/null @@ -1,10 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -object Array4 { - - def foo(a: Array[Int]): Int = { - val b = a - b(3) - } - -} diff --git a/src/test/resources/regression/frontends/error/xlang/Array5.scala b/src/test/resources/regression/frontends/error/xlang/Array5.scala deleted file mode 100644 index 8b7254e9482ddc7df1196c03b1a191c54e86ea0f..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array5.scala +++ /dev/null @@ -1,12 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -object Array5 { - - def foo(a: Array[Int]): Int = { - a(2) = 4 - a(2) - } - -} - -// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/frontends/error/xlang/Array6.scala b/src/test/resources/regression/frontends/error/xlang/Array6.scala deleted file mode 100644 index c4d0c09541d3c7fe4ea4be1527cd704e96d54bb1..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array6.scala +++ /dev/null @@ -1,12 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - - -object Array6 { - - def foo(): Int = { - val a = Array.fill(5)(5) - var b = a - b(0) - } - -} diff --git a/src/test/resources/regression/frontends/error/xlang/Array7.scala b/src/test/resources/regression/frontends/error/xlang/Array7.scala deleted file mode 100644 index ab6f4c20da5a84adfd8509a60174d78c0c423654..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/frontends/error/xlang/Array7.scala +++ /dev/null @@ -1,11 +0,0 @@ -/* Copyright 2009-2015 EPFL, Lausanne */ - -object Array7 { - - def foo(): Int = { - val a = Array.fill(5)(5) - var b = a - b(0) - } - -} diff --git a/src/test/resources/regression/termination/valid/Ackermann.scala b/src/test/resources/regression/termination/valid/Ackermann.scala new file mode 100644 index 0000000000000000000000000000000000000000..11ea76bee4ae9ae9c8225642d413d0f9054691f1 --- /dev/null +++ b/src/test/resources/regression/termination/valid/Ackermann.scala @@ -0,0 +1,10 @@ +import leon.lang._ + +object Ackermann { + def ackermann(m: BigInt, n: BigInt): BigInt = { + require(m >= 0 && n >= 0) + if (m == 0) n + 1 + else if (n == 0) ackermann(m - 1, 1) + else ackermann(m - 1, ackermann(m, n - 1)) + } +} diff --git a/src/test/resources/regression/verification/purescala/invalid/AssociativityProperties.scala b/src/test/resources/regression/verification/purescala/invalid/AssociativityProperties.scala new file mode 100644 index 0000000000000000000000000000000000000000..b143816e8d5a28437701567b2459511882a805dd --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/AssociativityProperties.scala @@ -0,0 +1,26 @@ +import leon.lang._ + +object AssociativityProperties { + + def isAssociative[A](f: (A,A) => A): Boolean = { + forall((x: A, y: A, z: A) => f(f(x, y), z) == f(x, f(y, z))) + } + + def isCommutative[A](f: (A,A) => A): Boolean = { + forall((x: A, y: A) => f(x, y) == f(y, x)) + } + + def isRotate[A](f: (A,A) => A): Boolean = { + forall((x: A, y: A, z: A) => f(f(x, y), z) == f(f(y, z), x)) + } + + def assocNotCommutative[A](f: (A,A) => A): Boolean = { + require(isAssociative(f)) + isCommutative(f) + }.holds + + def commNotAssociative[A](f: (A,A) => A): Boolean = { + require(isCommutative(f)) + isAssociative(f) + }.holds +} diff --git a/src/test/resources/regression/verification/purescala/invalid/InductiveQuantification.scala b/src/test/resources/regression/verification/purescala/invalid/InductiveQuantification.scala deleted file mode 100644 index 970b3b53bddb295db19308e3443309f478ee15cc..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/verification/purescala/invalid/InductiveQuantification.scala +++ /dev/null @@ -1,17 +0,0 @@ -import leon.lang._ - -object SizeInc { - - abstract class List[A] - case class Cons[A](head: A, tail: List[A]) extends List[A] - case class Nil[A]() extends List[A] - - def failling_1[A](x: List[A]): Int => Int = { - (i: Int) => x match { - case Cons(head, tail) => 1 + failling_1(tail)(i) - case Nil() => i - } - } ensuring { res => forall((a: Int) => res(a) > 0) } -} - -// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/purescala/valid/AssociativityProperties.scala b/src/test/resources/regression/verification/purescala/valid/AssociativityProperties.scala new file mode 100644 index 0000000000000000000000000000000000000000..5c8530615dff1cdc15868be54a2b0019acba1844 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/AssociativityProperties.scala @@ -0,0 +1,33 @@ +import leon.lang._ + +object AssociativityProperties { + + def isAssociative[A](f: (A,A) => A): Boolean = { + forall((x: A, y: A, z: A) => f(f(x, y), z) == f(x, f(y, z))) + } + + def isCommutative[A](f: (A,A) => A): Boolean = { + forall((x: A, y: A) => f(x, y) == f(y, x)) + } + + def isRotate[A](f: (A,A) => A): Boolean = { + forall((x: A, y: A, z: A) => f(f(x, y), z) == f(f(y, z), x)) + } + + def assocPairs[A,B](f1: (A,A) => A, f2: (B,B) => B) = { + require(isAssociative(f1) && isAssociative(f2)) + val fp = ((p1: (A,B), p2: (A,B)) => (f1(p1._1, p2._1), f2(p1._2, p2._2))) + isAssociative(fp) + }.holds + + def assocRotate[A](f: (A,A) => A): Boolean = { + require(isCommutative(f) && isRotate(f)) + isAssociative(f) + }.holds + + def assocRotateInt(f: (BigInt, BigInt) => BigInt): Boolean = { + require(isCommutative(f) && isRotate(f)) + isAssociative(f) + }.holds + +} diff --git a/src/test/resources/regression/verification/purescala/valid/InductiveQuantification.scala b/src/test/resources/regression/verification/purescala/valid/InductiveQuantification.scala deleted file mode 100644 index 4883043040fdb90df294aac09b076d16c7628cf4..0000000000000000000000000000000000000000 --- a/src/test/resources/regression/verification/purescala/valid/InductiveQuantification.scala +++ /dev/null @@ -1,26 +0,0 @@ -import leon.lang._ - -object SizeInc { - - abstract class List[A] - case class Cons[A](head: A, tail: List[A]) extends List[A] - case class Nil[A]() extends List[A] - - def sizeInc[A](x: List[A]): BigInt => BigInt = { - (i: BigInt) => x match { - case Cons(head, tail) => 1 + sizeInc(tail)(i) - case Nil() => i - } - } ensuring { res => forall((a: BigInt) => a > 0 ==> res(a) > 0) } - - def sizeInc2[A](x: BigInt): List[A] => BigInt = { - require(x > 0) - - (list: List[A]) => list match { - case Cons(head, tail) => 1 + sizeInc2(x)(tail) - case Nil() => x - } - } ensuring { res => forall((a: List[A]) => res(a) > 0) } -} - -// vim: set ts=4 sw=4 et: diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation1.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation1.scala new file mode 100644 index 0000000000000000000000000000000000000000..f0e622c66c31e4fe01ff8be5c1b08e17d4d330eb --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation1.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object ArrayParamMutation1 { + + def update(a: Array[BigInt]): Unit = { + require(a.length > 0) + a(0) = 10 + } + + def f(): BigInt = { + val a = Array.fill(10)(BigInt(0)) + update(a) + a(0) + } ensuring(res => res == 10) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation2.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation2.scala new file mode 100644 index 0000000000000000000000000000000000000000..801b35e0cc545f160fc8061e34fd0ee06b7c3f73 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation2.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object ArrayParamMutation2 { + + def rec(a: Array[BigInt]): BigInt = { + require(a.length > 1 && a(0) >= 0) + if(a(0) == 0) + a(1) + else { + a(0) = a(0) - 1 + a(1) = a(1) + a(0) + rec(a) + } + } ensuring(res => a(0) == 0) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation3.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation3.scala new file mode 100644 index 0000000000000000000000000000000000000000..f575167444f839c0ee900a35de5e4e822624dc21 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation3.scala @@ -0,0 +1,23 @@ +import leon.lang._ + +object ArrayParamMutation3 { + + def odd(a: Array[BigInt]): Boolean = { + require(a.length > 0 && a(0) >= 0) + if(a(0) == 0) false + else { + a(0) = a(0) - 1 + even(a) + } + } ensuring(res => a(0) == 0) + + def even(a: Array[BigInt]): Boolean = { + require(a.length > 0 && a(0) >= 0) + if(a(0) == 0) true + else { + a(0) = a(0) - 1 + odd(a) + } + } ensuring(res => a(0) == 0) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation4.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation4.scala new file mode 100644 index 0000000000000000000000000000000000000000..31af4cd5885ea66a2d11e9387ba7e306423ec4d7 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation4.scala @@ -0,0 +1,23 @@ +import leon.lang._ + +object ArrayParamMutation4 { + + def multipleArgs(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0) + if(a1(0) == 10) + a2(0) = 13 + else + a2(0) = a1(0) + 1 + } + + def transitiveEffects(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0) + multipleArgs(a1, a2) + } ensuring(_ => a2(0) >= a1(0)) + + def transitiveReverseEffects(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0) + multipleArgs(a2, a1) + } ensuring(_ => a1(0) >= a2(0)) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation5.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation5.scala new file mode 100644 index 0000000000000000000000000000000000000000..249a79d1f3b7d8df8c941ab3121c4eafed149e03 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation5.scala @@ -0,0 +1,21 @@ + +import leon.lang._ + +object ArrayParamMutation5 { + + def mutuallyRec1(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a1(0) > 0 && a2.length > 0) + if(a1(0) == 10) { + () + } else { + mutuallyRec2(a1, a2) + } + } ensuring(res => a1(0) == 10) + + def mutuallyRec2(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0 && a1(0) > 0) + a1(0) = 10 + mutuallyRec1(a1, a2) + } + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation6.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation6.scala new file mode 100644 index 0000000000000000000000000000000000000000..29ded427fa6546a103d8da6f98cefc1415f389a6 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation6.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object ArrayParamMutation6 { + + def multipleEffects(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0) + a1(0) = 11 + a2(0) = 12 + } ensuring(_ => a1(0) != a2(0)) + + def f(a1: Array[BigInt], a2: Array[BigInt]): Unit = { + require(a1.length > 0 && a2.length > 0) + multipleEffects(a1, a2) + } ensuring(_ => a1(0) == 11 && a2(0) == 12) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation7.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation7.scala new file mode 100644 index 0000000000000000000000000000000000000000..53d67729fd57723d1693c564e99cd3d66ee095ef --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation7.scala @@ -0,0 +1,29 @@ +import leon.lang._ + +object ArrayParamMutation7 { + + def f(i: BigInt)(implicit world: Array[BigInt]): BigInt = { + require(world.length == 3) + + world(1) += 1 //global counter of f + + val res = i*i + world(0) = res + res + } + + def mainProgram(): Unit = { + + implicit val world: Array[BigInt] = Array(0,0,0) + + f(1) + assert(world(0) == 1) + f(2) + assert(world(0) == 4) + f(4) + assert(world(0) == 16) + + assert(world(1) == 3) + } + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation8.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation8.scala new file mode 100644 index 0000000000000000000000000000000000000000..68aa737eb42e6e51073dac27b799e06bde928400 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation8.scala @@ -0,0 +1,25 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +import leon.lang._ + +object ArrayParamMutation8 { + + def odd(a: Array[BigInt]): Boolean = { + require(a.length > 0 && a(0) >= 0) + if(a(0) == 0) false + else { + a(0) = a(0) - 1 + even(a) + } + } ensuring(res => if(old(a)(0) % 2 == 1) res else !res) + + def even(a: Array[BigInt]): Boolean = { + require(a.length > 0 && a(0) >= 0) + if(a(0) == 0) true + else { + a(0) = a(0) - 1 + odd(a) + } + } ensuring(res => if(old(a)(0) % 2 == 0) res else !res) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation9.scala b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation9.scala new file mode 100644 index 0000000000000000000000000000000000000000..f5046b6cf3b40382ccc5d989d81d73bc577da9f7 --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/ArrayParamMutation9.scala @@ -0,0 +1,22 @@ +import leon.lang._ + +object ArrayParamMutation9 { + def abs(a: Array[Int]) { + require(a.length > 0) + var i = 0; + (while (i < a.length) { + a(i) = if (a(i) < 0) -a(i) else a(i) // <-- this makes Leon crash + i = i + 1 + }) invariant(i >= 0) + } + + + def main = { + val a = Array(0, -1, 2, -3) + + abs(a) + + a(0) + a(1) - 1 + a(2) - 2 + a(3) - 3 // == 0 + } + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation1.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation1.scala new file mode 100644 index 0000000000000000000000000000000000000000..a7250a7bcfd572c49584110d213f9e9991a10c9f --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation1.scala @@ -0,0 +1,16 @@ +import leon.lang._ + +object NestedFunParamsMutation1 { + + def f(): Int = { + def g(a: Array[Int]): Unit = { + require(a.length > 0) + a(0) = 10 + } + + val a = Array(1,2,3,4) + g(a) + a(0) + } ensuring(_ == 10) + +} diff --git a/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation2.scala b/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation2.scala new file mode 100644 index 0000000000000000000000000000000000000000..799a87c6e9e70bf6ef89bfc1fb7a6e116adb7feb --- /dev/null +++ b/src/test/resources/regression/verification/xlang/valid/NestedFunParamsMutation2.scala @@ -0,0 +1,21 @@ +import leon.lang._ + +object NestedFunParamsMutation2 { + + def f(): Int = { + def g(a: Array[Int]): Unit = { + require(a.length > 0) + a(0) = 10 + } + + def h(a: Array[Int]): Unit = { + require(a.length > 0) + g(a) + } + + val a = Array(1,2,3,4) + h(a) + a(0) + } ensuring(_ == 10) + +} diff --git a/src/test/resources/regression/frontends/error/xlang/Array1.scala b/src/test/resources/regression/xlang/error/Array1.scala similarity index 100% rename from src/test/resources/regression/frontends/error/xlang/Array1.scala rename to src/test/resources/regression/xlang/error/Array1.scala diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing1.scala b/src/test/resources/regression/xlang/error/ArrayAliasing1.scala new file mode 100644 index 0000000000000000000000000000000000000000..30b1652dac16f1f8a9c7b36d83d9a0a52811e3c6 --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing1.scala @@ -0,0 +1,13 @@ +import leon.lang._ + +object ArrayAliasing1 { + + def f1(): BigInt = { + val a = Array.fill(10)(BigInt(0)) + val b = a + b(0) = 10 + a(0) + } ensuring(_ == 10) + +} + diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing10.scala b/src/test/resources/regression/xlang/error/ArrayAliasing10.scala new file mode 100644 index 0000000000000000000000000000000000000000..05737b03d9816be72cf52f4913f57a482abf76dd --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing10.scala @@ -0,0 +1,19 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +object ArrayAliasing10 { + + def foo(): Int = { + val a = Array.fill(5)(0) + + def rec(): Array[Int] = { + + def nestedRec(): Array[Int] = { + a + } + nestedRec() + } + val b = rec() + b(0) + } + +} diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing2.scala b/src/test/resources/regression/xlang/error/ArrayAliasing2.scala new file mode 100644 index 0000000000000000000000000000000000000000..4e906865a8848aaa00150c67de16f6b32136c64a --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing2.scala @@ -0,0 +1,11 @@ +import leon.lang._ + +object ArrayAliasing2 { + + def f1(a: Array[BigInt]): BigInt = { + val b = a + b(0) = 10 + a(0) + } ensuring(_ == 10) + +} diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing3.scala b/src/test/resources/regression/xlang/error/ArrayAliasing3.scala new file mode 100644 index 0000000000000000000000000000000000000000..0398fc37b9dc2028e1535878ec377bef1620dd88 --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing3.scala @@ -0,0 +1,11 @@ +import leon.lang._ + +object ArrayAliasing3 { + + def f1(a: Array[BigInt], b: Boolean): BigInt = { + val c = if(b) a else Array[BigInt](1,2,3,4,5) + c(0) = 10 + a(0) + } ensuring(_ == 10) + +} diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing4.scala b/src/test/resources/regression/xlang/error/ArrayAliasing4.scala new file mode 100644 index 0000000000000000000000000000000000000000..2632782c39e853744744b66309ef10342bee386b --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing4.scala @@ -0,0 +1,11 @@ +import leon.lang._ + +object ArrayAliasing4 { + + def f1(a: Array[BigInt]): Array[BigInt] = { + require(a.length > 0) + a(0) = 10 + a + } ensuring(res => res(0) == 10) + +} diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing5.scala b/src/test/resources/regression/xlang/error/ArrayAliasing5.scala new file mode 100644 index 0000000000000000000000000000000000000000..b9363d1ab5a627df29e6a0f0018c73850dcbb529 --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing5.scala @@ -0,0 +1,18 @@ +import leon.lang._ + +object ArrayAliasing5 { + + + def f1(a: Array[BigInt], b: Array[BigInt]): Unit = { + require(a.length > 0 && b.length > 0) + a(0) = 10 + b(0) = 20 + } ensuring(_ => a(0) == 10 && b(0) == 20) + + + def callWithAliases(): Unit = { + val a = Array[BigInt](0,0,0,0) + f1(a, a) + } + +} diff --git a/src/test/resources/regression/frontends/error/xlang/Array8.scala b/src/test/resources/regression/xlang/error/ArrayAliasing6.scala similarity index 80% rename from src/test/resources/regression/frontends/error/xlang/Array8.scala rename to src/test/resources/regression/xlang/error/ArrayAliasing6.scala index bbe5bd5fd92b0f4f9662379693d06924bdaf5461..963a134bf71da7a625252411854d73272f56d574 100644 --- a/src/test/resources/regression/frontends/error/xlang/Array8.scala +++ b/src/test/resources/regression/xlang/error/ArrayAliasing6.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -object Array8 { +object ArrayAliasing6 { def foo(a: Array[Int]): Array[Int] = { a diff --git a/src/test/resources/regression/xlang/error/ArrayAliasing7.scala b/src/test/resources/regression/xlang/error/ArrayAliasing7.scala new file mode 100644 index 0000000000000000000000000000000000000000..21bc94502327b334f2e4e5887d4a7286731c78b7 --- /dev/null +++ b/src/test/resources/regression/xlang/error/ArrayAliasing7.scala @@ -0,0 +1,10 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +object ArrayAliasing7 { + + def foo(a: Array[Int]): Array[Int] = { + val b = a + b + } + +} diff --git a/src/test/resources/regression/frontends/error/xlang/Array9.scala b/src/test/resources/regression/xlang/error/ArrayAliasing8.scala similarity index 86% rename from src/test/resources/regression/frontends/error/xlang/Array9.scala rename to src/test/resources/regression/xlang/error/ArrayAliasing8.scala index fbc7dd7376e0966df5ed6eb93bafa7427aeab9e8..e7c27cc9cebf657033da4c07936ce16a510075a0 100644 --- a/src/test/resources/regression/frontends/error/xlang/Array9.scala +++ b/src/test/resources/regression/xlang/error/ArrayAliasing8.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -object Array9 { +object ArrayAliasing8 { def foo(a: Array[Int]): Int = { def rec(): Array[Int] = { diff --git a/src/test/resources/regression/frontends/error/xlang/Array10.scala b/src/test/resources/regression/xlang/error/ArrayAliasing9.scala similarity index 87% rename from src/test/resources/regression/frontends/error/xlang/Array10.scala rename to src/test/resources/regression/xlang/error/ArrayAliasing9.scala index 563cdacdf7e0e66ff56eec571fae4d3e3bbe10be..c84d29c3fbb4866100173b7ac0e7b4d0da9a1e57 100644 --- a/src/test/resources/regression/frontends/error/xlang/Array10.scala +++ b/src/test/resources/regression/xlang/error/ArrayAliasing9.scala @@ -1,6 +1,6 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -object Array10 { +object ArrayAliasing9 { def foo(): Int = { val a = Array.fill(5)(0) diff --git a/src/test/resources/regression/xlang/error/NestedFunctionAliasing1.scala b/src/test/resources/regression/xlang/error/NestedFunctionAliasing1.scala new file mode 100644 index 0000000000000000000000000000000000000000..12feace5413c23e688ac194192482120793b4e24 --- /dev/null +++ b/src/test/resources/regression/xlang/error/NestedFunctionAliasing1.scala @@ -0,0 +1,17 @@ +import leon.lang._ + +object NestedFunctinAliasing1 { + + def f(): Int = { + val a = Array(1,2,3,4) + + def g(b: Array[Int]): Unit = { + require(b.length > 0 && a.length > 0) + b(0) = 10 + a(0) = 17 + } ensuring(_ => b(0) == 10) + + g(a) + a(0) + } ensuring(_ == 10) +} diff --git a/src/test/resources/regression/xlang/error/NestedFunctionAliasing2.scala b/src/test/resources/regression/xlang/error/NestedFunctionAliasing2.scala new file mode 100644 index 0000000000000000000000000000000000000000..81a9b82b39fb47af105ef2e3122fe86a8b10dbb6 --- /dev/null +++ b/src/test/resources/regression/xlang/error/NestedFunctionAliasing2.scala @@ -0,0 +1,17 @@ +import leon.lang._ + +object NestedFunctinAliasing1 { + + def f(a: Array(1,2,3,4)): Int = { + + def g(b: Array[Int]): Unit = { + require(b.length > 0 && a.length > 0) + b(0) = 10 + a(0) = 17 + } ensuring(_ => b(0) == 10) + + g(a) + a(0) + } ensuring(_ == 10) + +} diff --git a/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala b/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala index 84b0087caee305b66e16d82ad2adcd9f4b10e06d..812cf88eac2508814b1218eeb85a787bcad2d4ff 100644 --- a/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala +++ b/src/test/scala/leon/integration/evaluators/EvaluatorSuite.scala @@ -369,28 +369,11 @@ class EvaluatorSuite extends LeonTestSuiteWithProgram with ExpressionsDSL { } test("Lambda functions") { implicit fix => - def checkLambda(e: Evaluator, in: Expr, out: PartialFunction[Expr, Boolean]) { - val res = eval(e, in).success - if (!out.isDefinedAt(res) || !out(res)) - throw new AssertionError(s"Evaluation of '$in' with evaluator '${e.name}' produced invalid '$res'.") - } - - val ONE = bi(1) - val TWO = bi(2) - for(e <- allEvaluators) { - checkLambda(e, fcall("Lambda.foo1")(), { - case Lambda(Seq(vd), Variable(id)) if vd.id == id => true - }) - checkLambda(e, fcall("Lambda.foo2")(), { - case Lambda(Seq(vd), Plus(ONE, Variable(id))) if vd.id == id => true - }) - checkLambda(e, fcall("Lambda.foo3")(), { - case Lambda(Seq(vx, vy), Plus(Plus(Variable(x), ONE), Plus(Variable(y), TWO))) if vx.id == x && vy.id == y => true - }) - checkLambda(e, fcall("Lambda.foo4")(TWO), { - case Lambda(Seq(vd), Plus(Variable(id), TWO)) if vd.id == id => true - }) + eval(e, Application(fcall("Lambda.foo1")(), Seq(bi(1)))) === bi(1) + eval(e, Application(fcall("Lambda.foo2")(), Seq(bi(1)))) === bi(2) + eval(e, Application(fcall("Lambda.foo3")(), Seq(bi(1), bi(2)))) === bi(6) + eval(e, Application(fcall("Lambda.foo4")(bi(2)), Seq(bi(1)))) === bi(3) } } diff --git a/src/test/scala/leon/integration/purescala/ExprOpsSuite.scala b/src/test/scala/leon/integration/purescala/ExprOpsSuite.scala index 17260f7e5263cce8e61e57f5e8b684d1a736232f..fe548c1b013e3cbef8e5a7024286c1e9e8781e0f 100644 --- a/src/test/scala/leon/integration/purescala/ExprOpsSuite.scala +++ b/src/test/scala/leon/integration/purescala/ExprOpsSuite.scala @@ -7,6 +7,7 @@ import leon.test._ import leon.purescala.Constructors._ import leon.purescala.Expressions._ import leon.purescala.ExprOps._ +import leon.purescala.Definitions._ import leon.purescala.Common._ class ExprOpsSuite extends LeonTestSuiteWithProgram with helpers.ExpressionsDSL { @@ -101,5 +102,35 @@ class ExprOpsSuite extends LeonTestSuiteWithProgram with helpers.ExpressionsDSL assert(isSubtypeOf(simplestValue(act).getType, act)) assert(simplestValue(cct).getType == cct) } + + test("canBeHomomorphic") { implicit fix => + import leon.purescala.ExprOps.canBeHomomorphic + import leon.purescala.Types._ + import leon.purescala.Definitions._ + val d = FreshIdentifier("d", IntegerType) + val x = FreshIdentifier("x", IntegerType) + val y = FreshIdentifier("y", IntegerType) + assert(canBeHomomorphic(Variable(d), Variable(x)).isEmpty) + val l1 = Lambda(Seq(ValDef(x)), Variable(x)) + val l2 = Lambda(Seq(ValDef(y)), Variable(y)) + assert(canBeHomomorphic(l1, l2).nonEmpty) + val fType = FunctionType(Seq(IntegerType), IntegerType) + val f = FreshIdentifier("f", + FunctionType(Seq(IntegerType, fType, fType), IntegerType)) + val farg1 = FreshIdentifier("arg1", IntegerType) + val farg2 = FreshIdentifier("arg2", fType) + val farg3 = FreshIdentifier("arg3", fType) + + val fdef = new FunDef(f, Seq(), Seq(ValDef(farg1), ValDef(farg2), ValDef(farg3)), IntegerType) + + // Captured variables should be silent, even if they share the same identifier in two places of the code. + assert(canBeHomomorphic( + FunctionInvocation(fdef.typed, Seq(Variable(d), l1, l2)), + FunctionInvocation(fdef.typed, Seq(Variable(d), l1, l1))).nonEmpty) + + assert(canBeHomomorphic( + StringLiteral("1"), + StringLiteral("2")).isEmpty) + } } diff --git a/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala b/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..4ff5ccc71d159fd76ec1e48d99bce4be4dbfb125 --- /dev/null +++ b/src/test/scala/leon/integration/solvers/QuantifierSolverSuite.scala @@ -0,0 +1,145 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.integration.solvers + +import leon.test._ +import leon.purescala.Common._ +import leon.purescala.Definitions._ +import leon.purescala.Expressions._ +import leon.purescala.Constructors._ +import leon.purescala.Types._ +import leon.LeonContext +import leon.LeonOption + +import leon.solvers._ +import leon.solvers.smtlib._ +import leon.solvers.combinators._ +import leon.solvers.z3._ + +class QuantifierSolverSuite extends LeonTestSuiteWithProgram { + + val sources = List() + + override val leonOpts = List("checkmodels") + + val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { + (if (SolverFactory.hasNativeZ3) Seq( + ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) + ) else Nil) ++ + (if (SolverFactory.hasZ3) Seq( + ("smt-z3", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) + ) else Nil) ++ + (if (SolverFactory.hasCVC4) Seq( + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) + ) else Nil) + } + + val f1: Identifier = FreshIdentifier("f1", FunctionType(Seq(IntegerType, IntegerType), IntegerType)) + val A = TypeParameter.fresh("A") + val f2: Identifier = FreshIdentifier("f2", FunctionType(Seq(A, A), A)) + + def app(f: Expr, args: Expr*): Expr = Application(f, args) + def bi(i: Int): Expr = InfiniteIntegerLiteral(i) + + def associative(f: Expr): Expr = { + val FunctionType(Seq(t1, t2), _) = f.getType + assert(t1 == t2, "Can't specify associativity for type " + f.getType) + + val ids @ Seq(x, y, z) = Seq("x", "y", "z").map(name => FreshIdentifier(name, t1, true)) + Forall(ids.map(ValDef(_)), Equals( + app(f, app(f, Variable(x), Variable(y)), Variable(z)), + app(f, Variable(x), app(f, Variable(y), Variable(z))))) + } + + def commutative(f: Expr): Expr = { + val FunctionType(Seq(t1, t2), _) = f.getType + assert(t1 == t2, "Can't specify commutativity for type " + f.getType) + + val ids @ Seq(x, y) = Seq("x", "y").map(name => FreshIdentifier(name, t1, true)) + Forall(ids.map(ValDef(_)), Equals( + app(f, Variable(x), Variable(y)), app(f, Variable(y), Variable(x)))) + } + + def idempotent(f: Expr): Expr = { + val FunctionType(Seq(t1, t2), _) = f.getType + assert(t1 == t2, "Can't specify idempotency for type " + f.getType) + + val ids @ Seq(x, y, z) = Seq("x", "y", "z").map(name => FreshIdentifier(name, t1, true)) + Forall(ids.map(ValDef(_)), Equals( + app(f, Variable(x), Variable(y)), + app(f, Variable(x), app(f, Variable(x), Variable(y))))) + } + + def rotative(f: Expr): Expr = { + val FunctionType(Seq(t1, t2), _) = f.getType + assert(t1 == t2, "Can't specify associativity for type " + f.getType) + + val ids @ Seq(x, y, z) = Seq("x", "y", "z").map(name => FreshIdentifier(name, t1, true)) + Forall(ids.map(ValDef(_)), Equals( + app(f, app(f, Variable(x), Variable(y)), Variable(z)), + app(f, app(f, Variable(y), Variable(z)), Variable(x)))) + } + + val satisfiable = List( + "paper 1" -> and(associative(Variable(f1)), + Not(Equals(app(Variable(f1), app(Variable(f1), bi(1), bi(2)), bi(3)), + app(Variable(f1), bi(1), app(Variable(f1), bi(2), bi(2)))))), + "paper 2" -> and(commutative(Variable(f1)), idempotent(Variable(f1)), + Not(Equals(app(Variable(f1), app(Variable(f1), bi(1), bi(2)), bi(2)), + app(Variable(f1), bi(1), app(Variable(f1), bi(2), app(Variable(f1), bi(1), bi(3))))))), + "assoc not comm int" -> and(associative(Variable(f1)), Not(commutative(Variable(f1)))), + "assoc not comm generic" -> and(associative(Variable(f2)), Not(commutative(Variable(f2)))), + "comm not assoc int" -> and(commutative(Variable(f1)), Not(associative(Variable(f1)))), + "comm not assoc generic" -> and(commutative(Variable(f2)), Not(associative(Variable(f2)))) + ) + + val unification: Expr = { + val ids @ Seq(x, y) = Seq("x", "y").map(name => FreshIdentifier(name, IntegerType, true)) + Forall(ids.map(ValDef(_)), Not(Equals(app(Variable(f1), Variable(x), Variable(y)), app(Variable(f1), Variable(y), Variable(x))))) + } + + val unsatisfiable = List( + "comm + rotate = assoc int" -> and(commutative(Variable(f1)), rotative(Variable(f1)), Not(associative(Variable(f1)))), + "comm + rotate = assoc generic" -> and(commutative(Variable(f2)), rotative(Variable(f2)), Not(associative(Variable(f2)))), + "unification" -> unification + ) + + def checkSolver(solver: Solver, expr: Expr, sat: Boolean)(implicit fix: (LeonContext, Program)): Unit = { + try { + solver.assertCnstr(expr) + solver.check match { + case Some(true) if sat && fix._1.reporter.warningCount > 0 => + fail(s"Solver $solver - Constraint ${expr.asString} doesn't guarantee model validity!?") + case Some(true) if sat => // checkmodels ensures validity + case Some(false) if !sat => // we were looking for unsat + case res => fail(s"Solver $solver - Constraint ${expr.asString} has result $res!?") + } + } finally { + solver.free() + } + } + + for ((sname, sf) <- getFactories; (ename, expr) <- satisfiable) { + test(s"Satisfiable quantified formula $ename in $sname") { implicit fix => + val (ctx, pgm) = fix + val solver = sf(ctx, pgm) + checkSolver(solver, expr, true) + } + + test(s"Satisfiable quantified formula $ename in $sname with partial models") { implicit fix => + val (ctx, pgm) = fix + val newCtx = ctx.copy(options = ctx.options.filter(_ != UnrollingProcedure.optPartialModels) :+ + LeonOption(UnrollingProcedure.optPartialModels)(true)) + val solver = sf(newCtx, pgm) + checkSolver(solver, expr, true) + } + } + + for ((sname, sf) <- getFactories; (ename, expr) <- unsatisfiable) { + test(s"Unsatisfiable quantified formula $ename in $sname") { implicit fix => + val (ctx, pgm) = fix + val solver = sf(ctx, pgm) + checkSolver(solver, expr, false) + } + } +} diff --git a/src/test/scala/leon/integration/solvers/SolversSuite.scala b/src/test/scala/leon/integration/solvers/SolversSuite.scala index d568e471f08eb4cf1a558675419daabd9e9c940a..d2af2030bded5ae60375180661107346164ca0c6 100644 --- a/src/test/scala/leon/integration/solvers/SolversSuite.scala +++ b/src/test/scala/leon/integration/solvers/SolversSuite.scala @@ -22,67 +22,67 @@ class SolversSuite extends LeonTestSuiteWithProgram { val getFactories: Seq[(String, (LeonContext, Program) => Solver)] = { (if (SolverFactory.hasNativeZ3) Seq( - ("fairz3", (ctx: LeonContext, pgm: Program) => new FairZ3Solver(ctx, pgm)) + ("fairz3", (ctx: LeonContext, pgm: Program) => new Z3StringFairZ3Solver(ctx, pgm)) ) else Nil) ++ (if (SolverFactory.hasZ3) Seq( - ("smt-z3", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBZ3Solver(ctx, pgm))) + ("smt-z3", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBZ3Solver(ctx, pgm))) ) else Nil) ++ (if (SolverFactory.hasCVC4) Seq( - ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new UnrollingSolver(ctx, pgm, new SMTLIBCVC4Solver(ctx, pgm))) + ("smt-cvc4", (ctx: LeonContext, pgm: Program) => new Z3StringUnrollingSolver(ctx, pgm, pgm => new SMTLIBCVC4Solver(ctx, pgm))) ) else Nil) } - val types = Seq( - BooleanType, - UnitType, - CharType, + val types = Seq( + BooleanType, + UnitType, + CharType, RealType, - IntegerType, - Int32Type, - StringType, - TypeParameter.fresh("T"), - SetType(IntegerType), - MapType(IntegerType, IntegerType), + IntegerType, + Int32Type, + StringType, + TypeParameter.fresh("T"), + SetType(IntegerType), + MapType(IntegerType, IntegerType), FunctionType(Seq(IntegerType), IntegerType), - TupleType(Seq(IntegerType, BooleanType, Int32Type)) - ) + TupleType(Seq(IntegerType, BooleanType, Int32Type)) + ) - val vs = types.map(FreshIdentifier("v", _).toVariable) + val vs = types.map(FreshIdentifier("v", _).toVariable) // We need to make sure models are not co-finite val cnstrs = vs.map(v => v.getType match { - case UnitType => - Equals(v, simplestValue(v.getType)) - case SetType(base) => - Not(ElementOfSet(simplestValue(base), v)) - case MapType(from, to) => - Not(Equals(MapApply(v, simplestValue(from)), simplestValue(to))) + case UnitType => + Equals(v, simplestValue(v.getType)) + case SetType(base) => + Not(ElementOfSet(simplestValue(base), v)) + case MapType(from, to) => + Not(Equals(MapApply(v, simplestValue(from)), simplestValue(to))) case FunctionType(froms, to) => Not(Equals(Application(v, froms.map(simplestValue)), simplestValue(to))) - case _ => - not(Equals(v, simplestValue(v.getType))) + case _ => + not(Equals(v, simplestValue(v.getType))) }) def checkSolver(solver: Solver, vs: Set[Variable], cnstr: Expr)(implicit fix: (LeonContext, Program)): Unit = { - try { - solver.assertCnstr(cnstr) + try { + solver.assertCnstr(cnstr) - solver.check match { - case Some(true) => - val model = solver.getModel - for (v <- vs) { - if (model.isDefinedAt(v.id)) { - assert(model(v.id).getType === v.getType, s"Solver $solver - Extracting value of type "+v.getType) - } else { - fail(s"Solver $solver - Model does not contain "+v.id.uniqueName+" of type "+v.getType) - } + solver.check match { + case Some(true) => + val model = solver.getModel + for (v <- vs) { + if (model.isDefinedAt(v.id)) { + assert(model(v.id).getType === v.getType, s"Solver $solver - Extracting value of type "+v.getType) + } else { + fail(s"Solver $solver - Model does not contain "+v.id.uniqueName+" of type "+v.getType) } - case _ => - fail(s"Solver $solver - Constraint "+cnstr.asString+" is unsat!?") - } - } finally { - solver.free() + } + case _ => + fail(s"Solver $solver - Constraint "+cnstr.asString+" is unsat!? Solver was "+solver.getClass) } + } finally { + solver.free() + } } // Check that we correctly extract several types from solver models @@ -99,6 +99,6 @@ class SolversSuite extends LeonTestSuiteWithProgram { for ((v,cnstr) <- vs zip cnstrs) { val solver = new EnumerationSolver(fix._1, fix._2) checkSolver(solver, Set(v), cnstr) -} + } } } diff --git a/src/test/scala/leon/integration/solvers/StringRenderSuite.scala b/src/test/scala/leon/integration/solvers/StringRenderSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..cd6bd2ee75a9594d2c0be081db265e75c7a45620 --- /dev/null +++ b/src/test/scala/leon/integration/solvers/StringRenderSuite.scala @@ -0,0 +1,431 @@ +package leon.integration.solvers + + +import org.scalatest.FunSuite +import org.scalatest.Matchers +import leon.test.helpers.ExpressionsDSL +import leon.solvers.string.StringSolver +import leon.purescala.Common.FreshIdentifier +import leon.purescala.Common.Identifier +import leon.purescala.Expressions._ +import leon.purescala.Definitions._ +import leon.purescala.DefOps +import leon.purescala.ExprOps +import leon.purescala.Types._ +import leon.purescala.TypeOps +import leon.purescala.Constructors._ +import leon.synthesis.rules.{StringRender, TypedTemplateGenerator} +import scala.collection.mutable.{HashMap => MMap} +import scala.concurrent.Future +import scala.concurrent.ExecutionContext.Implicits.global +import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.ScalaFutures +import org.scalatest.time.SpanSugar._ +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.ScalaFutures +import org.scalatest.time.SpanSugar._ +import leon.purescala.Types.Int32Type +import leon.test.LeonTestSuiteWithProgram +import leon.synthesis.SourceInfo +import leon.synthesis.CostModels +import leon.synthesis.graph.SimpleSearch +import leon.synthesis.graph.AndNode +import leon.synthesis.SearchContext +import leon.synthesis.Synthesizer +import leon.synthesis.SynthesisSettings +import leon.synthesis.RuleApplication +import leon.synthesis.RuleClosed +import leon.evaluators._ +import leon.LeonContext +import leon.synthesis.rules.DetupleInput +import leon.synthesis.Rules +import leon.solvers.ModelBuilder +import scala.language.implicitConversions + +class StringRenderSuite extends LeonTestSuiteWithProgram with Matchers with ScalaFutures { + test("Template Generator simple"){ case (ctx: LeonContext, program: Program) => + val TTG = new TypedTemplateGenerator(IntegerType) {} + val e = TTG(hole => Plus(hole, hole)) + val (expr, vars) = e.instantiateWithVars + vars should have length 2 + expr shouldEqual Plus(Variable(vars(0)), Variable(vars(1))) + + val f = TTG.nested(hole => (Minus(hole, expr), vars)) + val (expr2, vars2) = f.instantiateWithVars + vars2 should have length 3 + vars2(0) shouldEqual vars(0) + vars2(1) shouldEqual vars(1) + expr2 shouldEqual Minus(Variable(vars2(2)), expr) + } + + trait withSymbols { + val x = FreshIdentifier("x", StringType) + val y = FreshIdentifier("y", StringType) + val i = FreshIdentifier("i", IntegerType) + val f = FreshIdentifier("f", FunctionType(Seq(IntegerType), StringType)) + val fd = new FunDef(f, Nil, Seq(ValDef(i)), StringType) + val fdi = FunctionInvocation(fd.typed, Seq(Variable(i))) + } + + test("toEquations working"){ case (ctx: LeonContext, program: Program) => + import StringRender._ + new withSymbols { + val lhs = RegularStringFormToken(Left("abc"))::RegularStringFormToken(Right(x))::OtherStringFormToken(fdi)::Nil + val rhs = RegularStringChunk("abcdef")::OtherStringChunk(fdi)::Nil + val p = toEquations(lhs, rhs) + p should not be 'empty + p.get should have length 1 + } + } + + test("toEquations working 2"){ case (ctx: LeonContext, program: Program) => + import StringRender._ + new withSymbols { + val lhs = RegularStringFormToken(Left("abc"))::RegularStringFormToken(Right(x))::OtherStringFormToken(fdi)::RegularStringFormToken(Right(y))::Nil + val rhs = RegularStringChunk("abcdef")::OtherStringChunk(fdi)::RegularStringChunk("123")::Nil + val p = toEquations(lhs, rhs) + p should not be 'empty + p.get should have length 2 + } + } + + test("toEquations failing"){ case (ctx: LeonContext, program: Program) => + import StringRender._ + new withSymbols { + val lhs = RegularStringFormToken(Left("abc"))::RegularStringFormToken(Right(x))::RegularStringFormToken(Right(y))::Nil + val rhs = RegularStringChunk("abcdef")::OtherStringChunk(fdi)::RegularStringChunk("123")::Nil + val p = toEquations(lhs, rhs) + p should be ('empty) + } + } + + def applyStringRenderOn(functionName: String): (FunDef, Program) = { + val ci = synthesisInfos(functionName) + val search = new SimpleSearch(ctx, ci, ci.problem, CostModels.default, Some(200)) + val synth = new Synthesizer(ctx, program, ci, SynthesisSettings(rules = Seq(StringRender))) + val orNode = search.g.root + if (!orNode.isExpanded) { + val hctx = SearchContext(synth.sctx, synth.ci, orNode, search) + orNode.expand(hctx) + } + val andNodes = orNode.descendants.collect { + case n: AndNode => + n + } + + val rulesApps = (for ((t, i) <- andNodes.zipWithIndex) yield { + val status = if (t.isDeadEnd) { + "closed" + } else { + "open" + } + t.ri.asString -> i + }).toMap + rulesApps should contain key "String conversion" + + val rid = rulesApps("String conversion") + val path = List(rid) + + val solutions = (search.traversePath(path) match { + case Some(an: AndNode) => + val hctx = SearchContext(synth.sctx, synth.ci, an, search) + an.ri.apply(hctx) + case _ => throw new Exception("Was not an and node") + }) match { + case RuleClosed(solutions) => solutions + case _ => fail("no solution found") + } + solutions should not be 'empty + val newProgram = DefOps.addFunDefs(synth.program, solutions.head.defs, synth.sctx.functionContext) + val newFd = ci.fd.duplicate() + newFd.body = Some(solutions.head.term) + val (newProgram2, _) = DefOps.replaceFunDefs(newProgram)({ fd => + if(fd == ci.fd) { + Some(newFd) + } else None + }, { (fi: FunctionInvocation, fd: FunDef) => + Some(FunctionInvocation(fd.typed, fi.args)) + }) + (newFd, newProgram2) + } + + def getFunctionArguments(functionName: String) = { + program.lookupFunDef("StringRenderSuite." + functionName) match { + case Some(fd) => fd.params + case None => + fail("Could not find function " + functionName + " in sources. Other functions:" + program.definedFunctions.map(_.id.name).sorted) + } + } + + implicit class StringUtils(s: String) { + def replaceByExample: String = + s.replaceAll("\\((\\w+):(.*) by example", "\\($1:$2 ensuring { (res: String) => ($1, res) passes { case _ if false => \"\" } }") + } + + val sources = List(""" + |import leon.lang._ + |import leon.collection._ + |import leon.lang.synthesis._ + |import leon.annotation._ + | + |object StringRenderSuite { + | def literalSynthesis(i: Int): String = ??? ensuring { (res: String) => (i, res) passes { case 42 => ":42." } } + | + | def booleanSynthesis(b: Boolean): String = ??? ensuring { (res: String) => (b, res) passes { case true => "yes" case false => "no" } } + | def booleanSynthesis2(b: Boolean): String = ??? ensuring { (res: String) => (b, res) passes { case true => "B: true" } } + | //def stringEscape(s: String): String = ??? ensuring { (res: String) => (s, res) passes { case "\n" => "done...\\n" } } + | + | case class Dummy(s: String) + | def dummyToString(d: Dummy): String = "{" + d.s + "}" + | + | case class Dummy2(s: String) + | def dummy2ToString(d: Dummy2): String = "<" + d.s + ">" + | + | case class Config(i: BigInt, t: (Int, String)) + | + | def configToString(c: Config): String = ??? ensuring { (res: String) => (c, res) passes { case Config(BigInt(1), (2, "3")) => "1: 2 -> 3" } } + | def configToString2(c: Config): String = ??? ensuring { (res: String) => (c, res) passes { case Config(BigInt(1), (2, "3")) => "3: 1 -> 2" } } + | + | sealed abstract class Tree + | case class Knot(left: Tree, right: Tree) extends Tree + | case class Bud(v: String) extends Tree + | + | def treeToString(t: Tree): String = ???[String] ensuring { + | (res : String) => (t, res) passes { + | case Knot(Knot(Bud("17"), Bud("14")), Bud("23")) => + | "<<17Y14>Y23>" + | case Bud("foo") => + | "foo" + | case Knot(Bud("foo"), Bud("foo")) => + | "<fooYfoo>" + | case Knot(Bud("foo"), Knot(Bud("bar"), Bud("foo"))) => + | "<fooY<barYfoo>>" + | } + | } + | + | sealed abstract class BList[T] + | case class BCons[T](head: (T, T), tail: BList[T]) extends BList[T] + | case class BNil[T]() extends BList[T] + | + | def bListToString[T](b: BList[T], f: T => String) = ???[String] ensuring + | { (res: String) => (b, res) passes { case BNil() => "[]" case BCons(a, BCons(b, BCons(c, BNil()))) => "[" + f(a._1) + "-" + f(a._2) + ", " + f(b._1) + "-" + f(b._2) + ", " + f(c._1) + "-" + f(c._2) + "]" }} + | + | // Handling one rendering function at a time. + | case class BConfig(flags: BList[Dummy]) + | def bConfigToString(b: BConfig): String = ???[String] ensuring + | { (res: String) => (b, res) passes { case BConfig(BNil()) => "Config" + bListToString[Dummy](BNil(), (x: Dummy) => dummyToString(x)) } } + | + | def customListToString[T](l: List[T], f: T => String): String = ???[String] ensuring + | { (res: String) => (l, res) passes { case _ if false => "" } } + | + | // Handling multiple rendering functions at the same time. + | case class DConfig(dums: List[Dummy], dums2: List[Dummy2]) + | def dConfigToString(dc: DConfig): String = ???[String] ensuring + | { (res: String) => (dc, res) passes { + | case DConfig(Nil(), Nil()) => + | "Solution:\n " + customListToString[Dummy](List[Dummy](), (x : Dummy) => dummyToString(x)) + "\n " + customListToString[Dummy2](List[Dummy2](), (x: Dummy2) => dummy2ToString(x)) } } + | + | case class Node(tag: String, l: List[Edge]) + | case class Edge(start: Node, end: Node) + | + | def nodeToString(n: Node): String = ??? by example + | def edgeToString(e: Edge): String = ??? by example + | def listEdgeToString(l: List[Edge]): String = ??? by example + |} + """.stripMargin.replaceByExample) + implicit val (ctx, program) = getFixture() + + val synthesisInfos = SourceInfo.extractFromProgram(ctx, program).map(si => si.fd.id.name -> si ).toMap + + def synthesizeAndTest(functionName: String, tests: (Seq[Expr], String)*) { + val (fd, program) = applyStringRenderOn(functionName) + val when = new DefaultEvaluator(ctx, program) + val args = getFunctionArguments(functionName) + for((in, out) <- tests) { + val expr = functionInvocation(fd, in) + when.eval(expr) match { + case EvaluationResults.Successful(value) => value shouldEqual StringLiteral(out) + case EvaluationResults.EvaluatorError(msg) => fail(/*program + "\n" + */msg) + case EvaluationResults.RuntimeError(msg) => fail(/*program + "\n" + */"Runtime: " + msg) + } + } + } + def synthesizeAndAbstractTest(functionName: String)(tests: (FunDef, Program) => Seq[(Seq[Expr], Expr)]) { + val (fd, program) = applyStringRenderOn(functionName) + val when_abstract = new AbstractEvaluator(ctx, program) + val args = getFunctionArguments(functionName) + for((in, out) <- tests(fd, program)) { + val expr = functionInvocation(fd, in) + when_abstract.eval(expr) match { + case EvaluationResults.Successful(value) => val m = ExprOps.canBeHomomorphic(value._1, out) + assert(m.nonEmpty, value._1 + " was not homomorphic with " + out) + case EvaluationResults.EvaluatorError(msg) => fail(/*program + "\n" + */msg) + case EvaluationResults.RuntimeError(msg) => fail(/*program + "\n" + */"Runtime: " + msg) + } + } + } + abstract class CCBuilder(ccName: String, prefix: String = "StringRenderSuite.")(implicit program: Program) { + val caseClassName = prefix + ccName + def getType: TypeTree = program.lookupCaseClass(caseClassName).get.typed + def apply(s: Expr*): CaseClass = { + CaseClass(program.lookupCaseClass(caseClassName).get.typed, s.toSeq) + } + def apply(s: String): CaseClass = this.apply(StringLiteral(s)) + } + abstract class ParamCCBuilder(caseClassName: String, prefix: String = "StringRenderSuite.")(implicit program: Program) { + def apply(types: TypeTree*)(s: Expr*): CaseClass = { + CaseClass(program.lookupCaseClass(prefix+caseClassName).get.typed(types), + s.toSeq) + } + } + def method(fName: String)(implicit program: Program) = { + program.lookupFunDef("StringRenderSuite." + fName).get + } + abstract class paramMethod(fName: String)(implicit program: Program) { + val fd = program.lookupFunDef("StringRenderSuite." + fName).get + def apply(types: TypeTree*)(args: Expr*) = + FunctionInvocation(fd.typed(types), args) + } + // Mimics the file above, allows construction of expressions. + case class Constructors(program: Program) { + implicit val p = program + implicit def CCBuilderToType(c: CCBuilder): TypeTree = c.getType + object Knot extends CCBuilder("Knot") + object Bud extends CCBuilder("Bud") + object Dummy extends CCBuilder("Dummy") + object Dummy2 extends CCBuilder("Dummy2") + object Cons extends ParamCCBuilder("Cons", "leon.collection.") + object Nil extends ParamCCBuilder("Nil", "leon.collection.") + object List { + def apply(types: TypeTree*)(elems: Expr*): CaseClass = { + elems.toList match { + case collection.immutable.Nil => Nil(types: _*)() + case a::b => Cons(types: _*)(a, List(types: _*)(b: _*)) + } + } + } + + object BCons extends ParamCCBuilder("BCons") + object BNil extends ParamCCBuilder("BNil") + object BList { + def apply(types: TypeTree*)(elems: Expr*): CaseClass = { + elems.toList match { + case collection.immutable.Nil => BNil(types: _*)() + case a::b => BCons(types: _*)(a, BList(types: _*)(b: _*)) + } + } + def helper(types: TypeTree*)(elems: (Expr, Expr)*): CaseClass = { + this.apply(types: _*)(elems.map(x => tupleWrap(Seq(x._1, x._2))): _*) + } + } + object Config extends CCBuilder("Config") { + def apply(i: Int, b: (Int, String)): CaseClass = + this.apply(InfiniteIntegerLiteral(i), tupleWrap(Seq(IntLiteral(b._1), StringLiteral(b._2)))) + } + object BConfig extends CCBuilder("BConfig") + object DConfig extends CCBuilder("DConfig") + lazy val dummyToString = method("dummyToString") + lazy val dummy2ToString = method("dummy2ToString") + lazy val bListToString = method("bListToString") + object customListToString extends paramMethod("customListToString") + } + + test("Literal synthesis"){ case (ctx: LeonContext, program: Program) => + synthesizeAndTest("literalSynthesis", + Seq(IntLiteral(156)) -> ":156.", + Seq(IntLiteral(-5)) -> ":-5.") + } + + test("boolean Synthesis"){ case (ctx: LeonContext, program: Program) => + synthesizeAndTest("booleanSynthesis", + Seq(BooleanLiteral(true)) -> "yes", + Seq(BooleanLiteral(false)) -> "no") + } + + test("Boolean synthesis 2"){ case (ctx: LeonContext, program: Program) => + synthesizeAndTest("booleanSynthesis2", + Seq(BooleanLiteral(true)) -> "B: true", + Seq(BooleanLiteral(false)) -> "B: false") + } + + /*test("String escape synthesis"){ case (ctx: LeonContext, program: Program) => + synthesizeAndTest("stringEscape", + Seq(StringLiteral("abc")) -> "done...abc", + Seq(StringLiteral("\t")) -> "done...\\t") + + }*/ + + test("Case class synthesis"){ case (ctx: LeonContext, program: Program) => + val c = Constructors(program); import c._ + StringRender.enforceDefaultStringMethodsIfAvailable = false + synthesizeAndTest("configToString", + Seq(Config(4, (5, "foo"))) -> "4: 5 -> foo") + } + + test("Out of order synthesis"){ case (ctx: LeonContext, program: Program) => + val c = Constructors(program); import c._ + StringRender.enforceDefaultStringMethodsIfAvailable = false + synthesizeAndTest("configToString2", + Seq(Config(4, (5, "foo"))) -> "foo: 4 -> 5") + } + + test("Recursive case class synthesis"){ case (ctx: LeonContext, program: Program) => + val c = Constructors(program); import c._ + synthesizeAndTest("treeToString", + Seq(Knot(Knot(Bud("aa"), Bud("bb")), Knot(Bud("mm"), Bud("nn")))) -> + "<<aaYbb>Y<mmYnn>>") + } + + test("Abstract synthesis"){ case (ctx: LeonContext, program: Program) => + val c = Constructors(program); import c._ + val d = FreshIdentifier("d", Dummy) + + synthesizeAndTest("bListToString", + Seq(BList.helper(Dummy)( + (Dummy("a"), Dummy("b")), + (Dummy("c"), Dummy("d"))), + Lambda(Seq(ValDef(d)), FunctionInvocation(dummyToString.typed, Seq(Variable(d))))) + -> + "[{a}-{b}, {c}-{d}]") + + } + + + test("Pretty-printing using inferred not yet defined functions in argument"){ case (ctx: LeonContext, program: Program) => + StringRender.enforceDefaultStringMethodsIfAvailable = true + synthesizeAndAbstractTest("bConfigToString"){ (fd: FunDef, program: Program) => + val c = Constructors(program); import c._ + val arg = BList.helper(Dummy)((Dummy("a"), Dummy("b"))) + val d = FreshIdentifier("d", Dummy) + val lambdaDummyToString = Lambda(Seq(ValDef(d)), FunctionInvocation(dummyToString.typed, Seq(Variable(d)))) + val listDummyToString = functionInvocation(bListToString, Seq(arg, lambdaDummyToString)) + Seq(Seq(BConfig(arg)) -> + StringConcat(StringLiteral("Config"), listDummyToString)) + } + } + + test("Pretty-printing using an existing not yet defined parametrized function") { case (ctx: LeonContext, program: Program) => + StringRender.enforceDefaultStringMethodsIfAvailable = true + + synthesizeAndAbstractTest("dConfigToString"){ (fd: FunDef, program: Program) => + val c = Constructors(program); import c._ + + val listDummy1 = c.List(Dummy)(Dummy("a"), Dummy("b"), Dummy("c")) + val listDummy2 = c.List(Dummy2)(Dummy2("1"), Dummy2("2")) + val arg = DConfig(listDummy1, listDummy2) + + val d = FreshIdentifier("d", Dummy) + val lambdaDummyToString = Lambda(Seq(ValDef(d)), FunctionInvocation(dummyToString.typed, Seq(Variable(d)))) + val d2 = FreshIdentifier("d2", Dummy2) + val lambdaDummy2ToString = Lambda(Seq(ValDef(d2)), FunctionInvocation(dummy2ToString.typed, Seq(Variable(d2)))) + + Seq(Seq(arg) -> + StringConcat(StringConcat(StringConcat( + StringLiteral("Solution:\n "), + customListToString(Dummy)(listDummy1, lambdaDummyToString)), + StringLiteral("\n ")), + customListToString(Dummy2)(listDummy2, lambdaDummy2ToString))) + } + } +} \ No newline at end of file diff --git a/src/test/scala/leon/regression/frontends/FrontEndsSuite.scala b/src/test/scala/leon/regression/frontends/FrontEndsSuite.scala index b9bfccd36bc2abcc3479c9d3a703df0cdfe16ab7..a0e52aa0b083c04a5a3e1141d2f3287a1995611e 100644 --- a/src/test/scala/leon/regression/frontends/FrontEndsSuite.scala +++ b/src/test/scala/leon/regression/frontends/FrontEndsSuite.scala @@ -36,7 +36,6 @@ class FrontEndsSuite extends LeonRegressionSuite { } val pipeNormal = xlang.NoXLangFeaturesChecking andThen NoopPhase() // redundant NoopPhase to trigger throwing error between phases - val pipeX = NoopPhase[Program]() val baseDir = "regression/frontends/" forEachFileIn(baseDir+"passing/") { f => @@ -45,8 +44,5 @@ class FrontEndsSuite extends LeonRegressionSuite { forEachFileIn(baseDir+"error/simple/") { f => testFrontend(f, pipeNormal, true) } - forEachFileIn(baseDir+"error/xlang/") { f => - testFrontend(f, pipeX, true) - } } diff --git a/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala b/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala index c70df950e0768ca71dd7bba013ac04cd7404edab..ca2b4a3c98107c73c9244486200dd0a44a348cb2 100644 --- a/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala +++ b/src/test/scala/leon/regression/synthesis/SynthesisSuite.scala @@ -251,6 +251,7 @@ object SortedList { case "insertSorted" => Decomp("Assert isSorted(in1)", List( Decomp("ADT Split on 'in1'", List( + Close("CEGIS"), Decomp("Ineq. Split on 'head*' and 'v*'", List( Close("CEGIS"), Decomp("Equivalent Inputs *", List( @@ -259,8 +260,7 @@ object SortedList { )) )), Close("CEGIS") - )), - Close("CEGIS") + )) )) )) } diff --git a/src/test/scala/leon/regression/termination/TerminationSuite.scala b/src/test/scala/leon/regression/termination/TerminationSuite.scala index 6c2df0820a2b0f0c19ee393c12e1d70b36bbf233..215c72272673c5c6e9bbd5173b2b334ed3fc753c 100644 --- a/src/test/scala/leon/regression/termination/TerminationSuite.scala +++ b/src/test/scala/leon/regression/termination/TerminationSuite.scala @@ -34,8 +34,7 @@ class TerminationSuite extends LeonRegressionSuite { val ignored = List( "verification/purescala/valid/MergeSort.scala", - "verification/purescala/valid/Nested14.scala", - "verification/purescala/valid/InductiveQuantification.scala" + "verification/purescala/valid/Nested14.scala" ) val t = if (ignored.exists(displayName.replaceAll("\\\\","/").endsWith)) { diff --git a/src/test/scala/leon/regression/verification/VerificationSuite.scala b/src/test/scala/leon/regression/verification/VerificationSuite.scala index f2ae97880694baac498b94570f81beb1c21c422f..446d06675cdb9b3eb1c7095137ab95e8a928c399 100644 --- a/src/test/scala/leon/regression/verification/VerificationSuite.scala +++ b/src/test/scala/leon/regression/verification/VerificationSuite.scala @@ -41,7 +41,7 @@ trait VerificationSuite extends LeonRegressionSuite { VerificationPhase andThen (if (desugarXLang) FixReportLabels else NoopPhase[VerificationReport]) - val ctx = createLeonContext(files:_*) + val ctx = createLeonContext(files:_*).copy(reporter = new TestErrorReporter) try { val (_, ast) = extraction.run(ctx, files) diff --git a/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala b/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala index 6c450c1c8613f74c5a6679aaaf6035f59e593e66..a0beb7fba142dccf9150cfaca0a0ad2329baa861 100644 --- a/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala +++ b/src/test/scala/leon/regression/verification/purescala/PureScalaVerificationSuite.scala @@ -19,8 +19,8 @@ abstract class PureScalaVerificationSuite extends VerificationSuite { val opts: List[List[String]] = { List( List("--feelinglucky"), - List("--codegen", "--evalground", "--feelinglucky"), - List("--solvers=fairz3,enum", "--codegen", "--evalground", "--feelinglucky") + List("--codegen", /*"--evalground",*/ "--feelinglucky"), + List("--solvers=fairz3,enum", "--codegen", /*"--evalground",*/ "--feelinglucky") ) ++ ( if (isZ3Available) List( List("--solvers=smt-z3", "--feelinglucky") @@ -46,6 +46,7 @@ class PureScalaValidSuite2 extends PureScalaValidSuite { } class PureScalaValidSuite3 extends PureScalaValidSuite { val optionVariants = List(opts(2)) + override val ignored = Seq("valid/Predicate.scala") } class PureScalaValidSuiteZ3 extends PureScalaValidSuite { val optionVariants = if (isZ3Available) List(opts(3)) else Nil diff --git a/src/test/scala/leon/regression/xlang/XLangDesugaringSuite.scala b/src/test/scala/leon/regression/xlang/XLangDesugaringSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..41260ec2df6f9d9f32827c39e7f700da15ccca9e --- /dev/null +++ b/src/test/scala/leon/regression/xlang/XLangDesugaringSuite.scala @@ -0,0 +1,46 @@ +/* Copyright 2009-2015 EPFL, Lausanne */ + +package leon.regression.xlang + +import leon._ +import leon.test._ +import purescala.Definitions.Program +import java.io.File + +class XLangDesugaringSuite extends LeonRegressionSuite { + // Hard-code output directory, for Eclipse purposes + + val pipeline = frontends.scalac.ExtractionPhase andThen new utils.PreprocessingPhase(true) + + def testFrontend(f: File, forError: Boolean) = { + test ("Testing " + f.getName) { + val ctx = createLeonContext() + if (forError) { + intercept[LeonFatalError]{ + pipeline.run(ctx, List(f.getAbsolutePath)) + } + } else { + pipeline.run(ctx, List(f.getAbsolutePath)) + } + } + + } + + private def forEachFileIn(path : String)(block : File => Unit) { + val fs = filesInResourceDir(path, _.endsWith(".scala")) + + for(f <- fs) { + block(f) + } + } + + val baseDir = "regression/xlang/" + + forEachFileIn(baseDir+"passing/") { f => + testFrontend(f, false) + } + forEachFileIn(baseDir+"error/") { f => + testFrontend(f, true) + } + +} diff --git a/src/test/scala/leon/test/TestSilentReporter.scala b/src/test/scala/leon/test/TestSilentReporter.scala index 2cf9ea4f7f6c78d4f07002a01f434a150e0d9034..2a8761584222f02c1e4bf85b6fd031c603771079 100644 --- a/src/test/scala/leon/test/TestSilentReporter.scala +++ b/src/test/scala/leon/test/TestSilentReporter.scala @@ -13,3 +13,10 @@ class TestSilentReporter extends DefaultReporter(Set()) { case _ => } } + +class TestErrorReporter extends DefaultReporter(Set()) { + override def emit(msg: Message): Unit = msg match { + case Message(this.ERROR | this.FATAL, _, _) => super.emit(msg) + case _ => + } +} diff --git a/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala b/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala index b448953acf1856cf1df12289880000226f78f8bf..4f74b00c5f93a3125caece106809dfd4182fa8a8 100644 --- a/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala +++ b/src/test/scala/leon/unit/purescala/ExprOpsSuite.scala @@ -6,7 +6,7 @@ import leon.test._ import leon.purescala.Common._ import leon.purescala.Expressions._ import leon.purescala.Types._ -import leon.purescala.TypeOps._ +import leon.purescala.TypeOps.isSubtypeOf import leon.purescala.Definitions._ import leon.purescala.ExprOps._ @@ -279,4 +279,44 @@ class ExprOpsSuite extends LeonTestSuite with helpers.WithLikelyEq with helpers. } } + + test("preMapWithContext") { ctx => + val expr = Plus(bi(1), Minus(bi(2), bi(3))) + def op(e : Expr, set: Set[Int]): (Option[Expr], Set[Int]) = e match { + case Minus(InfiniteIntegerLiteral(two), e2) if two == BigInt(2) => (Some(bi(2)), set) + case InfiniteIntegerLiteral(one) if one == BigInt(1) => (Some(bi(2)), set) + case InfiniteIntegerLiteral(two) if two == BigInt(2) => (Some(bi(42)), set) + case _ => (None, set) + } + + assert(preMapWithContext(op, false)(expr, Set()) === Plus(bi(2), bi(2))) + assert(preMapWithContext(op, true)(expr, Set()) === Plus(bi(42), bi(42))) + + val expr2 = Let(x.id, bi(1), Let(y.id, bi(2), Plus(x, y))) + def op2(e: Expr, bindings: Map[Identifier, BigInt]): (Option[Expr], Map[Identifier, BigInt]) = e match { + case Let(id, InfiniteIntegerLiteral(v), body) => (None, bindings + (id -> v)) + case Variable(id) => (bindings.get(id).map(v => InfiniteIntegerLiteral(v)), bindings) + case _ => (None, bindings) + } + + assert(preMapWithContext(op2, false)(expr2, Map()) === Let(x.id, bi(1), Let(y.id, bi(2), Plus(bi(1), bi(2))))) + + def op3(e: Expr, bindings: Map[Identifier, BigInt]): (Option[Expr], Map[Identifier, BigInt]) = e match { + case Let(id, InfiniteIntegerLiteral(v), body) => (Some(body), bindings + (id -> v)) + case Variable(id) => (bindings.get(id).map(v => InfiniteIntegerLiteral(v)), bindings) + case _ => (None, bindings) + } + assert(preMapWithContext(op3, true)(expr2, Map()) === Plus(bi(1), bi(2))) + + + val expr4 = Plus(Let(y.id, bi(2), y), + Let(y.id, bi(4), y)) + def op4(e: Expr, bindings: Map[Identifier, BigInt]): (Option[Expr], Map[Identifier, BigInt]) = e match { + case Let(id, InfiniteIntegerLiteral(v), body) => (Some(body), if(bindings.contains(id)) bindings else (bindings + (id -> v))) + case Variable(id) => (bindings.get(id).map(v => InfiniteIntegerLiteral(v)), bindings) + case _ => (None, bindings) + } + assert(preMapWithContext(op4, true)(expr4, Map()) === Plus(bi(2), bi(4))) + } + } diff --git a/testcases/stringrender/CalendarRender.scala b/testcases/stringrender/CalendarRender.scala new file mode 100644 index 0000000000000000000000000000000000000000..062c908be0926f37871f4168330fd556b1d50a42 --- /dev/null +++ b/testcases/stringrender/CalendarRender.scala @@ -0,0 +1,59 @@ +import leon.lang._ +import leon.lang.synthesis._ +import leon.annotation._ +import leon.collection._ + +object CalendartoString { + val dayEventsSep = "\n" + val eventsSep = "\n" + val daysSep = "\n\n" + val hoursSep = "-" + val dayPlusPrefix = " (D+" + val dayPlusSuffix = ")" + val hoursDescriptionSep = " : " + val titleDescriptionSep = "\n" + + case class Week(days: List[Day]) + case class Day(name: String, events: List[Event]) + case class Event(startHour: Int, startMinute: Int, durationMinutes: Int, description: String) + + def renderHour(h: Int, m: Int) = { + //require(m >= 0 && m < 60 && h > 0) + val h_adjusted = h + m / 60 + val realh = h_adjusted % 24 + val realm = m % 60 + val days = h_adjusted / 24 + realh + "h" + (if(realm == 0) "" else (if(realm < 10) "0" + realm else realm.toString)) + (if(days > 0) dayPlusPrefix + days + dayPlusSuffix else "") + } + + def renderEvent(e: Event) = { + renderHour(e.startHour, e.startMinute) + hoursSep + renderHour(e.startHour, e.startMinute + e.durationMinutes) + hoursDescriptionSep + e.description + } + + def renderList[T](l: List[T], f: T => String, prefix: String, separator: String, suffix: String): String = l match { + case Nil() => prefix + suffix + case Cons(e, tail:Cons[T]) => prefix + f(e) + separator + renderList(tail, f, "", separator, suffix) + case Cons(e, Nil()) => prefix + f(e) + suffix + } + + def renderDay(d: Day): String = { + renderList(d.events, (e: Event) => renderEvent(e), d.name + dayEventsSep, eventsSep, "") + } + + def renderWeek(s: Week): String = { + renderList(s.days, (d: Day) => renderDay(d), """Dear manager, +Here is what happened last week: +""", daysSep, "") + } ensuring { (res: String) => + (s, res) passes { + case Week(Cons(Day("Wednesday", Cons(Event(8, 30, 60, "First meeting"), Cons(Event(23, 15, 420, "Bus journey"), Nil()))), Cons(Day("Thursday", Cons(Event(12, 0, 65, "Meal with client"), Nil())), Nil()))) => """Dear manager, +Here is what happened last week: +Wednesday +8h30-9h30 : First meeting +23h15-6h15 (D+1) : Bus journey + +Thursday +12h-13h05 : Meal with client""" + } + } +} \ No newline at end of file diff --git a/testcases/stringrender/CustomRender.scala b/testcases/stringrender/CustomRender.scala new file mode 100644 index 0000000000000000000000000000000000000000..7f81d50f0bf756096222944dbb14bf6a9b405720 --- /dev/null +++ b/testcases/stringrender/CustomRender.scala @@ -0,0 +1,23 @@ +/** + * Name: CustomRender.scala + * Creation: 15.1.2015 + * Author: Mikael Mayer (mikael.mayer@epfl.ch) + * Comments: Custom generic rendering + */ + +import leon.lang._ +import leon.annotation._ +import leon.collection._ +import leon.collection.ListOps._ +import leon.lang.synthesis._ + +object CustomRender { + def generic_list[T](l: List[T], f: T => String): String = { + ??? + } ensuring { + (res: String) => ((l, res) passes { + case Nil() => "[]" + case Cons(a, Cons(b, Nil())) => "[" + f(a) + ", " + f(b) + "]" + }) + } +} \ No newline at end of file diff --git a/testcases/stringrender/JsonRender.scala b/testcases/stringrender/JsonRender.scala index 2b858784576e6cd97a282a2cba59246f5c4fa84f..274ed3a0f9af8ed9128f14c983abf99d714c3393 100644 --- a/testcases/stringrender/JsonRender.scala +++ b/testcases/stringrender/JsonRender.scala @@ -1,8 +1,8 @@ /** - * Name: ListRender.scala - * Creation: 14.10.2015 - * Author: Mika�l Mayer (mikael.mayer@epfl.ch) - * Comments: First benchmark ever for solving string transformation problems. List specifications. + * Name: JsonRender.scala + * Creation: 25.11.2015 + * Author: Mikael Mayer (mikael.mayer@epfl.ch) + * Comments: Json specifications */ import leon.lang._ diff --git a/testcases/stringrender/ModularRender.scala b/testcases/stringrender/ModularRender.scala new file mode 100644 index 0000000000000000000000000000000000000000..cf46ab0ab81162d9bf13fd2cfd89c1ca842f262c --- /dev/null +++ b/testcases/stringrender/ModularRender.scala @@ -0,0 +1,36 @@ +/** + * Name: ModularRender.scala + * Creation: 26.01.2015 + * Author: Mikael Mayer (mikael.mayer@epfl.ch) + * Comments: Modular rendering, in one order or the other. + */ +import leon.lang._ +import leon.annotation._ +import leon.collection._ +import leon.collection.ListOps._ +import leon.lang.synthesis._ + +object ModularRender { + def customToString[T](l : List[T], f : (T) => String): String = { + ??? + } ensuring { + (res : String) => (l, res) passes { + case Nil() => + "[]" + case Cons(t, Nil()) => + "[" + f(t) + "]" + case Cons(ta, Cons(tb, Nil())) => + "[" + f(ta) + ", " + f(tb) + "]" + case Cons(ta, Cons(tb, Cons(tc, Nil()))) => + "[" + f(ta) + ", " + f(tb) + ", " + f(tc) + "]" + } + } + + def booleanToString(b: Boolean) = if(!b) "Up" else "Down" + + case class Configuration(flags: List[Boolean]) + + // We want to write Config:[Up,Down,Up....] + def ConfigToString(config : Configuration): String = + ???[String] ask config +} diff --git a/testcases/stringrender/SymbolGrammarRender.scala b/testcases/stringrender/SymbolGrammarRender.scala new file mode 100644 index 0000000000000000000000000000000000000000..af03bdbbfc4ced66729e03171a1c633cb267d672 --- /dev/null +++ b/testcases/stringrender/SymbolGrammarRender.scala @@ -0,0 +1,54 @@ +/** + * Name: SymbolGrammarRender.scala + * Creation: 14.01.2016 + * Author: Mika�l Mayer (mikael.mayer@epfl.ch) + * Comments: Grammar rendering specifications starting with symbols + */ + +import leon.lang._ +import leon.annotation._ +import leon.collection._ +import leon.collection.ListOps._ +import leon.lang.synthesis._ + +object GrammarRender { + /** A tagged symbol */ + abstract class Symbol + /** A tagged non-terminal, used for markovization */ + case class NonTerminal(tag: String, vcontext: List[Symbol], hcontext: List[Symbol]) extends Symbol + /** A tagged terminal */ + case class Terminal(tag: String) extends Symbol + + /** All possible right-hand-side of rules */ + case class Expansion(ls: List[List[Symbol]]) + /** A grammar here has a start sequence instead of a start symbol */ + case class Grammar(start: List[Symbol], rules: List[(Symbol, Expansion)]) + + def symbol_markov(s: Grammar): String = { + ???[String] + } ensuring { + (res : String) => (s, res) passes { + case Terminal("foo") => + "Tfoo" + case Terminal("\"'\n\t") => + "T\"'\n\t" + case NonTerminal("foo", Nil(), Nil()) => + "Nfoo" + case NonTerminal("\"'\n\t", Nil(), Nil()) => + "N\"'\n\t" + case NonTerminal("foo", Nil(), Cons(Terminal("foo"), Nil())) => + "Nfoo_hTfoo" + case NonTerminal("foo", Cons(Terminal("foo"), Nil()), Nil()) => + "Nfoo_vTfoo" + case NonTerminal("foo", Nil(), Cons(NonTerminal("foo", Nil(), Nil()), Cons(NonTerminal("foo", Nil(), Nil()), Nil()))) => + "Nfoo_hNfoo_Nfoo" + case NonTerminal("foo", Cons(Terminal("foo"), Nil()), Cons(NonTerminal("foo", Nil(), Nil()), Nil())) => + "Nfoo_vTfoo_hNfoo" + case NonTerminal("foo", Cons(NonTerminal("foo", Nil(), Nil()), Cons(NonTerminal("foo", Nil(), Nil()), Nil())), Nil()) => + "Nfoo_vNfoo_Nfoo" + } + } + + def grammarToString(p : Grammar): String = + ???[String] ask p +} \ No newline at end of file diff --git a/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala b/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala index fe01946d158153d2dd9ae2a3be2234ee4cd18aa9..0f30a5ba1a95d39e78a1594f39804c8161e919a6 100644 --- a/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala +++ b/testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala @@ -72,17 +72,12 @@ object BatchedQueue { def enqueue(v: T): Queue[T] = { require(invariant) - f match { - case Cons(h, t) => - Queue(f, Cons(v, r)) - case Nil() => - Queue(Cons(v, f), Nil()) - } - + ???[Queue[T]] } ensuring { (res: Queue[T]) => - res.invariant && - res.toList.last == v && - res.content == this.content ++ Set(v) + res.invariant && + res.toList.last == v && + res.size == size + 1 && + res.content == this.content ++ Set(v) } } } diff --git a/testcases/synthesis/etienne-thesis/run.sh b/testcases/synthesis/etienne-thesis/run.sh index ee64d86702076bf5ff909c3437f321498a2afe68..924b99cc57386f1dba92bfb97017b41a801cd8ea 100755 --- a/testcases/synthesis/etienne-thesis/run.sh +++ b/testcases/synthesis/etienne-thesis/run.sh @@ -1,7 +1,7 @@ #!/bin/bash function run { - cmd="./leon --debug=report --timeout=30 --synthesis $1" + cmd="./leon --debug=report --timeout=30 --synthesis --cegis:maxsize=5 $1" echo "Running " $cmd echo "------------------------------------------------------------------------------------------------------------------" $cmd; @@ -35,9 +35,9 @@ run testcases/synthesis/etienne-thesis/UnaryNumerals/Distinct.scala run testcases/synthesis/etienne-thesis/UnaryNumerals/Mult.scala # BatchedQueue -#run testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala +run testcases/synthesis/etienne-thesis/BatchedQueue/Enqueue.scala run testcases/synthesis/etienne-thesis/BatchedQueue/Dequeue.scala # AddressBook -#run testcases/synthesis/etienne-thesis/AddressBook/Make.scala +run testcases/synthesis/etienne-thesis/AddressBook/Make.scala run testcases/synthesis/etienne-thesis/AddressBook/Merge.scala diff --git a/testcases/verification/math/RationalProps.scala b/testcases/verification/math/RationalProps.scala index aec07246e2e8f76b3a054cd202e7beaae32f3da5..0b13ff1aa3b665bdcfdc6fcb15180a1fa86eaa2d 100644 --- a/testcases/verification/math/RationalProps.scala +++ b/testcases/verification/math/RationalProps.scala @@ -7,63 +7,57 @@ import scala.language.postfixOps object RationalProps { def squarePos(r: Rational): Rational = { - require(r.isRational) r * r } ensuring(_ >= Rational(0)) def additionIsCommutative(p: Rational, q: Rational): Boolean = { - require(p.isRational && q.isRational) p + q == q + p } holds def multiplicationIsCommutative(p: Rational, q: Rational): Boolean = { - require(p.isRational && q.isRational) p * q == q * p } holds def lessThanIsTransitive(p: Rational, q: Rational, r: Rational): Boolean = { - require(p.isRational && q.isRational && r.isRational && p < q && q < r) + require(p < q && q < r) p < r } holds def lessEqualsIsTransitive(p: Rational, q: Rational, r: Rational): Boolean = { - require(p.isRational && q.isRational && r.isRational && p <= q && q <= r) + require(p <= q && q <= r) p <= r } holds def greaterThanIsTransitive(p: Rational, q: Rational, r: Rational): Boolean = { - require(p.isRational && q.isRational && r.isRational && p > q && q > r) + require(p > q && q > r) p > r } holds def greaterEqualsIsTransitive(p: Rational, q: Rational, r: Rational): Boolean = { - require(p.isRational && q.isRational && r.isRational && p >= q && q >= r) + require(p >= q && q >= r) p >= r } holds def distributionMult(p: Rational, q: Rational, r: Rational): Boolean = { - require(p.isRational && q.isRational && r.isRational) (p*(q + r)) ~ (p*q + p*r) } holds def reciprocalIsCorrect(p: Rational): Boolean = { - require(p.isRational && p.nonZero) + require(p.nonZero) (p * p.reciprocal) ~ Rational(1) } holds def additiveInverseIsCorrect(p: Rational): Boolean = { - require(p.isRational) (p + (-p)) ~ Rational(0) } holds //should not hold because q could be 0 def divByZero(p: Rational, q: Rational): Boolean = { - require(p.isRational && q.isRational) ((p / q) * q) ~ p } holds def divByNonZero(p: Rational, q: Rational): Boolean = { - require(p.isRational && q.isRational && q.nonZero) + require(q.nonZero) ((p / q) * q) ~ p } holds @@ -73,17 +67,16 @@ object RationalProps { */ def equivalentIsReflexive(p: Rational): Boolean = { - require(p.isRational) p ~ p } holds def equivalentIsSymmetric(p: Rational, q: Rational): Boolean = { - require(p.isRational && q.isRational && p ~ q) + require(p ~ q) q ~ p } holds def equivalentIsTransitive(p: Rational, q: Rational, r: Rational): Boolean = { - require(p.isRational && q.isRational && r.isRational && p ~ q && q ~ r) + require(p ~ q && q ~ r) p ~ r } holds } diff --git a/testcases/verification/quantification/invalid/SizeInc.scala b/testcases/verification/quantification/invalid/SizeInc.scala deleted file mode 100644 index 970b3b53bddb295db19308e3443309f478ee15cc..0000000000000000000000000000000000000000 --- a/testcases/verification/quantification/invalid/SizeInc.scala +++ /dev/null @@ -1,17 +0,0 @@ -import leon.lang._ - -object SizeInc { - - abstract class List[A] - case class Cons[A](head: A, tail: List[A]) extends List[A] - case class Nil[A]() extends List[A] - - def failling_1[A](x: List[A]): Int => Int = { - (i: Int) => x match { - case Cons(head, tail) => 1 + failling_1(tail)(i) - case Nil() => i - } - } ensuring { res => forall((a: Int) => res(a) > 0) } -} - -// vim: set ts=4 sw=4 et: diff --git a/testcases/verification/quantification/valid/SizeInc.scala b/testcases/verification/quantification/valid/SizeInc.scala deleted file mode 100644 index 9fe7ea96eb44bba951ff620507bad750c1560056..0000000000000000000000000000000000000000 --- a/testcases/verification/quantification/valid/SizeInc.scala +++ /dev/null @@ -1,26 +0,0 @@ -import leon.lang._ - -object SizeInc { - - abstract class List[A] - case class Cons[A](head: A, tail: List[A]) extends List[A] - case class Nil[A]() extends List[A] - - def sizeInc[A](x: List[A]): BigInt => BigInt = { - (i: BigInt) => x match { - case Cons(head, tail) => 1 + sizeInc(tail)(i) - case Nil() => i - } - } ensuring { res => forall((a: BigInt) => a > 0 ==> res(a) > 0) } - - def sizeInc2[A](x: BigInt): List[A] => BigInt = { - require(x > 0) - - (list: List[A]) => list match { - case Cons(head, tail) => 1 + sizeInc2(x)(tail) - case Nil() => x - } - } ensuring { res => forall((a: List[A]) => res(a) > 0) } -} - -// vim: set ts=4 sw=4 et: diff --git a/testcases/verification/strings/invalid/CompatibleListChar.scala b/testcases/verification/strings/invalid/CompatibleListChar.scala new file mode 100644 index 0000000000000000000000000000000000000000..86eec34cddcee8055c34d5ffc791b7bbf7a397e7 --- /dev/null +++ b/testcases/verification/strings/invalid/CompatibleListChar.scala @@ -0,0 +1,29 @@ +import leon.lang._ +import leon.annotation._ +import leon.collection._ +import leon.collection.ListOps._ +import leon.lang.synthesis._ + +object CompatibleListChar { + def rec[T](l : List[T], f : T => String): String = l match { + case Cons(head, tail) => f(head) + rec(tail, f) + case Nil() => "" + } + def customToString[T](l : List[T], p: List[Char], d: String, fd: String => String, fp: List[Char] => String, pf: String => List[Char], f : T => String): String = rec(l, f) ensuring { + (res : String) => (p == Nil[Char]() || d == "" || fd(d) == "" || fp(p) == "" || pf(d) == Nil[Char]()) && ((l, res) passes { + case Cons(a, Nil()) => f(a) + }) + } + def customPatternMatching(s: String): Boolean = { + s match { + case "" => true + case b => List(b) match { + case Cons("", Nil()) => true + case Cons(s, Nil()) => false // StrOps.length(s) < BigInt(2) // || (s == "\u0000") //+ "a" + case Cons(_, Cons(_, Nil())) => true + case _ => false + } + case _ => false + } + } holds +} \ No newline at end of file diff --git a/testcases/verification/xlang/AbsFun.scala b/testcases/verification/xlang/AbsFun.scala index fe37632df68bc4ae1c164bb034902e45b18365c4..a6ff9679e27fc32ff4dd8b62a1cf2170386083d8 100644 --- a/testcases/verification/xlang/AbsFun.scala +++ b/testcases/verification/xlang/AbsFun.scala @@ -35,11 +35,7 @@ object AbsFun { isPositive(t, k)) if(k < tab.length) { - val nt = if(tab(k) < 0) { - t.updated(k, -tab(k)) - } else { - t.updated(k, tab(k)) - } + val nt = t.updated(k, if(tab(k) < 0) -tab(k) else tab(k)) while0(nt, k+1, tab) } else { (t, k) @@ -54,11 +50,7 @@ object AbsFun { def property(t: Array[Int], k: Int): Boolean = { require(isPositive(t, k) && t.length >= 0 && k >= 0) if(k < t.length) { - val nt = if(t(k) < 0) { - t.updated(k, -t(k)) - } else { - t.updated(k, t(k)) - } + val nt = t.updated(k, if(t(k) < 0) -t(k) else t(k)) isPositive(nt, k+1) } else true } holds diff --git a/testcases/web/synthesis/25_String_OutOfOrder.scala b/testcases/web/synthesis/25_String_OutOfOrder.scala index 72785cf605cfbfe26bb70abd37651c32d5eddd42..ec855caeee8e2f2eb67b0743438d85c7902dc47d 100644 --- a/testcases/web/synthesis/25_String_OutOfOrder.scala +++ b/testcases/web/synthesis/25_String_OutOfOrder.scala @@ -5,13 +5,13 @@ import leon.collection.ListOps._ import leon.lang.synthesis._ object OutOfOrderToString { - def argumentsToString(i: Int, j: Int): String = { + def arguments(i: Int, j: Int): String = { ??? } ensuring { (res: String) => ((i, j), res) passes { case (1, 2) => "2, 1" } } - def tupleToString(i: (Int, Int)): String = { + def tuple(i: (Int, Int)): String = { ??? } ensuring { (res: String) => (i, res) passes { case (1, 2) => "2, 1" @@ -27,7 +27,7 @@ object OutOfOrderToString { } } - def listPairToString(l : List[(Int, Int)]): String = { + def listPair(l : List[(Int, Int)]): String = { ???[String] } ensuring { (res : String) => (l, res) passes { @@ -36,7 +36,7 @@ object OutOfOrderToString { } } - def reverselistPairToString(l: List[(Int, Int)]): String = { + def reverselistPair(l: List[(Int, Int)]): String = { ??? } ensuring { (res: String) => (l, res) passes { case Cons((1, 2), Cons((3,4), Nil())) => "4 -> 3, 2 -> 1" @@ -44,7 +44,7 @@ object OutOfOrderToString { case class Rule(input: Int, applied: Option[Int]) - def ruleToString(r : Rule): String = { + def rule(r : Rule): String = { ??? } ensuring { (res : String) => (r, res) passes { diff --git a/testcases/web/synthesis/26_Modular_Render.scala b/testcases/web/synthesis/26_Modular_Render.scala new file mode 100644 index 0000000000000000000000000000000000000000..d180d7e4c8ac4b217136b023afe9bd9c0d833005 --- /dev/null +++ b/testcases/web/synthesis/26_Modular_Render.scala @@ -0,0 +1,34 @@ +/** + * Name: ModularRender.scala + * Creation: 26.01.2015 + * Author: Mikael Mayer (mikael.mayer@epfl.ch) + * Comments: Modular rendering, in one order or the other. + */ +import leon.lang._ +import leon.annotation._ +import leon.collection._ +import leon.collection.ListOps._ +import leon.lang.synthesis._ + +object ModularRender { + + def booleanToString(b: Boolean) = if(b) "Up" else "Down" + + def intToString(b: BigInt) = b.toString + + def customToString[T](l : List[T], f : (T) => String): String = + ???[String] ask l + + case class Configuration(flags: List[Boolean], strokes: List[BigInt]) + + // We want to write + // Solution: + // [Up, Down, Up....] + // [1, 2, 5, ...] + def ConfigToString(config : Configuration): String = + ???[String] ask config + + /** Wrong lemma for demonstration */ + def configurationsAreSimple(c: Configuration) = + (c.flags.size < 3 || c.strokes.size < 2 || c.flags(0) == c.flags(1) || c.strokes(0) == c.strokes(1)) holds +}