diff --git a/src/main/scala/leon/codegen/CodeGenPhase.scala b/src/main/scala/leon/codegen/CodeGenPhase.scala index 4a6110d893d190e111f901e1de4850831f5dbd8f..d905b179c8376e2bf47f9ac8551f35ccb49345b6 100644 --- a/src/main/scala/leon/codegen/CodeGenPhase.scala +++ b/src/main/scala/leon/codegen/CodeGenPhase.scala @@ -3,6 +3,8 @@ package leon package codegen +import scala.util.control.NonFatal + import purescala.Common._ import purescala.Definitions._ @@ -15,14 +17,12 @@ object CodeGenPhase extends LeonPhase[Program,CompilationResult] { val description = "Compiles a Leon program into Java methods" def run(ctx : LeonContext)(p : Program) : CompilationResult = { - import CodeGeneration._ - - CompilationUnit.compileProgram(p) match { - case Some(unit) => - //unit.writeClassFiles() - CompilationResult(successful = true) - case None => - CompilationResult(successful = false) + try { + val unit = new CompilationUnit(ctx, p); + unit.writeClassFiles() + CompilationResult(successful = true) + } catch { + case NonFatal(e) => CompilationResult(successful = false) } } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index c3b799b5b3230c26eda72017c0a6a1e12bd9992e..1a846991f3570ebf07b38c1498d0d87a16d085f3 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -15,22 +15,38 @@ import cafebabe.ClassFileTypes._ import cafebabe.Defaults.constructorName import cafebabe.Flags._ -object CodeGeneration { - private val BoxedIntClass = "java/lang/Integer" - private val BoxedBoolClass = "java/lang/Boolean" - - private val TupleClass = "leon/codegen/runtime/Tuple" - private val SetClass = "leon/codegen/runtime/Set" - private val MapClass = "leon/codegen/runtime/Map" - private val CaseClassClass = "leon/codegen/runtime/CaseClass" - private val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" - private val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException" - private val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" - private[codegen] val MonitorClass = "leon/codegen/runtime/LeonCodeGenRuntimeMonitor" - - def defToJVMName(d : Definition)(implicit env : CompilationEnvironment) : String = "Leon$CodeGen$" + d.id.uniqueName - - def typeToJVM(tpe : TypeTree)(implicit env : CompilationEnvironment) : String = tpe match { +trait CodeGeneration { + self: CompilationUnit => + + case class Locals(vars: Map[Identifier, Int]) { + def varToLocal(v: Identifier): Option[Int] = vars.get(v) + + def withVars(newVars: Map[Identifier, Int]) = { + Locals(vars ++ newVars) + } + + def withVar(nv: (Identifier, Int)) = { + Locals(vars + nv) + } + } + + object NoLocals extends Locals(Map()) + + private[codegen] val BoxedIntClass = "java/lang/Integer" + private[codegen] val BoxedBoolClass = "java/lang/Boolean" + private[codegen] val TupleClass = "leon/codegen/runtime/Tuple" + private[codegen] val SetClass = "leon/codegen/runtime/Set" + private[codegen] val MapClass = "leon/codegen/runtime/Map" + private[codegen] val CaseClassClass = "leon/codegen/runtime/CaseClass" + private[codegen] val ErrorClass = "leon/codegen/runtime/LeonCodeGenRuntimeException" + private[codegen] val ImpossibleEvaluationClass = "leon/codegen/runtime/LeonCodeGenEvaluationException" + private[codegen] val HashingClass = "leon/codegen/runtime/LeonCodeGenRuntimeHashing" + private[codegen] val ChooseEntryPointClass = "leon/codegen/runtime/ChooseEntryPoint" + private[codegen] val MonitorClass = "leon/codegen/runtime/LeonCodeGenRuntimeMonitor" + + def defToJVMName(d : Definition) : String = "Leon$CodeGen$" + d.id.uniqueName + + def typeToJVM(tpe : TypeTree) : String = tpe match { case Int32Type => "I" case BooleanType => "Z" @@ -38,7 +54,7 @@ object CodeGeneration { case UnitType => "Z" case c : ClassType => - env.classDefToClass(c.classDef).map(n => "L" + n + ";").getOrElse("Unsupported class " + c.id) + leonClassToJVMClass(c.classDef).map(n => "L" + n + ";").getOrElse("Unsupported class " + c.id) case _ : TupleType => "L" + TupleClass + ";" @@ -57,8 +73,8 @@ object CodeGeneration { // Assumes the CodeHandler has never received any bytecode. // Generates method body, and freezes the handler at the end. - def compileFunDef(funDef : FunDef, ch : CodeHandler)(implicit env : CompilationEnvironment) { - val newMapping = if (env.params.requireMonitor) { + def compileFunDef(funDef : FunDef, ch : CodeHandler) { + val newMapping = if (params.requireMonitor) { funDef.args.map(_.id).zipWithIndex.toMap.mapValues(_ + 1) } else { funDef.args.map(_.id).zipWithIndex.toMap @@ -66,13 +82,13 @@ object CodeGeneration { val body = funDef.body.getOrElse(throw CompilationException("Can't compile a FunDef without body")) - val bodyWithPre = if(funDef.hasPrecondition && env.params.checkContracts) { + val bodyWithPre = if(funDef.hasPrecondition && params.checkContracts) { IfExpr(funDef.precondition.get, body, Error("Precondition failed")) } else { body } - val bodyWithPost = if(funDef.hasPostcondition && env.params.checkContracts) { + val bodyWithPost = if(funDef.hasPostcondition && params.checkContracts) { val Some((id, post)) = funDef.postcondition Let(id, bodyWithPre, IfExpr(post, Variable(id), Error("Postcondition failed")) ) } else { @@ -81,11 +97,11 @@ object CodeGeneration { val exprToCompile = purescala.TreeOps.matchToIfThenElse(bodyWithPost) - if (env.params.recordInvocations) { + if (params.recordInvocations) { ch << ALoad(0) << InvokeVirtual(MonitorClass, "onInvoke", "()V") } - mkExpr(exprToCompile, ch)(env.withVars(newMapping)) + mkExpr(exprToCompile, ch)(Locals(newMapping)) funDef.returnType match { case Int32Type | BooleanType | UnitType => @@ -101,7 +117,7 @@ object CodeGeneration { ch.freeze } - private[codegen] def mkExpr(e : Expr, ch : CodeHandler, canDelegateToMkBranch : Boolean = true)(implicit env : CompilationEnvironment) { + private[codegen] def mkExpr(e: Expr, ch: CodeHandler, canDelegateToMkBranch: Boolean = true)(implicit locals: Locals) { e match { case Variable(id) => val slot = slotFor(id) @@ -119,7 +135,7 @@ object CodeGeneration { case _ => AStore(slot) } ch << instr - mkExpr(b, ch)(env.withVars(Map(i -> slot))) + mkExpr(b, ch)(locals.withVar(i -> slot)) case LetTuple(is,d,b) => mkExpr(d, ch) // the tuple @@ -138,7 +154,7 @@ object CodeGeneration { count += 1 } ch << POP - mkExpr(b, ch)(env.withVars(withSlots.toMap)) + mkExpr(b, ch)(locals.withVars(withSlots.toMap)) case IntLiteral(v) => ch << Ldc(v) @@ -151,7 +167,7 @@ object CodeGeneration { // Case classes case CaseClass(ccd, as) => - val ccName = env.classDefToClass(ccd).getOrElse { + val ccName = leonClassToJVMClass(ccd).getOrElse { throw CompilationException("Unknown class : " + ccd.id) } // TODO FIXME It's a little ugly that we do it each time. Could be in env. @@ -163,7 +179,7 @@ object CodeGeneration { ch << InvokeSpecial(ccName, constructorName, consSig) case CaseClassInstanceOf(ccd, e) => - val ccName = env.classDefToClass(ccd).getOrElse { + val ccName = leonClassToJVMClass(ccd).getOrElse { throw CompilationException("Unknown class : " + ccd.id) } mkExpr(e, ch) @@ -171,7 +187,7 @@ object CodeGeneration { case CaseClassSelector(ccd, e, sid) => mkExpr(e, ch) - val ccName = env.classDefToClass(ccd).getOrElse { + val ccName = leonClassToJVMClass(ccd).getOrElse { throw CompilationException("Unknown class : " + ccd.id) } ch << CheckCast(ccName) @@ -275,10 +291,10 @@ object CodeGeneration { ch << Label(al) case FunctionInvocation(fd, as) => - val (cn, mn, ms) = env.funDefToMethod(fd).getOrElse { + val (cn, mn, ms) = leonFunDefToJVMInfo(fd).getOrElse { throw CompilationException("Unknown method : " + fd.id) } - if (env.params.requireMonitor) { + if (params.requireMonitor) { ch << ALoad(0) } for(a <- as) { @@ -351,11 +367,27 @@ object CodeGeneration { ch << InvokeSpecial(ErrorClass, constructorName, "(Ljava/lang/String;)V") ch << ATHROW - case Choose(_, _) => - ch << New(ImpossibleEvaluationClass) << DUP - ch << Ldc("Cannot execute choose.") - ch << InvokeSpecial(ImpossibleEvaluationClass, constructorName, "(Ljava/lang/String;)V") - ch << ATHROW + case choose @ Choose(_, _) => + val prob = synthesis.Problem.fromChoose(choose) + + val id = runtime.ChooseEntryPoint.register(prob, this); + ch << Ldc(id) + + + ch << Ldc(prob.as.size) + ch << NewArray("java/lang/Object") + + for ((id, i) <- prob.as.zipWithIndex) { + ch << DUP + ch << Ldc(i) + mkExpr(Variable(id), ch) + mkBox(id.getType, ch) + ch << AASTORE + } + + ch << InvokeStatic(ChooseEntryPointClass, "invoke", "(I[Ljava/lang/Object;)Ljava/lang/Object;") + + mkUnbox(choose.getType, ch) case b if b.getType == BooleanType && canDelegateToMkBranch => val fl = ch.getFreshLabel("boolfalse") @@ -369,7 +401,7 @@ object CodeGeneration { } // Leaves on the stack a value equal to `e`, always of a type compatible with java.lang.Object. - private[codegen] def mkBoxedExpr(e : Expr, ch : CodeHandler)(implicit env : CompilationEnvironment) { + private[codegen] def mkBoxedExpr(e: Expr, ch: CodeHandler)(implicit locals: Locals) { e.getType match { case Int32Type => ch << New(BoxedIntClass) << DUP @@ -388,7 +420,7 @@ object CodeGeneration { // Assumes the top of the stack contains of value of the right type, and makes it // compatible with java.lang.Object. - private[codegen] def mkBox(tpe : TypeTree, ch : CodeHandler)(implicit env : CompilationEnvironment) { + private[codegen] def mkBox(tpe: TypeTree, ch: CodeHandler)(implicit locals: Locals) { tpe match { case Int32Type => ch << New(BoxedIntClass) << DUP_X1 << SWAP << InvokeSpecial(BoxedIntClass, constructorName, "(I)V") @@ -401,7 +433,7 @@ object CodeGeneration { } // Assumes that the top of the stack contains a value that should be of type `tpe`, and unboxes it to the right (JVM) type. - private[codegen] def mkUnbox(tpe : TypeTree, ch : CodeHandler)(implicit env : CompilationEnvironment) { + private[codegen] def mkUnbox(tpe: TypeTree, ch: CodeHandler)(implicit locals: Locals) { tpe match { case Int32Type => ch << CheckCast(BoxedIntClass) << InvokeVirtual(BoxedIntClass, "intValue", "()I") @@ -410,7 +442,7 @@ object CodeGeneration { ch << CheckCast(BoxedBoolClass) << InvokeVirtual(BoxedBoolClass, "booleanValue", "()Z") case ct : ClassType => - val cn = env.classDefToClass(ct.classDef).getOrElse { + val cn = leonClassToJVMClass(ct.classDef).getOrElse { throw new CompilationException("Unsupported class type : " + ct) } ch << CheckCast(cn) @@ -429,7 +461,7 @@ object CodeGeneration { } } - private[codegen] def mkBranch(cond : Expr, thenn : String, elze : String, ch : CodeHandler, canDelegateToMkExpr : Boolean = true)(implicit env : CompilationEnvironment) { + private[codegen] def mkBranch(cond: Expr, thenn: String, elze: String, ch: CodeHandler, canDelegateToMkExpr: Boolean = true)(implicit locals: Locals) { cond match { case BooleanLiteral(true) => ch << Goto(thenn) @@ -503,16 +535,17 @@ object CodeGeneration { } } - private[codegen] def slotFor(id : Identifier)(implicit env : CompilationEnvironment) : Int = { - env.varToLocal(id).getOrElse { - throw CompilationException("Unknown variable : " + id) + private[codegen] def slotFor(id: Identifier)(implicit locals: Locals) : Int = { + locals.varToLocal(id).getOrElse { + throw CompilationException("Unknown variable: " + id) } } - def compileAbstractClassDef(acd : AbstractClassDef)(implicit env : CompilationEnvironment) : ClassFile = { + def compileAbstractClassDef(acd : AbstractClassDef) { val cName = defToJVMName(acd) - val cf = new ClassFile(cName, None) + val cf = classes(acd) + cf.setFlags(( CLASS_ACC_SUPER | CLASS_ACC_PUBLIC | @@ -522,8 +555,6 @@ object CodeGeneration { cf.addInterface(CaseClassClass) cf.addDefaultConstructor - - cf } /** @@ -531,11 +562,11 @@ object CodeGeneration { */ val instrumentedField = "__read" - def instrumentedGetField(ch: CodeHandler, ccd: CaseClassDef, id: Identifier)(implicit env : CompilationEnvironment): Unit = { + def instrumentedGetField(ch: CodeHandler, ccd: CaseClassDef, id: Identifier)(implicit locals: Locals): Unit = { ccd.fields.zipWithIndex.find(_._1.id == id) match { case Some((f, i)) => val cName = defToJVMName(ccd) - if (env.params.doInstrument) { + if (params.doInstrument) { ch << DUP << DUP ch << GetField(cName, instrumentedField, "I") ch << Ldc(1) @@ -550,12 +581,13 @@ object CodeGeneration { } } - def compileCaseClassDef(ccd : CaseClassDef)(implicit env : CompilationEnvironment) : ClassFile = { + def compileCaseClassDef(ccd: CaseClassDef) { val cName = defToJVMName(ccd) val pName = ccd.parent.map(parent => defToJVMName(parent)) - val cf = new ClassFile(cName, pName) + val cf = classes(ccd) + cf.setFlags(( CLASS_ACC_SUPER | CLASS_ACC_PUBLIC | @@ -569,10 +601,9 @@ object CodeGeneration { val namesTypes = ccd.fields.map { vd => (vd.id.name, typeToJVM(vd.tpe)) } // definition of the constructor - if(!env.params.doInstrument && ccd.fields.isEmpty) { + if(!params.doInstrument && ccd.fields.isEmpty) { cf.addDefaultConstructor } else { - for((nme, jvmt) <- namesTypes) { val fh = cf.addField(jvmt, nme) fh.setFlags(( @@ -581,7 +612,7 @@ object CodeGeneration { ).asInstanceOf[U2]) } - if (env.params.doInstrument) { + if (params.doInstrument) { val fh = cf.addField("I", instrumentedField) fh.setFlags(FIELD_ACC_PUBLIC) } @@ -591,7 +622,7 @@ object CodeGeneration { cch << ALoad(0) cch << InvokeSpecial(pName.getOrElse("java/lang/Object"), constructorName, "()V") - if (env.params.doInstrument) { + if (params.doInstrument) { cch << ALoad(0) cch << Ldc(0) cch << PutField(cName, instrumentedField, "I") @@ -655,8 +686,8 @@ object CodeGeneration { pech << DUP pech << Ldc(i) pech << ALoad(0) - instrumentedGetField(pech, ccd, f.id) - mkBox(f.tpe, pech) + instrumentedGetField(pech, ccd, f.id)(NoLocals) + mkBox(f.tpe, pech)(NoLocals) pech << AASTORE } @@ -690,9 +721,9 @@ object CodeGeneration { for(vd <- ccd.fields) { ech << ALoad(0) - instrumentedGetField(ech, ccd, vd.id) + instrumentedGetField(ech, ccd, vd.id)(NoLocals) ech << ALoad(castSlot) - instrumentedGetField(ech, ccd, vd.id) + instrumentedGetField(ech, ccd, vd.id)(NoLocals) typeToJVM(vd.id.getType) match { case "I" | "Z" => @@ -736,6 +767,5 @@ object CodeGeneration { hch.freeze } - cf } } diff --git a/src/main/scala/leon/codegen/CompilationEnvironment.scala b/src/main/scala/leon/codegen/CompilationEnvironment.scala deleted file mode 100644 index 21732db3d294ee4f1e2349de371d7c32620eb8b7..0000000000000000000000000000000000000000 --- a/src/main/scala/leon/codegen/CompilationEnvironment.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2009-2013 EPFL, Lausanne */ - -package leon -package codegen - -import purescala.Common._ -import purescala.Definitions._ - -abstract class CompilationEnvironment() { - self => - // Should contain: - // - a mapping of function defs to class + method name - // - a mapping of class defs to class names - // - a mapping of class fields to fields - - val program: Program - - val params: CodeGenParams - - // Returns (JVM) name of class, and signature of constructor - def classDefToClass(classDef : ClassTypeDef) : Option[String] - - // Return (JVM) name of enclosing class, name of method, and signature - def funDefToMethod(funDef : FunDef) : Option[(String,String,String)] - - def varToLocal(v : Identifier) : Option[Int] - - /** Augment the environment with new local var. mappings. */ - def withVars(pairs : Map[Identifier,Int]) = { - new CompilationEnvironment { - val program = self.program - val params = self.params - def classDefToClass(classDef : ClassTypeDef) = self.classDefToClass(classDef) - def funDefToMethod(funDef : FunDef) = self.funDefToMethod(funDef) - def varToLocal(v : Identifier) = pairs.get(v).orElse(self.varToLocal(v)) - } - } -} - -object CompilationEnvironment { - def fromProgram(p : Program, _params: CodeGenParams) : CompilationEnvironment = { - import CodeGeneration.typeToJVM - - // This should change: it should contain the case classes before - // we go and generate function signatures. - implicit val initial = new CompilationEnvironment { - val program = p - - val params = _params - - private val cNames : Map[ClassTypeDef,String] = - p.definedClasses.map(c => (c, CodeGeneration.defToJVMName(c)(this))).toMap - - def classDefToClass(classDef : ClassTypeDef) = cNames.get(classDef) - def funDefToMethod(funDef : FunDef) = None - def varToLocal(v : Identifier) = None - } - - val className = CodeGeneration.defToJVMName(p.mainObject) - - val fs = p.definedFunctions.filter(_.hasImplementation) - - val monitorType = if (_params.requireMonitor) { - "L" + CodeGeneration.MonitorClass + ";" - } else { - "" - } - - val fMap : Map[FunDef,(String,String,String)] = (fs.map { fd => - val sig = "(" + monitorType + fd.args.map(a => typeToJVM(a.tpe)).mkString("") + ")" + typeToJVM(fd.returnType) - - fd -> (className, fd.id.uniqueName, sig) - }).toMap - - new CompilationEnvironment { - val program = p - - val params = initial.params - - def classDefToClass(classDef : ClassTypeDef) = initial.classDefToClass(classDef) - def funDefToMethod(funDef : FunDef) = fMap.get(funDef) - def varToLocal(v : Identifier) = None - } - } -} diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index c612803e6eeb77cbd24be8193394313fe0045c71..1b3ce23c050ef05431ce2fe7f42f16b6ca3746c8 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -18,54 +18,95 @@ import scala.collection.JavaConverters._ import java.lang.reflect.Constructor -import CodeGeneration._ +class CompilationUnit(val ctx: LeonContext, + val program: Program, + val params: CodeGenParams = CodeGenParams()) extends CodeGeneration { -class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFile], implicit val env: CompilationEnvironment) { + val loader = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader) - val jvmClassToDef = classes.map { - case (d, cf) => cf.className -> d - }.toMap + var classes = Map[Definition, ClassFile]() - protected[codegen] val loader = { - val l = new CafebabeClassLoader(classOf[CompilationUnit].getClassLoader) - classes.values.foreach(l.register(_)) - l + def defineClass(df: Definition) { + val cName = defToJVMName(df) + + val cf = df match { + case ccd: CaseClassDef => + val pName = ccd.parent.map(parent => defToJVMName(parent)) + new ClassFile(cName, pName) + + case acd: AbstractClassDef => + new ClassFile(cName, None) + + case ob: ObjectDef => + new ClassFile(cName, None) + + case _ => + sys.error("Unhandled definition type") + } + + classes += df -> cf } - private val caseClassConstructors : Map[CaseClassDef,Constructor[_]] = { - (classes collect { - case (ccd : CaseClassDef, cf) => - val klass = loader.loadClass(cf.className) - // This is a hack: we pick the constructor with the most arguments. - val conss = klass.getConstructors().sortBy(_.getParameterTypes().length) - assert(!conss.isEmpty) - (ccd -> conss.last) - }).toMap + def jvmClassToLeonClass(name: String): Option[Definition] = { + classes.find(_._2.className == name).map(_._1) } - private lazy val tupleConstructor: Constructor[_] = { - val tc = loader.loadClass("leon.codegen.runtime.Tuple") - val conss = tc.getConstructors().sortBy(_.getParameterTypes().length) - assert(!conss.isEmpty) - conss.last + def leonClassToJVMClass(cd: Definition): Option[String] = { + classes.get(cd).map(_.className) } - private def writeClassFiles() { - for ((d, cl) <- classes) { - cl.writeToFile(cl.className + ".class") + // Returns className, methodName, methodSignature + private[this] var funDefInfo = Map[FunDef, (String, String, String)]() + + def leonFunDefToJVMInfo(fd: FunDef): Option[(String, String, String)] = { + funDefInfo.get(fd).orElse { + val monitorType = if (params.requireMonitor) "L"+MonitorClass+";" else "" + + val sig = "(" + monitorType + fd.args.map(a => typeToJVM(a.tpe)).mkString("") + ")" + typeToJVM(fd.returnType) + + leonClassToJVMClass(program.mainObject) match { + case Some(cn) => + val res = (cn, fd.id.uniqueName, sig) + funDefInfo += fd -> res + Some(res) + case None => + None + } } } - private var _nextExprId = 0 - private def nextExprId = { - _nextExprId += 1 - _nextExprId + // Get the Java constructor corresponding to the Case class + private[this] var ccdConstructors = Map[CaseClassDef, Constructor[_]]() + + private[this] def caseClassConstructor(ccd: CaseClassDef): Option[Constructor[_]] = { + ccdConstructors.get(ccd).orElse { + classes.get(ccd) match { + case Some(cf) => + val klass = loader.loadClass(cf.className) + // This is a hack: we pick the constructor with the most arguments. + val conss = klass.getConstructors().sortBy(_.getParameterTypes().length) + assert(!conss.isEmpty) + val cons = conss.last + + ccdConstructors += ccd -> cons + Some(cons) + case None => + None + } + } + } + + private[this] lazy val tupleConstructor: Constructor[_] = { + val tc = loader.loadClass("leon.codegen.runtime.Tuple") + val conss = tc.getConstructors().sortBy(_.getParameterTypes().length) + assert(!conss.isEmpty) + conss.last } // Currently, this method is only used to prepare arguments to reflective calls. // This means it is safe to return AnyRef (as opposed to primitive types), because // reflection needs this anyway. - private[codegen] def valueToJVM(e: Expr): AnyRef = e match { + private[codegen] def exprToJVM(e: Expr): AnyRef = e match { case IntLiteral(v) => new java.lang.Integer(v) @@ -73,16 +114,20 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi new java.lang.Boolean(v) case Tuple(elems) => - tupleConstructor.newInstance(elems.map(valueToJVM).toArray).asInstanceOf[AnyRef] + tupleConstructor.newInstance(elems.map(exprToJVM).toArray).asInstanceOf[AnyRef] case CaseClass(ccd, args) => - val cons = caseClassConstructors(ccd) - cons.newInstance(args.map(valueToJVM).toArray : _*).asInstanceOf[AnyRef] + caseClassConstructor(ccd) match { + case Some(cons) => + cons.newInstance(args.map(exprToJVM).toArray : _*).asInstanceOf[AnyRef] + case None => + ctx.reporter.fatalError("Case class constructor not found?!?") + } // For now, we only treat boolean arrays separately. // We have a use for these, mind you. case f @ FiniteArray(exprs) if f.getType == ArrayType(BooleanType) => - exprs.map(e => valueToJVM(e).asInstanceOf[java.lang.Boolean].booleanValue).toArray + exprs.map(e => exprToJVM(e).asInstanceOf[java.lang.Boolean].booleanValue).toArray // Just slightly overkill... case _ => @@ -90,7 +135,7 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi } // Note that this may produce untyped expressions! (typically: sets, maps) - private[codegen] def jvmToValue(e: AnyRef): Expr = e match { + private[codegen] def jvmToExpr(e: AnyRef): Expr = e match { case i: Integer => IntLiteral(i.toInt) @@ -100,26 +145,26 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi case cc: runtime.CaseClass => val fields = cc.productElements() - jvmClassToDef.get(e.getClass.getName) match { + jvmClassToLeonClass(e.getClass.getName) match { case Some(cc: CaseClassDef) => - CaseClass(cc, fields.map(jvmToValue)) + CaseClass(cc, fields.map(jvmToExpr)) case _ => throw CompilationException("Unsupported return value : " + e) } case tpl: runtime.Tuple => val elems = for (i <- 0 until tpl.getArity) yield { - jvmToValue(tpl.get(i)) + jvmToExpr(tpl.get(i)) } Tuple(elems) case set : runtime.Set => - FiniteSet(set.getElements().asScala.map(jvmToValue).toSeq) + FiniteSet(set.getElements().asScala.map(jvmToExpr).toSeq) case map : runtime.Map => val pairs = map.getElements().asScala.map { entry => - val k = jvmToValue(entry.getKey()) - val v = jvmToValue(entry.getValue()) + val k = jvmToExpr(entry.getKey()) + val v = jvmToExpr(entry.getValue()) (k, v) } FiniteMap(pairs.toSeq) @@ -133,7 +178,7 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi throw new IllegalArgumentException("Cannot compile untyped expression [%s].".format(e)) } - val id = nextExprId + val id = CompilationUnit.nextExprId val cName = "Leon$CodeGen$Expr$"+id @@ -147,8 +192,8 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi val argsTypes = args.map(a => typeToJVM(a.getType)) - val realArgs = if (env.params.requireMonitor) { - ("L" + CodeGeneration.MonitorClass + ";") +: argsTypes + val realArgs = if (params.requireMonitor) { + ("L" + MonitorClass + ";") +: argsTypes } else { argsTypes } @@ -167,7 +212,7 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi val ch = m.codeHandler - val newMapping = if (env.params.requireMonitor) { + val newMapping = if (params.requireMonitor) { args.zipWithIndex.toMap.mapValues(_ + 1) } else { args.zipWithIndex.toMap @@ -175,7 +220,7 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi val exprToCompile = purescala.TreeOps.matchToIfThenElse(e) - mkExpr(e, ch)(env.withVars(newMapping)) + mkExpr(e, ch)(Locals(newMapping)) e.getType match { case Int32Type | BooleanType => @@ -195,31 +240,8 @@ class CompilationUnit(val program: Program, val classes: Map[Definition, ClassFi new CompiledExpression(this, cf, e, args) } - // writeClassFiles -} - -object CompilationUnit { - def compileProgram(p: Program, params: CodeGenParams = CodeGenParams()): Option[CompilationUnit] = { - implicit val env = CompilationEnvironment.fromProgram(p, params) - - var classes = Map[Definition, ClassFile]() - - for((parent,children) <- p.algebraicDataTypes) { - classes += parent -> compileAbstractClassDef(parent) - - for (c <- children) { - classes += c -> compileCaseClassDef(c) - } - } - - for(single <- p.singleCaseClasses) { - classes += single -> compileCaseClassDef(single) - } - - val mainClassName = defToJVMName(p.mainObject) - val cf = new ClassFile(mainClassName, None) - - classes += p.mainObject -> cf + def compileMainObject() { + val cf = classes(program.mainObject) cf.addDefaultConstructor @@ -231,12 +253,13 @@ object CompilationUnit { // This assumes that all functions of a given program get compiled // as methods of a single class file. - for(funDef <- p.definedFunctions; - (_,mn,_) <- env.funDefToMethod(funDef)) { + for(funDef <- program.definedFunctions; + (_,mn,_) <- leonFunDefToJVMInfo(funDef)) { val argsTypes = funDef.args.map(a => typeToJVM(a.tpe)) - val realArgs = if (env.params.requireMonitor) { - ("L" + CodeGeneration.MonitorClass + ";") +: argsTypes + + val realArgs = if (params.requireMonitor) { + ("L" + MonitorClass + ";") +: argsTypes } else { argsTypes } @@ -253,8 +276,62 @@ object CompilationUnit { ).asInstanceOf[U2]) compileFunDef(funDef, m.codeHandler) + + } + } + + + def init() { + // First define all classes + for ((parent, children) <- program.algebraicDataTypes) { + defineClass(parent) + + for (c <- children) { + defineClass(c) + } } - Some(new CompilationUnit(p, classes, env)) + for(single <- program.singleCaseClasses) { + defineClass(single) + } + + defineClass(program.mainObject) } + + def compile() { + // Compile everything + for ((parent, children) <- program.algebraicDataTypes) { + compileAbstractClassDef(parent) + + for (c <- children) { + compileCaseClassDef(c) + } + } + + for(single <- program.singleCaseClasses) { + compileCaseClassDef(single) + } + + compileMainObject() + + classes.values.foreach(loader.register _) + } + + def writeClassFiles() { + for ((d, cl) <- classes) { + cl.writeToFile(cl.className + ".class") + } + } + + init() + compile() } + +object CompilationUnit { + private var _nextExprId = 0 + private def nextExprId = { + _nextExprId += 1 + _nextExprId + } +} + diff --git a/src/main/scala/leon/codegen/CompiledExpression.scala b/src/main/scala/leon/codegen/CompiledExpression.scala index 0f4c8f855d8230d3ac885198693dc93303258828..4c2035921d2138eaecf09b39ea8a0989393b576b 100644 --- a/src/main/scala/leon/codegen/CompiledExpression.scala +++ b/src/main/scala/leon/codegen/CompiledExpression.scala @@ -24,10 +24,10 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression : Expr private val exprType = expression.getType - private val params = unit.env.params + private val params = unit.params def argsToJVM(args: Seq[Expr]): Seq[AnyRef] = { - args.map(unit.valueToJVM) + args.map(unit.exprToJVM) } def evalToJVM(args: Seq[AnyRef]): AnyRef = { @@ -50,7 +50,7 @@ class CompiledExpression(unit: CompilationUnit, cf: ClassFile, expression : Expr // We also need to reattach a type in some cases (sets, maps). def evalFromJVM(args: Seq[AnyRef]) : Expr = { try { - val result = unit.jvmToValue(evalToJVM(args)) + val result = unit.jvmToExpr(evalToJVM(args)) if(!result.isTyped) { result.setType(exprType) } diff --git a/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala new file mode 100644 index 0000000000000000000000000000000000000000..c5ed3070e4e2d7526d5ec755f55eb6230325b1ed --- /dev/null +++ b/src/main/scala/leon/codegen/runtime/ChooseEntryPoint.scala @@ -0,0 +1,82 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package codegen.runtime + +import utils._ +import purescala.Common._ +import purescala.Definitions._ +import purescala.Trees.{Tuple => LeonTuple, _} +import purescala.TreeOps.valuateWithModel +import purescala.TypeTrees._ +import solvers.z3._ + +import codegen.CompilationUnit + +import scala.collection.immutable.{Map => ScalaMap} + +import synthesis._ + +object ChooseEntryPoint { + private[this] var map = ScalaMap[Int, (Problem, CompilationUnit)]() + + implicit val debugSection = DebugSectionSynthesis + + def register(p: Problem, unit: CompilationUnit): Int = { + val stored = (p, unit) + val hash = stored.## + + map += hash -> stored + + hash + } + + def invoke(i: Int, inputs: Array[AnyRef]): java.lang.Object = { + val (p, unit) = map(i) + + val program = unit.program + val ctx = unit.ctx + + ctx.reporter.debug("Executing choose!") + + val tStart = System.currentTimeMillis; + + val solver = new FairZ3Solver(ctx, program).setTimeout(10000L) + + val inputsMap = (p.as zip inputs).map { + case (id, v) => + Equals(Variable(id), unit.jvmToExpr(v)) + } + + solver.assertCnstr(And(Seq(p.pc, p.phi) ++ inputsMap)) + + try { + solver.check match { + case Some(true) => + val model = solver.getModel; + + val valModel = valuateWithModel(model) _ + + val res = p.xs.map(valModel) + val leonRes = if (res.size > 1) { + LeonTuple(res) + } else { + res(0) + } + + val total = System.currentTimeMillis-tStart; + + ctx.reporter.debug("Synthesis took "+total+"ms") + ctx.reporter.debug("Finished synthesis with "+leonRes) + + unit.exprToJVM(leonRes) + case Some(false) => + throw new LeonCodeGenRuntimeException("Constraint is UNSAT") + case _ => + throw new LeonCodeGenRuntimeException("Timeout exceeded") + } + } finally { + solver.free() + } + } +} diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index 7a579b2373c2cea2a0238e23fe7c6333cb90d6ab..6226e9663ff07276b46ad4279ccb51338756ca08 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -16,7 +16,7 @@ import vanuatoo.{Pattern => VPattern, _} import evaluators._ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { - val unit = CompilationUnit.compileProgram(p).get + val unit = new CompilationUnit(ctx, p) val ints = (for (i <- Set(0, 1, 2, 3)) yield { i -> Constructor[Expr, TypeTree](List(), Int32Type, s => IntLiteral(i), ""+i) @@ -86,7 +86,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case (cc: codegen.runtime.CaseClass, ct: ClassType) => val r = cc.__getRead() - unit.jvmClassToDef.get(cc.getClass.getName) match { + unit.jvmClassToLeonClass(cc.getClass.getName) match { case Some(ccd: CaseClassDef) => val c = ct match { case act : AbstractClassType => diff --git a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala index 6d290c56955bf9c238024a70b5bc7975395bbd7d..12d864cef45ef5f8d221b2d74dcbf6ce31d00ce7 100644 --- a/src/main/scala/leon/evaluators/CodeGenEvaluator.scala +++ b/src/main/scala/leon/evaluators/CodeGenEvaluator.scala @@ -18,12 +18,10 @@ class CodeGenEvaluator(ctx : LeonContext, val unit : CompilationUnit) extends Ev /** Another constructor to make it look more like other `Evaluator`s. */ def this(ctx : LeonContext, prog : Program, params: CodeGenParams = CodeGenParams()) { - this(ctx, CompilationUnit.compileProgram(prog, params).get) // this .get is dubious... + this(ctx, new CompilationUnit(ctx, prog, params)) } def eval(expression : Expr, mapping : Map[Identifier,Expr]) : EvaluationResult = { - // ctx.reporter.warning("Using `eval` in CodeGenEvaluator is discouraged. Use `compile` whenever applicable.") - val toPairs = mapping.toSeq compile(expression, toPairs.map(_._1)).map(e => e(toPairs.map(_._2))).getOrElse(EvaluationResults.EvaluatorError("Couldn't compile expression.")) } diff --git a/src/main/scala/leon/evaluators/DefaultEvaluator.scala b/src/main/scala/leon/evaluators/DefaultEvaluator.scala index dd55f2b5d65c333b575a54605d685ba0eda596f4..259ecd85ecdadd48d91f1c22395f3020558e00f0 100644 --- a/src/main/scala/leon/evaluators/DefaultEvaluator.scala +++ b/src/main/scala/leon/evaluators/DefaultEvaluator.scala @@ -4,322 +4,19 @@ package leon package evaluators import purescala.Common._ -import purescala.Definitions._ -import purescala.TreeOps._ import purescala.Trees._ -import purescala.TypeTrees._ - -import xlang.Trees._ - -class DefaultEvaluator(ctx : LeonContext, prog : Program) extends Evaluator(ctx, prog) { - val name = "evaluator" - val description = "Recursive interpreter for PureScala expressions" - - private def typeErrorMsg(tree : Expr, expected : TypeTree) : String = "Type error : expected %s, found %s.".format(expected, tree) - private case class EvalError(msg : String) extends Exception - private case class RuntimeError(msg : String) extends Exception - - private val maxSteps = 50000 - - def eval(expression: Expr, mapping : Map[Identifier,Expr]) : EvaluationResults.Result = { - var left: Int = maxSteps - - def rec(ctx: Map[Identifier,Expr], expr: Expr) : Expr = if(left <= 0) { - throw RuntimeError("Diverging computation.") - } else { - // println("Step on : " + expr) - // println(ctx) - left -= 1 - expr match { - case Variable(id) => { - if(ctx.isDefinedAt(id)) { - val res = ctx(id) - if(!isGround(res)) { - throw EvalError("Substitution for identifier " + id.name + " is not ground.") - } else { - res - } - } else { - throw EvalError("No value for identifier " + id.name + " in mapping.") - } - } - case Tuple(ts) => { - val tsRec = ts.map(rec(ctx, _)) - Tuple(tsRec) - } - case TupleSelect(t, i) => { - val Tuple(rs) = rec(ctx, t) - rs(i-1) - } - case Let(i,e,b) => { - val first = rec(ctx, e) - rec(ctx + ((i -> first)), b) - } - case Error(desc) => throw RuntimeError("Error reached in evaluation: " + desc) - case IfExpr(cond, thenn, elze) => { - val first = rec(ctx, cond) - first match { - case BooleanLiteral(true) => rec(ctx, thenn) - case BooleanLiteral(false) => rec(ctx, elze) - case _ => throw EvalError(typeErrorMsg(first, BooleanType)) - } - } - case Waypoint(_, arg) => rec(ctx, arg) - case FunctionInvocation(fd, args) => { - val evArgs = args.map(a => rec(ctx, a)) - // build a mapping for the function... - val frame = Map[Identifier,Expr]((fd.args.map(_.id) zip evArgs) : _*) - - if(fd.hasPrecondition) { - rec(frame, matchToIfThenElse(fd.precondition.get)) match { - case BooleanLiteral(true) => ; - case BooleanLiteral(false) => { - throw RuntimeError("Precondition violation for " + fd.id.name + " reached in evaluation.: " + fd.precondition.get) - } - case other => throw RuntimeError(typeErrorMsg(other, BooleanType)) - } - } - - if(!fd.hasBody && !mapping.isDefinedAt(fd.id)) { - throw EvalError("Evaluation of function with unknown implementation.") - } - val body = fd.body.getOrElse(mapping(fd.id)) - val callResult = rec(frame, matchToIfThenElse(body)) - - if(fd.hasPostcondition) { - val (id, post) = fd.postcondition.get - - val freshResID = FreshIdentifier("result").setType(fd.returnType) - val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post)) - rec(frame + ((freshResID -> callResult)), postBody) match { - case BooleanLiteral(true) => ; - case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + fd.id.name + " reached in evaluation.") - case other => throw EvalError(typeErrorMsg(other, BooleanType)) - } - } - - callResult - } - case And(args) if args.isEmpty => BooleanLiteral(true) - case And(args) => { - rec(ctx, args.head) match { - case BooleanLiteral(false) => BooleanLiteral(false) - case BooleanLiteral(true) => rec(ctx, And(args.tail)) - case other => throw EvalError(typeErrorMsg(other, BooleanType)) - } - } - case Or(args) if args.isEmpty => BooleanLiteral(false) - case Or(args) => { - rec(ctx, args.head) match { - case BooleanLiteral(true) => BooleanLiteral(true) - case BooleanLiteral(false) => rec(ctx, Or(args.tail)) - case other => throw EvalError(typeErrorMsg(other, BooleanType)) - } - } - case Not(arg) => rec(ctx, arg) match { - case BooleanLiteral(v) => BooleanLiteral(!v) - case other => throw EvalError(typeErrorMsg(other, BooleanType)) - } - case Implies(l,r) => (rec(ctx,l), rec(ctx,r)) match { - case (BooleanLiteral(b1),BooleanLiteral(b2)) => BooleanLiteral(!b1 || b2) - case (le,re) => throw EvalError(typeErrorMsg(le, BooleanType)) - } - case Iff(le,re) => (rec(ctx,le),rec(ctx,re)) match { - case (BooleanLiteral(b1),BooleanLiteral(b2)) => BooleanLiteral(b1 == b2) - case _ => throw EvalError(typeErrorMsg(le, BooleanType)) - } - case Equals(le,re) => { - val lv = rec(ctx,le) - val rv = rec(ctx,re) - - (lv,rv) match { - case (FiniteSet(el1),FiniteSet(el2)) => BooleanLiteral(el1.toSet == el2.toSet) - case (FiniteMap(el1),FiniteMap(el2)) => BooleanLiteral(el1.toSet == el2.toSet) - case _ => BooleanLiteral(lv == rv) - } - } - case CaseClass(cd, args) => CaseClass(cd, args.map(rec(ctx,_))) - case CaseClassInstanceOf(cd, expr) => { - val le = rec(ctx,expr) - BooleanLiteral(le.getType match { - case CaseClassType(cd2) if cd2 == cd => true - case _ => false - }) - } - case CaseClassSelector(cd, expr, sel) => { - val le = rec(ctx, expr) - le match { - case CaseClass(cd2, args) if cd == cd2 => args(cd.selectorID2Index(sel)) - case _ => throw EvalError(typeErrorMsg(le, CaseClassType(cd))) - } - } - case Plus(l,r) => (rec(ctx,l), rec(ctx,r)) match { - case (IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 + i2) - case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) - } - case Minus(l,r) => (rec(ctx,l), rec(ctx,r)) match { - case (IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 - i2) - case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) - } - case UMinus(e) => rec(ctx,e) match { - case IntLiteral(i) => IntLiteral(-i) - case re => throw EvalError(typeErrorMsg(re, Int32Type)) - } - case Times(l,r) => (rec(ctx,l), rec(ctx,r)) match { - case (IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 * i2) - case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) - } - case Division(l,r) => (rec(ctx,l), rec(ctx,r)) match { - case (IntLiteral(i1), IntLiteral(i2)) => - if(i2 != 0) IntLiteral(i1 / i2) else throw RuntimeError("Division by 0.") - case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) - } - case Modulo(l,r) => (rec(ctx,l), rec(ctx,r)) match { - case (IntLiteral(i1), IntLiteral(i2)) => - if(i2 != 0) IntLiteral(i1 % i2) else throw RuntimeError("Modulo by 0.") - case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) - } - case LessThan(l,r) => (rec(ctx,l), rec(ctx,r)) match { - case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 < i2) - case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) - } - case GreaterThan(l,r) => (rec(ctx,l), rec(ctx,r)) match { - case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 > i2) - case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) - } - case LessEquals(l,r) => (rec(ctx,l), rec(ctx,r)) match { - case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 <= i2) - case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) - } - case GreaterEquals(l,r) => (rec(ctx,l), rec(ctx,r)) match { - case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 >= i2) - case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) - } - - case SetUnion(s1,s2) => (rec(ctx,s1), rec(ctx,s2)) match { - case (f@FiniteSet(els1),FiniteSet(els2)) => FiniteSet((els1 ++ els2).distinct).setType(f.getType) - case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) - } - case SetIntersection(s1,s2) => (rec(ctx,s1), rec(ctx,s2)) match { - case (f@FiniteSet(els1),FiniteSet(els2)) => { - val newElems = (els1.toSet intersect els2.toSet).toSeq - val baseType = f.getType.asInstanceOf[SetType].base - FiniteSet(newElems).setType(f.getType) - } - case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) - } - case SetDifference(s1,s2) => (rec(ctx,s1), rec(ctx,s2)) match { - case (f@FiniteSet(els1),FiniteSet(els2)) => { - val newElems = (els1.toSet -- els2.toSet).toSeq - val baseType = f.getType.asInstanceOf[SetType].base - FiniteSet(newElems).setType(f.getType) - } - case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) - } - case ElementOfSet(el,s) => (rec(ctx,el), rec(ctx,s)) match { - case (e, f @ FiniteSet(els)) => BooleanLiteral(els.contains(e)) - case (l,r) => throw EvalError(typeErrorMsg(r, SetType(l.getType))) - } - case SubsetOf(s1,s2) => (rec(ctx,s1), rec(ctx,s2)) match { - case (f@FiniteSet(els1),FiniteSet(els2)) => BooleanLiteral(els1.toSet.subsetOf(els2.toSet)) - case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) - } - case SetCardinality(s) => { - val sr = rec(ctx, s) - sr match { - case FiniteSet(els) => IntLiteral(els.size) - case _ => throw EvalError(typeErrorMsg(sr, SetType(AnyType))) - } - } - - case f @ FiniteSet(els) => FiniteSet(els.map(rec(ctx,_)).distinct).setType(f.getType) - case i @ IntLiteral(_) => i - case b @ BooleanLiteral(_) => b - case u @ UnitLiteral => u - - case f @ ArrayFill(length, default) => { - val rDefault = rec(ctx, default) - val rLength = rec(ctx, length) - val IntLiteral(iLength) = rLength - FiniteArray((1 to iLength).map(_ => rDefault).toSeq) - } - case ArrayLength(a) => { - var ra = rec(ctx, a) - while(!ra.isInstanceOf[FiniteArray]) - ra = ra.asInstanceOf[ArrayUpdated].array - IntLiteral(ra.asInstanceOf[FiniteArray].exprs.size) - } - case ArrayUpdated(a, i, v) => { - val ra = rec(ctx, a) - val ri = rec(ctx, i) - val rv = rec(ctx, v) - - val IntLiteral(index) = ri - val FiniteArray(exprs) = ra - FiniteArray(exprs.updated(index, rv)) - } - case ArraySelect(a, i) => { - val IntLiteral(index) = rec(ctx, i) - val FiniteArray(exprs) = rec(ctx, a) - try { - exprs(index) - } catch { - case e : IndexOutOfBoundsException => throw RuntimeError(e.getMessage) - } - } - case FiniteArray(exprs) => { - FiniteArray(exprs.map(e => rec(ctx, e))) - } - - case f @ FiniteMap(ss) => FiniteMap(ss.map{ case (k, v) => (rec(ctx, k), rec(ctx, v)) }.distinct).setType(f.getType) - case g @ MapGet(m,k) => (rec(ctx,m), rec(ctx,k)) match { - case (FiniteMap(ss), e) => ss.find(_._1 == e) match { - case Some((_, v0)) => v0 - case None => throw RuntimeError("Key not found: " + e) - } - case (l,r) => throw EvalError(typeErrorMsg(l, MapType(r.getType, g.getType))) - } - case u @ MapUnion(m1,m2) => (rec(ctx,m1), rec(ctx,m2)) match { - case (f1@FiniteMap(ss1), FiniteMap(ss2)) => { - val filtered1 = ss1.filterNot(s1 => ss2.exists(s2 => s2._1 == s1._1)) - val newSs = filtered1 ++ ss2 - FiniteMap(newSs).setType(f1.getType) - } - case (l, r) => throw EvalError(typeErrorMsg(l, m1.getType)) - } - case i @ MapIsDefinedAt(m,k) => (rec(ctx,m), rec(ctx,k)) match { - case (FiniteMap(ss), e) => BooleanLiteral(ss.exists(_._1 == e)) - case (l, r) => throw EvalError(typeErrorMsg(l, m.getType)) - } - case Distinct(args) => { - val newArgs = args.map(rec(ctx, _)) - BooleanLiteral(newArgs.distinct.size == newArgs.size) - } +import purescala.Definitions._ - case Choose(_, _) => - throw EvalError("Cannot evaluate choose.") +class DefaultEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog) { + type RC = DefaultRecContext + type GC = GlobalContext - case other => { - context.reporter.error("Error: don't know how to handle " + other + " in Evaluator.") - throw EvalError("Unhandled case in Evaluator : " + other) - } - } - } + def initRC(mappings: Map[Identifier, Expr]) = DefaultRecContext(mappings) + def initGC = new GlobalContext(stepsLeft = 50000) - try { - EvaluationResults.Successful(rec(mapping, expression)) - } catch { - case so: StackOverflowError => - EvaluationResults.EvaluatorError("Stack overflow") - case EvalError(msg) => - EvaluationResults.EvaluatorError(msg) - case RuntimeError(msg) => - EvaluationResults.RuntimeError(msg) - } - } + case class DefaultRecContext(mappings: Map[Identifier, Expr]) extends RecContext { + def withNewVar(id: Identifier, v: Expr) = copy(mappings + (id -> v)) - // quick and dirty.. don't overuse. - private def isGround(expr: Expr) : Boolean = { - variablesOf(expr) == Set.empty + def withVars(news: Map[Identifier, Expr]) = copy(news) } } diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..ab78f94dcac73ebf335398305c9dc87313398ec3 --- /dev/null +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -0,0 +1,423 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package evaluators + +import purescala.Common._ +import purescala.Definitions._ +import purescala.TreeOps._ +import purescala.Trees._ +import purescala.TypeTrees._ + +import xlang.Trees._ + +abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evaluator(ctx, prog) { + val name = "evaluator" + val description = "Recursive interpreter for PureScala expressions" + + type RC <: RecContext + type GC <: GlobalContext + + case class EvalError(msg : String) extends Exception + case class RuntimeError(msg : String) extends Exception + + abstract class RecContext { + val mappings: Map[Identifier, Expr] + + def withNewVar(id: Identifier, v: Expr): RC; + + def withVars(news: Map[Identifier, Expr]): RC; + } + + class GlobalContext(var stepsLeft: Int) + + def initRC(mappings: Map[Identifier, Expr]): RC + def initGC: GC + + def eval(e: Expr, mappings: Map[Identifier, Expr]) = { + try { + EvaluationResults.Successful(se(e)(initRC(mappings), initGC)) + } catch { + case so: StackOverflowError => + EvaluationResults.EvaluatorError("Stack overflow") + case EvalError(msg) => + EvaluationResults.EvaluatorError(msg) + case RuntimeError(msg) => + EvaluationResults.RuntimeError(msg) + } + } + + def se(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = { + if (gctx.stepsLeft < 0) { + throw RuntimeError("Exceeded number of allocated steps") + } else { + gctx.stepsLeft -= 1 + e(expr) + } + } + + def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = expr match { + case Variable(id) => + rctx.mappings.get(id) match { + case Some(v) => + if(!isGround(v)) { + throw EvalError("Substitution for identifier " + id.name + " is not ground.") + } else { + v + } + case None => + throw EvalError("No value for identifier " + id.name + " in mapping.") + } + + case Tuple(ts) => + val tsRec = ts.map(se) + Tuple(tsRec) + + case TupleSelect(t, i) => + val Tuple(rs) = se(t) + rs(i-1) + + case Let(i,e,b) => + val first = se(e) + se(b)(rctx.withNewVar(i, first), gctx) + + case Error(desc) => + throw RuntimeError("Error reached in evaluation: " + desc) + + case IfExpr(cond, thenn, elze) => + val first = se(cond) + first match { + case BooleanLiteral(true) => se(thenn) + case BooleanLiteral(false) => se(elze) + case _ => throw EvalError(typeErrorMsg(first, BooleanType)) + } + + case FunctionInvocation(fd, args) => + val evArgs = args.map(a => se(a)) + + // build a mapping for the function... + val frame = rctx.withVars((fd.args.map(_.id) zip evArgs).toMap) + + if(fd.hasPrecondition) { + se(matchToIfThenElse(fd.precondition.get))(frame, gctx) match { + case BooleanLiteral(true) => + case BooleanLiteral(false) => + throw RuntimeError("Precondition violation for " + fd.id.name + " reached in evaluation.: " + fd.precondition.get) + case other => throw RuntimeError(typeErrorMsg(other, BooleanType)) + } + } + + if(!fd.hasBody && !rctx.mappings.isDefinedAt(fd.id)) { + throw EvalError("Evaluation of function with unknown implementation.") + } + + val body = fd.body.getOrElse(rctx.mappings(fd.id)) + val callResult = se(matchToIfThenElse(body))(frame, gctx) + + if(fd.hasPostcondition) { + val (id, post) = fd.postcondition.get + + val freshResID = FreshIdentifier("result").setType(fd.returnType) + val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post)) + + se(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match { + case BooleanLiteral(true) => + case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + fd.id.name + " reached in evaluation.") + case other => throw EvalError(typeErrorMsg(other, BooleanType)) + } + } + + callResult + + case And(args) if args.isEmpty => + BooleanLiteral(true) + + case And(args) => + se(args.head) match { + case BooleanLiteral(false) => BooleanLiteral(false) + case BooleanLiteral(true) => se(And(args.tail)) + case other => throw EvalError(typeErrorMsg(other, BooleanType)) + } + + case Or(args) if args.isEmpty => BooleanLiteral(false) + case Or(args) => + se(args.head) match { + case BooleanLiteral(true) => BooleanLiteral(true) + case BooleanLiteral(false) => se(Or(args.tail)) + case other => throw EvalError(typeErrorMsg(other, BooleanType)) + } + + case Not(arg) => + se(arg) match { + case BooleanLiteral(v) => BooleanLiteral(!v) + case other => throw EvalError(typeErrorMsg(other, BooleanType)) + } + + case Implies(l,r) => + (se(l), se(r)) match { + case (BooleanLiteral(b1),BooleanLiteral(b2)) => BooleanLiteral(!b1 || b2) + case (le, re) => throw EvalError(typeErrorMsg(le, BooleanType)) + } + + case Iff(le,re) => + (se(le), se(re)) match { + case (BooleanLiteral(b1),BooleanLiteral(b2)) => BooleanLiteral(b1 == b2) + case _ => throw EvalError(typeErrorMsg(le, BooleanType)) + } + case Equals(le,re) => + val lv = se(le) + val rv = se(re) + + (lv,rv) match { + case (FiniteSet(el1),FiniteSet(el2)) => BooleanLiteral(el1.toSet == el2.toSet) + case (FiniteMap(el1),FiniteMap(el2)) => BooleanLiteral(el1.toSet == el2.toSet) + case _ => BooleanLiteral(lv == rv) + } + + case CaseClass(cd, args) => + CaseClass(cd, args.map(se(_))) + + case CaseClassInstanceOf(cd, expr) => + val le = se(expr) + BooleanLiteral(le.getType match { + case CaseClassType(cd2) if cd2 == cd => true + case _ => false + }) + + case CaseClassSelector(cd, expr, sel) => + val le = se(expr) + le match { + case CaseClass(cd2, args) if cd == cd2 => args(cd.selectorID2Index(sel)) + case _ => throw EvalError(typeErrorMsg(le, CaseClassType(cd))) + } + + case Plus(l,r) => + (se(l), se(r)) match { + case (IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 + i2) + case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case Minus(l,r) => + (se(l), se(r)) match { + case (IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 - i2) + case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case UMinus(e) => + se(e) match { + case IntLiteral(i) => IntLiteral(-i) + case re => throw EvalError(typeErrorMsg(re, Int32Type)) + } + + case Times(l,r) => + (se(l), se(r)) match { + case (IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 * i2) + case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case Division(l,r) => + (se(l), se(r)) match { + case (IntLiteral(i1), IntLiteral(i2)) => + if(i2 != 0) IntLiteral(i1 / i2) else throw RuntimeError("Division by 0.") + case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case Modulo(l,r) => + (se(l), se(r)) match { + case (IntLiteral(i1), IntLiteral(i2)) => + if(i2 != 0) IntLiteral(i1 % i2) else throw RuntimeError("Modulo by 0.") + case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + case LessThan(l,r) => + (se(l), se(r)) match { + case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 < i2) + case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case GreaterThan(l,r) => + (se(l), se(r)) match { + case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 > i2) + case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case LessEquals(l,r) => + (se(l), se(r)) match { + case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 <= i2) + case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case GreaterEquals(l,r) => + (se(l), se(r)) match { + case (IntLiteral(i1), IntLiteral(i2)) => BooleanLiteral(i1 >= i2) + case (le,re) => throw EvalError(typeErrorMsg(le, Int32Type)) + } + + case SetUnion(s1,s2) => + (se(s1), se(s2)) match { + case (f@FiniteSet(els1),FiniteSet(els2)) => FiniteSet((els1 ++ els2).distinct).setType(f.getType) + case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) + } + + case SetIntersection(s1,s2) => + (se(s1), se(s2)) match { + case (f @ FiniteSet(els1), FiniteSet(els2)) => { + val newElems = (els1.toSet intersect els2.toSet).toSeq + val baseType = f.getType.asInstanceOf[SetType].base + FiniteSet(newElems).setType(f.getType) + } + case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) + } + + case SetDifference(s1,s2) => + (se(s1), se(s2)) match { + case (f @ FiniteSet(els1),FiniteSet(els2)) => { + val newElems = (els1.toSet -- els2.toSet).toSeq + val baseType = f.getType.asInstanceOf[SetType].base + FiniteSet(newElems).setType(f.getType) + } + case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) + } + + case ElementOfSet(el,s) => (se(el), se(s)) match { + case (e, f @ FiniteSet(els)) => BooleanLiteral(els.contains(e)) + case (l,r) => throw EvalError(typeErrorMsg(r, SetType(l.getType))) + } + case SubsetOf(s1,s2) => (se(s1), se(s2)) match { + case (f@FiniteSet(els1),FiniteSet(els2)) => BooleanLiteral(els1.toSet.subsetOf(els2.toSet)) + case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) + } + case SetCardinality(s) => { + val sr = se(s) + sr match { + case FiniteSet(els) => IntLiteral(els.size) + case _ => throw EvalError(typeErrorMsg(sr, SetType(AnyType))) + } + } + + case f @ FiniteSet(els) => FiniteSet(els.map(se(_)).distinct).setType(f.getType) + case i @ IntLiteral(_) => i + case b @ BooleanLiteral(_) => b + case u @ UnitLiteral => u + + case f @ ArrayFill(length, default) => { + val rDefault = se(default) + val rLength = se(length) + val IntLiteral(iLength) = rLength + FiniteArray((1 to iLength).map(_ => rDefault).toSeq) + } + case ArrayLength(a) => { + var ra = se(a) + while(!ra.isInstanceOf[FiniteArray]) + ra = ra.asInstanceOf[ArrayUpdated].array + IntLiteral(ra.asInstanceOf[FiniteArray].exprs.size) + } + case ArrayUpdated(a, i, v) => { + val ra = se(a) + val ri = se(i) + val rv = se(v) + + val IntLiteral(index) = ri + val FiniteArray(exprs) = ra + FiniteArray(exprs.updated(index, rv)) + } + case ArraySelect(a, i) => { + val IntLiteral(index) = se(i) + val FiniteArray(exprs) = se(a) + try { + exprs(index) + } catch { + case e : IndexOutOfBoundsException => throw RuntimeError(e.getMessage) + } + } + case FiniteArray(exprs) => { + FiniteArray(exprs.map(e => se(e))) + } + + case f @ FiniteMap(ss) => FiniteMap(ss.map{ case (k, v) => (se(k), se(v)) }.distinct).setType(f.getType) + case g @ MapGet(m,k) => (se(m), se(k)) match { + case (FiniteMap(ss), e) => ss.find(_._1 == e) match { + case Some((_, v0)) => v0 + case None => throw RuntimeError("Key not found: " + e) + } + case (l,r) => throw EvalError(typeErrorMsg(l, MapType(r.getType, g.getType))) + } + case u @ MapUnion(m1,m2) => (se(m1), se(m2)) match { + case (f1@FiniteMap(ss1), FiniteMap(ss2)) => { + val filtered1 = ss1.filterNot(s1 => ss2.exists(s2 => s2._1 == s1._1)) + val newSs = filtered1 ++ ss2 + FiniteMap(newSs).setType(f1.getType) + } + case (l, r) => throw EvalError(typeErrorMsg(l, m1.getType)) + } + case i @ MapIsDefinedAt(m,k) => (se(m), se(k)) match { + case (FiniteMap(ss), e) => BooleanLiteral(ss.exists(_._1 == e)) + case (l, r) => throw EvalError(typeErrorMsg(l, m.getType)) + } + case Distinct(args) => { + val newArgs = args.map(se(_)) + BooleanLiteral(newArgs.distinct.size == newArgs.size) + } + + case choose: Choose => + import solvers.z3.FairZ3Solver + import purescala.TreeOps.simplestValue + + implicit val debugSection = DebugSectionSynthesis + + val p = synthesis.Problem.fromChoose(choose) + + ctx.reporter.debug("Executing choose!") + + val tStart = System.currentTimeMillis; + + val solver = new FairZ3Solver(ctx, program).setTimeout(10000L) + + val inputsMap = p.as.map { + case id => + Equals(Variable(id), rctx.mappings(id)) + } + + solver.assertCnstr(And(Seq(p.pc, p.phi) ++ inputsMap)) + + try { + solver.check match { + case Some(true) => + val model = solver.getModel; + + val valModel = valuateWithModel(model) _ + + val res = p.xs.map(valModel) + val leonRes = if (res.size > 1) { + Tuple(res) + } else { + res(0) + } + + val total = System.currentTimeMillis-tStart; + + ctx.reporter.debug("Synthesis took "+total+"ms") + ctx.reporter.debug("Finished synthesis with "+leonRes) + + leonRes + case Some(false) => + throw RuntimeError("Constraint is UNSAT") + case _ => + throw RuntimeError("Timeout exceeded") + } + } finally { + solver.free() + } + + case other => { + context.reporter.error("Error: don't know how to handle " + other + " in Evaluator.") + throw EvalError("Unhandled case in Evaluator : " + other) + } + } + + def typeErrorMsg(tree : Expr, expected : TypeTree) : String = "Type error : expected %s, found %s.".format(expected, tree) + + // quick and dirty.. don't overuse. + private def isGround(expr: Expr) : Boolean = { + variablesOf(expr) == Set.empty + } +} diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..183b6cf2f5f79772f51686fb6ae34318121a85fb --- /dev/null +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -0,0 +1,106 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package evaluators + +import purescala.Common._ +import purescala.Trees._ +import purescala.Definitions._ +import purescala.TreeOps._ +import purescala.TypeTrees._ + +class TracingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluator(ctx, prog) { + type RC = TracingRecContext + type GC = TracingGlobalContext + + var lastGlobalContext: Option[GC] = None + + def initRC(mappings: Map[Identifier, Expr]) = { + TracingRecContext(mappings, 2) + } + + def initGC = { + val gc = new TracingGlobalContext(stepsLeft = 50000, Nil) + lastGlobalContext = Some(gc) + gc + } + + class TracingGlobalContext(stepsLeft: Int, var values: List[(Expr, Expr)]) extends GlobalContext(stepsLeft) + + case class TracingRecContext(mappings: Map[Identifier, Expr], tracingFrames: Int) extends RecContext { + def withNewVar(id: Identifier, v: Expr) = copy(mappings = mappings + (id -> v)) + + def withVars(news: Map[Identifier, Expr]) = copy(mappings = news) + } + + override def e(expr: Expr)(implicit rctx: RC, gctx: GC): Expr = { + try { + val (res, recordedRes) = expr match { + case Let(i,e,b) => + // We record the value of the val at the position of Let, not the value of the body. + val first = se(e) + val res = se(b)(rctx.withNewVar(i, first), gctx) + (res, first) + + case fi @ FunctionInvocation(fd, args) => + + val evArgs = args.map(a => se(a)) + + // build a mapping for the function... + val frame = new TracingRecContext((fd.args.map(_.id) zip evArgs).toMap, rctx.tracingFrames-1) + + if(fd.hasPrecondition) { + se(matchToIfThenElse(fd.precondition.get))(frame, gctx) match { + case BooleanLiteral(true) => + case BooleanLiteral(false) => + throw RuntimeError("Precondition violation for " + fd.id.name + " reached in evaluation.: " + fd.precondition.get) + case other => throw RuntimeError(typeErrorMsg(other, BooleanType)) + } + } + + if(!fd.hasBody && !rctx.mappings.isDefinedAt(fd.id)) { + throw EvalError("Evaluation of function with unknown implementation.") + } + + val body = fd.body.getOrElse(rctx.mappings(fd.id)) + val callResult = se(matchToIfThenElse(body))(frame, gctx) + + if(fd.hasPostcondition) { + val (id, post) = fd.postcondition.get + + val freshResID = FreshIdentifier("result").setType(fd.returnType) + val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post)) + + se(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match { + case BooleanLiteral(true) => + case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + fd.id.name + " reached in evaluation.") + case other => throw EvalError(typeErrorMsg(other, BooleanType)) + } + } + + (callResult, callResult) + case _ => + val res = super.e(expr) + (res, res) + } + if (rctx.tracingFrames > 0) { + gctx.values ::= (expr -> recordedRes) + } + + res + } catch { + case ee @ EvalError(e) => + if (rctx.tracingFrames > 0) { + gctx.values ::= (expr -> Error(e)) + } + throw ee; + + case re @ RuntimeError(e) => + if (rctx.tracingFrames > 0) { + gctx.values ::= (expr -> Error(e)) + } + throw re; + } + } + +} diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 7b0a88a4e4fe08d113543121c4040ffaceed69cd..a31378248717a2b425a515dc56b421687e509194 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -58,7 +58,7 @@ trait CodeExtraction extends ASTExtractors { //This is a bit missleading, if an expr is not mapped then it has no owner, if it is mapped to None it means //that it can have any owner - private var owners: Map[LeonExpr, Option[FunDef]] = Map() + private var owners: Map[Identifier, Option[FunDef]] = Map() class Extraction(unit: CompilationUnit) { @@ -170,7 +170,7 @@ trait CodeExtraction extends ASTExtractors { ccd.fields = scalaClassArgs(sym).map{ case (name, asym) => val tpe = toPureScalaType(asym.tpe) - VarDecl(FreshIdentifier(name).setType(tpe), tpe) + VarDecl(FreshIdentifier(name).setType(tpe).setPos(asym.pos), tpe).setPos(asym.pos) } case _ => // no fields to set @@ -234,10 +234,10 @@ trait CodeExtraction extends ASTExtractors { private def extractFunSig(nameStr: String, params: Seq[ValDef], tpt: Tree): FunDef = { val newParams = params.map(p => { val ptpe = toPureScalaType(p.tpt.tpe) - val newID = FreshIdentifier(p.name.toString).setType(ptpe) - owners += (Variable(newID) -> None) + val newID = FreshIdentifier(p.name.toString).setType(ptpe).setPos(p.pos) + owners += (newID -> None) varSubsts(p.symbol) = (() => Variable(newID)) - VarDecl(newID, ptpe) + VarDecl(newID, ptpe).setPos(p.pos) }) new FunDef(FreshIdentifier(nameStr), toPureScalaType(tpt.tpe), newParams) } @@ -247,7 +247,7 @@ trait CodeExtraction extends ASTExtractors { val (body2, ensuring) = body match { case ExEnsuredExpression(body2, resSym, contract) => - val resId = FreshIdentifier(resSym.name.toString).setType(funDef.returnType) + val resId = FreshIdentifier(resSym.name.toString).setType(funDef.returnType).setPos(resSym.pos) varSubsts(resSym) = (() => Variable(resId)) (body2, toPureScala(contract).map(r => (resId, r))) @@ -326,29 +326,29 @@ trait CodeExtraction extends ASTExtractors { private def extractPattern(p: Tree, binder: Option[Identifier] = None): Pattern = p match { case b @ Bind(name, t @ Typed(pat, tpe)) => - val newID = FreshIdentifier(name.toString).setType(extractType(tpe.tpe)) + val newID = FreshIdentifier(name.toString).setType(extractType(tpe.tpe)).setPos(b.pos) varSubsts(b.symbol) = (() => Variable(newID)) extractPattern(t, Some(newID)) case b @ Bind(name, pat) => - val newID = FreshIdentifier(name.toString).setType(extractType(b.symbol.tpe)) + val newID = FreshIdentifier(name.toString).setType(extractType(b.symbol.tpe)).setPos(b.pos) varSubsts(b.symbol) = (() => Variable(newID)) extractPattern(pat, Some(newID)) case t @ Typed(Ident(nme.WILDCARD), tpe) if t.tpe.typeSymbol.isCase && classesToClasses.contains(t.tpe.typeSymbol) => val cd = classesToClasses(t.tpe.typeSymbol).asInstanceOf[CaseClassDef] - InstanceOfPattern(binder, cd) + InstanceOfPattern(binder, cd).setPos(p.pos) case Ident(nme.WILDCARD) => - WildcardPattern(binder) + WildcardPattern(binder).setPos(p.pos) case s @ Select(This(_), b) if s.tpe.typeSymbol.isCase && classesToClasses.contains(s.tpe.typeSymbol) => // case Obj => val cd = classesToClasses(s.tpe.typeSymbol).asInstanceOf[CaseClassDef] assert(cd.fields.size == 0) - CaseClassPattern(binder, cd, Seq()) + CaseClassPattern(binder, cd, Seq()).setPos(p.pos) case a @ Apply(fn, args) if fn.isType && a.tpe.typeSymbol.isCase && @@ -356,12 +356,12 @@ trait CodeExtraction extends ASTExtractors { val cd = classesToClasses(a.tpe.typeSymbol).asInstanceOf[CaseClassDef] assert(args.size == cd.fields.size) - CaseClassPattern(binder, cd, args.map(extractPattern(_))) + CaseClassPattern(binder, cd, args.map(extractPattern(_))).setPos(p.pos) case a @ Apply(fn, args) => extractType(a.tpe) match { case TupleType(argsTpes) => - TuplePattern(binder, args.map(extractPattern(_))) + TuplePattern(binder, args.map(extractPattern(_))).setPos(p.pos) case _ => unsupported(p, "Unsupported pattern") } @@ -375,7 +375,7 @@ trait CodeExtraction extends ASTExtractors { val recBody = extractTree(cd.body) if(cd.guard == EmptyTree) { - SimpleCase(recPattern, recBody) + SimpleCase(recPattern, recBody).setPos(cd.pos) } else { val recGuard = extractTree(cd.guard) @@ -384,7 +384,7 @@ trait CodeExtraction extends ASTExtractors { throw ImpureCodeEncounteredException(cd) } - GuardedCase(recPattern, recGuard, recBody) + GuardedCase(recPattern, recGuard, recBody).setPos(cd.pos) } } @@ -458,7 +458,7 @@ trait CodeExtraction extends ASTExtractors { if(valTree.getType.isInstanceOf[ArrayType]) { getOwner(valTree) match { case None => - owners += (Variable(newID) -> Some(currentFunDef)) + owners += (newID -> Some(currentFunDef)) case _ => unsupported(tr, "Cannot alias array") } @@ -507,7 +507,7 @@ trait CodeExtraction extends ASTExtractors { if(valTree.getType.isInstanceOf[ArrayType]) { getOwner(valTree) match { case None => - owners += (Variable(newID) -> Some(currentFunDef)) + owners += (newID -> Some(currentFunDef)) case Some(_) => unsupported(tr, "Cannot alias array") } @@ -613,10 +613,10 @@ trait CodeExtraction extends ASTExtractors { // TODO: refine type here? extractTree(e) - case ExIdentifier(sym,tpt) => varSubsts.get(sym) match { - case Some(fun) => fun() + case ex @ ExIdentifier(sym,tpt) => varSubsts.get(sym) match { + case Some(fun) => fun().setPos(ex.pos) case None => mutableVarSubsts.get(sym) match { - case Some(fun) => fun() + case Some(fun) => fun().setPos(ex.pos) case None => unsupported(tr, "Unidentified variable.") } @@ -628,7 +628,7 @@ trait CodeExtraction extends ASTExtractors { val vars = args map { case (tpe, sym) => val aTpe = extractType(tpe) val newID = FreshIdentifier(sym.name.toString).setType(aTpe) - owners += (Variable(newID) -> None) + owners += (newID -> None) varSubsts(sym) = (() => Variable(newID)) newID } @@ -1024,7 +1024,13 @@ trait CodeExtraction extends ASTExtractors { } def getOwner(exprs: Seq[LeonExpr]): Option[Option[FunDef]] = { - val exprOwners: Seq[Option[Option[FunDef]]] = exprs.map(owners.get(_)) + val exprOwners: Seq[Option[Option[FunDef]]] = exprs.map { + case Variable(id) => + owners.get(id) + case _ => + None + } + if(exprOwners.exists(_ == None)) None else if(exprOwners.exists(_ == Some(None))) @@ -1037,17 +1043,28 @@ trait CodeExtraction extends ASTExtractors { def getOwner(expr: LeonExpr): Option[Option[FunDef]] = getOwner(getReturnedExpr(expr)) - def extractProgram: Option[Program] = { - val topLevelObjDef = extractTopLevelDef - - val programName: Identifier = unit.body match { - case PackageDef(name, _) => FreshIdentifier(name.toString) - case _ => FreshIdentifier("<program>") - } + def extractProgram: Option[Program] = { + val topLevelObjDef = extractTopLevelDef - topLevelObjDef.map(obj => Program(programName, obj)) + val programName: Identifier = unit.body match { + case PackageDef(name, _) => FreshIdentifier(name.toString) + case _ => FreshIdentifier("<program>") } - } + topLevelObjDef.map(obj => Program(programName, obj)) + } + } + def containsLetDef(expr: LeonExpr): Boolean = { + def convert(t : LeonExpr) : Boolean = t match { + case (l : LetDef) => true + case _ => false + } + def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 + def compute(t : LeonExpr, c : Boolean) = t match { + case (l : LetDef) => true + case _ => c + } + treeCatamorphism(convert, combine, compute, expr) + } } diff --git a/src/main/scala/leon/purescala/Common.scala b/src/main/scala/leon/purescala/Common.scala index 7c3e7271aeab567e4012085cefe8ebdd4ed36397..e8838bee75a1ef7f0474cdc3c83221f8f5483b43 100644 --- a/src/main/scala/leon/purescala/Common.scala +++ b/src/main/scala/leon/purescala/Common.scala @@ -8,7 +8,18 @@ object Common { import Trees.Variable import TypeTrees.Typed - abstract class Tree extends Positioned with Serializable + abstract class Tree extends Positioned with Serializable { + def copiedFrom(o: Tree): this.type = { + setPos(o) + (this, o) match { + // do not force if already set + case (t1: Typed, t2: Typed) if !t1.isTyped => + t1.setType(t2.getType) + case _ => + } + this + } + } // the type is left blank (Untyped) for Identifiers that are not variables class Identifier private[Common](val name: String, private val globalId: Int, val id: Int, alwaysShowUniqueID: Boolean = false) extends Tree with Typed { diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 4e1e748022fee54aa6223f90e4317c7bd2931cb2..0720cf2488b248d6af2e7c766c197d7aa208add6 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -23,7 +23,6 @@ object FunctionClosure extends TransformationPhase { private var parent: FunDef = null //refers to the current toplevel parent def apply(ctx: LeonContext, program: Program): Program = { - pathConstraints = Nil enclosingLets = Nil newFunDefs = Map() @@ -46,7 +45,7 @@ object FunctionClosure extends TransformationPhase { val capturedVars: Set[Identifier] = bindedVars.diff(enclosingLets.map(_._1).toSet) val capturedConstraints: Set[Expr] = pathConstraints.toSet - val freshIds: Map[Identifier, Identifier] = capturedVars.map(id => (id, FreshIdentifier(id.name).setType(id.getType))).toMap + val freshIds: Map[Identifier, Identifier] = capturedVars.map(id => (id, FreshIdentifier(id.name).copiedFrom(id))).toMap val freshVars: Map[Expr, Expr] = freshIds.map(p => (p._1.toVariable, p._2.toVariable)) val extraVarDeclOldIds: Seq[Identifier] = capturedVars.toSeq @@ -56,7 +55,7 @@ object FunctionClosure extends TransformationPhase { val newBindedVars: Set[Identifier] = bindedVars ++ fd.args.map(_.id) val newFunId = FreshIdentifier(fd.id.uniqueName) //since we hoist this at the top level, we need to make it a unique name - val newFunDef = new FunDef(newFunId, fd.returnType, newVarDecls).setPos(fd) + val newFunDef = new FunDef(newFunId, fd.returnType, newVarDecls).copiedFrom(fd) topLevelFuns += newFunDef newFunDef.addAnnotation(fd.annotations.toSeq:_*) //TODO: this is still some dangerous side effects newFunDef.parent = Some(parent) @@ -65,7 +64,7 @@ object FunctionClosure extends TransformationPhase { def introduceLets(expr: Expr, fd2FreshFd: Map[FunDef, (FunDef, Seq[Variable])]): Expr = { val (newExpr, _) = enclosingLets.foldLeft((expr, Map[Identifier, Identifier]()))((acc, p) => { - val newId = FreshIdentifier(p._1.name).setType(p._1.getType) + val newId = FreshIdentifier(p._1.name).copiedFrom(p._1) val newMap = acc._2 + (p._1 -> newId) val newBody = functionClosure(acc._1, newBindedVars, freshIds ++ newMap, fd2FreshFd) (Let(newId, p._2, newBody), newMap) @@ -89,7 +88,7 @@ object FunctionClosure extends TransformationPhase { //val freshRest = functionClosure(rest, bindedVars, id2freshId, fd2FreshFd + (fd -> // ((newFunDef, extraVarDeclOldIds.map(id => id2freshId.get(id).getOrElse(id).toVariable))))) val freshRest = functionClosure(rest, bindedVars, id2freshId, fd2FreshFd + (fd -> ((newFunDef, extraVarDeclOldIds.map(_.toVariable))))) - freshRest.setType(l.getType) + freshRest.copiedFrom(l) } case l @ Let(i,e,b) => { val re = functionClosure(e, bindedVars, id2freshId, fd2FreshFd) @@ -99,7 +98,7 @@ object FunctionClosure extends TransformationPhase { val rb = functionClosure(b, bindedVars + i, id2freshId, fd2FreshFd) enclosingLets = enclosingLets.tail //pathConstraints = pathConstraints.tail - Let(i, re, rb).setType(l.getType) + Let(i, re, rb).copiedFrom(l) } case i @ IfExpr(cond,thenn,elze) => { /* @@ -113,26 +112,29 @@ object FunctionClosure extends TransformationPhase { pathConstraints ::= Not(cond)//Not(rCond) val rElze = functionClosure(elze, bindedVars, id2freshId, fd2FreshFd) pathConstraints = pathConstraints.tail - IfExpr(rCond, rThen, rElze).setType(i.getType) + IfExpr(rCond, rThen, rElze).copiedFrom(i) } case fi @ FunctionInvocation(fd, args) => fd2FreshFd.get(fd) match { - case None => FunctionInvocation(fd, args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd))).setPos(fi) + case None => + FunctionInvocation(fd, + args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd))).copiedFrom(fi) case Some((nfd, extraArgs)) => - FunctionInvocation(nfd, args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd)) ++ - extraArgs.map(v => replace(id2freshId.map(p => (p._1.toVariable, p._2.toVariable)), v))).setPos(fi) + FunctionInvocation(nfd, + args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd)) ++ + extraArgs.map(v => replace(id2freshId.map(p => (p._1.toVariable, p._2.toVariable)), v))).copiedFrom(fi) } case n @ NAryOperator(args, recons) => { val rargs = args.map(a => functionClosure(a, bindedVars, id2freshId, fd2FreshFd)) - recons(rargs).setType(n.getType) + recons(rargs).copiedFrom(n) } case b @ BinaryOperator(t1,t2,recons) => { val r1 = functionClosure(t1, bindedVars, id2freshId, fd2FreshFd) val r2 = functionClosure(t2, bindedVars, id2freshId, fd2FreshFd) - recons(r1,r2).setType(b.getType) + recons(r1,r2).copiedFrom(b) } case u @ UnaryOperator(t,recons) => { val r = functionClosure(t, bindedVars, id2freshId, fd2FreshFd) - recons(r).setType(u.getType) + recons(r).copiedFrom(u) } case m @ MatchExpr(scrut,cses) => { val scrutRec = functionClosure(scrut, bindedVars, id2freshId, fd2FreshFd) @@ -156,7 +158,7 @@ object FunctionClosure extends TransformationPhase { } } val tpe = csesRec.head.rhs.getType - MatchExpr(scrutRec, csesRec).setType(tpe).setPos(m) + MatchExpr(scrutRec, csesRec).copiedFrom(m).setType(tpe) } case v @ Variable(id) => id2freshId.get(id) match { case None => v diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index c03798d6243da1b0cb552ad754381dad042f1137..fa24b126568a3165538a85e4dcc3b3a8ac9d7881 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -8,11 +8,13 @@ import Trees._ import TypeTrees._ import Definitions._ +import utils._ + import java.lang.StringBuffer /** This pretty-printer uses Unicode for some operators, to make sure we * distinguish PureScala from "real" Scala (and also because it's cute). */ -class PrettyPrinter(sb: StringBuffer = new StringBuffer) { +class PrettyPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) { override def toString = sb.toString def append(str: String) { @@ -55,14 +57,19 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { sb.append(post) } - def idToString(id: Identifier): String = id.toString + def idToString(id: Identifier): String = { + if (opts.printUniqueIds) { + id.uniqueName + } else { + id.toString + } + } def pp(tree: Tree, parent: Option[Tree])(implicit lvl: Int): Unit = { implicit val p = Some(tree) tree match { case Variable(id) => sb.append(idToString(id)) - case DeBruijnIndex(idx) => sb.append("_" + idx) case LetTuple(bs,d,e) => sb.append("(let (" + bs.map(idToString _).mkString(",") + " := "); pp(d, p) @@ -438,6 +445,17 @@ class PrettyPrinter(sb: StringBuffer = new StringBuffer) { case _ => sb.append("Tree? (" + tree.getClass + ")") } + if (opts.printPositions) { + ppos(tree.getPos) + } + } + + def ppos(p: Position) = p match { + case op: OffsetPosition => + sb.append("@"+op.toString) + case rp: RangePosition => + sb.append("@"+rp.focusBegin.toString+"--"+rp.focusEnd.toString) + case _ => } } @@ -447,24 +465,24 @@ trait PrettyPrintable { def printWith(printer: PrettyPrinter)(implicit lvl: Int): Unit } -class EquivalencePrettyPrinter() extends PrettyPrinter() { +class EquivalencePrettyPrinter(opts: PrinterOptions) extends PrettyPrinter(opts) { override def idToString(id: Identifier) = id.name } abstract class PrettyPrinterFactory { - def create: PrettyPrinter + def create(opts: PrinterOptions): PrettyPrinter - def apply(tree: Tree, ind: Int = 0): String = { - val printer = create - printer.pp(tree, None)(ind) + def apply(tree: Tree, opts: PrinterOptions = PrinterOptions()): String = { + val printer = create(opts) + printer.pp(tree, None)(opts.baseIndent) printer.toString } } object PrettyPrinter extends PrettyPrinterFactory { - def create = new PrettyPrinter() + def create(opts: PrinterOptions) = new PrettyPrinter(opts) } object EquivalencePrettyPrinter extends PrettyPrinterFactory { - def create = new EquivalencePrettyPrinter() + def create(opts: PrinterOptions) = new EquivalencePrettyPrinter(opts) } diff --git a/src/main/scala/leon/purescala/PrinterOptions.scala b/src/main/scala/leon/purescala/PrinterOptions.scala new file mode 100644 index 0000000000000000000000000000000000000000..dc7b77aa2e1237c8341d95f7086da46c73522695 --- /dev/null +++ b/src/main/scala/leon/purescala/PrinterOptions.scala @@ -0,0 +1,7 @@ +package leon.purescala + +case class PrinterOptions ( + baseIndent: Int = 0, + printPositions: Boolean = false, + printUniqueIds: Boolean = false +) diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index eaa7bca9a3011abb9c684af763ea174c78bf4f62..90bd33aff9794127a2d47ffe5501a7263099cc04 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -8,7 +8,7 @@ import TypeTrees._ import Definitions._ /** This pretty-printer only print valid scala syntax */ -class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb) { +class ScalaPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) extends PrettyPrinter(opts, sb) { import Common._ import Trees._ import TypeTrees._ @@ -50,9 +50,10 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb } } + var printPos = opts.printPositions + tree match { case Variable(id) => sb.append(idToString(id)) - case DeBruijnIndex(idx) => sys.error("Not Valid Scala") case LetTuple(ids,d,e) => optBraces { implicit lvl => sb.append("val (" ) @@ -386,7 +387,14 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb sb.append(" }") } - case _ => super.pp(tree, parent)(lvl) + case _ => + super.pp(tree, parent)(lvl) + // Parent will already print + printPos = false + } + + if (printPos) { + ppos(tree.getPos) } } @@ -420,5 +428,5 @@ class ScalaPrinter(sb: StringBuffer = new StringBuffer) extends PrettyPrinter(sb } object ScalaPrinter extends PrettyPrinterFactory { - def create = new ScalaPrinter() + def create(opts: PrinterOptions) = new ScalaPrinter(opts) } diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala new file mode 100644 index 0000000000000000000000000000000000000000..566ea5a95cd2b632dd045dc6a600fab0dd936256 --- /dev/null +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -0,0 +1,150 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package purescala + +import Common._ +import Definitions._ +import Trees._ +import TypeTrees._ +import TreeOps._ +import Extractors._ + +class ScopeSimplifier extends Transformer { + case class Scope(inScope: Set[Identifier] = Set(), oldToNew: Map[Identifier, Identifier] = Map(), funDefs: Map[FunDef, FunDef] = Map()) { + + def register(oldNew: (Identifier, Identifier)): Scope = { + val (oldId, newId) = oldNew + copy(inScope = inScope + newId, oldToNew = oldToNew + oldNew) + } + + def registerFunDef(oldNew: (FunDef, FunDef)): Scope = { + copy(funDefs = funDefs + oldNew) + } + } + + protected def genId(id: Identifier, scope: Scope): Identifier = { + val existCount = scope.inScope.count(_.name == id.name) + + FreshIdentifier(id.name, existCount).setType(id.getType) + } + + protected def rec(e: Expr, scope: Scope): Expr = e match { + case Let(i, e, b) => + val si = genId(i, scope) + val se = rec(e, scope) + val sb = rec(b, scope.register(i -> si)) + Let(si, se, sb) + + case LetDef(fd: FunDef, body: Expr) => + val newId = genId(fd.id, scope) + var newScope = scope.register(fd.id -> newId) + + val newArgs = for(VarDecl(id, tpe) <- fd.args) yield { + val newArg = genId(id, newScope) + newScope = newScope.register(id -> newArg) + VarDecl(newArg, tpe) + } + + val newFd = new FunDef(newId, fd.returnType, newArgs) + + newScope = newScope.registerFunDef(fd -> newFd) + + newFd.body = fd.body.map(b => rec(b, newScope)) + newFd.precondition = fd.precondition.map(pre => rec(pre, newScope)) + + newFd.postcondition = fd.postcondition.map { + case (id, post) => + val nid = genId(id, newScope) + val postScope = newScope.register(id -> nid) + (nid, rec(post, postScope)) + } + + LetDef(newFd, rec(body, newScope)) + + case LetTuple(is, e, b) => + var newScope = scope + val sis = for (i <- is) yield { + val si = genId(i, newScope) + newScope = newScope.register(i -> si) + si + } + + val se = rec(e, scope) + val sb = rec(b, newScope) + LetTuple(sis, se, sb) + + case MatchExpr(scrut, cases) => + val rs = rec(scrut, scope) + + def trPattern(p: Pattern, scope: Scope): (Pattern, Scope) = { + val (newBinder, newScope) = p.binder match { + case Some(id) => + val newId = genId(id, scope) + val newScope = scope.register(id -> newId) + (Some(newId), newScope) + case None => + (None, scope) + } + + var curScope = newScope + var newSubPatterns = for (sp <- p.subPatterns) yield { + val (subPattern, subScope) = trPattern(sp, curScope) + curScope = subScope + subPattern + } + + val newPattern = p match { + case InstanceOfPattern(b, ctd) => + InstanceOfPattern(newBinder, ctd) + case WildcardPattern(b) => + WildcardPattern(newBinder) + case CaseClassPattern(b, ccd, sub) => + CaseClassPattern(newBinder, ccd, newSubPatterns) + case TuplePattern(b, sub) => + TuplePattern(newBinder, newSubPatterns) + } + + + (newPattern, curScope) + } + + MatchExpr(rs, cases.map { c => + val (newP, newScope) = trPattern(c.pattern, scope) + + c match { + case SimpleCase(p, rhs) => + SimpleCase(newP, rec(rhs, newScope)) + case GuardedCase(p, g, rhs) => + GuardedCase(newP, rec(g, newScope), rec(rhs, newScope)) + } + }) + + case Variable(id) => + Variable(scope.oldToNew.getOrElse(id, id)) + + case FunctionInvocation(fd, args) => + val newFd = scope.funDefs.getOrElse(fd, fd) + val newArgs = args.map(rec(_, scope)) + + FunctionInvocation(newFd, newArgs) + + case UnaryOperator(e, builder) => + builder(rec(e, scope)) + + case BinaryOperator(e1, e2, builder) => + builder(rec(e1, scope), rec(e2, scope)) + + case NAryOperator(es, builder) => + builder(es.map(rec(_, scope))) + + case t : Terminal => t + + case _ => + sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") + } + + def transform(e: Expr): Expr = { + rec(e, Scope()) + } +} diff --git a/src/main/scala/leon/purescala/SimplifierWithPaths.scala b/src/main/scala/leon/purescala/SimplifierWithPaths.scala new file mode 100644 index 0000000000000000000000000000000000000000..4f6670898b556192833922ead45bde4c977eb728 --- /dev/null +++ b/src/main/scala/leon/purescala/SimplifierWithPaths.scala @@ -0,0 +1,118 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package purescala + +import Trees._ +import TypeTrees._ +import TreeOps._ +import Extractors._ +import solvers._ + +class SimplifierWithPaths(sf: SolverFactory[Solver]) extends TransformerWithPC { + type C = List[Expr] + + val initC = Nil + + val solver = SimpleSolverAPI(sf) + + protected def register(e: Expr, c: C) = e :: c + + def impliedBy(e : Expr, path : Seq[Expr]) : Boolean = try { + solver.solveVALID(Implies(And(path), e)) match { + case Some(true) => true + case _ => false + } + } catch { + case _ : Exception => false + } + + def contradictedBy(e : Expr, path : Seq[Expr]) : Boolean = try { + solver.solveVALID(Implies(And(path), Not(e))) match { + case Some(true) => true + case _ => false + } + } catch { + case _ : Exception => false + } + + protected override def rec(e: Expr, path: C) = e match { + case IfExpr(cond, thenn, elze) => + super.rec(e, path) match { + case IfExpr(BooleanLiteral(true) , t, _) => t + case IfExpr(BooleanLiteral(false), _, e) => e + case ite => ite + } + + case And(es) => + var soFar = path + var continue = true + var r = And(for(e <- es if continue) yield { + val se = rec(e, soFar) + if(se == BooleanLiteral(false)) continue = false + soFar = register(se, soFar) + se + }).copiedFrom(e) + + if (continue) { + r + } else { + BooleanLiteral(false).copiedFrom(e) + } + + case MatchExpr(scrut, cases) => + val rs = rec(scrut, path) + + var stillPossible = true + + if (cases.exists(_.hasGuard)) { + // unsupported for now + e + } else { + MatchExpr(rs, cases.flatMap { c => + val patternExpr = conditionForPattern(rs, c.pattern, includeBinders = true) + + if (stillPossible && !contradictedBy(patternExpr, path)) { + + if (impliedBy(patternExpr, path)) { + stillPossible = false + } + + c match { + case SimpleCase(p, rhs) => + Some(SimpleCase(p, rec(rhs, patternExpr +: path)).copiedFrom(c)) + case GuardedCase(_, _, _) => + sys.error("woot.") + } + } else { + None + } + }).copiedFrom(e) + } + + case Or(es) => + var soFar = path + var continue = true + var r = Or(for(e <- es if continue) yield { + val se = rec(e, soFar) + if(se == BooleanLiteral(true)) continue = false + soFar = register(Not(se), soFar) + se + }).copiedFrom(e) + + if (continue) { + r + } else { + BooleanLiteral(true).copiedFrom(e) + } + + case b if b.getType == BooleanType && impliedBy(b, path) => + BooleanLiteral(true).copiedFrom(b) + + case b if b.getType == BooleanType && contradictedBy(b, path) => + BooleanLiteral(false).copiedFrom(b) + + case _ => + super.rec(e, path) + } +} diff --git a/src/main/scala/leon/purescala/Transformer.scala b/src/main/scala/leon/purescala/Transformer.scala new file mode 100644 index 0000000000000000000000000000000000000000..54b9245b6475816961ca720b7f39888c03e61b3c --- /dev/null +++ b/src/main/scala/leon/purescala/Transformer.scala @@ -0,0 +1,11 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package purescala + +import purescala.Trees._ + + +trait Transformer { + def transform(e: Expr): Expr +} diff --git a/src/main/scala/leon/purescala/TransformerWithPC.scala b/src/main/scala/leon/purescala/TransformerWithPC.scala new file mode 100644 index 0000000000000000000000000000000000000000..0551aba59777e8a166f9e2da4ddff0940c5f5c99 --- /dev/null +++ b/src/main/scala/leon/purescala/TransformerWithPC.scala @@ -0,0 +1,87 @@ +/* Copyright 2009-2013 EPFL, Lausanne */ + +package leon +package purescala + +import Trees._ +import TreeOps._ +import Extractors._ + +abstract class TransformerWithPC extends Transformer { + type C + + protected val initC: C + + protected def register(cond: Expr, path: C): C + + protected def rec(e: Expr, path: C): Expr = e match { + case Let(i, e, b) => + val se = rec(e, path) + val sb = rec(b, register(Equals(Variable(i), se), path)) + Let(i, se, sb).copiedFrom(e) + + case MatchExpr(scrut, cases) => + val rs = rec(scrut, path) + + var soFar = path + + MatchExpr(rs, cases.map { c => + val patternExpr = conditionForPattern(rs, c.pattern, includeBinders = true) + + val subPath = register(patternExpr, soFar) + soFar = register(Not(patternExpr), soFar) + + c match { + case SimpleCase(p, rhs) => + SimpleCase(p, rec(rhs, subPath)).copiedFrom(c) + case GuardedCase(p, g, rhs) => + GuardedCase(p, g, rec(rhs, subPath)).copiedFrom(c) + } + }).copiedFrom(e) + + case LetTuple(is, e, b) => + val se = rec(e, path) + val sb = rec(b, register(Equals(Tuple(is.map(Variable(_))), se), path)) + LetTuple(is, se, sb).copiedFrom(e) + + case IfExpr(cond, thenn, elze) => + val rc = rec(cond, path) + + IfExpr(rc, rec(thenn, register(rc, path)), rec(elze, register(Not(rc), path))).copiedFrom(e) + + case And(es) => + var soFar = path + And(for(e <- es) yield { + val se = rec(e, soFar) + soFar = register(se, soFar) + se + }).copiedFrom(e) + + case Or(es) => + var soFar = path + Or(for(e <- es) yield { + val se = rec(e, soFar) + soFar = register(Not(se), soFar) + se + }).copiedFrom(e) + + case o @ UnaryOperator(e, builder) => + builder(rec(e, path)).copiedFrom(o) + + case o @ BinaryOperator(e1, e2, builder) => + builder(rec(e1, path), rec(e2, path)).copiedFrom(o) + + case o @ NAryOperator(es, builder) => + builder(es.map(rec(_, path))).copiedFrom(o) + + case t : Terminal => t + + case _ => + sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") + } + + def transform(e: Expr): Expr = { + rec(e, initC) + } +} + diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 02b771757bba5f0d4866165ea67e4bfd59a539df..c6e82ea0fee23b158231af26f2fd2f77100e3520 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -14,7 +14,7 @@ object TreeOps { import Trees._ import Extractors._ - def negate(expr: Expr) : Expr = expr match { + def negate(expr: Expr) : Expr = (expr match { case Let(i,b,e) => Let(i,b,negate(e)) case Not(e) => e case Iff(e1,e2) => Iff(negate(e1),e2) @@ -28,7 +28,7 @@ object TreeOps { case i @ IfExpr(c,e1,e2) => IfExpr(c, negate(e1), negate(e2)).setType(i.getType) case BooleanLiteral(b) => BooleanLiteral(!b) case _ => Not(expr) - } + }).setType(expr.getType).setPos(expr) // Warning ! This may loop forever if the substitutions are not // well-formed! @@ -57,23 +57,16 @@ object TreeOps { val re = rec(e) val rb = rec(b) if(re != e || rb != b) - Let(i, re, rb).setType(l.getType) + Let(i, re, rb).copiedFrom(l) else l } - //case l @ LetDef(fd, b) => { - // //TODO, not sure, see comment for the next LetDef - // fd.body = fd.body.map(rec(_)) - // fd.precondition = fd.precondition.map(rec(_)) - // fd.postcondition = fd.postcondition.map(rec(_)) - // LetDef(fd, rec(b)).setType(l.getType) - //} case lt @ LetTuple(ids, expr, body) => { val re = rec(expr) val rb = rec(body) if (re != expr || rb != body) { - LetTuple(ids, re, rb).setType(lt.getType) + LetTuple(ids, re, rb).copiedFrom(lt) } else { lt } @@ -90,7 +83,7 @@ object TreeOps { } }) if(change) - recons(rargs).setType(n.getType) + recons(rargs).copiedFrom(n) else n } @@ -98,14 +91,14 @@ object TreeOps { val r1 = rec(t1) val r2 = rec(t2) if(r1 != t1 || r2 != t2) - recons(r1,r2).setType(b.getType) + recons(r1,r2).copiedFrom(b) else b } case u @ UnaryOperator(t,recons) => { val r = rec(t) if(r != t) - recons(r).setType(u.getType) + recons(r).copiedFrom(u) else u } @@ -114,17 +107,17 @@ object TreeOps { val r2 = rec(t2) val r3 = rec(t3) if(r1 != t1 || r2 != t2 || r3 != t3) - IfExpr(rec(t1),rec(t2),rec(t3)).setType(i.getType) + IfExpr(rec(t1),rec(t2),rec(t3)).copiedFrom(i) else i } - case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).setType(m.getType).setPos(m) + case m @ MatchExpr(scrut,cses) => MatchExpr(rec(scrut), cses.map(inCase(_))).copiedFrom(m) case c @ Choose(args, body) => val body2 = rec(body) if (body != body2) { - Choose(args, body2).setType(c.getType) + Choose(args, body2).copiedFrom(c) } else { c } @@ -135,8 +128,8 @@ object TreeOps { } def inCase(cse: MatchCase) : MatchCase = cse match { - case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs)) - case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs)) + case SimpleCase(pat, rhs) => SimpleCase(pat, rec(rhs)).copiedFrom(cse) + case GuardedCase(pat, guard, rhs) => GuardedCase(pat, rec(guard), rec(rhs)).copiedFrom(cse) } rec(expr) @@ -166,7 +159,7 @@ object TreeOps { val re = rec(e) val rb = rec(b) applySubst(if(re != e || rb != b) { - Let(i,re,rb).setType(l.getType) + Let(i,re,rb).copiedFrom(l) } else { l }) @@ -175,19 +168,11 @@ object TreeOps { val re = rec(e) val rb = rec(b) applySubst(if(re != e || rb != b) { - LetTuple(ids,re,rb).setType(l.getType) + LetTuple(ids,re,rb).copiedFrom(l) } else { l }) } - //case l @ LetDef(fd,b) => { - // //TODO: Not sure: I actually need the replace to occurs even in the pre/post condition, hope this is correct - // fd.body = fd.body.map(rec(_)) - // fd.precondition = fd.precondition.map(rec(_)) - // fd.postcondition = fd.postcondition.map(rec(_)) - // val rl = LetDef(fd, rec(b)).setType(l.getType) - // applySubst(rl) - //} case n @ NAryOperator(args, recons) => { var change = false val rargs = args.map(a => { @@ -200,7 +185,7 @@ object TreeOps { } }) applySubst(if(change) { - recons(rargs).setType(n.getType) + recons(rargs).copiedFrom(n) } else { n }) @@ -209,7 +194,7 @@ object TreeOps { val r1 = rec(t1) val r2 = rec(t2) applySubst(if(r1 != t1 || r2 != t2) { - recons(r1,r2).setType(b.getType) + recons(r1,r2).copiedFrom(b) } else { b }) @@ -217,7 +202,7 @@ object TreeOps { case u @ UnaryOperator(t,recons) => { val r = rec(t) applySubst(if(r != t) { - recons(r).setType(u.getType) + recons(r).copiedFrom(u) } else { u }) @@ -227,7 +212,7 @@ object TreeOps { val r2 = rec(t2) val r3 = rec(t3) applySubst(if(r1 != t1 || r2 != t2 || r3 != t3) { - IfExpr(r1,r2,r3).setType(i.getType) + IfExpr(r1,r2,r3).copiedFrom(i) } else { i }) @@ -236,7 +221,7 @@ object TreeOps { val rscrut = rec(scrut) val (newCses,changes) = cses.map(inCase(_)).unzip applySubst(if(rscrut != scrut || changes.exists(res=>res)) { - MatchExpr(rscrut, newCses).setType(m.getType).setPos(m) + MatchExpr(rscrut, newCses).copiedFrom(m) } else { m }) @@ -246,12 +231,12 @@ object TreeOps { val body2 = rec(body) applySubst(if (body != body2) { - Choose(args, body2).setType(c.getType).setPos(c) + Choose(args, body2).copiedFrom(c) } else { c }) - case t if t.isInstanceOf[Terminal] => applySubst(t) + case t if t.isInstanceOf[Terminal] => applySubst(t).setPos(t) case unhandled => scala.sys.error("Non-terminal case should be handled in searchAndReplaceDFS: " + unhandled) } @@ -300,9 +285,9 @@ object TreeOps { } def applyToTree(e : Expr) : Option[Expr] = e match { - case m @ MatchExpr(s, cses) => Some(MatchExpr(s, cses.map(freshenCase(_))).setType(m.getType).setPos(m)) + case m @ MatchExpr(s, cses) => Some(MatchExpr(s, cses.map(freshenCase(_))).copiedFrom(m)) case l @ Let(i,e,b) => { - val newID = FreshIdentifier(i.name, true).setType(i.getType) + val newID = FreshIdentifier(i.name, true).copiedFrom(i) Some(Let(newID, e, replace(Map(Variable(i) -> Variable(newID)), b))) } case _ => None @@ -317,14 +302,11 @@ object TreeOps { def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, expression: Expr) : A = { treeCatamorphism(convert, combine, (e:Expr,a:A)=>a, expression) } + // compute allows the catamorphism to change the combined value depending on the tree def treeCatamorphism[A](convert: Expr=>A, combine: (A,A)=>A, compute: (Expr,A)=>A, expression: Expr) : A = { def rec(expr: Expr) : A = expr match { case l @ Let(_, e, b) => compute(l, combine(rec(e), rec(b))) - //case l @ LetDef(fd, b) => {//TODO, still not sure about the semantic - // val exprs: Seq[Expr] = fd.precondition.toSeq ++ fd.body.toSeq ++ fd.postcondition.toSeq ++ Seq(b) - // compute(l, exprs.map(rec(_)).reduceLeft(combine)) - //} case n @ NAryOperator(args, _) => { if(args.size == 0) compute(n, convert(n)) @@ -343,19 +325,6 @@ object TreeOps { rec(expression) } - def containsIfExpr(expr: Expr): Boolean = { - def convert(t : Expr) : Boolean = t match { - case (i: IfExpr) => true - case _ => false - } - def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 - def compute(t : Expr, c : Boolean) = t match { - case (i: IfExpr) => true - case _ => c - } - treeCatamorphism(convert, combine, compute, expr) - } - def variablesOf(expr: Expr) : Set[Identifier] = { def convert(t: Expr) : Set[Identifier] = t match { case Variable(i) => Set(i) @@ -371,47 +340,16 @@ object TreeOps { treeCatamorphism(convert, combine, compute, expr) } - def containsFunctionCalls(expr : Expr) : Boolean = { - def convert(t : Expr) : Boolean = t match { - case f : FunctionInvocation => true - case _ => false - } - def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 - def compute(t : Expr, c : Boolean) = t match { - case f : FunctionInvocation => true - case _ => c - } - treeCatamorphism(convert, combine, compute, expr) - } - - def topLevelFunctionCallsOf(expr: Expr, barring : Set[FunDef] = Set.empty) : Set[FunctionInvocation] = { - def convert(t: Expr) : Set[FunctionInvocation] = t match { - case f @ FunctionInvocation(fd, _) if(!barring(fd)) => Set(f) - case _ => Set.empty - } - def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2 - def compute(t: Expr, s: Set[FunctionInvocation]) = t match { - case f @ FunctionInvocation(fd, _) if(!barring(fd)) => Set(f) // ++ s that's the difference with the one below - case _ => s - } - treeCatamorphism(convert, combine, compute, expr) - } - - def allNonRecursiveFunctionCallsOf(expr: Expr, program: Program) : Set[FunctionInvocation] = { - def convert(t: Expr) : Set[FunctionInvocation] = t match { - case f @ FunctionInvocation(fd, _) if program.isRecursive(fd) => Set(f) - case _ => Set.empty - } - - def combine(s1: Set[FunctionInvocation], s2: Set[FunctionInvocation]) = s1 ++ s2 - - def compute(t: Expr, s: Set[FunctionInvocation]) = t match { - case f @ FunctionInvocation(fd,_) if program.isRecursive(fd) => Set(f) ++ s - case _ => s - } - treeCatamorphism(convert, combine, compute, expr) + def containsFunctionCalls(expr: Expr): Boolean = { + contains(expr, { + case _: FunctionInvocation => true + case _ => false + }) } + /** + * Returns all Function calls found in an expression + */ def functionCallsOf(expr: Expr) : Set[FunctionInvocation] = { def convert(t: Expr) : Set[FunctionInvocation] = t match { case f @ FunctionInvocation(_, _) => Set(f) @@ -425,24 +363,20 @@ object TreeOps { treeCatamorphism(convert, combine, compute, expr) } - def contains(expr: Expr, matcher: Expr=>Boolean) : Boolean = { - treeCatamorphism[Boolean]( - matcher, - (b1: Boolean, b2: Boolean) => b1 || b2, - (t: Expr, b: Boolean) => b || matcher(t), - expr) - } + /** + * Returns true if matcher(se) == true where se is any sub-expression of e + */ - def allDeBruijnIndices(expr: Expr) : Set[DeBruijnIndex] = { - def convert(t: Expr) : Set[DeBruijnIndex] = t match { - case i @ DeBruijnIndex(idx) => Set(i) - case _ => Set.empty - } - def combine(s1: Set[DeBruijnIndex], s2: Set[DeBruijnIndex]) = s1 ++ s2 - treeCatamorphism(convert, combine, expr) + def contains(e: Expr, matcher: Expr => Boolean): Boolean = { + simplePreTransform{ + case e if matcher(e) => return true + case e => e + }(e) + false } - /* Simplifies let expressions: + /** + * Simplifies let expressions: * - removes lets when expression never occurs * - simplifies when expressions occurs exactly once * - expands when expression is just a variable. @@ -504,7 +438,7 @@ object TreeOps { case l @ LetTuple(ids, tExpr: Terminal, body) if !containsChoose(body) => val substMap : Map[Expr,Expr] = ids.map(Variable(_) : Expr).zipWithIndex.toMap.map { - case (v,i) => (v -> TupleSelect(tExpr, i + 1).setType(v.getType)) + case (v,i) => (v -> TupleSelect(tExpr, i + 1).copiedFrom(v)) } Some(replace(substMap, body)) @@ -531,7 +465,7 @@ object TreeOps { Some(body) } else if(total == 1) { val substMap : Map[Expr,Expr] = ids.map(Variable(_) : Expr).zipWithIndex.toMap.map { - case (v,i) => (v -> TupleSelect(tExpr, i + 1).setType(v.getType)) + case (v,i) => (v -> TupleSelect(tExpr, i + 1).copiedFrom(v)) } Some(replace(substMap, body)) @@ -545,74 +479,6 @@ object TreeOps { searchAndReplaceDFS(simplerLet)(expr) } - // Pulls out all let constructs to the top level, and makes sure they're - // properly ordered. - private type DefPair = (Identifier,Expr) - private type DefPairs = List[DefPair] - private def allLetDefinitions(expr: Expr) : DefPairs = treeCatamorphism[DefPairs]( - (e: Expr) => Nil, - (s1: DefPairs, s2: DefPairs) => s1 ::: s2, - (e: Expr, dps: DefPairs) => e match { - case Let(i, e, _) => (i,e) :: dps - case _ => dps - }, - expr) - - private def killAllLets(expr: Expr) : Expr = searchAndReplaceDFS((e: Expr) => e match { - case Let(_,_,ex) => Some(ex) - case _ => None - })(expr) - - def liftLets(expr: Expr) : Expr = { - val initialDefinitionPairs = allLetDefinitions(expr) - val definitionPairs = initialDefinitionPairs.map(p => (p._1, killAllLets(p._2))) - val occursLists : Map[Identifier,Set[Identifier]] = Map(definitionPairs.map((dp: DefPair) => (dp._1 -> variablesOf(dp._2).toSet.filter(_.isLetBinder))) : _*) - var newList : DefPairs = Nil - var placed : Set[Identifier] = Set.empty - val toPlace = definitionPairs.size - var placedC = 0 - var traversals = 0 - - while(placedC < toPlace) { - if(traversals > toPlace + 1) { - scala.sys.error("Cycle in let definitions or multiple definition for the same identifier in liftLets : " + definitionPairs.mkString("\n")) - } - for((id,ex) <- definitionPairs) if (!placed(id)) { - if((occursLists(id) -- placed) == Set.empty) { - placed = placed + id - newList = (id,ex) :: newList - placedC = placedC + 1 - } - } - traversals = traversals + 1 - } - - val noLets = killAllLets(expr) - - val res = (newList.foldLeft(noLets)((e,iap) => Let(iap._1, iap._2, e))) - simplifyLets(res) - } - - def wellOrderedLets(tree : Expr) : Boolean = { - val pairs = allLetDefinitions(tree) - val definitions: Set[Identifier] = Set(pairs.map(_._1) : _*) - val vars: Set[Identifier] = variablesOf(tree) - val intersection = vars intersect definitions - if(!intersection.isEmpty) { - intersection.foreach(id => { - sys.error("Variable with identifier '" + id + "' has escaped its let-definition !") - }) - false - } else { - vars.forall(id => if(id.isLetBinder) { - sys.error("Variable with identifier '" + id + "' has lost its let-definition (it disappeared??)") - false - } else { - true - }) - } - } - /* Fully expands all let expressions. */ def expandLets(expr: Expr) : Expr = { def rec(ex: Expr, s: Map[Identifier,Expr]) : Expr = ex match { @@ -625,11 +491,11 @@ object TreeOps { val rargs = args.map(a => { val ra = rec(a, s) if(ra != a) { - change = true + change = true ra } else { a - } + } }) if(change) recons(rargs).setType(n.getType) @@ -774,11 +640,13 @@ object TreeOps { } case _ => None } - + searchAndReplaceDFS(rewritePM)(expr) } - /** Rewrites all map accesses with additional error conditions. */ + /** + * Rewrites all map accesses with additional error conditions. + */ val cacheMGWC = new TrieMap[Expr, Expr]() def mapGetWithChecks(expr: Expr) : Expr = { @@ -786,97 +654,72 @@ object TreeOps { case Some(res) => res case None => - val r = convertMapGet(expr) + def rewriteMapGet(e: Expr) : Option[Expr] = e match { + case mg @ MapGet(m,k) => + val ida = MapIsDefinedAt(m, k) + Some(IfExpr(ida, mg, Error("key not found for map access").copiedFrom(mg)).copiedFrom(mg)) + case _ => None + } + + val r = searchAndReplaceDFS(rewriteMapGet)(expr) cacheMGWC += expr -> r r } } - private def convertMapGet(expr: Expr) : Expr = { - def rewriteMapGet(e: Expr) : Option[Expr] = e match { - case mg @ MapGet(m,k) => - val ida = MapIsDefinedAt(m, k) - Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType).setPos(mg)).setType(mg.getType)) - case _ => None - } - - searchAndReplaceDFS(rewriteMapGet)(expr) - } - - // prec: expression does not contain match expressions - def measureADTChildrenDepth(expression: Expr) : Int = { - import scala.math.max - - def rec(ex: Expr, lm: Map[Identifier,Int]) : Int = ex match { - case Let(i,e,b) => rec(b,lm + (i -> rec(e,lm))) - case Variable(id) => lm.getOrElse(id, 0) - case CaseClassSelector(_, e, _) => rec(e,lm) + 1 - case NAryOperator(args, _) => if(args.isEmpty) 0 else args.map(rec(_,lm)).max - case BinaryOperator(e1,e2,_) => max(rec(e1,lm), rec(e2,lm)) - case UnaryOperator(e,_) => rec(e,lm) - case IfExpr(c,t,e) => max(max(rec(c,lm),rec(t,lm)),rec(e,lm)) - case t: Terminal => 0 - case _ => scala.sys.error("Not handled in measureChildrenDepth : " + ex) - } - - rec(expression,Map.empty) - } - - private val random = new scala.util.Random() - - def randomValue(v: Variable) : Expr = randomValue(v.getType) - def simplestValue(v: Variable) : Expr = simplestValue(v.getType) + /** + * Returns simplest value of a given type + */ + def simplestValue(tpe: TypeTree) : Expr = tpe match { + case Int32Type => IntLiteral(0) + case BooleanType => BooleanLiteral(false) + case SetType(baseType) => FiniteSet(Seq()).setType(tpe) + case MapType(fromType, toType) => FiniteMap(Seq()).setType(tpe) + case TupleType(tpes) => Tuple(tpes.map(simplestValue)) + case ArrayType(tpe) => ArrayFill(IntLiteral(0), simplestValue(tpe)) - def randomValue(tpe: TypeTree) : Expr = tpe match { - case Int32Type => IntLiteral(random.nextInt(42)) - case BooleanType => BooleanLiteral(random.nextBoolean()) case AbstractClassType(acd) => val children = acd.knownChildren - randomValue(classDefToClassType(children(random.nextInt(children.size)))) - case CaseClassType(cd) => - val fields = cd.fields - CaseClass(cd, fields.map(f => randomValue(f.getType))) - case _ => throw new Exception("I can't choose random value for type " + tpe) - } - def simplestValue(tpe: TypeTree) : Expr = tpe match { - case Int32Type => IntLiteral(0) - case BooleanType => BooleanLiteral(false) - case AbstractClassType(acd) => { - val children = acd.knownChildren - val simplerChildren = children.filter{ - case ccd @ CaseClassDef(id, Some(parent), fields) => - !fields.exists(vd => vd.getType match { - case AbstractClassType(fieldAcd) => acd == fieldAcd - case CaseClassType(fieldCcd) => ccd == fieldCcd - case _ => false - }) - case _ => false - } - def orderByNumberOfFields(fst: ClassTypeDef, snd: ClassTypeDef) : Boolean = (fst, snd) match { - case (CaseClassDef(_, _, flds1), CaseClassDef(_, _, flds2)) => flds1.size <= flds2.size - case _ => true + def isRecursive(ccd: CaseClassDef): Boolean = { + ccd.fields.exists(fd => fd.getType match { + case AbstractClassType(fieldAcd) => acd == fieldAcd + case CaseClassType(fieldCcd) => ccd == fieldCcd + case _ => false + }) } - val orderedChildren = simplerChildren.sortWith(orderByNumberOfFields) + + val nonRecChildren = children.collect { case ccd: CaseClassDef if !isRecursive(ccd) => ccd } + + val orderedChildren = nonRecChildren.sortBy(_.fields.size) + simplestValue(classDefToClassType(orderedChildren.head)) - } + case CaseClassType(ccd) => val fields = ccd.fields CaseClass(ccd, fields.map(f => simplestValue(f.getType))) - case SetType(baseType) => FiniteSet(Seq()).setType(tpe) - case MapType(fromType, toType) => FiniteMap(Seq()).setType(tpe) - case TupleType(tpes) => Tuple(tpes.map(simplestValue)) - case ArrayType(tpe) => ArrayFill(IntLiteral(0), simplestValue(tpe)) + case _ => throw new Exception("I can't choose simplest value for type " + tpe) } - //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be found in the sub-expressions - //require no-match, no-ets and only pure code + /** + * Guarentees that all IfExpr will be at the top level and as soon as you + * encounter a non-IfExpr, then no more IfExpr can be found in the + * sub-expressions + * + * Assumes no match expressions + */ def hoistIte(expr: Expr): Expr = { def transform(expr: Expr): Option[Expr] = expr match { - case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType)) - case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType)) - case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType)) + case uop@UnaryOperator(IfExpr(c, t, e), op) => + Some(IfExpr(c, op(t).copiedFrom(uop), op(e).copiedFrom(uop)).copiedFrom(uop)) + + case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => + Some(IfExpr(c, op(t, t2).copiedFrom(bop), op(e, t2).copiedFrom(bop)).copiedFrom(bop)) + + case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => + Some(IfExpr(c, op(t1, t).copiedFrom(bop), op(t1, e).copiedFrom(bop)).copiedFrom(bop)) + case nop@NAryOperator(ts, op) => { val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } if(iteIndex == -1) None else { @@ -884,8 +727,8 @@ object TreeOps { val afterIte = startIte.tail val IfExpr(c, t, e) = startIte.head Some(IfExpr(c, - op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType), - op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType) + op(beforeIte ++ Seq(t) ++ afterIte).copiedFrom(nop), + op(beforeIte ++ Seq(e) ++ afterIte).copiedFrom(nop) ).setType(nop.getType)) } } @@ -913,20 +756,20 @@ object TreeOps { case UnaryOperator(e, builder) => val (e1, c) = rec(e, ctx) - val newE = builder(e1) + val newE = builder(e1).copiedFrom(expr) (newE, combiner(Seq(c))) case BinaryOperator(e1, e2, builder) => val (ne1, c1) = rec(e1, ctx) val (ne2, c2) = rec(e2, ctx) - val newE = builder(ne1, ne2) + val newE = builder(ne1, ne2).copiedFrom(expr) (newE, combiner(Seq(c1, c2))) case NAryOperator(es, builder) => val (nes, cs) = es.map{ rec(_, ctx)}.unzip - val newE = builder(nes) + val newE = builder(nes).copiedFrom(expr) (newE, combiner(cs)) @@ -940,6 +783,7 @@ object TreeOps { rec(expr, init) } + private def noCombiner(subCs: Seq[Unit]) = () def simpleTransform(pre: Expr => Expr, post: Expr => Expr)(expr: Expr) = { @@ -961,19 +805,6 @@ object TreeOps { genericTransform[Unit]((e,c) => (e, None), newPost, noCombiner)(())(expr)._1 } - def toCNF(e: Expr): Expr = { - def pre(e: Expr) = e match { - case Or(Seq(l, And(Seq(r1, r2)))) => - And(Or(l, r1), Or(l, r2)) - case Or(Seq(And(Seq(l1, l2)), r)) => - And(Or(l1, r), Or(l2, r)) - case _ => - e - } - - simplePreTransform(pre)(e) - } - /* * Transforms complicated Ifs into multiple nested if blocks * It will decompose every OR clauses, and it will group AND clauses checking @@ -1006,6 +837,8 @@ object TreeOps { * } * * This transformation runs immediately before patternMatchReconstruction. + * + * Notes: positions are lost. */ def decomposeIfs(e: Expr): Expr = { def pre(e: Expr): Expr = e match { @@ -1036,7 +869,11 @@ object TreeOps { simplePreTransform(pre)(e) } - // This transformation assumes IfExpr of the form generated by decomposeIfs + /** + * Reconstructs match expressions from if-then-elses. + * + * Notes: positions are lost. + */ def patternMatchReconstruction(e: Expr): Expr = { def post(e: Expr): Expr = e match { case IfExpr(cond, thenn, elze) => @@ -1199,6 +1036,10 @@ object TreeOps { simplePostTransform(post)(e) } + /** + * Simplify If expressions when the branch is predetermined by the path + * condition + */ def simplifyTautologies(sf: SolverFactory[Solver])(expr : Expr) : Expr = { val solver = SimpleSolverAPI(sf) @@ -1211,9 +1052,9 @@ object TreeOps { case Some(true) => fd.precondition = None - case Some(false) => solver.solveVALID(Not(pre)) match { - case Some(true) => - fd.precondition = Some(BooleanLiteral(false)) + case Some(false) => solver.solveSAT(pre) match { + case (Some(false), _) => + fd.precondition = Some(BooleanLiteral(false).copiedFrom(e)) case _ => } case None => @@ -1246,205 +1087,10 @@ object TreeOps { new SimplifierWithPaths(sf).transform _ } - trait Transformer { - def transform(e: Expr): Expr - } - trait Traverser[T] { def traverse(e: Expr): T } - abstract class TransformerWithPC extends Transformer { - type C - - protected val initC: C - - protected def register(cond: Expr, path: C): C - - protected def rec(e: Expr, path: C): Expr = e match { - case Let(i, e, b) => - val se = rec(e, path) - val sb = rec(b, register(Equals(Variable(i), se), path)) - Let(i, se, sb) - - case MatchExpr(scrut, cases) => - val rs = rec(scrut, path) - - var soFar = path - - MatchExpr(rs, cases.map { c => - val patternExpr = conditionForPattern(rs, c.pattern, includeBinders = true) - - val subPath = register(patternExpr, soFar) - soFar = register(Not(patternExpr), soFar) - - c match { - case SimpleCase(p, rhs) => - SimpleCase(p, rec(rhs, subPath)) - case GuardedCase(p, g, rhs) => - GuardedCase(p, g, rec(rhs, subPath)) - } - }) - - case LetTuple(is, e, b) => - val se = rec(e, path) - val sb = rec(b, register(Equals(Tuple(is.map(Variable(_))), se), path)) - LetTuple(is, se, sb) - - case IfExpr(cond, thenn, elze) => - val rc = rec(cond, path) - - IfExpr(rc, rec(thenn, register(rc, path)), rec(elze, register(Not(rc), path))) - - case And(es) => { - var soFar = path - And(for(e <- es) yield { - val se = rec(e, soFar) - soFar = register(se, soFar) - se - }) - } - - case Or(es) => { - var soFar = path - Or(for(e <- es) yield { - val se = rec(e, soFar) - soFar = register(Not(se), soFar) - se - }) - } - - - case UnaryOperator(e, builder) => - builder(rec(e, path)) - - case BinaryOperator(e1, e2, builder) => - builder(rec(e1, path), rec(e2, path)) - - case NAryOperator(es, builder) => - builder(es.map(rec(_, path))) - - case t : Terminal => t - - case _ => - sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") - } - - def transform(e: Expr): Expr = { - rec(e, initC) - } - } - - class SimplifierWithPaths(sf: SolverFactory[Solver]) extends TransformerWithPC { - type C = List[Expr] - - val initC = Nil - - val solver = SimpleSolverAPI(sf) - - protected def register(e: Expr, c: C) = e :: c - - def impliedBy(e : Expr, path : Seq[Expr]) : Boolean = try { - solver.solveVALID(Implies(And(path), e)) match { - case Some(true) => true - case _ => false - } - } catch { - case _ : Exception => false - } - - def contradictedBy(e : Expr, path : Seq[Expr]) : Boolean = try { - solver.solveVALID(Implies(And(path), Not(e))) match { - case Some(true) => true - case _ => false - } - } catch { - case _ : Exception => false - } - - protected override def rec(e: Expr, path: C) = e match { - case IfExpr(cond, thenn, elze) => - super.rec(e, path) match { - case IfExpr(BooleanLiteral(true) , t, _) => t - case IfExpr(BooleanLiteral(false), _, e) => e - case ite => ite - } - - case And(es) => { - var soFar = path - var continue = true - var r = And(for(e <- es if continue) yield { - val se = rec(e, soFar) - if(se == BooleanLiteral(false)) continue = false - soFar = register(se, soFar) - se - }) - - if (continue) { - r - } else { - BooleanLiteral(false) - } - } - - case MatchExpr(scrut, cases) => - val rs = rec(scrut, path) - - var stillPossible = true - - if (cases.exists(_.hasGuard)) { - // unsupported for now - e - } else { - MatchExpr(rs, cases.flatMap { c => - val patternExpr = conditionForPattern(rs, c.pattern, includeBinders = true) - - if (stillPossible && !contradictedBy(patternExpr, path)) { - - if (impliedBy(patternExpr, path)) { - stillPossible = false - } - - c match { - case SimpleCase(p, rhs) => - Some(SimpleCase(p, rec(rhs, patternExpr +: path))) - case GuardedCase(_, _, _) => - sys.error("woot.") - } - } else { - None - } - }) - } - - case Or(es) => { - var soFar = path - var continue = true - var r = Or(for(e <- es if continue) yield { - val se = rec(e, soFar) - if(se == BooleanLiteral(true)) continue = false - soFar = register(Not(se), soFar) - se - }) - - if (continue) { - r - } else { - BooleanLiteral(true) - } - } - - case b if b.getType == BooleanType && impliedBy(b, path) => - BooleanLiteral(true) - - case b if b.getType == BooleanType && contradictedBy(b, path) => - BooleanLiteral(false) - - case _ => - super.rec(e, path) - } - } - class ChooseCollectorWithPaths extends TransformerWithPC with Traverser[Seq[(Choose, Expr)]] { type C = Seq[Expr] val initC = Nil @@ -1467,148 +1113,12 @@ object TreeOps { } } - class ScopeSimplifier extends Transformer { - - case class Scope(inScope: Set[Identifier] = Set(), oldToNew: Map[Identifier, Identifier] = Map(), funDefs: Map[FunDef, FunDef] = Map()) { - - def register(oldNew: (Identifier, Identifier)): Scope = { - val (oldId, newId) = oldNew - copy(inScope = inScope + newId, oldToNew = oldToNew + oldNew) - } - - def registerFunDef(oldNew: (FunDef, FunDef)): Scope = { - copy(funDefs = funDefs + oldNew) - } - } - - protected def genId(id: Identifier, scope: Scope): Identifier = { - val existCount = scope.inScope.count(_.name == id.name) - - FreshIdentifier(id.name, existCount).setType(id.getType) - } - - protected def rec(e: Expr, scope: Scope): Expr = e match { - case Let(i, e, b) => - val si = genId(i, scope) - val se = rec(e, scope) - val sb = rec(b, scope.register(i -> si)) - Let(si, se, sb) - - case LetDef(fd: FunDef, body: Expr) => - val newId = genId(fd.id, scope) - var newScope = scope.register(fd.id -> newId) - - val newArgs = for(VarDecl(id, tpe) <- fd.args) yield { - val newArg = genId(id, newScope) - newScope = newScope.register(id -> newArg) - VarDecl(newArg, tpe) - } - - val newFd = new FunDef(newId, fd.returnType, newArgs) - - newScope = newScope.registerFunDef(fd -> newFd) - - newFd.body = fd.body.map(b => rec(b, newScope)) - newFd.precondition = fd.precondition.map(pre => rec(pre, newScope)) - - newFd.postcondition = fd.postcondition.map { - case (id, post) => - val nid = genId(id, newScope) - val postScope = newScope.register(id -> nid) - (nid, rec(post, postScope)) - } - - LetDef(newFd, rec(body, newScope)) - - case LetTuple(is, e, b) => - var newScope = scope - val sis = for (i <- is) yield { - val si = genId(i, newScope) - newScope = newScope.register(i -> si) - si - } - - val se = rec(e, scope) - val sb = rec(b, newScope) - LetTuple(sis, se, sb) - - case MatchExpr(scrut, cases) => - val rs = rec(scrut, scope) - - def trPattern(p: Pattern, scope: Scope): (Pattern, Scope) = { - val (newBinder, newScope) = p.binder match { - case Some(id) => - val newId = genId(id, scope) - val newScope = scope.register(id -> newId) - (Some(newId), newScope) - case None => - (None, scope) - } - - var curScope = newScope - var newSubPatterns = for (sp <- p.subPatterns) yield { - val (subPattern, subScope) = trPattern(sp, curScope) - curScope = subScope - subPattern - } - - val newPattern = p match { - case InstanceOfPattern(b, ctd) => - InstanceOfPattern(newBinder, ctd) - case WildcardPattern(b) => - WildcardPattern(newBinder) - case CaseClassPattern(b, ccd, sub) => - CaseClassPattern(newBinder, ccd, newSubPatterns) - case TuplePattern(b, sub) => - TuplePattern(newBinder, newSubPatterns) - } - - - (newPattern, curScope) - } - - MatchExpr(rs, cases.map { c => - val (newP, newScope) = trPattern(c.pattern, scope) - - c match { - case SimpleCase(p, rhs) => - SimpleCase(newP, rec(rhs, newScope)) - case GuardedCase(p, g, rhs) => - GuardedCase(newP, rec(g, newScope), rec(rhs, newScope)) - } - }) - - case Variable(id) => - Variable(scope.oldToNew.getOrElse(id, id)) - - case FunctionInvocation(fd, args) => - val newFd = scope.funDefs.getOrElse(fd, fd) - val newArgs = args.map(rec(_, scope)) - - FunctionInvocation(newFd, newArgs) - - case UnaryOperator(e, builder) => - builder(rec(e, scope)) - - case BinaryOperator(e1, e2, builder) => - builder(rec(e1, scope), rec(e2, scope)) - - case NAryOperator(es, builder) => - builder(es.map(rec(_, scope))) - - case t : Terminal => t - - case _ => - sys.error("Expression "+e+" ["+e.getClass+"] is not extractable") - } - - def transform(e: Expr): Expr = { - rec(e, Scope()) - } - } - - // Eliminates tuples of arity 0 and 1. This function also affects types! - // Only rewrites local fundefs (i.e. LetDef's). + /** + * Eliminates tuples of arity 0 and 1. + * Used to simplify synthesis solutions + * + * Only rewrites local fundefs. + */ def rewriteTuples(expr: Expr) : Expr = { def mapType(tt : TypeTree) : Option[TypeTree] = tt match { case TupleType(ts) => ts.size match { @@ -1719,22 +1229,6 @@ object TreeOps { es.map(formulaSize).foldRight(0)(_ + _)+1 } - def collect[C](f: PartialFunction[Expr, C])(e: Expr): List[C] = { - def post(e: Expr, cs: List[C]) = { - if (f.isDefinedAt(e)) { - (e, f(e) :: cs) - } else { - (e, cs) - } - } - - def combiner(cs: Seq[List[C]]) = { - cs.foldLeft(List[C]())(_ ::: _) - } - - genericTransform[List[C]]((_, _), post, combiner)(List())(e)._2 - } - def collectChooses(e: Expr): List[Choose] = { new ChooseCollectorWithPaths().traverse(e).map(_._1).toList } @@ -1747,22 +1241,33 @@ object TreeOps { false } + /** + * Returns the value for an identifier given a model. + */ def valuateWithModel(model: Map[Identifier, Expr])(id: Identifier): Expr = { model.getOrElse(id, simplestValue(id.getType)) } + /** + * Substitute (free) variables in an expression with values form a model. + * + * Complete with simplest values in case of incomplete model. + */ def valuateWithModelIn(expr: Expr, vars: Set[Identifier], model: Map[Identifier, Expr]): Expr = { val valuator = valuateWithModel(model) _ replace(vars.map(id => Variable(id) -> valuator(id)).toMap, expr) } - //simple, local simplifications on arithmetic - //you should not assume anything smarter than some constant folding and simple cancelation - //to avoid infinite cycle we only apply simplification that reduce the size of the tree - //The only guarentee from this function is to not augment the size of the expression and to be sound - //(note that an identity function would meet this specification) + /** + * Simple, local simplification on arithmetic + * + * You should not assume anything smarter than some constant folding and + * simple cancelation. To avoid infinite cycle we only apply simplification + * that reduce the size of the tree. The only guarentee from this function is + * to not augment the size of the expression and to be sound. + */ def simplifyArithmetic(expr: Expr): Expr = { - def simplify0(expr: Expr): Expr = expr match { + def simplify0(expr: Expr): Expr = (expr match { case Plus(IntLiteral(i1), IntLiteral(i2)) => IntLiteral(i1 + i2) case Plus(IntLiteral(0), e) => e case Plus(e, IntLiteral(0)) => e @@ -1806,45 +1311,23 @@ object TreeOps { //default case e => e - } + }).copiedFrom(expr) + def fix[A](f: (A) => A)(a: A): A = { val na = f(a) if(a == na) a else fix(f)(na) } - - val res = fix(simplePostTransform(simplify0))(expr) - res - } - - def expandAndSimplifyArithmetic(expr: Expr): Expr = { - val expr0 = try { - val freeVars: Array[Identifier] = variablesOf(expr).toArray - val coefs: Array[Expr] = TreeNormalizations.linearArithmeticForm(expr, freeVars) - coefs.toList.zip(IntLiteral(1) :: freeVars.toList.map(Variable(_))).foldLeft[Expr](IntLiteral(0))((acc, t) => { - if(t._1 == IntLiteral(0)) acc else Plus(acc, Times(t._1, t._2)) - }) - } catch { - case _: Throwable => - expr - } - simplifyArithmetic(expr0) - } - - //If the formula consist of some top level AND, find a top level - //Equals and extract it, return the remaining formula as well - def extractEquals(expr: Expr): (Option[Equals], Expr) = expr match { - case And(es) => - // OK now I'm just messing with you. - val (r, nes) = es.foldLeft[(Option[Equals],Seq[Expr])]((None, Seq())) { - case ((None, nes), eq @ Equals(_,_)) => (Some(eq), nes) - case ((o, nes), e) => (o, e +: nes) - } - (r, And(nes.reverse)) - - case e => (None, e) + fix(simplePostTransform(simplify0))(expr) } + /** + * Checks whether a predicate is inductive on a certain identfier. + * + * isInductive(foo(a, b), a) where a: List will check whether + * foo(Nil, b) and + * foo(Cons(h,t), b) => foo(t, b) + */ def isInductiveOn(sf: SolverFactory[Solver])(expr: Expr, on: Identifier): Boolean = on match { case IsTyped(origId, AbstractClassType(cd)) => def isAlternativeRecursive(cd: CaseClassDef): Boolean = { @@ -1885,19 +1368,11 @@ object TreeOps { false } - def containsLetDef(expr: Expr): Boolean = { - def convert(t : Expr) : Boolean = t match { - case (l : LetDef) => true - case _ => false - } - def combine(c1 : Boolean, c2 : Boolean) : Boolean = c1 || c2 - def compute(t : Expr, c : Boolean) = t match { - case (l : LetDef) => true - case _ => c - } - treeCatamorphism(convert, combine, compute, expr) - } - + /** + * Checks whether two trees are homomoprhic modulo an identifier map. + * + * Used for transformation tests. + */ def isHomomorphic(t1: Expr, t2: Expr)(implicit map: Map[Identifier, Identifier]): Boolean = { object Same { def unapply(tt: (Expr, Expr)): Option[(Expr, Expr)] = { @@ -2045,20 +1520,29 @@ object TreeOps { false } - //if (!res) { - // println("@"*80) - // println("MISMATCH:") - // println("t1:"+t1) - // println("t2:"+t2) - // println("map:"+map) - //} - res } isHomo(t1,t2) } + /** + * Checks whether the match cases cover all possible inputs + * Used when reconstructing pattern matching from ITE. + * + * e.g. The following: + * + * list match { + * case Cons(_, Cons(_, a)) => + * + * case Cons(_, Nil) => + * + * case Nil => + * + * } + * + * is exaustive. + */ def isMatchExhaustive(m: MatchExpr): Boolean = { /** * Takes the matrix of the cases per position/types: @@ -2142,6 +1626,23 @@ object TreeOps { areExaustive(Seq((m.scrutinee.getType, patterns))) } + /** + * Flattens a function that contains a LetDef with a direct call to it + * Used for merging synthesis results. + * + * def foo(a, b) { + * def bar(c, d) { + * if (..) { bar(c, d) } else { .. } + * } + * bar(b, a) + * } + * + * becomes + * + * def foo(a, b) { + * if (..) { foo(b, a) } else { .. } + * } + **/ def flattenFunctions(fdOuter: FunDef): FunDef = { fdOuter.body match { case Some(LetDef(fdInner, FunctionInvocation(fdInner2, args))) if fdInner == fdInner2 => @@ -2200,4 +1701,18 @@ object TreeOps { fdOuter } } + + def expandAndSimplifyArithmetic(expr: Expr): Expr = { + val expr0 = try { + val freeVars: Array[Identifier] = variablesOf(expr).toArray + val coefs: Array[Expr] = TreeNormalizations.linearArithmeticForm(expr, freeVars) + coefs.toList.zip(IntLiteral(1) :: freeVars.toList.map(Variable(_))).foldLeft[Expr](IntLiteral(0))((acc, t) => { + if(t._1 == IntLiteral(0)) acc else Plus(acc, Times(t._1, t._2)) + }) + } catch { + case _: Throwable => + expr + } + simplifyArithmetic(expr0) + } } diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index e60a31cfa2a3fff0458e18671cde6bdb6e985402..74f8ead992b73a9eb6a2934e97fedafe89020507 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -382,8 +382,6 @@ object Trees { override def setType(tt: TypeTree) = { id.setType(tt); this } } - case class DeBruijnIndex(index: Int) extends Expr with Terminal - /* Literals */ sealed abstract class Literal[T] extends Expr with Terminal { val value: T diff --git a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala b/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala index 558274db58ea37f712abdcbb1618b872d93e9638..ab2d7084c5facbdb13e559e26fd0d03348b520c8 100644 --- a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala +++ b/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala @@ -16,7 +16,6 @@ trait Z3ModelReconstruction { // exprToZ3Id, softFromZ3Formula, reporter private final val AUTOCOMPLETEMODELS : Boolean = true - private final val SIMPLESTCOMPLETION : Boolean = true // if true, use 0, Nil(), etc., else random def modelValue(model: Z3Model, id: Identifier, tpe: TypeTree = null) : Option[Expr] = { val expectedType = if(tpe == null) id.getType else tpe @@ -49,12 +48,9 @@ trait Z3ModelReconstruction { def modelToMap(model: Z3Model, ids: Iterable[Identifier]) : Map[Identifier,Expr] = { var asMap = Map.empty[Identifier,Expr] - def completeID(id : Identifier) : Unit = if (SIMPLESTCOMPLETION) { - asMap = asMap + ((id -> simplestValue(id.toVariable))) + def completeID(id : Identifier) : Unit = { + asMap = asMap + ((id -> simplestValue(id.getType))) reporter.info("Completing variable '" + id + "' to simplest value") - } else { - asMap = asMap + ((id -> randomValue(id.toVariable))) - reporter.info("Completing variable '" + id + "' to random value") } for(id <- ids) { diff --git a/src/main/scala/leon/synthesis/FileInterface.scala b/src/main/scala/leon/synthesis/FileInterface.scala index 49ff882540e8c27798bf7e50c96cb50427d44cdd..0cb67cc294c51dc116237e6001c9b9dbf5d21d47 100644 --- a/src/main/scala/leon/synthesis/FileInterface.scala +++ b/src/main/scala/leon/synthesis/FileInterface.scala @@ -7,6 +7,7 @@ import purescala.Trees._ import purescala.Common.Tree import purescala.Definitions.FunDef import purescala.ScalaPrinter +import purescala.PrinterOptions import leon.utils.RangePosition @@ -54,7 +55,7 @@ class FileInterface(reporter: Reporter) { val before = str.substring(0, from) val after = str.substring(to, str.length) - val newCode = ScalaPrinter(toTree, fromTree.getPos.col/2) + val newCode = ScalaPrinter(toTree, PrinterOptions(baseIndent = fromTree.getPos.col/2)) before + newCode + after diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 782a72049d84cfc0c8c09e8ab477e983434fd1ac..b6b68d5dbd665467bbd618b5ab0daaa2c6a55fb4 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -7,6 +7,7 @@ import purescala.Trees._ import purescala.TypeTrees.{TypeTree,TupleType} import purescala.Definitions._ import purescala.TreeOps._ +import purescala.ScopeSimplifier import solvers.z3._ import solvers._ diff --git a/src/main/scala/leon/synthesis/rules/Assert.scala b/src/main/scala/leon/synthesis/rules/Assert.scala index 240ed46850b200ab3b666296a1e4f8b23e843fd3..51afa91d6e44a6c8775d344647c3ca04f2c879e1 100644 --- a/src/main/scala/leon/synthesis/rules/Assert.scala +++ b/src/main/scala/leon/synthesis/rules/Assert.scala @@ -18,7 +18,7 @@ case object Assert extends NormalizingRule("Assert") { if (!exprsA.isEmpty) { if (others.isEmpty) { - List(RuleInstantiation.immediateSuccess(p, this, Solution(And(exprsA), Set(), Tuple(p.xs.map(id => simplestValue(Variable(id))))))) + List(RuleInstantiation.immediateSuccess(p, this, Solution(And(exprsA), Set(), Tuple(p.xs.map(id => simplestValue(id.getType)))))) } else { val sub = p.copy(pc = And(p.pc +: exprsA), phi = And(others)) diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index ae065bc4631060e6d804c29a1fdc583a9c7b2ee6..59a8e70ab759a5f76806f91fb1b90fa0c50b6bb4 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -108,7 +108,7 @@ case object CEGIS extends Rule("CEGIS") { val isNotSynthesizable = fd.body match { case Some(b) => - collectChooses(b).isEmpty + !containsChoose(b) case None => false diff --git a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala index 34902c9cbeb170ec481024a1e5d82cb643f37530..6def2cdfb455ede996d162e0e7faa86dfc52cb32 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala @@ -19,6 +19,9 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { def instantiateOn(sctx: SynthesisContext, problem: Problem): Traversable[RuleInstantiation] = { val TopLevelAnds(exprs) = problem.phi + + + //assume that we only have inequalities var lhsSides: List[Expr] = Nil var exprNotUsed: List[Expr] = Nil diff --git a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala index 4d5e0add31e2394feb10a96e73af1ac8340e3c1d..9584311f70423dee4cd4f0f6ad83dfc298216a57 100644 --- a/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala +++ b/src/main/scala/leon/synthesis/rules/UnconstrainedOutput.scala @@ -17,7 +17,7 @@ case object UnconstrainedOutput extends NormalizingRule("Unconstr.Output") { val onSuccess: List[Solution] => Option[Solution] = { case List(s) => - Some(Solution(s.pre, s.defs, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(Variable(id)) else Variable(id)))))) + Some(Solution(s.pre, s.defs, LetTuple(sub.xs, s.term, Tuple(p.xs.map(id => if (unconstr(id)) simplestValue(id.getType) else Variable(id)))))) case _ => None } diff --git a/src/main/scala/leon/utils/Positions.scala b/src/main/scala/leon/utils/Positions.scala index 1ca2f9faf279537f9bff8965d85f1c2a7d5b716a..6fda9a100b7cf67fabb3a337dad6edcc849edffc 100644 --- a/src/main/scala/leon/utils/Positions.scala +++ b/src/main/scala/leon/utils/Positions.scala @@ -21,10 +21,12 @@ case class OffsetPosition(line: Int, col: Int, point: Int, file: File) extends P case class RangePosition(lineFrom: Int, colFrom: Int, pointFrom: Int, lineTo: Int, colTo: Int, pointTo: Int, file: File) extends Position { + + def focusEnd = OffsetPosition(lineTo, colTo, pointTo, file) + def focusBegin = OffsetPosition(lineFrom, colFrom, pointFrom, file) + val line = lineFrom val col = colFrom - - override def toString = lineFrom+":"+colFrom+"->"+lineTo+":"+colTo } case object NoPosition extends Position { diff --git a/src/main/scala/leon/xlang/ArrayTransformation.scala b/src/main/scala/leon/xlang/ArrayTransformation.scala index aa259e76c068dc59f2b5c6f96e1bbf5cb7fdd89a..4d6ad43251a5d139a27caaf235e091f0eb054096 100644 --- a/src/main/scala/leon/xlang/ArrayTransformation.scala +++ b/src/main/scala/leon/xlang/ArrayTransformation.scala @@ -32,7 +32,7 @@ object ArrayTransformation extends TransformationPhase { } - def transform(expr: Expr): Expr = expr match { + def transform(expr: Expr): Expr = (expr match { case sel@ArraySelect(a, i) => { val ra = transform(a) val ri = transform(i) @@ -129,6 +129,6 @@ object ArrayTransformation extends TransformationPhase { case (t: Terminal) => t case unhandled => scala.sys.error("Non-terminal case should be handled in ArrayTransformation: " + unhandled) - } + }).setPos(expr) } diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index 94cdd463d110a654db6b7f548684698b6262ed58..d77bf920805be4de79aa733169e918067421e7bf 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -43,19 +43,19 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef private def toFunction(expr: Expr): (Expr, Expr => Expr, Map[Identifier, Identifier]) = { val res = expr match { case LetVar(id, e, b) => { - val newId = FreshIdentifier(id.name).setType(id.getType) + val newId = FreshIdentifier(id.name).copiedFrom(id) val (rhsVal, rhsScope, rhsFun) = toFunction(e) varInScope += id val (bodyRes, bodyScope, bodyFun) = toFunction(b) varInScope -= id - val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, replaceNames(rhsFun + (id -> newId), bodyScope(body)))) + val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, replaceNames(rhsFun + (id -> newId), bodyScope(body))).copiedFrom(expr)) (bodyRes, scope, (rhsFun + (id -> newId)) ++ bodyFun) } case Assignment(id, e) => { assert(varInScope.contains(id)) - val newId = FreshIdentifier(id.name).setType(id.getType) + val newId = FreshIdentifier(id.name).copiedFrom(id) val (rhsVal, rhsScope, rhsFun) = toFunction(e) - val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, body)) + val scope = (body: Expr) => rhsScope(Let(newId, rhsVal, body).copiedFrom(expr)) (UnitLiteral, scope, rhsFun + (id -> newId)) } @@ -81,7 +81,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef case None => vId.toVariable })).setType(iteType) - val iteExpr = IfExpr(cRes, replaceNames(cFun, tScope(thenVal)), replaceNames(cFun, eScope(elseVal))).setType(iteType) + val iteExpr = IfExpr(cRes, replaceNames(cFun, tScope(thenVal)), replaceNames(cFun, eScope(elseVal))).setType(iteType).copiedFrom(ite) val scope = ((body: Expr) => { val tupleId = FreshIdentifier("t").setType(iteType) @@ -94,7 +94,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef freshIds.zipWithIndex.foldLeft(body)((b, id) => Let(id._1, TupleSelect(tupleId.toVariable, id._2 + 2).setType(id._1.getType), - b))))) + b)))).copiedFrom(expr)) }) (resId.toVariable, scope, cFun ++ modifiedVars.zip(freshIds).toMap) @@ -117,12 +117,12 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef })).setType(matchType) } - val newRhs = csesVals.zip(csesScope).map{ + val newRhs = csesVals.zip(csesScope).map{ case (cVal, cScope) => replaceNames(scrutFun, cScope(cVal)) } val matchExpr = MatchExpr(scrutRes, cses.zip(newRhs).map{ - case (SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs) - case (GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs) + case (sc @ SimpleCase(pat, _), newRhs) => SimpleCase(pat, newRhs).setPos(sc) + case (gc @ GuardedCase(pat, guard, _), newRhs) => GuardedCase(pat, replaceNames(scrutFun, guard), newRhs).setPos(gc) }).setType(matchType).setPos(m) val scope = ((body: Expr) => { @@ -233,7 +233,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val (bindRes, bindScope, bindFun) = toFunction(e) val (bodyRes, bodyScope, bodyFun) = toFunction(b) (bodyRes, - (b2: Expr) => bindScope(Let(id, bindRes, replaceNames(bindFun, bodyScope(b2)))), + (b2: Expr) => bindScope(Let(id, bindRes, replaceNames(bindFun, bodyScope(b2))).copiedFrom(expr)), bindFun ++ bodyFun) } case LetDef(fd, b) => { @@ -247,12 +247,12 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef fd } val (bodyRes, bodyScope, bodyFun) = toFunction(b) - (bodyRes, (b2: Expr) => LetDef(newFd, bodyScope(b2)), bodyFun) + (bodyRes, (b2: Expr) => LetDef(newFd, bodyScope(b2)).copiedFrom(expr), bodyFun) } case c @ Choose(ids, b) => { //Recall that Choose cannot mutate variables from the scope val (bodyRes, bodyScope, bodyFun) = toFunction(b) - (bodyRes, (b2: Expr) => Choose(ids, bodyScope(b2)).setPos(c), bodyFun) + (bodyRes, (b2: Expr) => Choose(ids, bodyScope(b2)).copiedFrom(c), bodyFun) } case n @ NAryOperator(Seq(), recons) => (n, (body: Expr) => body, Map()) case n @ NAryOperator(args, recons) => { @@ -262,7 +262,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val newScope = (body: Expr) => argScope(replaceNames(argFun, accScope(body))) (argVal +: accArgs, newScope, argFun ++ accFun) }) - (recons(recArgs).setType(n.getType), scope, fun) + (recons(recArgs).copiedFrom(n), scope, fun) } case b @ BinaryOperator(a1, a2, recons) => { val (argVal1, argScope1, argFun1) = toFunction(a1) @@ -272,11 +272,11 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val lhs = argScope1(replaceNames(argFun1, rhs)) lhs } - (recons(argVal1, argVal2).setType(b.getType), scope, argFun1 ++ argFun2) + (recons(argVal1, argVal2).copiedFrom(b), scope, argFun1 ++ argFun2) } case u @ UnaryOperator(a, recons) => { val (argVal, argScope, argFun) = toFunction(a) - (recons(argVal).setType(u.getType), argScope, argFun) + (recons(argVal).copiedFrom(u), argScope, argFun) } case (t: Terminal) => (t, (body: Expr) => body, Map()) diff --git a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala index 972d0d73271e59747413810ec2255fb79cd59885..6967f38def16bf3ff7503c4f54ea9b88ce26b1f6 100644 --- a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala +++ b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala @@ -420,11 +420,11 @@ class EvaluatorsTests extends LeonTestSuite { } } - test("Misc") { + test("Executing Chooses") { val p = """|object Program { | import leon.Utils._ | - | def c(i : Int) : Int = choose { (j : Int) => j > i } + | def c(i : Int) : Int = choose { (j : Int) => j > i && j < i + 2 } |} |""".stripMargin @@ -432,7 +432,7 @@ class EvaluatorsTests extends LeonTestSuite { val evaluators = prepareEvaluators for(e <- evaluators) { - checkEvaluatorError(e, mkCall("c", IL(42))) + checkComp(e, mkCall("c", IL(42)), IL(43)) } } diff --git a/src/test/scala/leon/test/purescala/TreeOpsTests.scala b/src/test/scala/leon/test/purescala/TreeOpsTests.scala index f261d246b52e286038ff68ff6a285f5fb7598d27..790bf7d577e782e55c67fa3311796889ae7c174f 100644 --- a/src/test/scala/leon/test/purescala/TreeOpsTests.scala +++ b/src/test/scala/leon/test/purescala/TreeOpsTests.scala @@ -22,6 +22,22 @@ class TreeOpsTests extends LeonTestSuite { assert(true) } + /** + * If the formula consist of some top level AND, find a top level + * Equals and extract it, return the remaining formula as well + */ + def extractEquals(expr: Expr): (Option[Equals], Expr) = expr match { + case And(es) => + // OK now I'm just messing with you. + val (r, nes) = es.foldLeft[(Option[Equals],Seq[Expr])]((None, Seq())) { + case ((None, nes), eq @ Equals(_,_)) => (Some(eq), nes) + case ((o, nes), e) => (o, e +: nes) + } + (r, And(nes.reverse)) + + case e => (None, e) + } + def i(x: Int) = IntLiteral(x)