diff --git a/src/main/scala/leon/Z3ModelReconstruction.scala b/src/main/scala/leon/Z3ModelReconstruction.scala index 913da2fa32b64056026f24e3d1ea1cbd8dd81672..87c75ddd5fb6667764727daa85f018deeeb4b6ec 100644 --- a/src/main/scala/leon/Z3ModelReconstruction.scala +++ b/src/main/scala/leon/Z3ModelReconstruction.scala @@ -4,6 +4,7 @@ import z3.scala._ import purescala.Common._ import purescala.Definitions._ import purescala.Trees._ +import purescala.TreeOps._ import purescala.TypeTrees._ import Extensions._ diff --git a/src/main/scala/leon/isabelle/Main.scala b/src/main/scala/leon/isabelle/Main.scala index 75e1c86fb1f4181a0e2834ac628dcb20e1a866a2..ac95c8095ec4a9af8543e8d0cecc795a5f8acd19 100644 --- a/src/main/scala/leon/isabelle/Main.scala +++ b/src/main/scala/leon/isabelle/Main.scala @@ -8,6 +8,7 @@ import leon.purescala.Common.Identifier import leon.purescala.Definitions._ import leon.purescala.PrettyPrinter import leon.purescala.Trees._ +import leon.purescala.TreeOps._ import leon.purescala.Extractors._ import leon.purescala.TypeTrees._ diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 2949fa0361e8f50ef321750860166c6c66b408ea..a0494a8e440a799ca8d2d600a5a6fda6d9268170 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -727,4 +727,227 @@ object TreeOps { rec(expr, Map.empty) } + + private var matchConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]() + /** Rewrites all pattern-matching expressions into if-then-else expressions, + * with additional error conditions. Does not introduce additional variables. + * We use a cache because we can. */ + def matchToIfThenElse(expr: Expr) : Expr = { + val toRet = if(matchConverterCache.isDefinedAt(expr)) { + matchConverterCache(expr) + } else { + val converted = convertMatchToIfThenElse(expr) + matchConverterCache(expr) = converted + converted + } + + toRet + } + + def conditionForPattern(in: Expr, pattern: Pattern) : Expr = pattern match { + case WildcardPattern(_) => BooleanLiteral(true) + case InstanceOfPattern(_,_) => scala.sys.error("InstanceOfPattern not yet supported.") + case CaseClassPattern(_, ccd, subps) => { + assert(ccd.fields.size == subps.size) + val pairs = ccd.fields.map(_.id).toList zip subps.toList + val subTests = pairs.map(p => conditionForPattern(CaseClassSelector(ccd, in, p._1), p._2)) + val together = And(subTests) + And(CaseClassInstanceOf(ccd, in), together) + } + case TuplePattern(_, subps) => { + val TupleType(tpes) = in.getType + assert(tpes.size == subps.size) + val subTests = subps.zipWithIndex.map{case (p, i) => conditionForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)} + And(subTests) + } + } + + private def convertMatchToIfThenElse(expr: Expr) : Expr = { + def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = pattern match { + case WildcardPattern(None) => Map.empty + case WildcardPattern(Some(id)) => Map(id -> in) + case InstanceOfPattern(None, _) => Map.empty + case InstanceOfPattern(Some(id), _) => Map(id -> in) + case CaseClassPattern(b, ccd, subps) => { + assert(ccd.fields.size == subps.size) + val pairs = ccd.fields.map(_.id).toList zip subps.toList + val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2)) + val together = subMaps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) + b match { + case Some(id) => Map(id -> in) ++ together + case None => together + } + } + case TuplePattern(b, subps) => { + val TupleType(tpes) = in.getType + assert(tpes.size == subps.size) + + val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)} + val map = maps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) + b match { + case Some(id) => map + (id -> in) + case None => map + } + } + } + + def rewritePM(e: Expr) : Option[Expr] = e match { + case m @ MatchExpr(scrut, cases) => { + // println("Rewriting the following PM: " + e) + + val condsAndRhs = for(cse <- cases) yield { + val map = mapForPattern(scrut, cse.pattern) + val patCond = conditionForPattern(scrut, cse.pattern) + val realCond = cse.theGuard match { + case Some(g) => And(patCond, replaceFromIDs(map, g)) + case None => patCond + } + val newRhs = replaceFromIDs(map, cse.rhs) + (realCond, newRhs) + } + + val optCondsAndRhs = if(SimplePatternMatching.isSimple(m)) { + // this is a hackish optimization: because we know all cases are covered, we replace the last condition by true (and that drops the check) + val lastExpr = condsAndRhs.last._2 + + condsAndRhs.dropRight(1) ++ Seq((BooleanLiteral(true),lastExpr)) + } else { + condsAndRhs + } + + val bigIte = optCondsAndRhs.foldRight[Expr](Error("non-exhaustive match").setType(bestRealType(m.getType)).setPosInfo(m))((p1, ex) => { + if(p1._1 == BooleanLiteral(true)) { + p1._2 + } else { + IfExpr(p1._1, p1._2, ex).setType(m.getType) + } + }) + + Some(bigIte) + } + case _ => None + } + + searchAndReplaceDFS(rewritePM)(expr) + } + + private var mapGetConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]() + /** Rewrites all map accesses with additional error conditions. */ + def mapGetWithChecks(expr: Expr) : Expr = { + val toRet = if (mapGetConverterCache.isDefinedAt(expr)) { + mapGetConverterCache(expr) + } else { + val converted = convertMapGet(expr) + mapGetConverterCache(expr) = converted + converted + } + + toRet + } + + private def convertMapGet(expr: Expr) : Expr = { + def rewriteMapGet(e: Expr) : Option[Expr] = e match { + case mg @ MapGet(m,k) => + val ida = MapIsDefinedAt(m, k) + Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType).setPosInfo(mg)).setType(mg.getType)) + case _ => None + } + + searchAndReplaceDFS(rewriteMapGet)(expr) + } + + // prec: expression does not contain match expressions + def measureADTChildrenDepth(expression: Expr) : Int = { + import scala.math.max + + def rec(ex: Expr, lm: Map[Identifier,Int]) : Int = ex match { + case Let(i,e,b) => rec(b,lm + (i -> rec(e,lm))) + case Variable(id) => lm.getOrElse(id, 0) + case CaseClassSelector(_, e, _) => rec(e,lm) + 1 + case NAryOperator(args, _) => if(args.isEmpty) 0 else args.map(rec(_,lm)).max + case BinaryOperator(e1,e2,_) => max(rec(e1,lm), rec(e2,lm)) + case UnaryOperator(e,_) => rec(e,lm) + case IfExpr(c,t,e) => max(max(rec(c,lm),rec(t,lm)),rec(e,lm)) + case t: Terminal => 0 + case _ => scala.sys.error("Not handled in measureChildrenDepth : " + ex) + } + + rec(expression,Map.empty) + } + + private val random = new scala.util.Random() + + def randomValue(v: Variable) : Expr = randomValue(v.getType) + def simplestValue(v: Variable) : Expr = simplestValue(v.getType) + + def randomValue(tpe: TypeTree) : Expr = tpe match { + case Int32Type => IntLiteral(random.nextInt(42)) + case BooleanType => BooleanLiteral(random.nextBoolean()) + case AbstractClassType(acd) => + val children = acd.knownChildren + randomValue(classDefToClassType(children(random.nextInt(children.size)))) + case CaseClassType(cd) => + val fields = cd.fields + CaseClass(cd, fields.map(f => randomValue(f.getType))) + case _ => throw new Exception("I can't choose random value for type " + tpe) + } + + def simplestValue(tpe: TypeTree) : Expr = tpe match { + case Int32Type => IntLiteral(0) + case BooleanType => BooleanLiteral(false) + case AbstractClassType(acd) => { + val children = acd.knownChildren + val simplerChildren = children.filter{ + case ccd @ CaseClassDef(id, Some(parent), fields) => + !fields.exists(vd => vd.getType match { + case AbstractClassType(fieldAcd) => acd == fieldAcd + case CaseClassType(fieldCcd) => ccd == fieldCcd + case _ => false + }) + case _ => false + } + def orderByNumberOfFields(fst: ClassTypeDef, snd: ClassTypeDef) : Boolean = (fst, snd) match { + case (CaseClassDef(_, _, flds1), CaseClassDef(_, _, flds2)) => flds1.size <= flds2.size + case _ => true + } + val orderedChildren = simplerChildren.sortWith(orderByNumberOfFields) + simplestValue(classDefToClassType(orderedChildren.head)) + } + case CaseClassType(ccd) => + val fields = ccd.fields + CaseClass(ccd, fields.map(f => simplestValue(f.getType))) + case SetType(baseType) => EmptySet(baseType).setType(tpe) + case MapType(fromType, toType) => EmptyMap(fromType, toType).setType(tpe) + case FunctionType(fromTypes, toType) => AnonymousFunction(Seq.empty, simplestValue(toType)).setType(tpe) + case _ => throw new Exception("I can't choose simplest value for type " + tpe) + } + + //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be find in the sub-expressions + //require no-match, no-ets and only pure code + def hoistIte(expr: Expr): Expr = { + def transform(expr: Expr): Option[Expr] = expr match { + case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType)) + case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType)) + case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType)) + case nop@NAryOperator(ts, op) => { + val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } + if(iteIndex == -1) None else { + val (beforeIte, startIte) = ts.splitAt(iteIndex) + val afterIte = startIte.tail + val IfExpr(c, t, e) = startIte.head + Some(IfExpr(c, + op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType), + op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType) + ).setType(nop.getType)) + } + } + case _ => None + } + + def fix[A](f: (A) => A, a: A): A = { + val na = f(a) + if(a == na) a else fix(f, na) + } + fix(searchAndReplaceDFS(transform), expr) + } } diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index b985b2ac358714e2d497d232c935e604af9004ba..c8e1dd18f465adc574269bfce174273e26ba3630 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -441,227 +441,4 @@ object Trees { val fixedType = BooleanType } - private var matchConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]() - /** Rewrites all pattern-matching expressions into if-then-else expressions, - * with additional error conditions. Does not introduce additional variables. - * We use a cache because we can. */ - def matchToIfThenElse(expr: Expr) : Expr = { - val toRet = if(matchConverterCache.isDefinedAt(expr)) { - matchConverterCache(expr) - } else { - val converted = convertMatchToIfThenElse(expr) - matchConverterCache(expr) = converted - converted - } - - toRet - } - - def conditionForPattern(in: Expr, pattern: Pattern) : Expr = pattern match { - case WildcardPattern(_) => BooleanLiteral(true) - case InstanceOfPattern(_,_) => scala.sys.error("InstanceOfPattern not yet supported.") - case CaseClassPattern(_, ccd, subps) => { - assert(ccd.fields.size == subps.size) - val pairs = ccd.fields.map(_.id).toList zip subps.toList - val subTests = pairs.map(p => conditionForPattern(CaseClassSelector(ccd, in, p._1), p._2)) - val together = And(subTests) - And(CaseClassInstanceOf(ccd, in), together) - } - case TuplePattern(_, subps) => { - val TupleType(tpes) = in.getType - assert(tpes.size == subps.size) - val subTests = subps.zipWithIndex.map{case (p, i) => conditionForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)} - And(subTests) - } - } - - private def convertMatchToIfThenElse(expr: Expr) : Expr = { - def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = pattern match { - case WildcardPattern(None) => Map.empty - case WildcardPattern(Some(id)) => Map(id -> in) - case InstanceOfPattern(None, _) => Map.empty - case InstanceOfPattern(Some(id), _) => Map(id -> in) - case CaseClassPattern(b, ccd, subps) => { - assert(ccd.fields.size == subps.size) - val pairs = ccd.fields.map(_.id).toList zip subps.toList - val subMaps = pairs.map(p => mapForPattern(CaseClassSelector(ccd, in, p._1), p._2)) - val together = subMaps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) - b match { - case Some(id) => Map(id -> in) ++ together - case None => together - } - } - case TuplePattern(b, subps) => { - val TupleType(tpes) = in.getType - assert(tpes.size == subps.size) - - val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(TupleSelect(in, i+1).setType(tpes(i)), p)} - val map = maps.foldLeft(Map.empty[Identifier,Expr])(_ ++ _) - b match { - case Some(id) => map + (id -> in) - case None => map - } - } - } - - def rewritePM(e: Expr) : Option[Expr] = e match { - case m @ MatchExpr(scrut, cases) => { - // println("Rewriting the following PM: " + e) - - val condsAndRhs = for(cse <- cases) yield { - val map = mapForPattern(scrut, cse.pattern) - val patCond = conditionForPattern(scrut, cse.pattern) - val realCond = cse.theGuard match { - case Some(g) => And(patCond, replaceFromIDs(map, g)) - case None => patCond - } - val newRhs = replaceFromIDs(map, cse.rhs) - (realCond, newRhs) - } - - val optCondsAndRhs = if(SimplePatternMatching.isSimple(m)) { - // this is a hackish optimization: because we know all cases are covered, we replace the last condition by true (and that drops the check) - val lastExpr = condsAndRhs.last._2 - - condsAndRhs.dropRight(1) ++ Seq((BooleanLiteral(true),lastExpr)) - } else { - condsAndRhs - } - - val bigIte = optCondsAndRhs.foldRight[Expr](Error("non-exhaustive match").setType(bestRealType(m.getType)).setPosInfo(m))((p1, ex) => { - if(p1._1 == BooleanLiteral(true)) { - p1._2 - } else { - IfExpr(p1._1, p1._2, ex).setType(m.getType) - } - }) - - Some(bigIte) - } - case _ => None - } - - searchAndReplaceDFS(rewritePM)(expr) - } - - private var mapGetConverterCache = new scala.collection.mutable.HashMap[Expr,Expr]() - /** Rewrites all map accesses with additional error conditions. */ - def mapGetWithChecks(expr: Expr) : Expr = { - val toRet = if (mapGetConverterCache.isDefinedAt(expr)) { - mapGetConverterCache(expr) - } else { - val converted = convertMapGet(expr) - mapGetConverterCache(expr) = converted - converted - } - - toRet - } - - private def convertMapGet(expr: Expr) : Expr = { - def rewriteMapGet(e: Expr) : Option[Expr] = e match { - case mg @ MapGet(m,k) => - val ida = MapIsDefinedAt(m, k) - Some(IfExpr(ida, mg, Error("key not found for map access").setType(mg.getType).setPosInfo(mg)).setType(mg.getType)) - case _ => None - } - - searchAndReplaceDFS(rewriteMapGet)(expr) - } - - // prec: expression does not contain match expressions - def measureADTChildrenDepth(expression: Expr) : Int = { - import scala.math.max - - def rec(ex: Expr, lm: Map[Identifier,Int]) : Int = ex match { - case Let(i,e,b) => rec(b,lm + (i -> rec(e,lm))) - case Variable(id) => lm.getOrElse(id, 0) - case CaseClassSelector(_, e, _) => rec(e,lm) + 1 - case NAryOperator(args, _) => if(args.isEmpty) 0 else args.map(rec(_,lm)).max - case BinaryOperator(e1,e2,_) => max(rec(e1,lm), rec(e2,lm)) - case UnaryOperator(e,_) => rec(e,lm) - case IfExpr(c,t,e) => max(max(rec(c,lm),rec(t,lm)),rec(e,lm)) - case t: Terminal => 0 - case _ => scala.sys.error("Not handled in measureChildrenDepth : " + ex) - } - - rec(expression,Map.empty) - } - - private val random = new scala.util.Random() - - def randomValue(v: Variable) : Expr = randomValue(v.getType) - def simplestValue(v: Variable) : Expr = simplestValue(v.getType) - - def randomValue(tpe: TypeTree) : Expr = tpe match { - case Int32Type => IntLiteral(random.nextInt(42)) - case BooleanType => BooleanLiteral(random.nextBoolean()) - case AbstractClassType(acd) => - val children = acd.knownChildren - randomValue(classDefToClassType(children(random.nextInt(children.size)))) - case CaseClassType(cd) => - val fields = cd.fields - CaseClass(cd, fields.map(f => randomValue(f.getType))) - case _ => throw new Exception("I can't choose random value for type " + tpe) - } - - def simplestValue(tpe: TypeTree) : Expr = tpe match { - case Int32Type => IntLiteral(0) - case BooleanType => BooleanLiteral(false) - case AbstractClassType(acd) => { - val children = acd.knownChildren - val simplerChildren = children.filter{ - case ccd @ CaseClassDef(id, Some(parent), fields) => - !fields.exists(vd => vd.getType match { - case AbstractClassType(fieldAcd) => acd == fieldAcd - case CaseClassType(fieldCcd) => ccd == fieldCcd - case _ => false - }) - case _ => false - } - def orderByNumberOfFields(fst: ClassTypeDef, snd: ClassTypeDef) : Boolean = (fst, snd) match { - case (CaseClassDef(_, _, flds1), CaseClassDef(_, _, flds2)) => flds1.size <= flds2.size - case _ => true - } - val orderedChildren = simplerChildren.sortWith(orderByNumberOfFields) - simplestValue(classDefToClassType(orderedChildren.head)) - } - case CaseClassType(ccd) => - val fields = ccd.fields - CaseClass(ccd, fields.map(f => simplestValue(f.getType))) - case SetType(baseType) => EmptySet(baseType).setType(tpe) - case MapType(fromType, toType) => EmptyMap(fromType, toType).setType(tpe) - case FunctionType(fromTypes, toType) => AnonymousFunction(Seq.empty, simplestValue(toType)).setType(tpe) - case _ => throw new Exception("I can't choose simplest value for type " + tpe) - } - - //guarentee that all IfExpr will be at the top level and as soon as you encounter a non-IfExpr, then no more IfExpr can be find in the sub-expressions - //require no-match, no-ets and only pure code - def hoistIte(expr: Expr): Expr = { - def transform(expr: Expr): Option[Expr] = expr match { - case uop@UnaryOperator(IfExpr(c, t, e), op) => Some(IfExpr(c, op(t).setType(uop.getType), op(e).setType(uop.getType)).setType(uop.getType)) - case bop@BinaryOperator(IfExpr(c, t, e), t2, op) => Some(IfExpr(c, op(t, t2).setType(bop.getType), op(e, t2).setType(bop.getType)).setType(bop.getType)) - case bop@BinaryOperator(t1, IfExpr(c, t, e), op) => Some(IfExpr(c, op(t1, t).setType(bop.getType), op(t1, e).setType(bop.getType)).setType(bop.getType)) - case nop@NAryOperator(ts, op) => { - val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } - if(iteIndex == -1) None else { - val (beforeIte, startIte) = ts.splitAt(iteIndex) - val afterIte = startIte.tail - val IfExpr(c, t, e) = startIte.head - Some(IfExpr(c, - op(beforeIte ++ Seq(t) ++ afterIte).setType(nop.getType), - op(beforeIte ++ Seq(e) ++ afterIte).setType(nop.getType) - ).setType(nop.getType)) - } - } - case _ => None - } - - def fix[A](f: (A) => A, a: A): A = { - val na = f(a) - if(a == na) a else fix(f, na) - } - fix(searchAndReplaceDFS(transform), expr) - } - } diff --git a/src/main/scala/leon/testgen/TestGeneration.scala b/src/main/scala/leon/testgen/TestGeneration.scala index 1b48d72579f0ff620ef7d35f41e9bcfe841f3ac4..8404294eed8975f25a21f2fcfa95882de2f48d0f 100644 --- a/src/main/scala/leon/testgen/TestGeneration.scala +++ b/src/main/scala/leon/testgen/TestGeneration.scala @@ -3,6 +3,7 @@ package leon.testgen import leon.purescala.Common._ import leon.purescala.Definitions._ import leon.purescala.Trees._ +import leon.purescala.TreeOps._ import leon.purescala.TypeTrees._ import leon.purescala.ScalaPrinter import leon.Extensions._