diff --git a/src/main/scala/leon/LeonFatalError.scala b/src/main/scala/leon/LeonFatalError.scala index 8ceb4f3fd1d32b5cd21622bacf15e679eacf6270..e6fb5ca3779ee28d48babcae06fa60c3dcbb3fc6 100644 --- a/src/main/scala/leon/LeonFatalError.scala +++ b/src/main/scala/leon/LeonFatalError.scala @@ -2,4 +2,4 @@ package leon -case class LeonFatalError() extends Exception +case class LeonFatalError(msg: String) extends Exception(msg) diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index 58386b84906b5371fe09e455f3cb1a0f60292d16..ad29c7cb12c6d3280f23ec8b7f1a0f55958571f6 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -207,12 +207,12 @@ object Main { } def main(args : Array[String]) { - try { - // Process options - val timer = new Timer().start + val timer = new Timer().start - val ctx = processOptions(args.toList) + // Process options + val ctx = processOptions(args.toList) + try { ctx.interruptManager.registerSignalHandler() ctx.timers.get("Leon Opts") += timer @@ -246,7 +246,9 @@ object Main { } } catch { - case LeonFatalError() => sys.exit(1) + case LeonFatalError(msg) => + ctx.reporter.error(msg) + sys.exit(1) } } } diff --git a/src/main/scala/leon/Reporter.scala b/src/main/scala/leon/Reporter.scala index c13bfa60e4f802b112a58908249a023177012324..1cd2585a096f62cf5f3d50757524122b408dffd2 100644 --- a/src/main/scala/leon/Reporter.scala +++ b/src/main/scala/leon/Reporter.scala @@ -85,7 +85,7 @@ class DefaultReporter(settings: Settings) extends Reporter(settings) { def debugFunction(msg: Any) = output(reline(debugPfx, msg.toString)) def fatalErrorFunction(msg: Any) = { output(reline(fatalPfx, msg.toString)); - throw LeonFatalError() + throw LeonFatalError(msg.toString) } } diff --git a/src/main/scala/leon/codegen/CodeGeneration.scala b/src/main/scala/leon/codegen/CodeGeneration.scala index ffb30cbb2dd8f79992a71e71f5ced310cf3fb943..63defda0c4ee81e79d05eb5496c245dd7b8bfe22 100644 --- a/src/main/scala/leon/codegen/CodeGeneration.scala +++ b/src/main/scala/leon/codegen/CodeGeneration.scala @@ -54,7 +54,7 @@ trait CodeGeneration { case UnitType => "Z" case c : ClassType => - leonClassToJVMClass(c.classDef).map(n => "L" + n + ";").getOrElse("Unsupported class " + c.id) + leonClassToJVMInfo(c.classDef).map { case (n, _) => "L" + n + ";" }.getOrElse("Unsupported class " + c.id) case _ : TupleType => "L" + TupleClass + ";" @@ -68,6 +68,9 @@ trait CodeGeneration { case ArrayType(base) => "[" + typeToJVM(base) + case TypeParameter(_) => + "Ljava/lang/Object;" + case _ => throw CompilationException("Unsupported type : " + tpe) } @@ -107,7 +110,7 @@ trait CodeGeneration { case Int32Type | BooleanType | UnitType => ch << IRETURN - case _ : ClassType | _ : TupleType | _ : SetType | _ : MapType | _ : ArrayType => + case _ : ClassType | _ : TupleType | _ : SetType | _ : MapType | _ : ArrayType | _: TypeParameter => ch << ARETURN case other => @@ -166,32 +169,35 @@ trait CodeGeneration { ch << Ldc(1) // Case classes - case CaseClass(ccd, as) => - val ccName = leonClassToJVMClass(ccd).getOrElse { - throw CompilationException("Unknown class : " + ccd.id) + case CaseClass(cct, as) => + val (ccName, ccApplySig) = leonClassToJVMInfo(cct.classDef).getOrElse { + throw CompilationException("Unknown class : " + cct.id) } - // TODO FIXME It's a little ugly that we do it each time. Could be in env. - val consSig = "(" + ccd.fields.map(f => typeToJVM(f.tpe)).mkString("") + ")V" ch << New(ccName) << DUP - for(a <- as) { - mkExpr(a, ch) + for((a, vd) <- as zip cct.classDef.fields) { + vd.tpe match { + case TypeParameter(_) => + mkBoxedExpr(a, ch) + case _ => + mkExpr(a, ch) + } } - ch << InvokeSpecial(ccName, constructorName, consSig) + ch << InvokeSpecial(ccName, constructorName, ccApplySig) - case CaseClassInstanceOf(ccd, e) => - val ccName = leonClassToJVMClass(ccd).getOrElse { - throw CompilationException("Unknown class : " + ccd.id) + case CaseClassInstanceOf(cct, e) => + val (ccName, _) = leonClassToJVMInfo(cct.classDef).getOrElse { + throw CompilationException("Unknown class : " + cct.id) } mkExpr(e, ch) ch << InstanceOf(ccName) - case CaseClassSelector(ccd, e, sid) => + case CaseClassSelector(cct, e, sid) => mkExpr(e, ch) - val ccName = leonClassToJVMClass(ccd).getOrElse { - throw CompilationException("Unknown class : " + ccd.id) + val (ccName, _) = leonClassToJVMInfo(cct.classDef).getOrElse { + throw CompilationException("Unknown class : " + cct.id) } ch << CheckCast(ccName) - instrumentedGetField(ch, ccd, sid) + instrumentedGetField(ch, cct, sid) // Tuples (note that instanceOf checks are in mkBranch) case Tuple(es) => @@ -290,18 +296,32 @@ trait CodeGeneration { mkExpr(e, ch) ch << Label(al) - case FunctionInvocation(fd, as) => - val (cn, mn, ms) = leonFunDefToJVMInfo(fd).getOrElse { - throw CompilationException("Unknown method : " + fd.id) + case FunctionInvocation(tfd, as) => + val (cn, mn, ms) = leonFunDefToJVMInfo(tfd.fd).getOrElse { + throw CompilationException("Unknown method : " + tfd.id) } + if (params.requireMonitor) { ch << ALoad(0) } - for(a <- as) { - mkExpr(a, ch) + + for((a, vd) <- as zip tfd.fd.args) { + vd.tpe match { + case TypeParameter(_) => + mkBoxedExpr(a, ch) + case _ => + mkExpr(a, ch) + } } + ch << InvokeStatic(cn, mn, ms) + (tfd.fd.returnType, tfd.returnType) match { + case (TypeParameter(_), tpe) => + mkUnbox(tpe, ch) + case _ => + } + // Arithmetic case Plus(l, r) => mkExpr(l, ch) @@ -466,7 +486,7 @@ trait CodeGeneration { ch << CheckCast(BoxedBoolClass) << InvokeVirtual(BoxedBoolClass, "booleanValue", "()Z") case ct : ClassType => - val cn = leonClassToJVMClass(ct.classDef).getOrElse { + val (cn, _) = leonClassToJVMInfo(ct.classDef).getOrElse { throw new CompilationException("Unsupported class type : " + ct) } ch << CheckCast(cn) @@ -480,6 +500,8 @@ trait CodeGeneration { case mt : MapType => ch << CheckCast(MapClass) + case tp : TypeParameter => + case _ => throw new CompilationException("Unsupported type in unboxing : " + tpe) } @@ -586,9 +608,13 @@ trait CodeGeneration { */ val instrumentedField = "__read" - def instrumentedGetField(ch: CodeHandler, ccd: CaseClassDef, id: Identifier)(implicit locals: Locals): Unit = { + def instrumentedGetField(ch: CodeHandler, cct: CaseClassType, id: Identifier)(implicit locals: Locals): Unit = { + val ccd = cct.classDef + ccd.fields.zipWithIndex.find(_._1.id == id) match { case Some((f, i)) => + val expType = cct.fields(i).tpe + val cName = defToJVMName(ccd) if (params.doInstrument) { ch << DUP << DUP @@ -600,6 +626,12 @@ trait CodeGeneration { ch << PutField(cName, instrumentedField, "I") } ch << GetField(cName, f.id.name, typeToJVM(f.tpe)) + + f.tpe match { + case TypeParameter(_) => + mkUnbox(expType, ch) + case _ => + } case None => throw CompilationException("Unknown field: "+ccd.id.name+"."+id) } @@ -608,7 +640,8 @@ trait CodeGeneration { def compileCaseClassDef(ccd: CaseClassDef) { val cName = defToJVMName(ccd) - val pName = ccd.parent.map(parent => defToJVMName(parent)) + val pName = ccd.parent.map(parent => defToJVMName(parent.classDef)) + val cct = CaseClassType(ccd, ccd.tparams.map(_.tp)) val cf = classes(ccd) @@ -710,7 +743,7 @@ trait CodeGeneration { pech << DUP pech << Ldc(i) pech << ALoad(0) - instrumentedGetField(pech, ccd, f.id)(NoLocals) + instrumentedGetField(pech, cct, f.id)(NoLocals) mkBox(f.tpe, pech)(NoLocals) pech << AASTORE } @@ -745,9 +778,9 @@ trait CodeGeneration { for(vd <- ccd.fields) { ech << ALoad(0) - instrumentedGetField(ech, ccd, vd.id)(NoLocals) + instrumentedGetField(ech, cct, vd.id)(NoLocals) ech << ALoad(castSlot) - instrumentedGetField(ech, ccd, vd.id)(NoLocals) + instrumentedGetField(ech, cct, vd.id)(NoLocals) typeToJVM(vd.id.getType) match { case "I" | "Z" => diff --git a/src/main/scala/leon/codegen/CompilationUnit.scala b/src/main/scala/leon/codegen/CompilationUnit.scala index 1b3ce23c050ef05431ce2fe7f42f16b6ca3746c8..207bd7183982a4d4ac237a53bb42c6d18f655a03 100644 --- a/src/main/scala/leon/codegen/CompilationUnit.scala +++ b/src/main/scala/leon/codegen/CompilationUnit.scala @@ -31,13 +31,13 @@ class CompilationUnit(val ctx: LeonContext, val cf = df match { case ccd: CaseClassDef => - val pName = ccd.parent.map(parent => defToJVMName(parent)) + val pName = ccd.parent.map(parent => defToJVMName(parent.classDef)) new ClassFile(cName, pName) case acd: AbstractClassDef => new ClassFile(cName, None) - case ob: ObjectDef => + case ob: ModuleDef => new ClassFile(cName, None) case _ => @@ -51,8 +51,13 @@ class CompilationUnit(val ctx: LeonContext, classes.find(_._2.className == name).map(_._1) } - def leonClassToJVMClass(cd: Definition): Option[String] = { - classes.get(cd).map(_.className) + def leonClassToJVMInfo(cd: ClassDef): Option[(String, String)] = { + classes.get(cd) match { + case Some(cf) => + val sig = "(" + cd.fields.map(f => typeToJVM(f.tpe)).mkString("") + ")V" + Some((cf.className, sig)) + case _ => None + } } // Returns className, methodName, methodSignature @@ -64,9 +69,9 @@ class CompilationUnit(val ctx: LeonContext, val sig = "(" + monitorType + fd.args.map(a => typeToJVM(a.tpe)).mkString("") + ")" + typeToJVM(fd.returnType) - leonClassToJVMClass(program.mainObject) match { - case Some(cn) => - val res = (cn, fd.id.uniqueName, sig) + classes.get(program.mainModule) match { + case Some(cf) => + val res = (cf.className, fd.id.uniqueName, sig) funDefInfo += fd -> res Some(res) case None => @@ -113,11 +118,14 @@ class CompilationUnit(val ctx: LeonContext, case BooleanLiteral(v) => new java.lang.Boolean(v) + case GenericValue(tp, id) => + e + case Tuple(elems) => tupleConstructor.newInstance(elems.map(exprToJVM).toArray).asInstanceOf[AnyRef] - case CaseClass(ccd, args) => - caseClassConstructor(ccd) match { + case CaseClass(cct, args) => + caseClassConstructor(cct.classDef) match { case Some(cons) => cons.newInstance(args.map(exprToJVM).toArray : _*).asInstanceOf[AnyRef] case None => @@ -147,7 +155,7 @@ class CompilationUnit(val ctx: LeonContext, jvmClassToLeonClass(e.getClass.getName) match { case Some(cc: CaseClassDef) => - CaseClass(cc, fields.map(jvmToExpr)) + CaseClass(CaseClassType(cc, Nil), fields.map(jvmToExpr)) case _ => throw CompilationException("Unsupported return value : " + e) } @@ -158,6 +166,9 @@ class CompilationUnit(val ctx: LeonContext, } Tuple(elems) + case gv : GenericValue => + gv + case set : runtime.Set => FiniteSet(set.getElements().asScala.map(jvmToExpr).toSeq) @@ -226,7 +237,7 @@ class CompilationUnit(val ctx: LeonContext, case Int32Type | BooleanType => ch << IRETURN - case UnitType | _: TupleType | _: SetType | _: MapType | _: AbstractClassType | _: CaseClassType | _: ArrayType => + case UnitType | _: TupleType | _: SetType | _: MapType | _: AbstractClassType | _: CaseClassType | _: ArrayType | _: TypeParameter => ch << ARETURN case other => @@ -240,8 +251,8 @@ class CompilationUnit(val ctx: LeonContext, new CompiledExpression(this, cf, e, args) } - def compileMainObject() { - val cf = classes(program.mainObject) + def compileMainModule() { + val cf = classes(program.mainModule) cf.addDefaultConstructor @@ -295,7 +306,7 @@ class CompilationUnit(val ctx: LeonContext, defineClass(single) } - defineClass(program.mainObject) + defineClass(program.mainModule) } def compile() { @@ -312,7 +323,7 @@ class CompilationUnit(val ctx: LeonContext, compileCaseClassDef(single) } - compileMainObject() + compileMainModule() classes.values.foreach(loader.register _) } diff --git a/src/main/scala/leon/datagen/NaiveDataGen.scala b/src/main/scala/leon/datagen/NaiveDataGen.scala index f540c59b08ea891ad9b2de31cfbf42f4c2b2f896..5ec8577bdb3d0213679ead60a107d88d50ea2d20 100644 --- a/src/main/scala/leon/datagen/NaiveDataGen.scala +++ b/src/main/scala/leon/datagen/NaiveDataGen.scala @@ -84,30 +84,27 @@ class NaiveDataGen(ctx: LeonContext, p: Program, evaluator: Evaluator, _bounds : // We prioritize base cases among the children. // Otherwise we run the risk of infinite recursion when // generating lists. - val ccChildren = act.classDef.knownChildren.collect(_ match { - case ccd : CaseClassDef => ccd - } - ) + val ccChildren = act.knownCCDescendents + val (leafs,conss) = ccChildren.partition(_.fields.size == 0) // The stream for leafs... - val leafsStream = leafs.toStream.flatMap { ccd => - generate(classDefToClassType(ccd), bounds) + val leafsStream = leafs.toStream.flatMap { cct => + generate(cct, bounds) } // ...to which we append the streams for constructors. - leafsStream.append(interleave(conss.map { ccd => - generate(classDefToClassType(ccd), bounds) + leafsStream.append(interleave(conss.map { cct => + generate(cct, bounds) })) case cct : CaseClassType => - val ccd = cct.classDef - if(ccd.fields.isEmpty) { - Stream.cons(CaseClass(ccd, Nil), Stream.empty) + if(cct.fields.isEmpty) { + Stream.cons(CaseClass(cct, Nil), Stream.empty) } else { - val fieldTypes = ccd.fields.map(_.tpe) + val fieldTypes = cct.fieldsTypes val subStream = naryProduct(fieldTypes.map(generate(_, bounds))) - subStream.map(prod => CaseClass(ccd, prod)) + subStream.map(prod => CaseClass(cct, prod)) } case _ => Stream.empty diff --git a/src/main/scala/leon/datagen/VanuatooDataGen.scala b/src/main/scala/leon/datagen/VanuatooDataGen.scala index 6226e9663ff07276b46ad4279ccb51338756ca08..f934ac4bd5d49c0606428029af403ce0d1f45b09 100644 --- a/src/main/scala/leon/datagen/VanuatooDataGen.scala +++ b/src/main/scala/leon/datagen/VanuatooDataGen.scala @@ -34,8 +34,8 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { ConstructorPattern[Expr, TypeTree](c, args) } - private var ccConstructors = Map[CaseClassDef, Constructor[Expr, TypeTree]]() - private var acConstructors = Map[AbstractClassDef, List[Constructor[Expr, TypeTree]]]() + private var ccConstructors = Map[CaseClassType, Constructor[Expr, TypeTree]]() + private var acConstructors = Map[AbstractClassType, List[Constructor[Expr, TypeTree]]]() private var tConstructors = Map[TupleType, Constructor[Expr, TypeTree]]() private def getConstructorFor(t: CaseClassType, act: AbstractClassType): Constructor[Expr, TypeTree] = { @@ -52,22 +52,21 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { c })) - case act @ AbstractClassType(acd) => - acConstructors.getOrElse(acd, { - val cs = acd.knownDescendents.collect { - case ccd: CaseClassDef => - getConstructorFor(CaseClassType(ccd), act) + case act: AbstractClassType => + acConstructors.getOrElse(act, { + val cs = act.knownCCDescendents.map { + cct => getConstructorFor(cct, act) }.toList - acConstructors += acd -> cs + acConstructors += act -> cs cs }) - case CaseClassType(ccd) => - List(ccConstructors.getOrElse(ccd, { - val c = Constructor[Expr, TypeTree](ccd.fields.map(_.tpe), CaseClassType(ccd), s => CaseClass(ccd, s), ccd.id.name) - ccConstructors += ccd -> c + case cct: CaseClassType => + List(ccConstructors.getOrElse(cct, { + val c = Constructor[Expr, TypeTree](cct.fieldsTypes, cct, s => CaseClass(cct, s), cct.id.name) + ccConstructors += cct -> c c })) @@ -90,9 +89,9 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { case Some(ccd: CaseClassDef) => val c = ct match { case act : AbstractClassType => - getConstructorFor(CaseClassType(ccd), act) + getConstructorFor(CaseClassType(ccd, ct.tps), act) case cct : CaseClassType => - getConstructors(CaseClassType(ccd))(0) + getConstructors(CaseClassType(ccd, ct.tps))(0) } val fields = cc.productElements() @@ -100,7 +99,7 @@ class VanuatooDataGen(ctx: LeonContext, p: Program) extends DataGenerator { val elems = for (i <- 0 until fields.length) yield { if (((r >> i) & 1) == 1) { // has been read - valueToPattern(fields(i), ccd.fieldsIds(i).getType) + valueToPattern(fields(i), ct.fieldsTypes(i)) } else { (AnyPattern[Expr, TypeTree](), false) } diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index 0d61de43985b6e345907c5e6a9d581bd691dfd1e..246e7f654e39eb43e8fe0faa4eee8a7e3c01547e 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -94,37 +94,37 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu case _ => throw EvalError(typeErrorMsg(first, BooleanType)) } - case FunctionInvocation(fd, args) => + case FunctionInvocation(tfd, args) => val evArgs = args.map(a => se(a)) // build a mapping for the function... - val frame = rctx.withVars((fd.args.map(_.id) zip evArgs).toMap) + val frame = rctx.withVars((tfd.args.map(_.id) zip evArgs).toMap) - if(fd.hasPrecondition) { - se(matchToIfThenElse(fd.precondition.get))(frame, gctx) match { + if(tfd.hasPrecondition) { + se(matchToIfThenElse(tfd.precondition.get))(frame, gctx) match { case BooleanLiteral(true) => case BooleanLiteral(false) => - throw RuntimeError("Precondition violation for " + fd.id.name + " reached in evaluation.: " + fd.precondition.get) + throw RuntimeError("Precondition violation for " + tfd.id.name + " reached in evaluation.: " + tfd.precondition.get) case other => throw RuntimeError(typeErrorMsg(other, BooleanType)) } } - if(!fd.hasBody && !rctx.mappings.isDefinedAt(fd.id)) { + if(!tfd.hasBody && !rctx.mappings.isDefinedAt(tfd.id)) { throw EvalError("Evaluation of function with unknown implementation.") } - val body = fd.body.getOrElse(rctx.mappings(fd.id)) + val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) val callResult = se(matchToIfThenElse(body))(frame, gctx) - if(fd.hasPostcondition) { - val (id, post) = fd.postcondition.get + if(tfd.hasPostcondition) { + val (id, post) = tfd.postcondition.get - val freshResID = FreshIdentifier("result").setType(fd.returnType) + val freshResID = FreshIdentifier("result").setType(tfd.returnType) val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post)) se(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match { case BooleanLiteral(true) => - case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + fd.id.name + " reached in evaluation.") + case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") case other => throw EvalError(typeErrorMsg(other, BooleanType)) } } @@ -179,18 +179,18 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu case CaseClass(cd, args) => CaseClass(cd, args.map(se(_))) - case CaseClassInstanceOf(cd, expr) => + case CaseClassInstanceOf(cct, expr) => val le = se(expr) BooleanLiteral(le.getType match { - case CaseClassType(cd2) if cd2 == cd => true + case CaseClassType(cd2, _) if cd2 == cct.classDef => true case _ => false }) - case CaseClassSelector(cd, expr, sel) => + case CaseClassSelector(ct1, expr, sel) => val le = se(expr) le match { - case CaseClass(cd2, args) if cd == cd2 => args(cd.selectorID2Index(sel)) - case _ => throw EvalError(typeErrorMsg(le, CaseClassType(cd))) + case CaseClass(ct2, args) if ct1 == ct2 => args(ct1.classDef.selectorID2Index(sel)) + case _ => throw EvalError(typeErrorMsg(le, ct1)) } case Plus(l,r) => @@ -360,6 +360,9 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program) extends Evalu BooleanLiteral(newArgs.distinct.size == newArgs.size) } + case gv: GenericValue => + gv + case choose: Choose => import solvers.z3.FairZ3Solver import purescala.TreeOps.simplestValue diff --git a/src/main/scala/leon/evaluators/TracingEvaluator.scala b/src/main/scala/leon/evaluators/TracingEvaluator.scala index 183b6cf2f5f79772f51686fb6ae34318121a85fb..ded6d33c310fb7a2db772aef7c35f050559b1269 100644 --- a/src/main/scala/leon/evaluators/TracingEvaluator.scala +++ b/src/main/scala/leon/evaluators/TracingEvaluator.scala @@ -42,38 +42,38 @@ class TracingEvaluator(ctx: LeonContext, prog: Program) extends RecursiveEvaluat val res = se(b)(rctx.withNewVar(i, first), gctx) (res, first) - case fi @ FunctionInvocation(fd, args) => + case fi @ FunctionInvocation(tfd, args) => val evArgs = args.map(a => se(a)) // build a mapping for the function... - val frame = new TracingRecContext((fd.args.map(_.id) zip evArgs).toMap, rctx.tracingFrames-1) + val frame = new TracingRecContext((tfd.args.map(_.id) zip evArgs).toMap, rctx.tracingFrames-1) - if(fd.hasPrecondition) { - se(matchToIfThenElse(fd.precondition.get))(frame, gctx) match { + if(tfd.hasPrecondition) { + se(matchToIfThenElse(tfd.precondition.get))(frame, gctx) match { case BooleanLiteral(true) => case BooleanLiteral(false) => - throw RuntimeError("Precondition violation for " + fd.id.name + " reached in evaluation.: " + fd.precondition.get) + throw RuntimeError("Precondition violation for " + tfd.id.name + " reached in evaluation.: " + tfd.precondition.get) case other => throw RuntimeError(typeErrorMsg(other, BooleanType)) } } - if(!fd.hasBody && !rctx.mappings.isDefinedAt(fd.id)) { + if(!tfd.hasBody && !rctx.mappings.isDefinedAt(tfd.id)) { throw EvalError("Evaluation of function with unknown implementation.") } - val body = fd.body.getOrElse(rctx.mappings(fd.id)) + val body = tfd.body.getOrElse(rctx.mappings(tfd.id)) val callResult = se(matchToIfThenElse(body))(frame, gctx) - if(fd.hasPostcondition) { - val (id, post) = fd.postcondition.get + if(tfd.hasPostcondition) { + val (id, post) = tfd.postcondition.get - val freshResID = FreshIdentifier("result").setType(fd.returnType) + val freshResID = FreshIdentifier("result").setType(tfd.returnType) val postBody = replace(Map(Variable(id) -> Variable(freshResID)), matchToIfThenElse(post)) se(matchToIfThenElse(post))(frame.withNewVar(id, callResult), gctx) match { case BooleanLiteral(true) => - case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + fd.id.name + " reached in evaluation.") + case BooleanLiteral(false) => throw RuntimeError("Postcondition violation for " + tfd.id.name + " reached in evaluation.") case other => throw EvalError(typeErrorMsg(other, BooleanType)) } } diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index 2083d0e36e7eb1eedad58b544c6fcdfbac4dfc45..fa2fc7ce4e00436d045a21b01f7e6aabbffcd61c 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -159,7 +159,7 @@ trait ASTExtractors { * no abstract members. */ def unapply(cd: ClassDef): Option[(String,Symbol)] = cd match { // abstract class - case ClassDef(_, name, tparams, impl) if (cd.symbol.isAbstractClass && tparams.isEmpty && impl.body.size == 1) => Some((name.toString, cd.symbol)) + case ClassDef(_, name, tparams, impl) if (cd.symbol.isAbstractClass && impl.body.size == 1) => Some((name.toString, cd.symbol)) case _ => None } @@ -167,7 +167,7 @@ trait ASTExtractors { object ExCaseClass { def unapply(cd: ClassDef): Option[(String,Symbol,Seq[(String,Tree)])] = cd match { - case ClassDef(_, name, tparams, impl) if (cd.symbol.isCase && !cd.symbol.isAbstractClass && tparams.isEmpty && impl.body.size >= 8) => { + case ClassDef(_, name, tparams, impl) if (cd.symbol.isCase && !cd.symbol.isAbstractClass && impl.body.size >= 8) => { val constructor: DefDef = impl.children.find(child => child match { case ExConstructorDef() => true case _ => false @@ -217,8 +217,9 @@ trait ASTExtractors { object ExFunctionDef { /** Matches a function with a single list of arguments, no type * parameters and regardless of its visibility. */ - def unapply(dd: DefDef): Option[(String,Seq[ValDef],Tree,Tree)] = dd match { - case DefDef(_, name, tparams, vparamss, tpt, rhs) if(tparams.isEmpty && vparamss.size == 1 && name != nme.CONSTRUCTOR) => Some((name.toString, vparamss(0), tpt, rhs)) + def unapply(dd: DefDef): Option[(Symbol, Seq[Symbol], Seq[ValDef], Type, Tree)] = dd match { + case DefDef(_, name, tparams, vparamss, tpt, rhs) if(vparamss.size <= 1 && name != nme.CONSTRUCTOR) => + Some((dd.symbol, tparams.map(_.symbol), vparamss.headOption.getOrElse(Nil), tpt.tpe, rhs)) case _ => None } } @@ -574,9 +575,11 @@ trait ASTExtractors { } object ExLocalCall { - def unapply(tree: Apply): Option[(Symbol,String,List[Tree])] = tree match { - case a @ Apply(Select(This(_), nme), args) => Some((a.symbol, nme.toString, args)) - case a @ Apply(Ident(nme), args) => Some((a.symbol, nme.toString, args)) + def unapply(tree: Apply): Option[(Symbol, List[Tree], List[Tree])] = tree match { + case a @ Apply(Select(This(_), nme), args) => Some((a.symbol, Nil, args)) + case a @ Apply(Ident(nme), args) => Some((a.symbol, Nil, args)) + case a @ Apply(TypeApply(Select(This(_), nme), tps), args) => Some((a.symbol, tps, args)) + case a @ Apply(TypeApply(Ident(nme), tps), args) => Some((a.symbol, tps, args)) case _ => None } } diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 5422da82ceb5a80bd3ffacdb469c8ec030c23fb4..8b72eb86c5706c7870acfb44a643ee74b86a1500 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -9,11 +9,12 @@ import scala.tools.nsc.plugins._ import scala.language.implicitConversions import purescala._ -import purescala.Definitions._ +import purescala.Definitions.{ClassDef => LeonClassDef, ModuleDef => LeonModuleDef, _} import purescala.Trees.{Expr => LeonExpr, _} import purescala.TypeTrees.{TypeTree => LeonType, _} import purescala.Common._ import purescala.TreeOps._ +import purescala.TypeTreeOps._ import xlang.Trees.{Block => LeonBlock, _} import xlang.TreeOps._ @@ -43,12 +44,10 @@ trait CodeExtraction extends ASTExtractors { private val mutableVarSubsts: scala.collection.mutable.Map[Symbol,Function0[LeonExpr]] = scala.collection.mutable.Map.empty[Symbol,Function0[LeonExpr]] - private val varSubsts: scala.collection.mutable.Map[Symbol,Function0[LeonExpr]] = - scala.collection.mutable.Map.empty[Symbol,Function0[LeonExpr]] - private val classesToClasses: scala.collection.mutable.Map[Symbol,ClassTypeDef] = - scala.collection.mutable.Map.empty[Symbol,ClassTypeDef] - private val defsToDefs: scala.collection.mutable.Map[Symbol,FunDef] = - scala.collection.mutable.Map.empty[Symbol,FunDef] + + private var classesToClasses = Map[Symbol, LeonClassDef]() + private var defsToDefs = Map[Symbol, FunDef]() + private var varSubsts = Map[Symbol, () => LeonExpr]() /** An exception thrown when non-purescala compatible code is encountered. */ sealed case class ImpureCodeEncounteredException(tree: Tree) extends Exception @@ -62,7 +61,8 @@ trait CodeExtraction extends ASTExtractors { class Extraction(unit: CompilationUnit) { - def toPureScala(tree: Tree): Option[LeonExpr] = { + + def toPureScala(tree: Tree)(implicit dctx: DefContext): Option[LeonExpr] = { try { Some(extractTree(tree)) } catch { @@ -72,7 +72,7 @@ trait CodeExtraction extends ASTExtractors { } // This one never fails, on error, it returns Untyped - def toPureScalaType(tpt: Type): LeonType = { + def toPureScalaType(tpt: Type)(implicit dctx: DefContext): LeonType = { try { extractType(tpt) } catch { @@ -81,7 +81,7 @@ trait CodeExtraction extends ASTExtractors { } } - private def extractTopLevelDef: Option[ObjectDef] = { + private def extractTopLevelDef: Option[LeonModuleDef] = { unit.body match { case p @ PackageDef(name, lst) if lst.size == 0 => reporter.error(p.pos, "No top-level definition found.") @@ -89,7 +89,7 @@ trait CodeExtraction extends ASTExtractors { case PackageDef(name, lst) => if (lst.size > 1) { - reporter.error(lst(1).pos, "More than one top-level object. Rest will be ignored.") + reporter.error(lst(1).pos, "More than one top-level object. Rest will be ignored.") } lst(0) match { case ExObjectDef(n, templ) => @@ -102,90 +102,148 @@ trait CodeExtraction extends ASTExtractors { } } - private def extractObjectDef(nameStr: String, tmpl: Template): ObjectDef = { + case class DefContext( + tparams: Map[Symbol, TypeParameter] + ) + + private def extractTypeParams(tps: Seq[Type]): Seq[(Symbol, TypeParameter)] = { + tps.flatMap { + case TypeRef(_, sym, Nil) => + Some(sym -> TypeParameter(FreshIdentifier(sym.name.toString))) + case t => + reporter.error(t.typeSymbol.pos, "Unhandled type for parameter: "+t) + None + } + } + + private def extractObjectDef(nameStr: String, tmpl: Template): LeonModuleDef = { // we assume that the template actually corresponds to an object // definition. Typically it should have been obtained from the proper // extractor (ExObjectDef) - var scalaClassSyms = Map[Symbol,Identifier]() - var scalaClassArgs = Map[Symbol,Seq[(String,Tree)]]() - var scalaClassNames = Set[String]() - // we need the new type definitions before we can do anything... - for (t <- tmpl.body) t match { - case ExAbstractClass(o2, sym) => - if(scalaClassNames.contains(o2)) { - reporter.error(t.pos, "A class with the same name already exists.") - } else { - scalaClassSyms += sym -> FreshIdentifier(o2) - scalaClassNames += o2 - } + var seenClasses = Map[Symbol, Seq[(String, Tree)]]() - case ExCaseClass(o2, sym, args) => - if(scalaClassNames.contains(o2)) { - reporter.error(t.pos, "A class with the same name already exists.") - } else { - scalaClassSyms += sym -> FreshIdentifier(o2) - scalaClassNames += o2 - scalaClassArgs += sym -> args - } + def extractClass(sym: Symbol): LeonClassDef = { + classesToClasses.get(sym) match { + case Some(cd) => cd + case None => + val id = FreshIdentifier(sym.name.toString).setPos(sym.pos) - case _ => + val tparamsMap = sym.tpe match { + case TypeRef(_, _, tps) => + extractTypeParams(tps) + case _ => + Nil + } - } + val tparams = tparamsMap.map(t => TypeParameterDef(t._2)) - for ((sym, id) <- scalaClassSyms) { - if (sym.isAbstractClass) { - classesToClasses += sym -> new AbstractClassDef(id) - } else { - val ccd = new CaseClassDef(id) - ccd.isCaseObject = sym.isModuleClass - classesToClasses += sym -> ccd - } - } + val defCtx = DefContext(tparamsMap.toMap) - for ((sym, ctd) <- classesToClasses) { - val superClasses: List[ClassTypeDef] = sym.tpe.baseClasses - .filter(bcs => scalaClassSyms.contains(bcs) && bcs != sym) - .map(s => classesToClasses(s)).toList + val parent = sym.tpe.parents.headOption match { + case Some(TypeRef(_, parentSym, tps)) if seenClasses.contains(parentSym) => + extractClass(parentSym) match { + case acd :AbstractClassDef => + val newTps = tps.map(extractType(_)(defCtx)) + Some(AbstractClassType(acd, newTps)) + case cd => + reporter.error(sym.pos, "Class "+id+" cannot extend "+cd.id) + None + } - val superAClasses: List[AbstractClassDef] = superClasses.flatMap { - case acd: AbstractClassDef => - Some(acd) - case c => - reporter.error(sym.pos, "Class "+ctd.id+" is inheriting from non-abstract class "+c.id+".") - None - } + case _ => + None + } + + if (sym.isAbstractClass) { + val acd = AbstractClassDef(id, tparams, parent).setPos(sym.pos) + + classesToClasses += sym -> acd + + acd + } else { + val ccd = CaseClassDef(id, tparams, parent, sym.isModuleClass).setPos(sym.pos) - if (superAClasses.size > 2) { - reporter.error(sym.pos, "Multiple inheritance is not permitted, ignoring extra parents.") + parent.foreach(_.classDef.registerChildren(ccd)) + + classesToClasses += sym -> ccd + + // Validates type parameters + parent match { + case Some(pct) => + if(pct.classDef.tparams.size == tparams.size) { + val pcd = pct.classDef + val ptps = pcd.tparams.map(_.tp) + + val targetType = AbstractClassType(pcd, ptps) + val fromChild = CaseClassType(ccd, ptps).parent.get + + if (fromChild != targetType) { + reporter.error(sym.pos, "Child type should form a simple bijection with parent class type (e.g. C[T1,T2] extends P[T1,T2])") + } + + } else { + reporter.error(sym.pos, "Child classes should have the same number of type parameters as their parent") + } + case _ => + } + + ccd + } } + } - superAClasses.headOption.foreach{ parent => ctd.setParent(parent) } + // We collect all defined classes + for (t <- tmpl.body) t match { + case ExAbstractClass(o2, sym) => + seenClasses += sym -> Nil + + case ExCaseClass(o2, sym, args) => + seenClasses += sym -> args + + case _ => + } - ctd match { + // Pass 2: we define classDefs + for ((sym, params) <- seenClasses) { + extractClass(sym) + } + + // Pass 3: we define fields + for ((sym, params) <- seenClasses) { + extractClass(sym) match { case ccd: CaseClassDef => - assert(scalaClassArgs contains sym) + val tparamsSym = sym.tpe match { + case TypeRef(_, _, tps) => + extractTypeParams(tps).map(_._1) + case _ => + Nil + } + + val tparamsMap = (tparamsSym zip ccd.tparams.map(_.tp)).toMap + + val defCtx = DefContext(tparamsMap) - ccd.fields = scalaClassArgs(sym).map{ case (name, asym) => - val tpe = toPureScalaType(asym.tpe) - VarDecl(FreshIdentifier(name).setType(tpe).setPos(asym.pos), tpe).setPos(asym.pos) + val fields = params.map { case (aname, asym) => + val tpe = toPureScalaType(asym.tpe)(defCtx) + VarDecl(FreshIdentifier(aname).setType(tpe).setPos(asym.pos), tpe).setPos(asym.pos) } + ccd.setFields(fields) case _ => - // no fields to set } } - var classDefs: List[ClassTypeDef] = classesToClasses.values.toList - // First pass to instanciate all FunDefs for (d <- tmpl.body) d match { case ExMainFunctionDef() => // we ignore the main function - case dd @ ExFunctionDef(name, params, tpe, body) => - val funDef = extractFunSig(name, params, tpe).setPos(dd.pos) + case dd @ ExFunctionDef(sym, tparams, params, ret, body) => + val dctx = DefContext(Map()) + + val funDef = extractFunSig(sym, tparams, params, ret)(dctx) if (dd.mods.isPrivate) { funDef.addAnnotation("private") @@ -207,10 +265,13 @@ trait CodeExtraction extends ASTExtractors { // Second pass to convert function bodies for (d <- tmpl.body) d match { - case dd @ ExFunctionDef(_, _, _, body) if defsToDefs contains dd.symbol => - val fd = defsToDefs(dd.symbol) + case ExFunctionDef(sym, tparams, _, _, body) if defsToDefs contains sym => + val fd = defsToDefs(sym).setPos(d.pos) - extractFunDef(fd, body) + val tparamsMap = (tparams zip fd.tparams.map(_.tp)).toMap + + val dctx = DefContext(tparamsMap) + extractFunDef(fd, body)(dctx) case _ => } @@ -223,32 +284,40 @@ trait CodeExtraction extends ASTExtractors { case ExCaseClass(_,_,_) => case ExConstructorDef() => case ExMainFunctionDef() => - case ExFunctionDef(_,_,_,_) => + case ExFunctionDef(_, _, _, _, _) => case tree => unsupported(tree, "Don't know what to do with this. Not purescala?"); } - new ObjectDef(FreshIdentifier(nameStr), classDefs ::: funDefs, Nil) + new LeonModuleDef(FreshIdentifier(nameStr), classesToClasses.values.toList ::: funDefs, Nil) } - private def extractFunSig(nameStr: String, params: Seq[ValDef], tpt: Tree): FunDef = { - val newParams = params.map(p => { - val ptpe = toPureScalaType(p.tpt.tpe) - val newID = FreshIdentifier(p.name.toString).setType(ptpe).setPos(p.pos) + private def extractFunSig(sym: Symbol, tps: Seq[Symbol], params: Seq[ValDef], ret: Type)(implicit dctx: DefContext): FunDef = { + + val tparams = extractTypeParams(tps.map(_.tpe)) + + val dctx = DefContext(tparams.toMap) + + val newParams = params.map{ vd => + val ptpe = toPureScalaType(vd.tpt.tpe)(dctx) + val newID = FreshIdentifier(vd.symbol.name.toString).setType(ptpe).setPos(vd.pos) owners += (newID -> None) - varSubsts(p.symbol) = (() => Variable(newID)) - VarDecl(newID, ptpe).setPos(p.pos) - }) - new FunDef(FreshIdentifier(nameStr), toPureScalaType(tpt.tpe), newParams) + varSubsts += vd.symbol -> (() => Variable(newID)) + VarDecl(newID, ptpe).setPos(vd.pos) + } + + val tparamsDef = tparams.map(t => TypeParameterDef(t._2)) + + new FunDef(FreshIdentifier(sym.name.toString), tparamsDef, toPureScalaType(ret)(dctx), newParams) } - private def extractFunDef(funDef: FunDef, body: Tree): FunDef = { + private def extractFunDef(funDef: FunDef, body: Tree)(implicit dctx: DefContext): FunDef = { currentFunDef = funDef val (body2, ensuring) = body match { case ExEnsuredExpression(body2, resSym, contract) => val resId = FreshIdentifier(resSym.name.toString).setType(funDef.returnType).setPos(resSym.pos) - varSubsts(resSym) = (() => Variable(resId)) + varSubsts += resSym -> (() => Variable(resId)) (body2, toPureScala(contract).map(r => (resId, r))) case ExHoldsExpression(body2) => @@ -324,21 +393,24 @@ trait CodeExtraction extends ASTExtractors { } - private def extractPattern(p: Tree, binder: Option[Identifier] = None): Pattern = p match { + private def extractPattern(p: Tree, binder: Option[Identifier] = None)(implicit dctx: DefContext): Pattern = p match { case b @ Bind(name, t @ Typed(pat, tpe)) => val newID = FreshIdentifier(name.toString).setType(extractType(tpe.tpe)).setPos(b.pos) - varSubsts(b.symbol) = (() => Variable(newID)) + varSubsts += b.symbol -> (() => Variable(newID)) extractPattern(t, Some(newID)) case b @ Bind(name, pat) => val newID = FreshIdentifier(name.toString).setType(extractType(b.symbol.tpe)).setPos(b.pos) - varSubsts(b.symbol) = (() => Variable(newID)) + varSubsts += b.symbol -> (() => Variable(newID)) extractPattern(pat, Some(newID)) - case t @ Typed(Ident(nme.WILDCARD), tpe) if t.tpe.typeSymbol.isCase && - classesToClasses.contains(t.tpe.typeSymbol) => - val cd = classesToClasses(t.tpe.typeSymbol).asInstanceOf[CaseClassDef] - InstanceOfPattern(binder, cd).setPos(p.pos) + case t @ Typed(Ident(nme.WILDCARD), tpt) if classesToClasses.contains(t.tpe.typeSymbol) => + extractType(tpt.tpe) match { + case ct: ClassType => + InstanceOfPattern(binder, ct).setPos(p.pos) + case _ => + unsupported("Invalid type "+tpt.tpe+" for .isInstanceOf") + } case Ident(nme.WILDCARD) => WildcardPattern(binder).setPos(p.pos) @@ -346,31 +418,39 @@ trait CodeExtraction extends ASTExtractors { case s @ Select(This(_), b) if s.tpe.typeSymbol.isCase && classesToClasses.contains(s.tpe.typeSymbol) => // case Obj => - val cd = classesToClasses(s.tpe.typeSymbol).asInstanceOf[CaseClassDef] - assert(cd.fields.size == 0) - CaseClassPattern(binder, cd, Seq()).setPos(p.pos) + extractType(s.tpe) match { + case ct: CaseClassType => + assert(ct.classDef.fields.size == 0) + CaseClassPattern(binder, ct, Seq()).setPos(p.pos) + case _ => + unsupported("Invalid type "+s.tpe+" for .isInstanceOf") + } case a @ Apply(fn, args) if fn.isType && a.tpe.typeSymbol.isCase && classesToClasses.contains(a.tpe.typeSymbol) => - val cd = classesToClasses(a.tpe.typeSymbol).asInstanceOf[CaseClassDef] - assert(args.size == cd.fields.size) - CaseClassPattern(binder, cd, args.map(extractPattern(_))).setPos(p.pos) + extractType(a.tpe) match { + case ct: CaseClassType => + assert(args.size == ct.classDef.fields.size) + CaseClassPattern(binder, ct, args.map(extractPattern(_))).setPos(p.pos) + case _ => + unsupported("Invalid type "+a.tpe+" for .isInstanceOf") + } case a @ Apply(fn, args) => extractType(a.tpe) match { case TupleType(argsTpes) => TuplePattern(binder, args.map(extractPattern(_))).setPos(p.pos) case _ => - unsupported(p, "Unsupported pattern") + unsupported(p, "Unsupported pattern: "+a) } case _ => - unsupported(p, "Unsupported pattern") + unsupported(p, "Unsupported pattern: "+p) } - private def extractMatchCase(cd: CaseDef): MatchCase = { + private def extractMatchCase(cd: CaseDef)(implicit dctx: DefContext): MatchCase = { val recPattern = extractPattern(cd.pat) val recBody = extractTree(cd.body) @@ -388,7 +468,7 @@ trait CodeExtraction extends ASTExtractors { } } - private def extractTree(tr: Tree): LeonExpr = { + private def extractTree(tr: Tree)(implicit dctx: DefContext): LeonExpr = { val (current, tmpRest) = tr match { case Block(Block(e :: es1, l1) :: es2, l2) => (e, Some(Block(es1 ++ Seq(l1) ++ es2, l2))) @@ -409,7 +489,7 @@ trait CodeExtraction extends ASTExtractors { case ExCaseObject(sym) => classesToClasses.get(sym) match { case Some(ccd: CaseClassDef) => - CaseClass(ccd, Seq()) + CaseClass(CaseClassType(ccd, Seq()), Seq()) case _ => unsupported(tr, "Unknown case object "+sym.name) } @@ -417,11 +497,10 @@ trait CodeExtraction extends ASTExtractors { case ExParameterlessMethodCall(t,n) if extractTree(t).getType.isInstanceOf[CaseClassType] => val selector = extractTree(t) - val selType = selector.getType + val selType = selector.getType.asInstanceOf[CaseClassType] - val selDef: CaseClassDef = selType.asInstanceOf[CaseClassType].classDef - val fieldID = selDef.fields.find(_.id.name == n.toString) match { + val fieldID = selType.fields.find(_.id.name == n.toString) match { case None => unsupported(tr, "Invalid method or field invocation (not a case class arg?)") @@ -429,7 +508,7 @@ trait CodeExtraction extends ASTExtractors { vd.id } - CaseClassSelector(selDef, selector, fieldID).setType(fieldID.getType) + CaseClassSelector(selType, selector, fieldID) case ExTuple(tpes, exprs) => val tupleExprs = exprs.map(e => extractTree(e)) @@ -466,9 +545,9 @@ trait CodeExtraction extends ASTExtractors { val restTree = rest match { case Some(rst) => { - varSubsts(vs) = (() => Variable(newID)) + varSubsts += vs -> (() => Variable(newID)) val res = extractTree(rst) - varSubsts.remove(vs) + varSubsts -= vs res } case None => UnitLiteral @@ -481,9 +560,9 @@ trait CodeExtraction extends ASTExtractors { * XLang Extractors */ - case dd @ ExFunctionDef(n, p, t, b) => - val funDef = extractFunSig(n, p, t) - defsToDefs += (dd.symbol -> funDef) + case ExFunctionDef(symbol, tparams, params, ret, b) => + val funDef = extractFunSig(symbol, tparams, params, ret) + defsToDefs += (symbol -> funDef) val oldMutableVarSubst = mutableVarSubsts.toMap //take an immutable snapshot of the map val oldCurrentFunDef = currentFunDef mutableVarSubsts.clear //reseting the visible mutable vars, we do not handle mutable variable closure in nested functions @@ -494,7 +573,7 @@ trait CodeExtraction extends ASTExtractors { case Some(rst) => extractTree(rst) case None => UnitLiteral } - defsToDefs.remove(dd.symbol) + defsToDefs -= symbol rest = None LetDef(funDefWithBody, restTree) @@ -514,9 +593,9 @@ trait CodeExtraction extends ASTExtractors { } val restTree = rest match { case Some(rst) => { - varSubsts(vs) = (() => Variable(newID)) + varSubsts += vs -> (() => Variable(newID)) val res = extractTree(rst) - varSubsts.remove(vs) + varSubsts -= vs res } case None => UnitLiteral @@ -557,11 +636,11 @@ trait CodeExtraction extends ASTExtractors { case epsi @ ExEpsilonExpression(tpe, varSym, predBody) => val pstpe = extractType(tpe) val previousVarSubst: Option[Function0[LeonExpr]] = varSubsts.get(varSym) //save the previous in case of nested epsilon - varSubsts(varSym) = (() => EpsilonVariable(epsi.pos).setType(pstpe)) + varSubsts += varSym -> (() => EpsilonVariable(epsi.pos).setType(pstpe)) val c1 = extractTree(predBody) previousVarSubst match { - case Some(f) => varSubsts(varSym) = f - case None => varSubsts.remove(varSym) + case Some(f) => varSubsts += varSym -> f + case None => varSubsts -= varSym } if(containsEpsilon(c1)) { unsupported(epsi, "Usage of nested epsilon is not allowed") @@ -629,7 +708,7 @@ trait CodeExtraction extends ASTExtractors { val aTpe = extractType(tpe) val newID = FreshIdentifier(sym.name.toString).setType(aTpe) owners += (newID -> None) - varSubsts(sym) = (() => Variable(newID)) + varSubsts += sym -> (() => Variable(newID)) newID } @@ -641,7 +720,7 @@ trait CodeExtraction extends ASTExtractors { extractType(tpt.tpe) match { case cct: CaseClassType => val nargs = args.map(extractTree(_)) - CaseClass(cct.classDef, nargs) + CaseClass(cct, nargs) case _ => unsupported(tr, "Construction of a non-case class.") @@ -909,20 +988,21 @@ trait CodeExtraction extends ASTExtractors { val ccRec = extractTree(cc) val checkType = extractType(tt.tpe) checkType match { - case CaseClassType(ccd) => { - val rootType: ClassTypeDef = if(ccd.parent != None) ccd.parent.get else ccd + case cct @ CaseClassType(ccd, tps) => { + val rootType: LeonClassDef = if(ccd.parent != None) ccd.parent.get.classDef else ccd + if(!ccRec.getType.isInstanceOf[ClassType]) { reporter.error(tr.pos, "isInstanceOf can only be used with a case class") throw ImpureCodeEncounteredException(tr) } else { val testedExprType = ccRec.getType.asInstanceOf[ClassType].classDef - val testedExprRootType: ClassTypeDef = if(testedExprType.parent != None) testedExprType.parent.get else testedExprType + val testedExprRootType: LeonClassDef = if(testedExprType.parent != None) testedExprType.parent.get.classDef else testedExprType if(rootType != testedExprRootType) { reporter.error(tr.pos, "isInstanceOf can only be used with compatible case classes") throw ImpureCodeEncounteredException(tr) } else { - CaseClassInstanceOf(ccd, ccRec) + CaseClassInstanceOf(cct, ccRec) } } } @@ -933,14 +1013,17 @@ trait CodeExtraction extends ASTExtractors { } } - case lc @ ExLocalCall(sy,nm,ar) => { - if(defsToDefs.keysIterator.find(_ == sy).isEmpty) { - reporter.error(tr.pos, "Invoking an invalid function.") + case lc @ ExLocalCall(sym, tps, ar) => + if (!defsToDefs.contains(sym)) { + reporter.error(lc.pos, "Invoking an invalid function.") throw ImpureCodeEncounteredException(tr) } - val fd = defsToDefs(sy) - FunctionInvocation(fd, ar.map(extractTree(_))).setType(fd.returnType) - } + + val fd = defsToDefs(sym) + + val newTps = tps.map(t => extractType(t.tpe)) + + FunctionInvocation(fd.typed(newTps), ar.map(extractTree(_))).setType(fd.returnType) case pm @ ExPatternMatching(sel, cses) => val rs = extractTree(sel) @@ -964,7 +1047,7 @@ trait CodeExtraction extends ASTExtractors { } } - private def extractType(tpt: Type): LeonType = tpt match { + private def extractType(tpt: Type)(implicit dctx: DefContext): LeonType = tpt match { case tpe if tpe == IntClass.tpe => Int32Type @@ -1001,11 +1084,22 @@ trait CodeExtraction extends ASTExtractors { case TypeRef(_, sym, btt :: Nil) if isArrayClassSym(sym) => ArrayType(extractType(btt)) - case TypeRef(_, sym, Nil) if classesToClasses contains sym => - classDefToClassType(classesToClasses(sym)) + case TypeRef(_, sym, tps) if classesToClasses contains sym => + val leontps = tps.map(extractType) + + classDefToClassType(classesToClasses(sym), leontps) + + case TypeRef(_, sym, Nil) => + if(dctx.tparams contains sym) { + dctx.tparams(sym) + } else { + println(sym.id) + println(classesToClasses.map{ case (sym, cd) => sym.id+" -> "+cd.id }.mkString("\n")) + unsupported("Type parameter "+tpt+" is unknown (Known: "+dctx.tparams.values.mkString(", ")+")") + } case SingleType(_, sym) if classesToClasses contains sym.moduleClass=> - classDefToClassType(classesToClasses(sym.moduleClass)) + classDefToClassType(classesToClasses(sym.moduleClass), Nil) case _ => unsupported("Could not extract type as PureScala: "+tpt+" ("+tpt.getClass+")") diff --git a/src/main/scala/leon/purescala/Definitions.scala b/src/main/scala/leon/purescala/Definitions.scala index cbbeebfedadb38897ccae8e9ce0ad0bb5cb535a9..a02055b7457556ccab7afd560481e9156cb0e910 100644 --- a/src/main/scala/leon/purescala/Definitions.scala +++ b/src/main/scala/leon/purescala/Definitions.scala @@ -9,6 +9,7 @@ object Definitions { import TreeOps._ import Extractors._ import TypeTrees._ + import TypeTreeOps._ sealed abstract class Definition extends Tree { val id: Identifier @@ -36,26 +37,25 @@ object Definitions { def toVariable : Variable = Variable(id).setType(tpe) } - type VarDecls = Seq[VarDecl] - /** A wrapper for a program. For now a program is simply a single object. The * name is meaningless and we just use the package name as id. */ - case class Program(id: Identifier, mainObject: ObjectDef) extends Definition { - def definedFunctions = mainObject.definedFunctions - def definedClasses = mainObject.definedClasses - def classHierarchyRoots = mainObject.classHierarchyRoots - def algebraicDataTypes = mainObject.algebraicDataTypes - def singleCaseClasses = mainObject.singleCaseClasses - def callGraph = mainObject.callGraph - def calls(f1: FunDef, f2: FunDef) = mainObject.calls(f1, f2) - def callers(f1: FunDef) = mainObject.callers(f1) - def callees(f1: FunDef) = mainObject.callees(f1) - def transitiveCallGraph = mainObject.transitiveCallGraph - def transitivelyCalls(f1: FunDef, f2: FunDef) = mainObject.transitivelyCalls(f1, f2) - def transitiveCallers(f1: FunDef) = mainObject.transitiveCallers.getOrElse(f1, Set()) - def transitiveCallees(f1: FunDef) = mainObject.transitiveCallees.getOrElse(f1, Set()) - def isRecursive(f1: FunDef) = mainObject.isRecursive(f1) - def caseClassDef(name: String) = mainObject.caseClassDef(name) + case class Program(id: Identifier, mainModule: ModuleDef) extends Definition { + def definedFunctions = mainModule.definedFunctions + def definedClasses = mainModule.definedClasses + def classHierarchyRoots = mainModule.classHierarchyRoots + def algebraicDataTypes = mainModule.algebraicDataTypes + def singleCaseClasses = mainModule.singleCaseClasses + def callGraph = mainModule.callGraph + + def calls(f1: FunDef, f2: FunDef) = mainModule.calls(f1, f2) + def callers(f1: FunDef) = mainModule.callers(f1) + def callees(f1: FunDef) = mainModule.callees(f1) + def transitiveCallGraph = mainModule.transitiveCallGraph + def transitivelyCalls(f1: FunDef, f2: FunDef) = mainModule.transitivelyCalls(f1, f2) + def transitiveCallers(f1: FunDef) = mainModule.transitiveCallers.getOrElse(f1, Set()) + def transitiveCallees(f1: FunDef) = mainModule.transitiveCallees.getOrElse(f1, Set()) + def isRecursive(f1: FunDef) = mainModule.isRecursive(f1) + def caseClassDef(name: String) = mainModule.caseClassDef(name) def writeScalaFile(filename: String) { import java.io.FileWriter @@ -67,7 +67,7 @@ object Definitions { } def duplicate = { - copy(mainObject = mainObject.copy(defs = mainObject.defs.collect { + copy(mainModule = mainModule.copy(defs = mainModule.defs.collect { case fd: FunDef => fd.duplicate case d => d })) @@ -77,7 +77,7 @@ object Definitions { object Program { lazy val empty : Program = Program( FreshIdentifier("empty"), - ObjectDef( + ModuleDef( FreshIdentifier("empty"), Seq.empty, Seq.empty @@ -85,31 +85,39 @@ object Definitions { ) } + case class TypeParameterDef(tp: TypeParameter) extends Definition { + val id = tp.id + } + /** Objects work as containers for class definitions, functions (def's) and * val's. */ - case class ObjectDef(id: Identifier, defs : Seq[Definition], invariants: Seq[Expr]) extends Definition { - lazy val definedFunctions : Seq[FunDef] = defs.filter(_.isInstanceOf[FunDef]).map(_.asInstanceOf[FunDef]) + case class ModuleDef(id: Identifier, defs : Seq[Definition], invariants: Seq[Expr]) extends Definition { + lazy val definedFunctions : Seq[FunDef] = defs.collect { case fd: FunDef => fd } - lazy val definedClasses : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]) + lazy val definedClasses : Seq[ClassDef] = defs.collect { case ctd: ClassDef => ctd } - def caseClassDef(caseClassName : String) : CaseClassDef = - definedClasses.find(ctd => ctd.id.name == caseClassName).getOrElse(scala.sys.error("Asking for non-existent case class def: " + caseClassName)).asInstanceOf[CaseClassDef] + def caseClassDef(name : String) : CaseClassDef = definedClasses.find(ctd => ctd.id.name == name) match { + case Some(ccd: CaseClassDef) => ccd + case _ => throw new LeonFatalError("Unknown case class '"+name+"'") + } - lazy val classHierarchyRoots : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]).filter(!_.hasParent) + lazy val classHierarchyRoots : Seq[ClassDef] = defs.collect { + case ctd: ClassDef if !ctd.hasParent => ctd + } - lazy val algebraicDataTypes : Map[AbstractClassDef,Seq[CaseClassDef]] = (defs.collect { - case c @ CaseClassDef(_, Some(_), _) => c - }).groupBy(_.parent.get) + lazy val algebraicDataTypes : Map[AbstractClassDef, Seq[CaseClassDef]] = (defs.collect { + case c @ CaseClassDef(_, _, Some(p), _) => c + }).groupBy(_.parent.get.classDef) lazy val singleCaseClasses : Seq[CaseClassDef] = defs.collect { - case c @ CaseClassDef(_, None, _) => c + case c @ CaseClassDef(_, _, None, _) => c } lazy val (callGraph, callers, callees) = { type CallGraph = Set[(FunDef,FunDef)] def collectCalls(fd: FunDef)(e: Expr): CallGraph = e match { - case f @ FunctionInvocation(f2, _) => Set((fd, f2)) + case f @ FunctionInvocation(f2, _) => Set((fd, f2.fd)) case _ => Set() } @@ -169,101 +177,77 @@ object Definitions { /** Useful because case classes and classes are somewhat unified in some * patterns (of pattern-matching, that is) */ - sealed trait ClassTypeDef extends Definition { + sealed trait ClassDef extends Definition { self => val id: Identifier - def parent: Option[AbstractClassDef] - def setParent(parent: AbstractClassDef) : self.type - def hasParent: Boolean = parent.isDefined - val isAbstract: Boolean - - } + val tparams: Seq[TypeParameterDef] + def fields: Seq[VarDecl] + val parent: Option[AbstractClassType] - /** Will be used at some point as a common ground for case classes (which - * implicitely define extractors) and explicitely defined unapply methods. */ - sealed trait ExtractorTypeDef + def hasParent = parent.isDefined - /** Abstract classes. */ - object AbstractClassDef { - def unapply(acd: AbstractClassDef): Option[(Identifier,Option[AbstractClassDef])] = { - if(acd == null) None else Some((acd.id, acd.parent)) - } - } - class AbstractClassDef(val id: Identifier, prnt: Option[AbstractClassDef] = None) extends ClassTypeDef { - private var parent_ = prnt - var fields: VarDecls = Nil - val isAbstract = true + def fieldsIds = fields.map(_.id) - private var children : List[ClassTypeDef] = Nil + private var _children: List[ClassDef] = Nil - private[purescala] def registerChild(child: ClassTypeDef) : Unit = { - children = child :: children + def registerChildren(chd: ClassDef) = { + _children = (chd :: _children).sortBy(_.id.name) } - def knownChildren : Seq[ClassTypeDef] = { - children - } + def knownChildren: Seq[ClassDef] = _children - def knownDescendents : Seq[ClassTypeDef] = { + def knownDescendents: Seq[ClassDef] = { knownChildren ++ (knownChildren.map(c => c match { case acd: AbstractClassDef => acd.knownDescendents case _ => Nil }).reduceLeft(_ ++ _)) } - def setParent(newParent: AbstractClassDef) = { - if(parent_.isDefined) { - scala.sys.error("Resetting parent is forbidden.") - } - newParent.registerChild(this) - parent_ = Some(newParent) - this + def knownCCDescendents: Seq[CaseClassDef] = knownDescendents.collect { + case ccd: CaseClassDef => + ccd } - def parent = parent_ + + val isAbstract: Boolean + val isCaseObject: Boolean } - /** Case classes. */ - object CaseClassDef { - def unapply(ccd: CaseClassDef): Option[(Identifier,Option[AbstractClassDef],VarDecls)] = { - if(ccd == null) None else Some((ccd.id, ccd.parent, ccd.fields)) - } + /** Abstract classes. */ + case class AbstractClassDef(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val parent: Option[AbstractClassType]) extends ClassDef { + + val fields = Nil + + val isAbstract = true + val isCaseObject = false } - class CaseClassDef(val id: Identifier, prnt: Option[AbstractClassDef] = None) extends ClassTypeDef with ExtractorTypeDef { - private var parent_ = prnt - var fields: VarDecls = Nil - var isCaseObject = false - val isAbstract = false + /** Case classes/objects. */ + case class CaseClassDef(val id: Identifier, + val tparams: Seq[TypeParameterDef], + val parent: Option[AbstractClassType], + val isCaseObject: Boolean) extends ClassDef { - def setParent(newParent: AbstractClassDef) = { - if(parent_.isDefined) { - scala.sys.error("Resetting parent is forbidden.") - } - newParent.registerChild(this) - parent_ = Some(newParent) - this + var _fields = Seq[VarDecl]() + + def fields = _fields + + def setFields(fields: Seq[VarDecl]) { + _fields = fields } - def parent = parent_ - def fieldsIds = fields.map(_.id) + + val isAbstract = false + def selectorID2Index(id: Identifier) : Int = { - var i : Int = 0 - var found = false - val fs = fields.size - while(!found && i < fs) { - if(fields(i).id == id) { - found = true - } else { - i += 1 - } - } + val index = fields.zipWithIndex.find(_._1.id == id).map(_._2) - if(found) - i - else - scala.sys.error("Asking for index of field that does not belong to the case class.") + index.getOrElse { + scala.sys.error("Could not find '"+id+"' ("+id.uniqueName+") within "+fields.map(_.id.uniqueName).mkString(", ")) + } } } @@ -273,7 +257,7 @@ object Definitions { } /** Functions (= 'methods' of objects) */ - class FunDef(val id: Identifier, val returnType: TypeTree, val args: VarDecls) extends Definition { + class FunDef(val id: Identifier, val tparams: Seq[TypeParameterDef], val returnType: TypeTree, val args: Seq[VarDecl]) extends Definition { var body: Option[Expr] = None def implementation : Option[Expr] = body var precondition: Option[Expr] = None @@ -284,7 +268,7 @@ object Definitions { var orig: Option[FunDef] = None def duplicate: FunDef = { - val fd = new FunDef(id, returnType, args) + val fd = new FunDef(id, tparams, returnType, args) fd.body = body fd.precondition = precondition fd.postcondition = postcondition @@ -306,5 +290,97 @@ object Definitions { def annotations : Set[String] = annots def isPrivate : Boolean = annots.contains("private") + + def typed(tps: Seq[TypeTree]) = { + assert(tps.size == tparams.size) + TypedFunDef(this, tps) + } + + def typed = { + assert(tparams.isEmpty) + TypedFunDef(this, Nil) + } + + } + + // Wrapper for typing function according to valuations for type parameters + case class TypedFunDef(fd: FunDef, tps: Seq[TypeTree]) extends Tree { + val id = fd.id + + def signature = { + if (tps.nonEmpty) { + id.toString+tps.mkString("[", ", ", "]") + } else { + id.toString + } + } + + private lazy val typesMap = { + (fd.tparams zip tps).toMap + } + + def translated(t: TypeTree): TypeTree = instantiateType(t, typesMap) + + def translated(e: Expr): Expr = instantiateType(e, typesMap, argsMap) + + lazy val (args: Seq[VarDecl], argsMap: Map[Identifier, Identifier]) = { + if (tps.isEmpty) { + (fd.args, Map()) + } else { + val newArgs = fd.args.map { + case vd @ VarDecl(id, tpe) => + val newTpe = translated(tpe) + val newId = FreshIdentifier(id.name, true).setType(newTpe).copiedFrom(id) + + VarDecl(newId, newTpe).setPos(vd) + } + + val argsMap: Map[Identifier, Identifier] = (fd.args zip newArgs).map { case (vd1, vd2) => vd1.id -> vd2.id }.toMap + + (newArgs, argsMap) + } + } + + lazy val functionType = FunctionType(args.map(_.tpe).toList, returnType) + + lazy val returnType: TypeTree = translated(fd.returnType) + + private var trCache = Map[Expr, Expr]() + private var postCache = Map[(Identifier, Expr), (Identifier, Expr)]() + + def body = fd.body.map { b => + trCache.getOrElse(b, { + val res = translated(b) + trCache += b -> res + res + }) + } + + def precondition = fd.precondition.map { pre => + trCache.getOrElse(pre, { + val res = translated(pre) + trCache += pre -> res + res + }) + } + + def postcondition = fd.postcondition.map { + case (id, post) if tps.nonEmpty => + postCache.getOrElse((id, post), { + val nId = FreshIdentifier(id.name).setType(translated(id.getType)).copiedFrom(id) + val res = nId -> instantiateType(post, typesMap, argsMap + (id -> nId)) + postCache += ((id,post) -> res) + res + }) + + case p => p + } + + def hasImplementation = body.isDefined + def hasBody = hasImplementation + def hasPrecondition = precondition.isDefined + def hasPostcondition = postcondition.isDefined + + override def getPos = fd.getPos } } diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index 0720cf2488b248d6af2e7c766c197d7aa208add6..63d5add418c9be1c3a7466877f2f7a0da173e898 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -35,8 +35,8 @@ object FunctionClosure extends TransformationPhase { pathConstraints = fd.precondition.toList fd.body = fd.body.map(b => functionClosure(b, fd.args.map(_.id).toSet, Map(), Map())) }) - val Program(id, ObjectDef(objId, defs, invariants)) = program - val res = Program(id, ObjectDef(objId, defs ++ topLevelFuns, invariants)) + val Program(id, ModuleDef(objId, defs, invariants)) = program + val res = Program(id, ModuleDef(objId, defs ++ topLevelFuns, invariants)) res } @@ -55,7 +55,7 @@ object FunctionClosure extends TransformationPhase { val newBindedVars: Set[Identifier] = bindedVars ++ fd.args.map(_.id) val newFunId = FreshIdentifier(fd.id.uniqueName) //since we hoist this at the top level, we need to make it a unique name - val newFunDef = new FunDef(newFunId, fd.returnType, newVarDecls).copiedFrom(fd) + val newFunDef = new FunDef(newFunId, fd.tparams, fd.returnType, newVarDecls).copiedFrom(fd) topLevelFuns += newFunDef newFunDef.addAnnotation(fd.annotations.toSeq:_*) //TODO: this is still some dangerous side effects newFunDef.parent = Some(parent) @@ -114,12 +114,12 @@ object FunctionClosure extends TransformationPhase { pathConstraints = pathConstraints.tail IfExpr(rCond, rThen, rElze).copiedFrom(i) } - case fi @ FunctionInvocation(fd, args) => fd2FreshFd.get(fd) match { + case fi @ FunctionInvocation(tfd, args) => fd2FreshFd.get(tfd.fd) match { case None => - FunctionInvocation(fd, + FunctionInvocation(tfd, args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd))).copiedFrom(fi) case Some((nfd, extraArgs)) => - FunctionInvocation(nfd, + FunctionInvocation(nfd.typed(tfd.tps), args.map(arg => functionClosure(arg, bindedVars, id2freshId, fd2FreshFd)) ++ extraArgs.map(v => replace(id2freshId.map(p => (p._1.toVariable, p._2.toVariable)), v))).copiedFrom(fi) } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index fa24b126568a3165538a85e4dcc3b3a8ac9d7881..63d302ad2dad607089fceaf2ae106aa9d099e404 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -69,7 +69,16 @@ class PrettyPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) { implicit val p = Some(tree) tree match { - case Variable(id) => sb.append(idToString(id)) + case id: Identifier => + sb.append(idToString(id)) + + case Variable(id) => + //sb.append("(") + pp(id, p) + //sb.append(": ") + //pp(id.getType, p) + //sb.append(")") + case LetTuple(bs,d,e) => sb.append("(let (" + bs.map(idToString _).mkString(",") + " := "); pp(d, p) @@ -105,6 +114,10 @@ class PrettyPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) { case BooleanLiteral(v) => sb.append(v) case StringLiteral(s) => sb.append("\"" + s + "\"") case UnitLiteral => sb.append("()") + case GenericValue(tp, id) => + pp(tp, p) + sb.append("#"+id) + case t@Tuple(exprs) => ppNary(exprs, "(", ", ", ")") case s@TupleSelect(t, i) => pp(t, p) @@ -115,24 +128,31 @@ class PrettyPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) { pp(pred, p) sb.append(")") - case CaseClass(cd, args) => - sb.append(idToString(cd.id)) - if (cd.isCaseObject) { + case CaseClass(cct, args) => + pp(cct, p) + if (cct.classDef.isCaseObject) { ppNary(args, "", "", "") } else { ppNary(args, "(", ", ", ")") } - case CaseClassInstanceOf(cd, e) => + case CaseClassInstanceOf(cct, e) => pp(e, p) - sb.append(".isInstanceOf[" + idToString(cd.id) + "]") + sb.append(".isInstanceOf[") + pp(cct, p) + sb.append("]") case CaseClassSelector(_, cc, id) => pp(cc, p) sb.append("." + idToString(id)) - case FunctionInvocation(fd, args) => - sb.append(idToString(fd.id)) + case FunctionInvocation(tfd, args) => + sb.append(idToString(tfd.id)) + + if (tfd.tps.nonEmpty) { + ppNary(tfd.tps, "[", ",", "]") + } + ppNary(args, "(", ", ", ")") case Plus(l,r) => ppBinary(l, r, " + ") @@ -272,9 +292,10 @@ class PrettyPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) { pp(rhs, p)(lvl+1) // Patterns - case CaseClassPattern(bndr, ccd, subps) => + case CaseClassPattern(bndr, cct, subps) => bndr.foreach(b => sb.append(b + " @ ")) - sb.append(idToString(ccd.id)).append("(") + pp(cct, p) + sb.append("(") var c = 0 val sz = subps.size subps.foreach(sp => { @@ -287,9 +308,9 @@ class PrettyPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) { case WildcardPattern(None) => sb.append("_") case WildcardPattern(Some(id)) => sb.append(idToString(id)) - case InstanceOfPattern(bndr, ccd) => + case InstanceOfPattern(bndr, cct) => bndr.foreach(b => sb.append(b + " : ")) - sb.append(idToString(ccd.id)) + pp(cct, p) case TuplePattern(bndr, subPatterns) => bndr.foreach(b => sb.append(b + " @ ")) @@ -334,7 +355,12 @@ class PrettyPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) { } sb.append(" => ") pp(tt, p) - case c: ClassType => sb.append(idToString(c.classDef.id)) + + case c: ClassType => + sb.append(idToString(c.classDef.id)) + if (c.tps.nonEmpty) { + ppNary(c.tps, "[", ",", "]") + } // Definitions @@ -346,7 +372,7 @@ class PrettyPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) { pp(mainObj, p)(lvl+1) sb.append("}\n") - case ObjectDef(id, defs, invs) => + case ModuleDef(id, defs, invs) => nl sb.append("object ") sb.append(idToString(id)) @@ -366,30 +392,30 @@ class PrettyPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) { nl sb.append("}\n") - case AbstractClassDef(id, parent) => + case AbstractClassDef(id, tparams, parent) => nl sb.append("sealed abstract class ") sb.append(idToString(id)) parent.foreach(p => sb.append(" extends " + idToString(p.id))) - case CaseClassDef(id, parent, varDecls) => + case ccd @ CaseClassDef(id, tparams, parent, isObj) => nl - sb.append("case class ") + if (isObj) { + sb.append("case object ") + } else { + sb.append("case class ") + } + sb.append(idToString(id)) - sb.append("(") - var c = 0 - val sz = varDecls.size - varDecls.foreach(vd => { - sb.append(idToString(vd.id)) - sb.append(": ") - pp(vd.tpe, p) - if(c < sz - 1) { - sb.append(", ") - } - c = c + 1 - }) - sb.append(")") + if (tparams.nonEmpty) { + ppNary(tparams, "[", ", ", "]") + } + + if (!isObj) { + ppNary(ccd.fields, "(", ", ", ")") + } + parent.foreach(p => sb.append(" extends " + idToString(p.id))) case fd: FunDef => @@ -443,6 +469,12 @@ class PrettyPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) { sb.append("[unknown function implementation]") } + case TypeParameterDef(tp) => + pp(tp, p) + + case TypeParameter(id) => + pp(id, p) + case _ => sb.append("Tree? (" + tree.getClass + ")") } if (opts.printPositions) { diff --git a/src/main/scala/leon/purescala/ScalaPrinter.scala b/src/main/scala/leon/purescala/ScalaPrinter.scala index 90bd33aff9794127a2d47ffe5501a7263099cc04..3bbfc02603a94dd60e4cd6d8ae1faef17dbc5e47 100644 --- a/src/main/scala/leon/purescala/ScalaPrinter.scala +++ b/src/main/scala/leon/purescala/ScalaPrinter.scala @@ -111,22 +111,6 @@ class ScalaPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) ex pp(t, p) sb.append("._" + i) - case CaseClass(cd, args) => - sb.append(idToString(cd.id)) - if (cd.isCaseObject) { - ppNary(args, "", "", "") - } else { - ppNary(args, "(", ", ", ")") - } - - case CaseClassInstanceOf(cd, e) => - pp(e, p) - sb.append(".isInstanceOf[" + idToString(cd.id) + "]") - - case CaseClassSelector(_, cc, id) => - pp(cc, p) - sb.append("." + idToString(id)) - case FunctionInvocation(fd, args) => sb.append(idToString(fd.id)) ppNary(args, "(", ", ", ")") @@ -287,7 +271,7 @@ class ScalaPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) ex assert(lvl == 0) pp(mainObj, p) - case ObjectDef(id, defs, invs) => + case ModuleDef(id, defs, invs) => sb.append("object ") sb.append(idToString(id)) sb.append(" {\n") @@ -308,50 +292,51 @@ class ScalaPrinter(opts: PrinterOptions, sb: StringBuffer = new StringBuffer) ex ind(lvl) sb.append("}\n") - case AbstractClassDef(id, parent) => + case AbstractClassDef(id, tparams, parent) => sb.append("sealed abstract class ") sb.append(idToString(id)) + + if (tparams.nonEmpty) { + ppNary(tparams, "[", ",", "]") + } + parent.foreach(p => sb.append(" extends " + idToString(p.id))) - case CaseClassDef(id, parent, varDecls) => - sb.append("case class ") + case ccd @ CaseClassDef(id, tparams, parent, isObj) => + if (isObj) { + sb.append("case object ") + } else { + sb.append("case class ") + } + sb.append(idToString(id)) - sb.append("(") - var c = 0 - val sz = varDecls.size - varDecls.foreach(vd => { - sb.append(idToString(vd.id)) - sb.append(": ") - pp(vd.tpe, p) - if(c < sz - 1) { - sb.append(", ") - } - c = c + 1 - }) - sb.append(")") + if (tparams.nonEmpty) { + ppNary(tparams, "[", ", ", "]") + } + + if (!isObj) { + ppNary(ccd.fields, "(", ", ", ")") + } + parent.foreach(p => sb.append(" extends " + idToString(p.id))) + case vd: VarDecl => + pp(vd.id, p) + sb.append(": ") + pp(vd.tpe, p) + case fd: FunDef => sb.append("def ") - sb.append(idToString(fd.id)) - sb.append("(") - - val sz = fd.args.size - var c = 0 + pp(fd.id, p) - fd.args.foreach(arg => { - sb.append(idToString(arg.id)) - sb.append(": ") - pp(arg.tpe, p) + if (fd.tparams.nonEmpty) { + ppNary(fd.tparams, "[", ", ", "]") + } - if(c < sz - 1) { - sb.append(", ") - } - c = c + 1 - }) + ppNary(fd.args, "(", ", ", ")") - sb.append("): ") + sb.append(": ") pp(fd.returnType, p) sb.append(" = {\n") ind(lvl+1) diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index 566ea5a95cd2b632dd045dc6a600fab0dd936256..aef0720a3652162785cab2e33027701c8e794aa9 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -46,7 +46,7 @@ class ScopeSimplifier extends Transformer { VarDecl(newArg, tpe) } - val newFd = new FunDef(newId, fd.returnType, newArgs) + val newFd = new FunDef(newId, fd.tparams, fd.returnType, newArgs) newScope = newScope.registerFunDef(fd -> newFd) @@ -123,11 +123,11 @@ class ScopeSimplifier extends Transformer { case Variable(id) => Variable(scope.oldToNew.getOrElse(id, id)) - case FunctionInvocation(fd, args) => - val newFd = scope.funDefs.getOrElse(fd, fd) + case FunctionInvocation(tfd, args) => + val newFd = scope.funDefs.getOrElse(tfd.fd, tfd.fd) val newArgs = args.map(rec(_, scope)) - FunctionInvocation(newFd, newArgs) + FunctionInvocation(newFd.typed(tfd.tps), newArgs) case UnaryOperator(e, builder) => builder(rec(e, scope)) diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index e6d17543fdb024eab909de6d6c02523e3e37a9b8..59d0ba02c27e326c5993f2e9ad8de7105cf548da 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -12,6 +12,7 @@ object TreeOps { import TypeTrees._ import Definitions._ import Trees._ + import TypeTreeOps._ import Extractors._ /** @@ -557,19 +558,19 @@ object TreeOps { case WildcardPattern(ob) => bind(ob, in) case InstanceOfPattern(ob, ct) => ct match { - case _: AbstractClassDef => + case _: AbstractClassType => bind(ob, in) - case cd: CaseClassDef => - And(CaseClassInstanceOf(cd, in), bind(ob, in)) + case cct: CaseClassType => + And(CaseClassInstanceOf(cct, in), bind(ob, in)) } - case CaseClassPattern(ob, ccd, subps) => { - assert(ccd.fields.size == subps.size) - val pairs = ccd.fields.map(_.id).toList zip subps.toList - val subTests = pairs.map(p => rec(CaseClassSelector(ccd, in, p._1), p._2)) + case CaseClassPattern(ob, cct, subps) => + assert(cct.fields.size == subps.size) + val pairs = cct.fields.map(_.id).toList zip subps.toList + val subTests = pairs.map(p => rec(CaseClassSelector(cct, in, p._1), p._2)) val together = And(bind(ob, in) +: subTests) - And(CaseClassInstanceOf(ccd, in), together) - } + And(CaseClassInstanceOf(cct, in), together) + case TuplePattern(ob, subps) => { val TupleType(tpes) = in.getType assert(tpes.size == subps.size) @@ -678,26 +679,28 @@ object TreeOps { case TupleType(tpes) => Tuple(tpes.map(simplestValue)) case ArrayType(tpe) => ArrayFill(IntLiteral(0), simplestValue(tpe)) - case AbstractClassType(acd) => + case act @ AbstractClassType(acd, tpe) => val children = acd.knownChildren def isRecursive(ccd: CaseClassDef): Boolean = { - ccd.fields.exists(fd => fd.getType match { - case AbstractClassType(fieldAcd) => acd == fieldAcd - case CaseClassType(fieldCcd) => ccd == fieldCcd + act.fieldsTypes.exists{ + case AbstractClassType(fieldAcd, _) => acd == fieldAcd + case CaseClassType(fieldCcd, _) => ccd == fieldCcd case _ => false - }) + } } val nonRecChildren = children.collect { case ccd: CaseClassDef if !isRecursive(ccd) => ccd } val orderedChildren = nonRecChildren.sortBy(_.fields.size) - simplestValue(classDefToClassType(orderedChildren.head)) + simplestValue(classDefToClassType(orderedChildren.head, tpe)) - case CaseClassType(ccd) => - val fields = ccd.fields - CaseClass(ccd, fields.map(f => simplestValue(f.getType))) + case cct: CaseClassType => + CaseClass(cct, cct.fieldsTypes.map(t => simplestValue(t))) + + case tp: TypeParameter => + GenericValue(tp, 0) case _ => throw new Exception("I can't choose simplest value for type " + tpe) } @@ -884,11 +887,11 @@ object TreeOps { } var scrutSet = Set[Expr]() - var conditions = Map[Expr, CaseClassDef]() + var conditions = Map[Expr, CaseClassType]() var matchingOn = cases.collect { case cc : CaseClassInstanceOf => cc } sortBy(cc => selectorDepth(cc.expr)) - for (CaseClassInstanceOf(cd, expr) <- matchingOn) { - conditions += expr -> cd + for (CaseClassInstanceOf(cct, expr) <- matchingOn) { + conditions += expr -> cct expr match { case cd: CaseClassSelector => @@ -904,7 +907,7 @@ object TreeOps { var substMap = Map[Expr, Expr]() - def computePatternFor(cd: CaseClassDef, prefix: Expr): Pattern = { + def computePatternFor(ct: CaseClassType, prefix: Expr): Pattern = { val name = prefix match { case CaseClassSelector(_, _, id) => id.name @@ -912,14 +915,14 @@ object TreeOps { case _ => "tmp" } - val binder = FreshIdentifier(name, true).setType(prefix.getType) // Is it full of women though? + val binder = FreshIdentifier(name, true).setType(prefix.getType) // prefix becomes binder substMap += prefix -> Variable(binder) - substMap += CaseClassInstanceOf(cd, prefix) -> BooleanLiteral(true) + substMap += CaseClassInstanceOf(ct, prefix) -> BooleanLiteral(true) - val subconds = for (id <- cd.fieldsIds) yield { - val fieldSel = CaseClassSelector(cd, prefix, id) + val subconds = for (id <- ct.classDef.fieldsIds) yield { + val fieldSel = CaseClassSelector(ct, prefix, id) if (conditions contains fieldSel) { computePatternFor(conditions(fieldSel), fieldSel) } else { @@ -929,7 +932,7 @@ object TreeOps { } } - CaseClassPattern(Some(binder), cd, subconds) + CaseClassPattern(Some(binder), ct, subconds) } val (scrutinees, patterns) = scrutSet.toSeq.map(s => (s, computePatternFor(conditions(s), s))).unzip @@ -1177,7 +1180,7 @@ object TreeOps { val newFD = mapType(funDef.returnType) match { case None => funDef case Some(rt) => - val fd = new FunDef(FreshIdentifier(funDef.id.name, true), rt, funDef.args) + val fd = new FunDef(FreshIdentifier(funDef.id.name, true), funDef.tparams, rt, funDef.args) // These will be taken care of in the recursive traversal. fd.body = funDef.body fd.precondition = funDef.precondition @@ -1212,8 +1215,8 @@ object TreeOps { case l @ LetDef(fd, bdy) => LetDef(fd2fd(fd), bdy) - case FunctionInvocation(fd, args) => - FunctionInvocation(fd2fd(fd), args) + case FunctionInvocation(tfd, args) => + FunctionInvocation(fd2fd(tfd.fd).typed(tfd.tps), args) case _ => e } @@ -1335,26 +1338,27 @@ object TreeOps { * foo(Cons(h,t), b) => foo(t, b) */ def isInductiveOn(sf: SolverFactory[Solver])(expr: Expr, on: Identifier): Boolean = on match { - case IsTyped(origId, AbstractClassType(cd)) => - def isAlternativeRecursive(cd: CaseClassDef): Boolean = { - cd.fieldsIds.exists(_.getType == origId.getType) - } + case IsTyped(origId, AbstractClassType(cd, tps)) => val toCheck = cd.knownDescendents.collect { case ccd: CaseClassDef => - val isType = CaseClassInstanceOf(ccd, Variable(on)) + val cct = CaseClassType(ccd, tps) - val recSelectors = ccd.fieldsIds.filter(_.getType == on.getType) + val isType = CaseClassInstanceOf(cct, Variable(on)) - if (recSelectors.isEmpty) { - Seq() - } else { - val v = Variable(on) + val recSelectors = cct.fields.collect { + case vd if vd.tpe == on.getType => vd.id + } - recSelectors.map{ s => - And(And(isType, expr), Not(replace(Map(v -> CaseClassSelector(ccd, v, s)), expr))) - } + if (recSelectors.isEmpty) { + Seq() + } else { + val v = Variable(on) + + recSelectors.map{ s => + And(And(isType, expr), Not(replace(Map(v -> CaseClassSelector(cct, v, s)), expr))) } + } }.flatten val solver = SimpleSolverAPI(sf) @@ -1503,8 +1507,9 @@ object TreeOps { false } - case (FunctionInvocation(fd1, args1), FunctionInvocation(fd2, args2)) => - fdHomo(fd1, fd2) && + case (FunctionInvocation(tfd1, args1), FunctionInvocation(tfd2, args2)) => + // TODO: Check type params + fdHomo(tfd1.fd, tfd2.fd) && (args1 zip args2).forall{ case (a1, a2) => isHomo(a1, a2) } case Same(UnaryOperator(e1, _), UnaryOperator(e2, _)) => @@ -1561,6 +1566,9 @@ object TreeOps { * Seq( (T1, Seq(P1, P4)), (T2, Seq(P2, P5)), (T3, Seq(p3, p6))) * * We then check that P1+P4 covers every T1, etc.. + * + * @EK: We ignore type parameters here, we might want to make sure it's + * valid. What's Leon's semantics w.r.t. erasure? */ def areExaustive(pss: Seq[(TypeTree, Seq[Pattern])]): Boolean = pss.forall { case (tpe, ps) => @@ -1576,10 +1584,10 @@ object TreeOps { case _: ClassType => def typesOf(tpe: TypeTree): Set[CaseClassDef] = tpe match { - case AbstractClassType(ctp) => + case AbstractClassType(ctp, _) => ctp.knownDescendents.collect { case c: CaseClassDef => c }.toSet - case CaseClassType(ctd) => + case CaseClassType(ctd, _) => Set(ctd) case _ => @@ -1595,9 +1603,10 @@ object TreeOps { case InstanceOfPattern(_, cct) => // (a: B) covers all Bs - subChecks --= typesOf(classDefToClassType(cct)) + subChecks --= typesOf(cct) - case CaseClassPattern(_, ccd, subs) => + case CaseClassPattern(_, cct, subs) => + val ccd = cct.classDef // We record the patterns per types, if they still need to be checked if (subChecks contains ccd) { subChecks += (ccd -> (subChecks(ccd) :+ subs)) @@ -1651,7 +1660,7 @@ object TreeOps { **/ def flattenFunctions(fdOuter: FunDef): FunDef = { fdOuter.body match { - case Some(LetDef(fdInner, FunctionInvocation(fdInner2, args))) if fdInner == fdInner2 => + case Some(LetDef(fdInner, FunctionInvocation(tfdInner2, args))) if fdInner == tfdInner2.fd => val argsDef = fdOuter.args.map(_.id) val argsCall = args.collect { case Variable(id) => id } @@ -1662,9 +1671,9 @@ object TreeOps { val innerIdsToOuterIds = (fdInner.args.map(_.id) zip argsCall).toMap def pre(e: Expr) = e match { - case FunctionInvocation(fd, args) if fd == fdInner => + case FunctionInvocation(tfd, args) if tfd.fd == fdInner => val newArgs = (args zip rewriteMap).sortBy(_._2) - FunctionInvocation(fdOuter, newArgs.map(_._1)) + FunctionInvocation(fdOuter.typed(tfd.tps), newArgs.map(_._1)) case Variable(id) => Variable(innerIdsToOuterIds.getOrElse(id, id)) case _ => diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index 74f8ead992b73a9eb6a2934e97fedafe89020507..2052034fac6f3a15cee9b7eee1851c66359a6f9a 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -9,6 +9,7 @@ import utils._ object Trees { import Common._ import TypeTrees._ + import TypeTreeOps._ import Definitions._ import Extractors._ @@ -62,15 +63,13 @@ object Trees { /* Control flow */ - case class FunctionInvocation(funDef: FunDef, args: Seq[Expr]) extends Expr with FixedType { - val fixedType = funDef.returnType - - funDef.args.zip(args).foreach { - case (a, c) => typeCheck(c, a.tpe) - } + case class FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) extends Expr with FixedType { + val fixedType = tfd.returnType } case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr with FixedType { - val fixedType = leastUpperBound(thenn.getType, elze.getType).getOrElse(AnyType) + val fixedType = leastUpperBound(thenn.getType, elze.getType).getOrElse{ + AnyType + } } case class Tuple(exprs: Seq[Expr]) extends Expr with FixedType { @@ -118,7 +117,7 @@ object Trees { scrutinee.getType match { case a: AbstractClassType => new MatchExpr(scrutinee, cases) case c: CaseClassType => new MatchExpr(scrutinee, cases.filter(_.pattern match { - case CaseClassPattern(_, ccd, _) if ccd != c.classDef => false + case CaseClassPattern(_, cct, _) if cct.classDef != c.classDef => false case _ => true })) case t: TupleType => new MatchExpr(scrutinee, cases) @@ -130,8 +129,11 @@ object Trees { } class MatchExpr(val scrutinee: Expr, val cases: Seq[MatchCase]) extends Expr with FixedType { + assert(cases.nonEmpty) - val fixedType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(AnyType) + val fixedType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse{ + AnyType + } def scrutineeClassType: ClassType = scrutinee.getType.asInstanceOf[ClassType] @@ -168,17 +170,13 @@ object Trees { def binders: Set[Identifier] = subBinders ++ (if(binder.isDefined) Set(binder.get) else Set.empty) } - case class InstanceOfPattern(binder: Option[Identifier], classTypeDef: ClassTypeDef) extends Pattern { // c: Class + case class InstanceOfPattern(binder: Option[Identifier], ct: ClassType) extends Pattern { // c: Class val subPatterns = Seq.empty } case class WildcardPattern(binder: Option[Identifier]) extends Pattern { // c @ _ val subPatterns = Seq.empty } - case class CaseClassPattern(binder: Option[Identifier], caseClassDef: CaseClassDef, subPatterns: Seq[Pattern]) extends Pattern - // case class ExtractorPattern(binder: Option[Identifier], - // extractor : ExtractorTypeDef, - // subPatterns: Seq[Pattern]) extends Pattern // c @ Extractor(...,...) - // We don't handle Seq stars for now. + case class CaseClassPattern(binder: Option[Identifier], ct: CaseClassType, subPatterns: Seq[Pattern]) extends Pattern case class TuplePattern(binder: Option[Identifier], subPatterns: Seq[Pattern]) extends Pattern @@ -387,6 +385,10 @@ object Trees { val value: T } + case class GenericValue(tp: TypeParameter, id: Int) extends Expr with Terminal with FixedType { + val fixedType = tp + } + case class IntLiteral(value: Int) extends Literal[Int] with FixedType { val fixedType = Int32Type } @@ -401,36 +403,41 @@ object Trees { val value = () } - case class CaseClass(classDef: CaseClassDef, args: Seq[Expr]) extends Expr with FixedType { - val fixedType = CaseClassType(classDef) + case class CaseClass(ct: CaseClassType, args: Seq[Expr]) extends Expr with FixedType { + val fixedType = ct } - case class CaseClassInstanceOf(classDef: CaseClassDef, expr: Expr) extends Expr with FixedType { + case class CaseClassInstanceOf(classType: CaseClassType, expr: Expr) extends Expr with FixedType { val fixedType = BooleanType } object CaseClassSelector { - def apply(classDef: CaseClassDef, caseClass: Expr, selector: Identifier): Expr = { + def apply(classType: CaseClassType, caseClass: Expr, selector: Identifier): Expr = { caseClass match { - case CaseClass(cd, fields) if cd == classDef => fields(cd.selectorID2Index(selector)) - case _ => new CaseClassSelector(classDef, caseClass, selector) + case CaseClass(ct, fields) => + if (ct.classDef == classType.classDef) { + fields(ct.classDef.selectorID2Index(selector)) + } else { + new CaseClassSelector(classType, caseClass, selector) + } + case _ => new CaseClassSelector(classType, caseClass, selector) } } - def unapply(e: CaseClassSelector): Option[(CaseClassDef, Expr, Identifier)] = { - if (e eq null) None else Some((e.classDef, e.caseClass, e.selector)) + def unapply(ccs: CaseClassSelector): Option[(CaseClassType, Expr, Identifier)] = { + Some((ccs.classType, ccs.caseClass, ccs.selector)) } } - class CaseClassSelector(val classDef: CaseClassDef, val caseClass: Expr, val selector: Identifier) extends Expr with FixedType { - val fixedType = classDef.fields.find(_.id == selector).get.getType + class CaseClassSelector(val classType: CaseClassType, val caseClass: Expr, val selector: Identifier) extends Expr with FixedType { + val fixedType = classType.fieldsTypes(classType.classDef.selectorID2Index(selector)) override def equals(that: Any): Boolean = (that != null) && (that match { - case t: CaseClassSelector => (t.classDef, t.caseClass, t.selector) == (classDef, caseClass, selector) + case t: CaseClassSelector => (t.classType, t.caseClass, t.selector) == (classType, caseClass, selector) case _ => false }) - override def hashCode: Int = (classDef, caseClass, selector).hashCode + override def hashCode: Int = (classType, caseClass, selector).hashCode } /* Arithmetic */ @@ -492,13 +499,9 @@ object Trees { leastUpperBound(Seq(set1, set2).map(_.getType)).foreach(setType _) } case class SetMin(set: Expr) extends Expr with FixedType { - typeCheck(set, SetType(Int32Type)) - val fixedType = Int32Type } case class SetMax(set: Expr) extends Expr with FixedType { - typeCheck(set, SetType(Int32Type)) - val fixedType = Int32Type } diff --git a/src/main/scala/leon/purescala/TypeTreeOps.scala b/src/main/scala/leon/purescala/TypeTreeOps.scala new file mode 100644 index 0000000000000000000000000000000000000000..41d9ab08e385bc83f0db4ca6ea9e5d14efe28f1e --- /dev/null +++ b/src/main/scala/leon/purescala/TypeTreeOps.scala @@ -0,0 +1,219 @@ +package leon +package purescala + +import TreeOps.postMap +import TypeTrees._ +import Definitions._ +import Common._ +import Trees._ +import Extractors._ + +object TypeTreeOps { + def canBeSubtypeOf(tpe: TypeTree, freeParams: Seq[TypeParameterDef], stpe: TypeTree): Option[Seq[TypeParameter]] = { + if (freeParams.isEmpty) { + if (isSubtypeOf(tpe, stpe)) { + Some(Nil) + } else { + None + } + } else { + // TODO + None + } + } + + def bestRealType(t: TypeTree) : TypeTree = t match { + case c: ClassType if c.classDef.isInstanceOf[CaseClassDef] => { + c.classDef.parent match { + case None => CaseClassType(c.classDef.asInstanceOf[CaseClassDef], c.tps) + case Some(p) => instantiateType(p, (c.classDef.tparams zip c.tps).toMap) + } + } + case other => other + } + + def leastUpperBound(t1: TypeTree, t2: TypeTree): Option[TypeTree] = (t1,t2) match { + case (c1: ClassType, c2: ClassType) => + import scala.collection.immutable.Set + + + def computeChain(ct: ClassType): List[ClassType] = ct.parent match { + case Some(pct) => + computeChain(pct) ::: List(ct) + case None => + List(ct) + } + + var chain1 = computeChain(c1) + var chain2 = computeChain(c2) + + val prefix = (chain1 zip chain2).takeWhile { case (ct1, ct2) => ct1 == ct2 }.map(_._1) + + prefix.lastOption + + case (TupleType(args1), TupleType(args2)) => + val args = (args1 zip args2).map(p => leastUpperBound(p._1, p._2)) + if (args.forall(_.isDefined)) Some(TupleType(args.map(_.get))) else None + case (o1, o2) if (o1 == o2) => Some(o1) + case (o1,BottomType) => Some(o1) + case (BottomType,o2) => Some(o2) + case (o1,AnyType) => Some(AnyType) + case (AnyType,o2) => Some(AnyType) + + case _ => None + } + + def leastUpperBound(ts: Seq[TypeTree]): Option[TypeTree] = { + def olub(ot1: Option[TypeTree], t2: Option[TypeTree]): Option[TypeTree] = ot1 match { + case Some(t1) => leastUpperBound(t1, t2.get) + case None => None + } + + if (ts.isEmpty) { + None + } else { + ts.map(Some(_)).reduceLeft(olub) + } + } + + def isSubtypeOf(t1: TypeTree, t2: TypeTree): Boolean = { + leastUpperBound(t1, t2) == Some(t2) + } + + + def typeCheck(obj: Expr, exps: TypeTree*) { + val res = exps.exists(e => isSubtypeOf(obj.getType, e)) + + if (!res) { + throw TypeErrorException(obj, exps.toList) + } + } + + def instantiateType(tpe: TypeTree, tps: Map[TypeParameterDef, TypeTree]): TypeTree = { + if (tps.isEmpty) { + tpe + } else { + typeParamSubst(tps.map { case (tpd, tp) => tpd.tp -> tp })(tpe) + } + } + + private def typeParamSubst(map: Map[TypeParameter, TypeTree])(tpe: TypeTree): TypeTree = tpe match { + case (tp: TypeParameter) => map.getOrElse(tp, tp) + case NAryType(tps, builder) => builder(tps.map(typeParamSubst(map))) + } + + def instantiateType(e: Expr, tps: Map[TypeParameterDef, TypeTree], ids: Map[Identifier, Identifier]): Expr = { + if (tps.isEmpty && ids.isEmpty) { + e + } else { + val tpeSub = if (tps.isEmpty) { + { (tpe: TypeTree) => tpe } + } else { + typeParamSubst(tps.map { case (tpd, tp) => tpd.tp -> tp }) _ + } + + def rec(idsMap: Map[Identifier, Identifier])(e: Expr): Expr = { + def freshId(id: Identifier, newTpe: TypeTree) = { + FreshIdentifier(id.name, true).setType(newTpe).copiedFrom(id) + } + + // Simple rec without affecting map + val srec = rec(idsMap) _ + + e match { + case fi @ FunctionInvocation(TypedFunDef(fd, tps), args) => + FunctionInvocation(TypedFunDef(fd, tps.map(tpeSub)), args.map(srec)).copiedFrom(fi) + + case cc @ CaseClass(ct, args) => + CaseClass(tpeSub(ct).asInstanceOf[CaseClassType], args.map(srec)).copiedFrom(cc) + + case cc @ CaseClassSelector(ct, e, sel) => + CaseClassSelector(tpeSub(ct).asInstanceOf[CaseClassType], srec(e), sel).copiedFrom(cc) + + case cc @ CaseClassInstanceOf(ct, e) => + CaseClassInstanceOf(tpeSub(ct).asInstanceOf[CaseClassType], srec(e)).copiedFrom(cc) + + case m @ MatchExpr(e, cases) => + val newTpe = tpeSub(e.getType) + + def mapsUnion(maps: Seq[Map[Identifier, Identifier]]): Map[Identifier, Identifier] = { + maps.foldLeft(Map[Identifier, Identifier]())(_ ++ _) + } + + def trCase(c: MatchCase) = c match { + case SimpleCase(p, b) => + val (newP, newIds) = trPattern(p, newTpe) + SimpleCase(newP, rec(idsMap ++ newIds)(b)) + + case GuardedCase(p, g, b) => + val (newP, newIds) = trPattern(p, newTpe) + GuardedCase(newP, rec(idsMap ++ newIds)(g), rec(idsMap ++ newIds)(b)) + } + + def trPattern(p: Pattern, expType: TypeTree): (Pattern, Map[Identifier, Identifier]) = (p, expType) match { + case (InstanceOfPattern(ob, ct), _) => + val newCt = tpeSub(ct).asInstanceOf[ClassType] + val newOb = ob.map(id => freshId(id, newCt)) + + (InstanceOfPattern(newOb, newCt), (ob zip newOb).toMap) + + case (TuplePattern(ob, sps), tpt @ TupleType(stps)) => + val newOb = ob.map(id => freshId(id, tpt)) + + val (newSps, newMaps) = (sps zip stps).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip + + (TuplePattern(newOb, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) + + case (CaseClassPattern(ob, cct, sps), _) => + val newCt = tpeSub(cct).asInstanceOf[CaseClassType] + + val newOb = ob.map(id => freshId(id, newCt)) + + val (newSps, newMaps) = (sps zip newCt.fieldsTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip + + (CaseClassPattern(newOb, newCt, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) + + case (WildcardPattern(ob), expTpe) => + val newOb = ob.map(id => freshId(id, expTpe)) + + (WildcardPattern(newOb), (ob zip newOb).toMap) + } + + MatchExpr(srec(e), cases.map(trCase)).copiedFrom(m) + + case Error(desc) => + Error(desc).setType(tpeSub(e.getType)).copiedFrom(e) + + case s @ FiniteSet(elements) if elements.isEmpty => + FiniteSet(Nil).setType(tpeSub(s.getType)).copiedFrom(s) + + case v @ Variable(id) if idsMap contains id => + Variable(idsMap(id)).copiedFrom(v) + + case u @ UnaryOperator(e, builder) => + builder(srec(e)).copiedFrom(u) + + case b @ BinaryOperator(e1, e2, builder) => + builder(srec(e1), srec(e2)).copiedFrom(b) + + case n @ NAryOperator(es, builder) => + builder(es.map(srec)).copiedFrom(n) + + case _ => + e + } + } + + //println("\\\\"*80) + //println(tps) + //println(ids.map{ case (k,v) => k.uniqueName+" -> "+v.uniqueName }) + //println("\\\\"*80) + //println(e) + val res = rec(ids)(e) + //println(".."*80) + //println(res) + //println("//"*80) + res + } + } +} diff --git a/src/main/scala/leon/purescala/TypeTrees.scala b/src/main/scala/leon/purescala/TypeTrees.scala index 0c37ca43f298ffb478aca7ec86fbaeb87c2736dd..61918e5a21a12cc5443b756b394951b5cee83fa5 100644 --- a/src/main/scala/leon/purescala/TypeTrees.scala +++ b/src/main/scala/leon/purescala/TypeTrees.scala @@ -7,6 +7,7 @@ object TypeTrees { import Common._ import Trees._ import Definitions._ + import TypeTreeOps._ trait Typed extends Serializable { self => @@ -39,14 +40,6 @@ object TypeTrees { } } - def typeCheck(obj: Expr, exps: TypeTree*) { - val res = exps.exists(e => isSubtypeOf(obj.getType, e)) - - if (!res) { - throw TypeErrorException(obj, exps.toList) - } - } - trait FixedType extends Typed { self => @@ -60,72 +53,6 @@ object TypeTrees { override def toString: String = PrettyPrinter(this) } - // Sort of a quick hack... - def bestRealType(t: TypeTree) : TypeTree = t match { - case c: ClassType if c.classDef.isInstanceOf[CaseClassDef] => { - c.classDef.parent match { - case None => CaseClassType(c.classDef.asInstanceOf[CaseClassDef]) - case Some(p) => AbstractClassType(p) - } - } - case other => other - } - - def leastUpperBound(t1: TypeTree, t2: TypeTree): Option[TypeTree] = (t1,t2) match { - case (c1: ClassType, c2: ClassType) => { - import scala.collection.immutable.Set - var c: ClassTypeDef = c1.classDef - var visited: Set[ClassTypeDef] = Set(c) - - while(c.parent.isDefined) { - c = c.parent.get - visited = visited ++ Set(c) - } - - c = c2.classDef - var found: Option[ClassTypeDef] = if(visited.contains(c)) { - Some(c) - } else { - None - } - - while(found.isEmpty && c.parent.isDefined) { - c = c.parent.get - if(visited.contains(c)) - found = Some(c) - } - - if(found.isEmpty) { - None - } else { - Some(classDefToClassType(found.get)) - } - } - case (TupleType(args1), TupleType(args2)) => - val args = (args1 zip args2).map(p => leastUpperBound(p._1, p._2)) - if (args.forall(_.isDefined)) Some(TupleType(args.map(_.get))) else None - case (o1, o2) if (o1 == o2) => Some(o1) - case (o1,BottomType) => Some(o1) - case (BottomType,o2) => Some(o2) - case (o1,AnyType) => Some(AnyType) - case (AnyType,o2) => Some(AnyType) - - case _ => None - } - - def leastUpperBound(ts: Seq[TypeTree]): Option[TypeTree] = { - def olub(ot1: Option[TypeTree], t2: Option[TypeTree]): Option[TypeTree] = ot1 match { - case Some(t1) => leastUpperBound(t1, t2.get) - case None => None - } - - ts.map(Some(_)).reduceLeft(olub) - } - - def isSubtypeOf(t1: TypeTree, t2: TypeTree): Boolean = { - leastUpperBound(t1, t2) == Some(t2) - } - // returns the number of distinct values that inhabit a type sealed abstract class TypeSize extends Serializable case class FiniteSize(size: Int) extends TypeSize @@ -177,6 +104,8 @@ object TypeTrees { case object Int32Type extends TypeTree case object UnitType extends TypeTree + case class TypeParameter(id: Identifier) extends TypeTree + class TupleType private (val bases: Seq[TypeTree]) extends TypeTree { lazy val dimension: Int = bases.length @@ -221,20 +150,68 @@ object TypeTrees { case class ArrayType(base: TypeTree) extends TypeTree sealed abstract class ClassType extends TypeTree { - val classDef: ClassTypeDef + val classDef: ClassDef val id: Identifier = classDef.id - override def hashCode : Int = id.hashCode + override def hashCode : Int = id.hashCode + tps.hashCode override def equals(that : Any) : Boolean = that match { - case t : ClassType => t.id == this.id + case t : ClassType => t.id == this.id && t.tps == this.tps case _ => false } + + val tps: Seq[TypeTree] + + assert(classDef.tparams.size == tps.size) + + lazy val fields = { + val tmap = (classDef.tparams zip tps).toMap + if (tmap.isEmpty) { + classDef.fields + } else { + classDef.fields.map(vd => VarDecl(vd.id, instantiateType(vd.tpe, tmap))) + } + } + + def knownDescendents = classDef.knownDescendents.map(classDefToClassType(_, tps)) + + def knownCCDescendents = classDef.knownCCDescendents.map(CaseClassType(_, tps)) + + lazy val fieldsTypes = fields.map(_.tpe) + + lazy val parent = classDef.parent.map { + pct => instantiateType(pct, (classDef.tparams zip tps).toMap) match { + case act: AbstractClassType => act + case t => throw new LeonFatalError("Unexpected translated parent type: "+t) + } + } + + } + case class AbstractClassType(classDef: AbstractClassDef, tps: Seq[TypeTree]) extends ClassType + case class CaseClassType(override val classDef: CaseClassDef, tps: Seq[TypeTree]) extends ClassType + + def classDefToClassType(cd: ClassDef, tps: Seq[TypeTree]): ClassType = cd match { + case a: AbstractClassDef => AbstractClassType(a, tps) + case c: CaseClassDef => CaseClassType(c, tps) } - case class AbstractClassType(classDef: AbstractClassDef) extends ClassType - case class CaseClassType(classDef: CaseClassDef) extends ClassType - def classDefToClassType(cd: ClassTypeDef): ClassType = cd match { - case a: AbstractClassDef => AbstractClassType(a) - case c: CaseClassDef => CaseClassType(c) + // Using definition types + def classDefToClassType(cd: ClassDef): ClassType = { + classDefToClassType(cd, cd.tparams.map(_.tp)) + } + + object NAryType { + def unapply(t: TypeTree): Option[(Seq[TypeTree], Seq[TypeTree] => TypeTree)] = t match { + case CaseClassType(ccd, ts) => Some((ts, ts => CaseClassType(ccd, ts))) + case AbstractClassType(acd, ts) => Some((ts, ts => AbstractClassType(acd, ts))) + case TupleType(ts) => Some((ts, TupleType(_))) + case ListType(t) => Some((Seq(t), ts => ListType(ts.head))) + case ArrayType(t) => Some((Seq(t), ts => ArrayType(ts.head))) + case TupleType(ts) => Some((ts, TupleType(_))) + case SetType(t) => Some((Seq(t), ts => SetType(ts.head))) + case MultisetType(t) => Some((Seq(t), ts => MultisetType(ts.head))) + case MapType(from,to) => Some((Seq(from, to), t => MapType(t(0), t(1)))) + case FunctionType(fts, tt) => Some((tt +: fts, ts => FunctionType(ts.tail.toList, ts.head))) + case t => Some(Nil, fake => t) + } } } diff --git a/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala b/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala index 143b1c7557ea76553b1ec84d739a9e4c104d144e..1d8b149129481d77b62f6a779f79fba179d47c42 100644 --- a/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala +++ b/src/main/scala/leon/solvers/combinators/FunctionTemplate.scala @@ -15,14 +15,14 @@ import evaluators._ import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap} class FunctionTemplate private( - val funDef : FunDef, + val tfd : TypedFunDef, val activatingBool : Identifier, condVars : Set[Identifier], exprVars : Set[Identifier], guardedExprs : Map[Identifier,Seq[Expr]], isRealFunDef : Boolean) { - private val funDefArgsIDs : Seq[Identifier] = funDef.args.map(_.id) + private val funDefArgsIDs : Seq[Identifier] = tfd.args.map(_.id) private val asClauses : Seq[Expr] = { (for((b,es) <- guardedExprs; e <- es) yield { @@ -31,7 +31,7 @@ class FunctionTemplate private( } val blockers : Map[Identifier,Set[FunctionInvocation]] = { - val idCall = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) + val idCall = FunctionInvocation(tfd, tfd.args.map(_.toVariable)) Map((for((b, es) <- guardedExprs) yield { val calls = es.foldLeft(Set.empty[FunctionInvocation])((s,e) => s ++ functionCallsOf(e)) - idCall @@ -51,7 +51,7 @@ class FunctionTemplate private( private val cache : MutableMap[Seq[Expr],Map[Identifier,Expr]] = MutableMap.empty def instantiate(aVar : Identifier, args : Seq[Expr]) : (Seq[Expr], Map[Identifier,Set[FunctionInvocation]]) = { - assert(args.size == funDef.args.size) + assert(args.size == tfd.args.size) val (wasHit,baseIDSubstMap) = cache.get(args) match { case Some(m) => (true,m) @@ -85,7 +85,7 @@ class FunctionTemplate private( } override def toString : String = { - "Template for def " + funDef.id + "(" + funDef.args.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + funDef.returnType + " is :\n" + + "Template for def " + tfd.id + "(" + tfd.args.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" + " * Activating boolean : " + activatingBool + "\n" + " * Control booleans : " + condVars.toSeq.map(_.toString).mkString(", ") + "\n" + " * Expression vars : " + exprVars.toSeq.map(_.toString).mkString(", ") + "\n" + @@ -97,7 +97,7 @@ class FunctionTemplate private( object FunctionTemplate { val splitAndOrImplies = false - def mkTemplate(funDef: FunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { + def mkTemplate(tfd: TypedFunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { val condVars : MutableSet[Identifier] = MutableSet.empty val exprVars : MutableSet[Identifier] = MutableSet.empty @@ -257,11 +257,11 @@ object FunctionTemplate { } // The precondition if it exists. - val prec : Option[Expr] = funDef.precondition.map(p => matchToIfThenElse(p)) + val prec : Option[Expr] = tfd.precondition.map(p => matchToIfThenElse(p)) - val newBody : Option[Expr] = funDef.body.map(b => matchToIfThenElse(b)) + val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b)) - val invocation : Expr = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) + val invocation : Expr = FunctionInvocation(tfd, tfd.args.map(_.toVariable)) val invocationEqualsBody : Option[Expr] = newBody match { case Some(body) if isRealFunDef => @@ -288,12 +288,12 @@ object FunctionTemplate { } // Now the postcondition. - funDef.postcondition match { + tfd.postcondition match { case Some((id, post)) => val newPost : Expr = replace(Map(Variable(id) -> invocation), matchToIfThenElse(post)) val postHolds : Expr = - if(funDef.hasPrecondition) { + if(tfd.hasPrecondition) { Implies(prec.get, newPost) } else { newPost @@ -305,7 +305,7 @@ object FunctionTemplate { } - new FunctionTemplate(funDef, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*), + new FunctionTemplate(tfd, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*), isRealFunDef) } } diff --git a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala index 4f03819a7ec4255fa02b0110816f9335bf2f08c4..3012604406e5db70ae47fdaf41bcd2bc14f267fd 100644 --- a/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala +++ b/src/main/scala/leon/solvers/combinators/UnrollingSolver.scala @@ -52,8 +52,8 @@ class UnrollingSolver(val context: LeonContext, underlyings: SolverFactory[Incre var newClauses : List[Seq[Expr]] = Nil var newBlockers : Map[Identifier,Set[FunctionInvocation]] = Map.empty - for(blocker <- allBlockers.keySet; FunctionInvocation(funDef, args) <- allBlockers(blocker)) { - val (nc, nb) = getTemplate(funDef).instantiate(blocker, args) + for(blocker <- allBlockers.keySet; FunctionInvocation(tfd, args) <- allBlockers(blocker)) { + val (nc, nb) = getTemplate(tfd).instantiate(blocker, args) newClauses = nc :: newClauses newBlockers = newBlockers ++ nb } @@ -63,7 +63,7 @@ class UnrollingSolver(val context: LeonContext, underlyings: SolverFactory[Incre newClauses.flatten } - val (nc, nb) = template.instantiate(aVar, template.funDef.args.map(a => Variable(a.id))) + val (nc, nb) = template.instantiate(aVar, template.tfd.args.map(a => Variable(a.id))) allClauses = nc.reverse allBlockers = nb @@ -137,23 +137,23 @@ class UnrollingSolver(val context: LeonContext, underlyings: SolverFactory[Incre stop = false } - private val funDefTemplateCache : MutableMap[FunDef, FunctionTemplate] = MutableMap.empty + private val tfdTemplateCache : MutableMap[TypedFunDef, FunctionTemplate] = MutableMap.empty private val exprTemplateCache : MutableMap[Expr, FunctionTemplate] = MutableMap.empty - private def getTemplate(funDef: FunDef): FunctionTemplate = { - funDefTemplateCache.getOrElse(funDef, { - val res = FunctionTemplate.mkTemplate(funDef, true) - funDefTemplateCache += funDef -> res + private def getTemplate(tfd: TypedFunDef): FunctionTemplate = { + tfdTemplateCache.getOrElse(tfd, { + val res = FunctionTemplate.mkTemplate(tfd, true) + tfdTemplateCache += tfd -> res res }) } private def getTemplate(body: Expr): FunctionTemplate = { exprTemplateCache.getOrElse(body, { - val fakeFunDef = new FunDef(FreshIdentifier("fake", true), body.getType, variablesOf(body).toSeq.map(id => VarDecl(id, id.getType))) + val fakeFunDef = new FunDef(FreshIdentifier("fake", true), Nil, body.getType, variablesOf(body).toSeq.map(id => VarDecl(id, id.getType))) fakeFunDef.body = Some(body) - val res = FunctionTemplate.mkTemplate(fakeFunDef, false) + val res = FunctionTemplate.mkTemplate(fakeFunDef.typed, false) exprTemplateCache += body -> res res }) diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index 873388137625de6c28ad579a59d24ac861f3c685..9344224129106a0b849a5b0834e60182ec07b2d2 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -10,6 +10,7 @@ import solvers._ import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.TypeTreeOps._ import xlang.Trees._ import purescala.TreeOps._ import purescala.TypeTrees._ @@ -69,51 +70,97 @@ trait AbstractZ3Solver interrupted = false } - protected[leon] def prepareFunctions : Unit - protected[leon] def functionDefToDecl(funDef: FunDef) : Z3FuncDecl - protected[leon] def functionDeclToDef(decl: Z3FuncDecl) : FunDef - protected[leon] def isKnownDecl(decl: Z3FuncDecl) : Boolean + def functionDefToDecl(tfd: TypedFunDef): Z3FuncDecl = { + functions.toZ3OrCompute(tfd) { + val sortSeq = tfd.args.map(vd => typeToSort(vd.tpe)) + val returnSort = typeToSort(tfd.returnType) - // Lifting of common parts starts here - protected[leon] var exprToZ3Id : Map[Expr,Z3AST] = Map.empty - protected[leon] var z3IdToExpr : Map[Z3AST,Expr] = Map.empty + z3.mkFreshFuncDecl(tfd.id.uniqueName, sortSeq, returnSort) + } + } - protected[leon] var intSort: Z3Sort = null - protected[leon] var boolSort: Z3Sort = null - protected[leon] var setSorts: Map[TypeTree, Z3Sort] = Map.empty - protected[leon] var mapSorts: Map[TypeTree, Z3Sort] = Map.empty - protected[leon] var unitSort: Z3Sort = null - protected[leon] var unitValue: Z3AST = null + def genericValueToDecl(gv: GenericValue): Z3FuncDecl = { + generics.toZ3OrCompute(gv) { + z3.mkFreshFuncDecl(gv.tp.toString+"#"+gv.id+"!val", Seq(), typeToSort(gv.tp)) + } + } + + object LeonType { + def unapply(a: Z3Sort): Option[(TypeTree)] = { + sorts.getLeon(a).map(tt => (tt)) + } + } + + class Bijection[A, B] { + var leonToZ3 = Map[A, B]() + var z3ToLeon = Map[B, A]() + + def +=(a: A, b: B): Unit = { + leonToZ3 += a -> b + z3ToLeon += b -> a + } + + def +=(t: (A,B)): Unit = { + this += (t._1, t._2) + } + + + def clear(): Unit = { + z3ToLeon = Map() + leonToZ3 = Map() + } + + def getZ3(a: A): Option[B] = leonToZ3.get(a) + def getLeon(b: B): Option[A] = z3ToLeon.get(b) + + def toZ3(a: A): B = getZ3(a).get + def toLeon(b: B): A = getLeon(b).get - protected[leon] var funSorts: Map[TypeTree, Z3Sort] = Map.empty - protected[leon] var funDomainConstructors: Map[TypeTree, Z3FuncDecl] = Map.empty - protected[leon] var funDomainSelectors: Map[TypeTree, Seq[Z3FuncDecl]] = Map.empty + def toZ3OrCompute(a: A)(c: => B) = { + getZ3(a).getOrElse { + val res = c + this += a -> res + res + } + } + + def toLeonOrCompute(b: B)(c: => A) = { + getLeon(b).getOrElse { + val res = c + this += res -> b + res + } + } - protected[leon] var tupleSorts: Map[TypeTree, Z3Sort] = Map.empty - protected[leon] var tupleConstructors: Map[TypeTree, Z3FuncDecl] = Map.empty - protected[leon] var tupleSelectors: Map[TypeTree, Seq[Z3FuncDecl]] = Map.empty + def containsLeon(a: A): Boolean = leonToZ3 contains a + def containsZ3(b: B): Boolean = z3ToLeon contains b + } - protected[leon] var arraySorts: Map[TypeTree, Z3Sort] = Map.empty - protected[leon] var arrayTupleCons: Map[TypeTree, Z3FuncDecl] = Map.empty - protected[leon] var arrayTupleSelectorArray: Map[TypeTree, Z3FuncDecl] = Map.empty - protected[leon] var arrayTupleSelectorLength: Map[TypeTree, Z3FuncDecl] = Map.empty + // Bijections between Leon Types/Functions/Ids to Z3 Sorts/Decls/ASTs + protected[leon] var functions = new Bijection[TypedFunDef, Z3FuncDecl] + protected[leon] var generics = new Bijection[GenericValue, Z3FuncDecl] + protected[leon] var sorts = new Bijection[TypeTree, Z3Sort] + protected[leon] var variables = new Bijection[Expr, Z3AST] - protected[leon] var reverseTupleConstructors: Map[Z3FuncDecl, TupleType] = Map.empty - protected[leon] var reverseTupleSelectors: Map[Z3FuncDecl, (TupleType, Int)] = Map.empty + // Meta decls and information used by several sorts + case class ArrayDecls(cons: Z3FuncDecl, select: Z3FuncDecl, length: Z3FuncDecl) + case class TupleDecls(cons: Z3FuncDecl, selects: Seq[Z3FuncDecl]) + protected[leon] var unitValue: Z3AST = null protected[leon] var intSetMinFun: Z3FuncDecl = null protected[leon] var intSetMaxFun: Z3FuncDecl = null - protected[leon] var setCardFuns: Map[TypeTree, Z3FuncDecl] = Map.empty - protected[leon] var adtSorts: Map[ClassTypeDef, Z3Sort] = Map.empty - protected[leon] var fallbackSorts: Map[TypeTree, Z3Sort] = Map.empty - protected[leon] var adtTesters: Map[CaseClassDef, Z3FuncDecl] = Map.empty - protected[leon] var adtConstructors: Map[CaseClassDef, Z3FuncDecl] = Map.empty - protected[leon] var adtFieldSelectors: Map[Identifier, Z3FuncDecl] = Map.empty + protected[leon] var arrayMetaDecls: Map[TypeTree, ArrayDecls] = Map.empty + protected[leon] var tupleMetaDecls: Map[TypeTree, TupleDecls] = Map.empty + protected[leon] var setCardDecls: Map[TypeTree, Z3FuncDecl] = Map.empty + + protected[leon] var adtTesters: Map[CaseClassType, Z3FuncDecl] = Map.empty + protected[leon] var adtConstructors: Map[CaseClassType, Z3FuncDecl] = Map.empty + protected[leon] var adtFieldSelectors: Map[(CaseClassType, Identifier), Z3FuncDecl] = Map.empty - protected[leon] var reverseADTTesters: Map[Z3FuncDecl, CaseClassDef] = Map.empty - protected[leon] var reverseADTConstructors: Map[Z3FuncDecl, CaseClassDef] = Map.empty - protected[leon] var reverseADTFieldSelectors: Map[Z3FuncDecl, (CaseClassDef,Identifier)] = Map.empty + protected[leon] var reverseADTTesters: Map[Z3FuncDecl, CaseClassType] = Map.empty + protected[leon] var reverseADTConstructors: Map[Z3FuncDecl, CaseClassType] = Map.empty + protected[leon] var reverseADTFieldSelectors: Map[Z3FuncDecl, (CaseClassType,Identifier)] = Map.empty protected[leon] val mapRangeSorts: MutableMap[TypeTree, Z3Sort] = MutableMap.empty protected[leon] val mapRangeSomeConstructors: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty @@ -139,8 +186,16 @@ trait AbstractZ3Solver z3 = new Z3Context(z3cfg) + functions.clear() + generics.clear() + sorts.clear() + variables.clear() + + arrayMetaDecls = Map() + tupleMetaDecls = Map() + setCardDecls = Map() + prepareSorts - prepareFunctions isInitialized = true @@ -153,9 +208,6 @@ trait AbstractZ3Solver isInitialized = false initZ3() - - exprToZ3Id = Map.empty - z3IdToExpr = Map.empty } protected[leon] def mapRangeSort(toType : TypeTree) : Z3Sort = mapRangeSorts.get(toType) match { @@ -163,26 +215,13 @@ trait AbstractZ3Solver case None => { import Z3Context.{ADTSortReference, RecursiveType, RegularSort} - intSort = z3.mkIntSort - boolSort = z3.mkBoolSort - - def typeToSortRef(tt: TypeTree): ADTSortReference = tt match { - case BooleanType => RegularSort(boolSort) - case Int32Type => RegularSort(intSort) - case AbstractClassType(d) => RegularSort(adtSorts(d)) - case CaseClassType(d) => RegularSort(adtSorts(d)) - case SetType(d) => RegularSort(setSorts(d)) - case mt @ MapType(d,r) => RegularSort(mapSorts(mt)) - case _ => throw UntranslatableTypeException("Can't handle type " + tt) - } - val z3info = z3.mkADTSorts( Seq( ( toType.toString + "Option", Seq(toType.toString + "Some", toType.toString + "None"), Seq( - Seq(("value", typeToSortRef(toType))), + Seq(("value", RegularSort(typeToSort(toType)))), Seq() ) ) @@ -203,67 +242,72 @@ trait AbstractZ3Solver } case class UntranslatableTypeException(msg: String) extends Exception(msg) - // Prepares some of the Z3 sorts, but *not* the tuple sorts; these are created on-demand. - private def prepareSorts: Unit = { + + def rootType(ct: ClassType): ClassType = ct.parent match { + case Some(p) => rootType(p) + case None => ct + } + + def declareADTSort(ct: ClassType): Z3Sort = { import Z3Context.{ADTSortReference, RecursiveType, RegularSort} - // NOTE THAT abstract classes that extend abstract classes are not - // currently supported in the translation - - intSort = z3.mkIntSort - boolSort = z3.mkBoolSort - setSorts = Map.empty - setCardFuns = Map.empty - - //unitSort = z3.mkUninterpretedSort("unit") - //unitValue = z3.mkFreshConst("Unit", unitSort) - //val bound = z3.mkBound(0, unitSort) - //val eq = z3.mkEq(bound, unitValue) - //val decls = Seq((z3.mkFreshStringSymbol("u"), unitSort)) - //val unitAxiom = z3.mkForAll(0, Seq(), decls, eq) - //println(unitAxiom) - //println(unitValue) - //z3.assertCnstr(unitAxiom) - val Seq((us, Seq(unitCons), Seq(unitTester), _)) = z3.mkADTSorts( - Seq( - ( - "Unit", - Seq("Unit"), - Seq(Seq()) - ) - ) - ) - unitSort = us - unitValue = unitCons() - val intSetSort = typeToSort(SetType(Int32Type)) - intSetMinFun = z3.mkFreshFuncDecl("setMin", Seq(intSetSort), intSort) - intSetMaxFun = z3.mkFreshFuncDecl("setMax", Seq(intSetSort), intSort) + def getHierarchy(ct: ClassType): (ClassType, Seq[CaseClassType]) = ct match { + case act: AbstractClassType => + (act, act.knownCCDescendents) + case cct: CaseClassType => + cct.parent match { + case Some(p) => + getHierarchy(p) + case None => + (cct, List(cct)) + } + } - val hierarchies = program.classHierarchyRoots.flatMap { root => root match { - case c: CaseClassDef => - Some((root, List(c))) + var newHierarchiesMap = Map[ClassType, Seq[CaseClassType]]() - case a: AbstractClassDef => - val childs = a.knownChildren.collect{ case a: CaseClassDef => a }.toList - if (childs.isEmpty) { - None - } else { - Some((root, childs)) + def findDependencies(ct: ClassType): Unit = { + val (root, sub) = getHierarchy(ct) + + if (!(newHierarchiesMap contains root) && !(sorts containsLeon root)) { + newHierarchiesMap += root -> sub + + // look for dependencies + for (ct <- root +: sub; f <- ct.fields) f.tpe match { + case fct: ClassType => + findDependencies(fct) + case _ => } - }} + } + } - val indexMap: Map[ClassTypeDef, Int] = Map(hierarchies.map(_._1).zipWithIndex: _*) + // Populates the dependencies of the ADT to define. + findDependencies(ct) + + val newHierarchies = newHierarchiesMap.toSeq + + val indexMap: Map[ClassType, Int] = Map()++newHierarchies.map(_._1).zipWithIndex def typeToSortRef(tt: TypeTree): ADTSortReference = tt match { - case AbstractClassType(d) => RecursiveType(indexMap(d)) - case CaseClassType(d) => indexMap.get(d) match { - case Some(i) => RecursiveType(i) - case None => RecursiveType(indexMap(d.parent.get)) + case ct: ClassType if sorts containsLeon rootType(ct) => + RegularSort(sorts.toZ3(rootType(ct))) + + case act : AbstractClassType => + // It has to be here + RecursiveType(indexMap(act)) + + case cct: CaseClassType => cct.parent match { + case Some(p) => + typeToSortRef(p) + case None => + RecursiveType(indexMap(cct)) } - case _ => RegularSort(typeToSort(tt)) + + case _=> + RegularSort(typeToSort(tt)) } - val defs = for ((root, childrenList) <- hierarchies) yield { + // Define stuff + val defs = for ((root, childrenList) <- newHierarchies) yield { ( root.id.uniqueName, childrenList.map(ccd => ccd.id.uniqueName), @@ -279,110 +323,134 @@ trait AbstractZ3Solver // } //} - // everything should be alright now... val resultingZ3Info = z3.mkADTSorts(defs) - adtSorts = Map.empty - adtTesters = Map.empty - adtConstructors = Map.empty - adtFieldSelectors = Map.empty - reverseADTTesters = Map.empty - reverseADTConstructors = Map.empty - reverseADTFieldSelectors = Map.empty - - for ((z3Inf, (root, childrenList)) <- (resultingZ3Info zip hierarchies)) { - adtSorts += (root -> z3Inf._1) + for ((z3Inf, (root, childrenList)) <- (resultingZ3Info zip newHierarchies)) { + sorts += (root -> z3Inf._1) assert(childrenList.size == z3Inf._2.size) for ((child, (consFun, testFun)) <- childrenList zip (z3Inf._2 zip z3Inf._3)) { adtTesters += (child -> testFun) + reverseADTTesters += (testFun -> child) adtConstructors += (child -> consFun) + reverseADTConstructors += (consFun -> child) } for ((child, fieldFuns) <- childrenList zip z3Inf._4) { assert(child.fields.size == fieldFuns.size) for ((fid, selFun) <- (child.fields.map(_.id) zip fieldFuns)) { - adtFieldSelectors += (fid -> selFun) + adtFieldSelectors += ((child, fid) -> selFun) reverseADTFieldSelectors += (selFun -> (child, fid)) } } } - reverseADTTesters = adtTesters.map(_.swap) - reverseADTConstructors = adtConstructors.map(_.swap) - // ...and now everything should be in there... + sorts.toZ3(ct) + } + + // Prepares some of the Z3 sorts, but *not* the tuple sorts; these are created on-demand. + private def prepareSorts: Unit = { + import Z3Context.{ADTSortReference, RecursiveType, RegularSort} + + val Seq((us, Seq(unitCons), Seq(unitTester), _)) = z3.mkADTSorts( + Seq( + ( + "Unit", + Seq("Unit"), + Seq(Seq()) + ) + ) + ) + + sorts += Int32Type -> z3.mkIntSort + sorts += BooleanType -> z3.mkBoolSort + sorts += UnitType -> us + + unitValue = unitCons() + + val intSetSort = typeToSort(SetType(Int32Type)) + val intSort = typeToSort(Int32Type) + + intSetMinFun = z3.mkFreshFuncDecl("setMin", Seq(intSetSort), intSort) + intSetMaxFun = z3.mkFreshFuncDecl("setMax", Seq(intSetSort), intSort) + + // Empty everything + adtTesters = Map.empty + adtConstructors = Map.empty + adtFieldSelectors = Map.empty + reverseADTTesters = Map.empty + reverseADTConstructors = Map.empty + reverseADTFieldSelectors = Map.empty } // assumes prepareSorts has been called.... protected[leon] def typeToSort(tt: TypeTree): Z3Sort = tt match { - case Int32Type => intSort - case BooleanType => boolSort - case UnitType => unitSort - case AbstractClassType(cd) => adtSorts(cd) - case CaseClassType(cd) => { - if (cd.hasParent) { - adtSorts(cd.parent.get) - } else { - adtSorts(cd) + case Int32Type | BooleanType | UnitType => + sorts.toZ3(tt) + + case act: AbstractClassType => + sorts.toZ3OrCompute(rootType(act)) { + declareADTSort(rootType(act)) } - } - case SetType(base) => setSorts.get(base) match { - case Some(s) => s - case None => { + + case cct: CaseClassType => + sorts.toZ3OrCompute(rootType(cct)) { + declareADTSort(rootType(cct)) + } + + case SetType(base) => + sorts.toZ3OrCompute(tt) { val newSetSort = z3.mkSetSort(typeToSort(base)) - setSorts = setSorts + (base -> newSetSort) - val newCardFun = z3.mkFreshFuncDecl("card", Seq(newSetSort), intSort) - setCardFuns = setCardFuns + (base -> newCardFun) + + val card = z3.mkFreshFuncDecl("card", Seq(newSetSort), typeToSort(Int32Type)) + setCardDecls += tt -> card + newSetSort } - } - case mt @ MapType(fromType, toType) => mapSorts.get(mt) match { - case Some(s) => s - case None => { + + case MapType(fromType, toType) => + sorts.toZ3OrCompute(tt) { val fromSort = typeToSort(fromType) val toSort = mapRangeSort(toType) - val ms = z3.mkArraySort(fromSort, toSort) - mapSorts += ((mt, ms)) - ms + + z3.mkArraySort(fromSort, toSort) } - } - case at @ ArrayType(base) => arraySorts.get(at) match { - case Some(s) => s - case None => { + + case ArrayType(base) => + sorts.toZ3OrCompute(tt) { val intSort = typeToSort(Int32Type) val toSort = typeToSort(base) val as = z3.mkArraySort(intSort, toSort) val tupleSortSymbol = z3.mkFreshStringSymbol("Array") - val (arrayTupleSort, arrayTupleCons_, Seq(arrayTupleSelectorArray_, arrayTupleSelectorLength_)) = z3.mkTupleSort(tupleSortSymbol, as, intSort) - arraySorts += (at -> arrayTupleSort) - arrayTupleCons += (at -> arrayTupleCons_) - arrayTupleSelectorArray += (at -> arrayTupleSelectorArray_) - arrayTupleSelectorLength += (at -> arrayTupleSelectorLength_) - arrayTupleSort + val (ats, atcons, Seq(atsel, atlength)) = z3.mkTupleSort(tupleSortSymbol, as, intSort) + + arrayMetaDecls += tt -> ArrayDecls(atcons, atsel, atlength) + + ats } - } - case tt @ TupleType(tpes) => tupleSorts.get(tt) match { - case Some(s) => s - case None => { + case TupleType(tpes) => + sorts.toZ3OrCompute(tt) { val tpesSorts = tpes.map(typeToSort) val sortSymbol = z3.mkFreshStringSymbol("Tuple") val (tupleSort, consTuple, projsTuple) = z3.mkTupleSort(sortSymbol, tpesSorts: _*) - tupleSorts += (tt -> tupleSort) - tupleConstructors += (tt -> consTuple) - reverseTupleConstructors += (consTuple -> tt) - tupleSelectors += (tt -> projsTuple) - projsTuple.zipWithIndex.foreach{ case (proj, i) => reverseTupleSelectors += (proj -> (tt, i)) } + + tupleMetaDecls += tt -> TupleDecls(consTuple, projsTuple) + tupleSort } - } - case other => fallbackSorts.get(other) match { - case Some(s) => s - case None => { + + case TypeParameter(id) => + sorts.toZ3OrCompute(tt) { + val symbol = z3.mkFreshStringSymbol(id.name) + val newTPSort = z3.mkUninterpretedSort(symbol) + + newTPSort + } + + case other => + sorts.toZ3OrCompute(other) { reporter.warning("Resorting to uninterpreted type for : " + other) val symbol = z3.mkIntSymbol(nextIntForSymbol()) - val newFBSort = z3.mkUninterpretedSort(symbol) - fallbackSorts = fallbackSorts + (other -> newFBSort) - newFBSort + z3.mkUninterpretedSort(symbol) } - } } protected[leon] def toZ3Formula(expr: Expr, initialMap: Map[Identifier,Z3AST] = Map.empty) : Option[Z3AST] = { @@ -395,193 +463,180 @@ trait AbstractZ3Solver } else { // FIXME TODO pleeeeeeeease make this cleaner. Ie. decide what set of // variable has to remain in a map etc. - exprToZ3Id.filter(p => p._1.isInstanceOf[Variable]).map(p => (p._1.asInstanceOf[Variable].id -> p._2)) + variables.leonToZ3.filter(p => p._1.isInstanceOf[Variable]).map(p => (p._1.asInstanceOf[Variable].id -> p._2)) } - def rec(ex: Expr): Z3AST = { - //println("Stacking up call for:") - //println(ex) - val recResult = (ex match { - case tu @ Tuple(args) => { - // This call is required, because the Z3 sort may not have been generated yet. - // If it has, it's just a map lookup and instant return. - typeToSort(tu.getType) - val constructor = tupleConstructors(tu.getType) - constructor(args.map(rec(_)): _*) - } - case ts @ TupleSelect(tu, i) => { - // See comment above for similar code. - typeToSort(tu.getType) - val selector = tupleSelectors(tu.getType)(i-1) - selector(rec(tu)) - } - case Let(i, e, b) => { - val re = rec(e) - z3Vars = z3Vars + (i -> re) - val rb = rec(b) - z3Vars = z3Vars - i - rb - } - case Waypoint(_, e) => rec(e) - case e @ Error(_) => { - val tpe = e.getType - val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) - exprToZ3Id += (e -> newAST) - z3IdToExpr += (newAST -> e) + def rec(ex: Expr): Z3AST = ex match { + case tu @ Tuple(args) => + typeToSort(tu.getType) // Make sure we generate sort & meta info + val meta = tupleMetaDecls(tu.getType) + + meta.cons(args.map(rec(_)): _*) + + case ts @ TupleSelect(tu, i) => + typeToSort(tu.getType) // Make sure we generate sort & meta info + val meta = tupleMetaDecls(tu.getType) + + meta.selects(i-1)(rec(tu)) + + case Let(i, e, b) => { + val re = rec(e) + z3Vars = z3Vars + (i -> re) + val rb = rec(b) + z3Vars = z3Vars - i + rb + } + case Waypoint(_, e) => rec(e) + case e @ Error(_) => { + val tpe = e.getType + val newAST = z3.mkFreshConst("errorValue", typeToSort(tpe)) + // Might introduce dupplicates (e), but no worries here + variables += (e -> newAST) + newAST + } + case v @ Variable(id) => z3Vars.get(id) match { + case Some(ast) => ast + case None => { + val newAST = z3.mkFreshConst(id.uniqueName, typeToSort(v.getType)) + z3Vars = z3Vars + (id -> newAST) + variables += (v -> newAST) newAST } - case v @ Variable(id) => z3Vars.get(id) match { - case Some(ast) => ast - case None => { - // if (id.isLetBinder) { - // scala.sys.error("Error in formula being translated to Z3: identifier " + id + " seems to have escaped its let-definition") - // } - - // Remove this safety check, since choose() expresions are now - // translated to non-unrollable variables, that end up here. - // assert(!this.isInstanceOf[FairZ3Solver], "Trying to convert unknown variable '"+id+"' while using FairZ3") - - val newAST = z3.mkFreshConst(id.uniqueName/*name*/, typeToSort(v.getType)) - z3Vars = z3Vars + (id -> newAST) - exprToZ3Id += (v -> newAST) - z3IdToExpr += (newAST -> v) - newAST - } - } + } - case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) - case And(exs) => z3.mkAnd(exs.map(rec(_)): _*) - case Or(exs) => z3.mkOr(exs.map(rec(_)): _*) - case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) - case Iff(l, r) => { - val rl = rec(l) - val rr = rec(r) - // z3.mkIff used to trigger a bug - // z3.mkAnd(z3.mkImplies(rl, rr), z3.mkImplies(rr, rl)) - z3.mkIff(rl, rr) - } - case Not(Iff(l, r)) => z3.mkXor(rec(l), rec(r)) - case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) - case Not(e) => z3.mkNot(rec(e)) - case IntLiteral(v) => z3.mkInt(v, intSort) - case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() - case UnitLiteral => unitValue - case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) - case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) - case Minus(l, r) => z3.mkSub(rec(l), rec(r)) - case Times(l, r) => z3.mkMul(rec(l), rec(r)) - case Division(l, r) => z3.mkDiv(rec(l), rec(r)) - case Modulo(l, r) => z3.mkMod(rec(l), rec(r)) - case UMinus(e) => z3.mkUnaryMinus(rec(e)) - case LessThan(l, r) => z3.mkLT(rec(l), rec(r)) - case LessEquals(l, r) => z3.mkLE(rec(l), rec(r)) - case GreaterThan(l, r) => z3.mkGT(rec(l), rec(r)) - case GreaterEquals(l, r) => z3.mkGE(rec(l), rec(r)) - case c @ CaseClass(cd, args) => { - val constructor = adtConstructors(cd) - constructor(args.map(rec(_)): _*) - } - case c @ CaseClassSelector(_, cc, sel) => { - val selector = adtFieldSelectors(sel) - selector(rec(cc)) - } - case c @ CaseClassInstanceOf(ccd, e) => { - val tester = adtTesters(ccd) - tester(rec(e)) - } - case f @ FunctionInvocation(fd, args) => { - z3.mkApp(functionDefToDecl(fd), args.map(rec(_)): _*) - } - - case SetEquals(s1, s2) => z3.mkEq(rec(s1), rec(s2)) - case ElementOfSet(e, s) => z3.mkSetSubset(z3.mkSetAdd(z3.mkEmptySet(typeToSort(e.getType)), rec(e)), rec(s)) - case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) - case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) - case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) - case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) - case f @ FiniteSet(elems) => elems.foldLeft(z3.mkEmptySet(typeToSort(f.getType.asInstanceOf[SetType].base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) - case SetCardinality(s) => { - val rs = rec(s) - setCardFuns(s.getType.asInstanceOf[SetType].base)(rs) - } - case SetMin(s) => intSetMinFun(rec(s)) - case SetMax(s) => intSetMaxFun(rec(s)) - case f @ FiniteMap(elems) => f.getType match { - case tpe@MapType(fromType, toType) => - typeToSort(tpe) //had to add this here because the mapRangeNoneConstructors was not yet constructed... - val fromSort = typeToSort(fromType) - val toSort = typeToSort(toType) - elems.foldLeft(z3.mkConstArray(fromSort, mapRangeNoneConstructors(toType)())){ case (ast, (k,v)) => z3.mkStore(ast, rec(k), mapRangeSomeConstructors(toType)(rec(v))) } - case errorType => scala.sys.error("Unexpected type for finite map: " + (ex, errorType)) - } - case mg @ MapGet(m,k) => m.getType match { - case MapType(fromType, toType) => - val selected = z3.mkSelect(rec(m), rec(k)) - mapRangeValueSelectors(toType)(selected) - case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) - } - case MapUnion(m1,m2) => m1.getType match { - case MapType(ft, tt) => m2 match { - case FiniteMap(ss) => - ss.foldLeft(rec(m1)){ - case (ast, (k, v)) => z3.mkStore(ast, rec(k), mapRangeSomeConstructors(tt)(rec(v))) - } - case _ => scala.sys.error("map updates can only be applied with concrete map instances") - } - case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) - } - case MapIsDefinedAt(m,k) => m.getType match { - case MapType(ft, tt) => z3.mkDistinct(z3.mkSelect(rec(m), rec(k)), mapRangeNoneConstructors(tt)()) - case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) - } - case fill @ ArrayFill(length, default) => { - val at@ArrayType(base) = fill.getType - typeToSort(at) - val cons = arrayTupleCons(at) - val ar = z3.mkConstArray(typeToSort(base), rec(default)) - val res = cons(ar, rec(length)) - res - } - case ArraySelect(a, index) => { - typeToSort(a.getType) - val ar = rec(a) - val getArray = arrayTupleSelectorArray(a.getType) - val res = z3.mkSelect(getArray(ar), rec(index)) - res - } - case ArrayUpdated(a, index, newVal) => { - typeToSort(a.getType) - val ar = rec(a) - val getArray = arrayTupleSelectorArray(a.getType) - val getLength = arrayTupleSelectorLength(a.getType) - val cons = arrayTupleCons(a.getType) - val store = z3.mkStore(getArray(ar), rec(index), rec(newVal)) - val res = cons(store, getLength(ar)) - res - } - case ArrayLength(a) => { - typeToSort(a.getType) - val ar = rec(a) - val getLength = arrayTupleSelectorLength(a.getType) - val res = getLength(ar) - res + case ite @ IfExpr(c, t, e) => z3.mkITE(rec(c), rec(t), rec(e)) + case And(exs) => z3.mkAnd(exs.map(rec(_)): _*) + case Or(exs) => z3.mkOr(exs.map(rec(_)): _*) + case Implies(l, r) => z3.mkImplies(rec(l), rec(r)) + case Iff(l, r) => + val rl = rec(l) + val rr = rec(r) + z3.mkIff(rl, rr) + + case Not(Iff(l, r)) => z3.mkXor(rec(l), rec(r)) + case Not(Equals(l, r)) => z3.mkDistinct(rec(l), rec(r)) + case Not(e) => z3.mkNot(rec(e)) + case IntLiteral(v) => z3.mkInt(v, typeToSort(Int32Type)) + case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() + case UnitLiteral => unitValue + case Equals(l, r) => z3.mkEq(rec( l ), rec( r ) ) + case Plus(l, r) => z3.mkAdd(rec(l), rec(r)) + case Minus(l, r) => z3.mkSub(rec(l), rec(r)) + case Times(l, r) => z3.mkMul(rec(l), rec(r)) + case Division(l, r) => z3.mkDiv(rec(l), rec(r)) + case Modulo(l, r) => z3.mkMod(rec(l), rec(r)) + case UMinus(e) => z3.mkUnaryMinus(rec(e)) + case LessThan(l, r) => z3.mkLT(rec(l), rec(r)) + case LessEquals(l, r) => z3.mkLE(rec(l), rec(r)) + case GreaterThan(l, r) => z3.mkGT(rec(l), rec(r)) + case GreaterEquals(l, r) => z3.mkGE(rec(l), rec(r)) + case c @ CaseClass(ct, args) => + typeToSort(ct) // Making sure the sort is defined + val constructor = adtConstructors(ct) + constructor(args.map(rec(_)): _*) + + case c @ CaseClassSelector(cct, cc, sel) => + typeToSort(cct) // Making sure the sort is defined + val selector = adtFieldSelectors(cct, sel) + selector(rec(cc)) + + case c @ CaseClassInstanceOf(cct, e) => + typeToSort(cct) // Making sure the sort is defined + val tester = adtTesters(cct) + tester(rec(e)) + + case f @ FunctionInvocation(tfd, args) => + z3.mkApp(functionDefToDecl(tfd), args.map(rec(_)): _*) + + case SetEquals(s1, s2) => z3.mkEq(rec(s1), rec(s2)) + case ElementOfSet(e, s) => z3.mkSetSubset(z3.mkSetAdd(z3.mkEmptySet(typeToSort(e.getType)), rec(e)), rec(s)) + case SubsetOf(s1, s2) => z3.mkSetSubset(rec(s1), rec(s2)) + case SetIntersection(s1, s2) => z3.mkSetIntersect(rec(s1), rec(s2)) + case SetUnion(s1, s2) => z3.mkSetUnion(rec(s1), rec(s2)) + case SetDifference(s1, s2) => z3.mkSetDifference(rec(s1), rec(s2)) + case f @ FiniteSet(elems) => elems.foldLeft(z3.mkEmptySet(typeToSort(f.getType.asInstanceOf[SetType].base)))((ast, el) => z3.mkSetAdd(ast, rec(el))) + case SetCardinality(s) => + val rs = rec(s) + setCardDecls(s.getType)(rs) + + case SetMin(s) => intSetMinFun(rec(s)) + case SetMax(s) => intSetMaxFun(rec(s)) + case f @ FiniteMap(elems) => f.getType match { + case tpe@MapType(fromType, toType) => + typeToSort(tpe) //had to add this here because the mapRangeNoneConstructors was not yet constructed... + val fromSort = typeToSort(fromType) + val toSort = typeToSort(toType) + elems.foldLeft(z3.mkConstArray(fromSort, mapRangeNoneConstructors(toType)())){ case (ast, (k,v)) => z3.mkStore(ast, rec(k), mapRangeSomeConstructors(toType)(rec(v))) } + case errorType => scala.sys.error("Unexpected type for finite map: " + (ex, errorType)) + } + case mg @ MapGet(m,k) => m.getType match { + case MapType(fromType, toType) => + val selected = z3.mkSelect(rec(m), rec(k)) + mapRangeValueSelectors(toType)(selected) + case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) + } + case MapUnion(m1,m2) => m1.getType match { + case MapType(ft, tt) => m2 match { + case FiniteMap(ss) => + ss.foldLeft(rec(m1)){ + case (ast, (k, v)) => z3.mkStore(ast, rec(k), mapRangeSomeConstructors(tt)(rec(v))) + } + case _ => scala.sys.error("map updates can only be applied with concrete map instances") } + case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) + } + case MapIsDefinedAt(m,k) => m.getType match { + case MapType(ft, tt) => z3.mkDistinct(z3.mkSelect(rec(m), rec(k)), mapRangeNoneConstructors(tt)()) + case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) + } + case fill @ ArrayFill(length, default) => + val at @ ArrayType(base) = fill.getType + typeToSort(at) + val meta = arrayMetaDecls(at) + + val ar = z3.mkConstArray(typeToSort(base), rec(default)) + val res = meta.cons(ar, rec(length)) + res + + case ArraySelect(a, index) => + typeToSort(a.getType) + val ar = rec(a) + val getArray = arrayMetaDecls(a.getType).select + val res = z3.mkSelect(getArray(ar), rec(index)) + res + + case ArrayUpdated(a, index, newVal) => + typeToSort(a.getType) + val ar = rec(a) + val meta = arrayMetaDecls(a.getType) + + val store = z3.mkStore(meta.select(ar), rec(index), rec(newVal)) + val res = meta.cons(store, meta.length(ar)) + res + + case ArrayLength(a) => + typeToSort(a.getType) + val ar = rec(a) + val meta = arrayMetaDecls(a.getType) + val res = meta.length(ar) + res + + case arr @ FiniteArray(exprs) => { + val ArrayType(innerType) = arr.getType + val arrayType = arr.getType + val a: Expr = ArrayFill(IntLiteral(exprs.length), simplestValue(innerType)).setType(arrayType) + val u = exprs.zipWithIndex.foldLeft(a)((array, expI) => ArrayUpdated(array, IntLiteral(expI._2), expI._1).setType(arrayType)) + rec(u) + } + case Distinct(exs) => z3.mkDistinct(exs.map(rec(_)): _*) - case arr @ FiniteArray(exprs) => { - val ArrayType(innerType) = arr.getType - val arrayType = arr.getType - val a: Expr = ArrayFill(IntLiteral(exprs.length), simplestValue(innerType)).setType(arrayType) - val u = exprs.zipWithIndex.foldLeft(a)((array, expI) => ArrayUpdated(array, IntLiteral(expI._2), expI._1).setType(arrayType)) - rec(u) - } - case Distinct(exs) => z3.mkDistinct(exs.map(rec(_)): _*) - - case _ => { - reporter.warning("Can't handle this in translation to Z3: " + ex) - throw new CantTranslateException - } - }) - recResult + case gv @ GenericValue(tp, id) => + z3.mkApp(genericValueToDecl(gv)) + + case _ => { + reporter.warning("Can't handle this in translation to Z3: " + ex) + throw new CantTranslateException + } } try { @@ -592,156 +647,121 @@ trait AbstractZ3Solver } } - protected[leon] def fromZ3Formula(model: Z3Model, tree : Z3AST, expectedType: Option[TypeTree] = None) : Expr = { - def rec(t: Z3AST, expType: Option[TypeTree] = None) : Expr = expType match { - case _ if z3IdToExpr contains t => z3IdToExpr(t) - - case Some(MapType(kt,vt)) => - model.getArrayValue(t) match { - case None => throw new CantTranslateException(t) - case Some((map, elseValue)) => - val singletons: Seq[(Expr, Expr)] = map.map(e => (e, z3.getASTKind(e._2))).collect { - case ((index, value), Z3AppAST(someCons, arg :: Nil)) if someCons == mapRangeSomeConstructors(vt) => (rec(index, Some(kt)), rec(arg, Some(vt))) - }.toSeq - FiniteMap(singletons).setType(expType.get) - } - case Some(SetType(dt)) => - model.getSetValue(t) match { - case None => throw new CantTranslateException(t) - case Some(set) => { - val elems = set.map(e => rec(e, Some(dt))) - FiniteSet(elems.toSeq).setType(expType.get) - } - } - case Some(ArrayType(dt)) => { - val Z3AppAST(decl, args) = z3.getASTKind(t) - assert(args.size == 2) - val IntLiteral(length) = rec(args(1), Some(Int32Type)) - val array = model.getArrayValue(args(0)) match { - case None => throw new CantTranslateException(t) - case Some((map, elseValue)) => { - val exprs = map.foldLeft((1 to length).map(_ => rec(elseValue, Some(dt))).toSeq)((acc, p) => { - val IntLiteral(index) = rec(p._1, Some(Int32Type)) - if(index >= 0 && index < length) - acc.updated(index, rec(p._2, Some(dt))) - else acc - }) - FiniteArray(exprs) - } - } - array - } - case other => - if(t == unitValue) - UnitLiteral - else z3.getASTKind(t) match { - case Z3AppAST(decl, args) => { - val argsSize = args.size - if(argsSize == 0 && z3IdToExpr.isDefinedAt(t)) { - val toRet = z3IdToExpr(t) - // println("Map says I should replace " + t + " by " + toRet) - toRet - } else if(isKnownDecl(decl)) { - val fd = functionDeclToDef(decl) - assert(fd.args.size == argsSize) - FunctionInvocation(fd, (args zip fd.args).map(p => rec(p._1,Some(p._2.tpe)))) - } else if(argsSize == 1 && reverseADTTesters.isDefinedAt(decl)) { - CaseClassInstanceOf(reverseADTTesters(decl), rec(args(0))) - } else if(argsSize == 1 && reverseADTFieldSelectors.isDefinedAt(decl)) { - val (ccd, fid) = reverseADTFieldSelectors(decl) - CaseClassSelector(ccd, rec(args(0)), fid) - } else if(reverseADTConstructors.isDefinedAt(decl)) { - val ccd = reverseADTConstructors(decl) - assert(argsSize == ccd.fields.size) - CaseClass(ccd, (args zip ccd.fields).map(p => rec(p._1, Some(p._2.tpe)))) - } else if(reverseTupleConstructors.isDefinedAt(decl)) { - val TupleType(subTypes) = reverseTupleConstructors(decl) - val rargs = args.zip(subTypes).map(p => rec(p._1, Some(p._2))) - Tuple(rargs) - } else { - import Z3DeclKind._ - val rargs = args.map(rec(_)) - z3.getDeclKind(decl) match { - case OpTrue => BooleanLiteral(true) - case OpFalse => BooleanLiteral(false) - case OpEq => Equals(rargs(0), rargs(1)) - case OpITE => - assert(argsSize == 3) - val r0 = rargs(0) - val r1 = rargs(1) - val r2 = rargs(2) - try { - IfExpr(r0, r1, r2).setType(leastUpperBound(r1.getType, r2.getType).get) - } catch { - case e: Throwable => - println("I was asking for lub because of this.") - println(t) - println("which was translated as") - println(IfExpr(r0,r1,r2)) - throw e - } - - case OpAnd => And(rargs) - case OpOr => Or(rargs) - case OpIff => Iff(rargs(0), rargs(1)) - case OpXor => Not(Iff(rargs(0), rargs(1))) - case OpNot => Not(rargs(0)) - case OpImplies => Implies(rargs(0), rargs(1)) - case OpLE => LessEquals(rargs(0), rargs(1)) - case OpGE => GreaterEquals(rargs(0), rargs(1)) - case OpLT => LessThan(rargs(0), rargs(1)) - case OpGT => GreaterThan(rargs(0), rargs(1)) - case OpAdd => { - assert(argsSize == 2) - Plus(rargs(0), rargs(1)) - } - case OpSub => { - assert(argsSize == 2) - Minus(rargs(0), rargs(1)) - } - case OpUMinus => UMinus(rargs(0)) - case OpMul => { - assert(argsSize == 2) - Times(rargs(0), rargs(1)) - } - case OpDiv => { - assert(argsSize == 2) - Division(rargs(0), rargs(1)) + protected[leon] def fromZ3Formula(model: Z3Model, tree : Z3AST) : Expr = { + def rec(t: Z3AST): Expr = { + val kind = z3.getASTKind(t) + val sort = z3.getSort(t) + + kind match { + case Z3NumeralIntAST(Some(v)) => IntLiteral(v) + case Z3AppAST(decl, args) => + val argsSize = args.size + if(argsSize == 0 && (variables containsZ3 t)) { + variables.toLeon(t) + } else if(functions containsZ3 decl) { + val tfd = functions.toLeon(decl) + assert(tfd.args.size == argsSize) + FunctionInvocation(tfd, args.map(rec)) + } else if(argsSize == 1 && (reverseADTTesters contains decl)) { + val cct = reverseADTTesters(decl) + CaseClassInstanceOf(cct, rec(args(0))) + } else if(argsSize == 1 && (reverseADTFieldSelectors contains decl)) { + val (cct, fid) = reverseADTFieldSelectors(decl) + CaseClassSelector(cct, rec(args(0)), fid) + } else if(reverseADTConstructors contains decl) { + val cct = reverseADTConstructors(decl) + assert(argsSize == cct.fields.size) + CaseClass(cct, args.map(rec)) + } else if (generics containsZ3 decl) { + generics.toLeon(decl) + } else { + sort match { + case LeonType(tp: TypeParameter) => + val id = t.toString.split("!").last.toInt + GenericValue(tp, id) + + case LeonType(tp: TupleType) => + val rargs = args.map(rec) + Tuple(rargs) + + case LeonType(ArrayType(dt)) => + assert(args.size == 2) + val IntLiteral(length) = rec(args(1)) + model.getArrayValue(args(0)) match { + case None => throw new CantTranslateException(t) + case Some((map, elseZ3Value)) => + val elseValue = rec(elseZ3Value) + var valuesMap = map.map { case (k,v) => + val IntLiteral(index) = rec(k) + (index -> rec(v)) + } + + FiniteArray(for (i <- 1 to length) yield { + valuesMap.getOrElse(i, elseValue) + }) } - case OpIDiv => { - assert(argsSize == 2) - Division(rargs(0), rargs(1)) - } - case OpMod => { - assert(argsSize == 2) - Modulo(rargs(0), rargs(1)) + + case LeonType(tpe @ MapType(kt, vt)) => + model.getArrayValue(t) match { + case None => throw new CantTranslateException(t) + case Some((map, elseZ3Value)) => + var values = map.toSeq.map { case (k, v) => (k, z3.getASTKind(v)) }.collect { + case (k, Z3AppAST(cons, arg :: Nil)) if cons == mapRangeSomeConstructors(vt) => + (rec(k), rec(arg)) + } + + FiniteMap(values).setType(tpe) } - case OpAsArray => { - assert(argsSize == 0) - throw new Exception("encountered OpAsArray") + + case LeonType(tpe @ SetType(dt)) => + model.getSetValue(t) match { + case None => throw new CantTranslateException(t) + case Some(set) => + val elems = set.map(e => rec(e)) + FiniteSet(elems.toSeq).setType(tpe) } - case other => { - System.err.println("Don't know what to do with this declKind : " + other) - System.err.println("The arguments are : " + args) - throw new CantTranslateException(t) + + case LeonType(UnitType) => + UnitLiteral + + case _ => + import Z3DeclKind._ + val rargs = args.map(rec(_)) + z3.getDeclKind(decl) match { + case OpTrue => BooleanLiteral(true) + case OpFalse => BooleanLiteral(false) + case OpEq => Equals(rargs(0), rargs(1)) + case OpITE => IfExpr(rargs(0), rargs(1), rargs(2)) + case OpAnd => And(rargs) + case OpOr => Or(rargs) + case OpIff => Iff(rargs(0), rargs(1)) + case OpXor => Not(Iff(rargs(0), rargs(1))) + case OpNot => Not(rargs(0)) + case OpImplies => Implies(rargs(0), rargs(1)) + case OpLE => LessEquals(rargs(0), rargs(1)) + case OpGE => GreaterEquals(rargs(0), rargs(1)) + case OpLT => LessThan(rargs(0), rargs(1)) + case OpGT => GreaterThan(rargs(0), rargs(1)) + case OpAdd => Plus(rargs(0), rargs(1)) + case OpSub => Minus(rargs(0), rargs(1)) + case OpUMinus => UMinus(rargs(0)) + case OpMul => Times(rargs(0), rargs(1)) + case OpDiv => Division(rargs(0), rargs(1)) + case OpIDiv => Division(rargs(0), rargs(1)) + case OpMod => Modulo(rargs(0), rargs(1)) + case other => + System.err.println("Don't know what to do with this declKind : " + other) + System.err.println("The arguments are : " + args) + throw new CantTranslateException(t) } - } } } - - case Z3NumeralIntAST(Some(v)) => IntLiteral(v) - case Z3NumeralIntAST(None) => { - reporter.info("Cannot read exact model from Z3: Integer does not fit in machine word") - reporter.info("Exiting procedure now") - sys.exit(0) - } - case other @ _ => { - System.err.println("Don't know what this is " + other) - throw new CantTranslateException(t) - } - } + case _ => + System.err.println("Can't handle "+t) + throw new CantTranslateException(t) + } } - rec(tree, expectedType) + rec(tree) } // Tries to convert a Z3AST into a *ground* Expr. Doesn't try very hard, because @@ -753,25 +773,30 @@ trait AbstractZ3Solver def rec(t : Z3AST) : Expr = z3.getASTKind(t) match { case Z3AppAST(decl, args) => { val argsSize = args.size - if(isKnownDecl(decl)) { - val fd = functionDeclToDef(decl) - FunctionInvocation(fd, args.map(rec)) + if(functions containsZ3 decl) { + val tfd = functions.toLeon(decl) + FunctionInvocation(tfd, args.map(rec)) } else if(argsSize == 1 && reverseADTTesters.isDefinedAt(decl)) { - CaseClassInstanceOf(reverseADTTesters(decl), rec(args(0))) + val cct = reverseADTTesters(decl) + CaseClassInstanceOf(cct, rec(args(0))) } else if(argsSize == 1 && reverseADTFieldSelectors.isDefinedAt(decl)) { - val (ccd, fid) = reverseADTFieldSelectors(decl) - CaseClassSelector(ccd, rec(args(0)), fid) + val (cct, fid) = reverseADTFieldSelectors(decl) + CaseClassSelector(cct, rec(args(0)), fid) } else if(reverseADTConstructors.isDefinedAt(decl)) { - val ccd = reverseADTConstructors(decl) - CaseClass(ccd, args.map(rec)) - } else if(reverseTupleConstructors.isDefinedAt(decl)) { - Tuple(args.map(rec)) + val cct = reverseADTConstructors(decl) + CaseClass(cct, args.map(rec)) } else { - import Z3DeclKind._ - z3.getDeclKind(decl) match { - case OpTrue => BooleanLiteral(true) - case OpFalse => BooleanLiteral(false) - case _ => throw e + z3.getSort(t) match { + case LeonType(t : TupleType) => + Tuple(args.map(rec)) + + case _ => + import Z3DeclKind._ + z3.getDeclKind(decl) match { + case OpTrue => BooleanLiteral(true) + case OpFalse => BooleanLiteral(false) + case _ => throw e + } } } } @@ -786,9 +811,9 @@ trait AbstractZ3Solver } } - protected[leon] def softFromZ3Formula(model: Z3Model, tree : Z3AST, expectedType: TypeTree) : Option[Expr] = { + protected[leon] def softFromZ3Formula(model: Z3Model, tree : Z3AST) : Option[Expr] = { try { - Some(fromZ3Formula(model, tree, Some(expectedType))) + Some(fromZ3Formula(model, tree)) } catch { case e: CantTranslateException => None } diff --git a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala index 71589d58fa81ede5a40c70afc928ca57d4d40fa8..b3952e4fd4477fdca29f6a96191929b8572578af 100644 --- a/src/main/scala/leon/solvers/z3/FairZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/FairZ3Solver.scala @@ -71,57 +71,32 @@ class FairZ3Solver(val context : LeonContext, val program: Program) ) toggleWarningMessages(true) - def isKnownDef(funDef: FunDef) : Boolean = functionMap.isDefinedAt(funDef) - - def functionDefToDecl(funDef: FunDef) : Z3FuncDecl = - functionMap.getOrElse(funDef, scala.sys.error("No Z3 definition found for function symbol " + funDef.id.name + ".")) - - def isKnownDecl(decl: Z3FuncDecl) : Boolean = reverseFunctionMap.isDefinedAt(decl) - - def functionDeclToDef(decl: Z3FuncDecl) : FunDef = - reverseFunctionMap.getOrElse(decl, scala.sys.error("No FunDef corresponds to Z3 definition " + decl + ".")) - - private var functionMap: Map[FunDef, Z3FuncDecl] = Map.empty - private var reverseFunctionMap: Map[Z3FuncDecl, FunDef] = Map.empty - private var axiomatizedFunctions : Set[FunDef] = Set.empty - - protected[leon] def prepareFunctions: Unit = { - functionMap = Map.empty - reverseFunctionMap = Map.empty - for (funDef <- program.definedFunctions) { - val sortSeq = funDef.args.map(vd => typeToSort(vd.tpe)) - val returnSort = typeToSort(funDef.returnType) - - val z3Decl = z3.mkFreshFuncDecl(funDef.id.name, sortSeq, returnSort) - functionMap = functionMap + (funDef -> z3Decl) - reverseFunctionMap = reverseFunctionMap + (z3Decl -> funDef) - } - } private def validateModel(model: Z3Model, formula: Expr, variables: Set[Identifier], silenceErrors: Boolean) : (Boolean, Map[Identifier,Expr]) = { if(!interrupted) { val functionsModel: Map[Z3FuncDecl, (Seq[(Seq[Z3AST], Z3AST)], Z3AST)] = model.getModelFuncInterpretations.map(i => (i._1, (i._2, i._3))).toMap val functionsAsMap: Map[Identifier, Expr] = functionsModel.flatMap(p => { - if(isKnownDecl(p._1)) { - val fd = functionDeclToDef(p._1) - if(!fd.hasImplementation) { + if(functions containsZ3 p._1) { + val tfd = functions.toLeon(p._1) + if(!tfd.hasImplementation) { val (cses, default) = p._2 - val ite = cses.foldLeft(fromZ3Formula(model, default, Some(fd.returnType)))((expr, q) => IfExpr( + val ite = cses.foldLeft(fromZ3Formula(model, default))((expr, q) => IfExpr( And( - q._1.zip(fd.args).map(a12 => Equals(fromZ3Formula(model, a12._1, Some(a12._2.tpe)), Variable(a12._2.id))) + q._1.zip(tfd.args).map(a12 => Equals(fromZ3Formula(model, a12._1), Variable(a12._2.id))) ), - fromZ3Formula(model, q._2, Some(fd.returnType)), + fromZ3Formula(model, q._2), expr)) - Seq((fd.id, ite)) + Seq((tfd.id, ite)) } else Seq() } else Seq() }).toMap + val constantFunctionsAsMap: Map[Identifier, Expr] = model.getModelConstantInterpretations.flatMap(p => { - if(isKnownDecl(p._1)) { - val fd = functionDeclToDef(p._1) - if(!fd.hasImplementation) { - Seq((fd.id, fromZ3Formula(model, p._2, Some(fd.returnType)))) + if(functions containsZ3 p._1) { + val tfd = functions.toLeon(p._1) + if(!tfd.hasImplementation) { + Seq((tfd.id, fromZ3Formula(model, p._2))) } else Seq() } else Seq() }).toMap @@ -157,23 +132,23 @@ class FairZ3Solver(val context : LeonContext, val program: Program) } } - private val funDefTemplateCache : MutableMap[FunDef, FunctionTemplate] = MutableMap.empty - private val exprTemplateCache : MutableMap[Expr , FunctionTemplate] = MutableMap.empty + private val funDefTemplateCache : MutableMap[TypedFunDef, FunctionTemplate] = MutableMap.empty + private val exprTemplateCache : MutableMap[Expr , FunctionTemplate] = MutableMap.empty - private def getTemplate(funDef: FunDef): FunctionTemplate = { - funDefTemplateCache.getOrElse(funDef, { - val res = FunctionTemplate.mkTemplate(this, funDef, true) - funDefTemplateCache += funDef -> res + private def getTemplate(tfd: TypedFunDef): FunctionTemplate = { + funDefTemplateCache.getOrElse(tfd, { + val res = FunctionTemplate.mkTemplate(this, tfd, true) + funDefTemplateCache += tfd -> res res }) } private def getTemplate(body: Expr): FunctionTemplate = { exprTemplateCache.getOrElse(body, { - val fakeFunDef = new FunDef(FreshIdentifier("fake", true), body.getType, variablesOf(body).toSeq.map(id => VarDecl(id, id.getType))) + val fakeFunDef = new FunDef(FreshIdentifier("fake", true), Nil, body.getType, variablesOf(body).toSeq.map(id => VarDecl(id, id.getType))) fakeFunDef.body = Some(body) - val res = FunctionTemplate.mkTemplate(this, fakeFunDef, false) + val res = FunctionTemplate.mkTemplate(this, fakeFunDef.typed, false) exprTemplateCache += body -> res res }) @@ -209,7 +184,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) reporter.debug("--- "+gen) for (((bast), (gen, origGen, ast, fis)) <- entries) { - reporter.debug(". "+bast +" ~> "+fis.map(_.funDef.id)) + reporter.debug(". "+bast +" ~> "+fis.map(_.tfd.signature)) } } } @@ -250,14 +225,13 @@ class FairZ3Solver(val context : LeonContext, val program: Program) // define an activating boolean... val template = getTemplate(expr) - val z3args = for (vd <- template.funDef.args) yield { - exprToZ3Id.get(Variable(vd.id)) match { + val z3args = for (vd <- template.tfd.args) yield { + variables.getZ3(Variable(vd.id)) match { case Some(ast) => ast case None => val ast = idToFreshZ3Id(vd.id) - exprToZ3Id += Variable(vd.id) -> ast - z3IdToExpr += ast -> Variable(vd.id) + variables += Variable(vd.id) -> ast ast } } @@ -305,7 +279,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) var reintroducedSelf : Boolean = false for(fi <- fis) { - val template = getTemplate(fi.funDef) + val template = getTemplate(fi.tfd) val (newExprs, newBlocks) = template.instantiate(id, fi.args) for((i, fis2) <- newBlocks) { @@ -330,12 +304,6 @@ class FairZ3Solver(val context : LeonContext, val program: Program) val solver = z3.mkSolver - for(funDef <- program.definedFunctions) { - if (funDef.annotations.contains("axiomatize") && !axiomatizedFunctions(funDef)) { - reporter.warning("Function " + funDef.id + " was marked for axiomatization but could not be handled.") - } - } - private var varsInVC = Set[Identifier]() private var frameExpressions = List[List[Expr]](Nil) @@ -404,7 +372,7 @@ class FairZ3Solver(val context : LeonContext, val program: Program) val assumptionsAsZ3Set: Set[Z3AST] = assumptionsAsZ3.toSet def z3CoreToCore(core: Seq[Z3AST]): Set[Expr] = { - core.filter(assumptionsAsZ3Set).map(ast => fromZ3Formula(null, ast, None) match { + core.filter(assumptionsAsZ3Set).map(ast => fromZ3Formula(null, ast) match { case n @ Not(Variable(_)) => n case v @ Variable(_) => v case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") diff --git a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala index 59380d8750cc74552ac497abaffb379026b00ba0..67605150dd7c107a4580dfe91882d18678aaae4e 100644 --- a/src/main/scala/leon/solvers/z3/FunctionTemplate.scala +++ b/src/main/scala/leon/solvers/z3/FunctionTemplate.scala @@ -16,11 +16,11 @@ import z3.scala._ import scala.collection.mutable.{Set=>MutableSet,Map=>MutableMap} -case class Z3FunctionInvocation(funDef: FunDef, args: Seq[Z3AST]) +case class Z3FunctionInvocation(tfd: TypedFunDef, args: Seq[Z3AST]) class FunctionTemplate private( solver: FairZ3Solver, - val funDef : FunDef, + val tfd : TypedFunDef, activatingBool : Identifier, condVars : Set[Identifier], exprVars : Set[Identifier], @@ -29,14 +29,10 @@ class FunctionTemplate private( private def isTerminatingForAllInputs : Boolean = ( isRealFunDef - && !funDef.hasPrecondition - && solver.getTerminator.terminates(funDef).isGuaranteed + && !tfd.hasPrecondition + && solver.getTerminator.terminates(tfd.fd).isGuaranteed ) - // if(isRealFunDef) { - // println("Just created template for %s... Safe? %s".format(funDef.id.name, isTerminatingForAllInputs.toString)) - // } - private val z3 = solver.z3 private val asClauses : Seq[Expr] = { @@ -47,11 +43,11 @@ class FunctionTemplate private( val z3ActivatingBool = solver.idToFreshZ3Id(activatingBool) - private val z3FunDefArgs = funDef.args.map( ad => solver.idToFreshZ3Id(ad.id)) + private val z3FunDefArgs = tfd.args.map( ad => solver.idToFreshZ3Id(ad.id)) private val zippedCondVars = condVars.map(id => (id, solver.idToFreshZ3Id(id))) private val zippedExprVars = exprVars.map(id => (id, solver.idToFreshZ3Id(id))) - private val zippedFunDefArgs = funDef.args.map(_.id) zip z3FunDefArgs + private val zippedFunDefArgs = tfd.args.map(_.id) zip z3FunDefArgs val idToZ3Ids: Map[Identifier, Z3AST] = { Map(activatingBool -> z3ActivatingBool) ++ @@ -65,7 +61,7 @@ class FunctionTemplate private( } private val blockers : Map[Identifier,Set[FunctionInvocation]] = { - val idCall = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) + val idCall = FunctionInvocation(tfd, tfd.args.map(_.toVariable)) Map((for((b, es) <- guardedExprs) yield { val calls = es.foldLeft(Set.empty[FunctionInvocation])((s,e) => s ++ functionCallsOf(e)) - idCall @@ -79,14 +75,14 @@ class FunctionTemplate private( val z3Blockers: Map[Z3AST,Set[Z3FunctionInvocation]] = blockers.map { case (b, funs) => - (idToZ3Ids(b) -> funs.map(fi => Z3FunctionInvocation(fi.funDef, fi.args.map(solver.toZ3Formula(_, idToZ3Ids).get)))) + (idToZ3Ids(b) -> funs.map(fi => Z3FunctionInvocation(fi.tfd, fi.args.map(solver.toZ3Formula(_, idToZ3Ids).get)))) } // We use a cache to create the same boolean variables. private val cache : MutableMap[Seq[Z3AST],Map[Z3AST,Z3AST]] = MutableMap.empty def instantiate(aVar : Z3AST, args : Seq[Z3AST]) : (Seq[Z3AST], Map[Z3AST,Set[Z3FunctionInvocation]]) = { - assert(args.size == funDef.args.size) + assert(args.size == tfd.args.size) // The "isRealFunDef" part is to prevent evaluation of "fake" // function templates, as generated from FairZ3Solver. @@ -94,10 +90,10 @@ class FunctionTemplate private( val ga = args.view.map(solver.asGround) if(ga.forall(_.isDefined)) { val leonArgs = ga.map(_.get).force - val invocation = FunctionInvocation(funDef, leonArgs) + val invocation = FunctionInvocation(tfd, leonArgs) solver.getEvaluator.eval(invocation) match { case EvaluationResults.Successful(result) => - val z3Invocation = z3.mkApp(solver.functionDefToDecl(funDef), args: _*) + val z3Invocation = z3.mkApp(solver.functionDefToDecl(tfd), args: _*) val z3Value = solver.toZ3Formula(result).get val asZ3 = z3.mkEq(z3Invocation, z3Value) return (Seq(asZ3), Map.empty) @@ -140,7 +136,7 @@ class FunctionTemplate private( } override def toString : String = { - "Template for def " + funDef.id + "(" + funDef.args.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + funDef.returnType + " is :\n" + + "Template for def " + tfd.id + "(" + tfd.args.map(a => a.id + " : " + a.tpe).mkString(", ") + ") : " + tfd.returnType + " is :\n" + " * Activating boolean : " + activatingBool + "\n" + " * Control booleans : " + condVars.toSeq.map(_.toString).mkString(", ") + "\n" + " * Expression vars : " + exprVars.toSeq.map(_.toString).mkString(", ") + "\n" + @@ -152,7 +148,7 @@ class FunctionTemplate private( object FunctionTemplate { val splitAndOrImplies = false - def mkTemplate(solver: FairZ3Solver, funDef: FunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { + def mkTemplate(solver: FairZ3Solver, tfd: TypedFunDef, isRealFunDef : Boolean = true) : FunctionTemplate = { val condVars : MutableSet[Identifier] = MutableSet.empty val exprVars : MutableSet[Identifier] = MutableSet.empty @@ -312,11 +308,11 @@ object FunctionTemplate { } // The precondition if it exists. - val prec : Option[Expr] = funDef.precondition.map(p => matchToIfThenElse(p)) + val prec : Option[Expr] = tfd.precondition.map(p => matchToIfThenElse(p)) - val newBody : Option[Expr] = funDef.body.map(b => matchToIfThenElse(b)) + val newBody : Option[Expr] = tfd.body.map(b => matchToIfThenElse(b)) - val invocation : Expr = FunctionInvocation(funDef, funDef.args.map(_.toVariable)) + val invocation : Expr = FunctionInvocation(tfd, tfd.args.map(_.toVariable)) val invocationEqualsBody : Option[Expr] = newBody match { case Some(body) if isRealFunDef => @@ -343,12 +339,12 @@ object FunctionTemplate { } // Now the postcondition. - funDef.postcondition match { + tfd.postcondition match { case Some((id, post)) => val newPost : Expr = replace(Map(Variable(id) -> invocation), matchToIfThenElse(post)) val postHolds : Expr = - if(funDef.hasPrecondition) { + if(tfd.hasPrecondition) { Implies(prec.get, newPost) } else { newPost @@ -360,7 +356,7 @@ object FunctionTemplate { } - new FunctionTemplate(solver, funDef, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*), + new FunctionTemplate(solver, tfd, activatingBool, Set(condVars.toSeq : _*), Set(exprVars.toSeq : _*), Map(guardedExprs.toSeq : _*), isRealFunDef) } } diff --git a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala index 59b0f79fe1446c1923a86e5d7493922283fa4f1a..40ccf71dd104945ee46ca529d5a4023b35f401ed 100644 --- a/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/UninterpretedZ3Solver.scala @@ -36,24 +36,6 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) ) toggleWarningMessages(true) - private var functionMap: Map[FunDef, Z3FuncDecl] = Map.empty - private var reverseFunctionMap: Map[Z3FuncDecl, FunDef] = Map.empty - protected[leon] def prepareFunctions : Unit = { - functionMap = Map.empty - reverseFunctionMap = Map.empty - for(funDef <- program.definedFunctions) { - val sortSeq = funDef.args.map(vd => typeToSort(vd.tpe)) - val returnSort = typeToSort(funDef.returnType) - - val z3Decl = z3.mkFreshFuncDecl(funDef.id.name, sortSeq, returnSort) - functionMap = functionMap + (funDef -> z3Decl) - reverseFunctionMap = reverseFunctionMap + (z3Decl -> funDef) - } - } - protected[leon] def functionDefToDecl(funDef: FunDef) : Z3FuncDecl = functionMap(funDef) - protected[leon] def functionDeclToDef(decl: Z3FuncDecl) : FunDef = reverseFunctionMap(decl) - protected[leon] def isKnownDecl(decl: Z3FuncDecl) : Boolean = reverseFunctionMap.isDefinedAt(decl) - initZ3 val solver = z3.mkSolver @@ -67,13 +49,13 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) solver.pop(lvl) } - private var variables = Set[Identifier]() + private var freeVariables = Set[Identifier]() private var containsFunCalls = false def assertCnstr(expression: Expr) { - variables ++= variablesOf(expression) + freeVariables ++= variablesOf(expression) containsFunCalls ||= containsFunctionCalls(expression) - solver.assertCnstr(toZ3Formula(expression).get) + solver.assertCnstr(toZ3Formula(expression).getOrElse(scala.sys.error("Failed to compile to Z3: "+expression))) } override def check: Option[Boolean] = { @@ -91,16 +73,16 @@ class UninterpretedZ3Solver(val context : LeonContext, val program: Program) } override def checkAssumptions(assumptions: Set[Expr]): Option[Boolean] = { - variables ++= assumptions.flatMap(variablesOf(_)) + freeVariables ++= assumptions.flatMap(variablesOf(_)) solver.checkAssumptions(assumptions.toSeq.map(toZ3Formula(_).get) : _*) } def getModel = { - modelToMap(solver.getModel, variables) + modelToMap(solver.getModel, freeVariables) } def getUnsatCore = { - solver.getUnsatCore.map(ast => fromZ3Formula(null, ast, None) match { + solver.getUnsatCore.map(ast => fromZ3Formula(null, ast) match { case n @ Not(Variable(_)) => n case v @ Variable(_) => v case x => scala.sys.error("Impossible element extracted from core: " + ast + " (as Leon tree : " + x + ")") diff --git a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala b/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala index ab2d7084c5facbdb13e559e26fd0d03348b520c8..569b3a0930ed6100b694dc056dbe1963d7784cfc 100644 --- a/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala +++ b/src/main/scala/leon/solvers/z3/Z3ModelReconstruction.scala @@ -20,31 +20,18 @@ trait Z3ModelReconstruction { def modelValue(model: Z3Model, id: Identifier, tpe: TypeTree = null) : Option[Expr] = { val expectedType = if(tpe == null) id.getType else tpe - if(exprToZ3Id.isDefinedAt(id.toVariable)) { - val z3ID : Z3AST = exprToZ3Id(id.toVariable) - + variables.getZ3(id.toVariable).flatMap { z3ID => expectedType match { case BooleanType => model.evalAs[Boolean](z3ID).map(BooleanLiteral(_)) case Int32Type => model.evalAs[Int](z3ID).map(IntLiteral(_)) case other => model.eval(z3ID) match { case None => None - case Some(t) => softFromZ3Formula(model, t, expectedType) + case Some(t) => softFromZ3Formula(model, t) } } - } else None + } } - // def modelValue(model: Z3Model, id: Identifier, tpe: TypeTree = null) : Option[Expr] = { - // val expectedType = if(tpe == null) id.getType else tpe - // - // if(exprToZ3Id.isDefinedAt(id.toVariable)) { - // val z3ID : Z3AST = exprToZ3Id(id.toVariable) - - - // rec(z3ID, expectedType) - // } else None - // } - def modelToMap(model: Z3Model, ids: Iterable[Identifier]) : Map[Identifier,Expr] = { var asMap = Map.empty[Identifier,Expr] diff --git a/src/main/scala/leon/synthesis/FileInterface.scala b/src/main/scala/leon/synthesis/FileInterface.scala index 0cb67cc294c51dc116237e6001c9b9dbf5d21d47..29445a5ca99b2198612c421fe2c42ba640499191 100644 --- a/src/main/scala/leon/synthesis/FileInterface.scala +++ b/src/main/scala/leon/synthesis/FileInterface.scala @@ -59,8 +59,8 @@ class FileInterface(reporter: Reporter) { before + newCode + after - case _ => - sys.error("Substitution requires RangePos on the input tree: "+fromTree) + case p => + sys.error("Substitution requires RangePos on the input tree: "+fromTree +": "+fromTree.getClass+" GOT" +p) } } diff --git a/src/main/scala/leon/synthesis/SimpleSearch.scala b/src/main/scala/leon/synthesis/SimpleSearch.scala index cb2dd57c192e12ce36d6040ac45ede0a68a686bd..eed1436c5f5a3025ff02191a5b59405bcc8c9eb8 100644 --- a/src/main/scala/leon/synthesis/SimpleSearch.scala +++ b/src/main/scala/leon/synthesis/SimpleSearch.scala @@ -92,7 +92,7 @@ class SimpleSearch(synth: Synthesizer, } def fundefToSol(p: Problem, fd: FunDef): Solution = { - Solution(BooleanLiteral(true), Set(), FunctionInvocation(fd, p.as.map(Variable(_)))) + Solution(BooleanLiteral(true), Set(), FunctionInvocation(fd.typed, p.as.map(Variable(_)))) } def solToSubProgram(p: Problem, s: Solution): SubProgram = { @@ -138,7 +138,7 @@ class SimpleSearch(synth: Synthesizer, Map(Variable(p.xs.head) -> res) } - val fd = new FunDef(FreshIdentifier("chimp", true), ret, freshAs.map(id => VarDecl(id, id.getType))) + val fd = new FunDef(FreshIdentifier("chimp", true), Nil, ret, freshAs.map(id => VarDecl(id, id.getType))) fd.precondition = Some(replace(map, p.pc)) fd.postcondition = Some((res.id, replace(map++mapPost, p.phi))) diff --git a/src/main/scala/leon/synthesis/Synthesizer.scala b/src/main/scala/leon/synthesis/Synthesizer.scala index 9bb0bda005c5d86367441727add33ae10e052570..cd1355e7261929638be069cd696fbd02be046ac6 100644 --- a/src/main/scala/leon/synthesis/Synthesizer.scala +++ b/src/main/scala/leon/synthesis/Synthesizer.scala @@ -105,7 +105,7 @@ class Synthesizer(val context : LeonContext, import purescala.TypeTrees.TupleType import purescala.Definitions.VarDecl - val mainObject = program.mainObject + val mainModule = program.mainModule // Create new fundef for the body val ret = TupleType(problem.xs.map(_.getType)) @@ -116,14 +116,14 @@ class Synthesizer(val context : LeonContext, Variable(id) -> TupleSelect(res, i+1) }.toMap - val fd = new FunDef(FreshIdentifier("finalTerm", true), ret, problem.as.map(id => VarDecl(id, id.getType))) + val fd = new FunDef(FreshIdentifier("finalTerm", true), Nil, ret, problem.as.map(id => VarDecl(id, id.getType))) fd.precondition = Some(And(problem.pc, sol.pre)) fd.postcondition = Some((res.id, replace(mapPost, problem.phi))) fd.body = Some(sol.term) val newDefs = sol.defs + fd - val npr = program.copy(mainObject = mainObject.copy(defs = mainObject.defs ++ newDefs)) + val npr = program.copy(mainModule = mainModule.copy(defs = mainModule.defs ++ newDefs)) (npr, newDefs) } diff --git a/src/main/scala/leon/synthesis/condabd/Report.scala b/src/main/scala/leon/synthesis/condabd/Report.scala index caf0124592c8db3c09cc5f3b39e2cd54d3e56f16..34330f9a5cc6a8cf45d16ffe8f09c4512ee183ae 100755 --- a/src/main/scala/leon/synthesis/condabd/Report.scala +++ b/src/main/scala/leon/synthesis/condabd/Report.scala @@ -1,6 +1,6 @@ package leon.synthesis.condabd -import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ObjectDef } +import leon.purescala.Definitions.{ TypedFunDef, VarDecl, Program, ModuleDef } trait Report { def summaryString: String @@ -21,7 +21,7 @@ case object EmptyReport extends Report { override def isSuccess = false } -case class FullReport(val function: FunDef, val synthInfo: SynthesisInfo) extends Report { +case class FullReport(val function: TypedFunDef, val synthInfo: SynthesisInfo) extends Report { import SynthesisInfo.Action._ import Report._ diff --git a/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala b/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala index 4575df5de360f15d3997fbbd6fde405ad68815d6..3a768c27b5546ca5aaddd5930b724017fc86b1c1 100755 --- a/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala +++ b/src/main/scala/leon/synthesis/condabd/SynthesizerExamples.scala @@ -12,7 +12,7 @@ import leon.solvers.z3._ import leon.purescala.TypeTrees.{ TypeTree => LeonType, _ } import leon.purescala.Trees.{ Variable => LeonVariable, _ } -import leon.purescala.Definitions.{ FunDef, Program } +import leon.purescala.Definitions.{ TypedFunDef, FunDef, Program } import leon.purescala.Common.{ Identifier, FreshIdentifier } import leon.purescala.TreeOps @@ -47,10 +47,10 @@ class SynthesizerForRuleExamples( val mainSolver: SolverFactory[SynthesisSolver], val program: Program, val desiredType: LeonType, - val holeFunDef: FunDef, + val tfd: TypedFunDef, val problem: Problem, val synthesisContext: SynthesisContext, - val evaluationStrategy: EvaluationStrategy, // = DefaultEvaluationStrategy(program, holeFunDef, synthesisContext.context), + val evaluationStrategy: EvaluationStrategy, // = DefaultEvaluationStrategy(program, tfd, synthesisContext.context), // number of condition expressions to try before giving up on that branch expression numberOfBooleanSnippets: Int = 5, numberOfCounterExamplesToGenerate: Int = 5, @@ -78,7 +78,7 @@ class SynthesizerForRuleExamples( info("numberOfCounterExamplesToGenerate: %d".format(numberOfCounterExamplesToGenerate)) // info("leonTimeout: %d".format(leonTimeout)) - info("holeFunDef: %s".format(holeFunDef)) + info("holeFunDef: %s".format(tfd)) info("problem: %s".format(problem.toString)) // flag denoting if a correct body has been synthesized @@ -165,7 +165,7 @@ class SynthesizerForRuleExamples( info("####################################") info("######Iteration #" + iteration + " ###############") info("####################################") - info("# precondition is: " + holeFunDef.precondition.getOrElse(BooleanLiteral(true))) + info("# precondition is: " + tfd.precondition.getOrElse(BooleanLiteral(true))) info("# accumulatingCondition is: " + accumulatingCondition) info("# accumulatingExpression(Unit) is: " + accumulatingExpression(UnitLiteral)) info("####################################") @@ -206,10 +206,10 @@ class SynthesizerForRuleExamples( if (candidates.size > 0) { // save current precondition and the old body since it will can be mutated during evaluation - val oldPreconditionSaved = holeFunDef.precondition - val oldBodySaved = holeFunDef.body + val oldPreconditionSaved = tfd.precondition + val oldBodySaved = tfd.body // set initial precondition - holeFunDef.precondition = Some(initialPrecondition) + tfd.fd.precondition = Some(initialPrecondition) val ranker = evaluationStrategy.getRanker(candidates, accumulatingExpression, gatheredExamples) exampleRunner = evaluationStrategy.getExampleRunner @@ -238,8 +238,8 @@ class SynthesizerForRuleExamples( fine("Failed examples for the maximum candidate: " + examplesPartition) // restore original precondition and body - holeFunDef.precondition = oldPreconditionSaved - holeFunDef.body = oldBodySaved + tfd.fd.precondition = oldPreconditionSaved + tfd.fd.body = oldBodySaved // check for timeouts if (!keepGoing) break @@ -280,7 +280,7 @@ class SynthesizerForRuleExamples( synthInfo.iterations = iteration synthInfo.numberOfEnumeratedExpressions = numberOfTested info("We are done, in time: " + synthInfo.last) - return new FullReport(holeFunDef, synthInfo) + return new FullReport(tfd, synthInfo) } if (variableRefinedBranch) { @@ -322,7 +322,7 @@ class SynthesizerForRuleExamples( var ind = 0 while (ind < number && changed) { // analyze the program - val (solved, map) = analyzeFunction(holeFunDef) + val (solved, map) = analyzeFunction(tfd) // check if solver could solved this instance if (solved == false && !map.isEmpty) { @@ -396,15 +396,15 @@ class SynthesizerForRuleExamples( inSynthBoolean = new InSynth(allDeclarations, BooleanType, true) // funDef of the hole - fine("postcondition is: " + holeFunDef.postcondition.get) + fine("postcondition is: " + tfd.postcondition.get) fine("declarations we see: " + allDeclarations.map(_.toString).mkString("\n")) // interactivePause // accumulate precondition for the remaining branch to synthesize accumulatingCondition = BooleanLiteral(true) // save initial precondition - initialPrecondition = holeFunDef.precondition.getOrElse(BooleanLiteral(true)) - val holeFunDefBody = holeFunDef.body.get + initialPrecondition = tfd.precondition.getOrElse(BooleanLiteral(true)) + val holeFunDefBody = tfd.body.get // accumulate the final expression of the hole accumulatingExpression = (finalExp: Expr) => { def replaceChoose(expr: Expr) = expr match { @@ -430,7 +430,7 @@ class SynthesizerForRuleExamples( // loader.variableDeclarations, loader.classMap, mainSolver, reporter) // calculate cases that should not happen - refiner = new Filter(program, holeFunDef, variableRefiner) + refiner = new Filter(program, tfd, variableRefiner) gatheredExamples = ArrayBuffer(introduceExamples().map(Example(_)): _*) fine("Introduced examples: " + gatheredExamples.mkString("\t")) @@ -439,22 +439,22 @@ class SynthesizerForRuleExamples( def tryToSynthesizeBranch(snippetTree: Expr, examplesPartition: (Seq[Example], Seq[Example])): Boolean = { val (succeededExamples, failedExamples) = examplesPartition // replace hole in the body with the whole if-then-else structure, with current snippet tree - val oldBody = holeFunDef.body.get + val oldBody = tfd.body.get val newBody = accumulatingExpression(snippetTree) - holeFunDef.body = Some(newBody) + tfd.fd.body = Some(newBody) // precondition - val oldPrecondition = holeFunDef.precondition.getOrElse(BooleanLiteral(true)) - holeFunDef.precondition = Some(initialPrecondition) + val oldPrecondition = tfd.precondition.getOrElse(BooleanLiteral(true)) + tfd.fd.precondition = Some(initialPrecondition) snippetTree.setType(desiredType) //holeFunDef.getBody.setType(hole.desiredType) - info("Current candidate solution is:\n" + holeFunDef) + info("Current candidate solution is:\n" + tfd) if (failedExamples.isEmpty) { // check if solver could solved this instance - fine("Analyzing program for funDef:" + holeFunDef) - val (result, map) = analyzeFunction(holeFunDef) + fine("Analyzing program for funDef:" + tfd) + val (result, map) = analyzeFunction(tfd) info("Solver returned: " + result + " with CE " + map) if (result) { @@ -478,7 +478,7 @@ class SynthesizerForRuleExamples( var preconditionToRestore = Some(oldPrecondition) // because first initial test - holeFunDef.precondition = preconditionToRestore + tfd.fd.precondition = preconditionToRestore // get counterexamples // info("Going to generating counterexamples: " + holeFunDef) @@ -567,9 +567,9 @@ class SynthesizerForRuleExamples( } // try finally { // set these to the FunDef - holeFunDef.precondition = preconditionToRestore + tfd.fd.precondition = preconditionToRestore // restore old body (we accumulate expression) - holeFunDef.body = Some(oldBody) + tfd.fd.body = Some(oldBody) } } @@ -607,14 +607,14 @@ class SynthesizerForRuleExamples( // TODO take care of this mess val newFunId = FreshIdentifier("tempIntroducedFunction22") - val newFun = new FunDef(newFunId, holeFunDef.returnType, holeFunDef.args) + val newFun = new FunDef(newFunId, tfd.fd.tparams, tfd.fd.returnType, tfd.fd.args) // newFun.precondition = Some(newCondition) newFun.precondition = Some(initialPrecondition) - newFun.postcondition = holeFunDef.postcondition + newFun.postcondition = tfd.fd.postcondition def replaceFunDef(expr: Expr) = expr match { - case FunctionInvocation(`holeFunDef`, args) => - Some(FunctionInvocation(newFun, args)) + case FunctionInvocation(`tfd`, args) => + Some(FunctionInvocation(newFun.typed(tfd.tps), args)) case _ => None } @@ -644,8 +644,8 @@ class SynthesizerForRuleExamples( finest("New fun for Error evaluation: " + newFun) // println("new candidate: " + newBody) - val newProgram = program.copy(mainObject = - program.mainObject.copy(defs = newFun +: program.mainObject.defs )) + val newProgram = program.copy(mainModule = + program.mainModule.copy(defs = newFun +: program.mainModule.defs )) // println("new program: " + newProgram) val _evaluator = new CodeGenEvaluator(synthesisContext.context, newProgram @@ -681,14 +681,14 @@ class SynthesizerForRuleExamples( // throw new RuntimeException("should not go here") // TODO take care of this mess val newFunId = FreshIdentifier("tempIntroducedFunction22") - val newFun = new FunDef(newFunId, holeFunDef.returnType, holeFunDef.args) + val newFun = new FunDef(newFunId, tfd.fd.tparams, tfd.fd.returnType, tfd.fd.args) // newFun.precondition = Some(newCondition) newFun.precondition = Some(initialPrecondition) - newFun.postcondition = holeFunDef.postcondition + newFun.postcondition = tfd.fd.postcondition def replaceFunDef(expr: Expr) = expr match { - case FunctionInvocation(`holeFunDef`, args) => - Some(FunctionInvocation(newFun, args)) + case FunctionInvocation(`tfd`, args) => + Some(FunctionInvocation(newFun.typed(tfd.tps), args)) case _ => None } @@ -716,8 +716,8 @@ class SynthesizerForRuleExamples( finest("New fun for Error evaluation: " + newFun) // println("new candidate: " + newBody) - val newProgram = program.copy(mainObject = - program.mainObject.copy(defs = newFun +: program.mainObject.defs )) + val newProgram = program.copy(mainModule = + program.mainModule.copy(defs = newFun +: program.mainModule.defs )) // println("new program: " + newProgram) val _evaluator = new CodeGenEvaluator(synthesisContext.context, newProgram, @@ -747,10 +747,10 @@ class SynthesizerForRuleExamples( if (!implyCounterExamples) { // if expression implies counterexamples add it to the precondition and try to validate program - holeFunDef.precondition = Some(newPathCondition) + tfd.fd.precondition = Some(newPathCondition) // do analysis - val (valid, map) = analyzeFunction(holeFunDef) + val (valid, map) = analyzeFunction(tfd) // program is valid, we have a branch if (valid) { // we found a branch diff --git a/src/main/scala/leon/synthesis/condabd/evaluation/CodeGenExampleRunner.scala b/src/main/scala/leon/synthesis/condabd/evaluation/CodeGenExampleRunner.scala index 02efa3d24dc73015acd5889a8680541a38ab8c62..f77ee5683ed96e68741c7e9801f87acdc0e74bc0 100644 --- a/src/main/scala/leon/synthesis/condabd/evaluation/CodeGenExampleRunner.scala +++ b/src/main/scala/leon/synthesis/condabd/evaluation/CodeGenExampleRunner.scala @@ -7,7 +7,7 @@ import leon._ import leon.evaluators._ import leon.evaluators.EvaluationResults._ import leon.purescala.Trees._ -import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ObjectDef } +import leon.purescala.Definitions.{ TypedFunDef, FunDef, VarDecl, Program, ModuleDef } import leon.purescala.Common.{ Identifier, FreshIdentifier } import leon.purescala.TreeOps import leon.codegen.CodeGenParams @@ -17,7 +17,7 @@ import ranking._ import _root_.insynth.util.logging.HasLogger -case class CodeGenExampleRunner(program: Program, funDef: FunDef, ctx: LeonContext, +case class CodeGenExampleRunner(program: Program, tfd: TypedFunDef, ctx: LeonContext, candidates: Seq[Candidate], inputExamples: Seq[Example], params: CodeGenParams = CodeGenParams(maxFunctionInvocations = 200, checkContracts = true)) extends ExampleRunner(inputExamples) with HasLogger { @@ -33,7 +33,7 @@ case class CodeGenExampleRunner(program: Program, funDef: FunDef, ctx: LeonConte examples.map( ex => { val map = ex.map - for(id <- funDef.args.map(_.id)) yield + for(id <- tfd.args.map(_.id)) yield map(id) } ) @@ -44,7 +44,7 @@ case class CodeGenExampleRunner(program: Program, funDef: FunDef, ctx: LeonConte StopwatchCollections.get("Compilation").newStopwatch profile getEvaluator.compile(expr, ids).get } - val candidateClosures = candidates.map(cand => compile(cand.prepareExpression, funDef.args.map(_.id))) + val candidateClosures = candidates.map(cand => compile(cand.prepareExpression, tfd.args.map(_.id))) override def evaluate(candidateInd: Int, exampleInd: Int) = { val closure = candidateClosures(candidateInd) @@ -64,24 +64,24 @@ case class CodeGenExampleRunner(program: Program, funDef: FunDef, ctx: LeonConte def evaluate(expr: Expr, args: Seq[Expr]) { fine("to evaluate: " + expr + " for: " + args) - val closure = compile(expr, funDef.args.map(_.id)) + val closure = compile(expr, tfd.args.map(_.id)) evaluate(closure, args) } override def evaluate(expr: Expr, mapping: Map[Identifier, Expr]) = { fine("to evaluate: " + expr + " for mapping: " + mapping) - val closure = compile(expr, funDef.args.map(_.id)) + val closure = compile(expr, tfd.args.map(_.id)) - evaluate(closure, funDef.args.map(arg => mapping(arg.id))) + evaluate(closure, tfd.args.map(arg => mapping(arg.id))) } override def evaluateToResult(expr: Expr, mapping: Map[Identifier, Expr]): Result = { fine("to evaluate: " + expr + " for mapping: " + mapping) - val closure = compile(expr, funDef.args.map(_.id)) + val closure = compile(expr, tfd.args.map(_.id)) - closure(funDef.args.map(arg => mapping(arg.id))) + closure(tfd.args.map(arg => mapping(arg.id))) } def evaluate(evalClosure: Seq[Expr] => Result, args: Seq[Expr]) = { @@ -106,7 +106,7 @@ case class CodeGenExampleRunner(program: Program, funDef: FunDef, ctx: LeonConte entering("filter(" + prec + ")") fine("Old counterExamples.size: " + examples.size) - val closure = compile(prec, funDef.args.map(_.id)) + val closure = compile(prec, tfd.args.map(_.id)) val (newTransformed, newExamples) = ((_examples zip examples) filter { case ((transformedExample, _)) => @@ -123,7 +123,7 @@ case class CodeGenExampleRunner(program: Program, funDef: FunDef, ctx: LeonConte override def countPassed(expressionToCheck: Expr) = { fine("expressionToCheck: " + expressionToCheck) - val closure = compile(expressionToCheck, funDef.args.map(_.id)) + val closure = compile(expressionToCheck, tfd.args.map(_.id)) val (passed, failed) = (_examples zip examples).partition( pair => evaluate(closure, pair._1) diff --git a/src/main/scala/leon/synthesis/condabd/evaluation/DefaultExampleRunner.scala b/src/main/scala/leon/synthesis/condabd/evaluation/DefaultExampleRunner.scala index e6930d7b5ba86749e2fb5c8ef6c66e6475af6804..1e751ab3ee9cf3746a5bac925333f83ea749471e 100644 --- a/src/main/scala/leon/synthesis/condabd/evaluation/DefaultExampleRunner.scala +++ b/src/main/scala/leon/synthesis/condabd/evaluation/DefaultExampleRunner.scala @@ -7,7 +7,7 @@ import leon._ import leon.evaluators._ import leon.evaluators.EvaluationResults._ import leon.purescala.Trees._ -import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ObjectDef } +import leon.purescala.Definitions.{ TypedFunDef, VarDecl, Program, ModuleDef } import leon.purescala.Common.{ Identifier, FreshIdentifier } import leon.purescala.TreeOps @@ -16,7 +16,7 @@ import ranking._ import _root_.insynth.util.logging.HasLogger -case class DefaultExampleRunner(program: Program, funDef: FunDef, ctx: LeonContext, +case class DefaultExampleRunner(program: Program, tfd: TypedFunDef, ctx: LeonContext, candidates: Seq[Candidate], inputExamples: Seq[Example], maxSteps: Int = 2000) extends ExampleRunner(inputExamples) with HasLogger { @@ -45,7 +45,7 @@ case class DefaultExampleRunner(program: Program, funDef: FunDef, ctx: LeonConte } def evaluate(expr: Expr, args: Seq[Expr]): Boolean = { - evaluate(expr, funDef.args.map(_.id).zip(args).toMap) + evaluate(expr, tfd.args.map(_.id).zip(args).toMap) } override def evaluateToResult(expr: Expr, mapping: Map[Identifier, Expr]) = { diff --git a/src/main/scala/leon/synthesis/condabd/evaluation/EvaluationStrategy.scala b/src/main/scala/leon/synthesis/condabd/evaluation/EvaluationStrategy.scala index ee0e54ff3bfb00f9973c61dc102ae559a3d6c7e3..4516fbbf68541381ffa0d25ed6ab468b40628027 100644 --- a/src/main/scala/leon/synthesis/condabd/evaluation/EvaluationStrategy.scala +++ b/src/main/scala/leon/synthesis/condabd/evaluation/EvaluationStrategy.scala @@ -5,7 +5,7 @@ import leon._ import leon.evaluators._ import leon.evaluators.EvaluationResults._ import leon.purescala.Trees._ -import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ObjectDef } +import leon.purescala.Definitions.{ TypedFunDef, VarDecl, Program, ModuleDef } import leon.purescala.Common.{ Identifier, FreshIdentifier } import leon.purescala.TreeOps import leon.codegen.CodeGenParams @@ -40,7 +40,7 @@ trait EvaluationStrategy extends HasLogger { } -case class DefaultEvaluationStrategy(program: Program, funDef: FunDef, ctx: LeonContext, +case class DefaultEvaluationStrategy(program: Program, tfd: TypedFunDef, ctx: LeonContext, maxSteps: Int = 2000) extends EvaluationStrategy with HasLogger { var exampleRunner: ExampleRunner = _ @@ -49,9 +49,9 @@ case class DefaultEvaluationStrategy(program: Program, funDef: FunDef, ctx: Leon override def getRanker(candidatePairs: IndexedSeq[Output], bodyBuilder: (Expr) => Expr, inputExamples: Seq[Example]) = { - val candidates = Candidate.makeDefaultCandidates(candidatePairs, bodyBuilder, funDef) + val candidates = Candidate.makeDefaultCandidates(candidatePairs, bodyBuilder, tfd) - exampleRunner = DefaultExampleRunner(program, funDef, ctx, candidates, inputExamples) + exampleRunner = DefaultExampleRunner(program, tfd, ctx, candidates, inputExamples) logCounts(candidates, inputExamples) @@ -78,7 +78,7 @@ case class DefaultEvaluationStrategy(program: Program, funDef: FunDef, ctx: Leon override def getEvaluation = evaluation } -case class CodeGenEvaluationStrategy(program: Program, funDef: FunDef, ctx: LeonContext, +case class CodeGenEvaluationStrategy(program: Program, tfd: TypedFunDef, ctx: LeonContext, maxSteps: Int = 200) extends EvaluationStrategy with HasLogger { var exampleRunner: ExampleRunner = _ @@ -87,16 +87,16 @@ case class CodeGenEvaluationStrategy(program: Program, funDef: FunDef, ctx: Leon override def getRanker(candidatePairs: IndexedSeq[Output], bodyBuilder: (Expr) => Expr, inputExamples: Seq[Example]) = { - val candidates = Candidate.makeCodeGenCandidates(candidatePairs, bodyBuilder, funDef) + val candidates = Candidate.makeCodeGenCandidates(candidatePairs, bodyBuilder, tfd) - val newProgram = program.copy(mainObject = program.mainObject.copy(defs = program.mainObject.defs ++ candidates.map(_.newFunDef))) + val newProgram = program.copy(mainModule = program.mainModule.copy(defs = program.mainModule.defs ++ candidates.map(_.newFunDef))) finest("New program looks like: " + newProgram) finest("Candidates look like: " + candidates.map(_.prepareExpression).mkString("\n")) val params = CodeGenParams(maxFunctionInvocations = maxSteps, checkContracts = true) - exampleRunner = CodeGenExampleRunner(newProgram, funDef, ctx, candidates, inputExamples, params) + exampleRunner = CodeGenExampleRunner(newProgram, tfd, ctx, candidates, inputExamples, params) logCounts(candidates, inputExamples) diff --git a/src/main/scala/leon/synthesis/condabd/examples/InputExamples.scala b/src/main/scala/leon/synthesis/condabd/examples/InputExamples.scala index e4af7b2f3ce50a8f22e58d07f4485642169f2747..01c6d69cdfb784600d63e17ab9924f2cc7c807ff 100755 --- a/src/main/scala/leon/synthesis/condabd/examples/InputExamples.scala +++ b/src/main/scala/leon/synthesis/condabd/examples/InputExamples.scala @@ -66,9 +66,9 @@ object InputExamples { // list type val ct = argumentIds(1).getType.asInstanceOf[ClassType] - val setSubclasses = loader.directSubclassesMap(ct).map(_.asInstanceOf[CaseClassType].classDef) + val setSubclasses = loader.directSubclassesMap(ct).map(_.asInstanceOf[CaseClassType]) - val (nilClassSet, consClassSet) = setSubclasses.partition(_.fieldsIds.size == 0) + val (nilClassSet, consClassSet) = setSubclasses.partition(_.fields.size == 0) val nilClass = nilClassSet.head val consClass = consClassSet.head @@ -95,9 +95,9 @@ object InputExamples { // list type val ct = argumentIds(0).getType.asInstanceOf[ClassType] - val setSubclasses = loader.directSubclassesMap(ct).map(_.asInstanceOf[CaseClassType].classDef) + val setSubclasses = loader.directSubclassesMap(ct).map(_.asInstanceOf[CaseClassType]) - val (nilClassSet, consClassSet) = setSubclasses.partition(_.fieldsIds.size == 0) + val (nilClassSet, consClassSet) = setSubclasses.partition(_.fields.size == 0) val nilClass = nilClassSet.head val consClass = consClassSet.head @@ -124,9 +124,9 @@ object InputExamples { goalType match { case ct: ClassType => - val setSubclasses = loader.directSubclassesMap(ct).map(_.asInstanceOf[CaseClassType].classDef) + val setSubclasses = loader.directSubclassesMap(ct).map(_.asInstanceOf[CaseClassType]) - val (nilClassSet, consClassSet) = setSubclasses.partition(_.fieldsIds.size == 0) + val (nilClassSet, consClassSet) = setSubclasses.partition(_.fields.size == 0) val nilClass = nilClassSet.head val consClass = consClassSet.head diff --git a/src/main/scala/leon/synthesis/condabd/insynth/leon/CommonTypes.scala b/src/main/scala/leon/synthesis/condabd/insynth/leon/CommonTypes.scala index 50c78fac03184df7bad4dddf70432da64f604710..f579f3f1967072b56668cdc01a51d7f0986cf0ed 100644 --- a/src/main/scala/leon/synthesis/condabd/insynth/leon/CommonTypes.scala +++ b/src/main/scala/leon/synthesis/condabd/insynth/leon/CommonTypes.scala @@ -8,7 +8,7 @@ import leon.purescala.Definitions.AbstractClassDef object CommonTypes { - val LeonBottomType = AbstractClassType(new AbstractClassDef(FreshIdentifier("$IDontCare$"))) + val LeonBottomType = AbstractClassType(new AbstractClassDef(FreshIdentifier("$IDontCare$"), Nil, None), Nil) val InSynthBottomType = Const("$IDontCare$") -} \ No newline at end of file +} diff --git a/src/main/scala/leon/synthesis/condabd/insynth/leon/DomainTypeTransformer.scala b/src/main/scala/leon/synthesis/condabd/insynth/leon/DomainTypeTransformer.scala index 1ff79a4f9639ff0fa0fb84000bd778a06acc84f2..78e6e77fd7241fc8da33a0540c96806db0485ce5 100644 --- a/src/main/scala/leon/synthesis/condabd/insynth/leon/DomainTypeTransformer.scala +++ b/src/main/scala/leon/synthesis/condabd/insynth/leon/DomainTypeTransformer.scala @@ -12,7 +12,7 @@ object DomainTypeTransformer extends ( LeonType => DomainType ) { val InSynthTypeTransformer = TypeTransformer - def apply(typeDef: ClassTypeDef): DomainType = { + def apply(typeDef: ClassDef): DomainType = { implicit def singletonList(x: DomainType) = List(x) typeDef match { @@ -53,4 +53,4 @@ object DomainTypeTransformer extends ( LeonType => DomainType ) { Function( params map this, this(t) ) } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/synthesis/condabd/insynth/leon/TypeTransformer.scala b/src/main/scala/leon/synthesis/condabd/insynth/leon/TypeTransformer.scala index 9ddff7a91c6301ec1f706f667cb2cbebaf1e79ad..18332e270a969a8a245605eaa43bb4e648a9eec1 100644 --- a/src/main/scala/leon/synthesis/condabd/insynth/leon/TypeTransformer.scala +++ b/src/main/scala/leon/synthesis/condabd/insynth/leon/TypeTransformer.scala @@ -12,7 +12,7 @@ import scala.language.implicitConversions object TypeTransformer extends ( LeonType => SuccinctType ) { - def apply(typeDef: ClassTypeDef): SuccinctType = { + def apply(typeDef: ClassDef): SuccinctType = { implicit def singletonList(x: SuccinctType) = List(x) typeDef match { @@ -53,4 +53,4 @@ object TypeTransformer extends ( LeonType => SuccinctType ) { Arrow( TSet(params map this distinct), this(t) ) } -} \ No newline at end of file +} diff --git a/src/main/scala/leon/synthesis/condabd/insynth/leon/loader/DeclarationFactory.scala b/src/main/scala/leon/synthesis/condabd/insynth/leon/loader/DeclarationFactory.scala index 27e1294e53c3d1d767d406ec77d8ccd64091223d..141a13934312e55af18d8859f293878eefe6e8ab 100644 --- a/src/main/scala/leon/synthesis/condabd/insynth/leon/loader/DeclarationFactory.scala +++ b/src/main/scala/leon/synthesis/condabd/insynth/leon/loader/DeclarationFactory.scala @@ -44,7 +44,7 @@ object DeclarationFactory { Declaration(ImmediateExpression(varDecl.id, Variable(varDecl.id)), inSynthType, leonType) } - def makeInheritance(from: ClassTypeDef, to: ClassTypeDef) = { + def makeInheritance(from: ClassDef, to: ClassDef) = { val expr = UnaryReconstructionExpression("[" + from.id.name + "=>" + to.id.name + "]", identity[Expr] _) val inSynthType = Arrow(TSet(TypeTransformer(from)), TypeTransformer(to)) @@ -81,4 +81,4 @@ object DeclarationFactory { // define this for abstract declarations def getAbsExpression(inSynthType: InSynthType) = ErrorExpression -} \ No newline at end of file +} diff --git a/src/main/scala/leon/synthesis/condabd/insynth/leon/loader/Loader.scala b/src/main/scala/leon/synthesis/condabd/insynth/leon/loader/Loader.scala index f703f8ae0534be848c74c43c76b27776b3f513c8..16638be4ee73fa64cac800b4817ec51f395172d7 100644 --- a/src/main/scala/leon/synthesis/condabd/insynth/leon/loader/Loader.scala +++ b/src/main/scala/leon/synthesis/condabd/insynth/leon/loader/Loader.scala @@ -9,7 +9,7 @@ import leon.purescala.Definitions.{ Program, FunDef } import leon.purescala.TypeTrees.{ TypeTree => LeonType, _ } import leon.purescala.Trees.{ Expr, FunctionInvocation, _ } import leon.purescala.Common.{ Identifier } -import leon.purescala.Definitions.{ AbstractClassDef, CaseClassDef, ClassTypeDef } +import leon.purescala.Definitions.{ AbstractClassDef, CaseClassDef, ClassDef } // enable postfix operations import scala.language.postfixOps @@ -42,10 +42,12 @@ case class LeonLoader(program: Program, variables: List[Identifier], loadArithme // add function declarations for( funDef <- program.definedFunctions ) { - val leonFunctionType = FunctionType(funDef.args map { _.tpe } toList, funDef.returnType) + val tfd = funDef.typed(funDef.tparams.map(_.tp)) + + val leonFunctionType = tfd.functionType val newDeclaration = makeDeclaration( - NaryReconstructionExpression( funDef.id, { args: List[Expr] => FunctionInvocation(funDef, args) } ), + NaryReconstructionExpression(tfd.fd.id, { args: List[Expr] => FunctionInvocation(tfd, args) } ), leonFunctionType ) @@ -68,12 +70,12 @@ case class LeonLoader(program: Program, variables: List[Identifier], loadArithme for (variable <- variables; variableType = variable.getType) variableType match { case variableClassType: CaseClassType => variableClassType.classDef match { - case cas@CaseClassDef(id, parent, fields) => + case cas @ CaseClassDef(id, tparams, parent, isObj) => fine("adding fields of variable " + variable) - for (field <- fields) + for (field <- cas.fields) list += makeDeclaration( ImmediateExpression( field.id.name , - CaseClassSelector(cas, variable.toVariable, field.id) ), + CaseClassSelector(CaseClassType(cas, tparams.map(_.tp)), variable.toVariable, field.id) ), field.id.getType ) case _ => @@ -103,8 +105,8 @@ case class LeonLoader(program: Program, variables: List[Identifier], loadArithme ( for (classDef <- program.definedClasses) yield classDef match { - case caseClassDef: CaseClassDef => ( classDef.id, CaseClassType(caseClassDef) ) - case absClassDef: AbstractClassDef => ( absClassDef.id, AbstractClassType(absClassDef) ) + case ccd: CaseClassDef => ( ccd.id, CaseClassType(ccd, ccd.tparams.map(_.tp)) ) + case acd: AbstractClassDef => ( acd.id, AbstractClassType(acd, acd.tparams.map(_.tp)) ) } ) toMap } @@ -112,7 +114,7 @@ case class LeonLoader(program: Program, variables: List[Identifier], loadArithme // TODO add anyref to all and all to bottom ??? def extractInheritances: Seq[Declaration] = { - def extractInheritancesRec(classDef: ClassTypeDef): List[Declaration] = + def extractInheritancesRec(classDef: ClassDef): List[Declaration] = classDef match { case abs: AbstractClassDef => Nil ++ @@ -122,7 +124,7 @@ case class LeonLoader(program: Program, variables: List[Identifier], loadArithme classMap(child.id), classMap(abs.id) ) ) ++ ( - for (child <- abs.knownChildren) + for (child <-abs.knownChildren) yield extractInheritancesRec(child) ).flatten case _ => @@ -137,15 +139,16 @@ case class LeonLoader(program: Program, variables: List[Identifier], loadArithme yield inheritance } - def extractFields(classDef: ClassTypeDef) = classDef match { + def extractFields(classDef: ClassDef) = classDef match { case abs: AbstractClassDef => // this case does not seem to work //abs.fields Seq.empty case cas: CaseClassDef => + val cct = CaseClassType(cas, cas.tparams.map(_.tp)) for (field <- cas.fields) yield makeDeclaration( - UnaryReconstructionExpression(field.id.name, { CaseClassSelector(cas, _: Expr, field.id) }), + UnaryReconstructionExpression(field.id.name, { CaseClassSelector(cct, _: Expr, field.id) }), FunctionType(List(classMap(cas.id)), field.id.getType)) } @@ -156,14 +159,14 @@ case class LeonLoader(program: Program, variables: List[Identifier], loadArithme } def extractCaseClasses: Seq[Declaration] = { - for (caseClassDef@CaseClassDef(id, parent, fields) <- program.definedClasses) - yield fields match { + for (caseClassDef @ CaseClassDef(id, tparams, parent, isObj) <- program.definedClasses) + yield caseClassDef.fields match { case Nil => makeDeclaration( - ImmediateExpression( id.name, { CaseClass(caseClassDef, Nil) } ), + ImmediateExpression( id.name, { CaseClass(CaseClassType(caseClassDef, caseClassDef.tparams.map(_.tp)), Nil) } ), classMap(id) ) - case _ => makeDeclaration( - NaryReconstructionExpression( id.name , { CaseClass(caseClassDef, _: List[Expr]) } ), + case fields => makeDeclaration( + NaryReconstructionExpression( id.name , { CaseClass(CaseClassType(caseClassDef, caseClassDef.tparams.map(_.tp)), _: List[Expr]) } ), FunctionType(fields map { _.id.getType } toList, classMap(id)) ) } @@ -178,10 +181,10 @@ case class LeonLoader(program: Program, variables: List[Identifier], loadArithme // } // } - for ( classDef@CaseClassDef(_, _, _) <- program.definedClasses filter { _.isInstanceOf[CaseClassDef] }; +for ( classDef <- program.definedClasses collect { case ccd: CaseClassDef => ccd }; if classDef.hasParent) yield makeDeclaration( - UnaryReconstructionExpression( "isInstance[" + classDef.id + "]", { CaseClassInstanceOf(classDef, _: Expr) + UnaryReconstructionExpression( "isInstance[" + classDef.id + "]", { CaseClassInstanceOf(CaseClassType(classDef, classDef.tparams.map(_.tp)), _: Expr) } ), FunctionType(List(classMap(classDef.parent.get.id)), BooleanType) ) diff --git a/src/main/scala/leon/synthesis/condabd/ranking/Candidate.scala b/src/main/scala/leon/synthesis/condabd/ranking/Candidate.scala index 938823f178147fe5ec77f5cf7ede0302cd1acb94..9871e7a7df693a27872634b4d17b6be9f917af0d 100644 --- a/src/main/scala/leon/synthesis/condabd/ranking/Candidate.scala +++ b/src/main/scala/leon/synthesis/condabd/ranking/Candidate.scala @@ -17,19 +17,19 @@ object Candidate { def getFreshResultVariable(tpe: TypeTree) = _freshResultVariable = FreshIdentifier("result", true).setType(tpe) - def makeDefaultCandidates(candidatePairs: IndexedSeq[Output], bodyBuilder: Expr => Expr, funDef: FunDef) = { - getFreshResultVariable(funDef.body.get.getType) + def makeDefaultCandidates(candidatePairs: IndexedSeq[Output], bodyBuilder: Expr => Expr, tfd: TypedFunDef) = { + getFreshResultVariable(tfd.returnType) candidatePairs map { pair => - DefaultCandidate(pair.getSnippet, bodyBuilder(pair.getSnippet), pair.getWeight, funDef) + DefaultCandidate(pair.getSnippet, bodyBuilder(pair.getSnippet), pair.getWeight, tfd) } } - def makeCodeGenCandidates(candidatePairs: IndexedSeq[Output], bodyBuilder: Expr => Expr, funDef: FunDef) = { - getFreshResultVariable(funDef.body.get.getType) + def makeCodeGenCandidates(candidatePairs: IndexedSeq[Output], bodyBuilder: Expr => Expr, tfd: TypedFunDef) = { + getFreshResultVariable(tfd.returnType) candidatePairs map { pair => - CodeGenCandidate(pair.getSnippet, bodyBuilder(pair.getSnippet), pair.getWeight, funDef) + CodeGenCandidate(pair.getSnippet, bodyBuilder(pair.getSnippet), pair.getWeight, tfd) } } } @@ -42,7 +42,7 @@ abstract class Candidate(weight: Weight) { def getWeight = weight } -case class DefaultCandidate(expr: Expr, bodyExpr: Expr, weight: Weight, holeFunDef: FunDef) +case class DefaultCandidate(expr: Expr, bodyExpr: Expr, weight: Weight, tfd: TypedFunDef) extends Candidate(weight) with HasLogger { import Candidate._ @@ -54,7 +54,7 @@ case class DefaultCandidate(expr: Expr, bodyExpr: Expr, weight: Weight, holeFunD assert(bodyExpr.getType != Untyped) val resFresh = freshResultVariable//.setType(expr.getType) - val (id, post) = holeFunDef.postcondition.get + val (id, post) = tfd.postcondition.get // body can contain (self-)recursive calls Let(resFresh, bodyExpr, @@ -64,15 +64,15 @@ case class DefaultCandidate(expr: Expr, bodyExpr: Expr, weight: Weight, holeFunD override def prepareExpression = { // set appropriate body to the function for the correct evaluation due to recursive calls - holeFunDef.body = Some(bodyExpr) + tfd.fd.body = Some(bodyExpr) -// finest("going to evaluate candidate for: " + holeFunDef) +// finest("going to evaluate candidate for: " + tfd) // finest("going to evaluate candidate for: " + expressionToEvaluate) expressionToEvaluate } } -case class CodeGenCandidate(expr: Expr, bodyExpr: Expr, weight: Weight, holeFunDef: FunDef) +case class CodeGenCandidate(expr: Expr, bodyExpr: Expr, weight: Weight, tfd: TypedFunDef) extends Candidate(weight) with HasLogger { import Candidate._ @@ -80,15 +80,16 @@ case class CodeGenCandidate(expr: Expr, bodyExpr: Expr, weight: Weight, holeFunD lazy val (expressionToEvaluate, newFunDef) = { import TreeOps._ + val fd = tfd.fd val newFunId = FreshIdentifier("tempIntroducedFunction") - val newFun = new FunDef(newFunId, holeFunDef.returnType, holeFunDef.args) - newFun.precondition = holeFunDef.precondition - newFun.postcondition = holeFunDef.postcondition + val newFun = new FunDef(newFunId, fd.tparams, fd.returnType, fd.args) + newFun.precondition = fd.precondition + newFun.postcondition = fd.postcondition def replaceFunDef(expr: Expr) = expr match { - case FunctionInvocation(`holeFunDef`, args) => - Some(FunctionInvocation(newFun, args)) + case FunctionInvocation(`tfd`, args) => + Some(FunctionInvocation(newFun.typed(tfd.tps), args)) case _ => None } val newBody = postMap(replaceFunDef)(bodyExpr) @@ -109,7 +110,7 @@ case class CodeGenCandidate(expr: Expr, bodyExpr: Expr, weight: Weight, holeFunD } override def prepareExpression = { -// finest("going to evaluate candidate for: " + holeFunDef) +// finest("going to evaluate candidate for: " + tfd) // finest("going to evaluate candidate for: " + expressionToEvaluate) expressionToEvaluate } diff --git a/src/main/scala/leon/synthesis/condabd/refinement/Filter.scala b/src/main/scala/leon/synthesis/condabd/refinement/Filter.scala index 4f7db1aa792c38b8c9a8bd28db564e700a4737e5..7916047444681aa37bed13b2ee59430a1805aa16 100755 --- a/src/main/scala/leon/synthesis/condabd/refinement/Filter.scala +++ b/src/main/scala/leon/synthesis/condabd/refinement/Filter.scala @@ -15,7 +15,7 @@ import insynth.util.logging.HasLogger /** * Class used for filtering out unnecessary candidates during the search */ -class Filter(program: Program, holeFunDef: FunDef, refiner: VariableRefiner) extends HasLogger { +class Filter(program: Program, holeFunDef: TypedFunDef, refiner: VariableRefiner) extends HasLogger { // caching of previously filtered expressions type FilterSet = HashSet[Expr] @@ -130,7 +130,7 @@ class Filter(program: Program, holeFunDef: FunDef, refiner: VariableRefiner) ext } def isUnecessaryInstanceOf(expr: Expr) = { - def isOfClassType(exp: Expr, classDef: ClassTypeDef) = + def isOfClassType(exp: Expr, classDef: ClassDef) = expr.getType match { case tpe: ClassType => tpe.classDef == classDef case _ => false @@ -141,13 +141,13 @@ class Filter(program: Program, holeFunDef: FunDef, refiner: VariableRefiner) ext // true case CaseClassInstanceOf(classDef, _: FunctionInvocation) => true - case CaseClassInstanceOf(classDef, innerExpr) - if isOfClassType(innerExpr, classDef) => + case CaseClassInstanceOf(cct, innerExpr) + if isOfClassType(innerExpr, cct.classDef) => true - case CaseClassInstanceOf(classDef, v@Variable(id)) => { + case CaseClassInstanceOf(cct, v@Variable(id)) => { val possibleTypes = refiner.getPossibleTypes(id) if (possibleTypes.size == 1) - possibleTypes.head.classDef == classDef + possibleTypes.head.classDef == cct.classDef else false } diff --git a/src/main/scala/leon/synthesis/condabd/refinement/VariableRefiner.scala b/src/main/scala/leon/synthesis/condabd/refinement/VariableRefiner.scala index 08c08d34f1c4913f1a1bf13315d6bd8a077bc907..e47b82d97aea1de6f337e3c6628fc45f0c86cced 100755 --- a/src/main/scala/leon/synthesis/condabd/refinement/VariableRefiner.scala +++ b/src/main/scala/leon/synthesis/condabd/refinement/VariableRefiner.scala @@ -49,14 +49,14 @@ trait VariableRefiner extends HasLogger { yield dec match { case Declaration(inSynthType, _, decClassType, imex @ ImmediateExpression(_, Variable(`id`))) => (( - newType.classDef match { - case newTypeCaseClassDef @ CaseClassDef(_, parent, fields) => - fine("matched case class def for refinement " + newTypeCaseClassDef) - for (field <- fields) + newType match { + case cct: CaseClassType => + fine("matched case class def for refinement " + cct) + for (field <- cct.fields) yield Declaration( ImmediateExpression(id.name + "." + field.id.name, - CaseClassSelector(newTypeCaseClassDef, imex.expr, field.id)), - TypeTransformer(field.id.getType), field.id.getType) + CaseClassSelector(cct, imex.expr, field.id)), + TypeTransformer(field.tpe), field.tpe) case _ => Seq.empty }): Seq[Declaration]) :+ Declaration(imex, TypeTransformer(newType), newType) diff --git a/src/main/scala/leon/synthesis/condabd/refinement/VariableRefinerExecution.scala b/src/main/scala/leon/synthesis/condabd/refinement/VariableRefinerExecution.scala index 989fd3ccfa890189732abdf90ca2659a634b8cec..7511e044dcb9020626b8e6c04e73535142022c4a 100755 --- a/src/main/scala/leon/synthesis/condabd/refinement/VariableRefinerExecution.scala +++ b/src/main/scala/leon/synthesis/condabd/refinement/VariableRefinerExecution.scala @@ -29,8 +29,8 @@ class VariableRefinerExecution(variableDeclarations: Seq[Declaration], // TODO use cd.knownDescendents? for (varDec <- variableDeclarations) { varDec match { - case Declaration(_, _, typeOfVar: ClassType, ImmediateExpression(_, IsTyped(Variable(id), AbstractClassType(cd)))) => - variableRefinements += (id -> MutableSet(cd.knownDescendents.map(classDefToClassType _): _*)) + case Declaration(_, _, typeOfVar: ClassType, ImmediateExpression(_, IsTyped(Variable(id), AbstractClassType(cd, tps)))) => + variableRefinements += (id -> MutableSet(cd.knownDescendents.map(classDefToClassType(_, tps)): _*)) case _ => } } @@ -44,16 +44,16 @@ class VariableRefinerExecution(variableDeclarations: Seq[Declaration], if (variables.size == 1) { val variable = variables.head variable match { - case oldId @ IsTyped(id, AbstractClassType(cd)) // do not try to refine if we already know a single type is possible + case oldId @ IsTyped(id, AbstractClassType(cd, tps)) // do not try to refine if we already know a single type is possible if variableRefinements(id).size > 1 => assert(variableRefinements(id).map(_.classDef) subsetOf cd.knownDescendents.toSet) val optCases = - for (dcd <- variableRefinements(id).map(_.classDef)) yield dcd match { - case ccd: CaseClassDef if ccd.fields.isEmpty => + for (cct <- variableRefinements(id)) yield cct match { + case cct: CaseClassType if cct.fields.isEmpty => - val testValue = CaseClass(ccd, Nil) + val testValue = CaseClass(cct, Nil) val conditionToEvaluate = And(Not(expr), condition) fine("Execute condition " + conditionToEvaluate + " on variable " + id + " as " + testValue) @@ -61,9 +61,9 @@ class VariableRefinerExecution(variableDeclarations: Seq[Declaration], case Successful(BooleanLiteral(false)) => fine("EvaluationSuccessful(false)") fine("Refining variableRefinements(id): " + variableRefinements(id)) - variableRefinements(id) -= classDefToClassType(ccd) + variableRefinements(id) -= cct fine("Refined variableRefinements(id): " + variableRefinements(id)) - Some(ccd) + Some(cct) case Successful(BooleanLiteral(true)) => fine("EvaluationSuccessful(true)") None diff --git a/src/main/scala/leon/synthesis/condabd/refinement/VariableRefinerStructure.scala b/src/main/scala/leon/synthesis/condabd/refinement/VariableRefinerStructure.scala index a76c8e9cfc73c19fafbfa734c4e66aa28e5df99f..b3a442a724dee061db660f5b2f8db7d008792625 100755 --- a/src/main/scala/leon/synthesis/condabd/refinement/VariableRefinerStructure.scala +++ b/src/main/scala/leon/synthesis/condabd/refinement/VariableRefinerStructure.scala @@ -44,7 +44,7 @@ class VariableRefinerStructure(directSubclassMap: Map[ClassType, Set[ClassType]] List((id, variableRefinements(id).toSet)) case _ => - Nil + Nil } // inspect the expression if some refinements can be done diff --git a/src/main/scala/leon/synthesis/condabd/refinement/VariableSolverRefiner.scala b/src/main/scala/leon/synthesis/condabd/refinement/VariableSolverRefiner.scala index 262d4528a76a3ef84abb78dd1bf778a7bb6dda2a..73406a3edeaa7594b3e063275478c1ff1d0ff7bc 100755 --- a/src/main/scala/leon/synthesis/condabd/refinement/VariableSolverRefiner.scala +++ b/src/main/scala/leon/synthesis/condabd/refinement/VariableSolverRefiner.scala @@ -31,30 +31,29 @@ class VariableSolverRefiner(directSubclassMap: Map[ClassType, Set[ClassType]], v if (variables.size == 1) { val variable = variables.head variable match { - case oldId@IsTyped(id, AbstractClassType(cd)) if variableRefinements(id).size > 1 => - + case oldId@IsTyped(id, AbstractClassType(cd, tps)) if variableRefinements(id).size > 1 => assert(variableRefinements(id).map(_.classDef) subsetOf cd.knownDescendents.toSet) //val optCases = for (dcd <- cd.knownDescendents.sortBy(_.id.name)) yield dcd match { - val optCases = for (dcd <- variableRefinements(id).map(_.classDef)) yield dcd match { - case ccd : CaseClassDef => + val optCases = for (cct <- variableRefinements(id)) yield cct match { + case cct : CaseClassType => fine("testing variable " + id + " with condition " + condition) - val toSat = And(condition, CaseClassInstanceOf(ccd, Variable(id))) + val toSat = And(condition, CaseClassInstanceOf(cct, Variable(id))) fine("checking satisfiability of: " + toSat) solver.solveSAT(toSat) match { case (Some(false), _) => - fine("variable cannot be of type " + ccd) + fine("variable cannot be of type " + cct) None case _ => - fine("variable can be of type " + ccd) - Some(ccd) + fine("variable can be of type " + cct) + Some(cct) } case _ => None } val cases = optCases.flatten - variableRefinements(id) = variableRefinements(id) & cases.map(CaseClassType(_)).toSet + variableRefinements(id) = variableRefinements(id) & cases.toSet assert(variableRefinements(id).size == cases.size) List((id, variableRefinements(id).toSet)) @@ -72,26 +71,23 @@ class VariableSolverRefiner(directSubclassMap: Map[ClassType, Set[ClassType]], v def refineProblem(p: Problem) = { val newAs = p.as.map { - case oldId @ IsTyped(id, AbstractClassType(cd)) => + case oldId @ IsTyped(id, act : AbstractClassType) => - val optCases = for (dcd <- cd.knownDescendents.sortBy(_.id.name)) yield dcd match { - case ccd: CaseClassDef => - val toSat = And(p.pc, CaseClassInstanceOf(ccd, Variable(id))) + val optCases = for (cct <- act.knownCCDescendents) yield { + val toSat = And(p.pc, CaseClassInstanceOf(cct, Variable(id))) - val isImplied = solver.solveSAT(toSat) match { - case (Some(false), _) => true - case _ => false - } + val isImplied = solver.solveSAT(toSat) match { + case (Some(false), _) => true + case _ => false + } - println(isImplied) + println(isImplied) - if (!isImplied) { - Some(ccd) - } else { - None - } - case _ => + if (!isImplied) { + Some(cct) + } else { None + } } val cases = optCases.flatten @@ -101,7 +97,7 @@ class VariableSolverRefiner(directSubclassMap: Map[ClassType, Set[ClassType]], v if (cases.size == 1) { // id.setType(CaseClassType(cases.head)) - FreshIdentifier(oldId.name).setType(CaseClassType(cases.head)) + FreshIdentifier(oldId.name).setType(cases.head) } else oldId case id => id diff --git a/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala b/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala index e6a486fa70cc9d8446d0d1500be4820485c23ea8..29178fcfe5ac4447302110a5c1f6bb57e5072aac 100755 --- a/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala +++ b/src/main/scala/leon/synthesis/condabd/rules/ConditionAbductionSynthesisTwoPhase.scala @@ -29,13 +29,14 @@ case object ConditionAbductionSynthesisTwoPhase extends Rule("Condition abductio val reporter = sctx.reporter val desiredType = givenVariable.getType - val holeFunDef = sctx.functionContext.get + val fd = sctx.functionContext.get + val tfd = fd.typed(fd.tparams.map(_.tp)) // temporary hack, should not mutate FunDef - val oldPostcondition = holeFunDef.postcondition + val oldPostcondition = fd.postcondition try { - val freshResID = FreshIdentifier("result").setType(holeFunDef.returnType) + val freshResID = FreshIdentifier("result").setType(tfd.returnType) val freshResVar = Variable(freshResID) val codeGenEval = new CodeGenEvaluator(sctx.context, sctx.program) @@ -49,11 +50,11 @@ case object ConditionAbductionSynthesisTwoPhase extends Rule("Condition abductio ,10) } - val evaluationStrategy = new CodeGenEvaluationStrategy(program, holeFunDef, sctx.context, 5000) - holeFunDef.postcondition = Some((givenVariable, p.phi)) + val evaluationStrategy = new CodeGenEvaluationStrategy(program, tfd, sctx.context, 5000) + fd.postcondition = Some((givenVariable, p.phi)) val synthesizer = new SynthesizerForRuleExamples( - solver, program, desiredType, holeFunDef, p, sctx, evaluationStrategy, + solver, program, desiredType, tfd, p, sctx, evaluationStrategy, 20, 1, reporter = reporter, introduceExamples = getInputExamples, @@ -77,7 +78,7 @@ case object ConditionAbductionSynthesisTwoPhase extends Rule("Condition abductio e.printStackTrace RuleApplicationImpossible } finally { - holeFunDef.postcondition = oldPostcondition + fd.postcondition = oldPostcondition } } } diff --git a/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala b/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala index 78a106d746b31dc995c1d5320e2b8bde1cf3b4a5..0e452329fd9092c389ed4a16a3e71e51bcfbe10d 100644 --- a/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala +++ b/src/main/scala/leon/synthesis/condabd/verification/AbstractVerifier.scala @@ -20,12 +20,12 @@ abstract class AbstractVerifier(solverf: SolverFactory[Solver with IncrementalSo import SynthesisInfo.Action._ - def analyzeFunction(funDef: FunDef) = { + def analyzeFunction(tfd: TypedFunDef) = { synthInfo.start(Verification) - fine("Analyzing function: " + funDef) + fine("Analyzing function: " + tfd) // create an expression to verify - val theExpr = generateInductiveVerificationCondition(funDef, funDef.body.get) + val theExpr = generateInductiveVerificationCondition(tfd, tfd.body.get) solver.push val valid = checkValidity(theExpr) @@ -38,11 +38,11 @@ abstract class AbstractVerifier(solverf: SolverFactory[Solver with IncrementalSo (valid, map) } - def analyzeFunction(funDef: FunDef, body: Expr) = { + def analyzeFunction(tfd: TypedFunDef, body: Expr) = { synthInfo.start(Verification) // create an expression to verify - val theExpr = generateInductiveVerificationCondition(funDef, body) + val theExpr = generateInductiveVerificationCondition(tfd, body) solver.push val valid = checkValidity(theExpr) @@ -55,14 +55,14 @@ abstract class AbstractVerifier(solverf: SolverFactory[Solver with IncrementalSo (valid, map) } - protected def generateInductiveVerificationCondition(funDef: FunDef, body: Expr) = { + protected def generateInductiveVerificationCondition(tfd: TypedFunDef, body: Expr) = { // replace recursive calls with fresh variables case class Replacement(id: Identifier, exprReplaced: FunctionInvocation) { def getMapping: Map[Expr, Expr] = { - val funDef = exprReplaced.funDef - val pairList = (Variable(funDef.postcondition.get._1), id.toVariable) :: - (funDef.args.map(_.toVariable).toList zip exprReplaced.args) + val tfd = exprReplaced.tfd + val pairList = (Variable(tfd.postcondition.get._1), id.toVariable) :: + (tfd.args.map(_.toVariable).toList zip exprReplaced.args) pairList.toMap } } @@ -72,9 +72,9 @@ abstract class AbstractVerifier(solverf: SolverFactory[Solver with IncrementalSo var replacements = List[Replacement]() def replaceRecursiveCalls(expr: Expr) = expr match { - case funInv@FunctionInvocation(`funDef`, args) => { + case funInv@FunctionInvocation(`tfd`, args) => { isThereARecursiveCall = true - val inductId = FreshIdentifier("induct", true).setType(funDef.returnType) + val inductId = FreshIdentifier("induct", true).setType(tfd.returnType) replacements :+= Replacement(inductId, funInv) Some(inductId.toVariable) } @@ -85,7 +85,7 @@ abstract class AbstractVerifier(solverf: SolverFactory[Solver with IncrementalSo // build the verification condition val resFresh = FreshIdentifier("result", true).setType(newBody.getType) - val (id, post) = funDef.postcondition.get + val (id, post) = tfd.postcondition.get val bodyAndPost = Let( resFresh, newBody, @@ -93,9 +93,9 @@ abstract class AbstractVerifier(solverf: SolverFactory[Solver with IncrementalSo ) val precondition = if( isThereARecursiveCall ) { - And( funDef.precondition.get :: replacements.map( r => replace(r.getMapping, post)) ) + And( tfd.precondition.get :: replacements.map( r => replace(r.getMapping, post)) ) } else - funDef.precondition.get + tfd.precondition.get // val bodyAndPost = // Let( // resFresh, newBody, diff --git a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala b/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala index 724191dd0e45624b7f2a18aa4e4bb5eb4f0b9d90..d1252e93cb192348f6850727a006881d9b61bd79 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/ADTInduction.scala @@ -15,14 +15,13 @@ import purescala.Definitions._ case object ADTInduction extends Rule("ADT Induction") with Heuristic { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { val candidates = p.as.collect { - case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(sctx.solverFactory)(p.pc, origId) => (origId, cd) + case IsTyped(origId, act: AbstractClassType) if isInductiveOn(sctx.solverFactory)(p.pc, origId) => (origId, act) } val instances = for (candidate <- candidates) yield { - val (origId, cd) = candidate + val (origId, ct) = candidate val oas = p.as.filterNot(_ == origId) - val resType = TupleType(p.xs.map(_.getType)) val inductOn = FreshIdentifier(origId.name, true).setType(origId.getType) @@ -30,14 +29,11 @@ case object ADTInduction extends Rule("ADT Induction") with Heuristic { val residualMap = (oas zip residualArgs).map{ case (id, id2) => id -> Variable(id2) }.toMap val residualArgDefs = residualArgs.map(a => VarDecl(a, a.getType)) - def isAlternativeRecursive(cd: CaseClassDef): Boolean = { - cd.fieldsIds.exists(_.getType == origId.getType) + def isAlternativeRecursive(ct: CaseClassType): Boolean = { + ct.fields.exists(_.tpe == origId.getType) } - val isRecursive = cd.knownDescendents.exists { - case ccd: CaseClassDef => isAlternativeRecursive(ccd) - case _ => false - } + val isRecursive = ct.knownCCDescendents.exists(isAlternativeRecursive) // Map for getting a formula in the context of within the recursive function val substMap = residualMap + (origId -> Variable(inductOn)) @@ -47,50 +43,47 @@ case object ADTInduction extends Rule("ADT Induction") with Heuristic { val innerPhi = substAll(substMap, p.phi) val innerPC = substAll(substMap, p.pc) - val subProblemsInfo = for (dcd <- cd.knownDescendents.sortBy(_.id.name)) yield dcd match { - case ccd : CaseClassDef => - var recCalls = Map[List[Identifier], List[Expr]]() - var postFs = List[Expr]() + val subProblemsInfo = for (cct <- ct.knownCCDescendents) yield { + var recCalls = Map[List[Identifier], List[Expr]]() + var postFs = List[Expr]() - val newIds = ccd.fieldsIds.map(id => FreshIdentifier(id.name, true).setType(id.getType)).toList + val newIds = cct.fields.map(vd => FreshIdentifier(vd.id.name, true).setType(vd.tpe)).toList - val inputs = (for (id <- newIds) yield { - if (id.getType == origId.getType) { - val postXs = p.xs map (id => FreshIdentifier("r", true).setType(id.getType)) - val postXsMap = (p.xs zip postXs).toMap.mapValues(Variable(_)) + val inputs = (for (id <- newIds) yield { + if (id.getType == origId.getType) { + val postXs = p.xs map (id => FreshIdentifier("r", true).setType(id.getType)) + val postXsMap = (p.xs zip postXs).toMap.mapValues(Variable(_)) - recCalls += postXs -> (Variable(id) +: residualArgs.map(id => Variable(id))) + recCalls += postXs -> (Variable(id) +: residualArgs.map(id => Variable(id))) - postFs ::= substAll(postXsMap + (inductOn -> Variable(id)), innerPhi) - id :: postXs - } else { - List(id) - } - }).flatten + postFs ::= substAll(postXsMap + (inductOn -> Variable(id)), innerPhi) + id :: postXs + } else { + List(id) + } + }).flatten - val subPhi = substAll(Map(inductOn -> CaseClass(ccd, newIds.map(Variable(_)))), innerPhi) - val subPC = substAll(Map(inductOn -> CaseClass(ccd, newIds.map(Variable(_)))), innerPC) + val subPhi = substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable(_)))), innerPhi) + val subPC = substAll(Map(inductOn -> CaseClass(cct, newIds.map(Variable(_)))), innerPC) - val subPre = CaseClassInstanceOf(ccd, Variable(origId)) + val subPre = CaseClassInstanceOf(cct, Variable(origId)) - val subProblem = Problem(inputs ::: residualArgs, And(subPC :: postFs), subPhi, p.xs) + val subProblem = Problem(inputs ::: residualArgs, And(subPC :: postFs), subPhi, p.xs) - (subProblem, subPre, ccd, newIds, recCalls) - case _ => - sys.error("Woops, non case-class as descendent") + (subProblem, subPre, cct, newIds, recCalls) } val onSuccess: List[Solution] => Option[Solution] = { case sols => var globalPre = List[Expr]() - val newFun = new FunDef(FreshIdentifier("rec", true), resType, VarDecl(inductOn, inductOn.getType) +: residualArgDefs) + val newFun = new FunDef(FreshIdentifier("rec", true), Nil, resType, VarDecl(inductOn, inductOn.getType) +: residualArgDefs) - val cases = for ((sol, (problem, pre, ccd, ids, calls)) <- (sols zip subProblemsInfo)) yield { + val cases = for ((sol, (problem, pre, cct, ids, calls)) <- (sols zip subProblemsInfo)) yield { globalPre ::= And(pre, sol.pre) - val caze = CaseClassPattern(None, ccd, ids.map(id => WildcardPattern(Some(id)))) - SimpleCase(caze, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => LetTuple(binders, FunctionInvocation(newFun, callargs), t) }) + val caze = CaseClassPattern(None, cct, ids.map(id => WildcardPattern(Some(id)))) + SimpleCase(caze, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => LetTuple(binders, FunctionInvocation(newFun.typed, callargs), t) }) } // Might be overly picky with obviously true pre (a.is[Cons] OR a.is[Nil]) @@ -112,7 +105,7 @@ case object ADTInduction extends Rule("ADT Induction") with Heuristic { Some(Solution(Or(globalPre), sols.flatMap(_.defs).toSet+newFun, - FunctionInvocation(newFun, Variable(origId) :: oas.map(Variable(_))) + FunctionInvocation(newFun.typed, Variable(origId) :: oas.map(Variable(_))) )) } } diff --git a/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala b/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala index 0305450d8d53a53d5e26fb1bfd4f013f1074af79..ace89a1839331895c4f745c63b2ad18ba682b0e1 100644 --- a/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/ADTLongInduction.scala @@ -15,11 +15,11 @@ import purescala.Definitions._ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { val candidates = p.as.collect { - case IsTyped(origId, AbstractClassType(cd)) if isInductiveOn(sctx.solverFactory)(p.pc, origId) => (origId, cd) + case IsTyped(origId, act @ AbstractClassType(cd, tpe)) if isInductiveOn(sctx.solverFactory)(p.pc, origId) => (origId, act) } val instances = for (candidate <- candidates) yield { - val (origId, cd) = candidate + val (origId, ct) = candidate val oas = p.as.filterNot(_ == origId) @@ -30,21 +30,11 @@ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { val residualMap = (oas zip residualArgs).map{ case (id, id2) => id -> Variable(id2) }.toMap val residualArgDefs = residualArgs.map(a => VarDecl(a, a.getType)) - def isAlternativeRecursive(cd: CaseClassDef): Boolean = { - cd.fieldsIds.exists(_.getType == origId.getType) + def isAlternativeRecursive(ct: CaseClassType): Boolean = { + ct.fields.exists(_.tpe == origId.getType) } - val isRecursive = cd.knownDescendents.exists { - case ccd: CaseClassDef => isAlternativeRecursive(ccd) - case _ => false - } - - def childsOf(cd: AbstractClassDef): List[CaseClassDef] = { - cd.knownDescendents.sortBy(_.id.name).toList.collect { - case ccd: CaseClassDef => - ccd - } - } + val isRecursive = ct.knownCCDescendents.exists(isAlternativeRecursive) // Map for getting a formula in the context of within the recursive function val substMap = residualMap + (origId -> Variable(inductOn)) @@ -61,12 +51,12 @@ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { def isRec(id: Identifier) = id.getType == origId.getType - def unrollPattern(id: Identifier, ccd: CaseClassDef, withIds: List[Identifier])(on: Pattern): Pattern = on match { + def unrollPattern(id: Identifier, cct: CaseClassType, withIds: List[Identifier])(on: Pattern): Pattern = on match { case WildcardPattern(Some(pid)) if pid == id => - CaseClassPattern(None, ccd, withIds.map(id => WildcardPattern(Some(id)))) + CaseClassPattern(None, cct, withIds.map(id => WildcardPattern(Some(id)))) case CaseClassPattern(binder, sccd, sub) => - CaseClassPattern(binder, sccd, sub.map(unrollPattern(id, ccd, withIds) _)) + CaseClassPattern(binder, sccd, sub.map(unrollPattern(id, cct, withIds) _)) case _ => on } @@ -76,8 +66,8 @@ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { val InductCase(ids, calls, pat, pc, trMap) = ic (for (id <- ids if isRec(id)) yield { - for (ccd <- childsOf(cd)) yield { - val subIds = ccd.fieldsIds.map(id => FreshIdentifier(id.name, true).setType(id.getType)).toList + for (cct <- ct.knownCCDescendents) yield { + val subIds = cct.fields.map(vd => FreshIdentifier(vd.id.name, true).setType(vd.tpe)).toList val newIds = ids.filterNot(_ == id) ++ subIds val newCalls = if (!subIds.isEmpty) { @@ -88,11 +78,11 @@ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { //println(ccd) //println(subIds) - val newPattern = unrollPattern(id, ccd, subIds)(pat) + val newPattern = unrollPattern(id, cct, subIds)(pat) - val newMap = trMap.mapValues(v => substAll(Map(id -> CaseClass(ccd, subIds.map(Variable(_)))), v)) + val newMap = trMap.mapValues(v => substAll(Map(id -> CaseClass(cct, subIds.map(Variable(_)))), v)) - InductCase(newIds, newCalls, newPattern, And(pc, CaseClassInstanceOf(ccd, Variable(id))), newMap) + InductCase(newIds, newCalls, newPattern, And(pc, CaseClassInstanceOf(cct, Variable(id))), newMap) } }).flatten } else { @@ -137,12 +127,12 @@ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { case sols => var globalPre = List[Expr]() - val newFun = new FunDef(FreshIdentifier("rec", true), resType, VarDecl(inductOn, inductOn.getType) +: residualArgDefs) + val newFun = new FunDef(FreshIdentifier("rec", true), Nil, resType, VarDecl(inductOn, inductOn.getType) +: residualArgDefs) val cases = for ((sol, (problem, pat, calls, pc)) <- (sols zip subProblemsInfo)) yield { globalPre ::= And(pc, sol.pre) - SimpleCase(pat, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => LetTuple(binders, FunctionInvocation(newFun, callargs), t) }) + SimpleCase(pat, calls.foldLeft(sol.term){ case (t, (binders, callargs)) => LetTuple(binders, FunctionInvocation(newFun.typed, callargs), t) }) } // Might be overly picky with obviously true pre (a.is[Cons] OR a.is[Nil]) @@ -164,7 +154,7 @@ case object ADTLongInduction extends Rule("ADT Long Induction") with Heuristic { Some(Solution(Or(globalPre), sols.flatMap(_.defs).toSet+newFun, - FunctionInvocation(newFun, Variable(origId) :: oas.map(Variable(_))) + FunctionInvocation(newFun.typed, Variable(origId) :: oas.map(Variable(_))) )) } } diff --git a/src/main/scala/leon/synthesis/heuristics/IntInduction.scala b/src/main/scala/leon/synthesis/heuristics/IntInduction.scala index 1be8dd8eee16e74063a10ef2bbcc3e4c02d68cf0..c5773af4828c6fa6134f0599288e3d6c587b09c8 100644 --- a/src/main/scala/leon/synthesis/heuristics/IntInduction.scala +++ b/src/main/scala/leon/synthesis/heuristics/IntInduction.scala @@ -43,7 +43,7 @@ case object IntInduction extends Rule("Int Induction") with Heuristic { And(LessThan(Variable(inductOn), IntLiteral(0)), lt.pre))) val preOut = subst(inductOn -> Variable(origId), preIn) - val newFun = new FunDef(FreshIdentifier("rec", true), tpe, Seq(VarDecl(inductOn, inductOn.getType))) + val newFun = new FunDef(FreshIdentifier("rec", true), Nil, tpe, Seq(VarDecl(inductOn, inductOn.getType))) val idPost = FreshIdentifier("res").setType(tpe) newFun.precondition = Some(preIn) @@ -53,12 +53,12 @@ case object IntInduction extends Rule("Int Induction") with Heuristic { IfExpr(Equals(Variable(inductOn), IntLiteral(0)), base.toExpr, IfExpr(GreaterThan(Variable(inductOn), IntLiteral(0)), - LetTuple(postXs, FunctionInvocation(newFun, Seq(Minus(Variable(inductOn), IntLiteral(1)))), gt.toExpr) - , LetTuple(postXs, FunctionInvocation(newFun, Seq(Plus(Variable(inductOn), IntLiteral(1)))), lt.toExpr))) + LetTuple(postXs, FunctionInvocation(newFun.typed, Seq(Minus(Variable(inductOn), IntLiteral(1)))), gt.toExpr) + , LetTuple(postXs, FunctionInvocation(newFun.typed, Seq(Plus(Variable(inductOn), IntLiteral(1)))), lt.toExpr))) ) - Some(Solution(preOut, base.defs++gt.defs++lt.defs+newFun, FunctionInvocation(newFun, Seq(Variable(origId))))) + Some(Solution(preOut, base.defs++gt.defs++lt.defs+newFun, FunctionInvocation(newFun.typed, Seq(Variable(origId))))) } case _ => None diff --git a/src/main/scala/leon/synthesis/rules/ADTDual.scala b/src/main/scala/leon/synthesis/rules/ADTDual.scala index 7f73a8198f7e84ef47f5404dce945191eb391c13..e871521b4ac22b5e0932b177046b9bae4038fd92 100644 --- a/src/main/scala/leon/synthesis/rules/ADTDual.scala +++ b/src/main/scala/leon/synthesis/rules/ADTDual.scala @@ -17,11 +17,11 @@ case object ADTDual extends NormalizingRule("ADTDual") { val (toRemove, toAdd) = exprs.collect { - case eq @ Equals(cc @ CaseClass(cd, args), e) if (variablesOf(e) -- as).isEmpty && !(variablesOf(cc) & xs).isEmpty => - (eq, CaseClassInstanceOf(cd, e) +: (cd.fieldsIds zip args).map{ case (id, ex) => Equals(ex, CaseClassSelector(cd, e, id)) } ) + case eq @ Equals(cc @ CaseClass(ct, args), e) if (variablesOf(e) -- as).isEmpty && !(variablesOf(cc) & xs).isEmpty => + (eq, CaseClassInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, CaseClassSelector(ct, e, vd.id)) } ) - case eq @ Equals(e, cc @ CaseClass(cd, args)) if (variablesOf(e) -- as).isEmpty && !(variablesOf(cc) & xs).isEmpty => - (eq, CaseClassInstanceOf(cd, e) +: (cd.fieldsIds zip args).map{ case (id, ex) => Equals(ex, CaseClassSelector(cd, e, id)) } ) + case eq @ Equals(e, cc @ CaseClass(ct, args)) if (variablesOf(e) -- as).isEmpty && !(variablesOf(cc) & xs).isEmpty => + (eq, CaseClassInstanceOf(ct, e) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, CaseClassSelector(ct, e, vd.id)) } ) }.unzip if (!toRemove.isEmpty) { diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 93f584ccda6279c97a9c27586bfd86aa98ebbe45..cdc147b02f63be7347f0a88b9a72f4aeb1f3804a 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -17,11 +17,12 @@ case object ADTSplit extends Rule("ADT Split.") { val solver = SimpleSolverAPI(new TimeoutSolverFactory(sctx.solverFactory, 200L)) val candidates = p.as.collect { - case IsTyped(id, AbstractClassType(cd)) => + case IsTyped(id, act @ AbstractClassType(cd, tpes)) => val optCases = for (dcd <- cd.knownDescendents.sortBy(_.id.name)) yield dcd match { case ccd : CaseClassDef => - val toSat = And(p.pc, CaseClassInstanceOf(ccd, Variable(id))) + val cct = CaseClassType(ccd, tpes) + val toSat = And(p.pc, CaseClassInstanceOf(cct, Variable(id))) val isImplied = solver.solveSAT(toSat) match { case (Some(false), _) => true @@ -40,25 +41,27 @@ case object ADTSplit extends Rule("ADT Split.") { val cases = optCases.flatten if (!cases.isEmpty) { - Some((id, cases)) + Some((id, act, cases)) } else { None } } candidates.collect{ _ match { - case Some((id, cases)) => + case Some((id, act, cases)) => val oas = p.as.filter(_ != id) val subInfo = for(ccd <- cases) yield { - val args = ccd.fieldsIds.map(id => FreshIdentifier(id.name, true).setType(id.getType)).toList + val cct = CaseClassType(ccd, act.tps) - val subPhi = subst(id -> CaseClass(ccd, args.map(Variable(_))), p.phi) - val subPC = subst(id -> CaseClass(ccd, args.map(Variable(_))), p.pc) + val args = cct.fields.map { vd => FreshIdentifier(vd.id.name, true).setType(vd.tpe) }.toList + + val subPhi = subst(id -> CaseClass(cct, args.map(Variable(_))), p.phi) + val subPC = subst(id -> CaseClass(cct, args.map(Variable(_))), p.pc) val subProblem = Problem(args ::: oas, subPC, subPhi, p.xs) - val subPattern = CaseClassPattern(None, ccd, args.map(id => WildcardPattern(Some(id)))) + val subPattern = CaseClassPattern(None, cct, args.map(id => WildcardPattern(Some(id)))) - (ccd, subProblem, subPattern) + (cct, subProblem, subPattern) } @@ -66,8 +69,8 @@ case object ADTSplit extends Rule("ADT Split.") { case sols => var globalPre = List[Expr]() - val cases = for ((sol, (ccd, problem, pattern)) <- (sols zip subInfo)) yield { - globalPre ::= And(CaseClassInstanceOf(ccd, Variable(id)), sol.pre) + val cases = for ((sol, (cct, problem, pattern)) <- (sols zip subInfo)) yield { + globalPre ::= And(CaseClassInstanceOf(cct, Variable(id)), sol.pre) SimpleCase(pattern, sol.term) } diff --git a/src/main/scala/leon/synthesis/rules/Cegis.scala b/src/main/scala/leon/synthesis/rules/Cegis.scala index 59a8e70ab759a5f76806f91fb1b90fa0c50b6bb4..b588357fc46804d08b36fad776736f64dc3bc462 100644 --- a/src/main/scala/leon/synthesis/rules/Cegis.scala +++ b/src/main/scala/leon/synthesis/rules/Cegis.scala @@ -12,6 +12,7 @@ import purescala.Common._ import purescala.Definitions._ import purescala.TypeTrees._ import purescala.TreeOps._ +import purescala.TypeTreeOps._ import purescala.Extractors._ import purescala.ScalaPrinter @@ -60,21 +61,22 @@ case object CEGIS extends Rule("CEGIS") { List((Tuple(ids.map(Variable(_))), ids.toSet)) } - case CaseClassType(cd) => + case cct @ CaseClassType(cd, _) => { () => - val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType)) - List((CaseClass(cd, ids.map(Variable(_))), ids.toSet)) + val ids = cct.fields.map { vd => FreshIdentifier("c", true).setType(vd.tpe) } + List((CaseClass(cct, ids.map(Variable(_))), ids.toSet)) } - case AbstractClassType(cd) => + case AbstractClassType(cd, tpes) => { () => val alts: Seq[(Expr, Set[Identifier])] = cd.knownDescendents.flatMap(i => i match { case acd: AbstractClassDef => sctx.reporter.error("Unnexpected abstract class in descendants!") None case cd: CaseClassDef => - val ids = cd.fieldsIds.map(i => FreshIdentifier("c", true).setType(i.getType)) - Some((CaseClass(cd, ids.map(Variable(_))), ids.toSet)) + val cct = CaseClassType(cd, tpes) + val ids = cct.fields.map{ vd => FreshIdentifier("c", true).setType(vd.tpe) } + Some((CaseClass(cct, ids.map(Variable(_))), ids.toSet)) }) alts.toList } @@ -92,11 +94,11 @@ case object CEGIS extends Rule("CEGIS") { p.as.filter(a => isSubtypeOf(a.getType, t)).map(id => (Variable(id) : Expr, Set[Identifier]())) } - val funcCache: MutableMap[TypeTree, Seq[FunDef]] = MutableMap.empty + val funcCache: MutableMap[TypeTree, Seq[TypedFunDef]] = MutableMap.empty def funcAlternatives(t: TypeTree): List[(Expr, Set[Identifier])] = { if (useFunGenerators) { - def isCandidate(fd: FunDef): Boolean = { + def isCandidate(fd: FunDef): Option[TypedFunDef] = { // Prevents recursive calls val isRecursiveCall = sctx.functionContext match { case Some(cfd) => @@ -114,23 +116,30 @@ case object CEGIS extends Rule("CEGIS") { false } - - - isSubtypeOf(fd.returnType, t) && !isRecursiveCall && isNotSynthesizable + if (!isRecursiveCall && isNotSynthesizable) { + canBeSubtypeOf(fd.returnType, fd.tparams, t) match { + case Some(tps) => + Some(fd.typed(tps)) + case None => + None + } + } else { + None + } } val funcs = funcCache.get(t) match { case Some(alts) => alts case None => - val alts = sctx.program.definedFunctions.filter(isCandidate) + val alts = sctx.program.definedFunctions.flatMap(isCandidate) funcCache += t -> alts alts } - funcs.map{ fd => - val ids = fd.args.map(vd => FreshIdentifier("c", true).setType(vd.getType)) - (FunctionInvocation(fd, ids.map(Variable(_))), ids.toSet) + funcs.map{ tfd => + val ids = tfd.args.map(vd => FreshIdentifier("c", true).setType(vd.tpe)) + (FunctionInvocation(tfd, ids.map(Variable(_))), ids.toSet) }.toList } else { Nil diff --git a/src/main/scala/leon/synthesis/rules/DetupleInput.scala b/src/main/scala/leon/synthesis/rules/DetupleInput.scala index 5bb4045b088aa225cf7b7f34824bb953945911b3..008968e44bd1c84017d4a729da8670b657e4e5ca 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleInput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleInput.scala @@ -15,19 +15,18 @@ case object DetupleInput extends NormalizingRule("Detuple In") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { def isDecomposable(id: Identifier) = id.getType match { - case CaseClassType(t) if !t.isAbstract => true + case CaseClassType(t, _) if !t.isAbstract => true case TupleType(ts) => true case _ => false } def decompose(id: Identifier): (List[Identifier], Expr, Map[Identifier, Expr]) = id.getType match { - case CaseClassType(ccd) if !ccd.isAbstract => - val CaseClassDef(name, _, fields) = ccd - val newIds = fields.map(vd => FreshIdentifier(vd.id.name, true).setType(vd.getType)) + case cct @ CaseClassType(ccd, _) if !ccd.isAbstract => + val newIds = cct.fields.map{ vd => FreshIdentifier(vd.id.name, true).setType(vd.tpe) } - val map = (fields zip newIds).map{ case (f, nid) => nid -> CaseClassSelector(ccd, Variable(id), f.id) }.toMap + val map = (ccd.fields zip newIds).map{ case (vd, nid) => nid -> CaseClassSelector(cct, Variable(id), vd.id) }.toMap - (newIds.toList, CaseClass(ccd, newIds.map(Variable(_))), map) + (newIds.toList, CaseClass(cct, newIds.map(Variable(_))), map) case TupleType(ts) => val newIds = ts.zipWithIndex.map{ case (t, i) => FreshIdentifier(id.name+"_"+(i+1), true).setType(t) } diff --git a/src/main/scala/leon/synthesis/rules/DetupleOutput.scala b/src/main/scala/leon/synthesis/rules/DetupleOutput.scala index 6c026ad806554e9af4208f95321a126938e2e569..de4aa1809e5b6f4181dfd78acd9ea64951f4efb6 100644 --- a/src/main/scala/leon/synthesis/rules/DetupleOutput.scala +++ b/src/main/scala/leon/synthesis/rules/DetupleOutput.scala @@ -15,7 +15,7 @@ case object DetupleOutput extends Rule("Detuple Out") { def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { def isDecomposable(id: Identifier) = id.getType match { - case CaseClassType(t) if !t.isAbstract => true + case CaseClassType(t, _) if !t.isAbstract => true case _ => false } @@ -24,11 +24,11 @@ case object DetupleOutput extends Rule("Detuple Out") { val (subOuts, outerOuts) = p.xs.map { x => if (isDecomposable(x)) { - val CaseClassType(ccd @ CaseClassDef(name, _, fields)) = x.getType + val ct @ CaseClassType(ccd @ CaseClassDef(name, _, _, _), tpes) = x.getType - val newIds = fields.map(vd => FreshIdentifier(vd.id.name, true).setType(vd.getType)) + val newIds = ct.fields.map{ vd => FreshIdentifier(vd.id.name, true).setType(vd.tpe) } - val newCC = CaseClass(ccd, newIds.map(Variable(_))) + val newCC = CaseClass(ct, newIds.map(Variable(_))) subProblem = subst(x -> newCC, subProblem) diff --git a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala index 6def2cdfb455ede996d162e0e7faa86dfc52cb32..ec66a047eac5c3f9dc844d6870913ccba5dc08e0 100644 --- a/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala +++ b/src/main/scala/leon/synthesis/rules/IntegerInequalities.scala @@ -95,7 +95,7 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { //define max function val maxVarDecls: Seq[VarDecl] = lowerBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type)) - val maxFun = new FunDef(FreshIdentifier("max"), Int32Type, maxVarDecls) + val maxFun = new FunDef(FreshIdentifier("max"), Nil, Int32Type, maxVarDecls) def maxRec(bounds: List[Expr]): Expr = bounds match { case (x1 :: x2 :: xs) => { val v = FreshIdentifier("m").setType(Int32Type) @@ -106,10 +106,10 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { } if(!lowerBounds.isEmpty) maxFun.body = Some(maxRec(maxVarDecls.map(vd => Variable(vd.id)).toList)) - def max(xs: Seq[Expr]): Expr = FunctionInvocation(maxFun, xs) + def max(xs: Seq[Expr]): Expr = FunctionInvocation(maxFun.typed, xs) //define min function val minVarDecls: Seq[VarDecl] = upperBounds.map(_ => VarDecl(FreshIdentifier("b"), Int32Type)) - val minFun = new FunDef(FreshIdentifier("min"), Int32Type, minVarDecls) + val minFun = new FunDef(FreshIdentifier("min"), Nil, Int32Type, minVarDecls) def minRec(bounds: List[Expr]): Expr = bounds match { case (x1 :: x2 :: xs) => { val v = FreshIdentifier("m").setType(Int32Type) @@ -120,16 +120,16 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { } if(!upperBounds.isEmpty) minFun.body = Some(minRec(minVarDecls.map(vd => Variable(vd.id)).toList)) - def min(xs: Seq[Expr]): Expr = FunctionInvocation(minFun, xs) - val floorFun = new FunDef(FreshIdentifier("floorDiv"), Int32Type, Seq( + def min(xs: Seq[Expr]): Expr = FunctionInvocation(minFun.typed, xs) + val floorFun = new FunDef(FreshIdentifier("floorDiv"), Nil, Int32Type, Seq( VarDecl(FreshIdentifier("x"), Int32Type), VarDecl(FreshIdentifier("x"), Int32Type))) - val ceilingFun = new FunDef(FreshIdentifier("ceilingDiv"), Int32Type, Seq( + val ceilingFun = new FunDef(FreshIdentifier("ceilingDiv"), Nil, Int32Type, Seq( VarDecl(FreshIdentifier("x"), Int32Type), VarDecl(FreshIdentifier("x"), Int32Type))) ceilingFun.body = Some(IntLiteral(0)) - def floorDiv(x: Expr, y: Expr): Expr = FunctionInvocation(floorFun, Seq(x, y)) - def ceilingDiv(x: Expr, y: Expr): Expr = FunctionInvocation(ceilingFun, Seq(x, y)) + def floorDiv(x: Expr, y: Expr): Expr = FunctionInvocation(floorFun.typed, Seq(x, y)) + def ceilingDiv(x: Expr, y: Expr): Expr = FunctionInvocation(ceilingFun.typed, Seq(x, y)) val witness: Expr = if(upperBounds.isEmpty) { if(lowerBounds.size > 1) max(lowerBounds.map{case (b, c) => ceilingDiv(b, IntLiteral(c))}) @@ -192,7 +192,7 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { val concretePre = replace(Map(Variable(k) -> loopCounter), pre) val concreteTerm = replace(Map(Variable(k) -> loopCounter), term) val returnType = TupleType(problem.xs.map(_.getType)) - val funDef = new FunDef(FreshIdentifier("rec", true), returnType, Seq(VarDecl(loopCounter.id, Int32Type))) + val funDef = new FunDef(FreshIdentifier("rec", true), Nil, returnType, Seq(VarDecl(loopCounter.id, Int32Type))) val funBody = expandAndSimplifyArithmetic(IfExpr( LessThan(loopCounter, IntLiteral(0)), Error("No solution exists"), @@ -202,12 +202,12 @@ case object IntegerInequalities extends Rule("Integer Inequalities") { Let(processedVar, witness, Tuple(problem.xs.map(Variable(_)))) ), - FunctionInvocation(funDef, Seq(Minus(loopCounter, IntLiteral(1)))) + FunctionInvocation(funDef.typed, Seq(Minus(loopCounter, IntLiteral(1)))) ) )) funDef.body = Some(funBody) - Some(Solution(And(newPre, pre), defs + funDef, FunctionInvocation(funDef, Seq(IntLiteral(L-1))))) + Some(Solution(And(newPre, pre), defs + funDef, FunctionInvocation(funDef.typed, Seq(IntLiteral(L-1))))) } } case _ => diff --git a/src/main/scala/leon/termination/ChainBuilder.scala b/src/main/scala/leon/termination/ChainBuilder.scala index 936e6a9a02fccd862f1f4c921ed3fbb9d11a2b5b..6d0427796af77fee9b6a050c3249ddc9d80cc459 100644 --- a/src/main/scala/leon/termination/ChainBuilder.scala +++ b/src/main/scala/leon/termination/ChainBuilder.scala @@ -29,16 +29,16 @@ final case class Chain(chain: List[Relation]) { def loop(initialSubst: Map[Identifier, Expr] = Map(), finalSubst: Map[Identifier, Expr] = Map()) : Seq[Expr] = { def rec(relations: List[Relation], subst: Map[Identifier, Expr]): Seq[Expr] = relations match { - case Relation(_, path, FunctionInvocation(fd, args)) :: Nil => - assert(fd == funDef) + case Relation(_, path, FunctionInvocation(tfd, args)) :: Nil => + assert(tfd.fd == funDef) val newPath = path.map(replaceFromIDs(subst, _)) val equalityConstraints = if (finalSubst.isEmpty) Seq() else { val newArgs = args.map(replaceFromIDs(subst, _)) - (fd.args.map(arg => finalSubst(arg.id)) zip newArgs).map(p => Equals(p._1, p._2)) + (tfd.args.map(arg => finalSubst(arg.id)) zip newArgs).map(p => Equals(p._1, p._2)) } newPath ++ equalityConstraints - case Relation(_, path, FunctionInvocation(fd, args)) :: xs => - val formalArgs = fd.args.map(_.id) + case Relation(_, path, FunctionInvocation(tfd, args)) :: xs => + val formalArgs = tfd.args.map(_.id) val freshFormalArgVars = formalArgs.map(_.freshen.toVariable) val formalArgsMap: Map[Identifier, Expr] = (formalArgs zip freshFormalArgVars).toMap val (newPath, newArgs) = (path.map(replaceFromIDs(subst, _)), args.map(replaceFromIDs(subst, _))) @@ -92,15 +92,15 @@ class ChainBuilder(relationBuilder: RelationBuilder) { def chains(partials: List[(Relation, List[Relation])]): List[List[Relation]] = if (partials.isEmpty) Nil else { // Note that chains in partials are reversed to profit from O(1) insertion val (results, newPartials) = partials.foldLeft(List[List[Relation]](),List[(Relation, List[Relation])]())({ - case ((results, partials), (first, chain @ Relation(_, _, FunctionInvocation(fd, _)) :: xs)) => - val cycle = relationBuilder.run(fd).contains(first) + case ((results, partials), (first, chain @ Relation(_, _, FunctionInvocation(tfd, _)) :: xs)) => + val cycle = relationBuilder.run(tfd.fd).contains(first) // reverse the chain when "returning" it since we're working on reversed chains val newResults = if (cycle) chain.reverse :: results else results // Partial chains can fall back onto a transition that was already taken (thus creating a cycle // inside the chain). Since this cycle will be discovered elsewhere, such partial chains should be // dropped from the partial chain list - val transitions = relationBuilder.run(fd) -- chain.toSet + val transitions = relationBuilder.run(tfd.fd) -- chain.toSet val newPartials = transitions.map(transition => (first, transition :: chain)).toList (newResults, partials ++ newPartials) diff --git a/src/main/scala/leon/termination/ChainComparator.scala b/src/main/scala/leon/termination/ChainComparator.scala index dc61531d30279f61c61b42b6dd575f270f206ade..b6720b6aa2e259171ebf2bd75bea63435abbfc2d 100644 --- a/src/main/scala/leon/termination/ChainComparator.scala +++ b/src/main/scala/leon/termination/ChainComparator.scala @@ -4,6 +4,7 @@ package termination import purescala.Trees._ import purescala.TreeOps._ import purescala.TypeTrees._ +import purescala.TypeTreeOps._ import purescala.Definitions._ import purescala.Common._ @@ -11,20 +12,27 @@ class ChainComparator(structuralSize: StructuralSize) { import structuralSize.size private object ContainerType { - def unapply(c: ClassType): Option[(CaseClassDef, Seq[(Identifier, TypeTree)])] = c match { - case CaseClassType(classDef) => - if (classDef.fields.exists(arg => isSubtypeOf(arg.tpe, classDef.parent.map(AbstractClassType(_)).getOrElse(c)))) None - else if (classDef.hasParent && classDef.parent.get.knownChildren.size > 1) None - else Some((classDef, classDef.fields.map(arg => arg.id -> arg.tpe))) + def unapply(c: ClassType): Option[(CaseClassType, Seq[(Identifier, TypeTree)])] = c match { + case act @ CaseClassType(classDef, tpes) => + val ftps = act.fields + val parentType = classDef.parent.getOrElse(c) + + if (ftps.exists(ad => isSubtypeOf(ad.tpe, parentType))) { + None + } else if (classDef.parent.map(_.classDef.knownChildren.size > 1).getOrElse(false)) { + None + } else { + Some((act, ftps.map{ ad => ad.id -> ad.tpe })) + } case _ => None } } def sizeDecreasing(e1: Expr, e2s: Seq[(Seq[Expr], Expr)]) : Expr = e1.getType match { - case ContainerType(def1, fields1) => Or(fields1.zipWithIndex map { case ((id1, type1), index) => - sizeDecreasing(CaseClassSelector(def1, e1, id1), e2s.map { case (path, e2) => + case ContainerType(ct1, fields1) => Or(fields1.zipWithIndex map { case ((id1, type1), index) => + sizeDecreasing(CaseClassSelector(ct1, e1, id1), e2s.map { case (path, e2) => e2.getType match { - case ContainerType(def2, fields2) => (path, CaseClassSelector(def2, e2, fields2(index)._1)) + case ContainerType(ct2, fields2) => (path, CaseClassSelector(ct2, e2, fields2(index)._1)) case _ => scala.sys.error("Unexpected input combinations: " + e1 + " " + e2) } }) @@ -128,8 +136,8 @@ class ChainComparator(structuralSize: StructuralSize) { case NoEndpoint => endpoint(thenn) min endpoint(elze) case ep => - val terminatingThen = functionCallsOf(thenn).forall(fi => checker.terminates(fi.funDef).isGuaranteed) - val terminatingElze = functionCallsOf(elze).forall(fi => checker.terminates(fi.funDef).isGuaranteed) + val terminatingThen = functionCallsOf(thenn).forall(fi => checker.terminates(fi.tfd.fd).isGuaranteed) + val terminatingElze = functionCallsOf(elze).forall(fi => checker.terminates(fi.tfd.fd).isGuaranteed) val thenEndpoint = if (terminatingThen) ep max endpoint(thenn) else endpoint(thenn) val elzeEndpoint = if (terminatingElze) ep.inverse max endpoint(elze) else endpoint(elze) thenEndpoint max elzeEndpoint diff --git a/src/main/scala/leon/termination/LoopProcessor.scala b/src/main/scala/leon/termination/LoopProcessor.scala index e33f1c055e4f53a4a5e6106674572664c5b43f94..0bede44f8a80dc9779d158493e1c4f6689733ca8 100644 --- a/src/main/scala/leon/termination/LoopProcessor.scala +++ b/src/main/scala/leon/termination/LoopProcessor.scala @@ -25,7 +25,7 @@ class LoopProcessor(checker: TerminationChecker, val formula = And(path :+ Equals(Tuple(chain.funDef.args.map(_.toVariable)), Tuple(freshArgs))) val solvable = functionCallsOf(formula).forall({ - case FunctionInvocation(fd, args) => checker.terminates(fd).isGuaranteed + case FunctionInvocation(tfd, args) => checker.terminates(tfd.fd).isGuaranteed }) if (!solvable) None else getModel(formula) match { diff --git a/src/main/scala/leon/termination/Processor.scala b/src/main/scala/leon/termination/Processor.scala index 6c72adb980c72933092ebd16878147f1ff1eb251..8b3a8e36e00468438f5324976660179aeb5afb07 100644 --- a/src/main/scala/leon/termination/Processor.scala +++ b/src/main/scala/leon/termination/Processor.scala @@ -95,8 +95,8 @@ trait Solvable { self: Processor => val structDefs = structuralSize.defs if (structDefs != lastDefs || solvers == null) { val program : Program = self.checker.program - val allDefs : Seq[Definition] = program.mainObject.defs ++ structDefs - val newProgram : Program = program.copy(mainObject = program.mainObject.copy(defs = allDefs)) + val allDefs : Seq[Definition] = program.mainModule.defs ++ structDefs + val newProgram : Program = program.copy(mainModule = program.mainModule.copy(defs = allDefs)) val context : LeonContext = self.checker.context solvers = new TimeoutSolverFactory(SolverFactory(() => new FairZ3Solver(context, newProgram) with TimeoutSolver), 500) :: Nil @@ -111,7 +111,7 @@ trait Solvable { self: Processor => // make Leon unroll them forever...) val dangerousCallsMap : Map[Expr, Expr] = functionCallsOf(problem).collect({ // extra definitions (namely size functions) are quaranteed to terminate because structures are non-looping - case fi @ FunctionInvocation(fd, args) if !structuralSize.defs(fd) && !self.checker.terminates(fd).isGuaranteed => + case fi @ FunctionInvocation(tfd, args) if !structuralSize.defs(tfd.fd) && !self.checker.terminates(tfd.fd).isGuaranteed => fi -> FreshIdentifier("noRun", true).setType(fi.getType).toVariable }).toMap diff --git a/src/main/scala/leon/termination/RecursionProcessor.scala b/src/main/scala/leon/termination/RecursionProcessor.scala index 1bbd88fd4eb4054b26326ab05522d2210b7d75c9..6ea2048a4916497372055f78ceefe70c6cae329d 100644 --- a/src/main/scala/leon/termination/RecursionProcessor.scala +++ b/src/main/scala/leon/termination/RecursionProcessor.scala @@ -26,7 +26,7 @@ class RecursionProcessor(checker: TerminationChecker, relationBuilder: RelationB val relations = relationBuilder.run(funDef) val (recursive, others) = relations.partition({ case Relation(_, _, FunctionInvocation(fd, _)) => fd == funDef }) - if (others.exists({ case Relation(_, _, FunctionInvocation(fd, _)) => !checker.terminates(fd).isGuaranteed })) (Nil, List(problem)) else { + if (others.exists({ case Relation(_, _, FunctionInvocation(tfd, _)) => !checker.terminates(tfd.fd).isGuaranteed })) (Nil, List(problem)) else { val decreases = funDef.args.zipWithIndex.exists({ case (arg, index) => recursive.forall({ case Relation(_, _, FunctionInvocation(_, args)) => isSubtreeOf(args(index), arg.id) diff --git a/src/main/scala/leon/termination/RelationBuilder.scala b/src/main/scala/leon/termination/RelationBuilder.scala index 94e10516cd646cb980d57cf9ede29d2cbb9d32fe..a5e75a622963df60a820acff836dadd5b78ce875 100644 --- a/src/main/scala/leon/termination/RelationBuilder.scala +++ b/src/main/scala/leon/termination/RelationBuilder.scala @@ -10,7 +10,7 @@ import purescala.Common._ import scala.collection.mutable.{Map => MutableMap} final case class Relation(funDef: FunDef, path: Seq[Expr], call: FunctionInvocation) { - override def toString : String = "Relation(" + funDef.id + "," + path + ", " + call.funDef.id + call.args.mkString("(",",",")") + ")" + override def toString : String = "Relation(" + funDef.id + "," + path + ", " + call.tfd.fd.id + call.args.mkString("(",",",")") + ")" } class RelationBuilder { diff --git a/src/main/scala/leon/termination/RelationProcessor.scala b/src/main/scala/leon/termination/RelationProcessor.scala index 7cb0b911126b5367c135acb5f935e0612ddb32d0..0fcd3a538526326f29cd0ddf93e5d6175ee2058d 100644 --- a/src/main/scala/leon/termination/RelationProcessor.scala +++ b/src/main/scala/leon/termination/RelationProcessor.scala @@ -22,12 +22,12 @@ class RelationProcessor(checker: TerminationChecker, val formulas = problem.funDefs.map({ funDef => funDef -> relationBuilder.run(funDef).collect({ - case Relation(_, path, FunctionInvocation(fd, args)) if problem.funDefs(fd) => + case Relation(_, path, FunctionInvocation(tfd, args)) if problem.funDefs(tfd.fd) => val (e1, e2) = (Tuple(funDef.args.map(_.toVariable)), Tuple(args)) def constraint(expr: Expr) = Implies(And(path.toSeq), expr) val greaterThan = relationComparator.sizeDecreasing(e1, e2) val greaterEquals = relationComparator.softDecreasing(e1, e2) - (fd, (constraint(greaterThan), constraint(greaterEquals))) + (tfd.fd, (constraint(greaterThan), constraint(greaterEquals))) }) }) diff --git a/src/main/scala/leon/termination/SimpleTerminationChecker.scala b/src/main/scala/leon/termination/SimpleTerminationChecker.scala index d8c5fce249bd4e2f575e514f29f03623d7134700..08fd6bea9ec56fc895ce35049ead39ce76d3a4d3 100644 --- a/src/main/scala/leon/termination/SimpleTerminationChecker.scala +++ b/src/main/scala/leon/termination/SimpleTerminationChecker.scala @@ -91,7 +91,7 @@ class SimpleTerminationChecker(context: LeonContext, program: Program) extends T oe.map { e => functionCallsOf( simplifyLets( - matchToIfThenElse(e))).filter(_.funDef == funDef) + matchToIfThenElse(e))).filter(_.tfd.fd == funDef) } getOrElse Set.empty[FunctionInvocation] } diff --git a/src/main/scala/leon/termination/StructuralSize.scala b/src/main/scala/leon/termination/StructuralSize.scala index e10a30c2824d2d06d54b8489577b66cbcc7d2f75..666d990adcb6684cedbba7b2685f6be395f02fbe 100644 --- a/src/main/scala/leon/termination/StructuralSize.scala +++ b/src/main/scala/leon/termination/StructuralSize.scala @@ -7,15 +7,16 @@ import purescala.TypeTrees._ import purescala.Definitions._ import purescala.Common._ -class StructuralSize { +class StructuralSize() { import scala.collection.mutable.{Map => MutableMap} - private val sizeFunctionCache : MutableMap[TypeTree, FunDef] = MutableMap() + private val sizeFunctionCache : MutableMap[TypeTree, TypedFunDef] = MutableMap() + def size(expr: Expr) : Expr = { def funDef(tpe: TypeTree, cases: => Seq[MatchCase]) = { // we want to reuse generic size functions for sub-types val argumentType = tpe match { - case CaseClassType(cd) if cd.parent.isDefined => classDefToClassType(cd.parent.get) + case CaseClassType(cd, tpes) if cd.parent.isDefined => classDefToClassType(cd.parent.get.classDef, tpes) case _ => tpe } @@ -23,8 +24,9 @@ class StructuralSize { case Some(fd) => fd case None => val argument = VarDecl(FreshIdentifier("x"), argumentType) - val fd = new FunDef(FreshIdentifier("size", true), Int32Type, Seq(argument)) - sizeFunctionCache(argumentType) = fd + val fd = new FunDef(FreshIdentifier("size", true), Nil, Int32Type, Seq(argument)) + val tfd = fd.typed(Nil) + sizeFunctionCache(argumentType) = tfd val body = simplifyLets(matchToIfThenElse(MatchExpr(argument.toVariable, cases))) val postId = FreshIdentifier("res", false).setType(Int32Type) @@ -34,25 +36,28 @@ class StructuralSize { fd.body = Some(body) fd.postcondition = Some(postId, postcondition) - fd + + tfd } } - def caseClassType2MatchCase(_c: ClassTypeDef): MatchCase = { - val c = _c.asInstanceOf[CaseClassDef] // required by leon framework - val arguments = c.fields.map(f => f -> f.id.freshen) - val argumentPatterns = arguments.map(p => WildcardPattern(Some(p._2))) - val sizes = arguments.map(p => size(Variable(p._2))) - val result = sizes.foldLeft[Expr](IntLiteral(1))(Plus(_,_)) - SimpleCase(CaseClassPattern(None, c, argumentPatterns), result) + def caseClassType2MatchCase(ct: ClassType): MatchCase = ct match { + case cct: CaseClassType => + val arguments = cct.fields.map(f => f -> f.id.freshen) + val argumentPatterns = arguments.map(p => WildcardPattern(Some(p._2))) + val sizes = arguments.map(p => size(Variable(p._2))) + val result = sizes.foldLeft[Expr](IntLiteral(1))(Plus(_,_)) + SimpleCase(CaseClassPattern(None, cct, argumentPatterns), result) + case _ => + sys.error("woot?") } expr.getType match { case a: AbstractClassType => - val sizeFd = funDef(a, a.classDef.knownChildren map caseClassType2MatchCase) + val sizeFd = funDef(a, a.knownCCDescendents.map(caseClassType2MatchCase)) FunctionInvocation(sizeFd, Seq(expr)) case c: CaseClassType => - val sizeFd = funDef(c, Seq(caseClassType2MatchCase(c.classDef))) + val sizeFd = funDef(c, Seq(caseClassType2MatchCase(c))) FunctionInvocation(sizeFd, Seq(expr)) case TupleType(argTypes) => argTypes.zipWithIndex.map({ case (_, index) => size(TupleSelect(expr, index + 1)) @@ -61,7 +66,7 @@ class StructuralSize { } } - def defs : Set[FunDef] = Set(sizeFunctionCache.values.toSeq : _*) + def defs : Set[FunDef] = sizeFunctionCache.values.map(_.fd).toSet } // vim: set ts=4 sw=4 et: diff --git a/src/main/scala/leon/testgen/CallGraph.scala b/src/main/scala/leon/testgen/CallGraph.scala index a8e4f367e26989e7c2351f3a9caeedd8d949d71a..997b6a667a94c3b1d1984341e79fc227d9d3be7e 100644 --- a/src/main/scala/leon/testgen/CallGraph.scala +++ b/src/main/scala/leon/testgen/CallGraph.scala @@ -80,9 +80,9 @@ class CallGraph(val program: Program) { var augmentedGraph = graph graph.foreach{ - case (point@ExpressionPoint(FunctionInvocation(fd, args), _), edges) => { - val newPoint = FunctionStart(fd) - val newTransition = TransitionLabel(BooleanLiteral(true), fd.args.zip(args).map{ case (VarDecl(id, _), arg) => (id.toVariable, arg) }.toMap) + case (point@ExpressionPoint(FunctionInvocation(tfd, args), _), edges) => { + val newPoint = FunctionStart(tfd.fd) + val newTransition = TransitionLabel(BooleanLiteral(true), tfd.args.zip(args).map{ case (VarDecl(id, _), arg) => (id.toVariable, arg) }.toMap) augmentedGraph += (point -> (edges + ((newPoint, newTransition)))) } case _ => ; diff --git a/src/main/scala/leon/testgen/TestGeneration.scala b/src/main/scala/leon/testgen/TestGeneration.scala index ce6d2fa6df8661678cc3a8e79c990b8854f235fe..ea62e67e6b4a51afef74be56f3affe115ad8687d 100644 --- a/src/main/scala/leon/testgen/TestGeneration.scala +++ b/src/main/scala/leon/testgen/TestGeneration.scala @@ -33,7 +33,7 @@ class TestGeneration(context : LeonContext) { val topFunDef = program.definedFunctions.find(fd => isMain(fd)).get - val testFun = new FunDef(FreshIdentifier("test"), UnitType, Seq()) + val testFun = new FunDef(FreshIdentifier("test"), Nil, UnitType, Seq()) val funInvocs = testcases.map(testcase => { val params = topFunDef.args val args = topFunDef.args.map{ @@ -42,12 +42,12 @@ class TestGeneration(context : LeonContext) { case None => simplestValue(tpe) } } - FunctionInvocation(topFunDef, args) + FunctionInvocation(topFunDef.typed, args) }).toSeq testFun.body = Some(Block(funInvocs, UnitLiteral)) - val Program(id, ObjectDef(objId, defs, invariants)) = program - val testProgram = Program(id, ObjectDef(objId, testFun +: defs , invariants)) + val Program(id, ModuleDef(objId, defs, invariants)) = program + val testProgram = Program(id, ModuleDef(objId, testFun +: defs , invariants)) testProgram.writeScalaFile("TestGen.scalax") reporter.info("Running from waypoint with the following testcases:\n") diff --git a/src/main/scala/leon/utils/SubtypingPhase.scala b/src/main/scala/leon/utils/SubtypingPhase.scala index 37230604196a655cf68acdbe23eb6066e4bfcd6f..434db7e24cbe46d950cf6167d1f17ee8ee97501d 100644 --- a/src/main/scala/leon/utils/SubtypingPhase.scala +++ b/src/main/scala/leon/utils/SubtypingPhase.scala @@ -20,7 +20,7 @@ object SubtypingPhase extends LeonPhase[Program, Program] { fd.precondition = { val argTypesPreconditions = fd.args.flatMap(arg => arg.tpe match { - case cct@CaseClassType(cd) => Seq(CaseClassInstanceOf(cd, arg.id.toVariable)) + case cct : CaseClassType => Seq(CaseClassInstanceOf(cct, arg.id.toVariable)) case _ => Seq() }) argTypesPreconditions match { @@ -33,16 +33,16 @@ object SubtypingPhase extends LeonPhase[Program, Program] { } fd.postcondition = fd.returnType match { - case cct@CaseClassType(cd) => { + case cct : CaseClassType => { fd.postcondition match { case Some((id, p)) => - Some((id, And(CaseClassInstanceOf(cd, Variable(id)), p))) + Some((id, And(CaseClassInstanceOf(cct, Variable(id)), p))) case None => val resId = FreshIdentifier("res").setType(cct) - Some((resId, CaseClassInstanceOf(cd, Variable(resId)))) + Some((resId, CaseClassInstanceOf(cct, Variable(resId)))) } } case _ => fd.postcondition diff --git a/src/main/scala/leon/utils/UnitElimination.scala b/src/main/scala/leon/utils/UnitElimination.scala index dfce1fbfea895ff7e48d164f5c249024a3365d7f..18953969e174a9094e4ef1589703ec134c8b8810 100644 --- a/src/main/scala/leon/utils/UnitElimination.scala +++ b/src/main/scala/leon/utils/UnitElimination.scala @@ -25,7 +25,7 @@ object UnitElimination extends TransformationPhase { //first introduce new signatures without Unit parameters allFuns.foreach(fd => { if(fd.returnType != UnitType && fd.args.exists(vd => vd.tpe == UnitType)) { - val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPos(fd) + val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.tparams, fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPos(fd) freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well.. freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well.. freshFunDef.addAnnotation(fd.annotations.toSeq:_*) @@ -43,9 +43,9 @@ object UnitElimination extends TransformationPhase { Seq(newFd) }) - val Program(id, ObjectDef(objId, _, invariants)) = pgm + val Program(id, ModuleDef(objId, _, invariants)) = pgm val allClasses = pgm.definedClasses - Program(id, ObjectDef(objId, allClasses ++ newFuns, invariants)) + Program(id, ModuleDef(objId, allClasses ++ newFuns, invariants)) } private def simplifyType(tpe: TypeTree): TypeTree = tpe match { @@ -61,9 +61,9 @@ object UnitElimination extends TransformationPhase { private def removeUnit(expr: Expr): Expr = { assert(expr.getType != UnitType) expr match { - case fi@FunctionInvocation(fd, args) => { + case fi@FunctionInvocation(tfd, args) => { val newArgs = args.filterNot(arg => arg.getType == UnitType) - FunctionInvocation(fun2FreshFun(fd), newArgs).setPos(fi) + FunctionInvocation(fun2FreshFun(tfd.fd).typed(tfd.tps), newArgs).setPos(fi) } case t@Tuple(args) => { val TupleType(tpes) = t.getType @@ -101,7 +101,7 @@ object UnitElimination extends TransformationPhase { removeUnit(b) else { val (newFd, rest) = if(fd.args.exists(vd => vd.tpe == UnitType)) { - val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPos(fd) + val freshFunDef = new FunDef(FreshIdentifier(fd.id.name), fd.tparams, fd.returnType, fd.args.filterNot(vd => vd.tpe == UnitType)).setPos(fd) freshFunDef.addAnnotation(fd.annotations.toSeq:_*) freshFunDef.precondition = fd.precondition //TODO: maybe removing unit from the conditions as well.. freshFunDef.postcondition = fd.postcondition//TODO: maybe removing unit from the conditions as well.. diff --git a/src/main/scala/leon/verification/DefaultTactic.scala b/src/main/scala/leon/verification/DefaultTactic.scala index f2ec7534ac4f08607c7324f91598d299320c6f16..e9463b03f4f63dff63b56c247ae3ea09dce864f7 100644 --- a/src/main/scala/leon/verification/DefaultTactic.scala +++ b/src/main/scala/leon/verification/DefaultTactic.scala @@ -48,16 +48,17 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) { withPrec } - Seq(new VerificationCondition(theExpr, functionDefinition, VCKind.Postcondition, this)) + Seq(new VerificationCondition(theExpr, functionDefinition, VCKind.Postcondition, this).setPos(post)) } } def generatePreconditions(function: FunDef) : Seq[VerificationCondition] = { val toRet = if(function.hasBody) { - val cleanBody = expandLets(matchToIfThenElse(function.body.get)) + val pre = matchToIfThenElse(function.body.get) + val cleanBody = expandLets(pre) val allPathConds = collectWithPathCondition((t => t match { - case FunctionInvocation(fd, _) if(fd.hasPrecondition) => true + case FunctionInvocation(tfd, _) if(tfd.hasPrecondition) => true case _ => false }), cleanBody) @@ -70,20 +71,13 @@ class DefaultTactic(reporter: Reporter) extends Tactic(reporter) { allPathConds.map(pc => { val path : Seq[Expr] = pc._1 val fi = pc._2.asInstanceOf[FunctionInvocation] - val FunctionInvocation(fd, args) = fi - val prec : Expr = freshenLocals(matchToIfThenElse(fd.precondition.get)) - val newLetIDs = fd.args.map(a => FreshIdentifier("arg_" + a.id.name, true).setType(a.tpe)) - val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) + val FunctionInvocation(tfd, args) = fi + val prec : Expr = freshenLocals(matchToIfThenElse(tfd.precondition.get)) + val newLetIDs = tfd.args.map(a => FreshIdentifier("arg_" + a.id.name, true).setType(a.tpe)) + val substMap = Map[Expr,Expr]((tfd.args.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) val newBody : Expr = replace(substMap, prec) val newCall : Expr = (newLetIDs zip args).foldRight(newBody)((iap, e) => Let(iap._1, iap._2, e)) - //if(fd.fromLoop) - // new VerificationCondition( - // withPrecIfDefined(path, newCall), - // fd.parent.get, - // if(fd == function) VCKind.InvariantInd else VCKind.InvariantInit, - // this.asInstanceOf[DefaultTactic]).setPosInfo(fd) - //else new VerificationCondition( withPrecIfDefined(path, newCall), function, diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index 71d7a31c5568d3ffa8cbdce2c84b5a83bff15689..91b92d981b45bab0b05fb06d8273a39ba6ffe276 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -13,22 +13,22 @@ class InductionTactic(reporter: Reporter) extends DefaultTactic(reporter) { override val description = "Induction tactic for suitable functions" override val shortDescription = "induction" - private def firstAbsClassDef(args: VarDecls) : Option[(AbstractClassDef, VarDecl)] = { + private def firstAbsClassDef(args: Seq[VarDecl]) : Option[(AbstractClassDef, VarDecl)] = { val filtered = args.filter(arg => arg.getType match { - case AbstractClassType(_) => true + case AbstractClassType(_, _) => true case _ => false }) if (filtered.size == 0) None else (filtered.head.getType match { - case AbstractClassType(classDef) => Some((classDef, filtered.head)) + case AbstractClassType(classDef, _) => Some((classDef, filtered.head)) case _ => scala.sys.error("This should not happen.") }) } - private def selectorsOfParentType(parentType: ClassType, ccd: CaseClassDef, expr: Expr) : Seq[Expr] = { - val childrenOfSameType = ccd.fields.filter(field => field.getType == parentType) + private def selectorsOfParentType(parentType: ClassType, cct: CaseClassType, expr: Expr) : Seq[Expr] = { + val childrenOfSameType = cct.fields.filter(_.tpe == parentType) for (field <- childrenOfSameType) yield { - CaseClassSelector(ccd, expr, field.id).setType(parentType) + CaseClassSelector(cct, expr, field.id) } } @@ -41,36 +41,33 @@ class InductionTactic(reporter: Reporter) extends DefaultTactic(reporter) { val optPost = funDef.postcondition val body = matchToIfThenElse(funDef.body.get) val argAsVar = arg.toVariable + val parentType = classDefToClassType(classDef) optPost match { case None => Seq.empty case Some((pid, post)) => - val children = classDef.knownChildren - val conditionsForEachChild = (for (child <- classDef.knownChildren) yield (child match { - case ccd @ CaseClassDef(id, prnt, vds) => - val selectors = selectorsOfParentType(classDefToClassType(classDef), ccd, argAsVar) + for (cct <- parentType.knownCCDescendents) yield { + val selectors = selectorsOfParentType(parentType, cct, argAsVar) // if no subtrees of parent type, assert property for base case - val resFresh = FreshIdentifier("result", true).setType(body.getType) - val bodyAndPostForArg = Let(resFresh, body, replace(Map(Variable(pid) -> Variable(resFresh)), matchToIfThenElse(post))) - val withPrec = if (prec.isEmpty) bodyAndPostForArg else Implies(matchToIfThenElse(prec.get), bodyAndPostForArg) - - val conditionForChild = - if (selectors.size == 0) + val resFresh = FreshIdentifier("result", true).setType(body.getType) + val bodyAndPostForArg = Let(resFresh, body, replace(Map(Variable(pid) -> Variable(resFresh)), matchToIfThenElse(post))) + val withPrec = if (prec.isEmpty) bodyAndPostForArg else Implies(matchToIfThenElse(prec.get), bodyAndPostForArg) + + val conditionForChild = + if (selectors.size == 0) + withPrec + else { + val inductiveHypothesis = (for (sel <- selectors) yield { + val resFresh = FreshIdentifier("result", true).setType(body.getType) + val bodyAndPost = Let(resFresh, replace(Map(argAsVar -> sel), body), replace(Map(Variable(pid) -> Variable(resFresh), argAsVar -> sel), matchToIfThenElse(post))) + val withPrec = if (prec.isEmpty) bodyAndPost else Implies(replace(Map(argAsVar -> sel), matchToIfThenElse(prec.get)), bodyAndPost) withPrec - else { - val inductiveHypothesis = (for (sel <- selectors) yield { - val resFresh = FreshIdentifier("result", true).setType(body.getType) - val bodyAndPost = Let(resFresh, replace(Map(argAsVar -> sel), body), replace(Map(Variable(pid) -> Variable(resFresh), argAsVar -> sel), matchToIfThenElse(post))) - val withPrec = if (prec.isEmpty) bodyAndPost else Implies(replace(Map(argAsVar -> sel), matchToIfThenElse(prec.get)), bodyAndPost) - withPrec - }) - Implies(And(inductiveHypothesis), withPrec) - } - new VerificationCondition(Implies(CaseClassInstanceOf(ccd, argAsVar), conditionForChild), funDef, VCKind.Postcondition, this) - case _ => scala.sys.error("Abstract class has non-case class subtype.") - })) - conditionsForEachChild + }) + Implies(And(inductiveHypothesis), withPrec) + } + new VerificationCondition(Implies(CaseClassInstanceOf(cct, argAsVar), conditionForChild), funDef, VCKind.Postcondition, this) + } } case None => @@ -84,10 +81,11 @@ class InductionTactic(reporter: Reporter) extends DefaultTactic(reporter) { firstAbsClassDef(function.args) match { case Some((classDef, arg)) => { val toRet = if(function.hasBody) { + val parentType = classDefToClassType(classDef) val cleanBody = expandLets(matchToIfThenElse(function.body.get)) val allPathConds = collectWithPathCondition((t => t match { - case FunctionInvocation(fd, _) if(fd.hasPrecondition) => true + case FunctionInvocation(tfd, _) if(tfd.hasPrecondition) => true case _ => false }), cleanBody) @@ -100,42 +98,38 @@ class InductionTactic(reporter: Reporter) extends DefaultTactic(reporter) { val conditionsForAllPaths : Seq[Seq[VerificationCondition]] = allPathConds.map(pc => { val path : Seq[Expr] = pc._1 val fi = pc._2.asInstanceOf[FunctionInvocation] - val FunctionInvocation(fd, args) = fi - - val conditionsForEachChild = (for (child <- classDef.knownChildren) yield (child match { - case ccd @ CaseClassDef(id, prnt, vds) => { - val argAsVar = arg.toVariable - val selectors = selectorsOfParentType(classDefToClassType(classDef), ccd, argAsVar) - - val prec : Expr = freshenLocals(matchToIfThenElse(fd.precondition.get)) - val newLetIDs = fd.args.map(a => FreshIdentifier("arg_" + a.id.name, true).setType(a.tpe)) - val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) - val newBody : Expr = replace(substMap, prec) - val newCall : Expr = (newLetIDs zip args).foldRight(newBody)((iap, e) => Let(iap._1, iap._2, e)) - - val toProve = withPrec(path, newCall) - - val conditionForChild = - if (selectors.isEmpty) - toProve - else { - val inductiveHypothesis = (for (sel <- selectors) yield { - val prec : Expr = freshenLocals(matchToIfThenElse(fd.precondition.get)) - val newLetIDs = fd.args.map(a => FreshIdentifier("arg_" + a.id.name, true).setType(a.tpe)) - val substMap = Map[Expr,Expr]((fd.args.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) - val newBody : Expr = replace(substMap, prec) - val newCall : Expr = (newLetIDs zip args).foldRight(newBody)((iap, e) => Let(iap._1, iap._2, e)) - - val toReplace = withPrec(path, newCall) - replace(Map(argAsVar -> sel), toReplace) - }) - Implies(And(inductiveHypothesis), toProve) - } - new VerificationCondition(Implies(CaseClassInstanceOf(ccd, argAsVar), conditionForChild), function, VCKind.Precondition, this).setPos(fi) - } - case _ => scala.sys.error("Abstract class has non-case class subtype") - })) - conditionsForEachChild + val FunctionInvocation(tfd, args) = fi + + for (cct <- parentType.knownCCDescendents) yield { + val argAsVar = arg.toVariable + val selectors = selectorsOfParentType(parentType, cct, argAsVar) + + val prec : Expr = freshenLocals(matchToIfThenElse(tfd.precondition.get)) + val newLetIDs = tfd.args.map(a => FreshIdentifier("arg_" + a.id.name, true).setType(a.tpe)) + val substMap = Map[Expr,Expr]((tfd.args.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) + val newBody : Expr = replace(substMap, prec) + val newCall : Expr = (newLetIDs zip args).foldRight(newBody)((iap, e) => Let(iap._1, iap._2, e)) + + val toProve = withPrec(path, newCall) + + val conditionForChild = + if (selectors.isEmpty) + toProve + else { + val inductiveHypothesis = (for (sel <- selectors) yield { + val prec : Expr = freshenLocals(matchToIfThenElse(tfd.precondition.get)) + val newLetIDs = tfd.args.map(a => FreshIdentifier("arg_" + a.id.name, true).setType(a.tpe)) + val substMap = Map[Expr,Expr]((tfd.args.map(_.toVariable) zip newLetIDs.map(Variable(_))) : _*) + val newBody : Expr = replace(substMap, prec) + val newCall : Expr = (newLetIDs zip args).foldRight(newBody)((iap, e) => Let(iap._1, iap._2, e)) + + val toReplace = withPrec(path, newCall) + replace(Map(argAsVar -> sel), toReplace) + }) + Implies(And(inductiveHypothesis), toProve) + } + new VerificationCondition(Implies(CaseClassInstanceOf(cct, argAsVar), conditionForChild), function, VCKind.Precondition, this).setPos(fi) + } }).toSeq conditionsForAllPaths.flatten diff --git a/src/main/scala/leon/xlang/EpsilonElimination.scala b/src/main/scala/leon/xlang/EpsilonElimination.scala index 1c92fc659204e07354f19b5482af74e111b91e77..771dfd022e62c38e7b8d9bc3a036daae1677cf67 100644 --- a/src/main/scala/leon/xlang/EpsilonElimination.scala +++ b/src/main/scala/leon/xlang/EpsilonElimination.scala @@ -23,12 +23,12 @@ object EpsilonElimination extends TransformationPhase { val newBody = postMap{ case eps@Epsilon(pred) => val freshName = FreshIdentifier("epsilon") - val newFunDef = new FunDef(freshName, eps.getType, Seq()) + val newFunDef = new FunDef(freshName, Nil, eps.getType, Seq()) val epsilonVar = EpsilonVariable(eps.getPos) val resId = FreshIdentifier("res").setType(eps.getType) val postcondition = replace(Map(epsilonVar -> Variable(resId)), pred) newFunDef.postcondition = Some((resId, postcondition)) - Some(LetDef(newFunDef, FunctionInvocation(newFunDef, Seq()))) + Some(LetDef(newFunDef, FunctionInvocation(newFunDef.typed, Seq()))) case _ => None diff --git a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala index d77bf920805be4de79aa733169e918067421e7bf..db5bd025af9339c4d287f1f287ad4d6ac120a505 100644 --- a/src/main/scala/leon/xlang/ImperativeCodeElimination.scala +++ b/src/main/scala/leon/xlang/ImperativeCodeElimination.scala @@ -11,6 +11,7 @@ import leon.purescala.Trees._ import leon.purescala.Extractors._ import leon.purescala.TypeTrees._ import leon.purescala.TreeOps._ +import leon.purescala.TypeTreeOps._ import leon.xlang.Trees._ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef])] { @@ -155,12 +156,12 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef val modifiedVars2WhileFunVars = modifiedVars.zip(whileFunVars).toMap val whileFunVarDecls = whileFunVars.map(id => VarDecl(id, id.getType)) val whileFunReturnType = if(whileFunVars.size == 1) whileFunVars.head.getType else TupleType(whileFunVars.map(_.getType)) - val whileFunDef = new FunDef(FreshIdentifier(parent.id.name), whileFunReturnType, whileFunVarDecls).setPos(wh) + val whileFunDef = new FunDef(FreshIdentifier(parent.id.name), Nil, whileFunReturnType, whileFunVarDecls).setPos(wh) wasLoop += whileFunDef val whileFunCond = condRes val whileFunRecursiveCall = replaceNames(condFun, - bodyScope(FunctionInvocation(whileFunDef, modifiedVars.map(id => condBodyFun(id).toVariable)).setPos(wh))) + bodyScope(FunctionInvocation(whileFunDef.typed, modifiedVars.map(id => condBodyFun(id).toVariable)).setPos(wh))) val whileFunBaseCase = (if(whileFunVars.size == 1) condFun.get(modifiedVars.head).getOrElse(whileFunVars.head).toVariable @@ -199,7 +200,7 @@ object ImperativeCodeElimination extends LeonPhase[Program, (Program, Set[FunDef LetDef( whileFunDef, Let(tupleId, - FunctionInvocation(whileFunDef, modifiedVars.map(_.toVariable)).setPos(wh), + FunctionInvocation(whileFunDef.typed, modifiedVars.map(_.toVariable)).setPos(wh), if(finalVars.size == 1) Let(finalVars.head, tupleId.toVariable, body) else diff --git a/src/test/resources/regression/verification/purescala/invalid/Generics.scala b/src/test/resources/regression/verification/purescala/invalid/Generics.scala new file mode 100644 index 0000000000000000000000000000000000000000..0afd38848fe180d3112c1730ac0a2f81baa155d1 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/Generics.scala @@ -0,0 +1,13 @@ +import leon.Utils._ + +object Generics1 { + abstract class List[T] + case class Cons[A](head: A, tail: List[A]) extends List[A] + case class Nil[B]() extends List[B] + + def size[T](l: List[T]): Int = (l match { + case Nil() => 0 + case Cons(h, t) => 1+size(t) + })ensuring { _ > 0 } + +} diff --git a/src/test/resources/regression/verification/purescala/invalid/Generics2.scala b/src/test/resources/regression/verification/purescala/invalid/Generics2.scala new file mode 100644 index 0000000000000000000000000000000000000000..ae8f9f458858b23687167372a2ef65cabcc5b43f --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/Generics2.scala @@ -0,0 +1,23 @@ +import leon.Utils._ + +object Generics1 { + abstract class List[T] + case class Cons[A](head: A, tail: List[A]) extends List[A] + case class Nil[B]() extends List[B] + + def size[T](l: List[T]): Int = (l match { + case Nil() => 0 + case Cons(h, t) => 1+size(t) + })ensuring { _ >= 0 } + + def foo[T](l: List[T]): List[T] = { + require(size(l) < 2) + + l + } + + def bar(l: List[Int]) = { + foo(l) + } + +} diff --git a/src/test/resources/regression/verification/purescala/valid/Generics.scala b/src/test/resources/regression/verification/purescala/valid/Generics.scala new file mode 100644 index 0000000000000000000000000000000000000000..b930b5f2c9932d06347bd40d89004f8ef3f83c7b --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Generics.scala @@ -0,0 +1,21 @@ +import leon.Utils._ + +object Generics1 { + abstract class List[T] + case class Cons[A](head: A, tail: List[A]) extends List[A] + case class Nil[B]() extends List[B] + + def size[T](l: List[T]): Int = (l match { + case Nil() => 0 + case Cons(h, t) => 1+size(t) + })ensuring { _ >= 0 } + + def content[T](l: List[T]): Set[T] = l match { + case Nil() => Set() + case Cons(h, t) => Set(h) ++ content(t) + } + + def insert[T](a: T, l: List[T]): List[T] = { + Cons(a, l) + } ensuring { res => (size(res) == size(l) + 1) && (content(res) == content(l) ++ Set(a))} +} diff --git a/src/test/resources/regression/verification/purescala/valid/Generics2.scala b/src/test/resources/regression/verification/purescala/valid/Generics2.scala new file mode 100644 index 0000000000000000000000000000000000000000..f1078cca5a0ef04d4fa76ecfaab6f161acaaa92f --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Generics2.scala @@ -0,0 +1,25 @@ +import leon.Utils._ + +object Generics1 { + abstract class List[T] + case class Cons[A](head: A, tail: List[A]) extends List[A] + case class Nil[B]() extends List[B] + + def size[T](l: List[T]): Int = (l match { + case Nil() => 0 + case Cons(h, t) => 1+size(t) + })ensuring { _ >= 0 } + + def content[T](l: List[T]): Set[T] = l match { + case Nil() => Set() + case Cons(h, t) => Set(h) ++ content(t) + } + + def insert[T](a: T, l: List[T]): List[T] = { + Cons(a, l) + } ensuring { res => (size(res) == size(l) + 1) && (content(res) == content(l) ++ Set(a))} + + def insertInt(a: Int, l: List[Int]): List[Int] = { + insert(a,l) + } ensuring { res => (size(res) == size(l) + 1) && (content(res) == content(l) ++ Set(a))} +} diff --git a/src/test/scala/leon/test/condabd/EvaluationTest.scala b/src/test/scala/leon/test/condabd/EvaluationTest.scala index c6b8a4ed86fd449f0a4fba781b44edc66a507878..b0bf490eab5016fc1910475a1a30799a40954af7 100644 --- a/src/test/scala/leon/test/condabd/EvaluationTest.scala +++ b/src/test/scala/leon/test/condabd/EvaluationTest.scala @@ -159,7 +159,7 @@ class EvaluationTest extends FunSuite { postMap(replaceFun)( program.definedFunctions.find(_.id.name == "testFun2").get.body.get) - val evaluationStrategy = new CodeGenEvaluationStrategy(program, funDef, sctx.context, 500) + val evaluationStrategy = new CodeGenEvaluationStrategy(program, funDef.typed, sctx.context, 500) val candidates = IndexedSeq(body1, body2) map (b => Output(b, 0)) @@ -219,13 +219,13 @@ class EvaluationTest extends FunSuite { import TreeOps._ val newFunId = FreshIdentifier("tempIntroducedFunction") - val newFun = new FunDef(newFunId, funDef.returnType, funDef.args) + val newFun = new FunDef(newFunId, Nil, funDef.returnType, funDef.args) newFun.precondition = funDef.precondition newFun.postcondition = funDef.postcondition def replaceFunDef(expr: Expr) = expr match { case FunctionInvocation(`funDef`, args) => - Some(FunctionInvocation(newFun, args)) + Some(FunctionInvocation(newFun.typed, args)) case _ => None } val newBody = postMap(replaceFunDef)(expr) @@ -251,7 +251,7 @@ class EvaluationTest extends FunSuite { val params = CodeGenParams(maxFunctionInvocations = 500, checkContracts = true) val evaluator = new CodeGenEvaluator(sctx.context, - program.copy(mainObject = program.mainObject.copy(defs = program.mainObject.defs ++ pairs.map(_._2))) + program.copy(mainModule = program.mainModule.copy(defs = program.mainModule.defs ++ pairs.map(_._2))) , params) val eval1 = (for (ind <- 0 until inputExamples.size) yield { @@ -278,13 +278,13 @@ class EvaluationTest extends FunSuite { import TreeOps._ val newFunId = FreshIdentifier("tempIntroducedFunction") - val newFun = new FunDef(newFunId, funDef.returnType, funDef.args) + val newFun = new FunDef(newFunId, Nil, funDef.returnType, funDef.args) newFun.precondition = funDef.precondition newFun.postcondition = funDef.postcondition def replaceFunDef(expr: Expr) = expr match { case FunctionInvocation(`funDef`, args) => - Some(FunctionInvocation(newFun, args)) + Some(FunctionInvocation(newFun.typed, args)) case _ => None } val newBody = postMap(replaceFunDef)(expr) @@ -305,7 +305,7 @@ class EvaluationTest extends FunSuite { val params = CodeGenParams(maxFunctionInvocations = 500, checkContracts = true) val evaluator = new CodeGenEvaluator(sctx.context, - program.copy(mainObject = program.mainObject.copy(defs = program.mainObject.defs ++ pairs.map(_._2))) + program.copy(mainModule = program.mainModule.copy(defs = program.mainModule.defs ++ pairs.map(_._2))) , params) val eval1 = (for (ind <- 0 until inputExamples.size) yield { diff --git a/src/test/scala/leon/test/condabd/VerifierTest.scala b/src/test/scala/leon/test/condabd/VerifierTest.scala index b5d5c3fdac7759f5c53bf5c2347f11675939451c..dc59a28912ab05a538cf90b3d389a56c0fa7611a 100644 --- a/src/test/scala/leon/test/condabd/VerifierTest.scala +++ b/src/test/scala/leon/test/condabd/VerifierTest.scala @@ -61,7 +61,7 @@ class VerifierTest extends FunSpec { val verifier = new Verifier(timeoutSolver, problem) - assert( verifier.analyzeFunction(funDef)._1 ) + assert( verifier.analyzeFunction(funDef.typed)._1 ) verifier.solver.free() } @@ -75,7 +75,7 @@ class VerifierTest extends FunSpec { val verifier = new Verifier(timeoutSolver, problem) - assert( verifier.analyzeFunction(funDef)._1 ) + assert( verifier.analyzeFunction(funDef.typed)._1 ) verifier.solver.free() } @@ -89,7 +89,7 @@ class VerifierTest extends FunSpec { val verifier = new Verifier(timeoutSolver, problem) - assert( ! verifier.analyzeFunction(funDef)._1 ) + assert( ! verifier.analyzeFunction(funDef.typed)._1 ) verifier.solver.free() } } @@ -132,7 +132,7 @@ class VerifierTest extends FunSpec { val verifier = new RelaxedVerifier(timeoutSolver, problem) - assert( verifier.analyzeFunction(funDef)._1 ) + assert( verifier.analyzeFunction(funDef.typed)._1 ) verifier.solver.free() } @@ -146,7 +146,7 @@ class VerifierTest extends FunSpec { val verifier = new Verifier(timeoutSolver, problem) - assert( verifier.analyzeFunction(funDef)._1 ) + assert( verifier.analyzeFunction(funDef.typed)._1 ) verifier.solver.free() } } diff --git a/src/test/scala/leon/test/condabd/enumeration/EnumeratorTest.scala b/src/test/scala/leon/test/condabd/enumeration/EnumeratorTest.scala index 7881846ae98f6915db2cc9baf85ad7db57d643cf..3bce2e5a7521d73fd2964a70dbaabcb39a434f89 100644 --- a/src/test/scala/leon/test/condabd/enumeration/EnumeratorTest.scala +++ b/src/test/scala/leon/test/condabd/enumeration/EnumeratorTest.scala @@ -90,14 +90,14 @@ class EnumeratorTest extends JUnitSuite { val nilAbstractClassDef = program.definedClasses.find(_.id.name == "Nil"). get.asInstanceOf[CaseClassDef] val listVal = funDef.args.head.toVariable - + val variableRefiner = new VariableRefinerStructure(loader.directSubclassesMap, loader.variableDeclarations, loader.classMap, reporter) val (refined, newDeclarations) = variableRefiner.refine( - CaseClassInstanceOf(nilAbstractClassDef, listVal), BooleanLiteral(true), allDeclarations) + CaseClassInstanceOf(CaseClassType(nilAbstractClassDef, Nil), listVal), BooleanLiteral(true), allDeclarations) assertTrue(refined) assert(allDeclarations.size + 2 == newDeclarations.size) @@ -412,4 +412,4 @@ class EnumeratorTest extends JUnitSuite { } } -} \ No newline at end of file +} diff --git a/src/test/scala/leon/test/condabd/insynth/reconstruction/CodeGeneratorTest.scala b/src/test/scala/leon/test/condabd/insynth/reconstruction/CodeGeneratorTest.scala index c5a7f5bcbb27817d0b93c9b7fb80020eec79a9c0..961b18b18c870aa1f8b613740db20f53febc6a51 100644 --- a/src/test/scala/leon/test/condabd/insynth/reconstruction/CodeGeneratorTest.scala +++ b/src/test/scala/leon/test/condabd/insynth/reconstruction/CodeGeneratorTest.scala @@ -25,7 +25,7 @@ class CodeGeneratorTest extends JUnitSuite { val codeGenResult = codeGenerator(constructBooleanToIntIntermediateLambda.head) assertEquals( - FunctionInvocation(functionBoolToIntFunDef, List(BooleanLiteral(false))), + FunctionInvocation(functionBoolToIntFunDef.typed, List(BooleanLiteral(false))), codeGenResult ) } @@ -37,24 +37,24 @@ class CodeGeneratorTest extends JUnitSuite { for (intermediateTree <- constructThreeParFunctionIntermediateLambda(4)) yield codeGenerator(intermediateTree) - val baseCase = FunctionInvocation(threeParFunctionDef, List(IntLiteral(0), IntLiteral(0), BooleanLiteral(false))) + val baseCase = FunctionInvocation(threeParFunctionDef.typed, List(IntLiteral(0), IntLiteral(0), BooleanLiteral(false))) val message = "Generated:\n" + generated.mkString("\n") assertTrue(baseCase + " not found. " + message, generated contains baseCase) - val oneLevCase1 = FunctionInvocation(threeParFunctionDef, List(baseCase, IntLiteral(0), BooleanLiteral(false))) - val oneLevCase2 = FunctionInvocation(threeParFunctionDef, List(baseCase, baseCase, BooleanLiteral(false))) + val oneLevCase1 = FunctionInvocation(threeParFunctionDef.typed, List(baseCase, IntLiteral(0), BooleanLiteral(false))) + val oneLevCase2 = FunctionInvocation(threeParFunctionDef.typed, List(baseCase, baseCase, BooleanLiteral(false))) assertTrue(oneLevCase1 + " not found. " + message, generated contains oneLevCase1) assertTrue(oneLevCase2 + " not found. " + message, generated contains oneLevCase2) - val twoLevCase1 = FunctionInvocation(threeParFunctionDef, List(oneLevCase1, IntLiteral(0), BooleanLiteral(false))) - val twoLevCase2 = FunctionInvocation(threeParFunctionDef, List(baseCase, oneLevCase2, BooleanLiteral(false))) + val twoLevCase1 = FunctionInvocation(threeParFunctionDef.typed, List(oneLevCase1, IntLiteral(0), BooleanLiteral(false))) + val twoLevCase2 = FunctionInvocation(threeParFunctionDef.typed, List(baseCase, oneLevCase2, BooleanLiteral(false))) assertTrue(twoLevCase1 + " not found. " + message, generated contains twoLevCase1) assertTrue(twoLevCase2 + " not found. " + message, generated contains twoLevCase2) } -} \ No newline at end of file +} diff --git a/src/test/scala/leon/test/condabd/insynth/reconstruction/ReconstructorTest.scala b/src/test/scala/leon/test/condabd/insynth/reconstruction/ReconstructorTest.scala index c57ef877e4428151282be6acb5a0d651968aaf5b..b7a9d2a2090493be9ea964b9e295bef0ed463b3e 100644 --- a/src/test/scala/leon/test/condabd/insynth/reconstruction/ReconstructorTest.scala +++ b/src/test/scala/leon/test/condabd/insynth/reconstruction/ReconstructorTest.scala @@ -8,7 +8,7 @@ import org.scalatest.junit.JUnitSuite import leon.synthesis.condabd.insynth.reconstruction.codegen.CodeGenerator import leon.synthesis.condabd.insynth.reconstruction._ -import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ObjectDef } +import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ModuleDef } import leon.purescala.Common.{ FreshIdentifier } import leon.purescala.TypeTrees._ import leon.purescala.Trees.{ Variable => LeonVariable, _ } @@ -34,7 +34,7 @@ class ReconstructorTest extends JUnitSuite { val codeGenResult = expStream.head - assertEquals(FunctionInvocation(functionBoolToIntFunDef, List(BooleanLiteral(false))), codeGenResult.snippet) + assertEquals(FunctionInvocation(functionBoolToIntFunDef.typed, List(BooleanLiteral(false))), codeGenResult.snippet) assertEquals(0f, codeGenResult.weight, 0f) } diff --git a/src/test/scala/leon/test/condabd/insynth/testutil/CommonDeclarations.scala b/src/test/scala/leon/test/condabd/insynth/testutil/CommonDeclarations.scala index 8fb24261f795c2b4a142b43c7e859bce75ff8ff6..7d4eaf3abad98e8d117e91bccfb318aee5d0c414 100644 --- a/src/test/scala/leon/test/condabd/insynth/testutil/CommonDeclarations.scala +++ b/src/test/scala/leon/test/condabd/insynth/testutil/CommonDeclarations.scala @@ -9,7 +9,7 @@ import org.junit.Ignore import leon.synthesis.condabd.insynth.leon.loader.DeclarationFactory._ import leon.synthesis.condabd.insynth.leon._ -import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ObjectDef } +import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ModuleDef } import leon.purescala.Common.{ FreshIdentifier } import leon.purescala.TypeTrees._ import leon.purescala.Trees._ @@ -40,6 +40,7 @@ object CommonDeclarations { val functionBoolToIntFunDef = new FunDef( FreshIdentifier("function1"), + Nil, Int32Type, List( VarDecl(FreshIdentifier("var"), BooleanType)) ) @@ -48,13 +49,14 @@ object CommonDeclarations { FunctionType(List(BooleanType), Int32Type) val functionBoolToIntDeclaration = makeDeclaration( - NaryReconstructionExpression("function1", { args: List[Expr] => FunctionInvocation(functionBoolToIntFunDef, args) }), + NaryReconstructionExpression("function1", { args: List[Expr] => FunctionInvocation(functionBoolToIntFunDef.typed, args) }), functionBoolToIntType ) val functionFun1ToUnitFunDef = new FunDef( FreshIdentifier("function2"), + Nil, UnitType, List( VarDecl(FreshIdentifier("var"), functionBoolToIntType)) ) @@ -63,7 +65,7 @@ object CommonDeclarations { FunctionType(List(UnitType), functionBoolToIntType) val functionFun1ToUnitDeclaration = makeDeclaration( - NaryReconstructionExpression("function2", { args: List[Expr] => FunctionInvocation(functionFun1ToUnitFunDef, args) }), + NaryReconstructionExpression("function2", { args: List[Expr] => FunctionInvocation(functionFun1ToUnitFunDef.typed, args) }), functionFun1ToUnitType ) @@ -76,6 +78,7 @@ object CommonDeclarations { val funDef = new FunDef( FreshIdentifier("functionRec"), + Nil, Int32Type, List( varDec ) ) @@ -86,7 +89,7 @@ object CommonDeclarations { } val functionIntToIntDeclaration = makeDeclaration( - NaryReconstructionExpression("functionRec", { args: List[Expr] => FunctionInvocation(functionIntToIntFunDef, args) }), + NaryReconstructionExpression("functionRec", { args: List[Expr] => FunctionInvocation(functionIntToIntFunDef.typed, args) }), functionIntToIntType ) @@ -96,6 +99,7 @@ object CommonDeclarations { val threeParFunctionDef = new FunDef( FreshIdentifier("function3"), + Nil, Int32Type, List( VarDecl(FreshIdentifier("var_1"), Int32Type), @@ -105,8 +109,8 @@ object CommonDeclarations { ) val threeParFunctionDeclaration = makeDeclaration( - NaryReconstructionExpression("function3", { args: List[Expr] => FunctionInvocation(threeParFunctionDef, args) }), + NaryReconstructionExpression("function3", { args: List[Expr] => FunctionInvocation(threeParFunctionDef.typed, args) }), threeParFunctionType ) -} \ No newline at end of file +} diff --git a/src/test/scala/leon/test/condabd/insynth/testutil/CommonLambda.scala b/src/test/scala/leon/test/condabd/insynth/testutil/CommonLambda.scala index 4ff98d9225ff8c47bf998e2064f89fee64e897b8..8f76beca54b861bfea77487eebac3a1c63a282ce 100644 --- a/src/test/scala/leon/test/condabd/insynth/testutil/CommonLambda.scala +++ b/src/test/scala/leon/test/condabd/insynth/testutil/CommonLambda.scala @@ -4,7 +4,7 @@ import leon.synthesis.condabd.insynth.leon.query.{ LeonQueryBuilder => QueryBuil import leon.synthesis.condabd.insynth.leon._ import insynth.reconstruction.stream._ -import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ObjectDef } +import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ModuleDef } import leon.purescala.Common.{ FreshIdentifier } import leon.purescala.TypeTrees._ import leon.purescala.Trees.{ Variable => _, _ } @@ -184,4 +184,4 @@ object CommonLambda { // intermediateTree // } -} \ No newline at end of file +} diff --git a/src/test/scala/leon/test/condabd/insynth/testutil/CommonLeonExpressions.scala b/src/test/scala/leon/test/condabd/insynth/testutil/CommonLeonExpressions.scala index 7576d3a08c1b5cf4c32b3ea8189250dd52df69eb..0238c6b9e26f6e4fc8e9cf998e434c55e21f9272 100644 --- a/src/test/scala/leon/test/condabd/insynth/testutil/CommonLeonExpressions.scala +++ b/src/test/scala/leon/test/condabd/insynth/testutil/CommonLeonExpressions.scala @@ -1,6 +1,6 @@ package leon.test.condabd.insynth.testutil -import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ObjectDef } +import leon.purescala.Definitions.{ FunDef, VarDecl, Program, ModuleDef } import leon.purescala.Common.{ FreshIdentifier } import leon.purescala.TypeTrees._ import leon.purescala.Trees.{ Variable => _, _ } @@ -9,13 +9,16 @@ object CommonLeonExpressions { import CommonDeclarations._ - val inv1boolInv = FunctionInvocation(functionBoolToIntFunDef, List(booleanLiteral)) - val inv2WithBoolInv = FunctionInvocation(functionIntToIntFunDef, List(inv1boolInv)) - val inv1WithInt = FunctionInvocation(functionIntToIntFunDef, List(intLiteral)) - val inv2WithInt = FunctionInvocation(functionIntToIntFunDef, List(inv1WithInt)) - val inv3WithInt = FunctionInvocation(functionIntToIntFunDef, List(inv2WithInt)) - val inv3WithBoolInv = FunctionInvocation(functionIntToIntFunDef, List(inv2WithBoolInv)) - val inv4WithBoolInv = FunctionInvocation(functionIntToIntFunDef, List(inv3WithBoolInv)) - val inv4WithInt = FunctionInvocation(functionIntToIntFunDef, List(inv3WithInt)) + val boolToInt = functionBoolToIntFunDef.typed + val intToInt = functionIntToIntFunDef.typed -} \ No newline at end of file + val inv1boolInv = FunctionInvocation(boolToInt, List(booleanLiteral)) + val inv2WithBoolInv = FunctionInvocation(intToInt, List(inv1boolInv)) + val inv1WithInt = FunctionInvocation(intToInt, List(intLiteral)) + val inv2WithInt = FunctionInvocation(intToInt, List(inv1WithInt)) + val inv3WithInt = FunctionInvocation(intToInt, List(inv2WithInt)) + val inv3WithBoolInv = FunctionInvocation(intToInt, List(inv2WithBoolInv)) + val inv4WithBoolInv = FunctionInvocation(intToInt, List(inv3WithBoolInv)) + val inv4WithInt = FunctionInvocation(intToInt, List(inv3WithInt)) + +} diff --git a/src/test/scala/leon/test/condabd/refinement/FilterTest.scala b/src/test/scala/leon/test/condabd/refinement/FilterTest.scala index fc18cb9bae82f27430b8d333aa29dfa8f52fbd47..7ff35871df849adf740b6fe3d39ca5e6c8509e09 100644 --- a/src/test/scala/leon/test/condabd/refinement/FilterTest.scala +++ b/src/test/scala/leon/test/condabd/refinement/FilterTest.scala @@ -32,6 +32,7 @@ class FilterTest extends JUnitSuite { var prog: Program = _ var funDef: FunDef = _ + var tfunDef: TypedFunDef = _ var variableRefiner: VariableRefiner = _ var tail: UnaryReconstructionExpression = _ @@ -49,6 +50,7 @@ class FilterTest extends JUnitSuite { prog = sctx.program this.funDef = funDef + this.tfunDef = funDef.typed val loader = new LeonLoader(prog, problem.as, true) @@ -99,7 +101,7 @@ class FilterTest extends JUnitSuite { case _ => fail("could not extract cons"); null } - filter = new Filter(prog, funDef, variableRefiner) + filter = new Filter(prog, tfunDef, variableRefiner) } @Test @@ -107,10 +109,10 @@ class FilterTest extends JUnitSuite { val filter = this.filter import filter.isLess - assertEquals(2, funDef.args.size) + assertEquals(2, tfunDef.args.size) - val variable1 = funDef.args.head - val variable2 = funDef.args(1) + val variable1 = tfunDef.args.head + val variable2 = tfunDef.args(1) assertEquals(+1, isLess(cons(List(UnitLiteral, variable1.toVariable)), variable1.id)) assertEquals(+1, isLess(cons(List(UnitLiteral, variable1.toVariable)), variable2.id)) @@ -130,14 +132,14 @@ class FilterTest extends JUnitSuite { val filter = this.filter import filter.isCallAvoidableBySize - assertEquals(2, funDef.args.size) + assertEquals(2, tfunDef.args.size) - val arg1 = funDef.args.head.toVariable - val arg2 = funDef.args(1).toVariable + val arg1 = tfunDef.args.head.toVariable + val arg2 = tfunDef.args(1).toVariable - def makeFunctionCall(arg1: Expr, arg2: Expr) = FunctionInvocation(funDef, Seq(arg1, arg2)) + def makeFunctionCall(arg1: Expr, arg2: Expr) = FunctionInvocation(tfunDef, Seq(arg1, arg2)) - val arguments = funDef.args.map(_.id).toList + val arguments = tfunDef.args.map(_.id).toList assertEquals(true, isCallAvoidableBySize(makeFunctionCall(nil, nil), arguments)) assertEquals(true, isCallAvoidableBySize(makeFunctionCall(arg1, arg2), arguments)) @@ -157,12 +159,12 @@ class FilterTest extends JUnitSuite { val filter = this.filter import filter.hasDoubleRecursion - assertEquals(2, funDef.args.size) + assertEquals(2, tfunDef.args.size) - val arg1 = funDef.args.head.toVariable - val arg2 = funDef.args(1).toVariable + val arg1 = tfunDef.args.head.toVariable + val arg2 = tfunDef.args(1).toVariable - def makeFunctionCall(arg1: Expr, arg2: Expr) = FunctionInvocation(funDef, Seq(arg1, arg2)) + def makeFunctionCall(arg1: Expr, arg2: Expr) = FunctionInvocation(tfunDef, Seq(arg1, arg2)) assertEquals(false, hasDoubleRecursion(makeFunctionCall(nil, nil))) assertEquals(false, hasDoubleRecursion(makeFunctionCall(arg1, arg2))) @@ -185,14 +187,14 @@ class FilterTest extends JUnitSuite { val filter = this.filter import filter.isAvoidable - assertEquals(2, funDef.args.size) + assertEquals(2, tfunDef.args.size) - val arg1 = funDef.args.head.toVariable - val arg2 = funDef.args(1).toVariable + val arg1 = tfunDef.args.head.toVariable + val arg2 = tfunDef.args(1).toVariable - def makeFunctionCall(arg1: Expr, arg2: Expr) = FunctionInvocation(funDef, Seq(arg1, arg2)) + def makeFunctionCall(arg1: Expr, arg2: Expr) = FunctionInvocation(tfunDef, Seq(arg1, arg2)) - val arguments = funDef.args.map(_.id).toList + val arguments = tfunDef.args.map(_.id).toList assertEquals(true, isAvoidable(makeFunctionCall(nil, nil), arguments)) assertEquals(true, isAvoidable(makeFunctionCall(arg1, arg2), arguments)) @@ -215,17 +217,17 @@ class FilterTest extends JUnitSuite { import filter.isAvoidable val arg1 = funDef.args.head.toVariable - val arg2 = funDef.args(1).toVariable - val arguments = funDef.args.map(_.id).toList + val arg2 = tfunDef.args(1).toVariable + val arguments = tfunDef.args.map(_.id).toList val tpe = cons(List(Error("temp"))).getType match { - case cct: CaseClassType => cct.classDef + case cct: CaseClassType => cct case _ => fail(arg1 + " should have a class type") null } - assertEquals(false, isAvoidable(CaseClassInstanceOf(tpe, arg1), arguments)) + assertEquals(false, isAvoidable(CaseClassInstanceOf(tpe, arg1), arguments)) assertEquals(false, isAvoidable(CaseClassInstanceOf(tpe, arg2), arguments)) assertEquals(false, isAvoidable(CaseClassInstanceOf(tpe, cons(List(arg1, nil))), arguments)) assertEquals(true, isAvoidable(CaseClassInstanceOf(tpe, tail(arg1)), arguments)) @@ -233,4 +235,4 @@ class FilterTest extends JUnitSuite { assertEquals(true, isAvoidable(CaseClassInstanceOf(tpe, tail(tail(tail(tail(arg1))))), arguments)) } -} \ No newline at end of file +} diff --git a/src/test/scala/leon/test/condabd/refinement/RefinementExamples.scala b/src/test/scala/leon/test/condabd/refinement/RefinementExamples.scala index d7fe41fdbced37129e9c574c52cc7a19eee2cad3..1c4d7bcb6c38cf8a09f76194ff1e5d7b2c490a1e 100644 --- a/src/test/scala/leon/test/condabd/refinement/RefinementExamples.scala +++ b/src/test/scala/leon/test/condabd/refinement/RefinementExamples.scala @@ -16,18 +16,18 @@ import leon.synthesis.condabd.refinement._ object RefinementExamples { val listClassId = FreshIdentifier("List") - val listAbstractClassDef = new AbstractClassDef(listClassId) - val listAbstractClass = new AbstractClassType(listAbstractClassDef) + val listAbstractClassDef = new AbstractClassDef(listClassId, Nil, None) + val listAbstractClass = new AbstractClassType(listAbstractClassDef, Nil) val nilClassId = FreshIdentifier("Nil") - val nilAbstractClassDef = new CaseClassDef(nilClassId).setParent(listAbstractClassDef) - val nilAbstractClass = new CaseClassType(nilAbstractClassDef) + val nilAbstractClassDef = new CaseClassDef(nilClassId, Nil, None, false) + val nilAbstractClass = new CaseClassType(nilAbstractClassDef, Nil) val consClassId = FreshIdentifier("Cons") - val consAbstractClassDef = new CaseClassDef(consClassId).setParent(listAbstractClassDef) + val consAbstractClassDef = new CaseClassDef(consClassId, Nil, None, false) val headId = FreshIdentifier("head").setType(Int32Type) - consAbstractClassDef.fields = Seq(VarDecl(headId, Int32Type)) - val consAbstractClass = new CaseClassType(consAbstractClassDef) + consAbstractClassDef.setFields(Seq(VarDecl(headId, Int32Type))) + val consAbstractClass = new CaseClassType(consAbstractClassDef, Nil) val directSubclassMap: Map[ClassType, Set[ClassType]] = Map( listAbstractClass -> Set(nilAbstractClass, consAbstractClass) @@ -48,25 +48,28 @@ object RefinementExamples { def buildClassMap(program: Program) = { val listAbstractClassDef = program.definedClasses.find(_.id.name == "List"). get.asInstanceOf[AbstractClassDef] + val listAbstractClass = AbstractClassType(listAbstractClassDef, Nil) val nilAbstractClassDef = program.definedClasses.find(_.id.name == "Nil"). get.asInstanceOf[CaseClassDef] + val nilAbstractClass = CaseClassType(nilAbstractClassDef, Nil) val consAbstractClassDef = program.definedClasses.find(_.id.name == "Cons"). get.asInstanceOf[CaseClassDef] + val consAbstractClass = CaseClassType(consAbstractClassDef, Nil) val directSubclassMap: Map[ClassType, Set[ClassType]] = Map( - AbstractClassType(listAbstractClassDef) -> - Set(CaseClassType(nilAbstractClassDef), CaseClassType(consAbstractClassDef)) + listAbstractClass -> + Set(nilAbstractClass, consAbstractClass) ) val classMap: Map[Identifier, ClassType] = Map( - listAbstractClassDef.id -> AbstractClassType(listAbstractClassDef), - nilAbstractClassDef.id -> CaseClassType(nilAbstractClassDef), - consAbstractClassDef.id -> CaseClassType(consAbstractClassDef) + listAbstractClassDef.id -> listAbstractClass, + nilAbstractClassDef.id -> nilAbstractClass, + consAbstractClassDef.id -> consAbstractClass ) - (directSubclassMap, AbstractClassType(listAbstractClassDef), classMap) + (directSubclassMap, listAbstractClass, classMap) } -} \ No newline at end of file +} diff --git a/src/test/scala/leon/test/condabd/refinement/VariableRefinerComposeTest.scala b/src/test/scala/leon/test/condabd/refinement/VariableRefinerComposeTest.scala index 833206dcbd48c6bcb84bec8bd99aab1c14b436c4..e470b9e19f858ca5093602cba5fb52c510398720 100644 --- a/src/test/scala/leon/test/condabd/refinement/VariableRefinerComposeTest.scala +++ b/src/test/scala/leon/test/condabd/refinement/VariableRefinerComposeTest.scala @@ -43,7 +43,7 @@ class VariableRefinerComposeTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isSorted" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isSorted"); null } @@ -114,7 +114,7 @@ class VariableRefinerComposeTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isSorted" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isSorted"); null } @@ -147,4 +147,4 @@ class VariableRefinerComposeTest extends FunSpec with GivenWhenThen { } -} \ No newline at end of file +} diff --git a/src/test/scala/leon/test/condabd/refinement/VariableRefinerExecutionTest.scala b/src/test/scala/leon/test/condabd/refinement/VariableRefinerExecutionTest.scala index 202072a3da0c7be930b5395132f10cfe9b83a43a..04c8f47ad8780a3e3ed5d4e718b78ba130e96638 100644 --- a/src/test/scala/leon/test/condabd/refinement/VariableRefinerExecutionTest.scala +++ b/src/test/scala/leon/test/condabd/refinement/VariableRefinerExecutionTest.scala @@ -43,7 +43,7 @@ class VariableRefinerExecutionTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isSorted" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isSorted"); null } @@ -109,7 +109,7 @@ class VariableRefinerExecutionTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isEmpty" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isEmpty"); null } @@ -117,7 +117,7 @@ class VariableRefinerExecutionTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isEmptyBad" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isEmpty"); null } @@ -182,7 +182,7 @@ class VariableRefinerExecutionTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "hasContent" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract hasContent"); null } @@ -249,7 +249,7 @@ class VariableRefinerExecutionTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isEmpty" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isEmpty"); null } @@ -257,7 +257,7 @@ class VariableRefinerExecutionTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isEmptyBad" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isEmpty"); null } @@ -295,4 +295,4 @@ class VariableRefinerExecutionTest extends FunSpec with GivenWhenThen { } -} \ No newline at end of file +} diff --git a/src/test/scala/leon/test/condabd/refinement/VariableRefinerStructureTest.scala b/src/test/scala/leon/test/condabd/refinement/VariableRefinerStructureTest.scala index 18d266db7be482ac495940876be8cadda7931b22..501ae166209294c402e94e5ee3edb299a0e2691d 100644 --- a/src/test/scala/leon/test/condabd/refinement/VariableRefinerStructureTest.scala +++ b/src/test/scala/leon/test/condabd/refinement/VariableRefinerStructureTest.scala @@ -30,27 +30,28 @@ class VariableRefinerTest extends FunSpec with GivenWhenThen { ) Then("it should return appropriate id And class def") - expectResult(Some((listVal.id, nilAbstractClassDef))) { - variableRefiner.getIdAndClassDef(CaseClassInstanceOf(nilAbstractClassDef, listVal)) + expectResult(Some((listVal.id, nilAbstractClass))) { + variableRefiner.getIdAndClassDef(CaseClassInstanceOf(nilAbstractClass, listVal)) } And("return None for some unknown expression") expectResult(None) { variableRefiner.getIdAndClassDef(listVal) } + Then("declarations should be updated accordingly") val allDeclarations = List(listLeonDeclaration) expectResult((true, LeonDeclaration( ImmediateExpression( listVal + "." + headId, - CaseClassSelector(consAbstractClassDef, listVal, headId) ), + CaseClassSelector(consAbstractClass, listVal, headId) ), TypeTransformer(Int32Type), Int32Type ) :: LeonDeclaration( listLeonDeclaration.expression, TypeTransformer(consAbstractClass), consAbstractClass ) :: Nil )) { - variableRefiner.refine(CaseClassInstanceOf(nilAbstractClassDef, listVal), + variableRefiner.refine(CaseClassInstanceOf(nilAbstractClass, listVal), BooleanLiteral(true), allDeclarations ) @@ -58,7 +59,7 @@ class VariableRefinerTest extends FunSpec with GivenWhenThen { And("after 2nd consequtive call, nothing should happen") expectResult((false, allDeclarations)) { - variableRefiner.refine(CaseClassInstanceOf(nilAbstractClassDef, listVal), + variableRefiner.refine(CaseClassInstanceOf(nilAbstractClass, listVal), BooleanLiteral(true), allDeclarations) } @@ -66,4 +67,4 @@ class VariableRefinerTest extends FunSpec with GivenWhenThen { } -} \ No newline at end of file +} diff --git a/src/test/scala/leon/test/condabd/refinement/VariableSolverRefinerTest.scala b/src/test/scala/leon/test/condabd/refinement/VariableSolverRefinerTest.scala index c6e63bab0beff80a73487183fca725f8e09321da..4c47654a13bdaf8a2008d5cb59826319b0557cdd 100644 --- a/src/test/scala/leon/test/condabd/refinement/VariableSolverRefinerTest.scala +++ b/src/test/scala/leon/test/condabd/refinement/VariableSolverRefinerTest.scala @@ -44,7 +44,7 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isEmpty" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isEmpty"); null } @@ -52,7 +52,7 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isEmptyBad" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isEmpty"); null } @@ -108,7 +108,7 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "hasContent" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract hasContent"); null } @@ -164,7 +164,7 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isEmpty" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isEmpty"); null } @@ -172,7 +172,7 @@ class VariableSolverRefinerTest extends FunSpec with GivenWhenThen { program.definedFunctions.find { _.id.name == "isEmptyBad" } match { - case Some(found) => (x: Expr) => FunctionInvocation(found, Seq(x)) + case Some(found) => (x: Expr) => FunctionInvocation(found.typed, Seq(x)) case _ => fail("could not extract isEmpty"); null } diff --git a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala index 51894fc064acc224c078a19ff9b50540e50db317..6a125cb14a03d29ba69763e12d84940cfd6d4b14 100644 --- a/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala +++ b/src/test/scala/leon/test/evaluators/EvaluatorsTests.scala @@ -44,12 +44,12 @@ class EvaluatorsTests extends LeonTestSuite { throw new AssertionError("No function named '%s' defined in program.".format(name)) } - FunctionInvocation(fDef, args.toSeq) + FunctionInvocation(fDef.typed, args.toSeq) } private def mkCaseClass(name : String, args : Expr*)(implicit p : Program) = { - val ccDef = p.mainObject.caseClassDef(name) - CaseClass(ccDef, args.toSeq) + val ccDef = p.mainModule.caseClassDef(name) + CaseClass(CaseClassType(ccDef, Nil), args.toSeq) } private def checkCompSuccess(evaluator : Evaluator, in : Expr) : Expr = { diff --git a/src/test/scala/leon/test/purescala/DataGen.scala b/src/test/scala/leon/test/purescala/DataGen.scala index a84687ebc01bdb6bc96fa2be0b1acdfce847194a..2de78cdce43d8790a6bf74ec5847dc6d4e04dff6 100644 --- a/src/test/scala/leon/test/purescala/DataGen.scala +++ b/src/test/scala/leon/test/purescala/DataGen.scala @@ -67,13 +67,13 @@ class DataGen extends LeonTestSuite { generator.generate(BooleanType).toSet.size === 2 generator.generate(TupleType(Seq(BooleanType,BooleanType))).toSet.size === 4 - val listType : TypeTree = classDefToClassType(prog.mainObject.classHierarchyRoots.head) + val listType : TypeTree = classDefToClassType(prog.mainModule.classHierarchyRoots.head) val sizeDef : FunDef = prog.definedFunctions.find(_.id.name == "size").get val sortedDef : FunDef = prog.definedFunctions.find(_.id.name == "isSorted").get val contentDef : FunDef = prog.definedFunctions.find(_.id.name == "content").get val insSpecDef : FunDef = prog.definedFunctions.find(_.id.name == "insertSpec").get - val consDef : CaseClassDef = prog.mainObject.caseClassDef("Cons") + val consDef : CaseClassDef = prog.mainModule.caseClassDef("Cons") generator.generate(listType).take(100).toSet.size === 100 @@ -84,11 +84,11 @@ class DataGen extends LeonTestSuite { val x = Variable(FreshIdentifier("x").setType(listType)) val y = Variable(FreshIdentifier("y").setType(listType)) - val sizeX = FunctionInvocation(sizeDef, Seq(x)) - val contentX = FunctionInvocation(contentDef, Seq(x)) - val contentY = FunctionInvocation(contentDef, Seq(y)) - val sortedX = FunctionInvocation(sortedDef, Seq(x)) - val sortedY = FunctionInvocation(sortedDef, Seq(y)) + val sizeX = FunctionInvocation(sizeDef.typed, Seq(x)) + val contentX = FunctionInvocation(contentDef.typed, Seq(x)) + val contentY = FunctionInvocation(contentDef.typed, Seq(y)) + val sortedX = FunctionInvocation(sortedDef.typed, Seq(x)) + val sortedY = FunctionInvocation(sortedDef.typed, Seq(y)) assert(generator.generateFor( Seq(x.id), @@ -115,8 +115,8 @@ class DataGen extends LeonTestSuite { Seq(x.id, y.id, b.id, a.id), And(Seq( LessThan(a, b), - FunctionInvocation(sortedDef, Seq(CaseClass(consDef, Seq(a, x)))), - FunctionInvocation(insSpecDef, Seq(b, x, y)) + FunctionInvocation(sortedDef.typed, Seq(CaseClass(CaseClassType(consDef, Nil), Seq(a, x)))), + FunctionInvocation(insSpecDef.typed, Seq(b, x, y)) )), 10, 500 diff --git a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala index affcad2d2e83d7237fe6606492387e99626fbc94..8a3fd8d8abee7d8048d8217e2d7b378a6dfbcb14 100644 --- a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTests.scala @@ -39,19 +39,19 @@ class FairZ3SolverTests extends LeonTestSuite { // def f(fx : Int) : Int = fx + 1 private val fx : Identifier = FreshIdentifier("x").setType(Int32Type) - private val fDef : FunDef = new FunDef(FreshIdentifier("f"), Int32Type, VarDecl(fx, Int32Type) :: Nil) + private val fDef : FunDef = new FunDef(FreshIdentifier("f"), Nil, Int32Type, VarDecl(fx, Int32Type) :: Nil) fDef.body = Some(Plus(Variable(fx), IntLiteral(1))) private val minimalProgram = Program( FreshIdentifier("Minimal"), - ObjectDef(FreshIdentifier("Minimal"), Seq( + ModuleDef(FreshIdentifier("Minimal"), Seq( fDef ), Seq.empty) ) private val x : Expr = Variable(FreshIdentifier("x").setType(Int32Type)) private val y : Expr = Variable(FreshIdentifier("y").setType(Int32Type)) - private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) + private def f(e : Expr) : Expr = FunctionInvocation(fDef.typed, e :: Nil) private val solver = SimpleSolverAPI(SolverFactory(() => new FairZ3Solver(testContext, minimalProgram))) diff --git a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala index 4d9f46185f99da00b460f46e01a366ede24194c9..4aae6a644eaf9caa5df80bd61b9bd9a4f6747b3a 100644 --- a/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala +++ b/src/test/scala/leon/test/solvers/z3/FairZ3SolverTestsNewAPI.scala @@ -47,19 +47,19 @@ class FairZ3SolverTestsNewAPI extends LeonTestSuite { // def f(fx : Int) : Int = fx + 1 private val fx : Identifier = FreshIdentifier("x").setType(Int32Type) - private val fDef : FunDef = new FunDef(FreshIdentifier("f"), Int32Type, VarDecl(fx, Int32Type) :: Nil) + private val fDef : FunDef = new FunDef(FreshIdentifier("f"), Nil, Int32Type, VarDecl(fx, Int32Type) :: Nil) fDef.body = Some(Plus(Variable(fx), IntLiteral(1))) private val minimalProgram = Program( FreshIdentifier("Minimal"), - ObjectDef(FreshIdentifier("Minimal"), Seq( + ModuleDef(FreshIdentifier("Minimal"), Seq( fDef ), Seq.empty) ) private val x : Expr = Variable(FreshIdentifier("x").setType(Int32Type)) private val y : Expr = Variable(FreshIdentifier("y").setType(Int32Type)) - private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) + private def f(e : Expr) : Expr = FunctionInvocation(fDef.typed, e :: Nil) private val solver = SolverFactory(() => new FairZ3Solver(testContext, minimalProgram)) diff --git a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala index 84c7589e8f0272058164e550ee126705298413c8..52fd55a0284683c3541a85f27aba2580e466aaec 100644 --- a/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala +++ b/src/test/scala/leon/test/solvers/z3/UninterpretedZ3SolverTests.scala @@ -39,24 +39,24 @@ class UninterpretedZ3SolverTests extends LeonTestSuite { // def f(fx : Int) : Int = fx + 1 private val fx : Identifier = FreshIdentifier("x").setType(Int32Type) - private val fDef : FunDef = new FunDef(FreshIdentifier("f"), Int32Type, VarDecl(fx, Int32Type) :: Nil) + private val fDef : FunDef = new FunDef(FreshIdentifier("f"), Nil, Int32Type, VarDecl(fx, Int32Type) :: Nil) fDef.body = Some(Plus(Variable(fx), IntLiteral(1))) // g is a function that is not in the program (on purpose) - private val gDef : FunDef = new FunDef(FreshIdentifier("g"), Int32Type, VarDecl(fx, Int32Type) :: Nil) + private val gDef : FunDef = new FunDef(FreshIdentifier("g"), Nil, Int32Type, VarDecl(fx, Int32Type) :: Nil) gDef.body = Some(Plus(Variable(fx), IntLiteral(1))) private val minimalProgram = Program( FreshIdentifier("Minimal"), - ObjectDef(FreshIdentifier("Minimal"), Seq( + ModuleDef(FreshIdentifier("Minimal"), Seq( fDef ), Seq.empty) ) private val x : Expr = Variable(FreshIdentifier("x").setType(Int32Type)) private val y : Expr = Variable(FreshIdentifier("y").setType(Int32Type)) - private def f(e : Expr) : Expr = FunctionInvocation(fDef, e :: Nil) - private def g(e : Expr) : Expr = FunctionInvocation(gDef, e :: Nil) + private def f(e : Expr) : Expr = FunctionInvocation(fDef.typed, e :: Nil) + private def g(e : Expr) : Expr = FunctionInvocation(gDef.typed, e :: Nil) private val solver = SimpleSolverAPI(SolverFactory(() => new UninterpretedZ3Solver(testContext, minimalProgram))) @@ -84,9 +84,5 @@ class UninterpretedZ3SolverTests extends LeonTestSuite { private val unknown1 : Expr = Equals(f(x), Plus(x, IntLiteral(1))) assertUnknown(solver, unknown1) - test("Expected crash on undefined functions.") { - intercept[Exception] { - solver.solveVALID(Equals(g(x), g(x))) - } - } + assertValid(solver, Equals(g(x), g(x))) } diff --git a/unmanaged/64/scalaz3-unix-64b-2.1.jar b/unmanaged/64/scalaz3-unix-64b-2.1.jar index c59bbc1f8c765eadbc6d86d0ea7e26778f78b904..ad00c9ba1a45d17eab0cda9a0c20eb603507d9e9 100644 Binary files a/unmanaged/64/scalaz3-unix-64b-2.1.jar and b/unmanaged/64/scalaz3-unix-64b-2.1.jar differ