diff --git a/src/main/java/leon/codegen/runtime/CaseClass.java b/src/main/java/leon/codegen/runtime/CaseClass.java index 8cda1d23aa6e74705525bacd35bc529e1a8fb3fc..50243c5c6109b4dc13e67c7cd36b5fe765cc3938 100644 --- a/src/main/java/leon/codegen/runtime/CaseClass.java +++ b/src/main/java/leon/codegen/runtime/CaseClass.java @@ -2,7 +2,9 @@ package leon.codegen.runtime; -public interface CaseClass { +public interface CaseClass { + public abstract int __getRead(); + public abstract Object[] productElements(); public abstract String productName(); diff --git a/src/main/java/leon/codegen/runtime/Tuple.java b/src/main/java/leon/codegen/runtime/Tuple.java index ea1e0880559ee7084a1f3ff9c7295aed42955958..6bd7279dfa41f15dcc92410ee2891cfa088d2c49 100644 --- a/src/main/java/leon/codegen/runtime/Tuple.java +++ b/src/main/java/leon/codegen/runtime/Tuple.java @@ -5,6 +5,12 @@ package leon.codegen.runtime; import java.util.Arrays; public final class Tuple { + private int __read = 0; + + public final int __getRead() { + return __read; + } + private final Object[] elements; // You may think that using varargs here would show less of the internals, @@ -18,6 +24,8 @@ public final class Tuple { if(index < 0 || index >= this.elements.length) { throw new IllegalArgumentException("Invalid tuple index : " + index); } + __read = (1 << (index)) | __read; + return this.elements[index]; } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index 5752b140b825bde0da19e4fe059322e00be8f16c..62e4f9f9dab420e65abbb59a3e8c9845f82cedad 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -19,15 +19,15 @@ object CodeGeneration { private val BoxedIntClass = "java/lang/Integer" private val BoxedBoolClass = "java/lang/Boolean" - private val TupleClass = "leon/codegen/runtime/Tuple" - private val SetClass = "leon/codegen/runtime/Set" - private val MapClass = "leon/codegen/runtime/Map" - private val CaseClassClass = "leon/codegen/runtime/CaseClass" - private val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" + private val TupleClass = "leon/codegen/runtime/Tuple" + private val SetClass = "leon/codegen/runtime/Set" + private val MapClass = "leon/codegen/runtime/Map" + private val CaseClassClass = "leon/codegen/runtime/CaseClass" + private val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" private val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException" private val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" - def defToJVMName(p : Program, d : Definition) : String = "Leon$CodeGen$" + d.id.uniqueName + def defToJVMName(d : Definition)(implicit env : CompilationEnvironment) : String = "Leon$CodeGen$" + d.id.uniqueName def typeToJVM(tpe : TypeTree)(implicit env : CompilationEnvironment) : String = tpe match { case Int32Type => "I" @@ -59,13 +59,13 @@ object CodeGeneration { def compileFunDef(funDef : FunDef, ch : CodeHandler)(implicit env : CompilationEnvironment) { val newMapping = funDef.args.map(_.id).zipWithIndex.toMap - val bodyWithPre = if(funDef.hasPrecondition) { + val bodyWithPre = if(funDef.hasPrecondition && env.compileContracts) { IfExpr(funDef.precondition.get, funDef.getBody, Error("Precondition failed")) } else { funDef.getBody } - val bodyWithPost = if(funDef.hasPostcondition) { + val bodyWithPost = if(funDef.hasPostcondition && env.compileContracts) { val freshResID = FreshIdentifier("result").setType(funDef.returnType) val post = purescala.TreeOps.replace(Map(ResultVariable() -> Variable(freshResID)), funDef.postcondition.get) Let(freshResID, bodyWithPre, IfExpr(post, Variable(freshResID), Error("Postcondition failed")) ) @@ -164,7 +164,7 @@ object CodeGeneration { throw CompilationException("Unknown class : " + ccd.id) } ch << CheckCast(ccName) - ch << GetField(ccName, sid.name, typeToJVM(sid.getType)) + instrumentedGetField(ch, ccd, sid) // Tuples (note that instanceOf checks are in mkBranch) case Tuple(es) => @@ -495,8 +495,8 @@ object CodeGeneration { } } - def compileAbstractClassDef(p : Program, acd : AbstractClassDef)(implicit env : CompilationEnvironment) : ClassFile = { - val cName = defToJVMName(p, acd) + def compileAbstractClassDef(acd : AbstractClassDef)(implicit env : CompilationEnvironment) : ClassFile = { + val cName = defToJVMName(acd) val cf = new ClassFile(cName, None) cf.setFlags(( @@ -512,10 +512,36 @@ object CodeGeneration { cf } - def compileCaseClassDef(p : Program, ccd : CaseClassDef)(implicit env : CompilationEnvironment) : ClassFile = { + var doInstrument = true + + /** + * Instrument read operations + */ + val instrumentedField = "__read" + + def instrumentedGetField(ch: CodeHandler, ccd: CaseClassDef, id: Identifier)(implicit env : CompilationEnvironment): Unit = { + ccd.fields.zipWithIndex.find(_._1.id == id) match { + case Some((f, i)) => + val cName = defToJVMName(ccd) + if (doInstrument) { + ch << DUP << DUP + ch << GetField(cName, instrumentedField, "I") + ch << Ldc(1) + ch << Ldc(i) + ch << ISHL + ch << IOR + ch << PutField(cName, instrumentedField, "I") + } + ch << GetField(cName, f.id.name, typeToJVM(f.tpe)) + case None => + throw CompilationException("Unknown field: "+ccd.id.name+"."+id) + } + } + + def compileCaseClassDef(ccd : CaseClassDef)(implicit env : CompilationEnvironment) : ClassFile = { - val cName = defToJVMName(p, ccd) - val pName = ccd.parent.map(parent => defToJVMName(p, parent)) + val cName = defToJVMName(ccd) + val pName = ccd.parent.map(parent => defToJVMName(parent)) val cf = new ClassFile(cName, pName) cf.setFlags(( @@ -528,11 +554,12 @@ object CodeGeneration { cf.addInterface(CaseClassClass) } + val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.tpe)) } + // definition of the constructor - if(ccd.fields.isEmpty) { + if(!doInstrument && ccd.fields.isEmpty) { cf.addDefaultConstructor } else { - val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.tpe)) } for((nme, jvmt) <- namesTypes) { val fh = cf.addField(jvmt, nme) @@ -542,11 +569,22 @@ object CodeGeneration { ).asInstanceOf[U2]) } + if (doInstrument) { + val fh = cf.addField("I", instrumentedField) + fh.setFlags(FIELD_ACC_PUBLIC) + } + val cch = cf.addConstructor(namesTypes.map(_._2).toList).codeHandler cch << ALoad(0) cch << InvokeSpecial(pName.getOrElse("java/lang/Object"), constructorName, "()V") + if (doInstrument) { + cch << ALoad(0) + cch << Ldc(0) + cch << PutField(cName, instrumentedField, "I") + } + var c = 1 for((nme, jvmt) <- namesTypes) { cch << ALoad(0) @@ -561,6 +599,20 @@ object CodeGeneration { cch.freeze } + locally { + val pnm = cf.addMethod("I", "__getRead") + pnm.setFlags(( + METHOD_ACC_PUBLIC | + METHOD_ACC_FINAL + ).asInstanceOf[U2]) + + val pnch = pnm.codeHandler + + pnch << ALoad(0) << GetField(cName, instrumentedField, "I") << IRETURN + + pnch.freeze + } + locally { val pnm = cf.addMethod("Ljava/lang/String;", "productName") pnm.setFlags(( @@ -591,7 +643,7 @@ object CodeGeneration { pech << DUP pech << Ldc(i) pech << ALoad(0) - pech << GetField(cName, f.id.name, typeToJVM(f.tpe)) + instrumentedGetField(pech, ccd, f.id) mkBox(f.tpe, pech) pech << AASTORE } @@ -624,13 +676,13 @@ object CodeGeneration { if(!ccd.fields.isEmpty) { ech << ALoad(1) << CheckCast(cName) << AStore(castSlot) - val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.tpe)) } - - for((nme, jvmt) <- namesTypes) { - ech << ALoad(0) << GetField(cName, nme, jvmt) - ech << ALoad(castSlot) << GetField(cName, nme, jvmt) + for(vd <- ccd.fields) { + ech << ALoad(0) + instrumentedGetField(ech, ccd, vd.id) + ech << ALoad(castSlot) + instrumentedGetField(ech, ccd, vd.id) - jvmt match { + typeToJVM(vd.id.getType) match { case "I" | "Z" => ech << If_ICmpNe(notEq) diff --git a/src/main/scala/leon/codegen/CompilationEnvironment.scala b/src/main/scala/leon/codegen/CompilationEnvironment.scala index 52f4ec3e8e1e12e099490ec15cc721dca9e017d5..c364b715311f2f0122db1c2ac6654280f15e98e1 100644 --- a/src/main/scala/leon/codegen/CompilationEnvironment.scala +++ b/src/main/scala/leon/codegen/CompilationEnvironment.scala @@ -12,6 +12,10 @@ abstract class CompilationEnvironment() { // - a mapping of function defs to class + method name // - a mapping of class defs to class names // - a mapping of class fields to fields + + val program: Program + + val compileContracts: Boolean // Returns (JVM) name of class, and signature of constructor def classDefToClass(classDef : ClassTypeDef) : Option[String] @@ -24,6 +28,8 @@ abstract class CompilationEnvironment() { /** Augment the environment with new local var. mappings. */ def withVars(pairs : Map[Identifier,Int]) = { new CompilationEnvironment { + val program = self.program + val compileContracts = self.compileContracts def classDefToClass(classDef : ClassTypeDef) = self.classDefToClass(classDef) def funDefToMethod(funDef : FunDef) = self.funDefToMethod(funDef) def varToLocal(v : Identifier) = pairs.get(v).orElse(self.varToLocal(v)) @@ -32,21 +38,25 @@ abstract class CompilationEnvironment() { } object CompilationEnvironment { - def fromProgram(p : Program) : CompilationEnvironment = { + def fromProgram(p : Program, _compileContracts: Boolean) : CompilationEnvironment = { import CodeGeneration.typeToJVM // This should change: it should contain the case classes before // we go and generate function signatures. implicit val initial = new CompilationEnvironment { + val program = p + + val compileContracts = _compileContracts + private val cNames : Map[ClassTypeDef,String] = - p.definedClasses.map(c => (c, CodeGeneration.defToJVMName(p, c))).toMap + p.definedClasses.map(c => (c, CodeGeneration.defToJVMName(c)(this))).toMap def classDefToClass(classDef : ClassTypeDef) = cNames.get(classDef) def funDefToMethod(funDef : FunDef) = None def varToLocal(v : Identifier) = None } - val className = CodeGeneration.defToJVMName(p, p.mainObject) + val className = CodeGeneration.defToJVMName(p.mainObject) val fs = p.definedFunctions.filter(_.hasImplementation) @@ -57,6 +67,10 @@ object CompilationEnvironment { }).toMap new CompilationEnvironment { + val program = p + + val compileContracts = initial.compileContracts + def classDefToClass(classDef : ClassTypeDef) = initial.classDefToClass(classDef) def funDefToMethod(funDef : FunDef) = fMap.get(funDef) def varToLocal(v : Identifier) = None diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index d62fb9fbc3a621b0037b881914fbfb8f046cc943..bbb79d1b3e0b6f3f66a9ce3df7c08f34a45d6039 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -22,7 +22,7 @@ import CodeGeneration._ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFile], implicit val env: CompilationEnvironment) { - private val jvmClassToDef = classes.map { + val jvmClassToDef = classes.map { case (d, cf) => cf.className -> d }.toMap @@ -43,7 +43,12 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi }).toMap } - val a = 42 + private lazy val tupleConstructor: Constructor[_] = { + val tc = loader.loadClass("leon.codegen.runtime.Tuple") + val conss = tc.getConstructors().sortBy(_.getParameterTypes().length) + assert(!conss.isEmpty) + conss.last + } private def writeClassFiles() { for ((d, cl) <- classes) { @@ -67,6 +72,9 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi case BooleanLiteral(v) => new java.lang.Boolean(v) + case Tuple(elems) => + tupleConstructor.newInstance(elems.map(valueToJVM).toArray).asInstanceOf[AnyRef] + case CaseClass(ccd, args) => val cons = caseClassConstructors(ccd) cons.newInstance(args.map(valueToJVM).toArray : _*).asInstanceOf[AnyRef] @@ -179,24 +187,24 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi } object CompilationUnit { - def compileProgram(p: Program): Option[CompilationUnit] = { - implicit val env = CompilationEnvironment.fromProgram(p) + def compileProgram(p: Program, compileContracts: Boolean = true): Option[CompilationUnit] = { + implicit val env = CompilationEnvironment.fromProgram(p, compileContracts) var classes = Map[Definition, ClassFile]() for((parent,children) <- p.algebraicDataTypes) { - classes += parent -> compileAbstractClassDef(p, parent) + classes += parent -> compileAbstractClassDef(parent) for (c <- children) { - classes += c -> compileCaseClassDef(p, c) + classes += c -> compileCaseClassDef(c) } } for(single <- p.singleCaseClasses) { - classes += single -> compileCaseClassDef(p, single) + classes += single -> compileCaseClassDef(single) } - val mainClassName = defToJVMName(p, p.mainObject) + val mainClassName = defToJVMName(p.mainObject) val cf = new ClassFile(mainClassName, None) classes += p.mainObject -> cf diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index cf5369d89ae0ae26d6eb09113ab8909da18992fe..d1b4bbb4d763039ecf67fb583db01a094201bd35 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -22,19 +22,23 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression : Expr private val exprType = expression.getType - protected[codegen] def evalToJVM(args: Seq[Expr]): AnyRef = { + def argsToJVM(args: Seq[Expr]): Seq[AnyRef] = { + args.map(unit.valueToJVM) + } + + def evalToJVM(args: Seq[AnyRef]): AnyRef = { assert(args.size == argsDecl.size) if (args.isEmpty) { meth.invoke(null) } else { - meth.invoke(null, args.map(unit.valueToJVM).toArray : _*) + meth.invoke(null, args.toArray : _*) } } // This may throw an exception. We unwrap it if needed. // We also need to reattach a type in some cases (sets, maps). - def eval(args: Seq[Expr]) : Expr = { + def evalFromJVM(args: Seq[AnyRef]) : Expr = { try { val result = unit.jvmToValue(evalToJVM(args)) if(!result.isTyped) { @@ -45,4 +49,12 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression : Expr case ite : InvocationTargetException => throw ite.getCause() } } + + def eval(args: Seq[Expr]) : Expr = { + try { + evalFromJVM(argsToJVM(args)) + } catch { + case ite : InvocationTargetException => throw ite.getCause() + } + } } diff --git a/src/main/scala/leon/datagen/DataGenerator.scala b/src/main/scala/leon/datagen/DataGenerator.scala new file mode 100644 index 0000000000000000000000000000000000000000..add01c033cc4ee5113bcf3c06987d73dd6554ff9 --- /dev/null +++ b/src/main/scala/leon/datagen/DataGenerator.scala @@ -0,0 +1,9 @@ +package leon +package datagen + +import purescala.Trees._ +import purescala.Common._ + +trait DataGenerator { + def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]]; +} diff --git a/src/main/scala/leon/purescala/DataGen.scala b/src/main/scala/leon/datagen/NaiveDataGen.scala similarity index 90% rename from src/main/scala/leon/purescala/DataGen.scala rename to src/main/scala/leon/datagen/NaiveDataGen.scala index 02e559d03d5bda9d46e09d3c4d5b87f1e609e97b..f540c59b08ea891ad9b2de31cfbf42f4c2b2f896 100644 --- a/src/main/scala/leon/purescala/DataGen.scala +++ b/src/main/scala/leon/datagen/NaiveDataGen.scala @@ -1,7 +1,7 @@ /* Copyright 2009-2013 EPFL, Lausanne */ package leon -package purescala +package datagen import purescala.Common._ import purescala.Trees._ @@ -16,11 +16,14 @@ import scala.collection.mutable.{Map=>MutableMap} /** Utility functions to generate values of a given type. * In fact, it could be used to generate *terms* of a given type, * e.g. by passing trees representing variables for the "bounds". */ -object DataGen { +class NaiveDataGen(ctx: LeonContext, p: Program, evaluator: Evaluator, _bounds : Option[Map[TypeTree,Seq[Expr]]] = None) extends DataGenerator { + private val defaultBounds : Map[TypeTree,Seq[Expr]] = Map( Int32Type -> Seq(IntLiteral(0), IntLiteral(1), IntLiteral(-1)) ) + val bounds = _bounds.getOrElse(defaultBounds) + private val boolStream : Stream[Expr] = Stream.cons(BooleanLiteral(true), Stream.cons(BooleanLiteral(false), Stream.empty)) @@ -111,20 +114,19 @@ object DataGen { } } - def findModels(expr : Expr, evaluator : Evaluator, maxModels : Int, maxTries : Int, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds, forcedFreeVars: Option[Seq[Identifier]] = None) : Stream[Map[Identifier,Expr]] = { - val freeVars : Seq[Identifier] = forcedFreeVars.getOrElse(variablesOf(expr).toSeq) - - evaluator.compile(expr, freeVars).map { evalFun => + //def findModels(expr : Expr, maxModels : Int, maxTries : Int, bounds : Map[TypeTree,Seq[Expr]] = defaultBounds, forcedFreeVars: Option[Seq[Identifier]] = None) : Stream[Map[Identifier,Expr]] = { + def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid : Int, maxEnumerated : Int) : Iterator[Seq[Expr]] = { + evaluator.compile(satisfying, ins).map { evalFun => val sat = EvaluationResults.Successful(BooleanLiteral(true)) - naryProduct(freeVars.map(id => generate(id.getType, bounds))) - .take(maxTries) + naryProduct(ins.map(id => generate(id.getType, bounds))) + .take(maxEnumerated) .filter{s => evalFun(s) == sat } - .take(maxModels) - .map(s => freeVars.zip(s).toMap) + .take(maxValid) + .iterator } getOrElse { - Stream.empty + Stream.empty.iterator } } diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala new file mode 100644 index 0000000000000000000000000000000000000000..5215670e3fc8f7e1e095038165cea124c9e2aa50 --- /dev/null +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -0,0 +1,277 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package datagen + +import purescala.Common._ +import purescala.Definitions._ +import purescala.TreeOps._ +import purescala.Trees._ +import purescala.TypeTrees._ +import purescala.Extractors.TopLevelAnds + +import codegen.CompilationUnit +import vanuatoo.{Pattern => VPattern, _} + +import evaluators._ + +class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { + val unit = CompilationUnit.compileProgram(p, compileContracts = false).get + + val ints = (for (i <- Set(0, 1, 2, 3)) yield { + i -> Constructor[Expr, TypeTree](List(), Int32Type, s => IntLiteral(i), ""+i) + }).toMap + + val booleans = (for (b <- Set(true, false)) yield { + b -> Constructor[Expr, TypeTree](List(), BooleanType, s => BooleanLiteral(b), ""+b) + }).toMap + + def intConstructor(i: Int) = ints(i) + + def boolConstructor(b: Boolean) = booleans(b) + + def cPattern(c: Constructor[Expr, TypeTree], args: Seq[VPattern[Expr, TypeTree]]) = { + ConstructorPattern[Expr, TypeTree](c, args) + } + + private var ccConstructors = Map[CaseClassDef, Constructor[Expr, TypeTree]]() + private var acConstructors = Map[AbstractClassDef, List[Constructor[Expr, TypeTree]]]() + private var tConstructors = Map[TupleType, Constructor[Expr, TypeTree]]() + + private def getConstructorFor(t: CaseClassType, act: AbstractClassType): Constructor[Expr, TypeTree] = { + // We "up-cast" the returnType of the specific caseclass generator to match its superclass + getConstructors(t)(0).copy(retType = act) + } + + + private def getConstructors(t: TypeTree): List[Constructor[Expr, TypeTree]] = t match { + case tt @ TupleType(parts) => + List(tConstructors.getOrElse(tt, { + val c = Constructor[Expr, TypeTree](parts, tt, s => Tuple(s).setType(tt), tt.toString) + tConstructors += tt -> c + c + })) + + case act @ AbstractClassType(acd) => + acConstructors.getOrElse(acd, { + val cs = acd.knownDescendents.collect { + case ccd: CaseClassDef => + getConstructorFor(CaseClassType(ccd), act) + }.toList + + acConstructors += acd -> cs + + cs + }) + + case CaseClassType(ccd) => + List(ccConstructors.getOrElse(ccd, { + val c = Constructor[Expr, TypeTree](ccd.fields.map(_.tpe), CaseClassType(ccd), s => CaseClass(ccd, s), ccd.id.name) + ccConstructors += ccd -> c + c + })) + + case _ => + ctx.reporter.error("Unknown type to generate constructor for: "+t) + Nil + } + + private def valueToPattern(v: AnyRef, expType: TypeTree): (VPattern[Expr, TypeTree], Boolean) = (v, expType) match { + case (i: Integer, Int32Type) => + (cPattern(intConstructor(i), List()), true) + + case (b: java.lang.Boolean, BooleanType) => + (cPattern(boolConstructor(b), List()), true) + + case (cc: codegen.runtime.CaseClass, ct: ClassType) => + val r = cc.__getRead() + + unit.jvmClassToDef.get(cc.getClass.getName) match { + case Some(ccd: CaseClassDef) => + val c = ct match { + case act : AbstractClassType => + getConstructorFor(CaseClassType(ccd), act) + case cct : CaseClassType => + getConstructors(CaseClassType(ccd))(0) + } + + val fields = cc.productElements() + + val elems = for (i <- 0 until fields.length) yield { + if (((r >> i) & 1) == 1) { + // has been read + valueToPattern(fields(i), ccd.fieldsIds(i).getType) + } else { + (AnyPattern[Expr, TypeTree](), false) + } + } + + (ConstructorPattern(c, elems.map(_._1)), elems.forall(_._2)) + + case _ => + sys.error("Could not retreive type for :"+cc.getClass.getName) + } + + case (t: codegen.runtime.Tuple, tt @ TupleType(parts)) => + val r = t.__getRead() + + val c = getConstructors(tt)(0) + + val elems = for (i <- 0 until t.getArity) yield { + if (((r >> i) & 1) == 1) { + // has been read + valueToPattern(t.get(i), parts(i)) + } else { + (AnyPattern[Expr, TypeTree](), false) + } + } + + (ConstructorPattern(c, elems.map(_._1)), elems.forall(_._2)) + + case _ => + sys.error("Unsupported value, can't paternify : "+v+" : "+expType) + } + + type InstrumentedResult = (EvaluationResults.Result, Option[vanuatoo.Pattern[Expr, TypeTree]]) + + def compile(expression : Expr, argorder : Seq[Identifier]) : Option[Tuple=>InstrumentedResult] = { + import leon.codegen.runtime.LeonCodeGenRuntimeException + import leon.codegen.runtime.LeonCodeGenEvaluationException + + try { + val ttype = TupleType(argorder.map(_.getType)) + val tid = FreshIdentifier("tup").setType(ttype) + + val map = argorder.zipWithIndex.map{ case (id, i) => (id -> TupleSelect(Variable(tid), i+1)) }.toMap + + val newExpr = replaceFromIDs(map, expression) + + val ce = unit.compileExpression(newExpr, Seq(tid)) + + Some((args : Tuple) => { + try { + val jvmArgs = ce.argsToJVM(Seq(args)) + + val result = ce.evalFromJVM(jvmArgs) + + // jvmArgs is getting updated by evaluating + val pattern = valueToPattern(jvmArgs(0), ttype) + + (EvaluationResults.Successful(result), if (!pattern._2) Some(pattern._1) else None) + } catch { + case e : ArithmeticException => + (EvaluationResults.RuntimeError(e.getMessage), None) + + case e : ArrayIndexOutOfBoundsException => + (EvaluationResults.RuntimeError(e.getMessage), None) + + case e : LeonCodeGenRuntimeException => + (EvaluationResults.RuntimeError(e.getMessage), None) + + case e : LeonCodeGenEvaluationException => + (EvaluationResults.EvaluatorError(e.getMessage), None) + } + }) + } catch { + case t: Throwable => + ctx.reporter.warning("Error while compiling expression: "+t.getMessage) + None + } + } + + def generateFor(ins: Seq[Identifier], satisfying: Expr, maxValid: Int, maxEnumerated: Int): Iterator[Seq[Expr]] = { + // Split conjunctions + val TopLevelAnds(ands) = satisfying + + val runners = ands.map(a => compile(a, ins) match { + case Some(runner) => Some(runner) + case None => + ctx.reporter.error("Could not compile predicate "+a) + None + }).flatten + + + val gen = new StubGenerator[Expr, TypeTree]((ints.values ++ booleans.values).toSeq, + Some(getConstructors _), + treatEmptyStubsAsChildless = true) + + var found = Set[Seq[Expr]]() + + /** + * Gather at most <n> isomoprhic models before skipping them + * - Too little means skipping many excluding patterns + * - Too large means repetitive (and not useful models) before reaching maxEnumerated + */ + + val maxIsomorphicModels = maxValid+1; + + val it = gen.enumerate(TupleType(ins.map(_.getType))) + + return new Iterator[Seq[Expr]] { + var total = 0 + var found = 0; + + var theNext: Option[Seq[Expr]] = None + + def hasNext() = { + if (total == 0) { + theNext = computeNext() + } + + theNext != None + } + + def next() = { + val res = theNext.get + theNext = computeNext() + res + } + + + def computeNext(): Option[Seq[Expr]] = { + while(total < maxEnumerated && found < maxValid && it.hasNext) { + val model = it.next.asInstanceOf[Tuple] + + if (model eq null) { + total = maxEnumerated + } else { + total += 1 + + var failed = false; + + for (r <- runners) r(model) match { + case (EvaluationResults.Successful(BooleanLiteral(true)), _) => + + case (_, Some(pattern)) => + failed = true; + it.exclude(pattern) + + case (_, None) => + failed = true; + } + + if (!failed) { + println("Got model:") + for ((i, v) <- (ins zip model.exprs)) { + println(" - "+i+" -> "+v) + } + + found += 1 + + if (found % maxIsomorphicModels == 0) { + it.skipIsomorphic() + } + + return Some(model.exprs); + } + + if (total % 1000 == 0) { + println("... "+total+" ...") + } + } + } + None + } + } + } +} diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index d59e67b40aeb9e1944a18bd07d4031deb41ed1dc..3e81b76cbf8a234aa922b5a356d92ed0c7a7f247 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1775,10 +1775,13 @@ object TreeOps { val recSelectors = ccd.fieldsIds.filter(_.getType == on.getType) if (recSelectors.isEmpty) { - None + Seq() } else { val v = Variable(on) - Some(And(And(isType, expr), Not(replace(recSelectors.map(s => v -> CaseClassSelector(ccd, v, s)).toMap, expr)))) + + recSelectors.map{ s => + And(And(isType, expr), Not(replace(Map(v -> CaseClassSelector(ccd, v, s)), expr))) + } } }.flatten @@ -1786,7 +1789,10 @@ object TreeOps { solver.solveSAT(cond) match { case (Some(false), _) => true - case (_, model) => + case (Some(true), model) => + false + case (None, _) => + // Should we be optimistic here? false } } diff --git a/src/main/scala/leon/synthesis/SynthesisOptions.scala b/src/main/scala/leon/synthesis/SynthesisOptions.scala index 37408d3f7513ea09856baf3f5684d05395fc55a2..0e7294f726579b4f22c2ab13b7c3e60cb03cd8de 100644 --- a/src/main/scala/leon/synthesis/SynthesisOptions.scala +++ b/src/main/scala/leon/synthesis/SynthesisOptions.scala @@ -18,5 +18,6 @@ case class SynthesisOptions( cegisGenerateFunCalls: Boolean = false, cegisUseCETests: Boolean = true, cegisUseCEPruning: Boolean = true, - cegisUseBPaths: Boolean = true + cegisUseBPaths: Boolean = true, + cegisUseVanuatoo: Boolean = false ) diff --git a/src/main/scala/leon/synthesis/SynthesisPhase.scala b/src/main/scala/leon/synthesis/SynthesisPhase.scala index 1251de0efb0662132a82f9073a606e026edc550c..9346b5530552316073b2b17688c537d04873d986 100644 --- a/src/main/scala/leon/synthesis/SynthesisPhase.scala +++ b/src/main/scala/leon/synthesis/SynthesisPhase.scala @@ -24,7 +24,8 @@ object SynthesisPhase extends LeonPhase[Program, Program] { LeonValueOptionDef( "timeout", "--timeout=T", "Timeout after T seconds when searching for synthesis solutions .."), LeonValueOptionDef( "costmodel", "--costmodel=cm", "Use a specific cost model for this search"), LeonValueOptionDef( "functions", "--functions=f1:f2", "Limit synthesis of choose found within f1,f2,.."), - LeonFlagOptionDef( "cegis:gencalls", "--cegis:gencalls", "Include function calls in CEGIS generators") + LeonFlagOptionDef( "cegis:gencalls", "--cegis:gencalls", "Include function calls in CEGIS generators"), + LeonFlagOptionDef( "cegis:vanuatoo", "--cegis:vanuatoo", "Generate inputs using new korat-style generator") ) def processOptions(ctx: LeonContext): SynthesisOptions = { @@ -75,6 +76,9 @@ object SynthesisPhase extends LeonPhase[Program, Program] { case LeonFlagOption("cegis:gencalls") => options = options.copy(cegisGenerateFunCalls = true) + case LeonFlagOption("cegis:vanuatoo") => + options = options.copy(cegisUseVanuatoo = true) + case LeonFlagOption("derivtrees") => options = options.copy(generateDerivationTrees = true) diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 6e4879944c6d54120bb7af23de6854cca986323b..20116d24b0fe30966f121f1b60ddbf84ea1bee1c 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -106,6 +106,9 @@ class Synthesizer(val context : LeonContext, if (vcreport.totalValid == vcreport.totalConditions) { (sol, true) + } else if (vcreport.totalValid + vcreport.totalUnknown == vcreport.totalConditions) { + reporter.warning("Solution may be invalid:") + (sol, false) } else { reporter.warning("Solution was invalid:") reporter.warning(fds.map(ScalaPrinter(_)).mkString("\n\n")) diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 294a54f8f4e8caad453149410651fc3d36b492a2..bed95ffd8e682e7360b959eaae7e4215cdb209ff 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -6,7 +6,6 @@ package rules import solvers.TimeoutSolver import purescala.Trees._ -import purescala.DataGen import purescala.Common._ import purescala.Definitions._ import purescala.TypeTrees._ @@ -17,6 +16,7 @@ import purescala.ScalaPrinter import scala.collection.mutable.{Map=>MutableMap} import evaluators._ +import datagen._ import solvers.z3.FairZ3Solver @@ -30,6 +30,7 @@ case object CEGIS extends Rule("CEGIS") { val useOptTimeout = true val useFunGenerators = sctx.options.cegisGenerateFunCalls val useBPaths = sctx.options.cegisUseBPaths + val useVanuatoo = sctx.options.cegisUseVanuatoo val useCETests = sctx.options.cegisUseCETests val useCEPruning = sctx.options.cegisUseCEPruning // Limits the number of programs CEGIS will specifically test for instead of reasonning symbolically @@ -451,11 +452,11 @@ case object CEGIS extends Rule("CEGIS") { val exSolver = new TimeoutSolver(sctx.solver, 3000L) // 3sec val cexSolver = new TimeoutSolver(sctx.solver, 3000L) // 3sec - var exampleInputs = Set[Seq[Expr]]() + var baseExampleInputs: Seq[Seq[Expr]] = Seq() // We populate the list of examples with a predefined one if (p.pc == BooleanLiteral(true)) { - exampleInputs += p.as.map(a => simplestValue(a.getType)) + baseExampleInputs = p.as.map(a => simplestValue(a.getType)) +: baseExampleInputs } else { val solver = exSolver.getNewSolver @@ -464,7 +465,7 @@ case object CEGIS extends Rule("CEGIS") { solver.check match { case Some(true) => val model = solver.getModel - exampleInputs += p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) + baseExampleInputs = p.as.map(a => model.getOrElse(a, simplestValue(a.getType))) +: baseExampleInputs case Some(false) => return RuleApplicationImpossible @@ -476,10 +477,26 @@ case object CEGIS extends Rule("CEGIS") { } - val discoveredInputs = DataGen.findModels(p.pc, evaluator, 20, 1000, forcedFreeVars = Some(p.as)).map{ - m => p.as.map(a => m(a)) + val inputIterator: Iterator[Seq[Expr]] = if (useVanuatoo) { + new VanuatooDataGen(sctx.context, sctx.program).generateFor(p.as, p.pc, 20, 3000) + } else { + new NaiveDataGen(sctx.context, sctx.program, evaluator).generateFor(p.as, p.pc, 20, 1000) + } + + val cachedInputIterator = new Iterator[Seq[Expr]] { + def next() = { + val i = inputIterator.next() + baseExampleInputs = i +: baseExampleInputs + i + } + + def hasNext() = inputIterator.hasNext } + def hasInputExamples() = baseExampleInputs.size > 0 || cachedInputIterator.hasNext + + def allInputExamples() = baseExampleInputs.iterator ++ cachedInputIterator + def checkForPrograms(programs: Set[Set[Identifier]]): RuleApplicationResult = { for (prog <- programs) { val expr = ndProgram.determinize(prog) @@ -500,10 +517,6 @@ case object CEGIS extends Rule("CEGIS") { RuleApplicationImpossible } - // println("Generating tests..") - // println("Found: "+discoveredInputs.size) - exampleInputs ++= discoveredInputs - // Keep track of collected cores to filter programs to test var collectedCores = Set[Set[Identifier]]() @@ -558,9 +571,10 @@ case object CEGIS extends Rule("CEGIS") { //println("#Tests: "+exampleInputs.size) // We further filter the set of working programs to remove those that fail on known examples - if (useCEPruning && !exampleInputs.isEmpty && ndProgram.canTest()) { + if (useCEPruning && hasInputExamples() && ndProgram.canTest()) { + for (p <- prunedPrograms) { - if (!exampleInputs.forall(ndProgram.testForProgram(p))) { + if (!allInputExamples().forall(ndProgram.testForProgram(p))) { // This program failed on at least one example solver1.assertCnstr(Not(And(p.map(Variable(_)).toSeq))) prunedPrograms -= p @@ -630,11 +644,11 @@ case object CEGIS extends Rule("CEGIS") { // println(". "+c+" = "+ex) //} - val validateWithZ3 = if (useCETests && !exampleInputs.isEmpty && ndProgram.canTest()) { + val validateWithZ3 = if (useCETests && hasInputExamples() && ndProgram.canTest()) { val p = bssAssumptions.collect { case Variable(b) => b } - if (exampleInputs.forall(ndProgram.testForProgram(p))) { + if (allInputExamples().forall(ndProgram.testForProgram(p))) { // All valid inputs also work with this, we need to // make sure by validating this candidate with z3 true @@ -661,7 +675,7 @@ case object CEGIS extends Rule("CEGIS") { val newCE = p.as.map(valuateWithModel(invalidModel)) - exampleInputs += newCE + baseExampleInputs = newCE +: baseExampleInputs //println("Found counter example: "+fixedAss) diff --git a/src/test/scala/leon/test/purescala/DataGen.scala b/src/test/scala/leon/test/purescala/DataGen.scala index a3c9d5a40f1816cc403c32589f38de9810c1faaa..65b20298ca0112bdefcdf1017eb7c60370b0980b 100644 --- a/src/test/scala/leon/test/purescala/DataGen.scala +++ b/src/test/scala/leon/test/purescala/DataGen.scala @@ -10,7 +10,7 @@ import leon.purescala.Common._ import leon.purescala.Trees._ import leon.purescala.Definitions._ import leon.purescala.TypeTrees._ -import leon.purescala.DataGen._ +import leon.datagen._ import leon.evaluators._ @@ -41,11 +41,6 @@ class DataGen extends FunSuite { program } - test("Booleans") { - generate(BooleanType).toSet.size === 2 - generate(TupleType(Seq(BooleanType,BooleanType))).toSet.size === 4 - } - test("Lists") { val p = """|object Program { | sealed abstract class List @@ -75,6 +70,12 @@ class DataGen extends FunSuite { val prog = parseString(p) + val eval = new DefaultEvaluator(leonContext, prog) + val generator = new NaiveDataGen(leonContext, prog, eval) + + generator.generate(BooleanType).toSet.size === 2 + generator.generate(TupleType(Seq(BooleanType,BooleanType))).toSet.size === 4 + val listType : TypeTree = classDefToClassType(prog.mainObject.classHierarchyRoots.head) val sizeDef : FunDef = prog.definedFunctions.find(_.id.name == "size").get val sortedDef : FunDef = prog.definedFunctions.find(_.id.name == "isSorted").get @@ -83,7 +84,7 @@ class DataGen extends FunSuite { val consDef : CaseClassDef = prog.mainObject.caseClassDef("Cons") - generate(listType).take(100).toSet.size === 100 + generator.generate(listType).take(100).toSet.size === 100 val evaluator = new CodeGenEvaluator(leonContext, prog) @@ -98,34 +99,34 @@ class DataGen extends FunSuite { val sortedX = FunctionInvocation(sortedDef, Seq(x)) val sortedY = FunctionInvocation(sortedDef, Seq(y)) - assert(findModels( + assert(generator.generateFor( + Seq(x.id), GreaterThan(sizeX, IntLiteral(0)), - evaluator, 10, 500 ).size === 10) - assert(findModels( + assert(generator.generateFor( + Seq(x.id, y.id), And(Equals(contentX, contentY), sortedY), - evaluator, 10, 500 ).size === 10) - assert(findModels( + assert(generator.generateFor( + Seq(x.id, y.id), And(Seq(Equals(contentX, contentY), sortedX, sortedY, Not(Equals(x, y)))), - evaluator, 1, 500 ).isEmpty, "There should be no models for this problem") - assert(findModels( + assert(generator.generateFor( + Seq(x.id, y.id, b.id, a.id), And(Seq( LessThan(a, b), FunctionInvocation(sortedDef, Seq(CaseClass(consDef, Seq(a, x)))), FunctionInvocation(insSpecDef, Seq(b, x, y)) )), - evaluator, 10, 500 ).size >= 5, "There should be at least 5 models for this problem.") diff --git a/testcases/synthesis/oopsla2013/SortedBinaryTree/Batch.scala b/testcases/synthesis/oopsla2013/SortedBinaryTree/Batch.scala new file mode 100644 index 0000000000000000000000000000000000000000..3d430f65c7d418e1e32cfe9cc3e214de91d2431f --- /dev/null +++ b/testcases/synthesis/oopsla2013/SortedBinaryTree/Batch.scala @@ -0,0 +1,80 @@ +import scala.collection.immutable.Set +import leon.Annotations._ +import leon.Utils._ + +object BinaryTree { + sealed abstract class Tree + case class Node(left : Tree, value : Int, right : Tree) extends Tree + case object Leaf extends Tree + + def content(t : Tree): Set[Int] = t match { + case Leaf => Set.empty[Int] + case Node(l, v, r) => content(l) ++ Set(v) ++ content(r) + } + + sealed abstract class OptPair + case class Pair(v1 : Int, v2 : Int) extends OptPair + case object NoPair extends OptPair + + def isSortedX(t : Tree) : (Boolean, OptPair) = (t match { + case Leaf => (true, NoPair) + case Node(Leaf, v, Leaf) => (true, Pair(v, v)) + case Node(Node(_, lv, _), v, _) if lv >= v => (false, NoPair) + case Node(_, v, Node(_, rv, _)) if rv <= v => (false, NoPair) + + case Node(l, v, r) => + val (ls,lb) = isSortedX(l) + + val (lOK,newMin) = lb match { + case NoPair => (ls, v) + case Pair(ll, lh) => (ls && lh < v, ll) + } + + if(lOK) { + val (rs,rb) = isSortedX(r) + val (rOK,newMax) = rb match { + case NoPair => (rs, v) + case Pair(rl, rh) => (rs && v < rl, rh) + } + + if(rOK) { + (true, Pair(newMin, newMax)) + } else { + (false, NoPair) + } + } else { + (false, NoPair) + } + }) ensuring((res : (Boolean,OptPair)) => res match { + case (s, Pair(l,u)) => s && (l <= u) + case _ => true + }) + + def isSorted(t: Tree): Boolean = isSortedX(t)._1 + + def deleteSynth(in : Tree, v : Int) = choose { + (out : Tree) => content(out) == (content(in) -- Set(v)) + } + + // def insertImpl(t : Tree, x : Int) : Tree = { + // require(isSorted(t)) + // t match { + // case Leaf => Node(Leaf, x, Leaf) + // case Node(l, v, r) if v == x => Node(l, v, r) + // case Node(l, v, r) if x < v => Node(insertImpl(l, x), v, r) + // case Node(l, v, r) if v < x => Node(l, v, insertImpl(r, x)) + // } + // } ensuring(isSorted(_)) + + def insertSynth(in : Tree, v : Int) = choose { + (out : Tree) => content(out) == (content(in) ++ Set(v)) + } + + def insertSortedSynth(in : Tree, v : Int) = choose { + (out : Tree) => isSorted(in) && (content(out) == (content(in) ++ Set(v))) && isSorted(out) + } + + def deleteSortedSynth(in : Tree, v : Int) = choose { + (out : Tree) => isSorted(in) && (content(out) == (content(in) -- Set(v))) && isSorted(out) + } +} diff --git a/unmanaged/32/vanuatoo_2.10-0.1.jar b/unmanaged/32/vanuatoo_2.10-0.1.jar new file mode 120000 index 0000000000000000000000000000000000000000..5148a49942a1aed689b0dc4d5b6ec5729aef14e1 --- /dev/null +++ b/unmanaged/32/vanuatoo_2.10-0.1.jar @@ -0,0 +1 @@ +../common/vanuatoo_2.10-0.1.jar \ No newline at end of file diff --git a/unmanaged/64/vanuatoo_2.10-0.1.jar b/unmanaged/64/vanuatoo_2.10-0.1.jar new file mode 120000 index 0000000000000000000000000000000000000000..5148a49942a1aed689b0dc4d5b6ec5729aef14e1 --- /dev/null +++ b/unmanaged/64/vanuatoo_2.10-0.1.jar @@ -0,0 +1 @@ +../common/vanuatoo_2.10-0.1.jar \ No newline at end of file diff --git a/unmanaged/common/vanuatoo_2.10-0.1.jar b/unmanaged/common/vanuatoo_2.10-0.1.jar new file mode 100644 index 0000000000000000000000000000000000000000..47eb1687a76e47f9d117b13b647da86d1eb8b340 Binary files /dev/null and b/unmanaged/common/vanuatoo_2.10-0.1.jar differ