From 6bffff4f17bc2b881018b3e90a845ba845419d01 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Wed, 3 Aug 2016 15:04:52 +0200 Subject: [PATCH] Finished up unrolling solver and theory encoders --- src/main/scala/inox/ast/DSL.scala | 147 +++++----- src/main/scala/inox/ast/Definitions.scala | 17 +- src/main/scala/inox/ast/Expressions.scala | 41 --- src/main/scala/inox/ast/Extractors.scala | 24 -- src/main/scala/inox/ast/GenTreeOps.scala | 2 +- src/main/scala/inox/ast/Printers.scala | 19 +- src/main/scala/inox/ast/SymbolOps.scala | 4 +- src/main/scala/inox/ast/TreeOps.scala | 18 +- src/main/scala/inox/ast/Trees.scala | 6 +- .../inox/evaluators/RecursiveEvaluator.scala | 39 +-- .../inox/solvers/theories/BagEncoder.scala | 179 +++++++----- .../inox/solvers/theories/StringEncoder.scala | 254 ++++++++---------- .../inox/solvers/theories/TheoryEncoder.scala | 37 ++- .../solvers/unrolling/FunctionTemplates.scala | 4 +- .../solvers/unrolling/UnrollingSolver.scala | 13 +- 15 files changed, 354 insertions(+), 450 deletions(-) diff --git a/src/main/scala/inox/ast/DSL.scala b/src/main/scala/inox/ast/DSL.scala index 03380ea15..052abc20e 100644 --- a/src/main/scala/inox/ast/DSL.scala +++ b/src/main/scala/inox/ast/DSL.scala @@ -18,90 +18,54 @@ import scala.language.implicitConversions * Instead one-letter constructors are provided. */ trait DSL { - val program: Program - import program._ + protected val trees: Trees import trees._ - import symbols._ /* Expressions */ - trait SimplificationLevel - case object NoSimplify extends SimplificationLevel - case object SafeSimplify extends SimplificationLevel - private def simp(e1: => Expr, e2: => Expr)(implicit simpLvl: SimplificationLevel): Expr = simpLvl match { - case NoSimplify => e1 - case SafeSimplify => e2 - } - - implicit class ExprOps(e: Expr)(implicit simpLvl: SimplificationLevel) { - - private def binOp( - e1: (Expr, Expr) => Expr, - e2: (Expr, Expr) => Expr - ) = { - (other: Expr) => simp(e1(e, other), e2(e,other)) - } - - private def unOp( - e1: (Expr) => Expr, - e2: (Expr) => Expr - ) = { - simp(e1(e), e2(e)) - } + implicit class ExpressionWrapper(e: Expr) { // Arithmetic - def + = binOp(Plus, plus) - def - = binOp(Minus, minus) - def % = binOp(Modulo, Modulo) - def / = binOp(Division, Division) + def + = Plus(e, _: Expr) + def - = Minus(e, _: Expr) + def % = Modulo(e, _: Expr) + def / = Division(e, _: Expr) // Comparisons - def < = binOp(LessThan, LessThan) - def <= = binOp(LessEquals, LessEquals) - def > = binOp(GreaterThan, GreaterThan) - def >= = binOp(GreaterEquals, GreaterEquals) - def === = binOp(Equals, equality) + def < = LessThan(e, _: Expr) + def <= = LessEquals(e, _: Expr) + def > = GreaterThan(e, _: Expr) + def >= = GreaterEquals(e, _: Expr) + def === = Equals(e, _: Expr) // Boolean - def && = binOp(And(_, _), and(_, _)) - def || = binOp(Or(_, _), or(_, _)) - def ==> = binOp(Implies, implies) - def unary_! = unOp(Not, not) + def && = And(e, _: Expr) + def || = Or(e, _: Expr) + def ==> = Implies(e, _: Expr) + def unary_! = Not(e) // Tuple selections - def _1 = unOp(TupleSelect(_, 1), tupleSelect(_, 1, true)) - def _2 = unOp(TupleSelect(_, 2), tupleSelect(_, 2, true)) - def _3 = unOp(TupleSelect(_, 3), tupleSelect(_, 3, true)) - def _4 = unOp(TupleSelect(_, 4), tupleSelect(_, 4, true)) + def _1 = TupleSelect(e, 1) + def _2 = TupleSelect(e, 2) + def _3 = TupleSelect(e, 3) + def _4 = TupleSelect(e, 4) // Sets def size = SetCardinality(e) - def subsetOf = binOp(SubsetOf, SubsetOf) - def insert = binOp(SetAdd, SetAdd) - def ++ = binOp(SetUnion, SetUnion) - def -- = binOp(SetDifference, SetDifference) - def & = binOp(SetIntersection, SetIntersection) + def subsetOf = SubsetOf(e, _: Expr) + def insert = SetAdd(e, _: Expr) + def ++ = SetUnion(e, _: Expr) + def -- = SetDifference(e, _: Expr) + def & = SetIntersection(e, _: Expr) // Misc. - def getField(selector: String) = { - val id = for { - ct <- scala.util.Try(e.getType.asInstanceOf[ClassType]).toOption - tcd <- ct.lookupClass - tccd <- scala.util.Try(tcd.toCase).toOption - field <- tccd.cd.fields.find(_.id.name == selector) - } yield { - field.id - } - CaseClassSelector(e, id.get) - } - - def ensures(other: Expr) = Ensuring(e, other) + def getField(selector: Identifier) = CaseClassSelector(e, selector) def apply(es: Expr*) = Application(e, es.toSeq) - def isInstOf(tp: ClassType) = unOp(IsInstanceOf(_, tp), symbols.isInstOf(_, tp)) - def asInstOf(tp: ClassType) = unOp(AsInstanceOf(_, tp), symbols.asInstOf(_, tp)) + def isInstOf(tp: ClassType) = IsInstanceOf(e, tp) + def asInstOf(tp: ClassType) = AsInstanceOf(e, tp) } // Literals @@ -113,28 +77,38 @@ trait DSL { def E(n: BigInt, d: BigInt) = FractionLiteral(n, d) def E(s: String): Expr = StringLiteral(s) def E(e1: Expr, e2: Expr, es: Expr*): Expr = Tuple(e1 :: e2 :: es.toList) + /* def E(s: Set[Expr]) = { require(s.nonEmpty) FiniteSet(s.toSeq, leastUpperBound(s.toSeq map (_.getType)).get) } + */ def E(vd: ValDef) = vd.toVariable // TODO: We should be able to remove this def E(id: Identifier) = new IdToFunInv(id) class IdToFunInv(id: Identifier) { - def apply(tps: Type*)(args: Expr*) = - FunctionInvocation(id, tps.toSeq, args.toSeq) + def apply(tp1: Type, tps: Type*)(args: Expr*) = + FunctionInvocation(id, tp1 +: tps.toSeq, args.toSeq) + def apply(args: Expr*) = + FunctionInvocation(id, Seq.empty, args.toSeq) } // if-then-else - class DanglingElse(cond: Expr, thenn: Expr) { - def else_ (theElse: Expr) = IfExpr(cond, thenn, theElse) + class DanglingElse private[DSL] (condThens: Seq[(Expr, Expr)]) { + def else_if (cond2: Expr)(thenn2: Expr) = new DanglingElse(condThens :+ (cond2 -> thenn2)) + def else_ (theElse: Expr) = condThens.foldRight(theElse) { + case ((cond, thenn), elze) =>IfExpr(cond, thenn, elze) + } } - def if_ (cond: Expr)(thenn: Expr) = new DanglingElse(cond, thenn) + def if_ (cond: Expr)(thenn: Expr) = new DanglingElse(Seq(cond -> thenn)) def ite(cond: Expr, thenn: Expr, elze: Expr) = IfExpr(cond, thenn, elze) implicit class FunctionInv(fd: FunDef) { - def apply(args: Expr*) = functionInvocation(fd, args.toSeq) + def apply(tp1: Type, tps: Type*)(args: Expr*) = + FunctionInvocation(fd.id, tp1 +: tps.toSeq, args.toSeq) + def apply(args: Expr*) = + FunctionInvocation(fd.id, Seq.empty, args.toSeq) } implicit class CaseClassToExpr(ct: ClassType) { @@ -157,11 +131,8 @@ trait DSL { * @param v The value bound to the let-variable * @param body The context which will be filled with the let-variable */ - def let(vd: ValDef, v: Expr)(body: Variable => Expr)(implicit simpLvl: SimplificationLevel) = { - simp( - Let(vd, v, body(vd.toVariable)), - symbols.let(vd, v, body(vd.toVariable)) - ) + def let(vd: ValDef, v: Expr)(body: Variable => Expr) = { + Let(vd, v, body(vd.toVariable)) } // Lambdas @@ -187,6 +158,28 @@ trait DSL { ) } + // Foralls + def forall(vd: ValDef)(body: Variable => Expr) = { + Forall(Seq(vd), body(vd.toVariable)) + } + + def forall(vd1: ValDef, vd2: ValDef) + (body: (Variable, Variable) => Expr) = { + Forall(Seq(vd1, vd2), body(vd1.toVariable, vd2.toVariable)) + } + + def forall(vd1: ValDef, vd2: ValDef, vd3: ValDef) + (body: (Variable, Variable, Variable) => Expr) = { + Forall(Seq(vd1, vd2, vd3), body(vd1.toVariable, vd2.toVariable, vd3.toVariable)) + } + + def forall(vd1: ValDef, vd2: ValDef, vd3: ValDef, vd4: ValDef) + (body: (Variable, Variable, Variable, Variable) => Expr) = { + Forall( + Seq(vd1, vd2, vd3, vd4), + body(vd1.toVariable, vd2.toVariable, vd3.toVariable, vd4.toVariable)) + } + // Block-like class BlockSuspension(susp: Expr => Expr) { def in(e: Expr) = susp(e) @@ -195,9 +188,11 @@ trait DSL { /* Types */ def T(tp1: Type, tp2: Type, tps: Type*) = TupleType(tp1 :: tp2 :: tps.toList) def T(id: Identifier) = new IdToClassType(id) + class IdToClassType(id: Identifier) { def apply(tps: Type*) = ClassType(id, tps.toSeq) } + implicit class FunctionTypeBuilder(to: Type) { def =>: (from: Type) = FunctionType(Seq(from), to) @@ -241,9 +236,9 @@ trait DSL { val tParams = tParamNames map TypeParameter.fresh val tParamDefs = tParams map TypeParameterDef val (params, retType, bodyBuilder) = builder(tParams) - val fullBody = bodyBuilder(params map (_.toVariable)) + val body = bodyBuilder(params map (_.toVariable)) - new FunDef(id, tParamDefs, params, retType, fullBody, Set()) + new FunDef(id, tParamDefs, params, retType, Some(body), Set()) } def mkAbstractClass(id: Identifier) diff --git a/src/main/scala/inox/ast/Definitions.scala b/src/main/scala/inox/ast/Definitions.scala index 143033d38..5de66e6fb 100644 --- a/src/main/scala/inox/ast/Definitions.scala +++ b/src/main/scala/inox/ast/Definitions.scala @@ -24,7 +24,7 @@ trait Definitions { self: Trees => case class ClassLookupException(id: Identifier) extends LookupException(id, "class") case class NotWellFormedException(id: Identifier, s: Symbols) - extends Exception(s"$id not well formed in ${s.asString}") + extends Exception(s"$id not well formed in $s") /** Common super-type for [[ValDef]] and [[Expressions.Variable]]. * @@ -77,7 +77,7 @@ trait Definitions { self: Trees => override def hashCode: Int = super[VariableSymbol].hashCode } - type Symbols <: AbstractSymbols + type Symbols >: Null <: AbstractSymbols /** A wrapper for a program. For now a program is simply a single object. */ trait AbstractSymbols @@ -86,10 +86,10 @@ trait Definitions { self: Trees => with SymbolOps with CallGraph with Constructors - with Paths { + with Paths { self0: Symbols => - protected val classes: Map[Identifier, ClassDef] - protected val functions: Map[Identifier, FunDef] + protected[ast] val classes: Map[Identifier, ClassDef] + protected[ast] val functions: Map[Identifier, FunDef] private[ast] val trees: self.type = self protected val symbols: this.type = this @@ -120,7 +120,12 @@ trait Definitions { self: Trees => def getFunction(id: Identifier): FunDef = lookupFunction(id).getOrElse(throw FunctionLookupException(id)) def getFunction(id: Identifier, tps: Seq[Type]): TypedFunDef = lookupFunction(id, tps).getOrElse(throw FunctionLookupException(id)) - def asString: String = asString(PrinterOptions.fromSymbols(this, InoxContext.printNames)) + override def toString: String = asString(PrinterOptions.fromSymbols(this, InoxContext.printNames)) + override def asString(implicit opts: PrinterOptions): String = { + classes.map(p => PrettyPrinter(p._2, opts)).mkString("\n\n") + + "\n\n-----------\n\n" + + functions.map(p => PrettyPrinter(p._2, opts)).mkString("\n\n") + } def transform(t: TreeTransformer): Symbols } diff --git a/src/main/scala/inox/ast/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala index 23cc7170e..7049e7c65 100644 --- a/src/main/scala/inox/ast/Expressions.scala +++ b/src/main/scala/inox/ast/Expressions.scala @@ -334,27 +334,6 @@ trait Expressions { self: Trees => /* String Theory */ - abstract class ConverterToString(fromType: Type, toType: Type) extends Expr with CachingTyped { - val expr: Expr - protected def computeType(implicit s: Symbols): Type = - if (expr.getType == fromType) toType else Untyped - } - - /** $encodingof `expr.toString` for Int32 to String */ - case class Int32ToString(expr: Expr) extends ConverterToString(Int32Type, StringType) - - /** $encodingof `expr.toString` for boolean to String */ - case class BooleanToString(expr: Expr) extends ConverterToString(BooleanType, StringType) - - /** $encodingof `expr.toString` for BigInt to String */ - case class IntegerToString(expr: Expr) extends ConverterToString(IntegerType, StringType) - - /** $encodingof `expr.toString` for char to String */ - case class CharToString(expr: Expr) extends ConverterToString(CharType, StringType) - - /** $encodingof `expr.toString` for real to String */ - case class RealToString(expr: Expr) extends ConverterToString(RealType, StringType) - /** $encodingof `lhs + rhs` for strings */ case class StringConcat(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { protected def computeType(implicit s: Symbols): Type = { @@ -365,17 +344,6 @@ trait Expressions { self: Trees => /** $encodingof `lhs.subString(start, end)` for strings */ case class SubString(expr: Expr, start: Expr, end: Expr) extends Expr with CachingTyped { - protected def computeType(implicit s: Symbols): Type = { - val ext = expr.getType - val st = start.getType - val et = end.getType - if (ext == StringType && st == Int32Type && et == Int32Type) StringType - else Untyped - } - } - - /** $encodingof `lhs.subString(start, end)` for strings */ - case class BigSubString(expr: Expr, start: Expr, end: Expr) extends Expr with CachingTyped { protected def computeType(implicit s: Symbols): Type = { val ext = expr.getType val st = start.getType @@ -387,21 +355,12 @@ trait Expressions { self: Trees => /** $encodingof `lhs.length` for strings */ case class StringLength(expr: Expr) extends Expr with CachingTyped { - protected def computeType(implicit s: Symbols): Type = { - if (expr.getType == StringType) Int32Type - else Untyped - } - } - - /** $encodingof `lhs.length` for strings */ - case class StringBigLength(expr: Expr) extends Expr with CachingTyped { protected def computeType(implicit s: Symbols): Type = { if (expr.getType == StringType) IntegerType else Untyped } } - /* General arithmetic */ def numericType(tpe: Type, tpes: Type*)(implicit s: Symbols): Type = { diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala index 4795bd3b1..e7cd86bbd 100644 --- a/src/main/scala/inox/ast/Extractors.scala +++ b/src/main/scala/inox/ast/Extractors.scala @@ -30,18 +30,6 @@ trait Extractors { self: Trees => Some((Seq(t), (es: Seq[Expr]) => UMinus(es.head))) case StringLength(t) => Some((Seq(t), (es: Seq[Expr]) => StringLength(es.head))) - case StringBigLength(t) => - Some((Seq(t), (es: Seq[Expr]) => StringBigLength(es.head))) - case Int32ToString(t) => - Some((Seq(t), (es: Seq[Expr]) => Int32ToString(es.head))) - case BooleanToString(t) => - Some((Seq(t), (es: Seq[Expr]) => BooleanToString(es.head))) - case IntegerToString(t) => - Some((Seq(t), (es: Seq[Expr]) => IntegerToString(es.head))) - case CharToString(t) => - Some((Seq(t), (es: Seq[Expr]) => CharToString(es.head))) - case RealToString(t) => - Some((Seq(t), (es: Seq[Expr]) => RealToString(es.head))) case SetCardinality(t) => Some((Seq(t), (es: Seq[Expr]) => SetCardinality(es.head))) case CaseClassSelector(e, sel) => @@ -126,7 +114,6 @@ trait Extractors { self: Trees => case And(args) => Some((args, es => And(es))) case Or(args) => Some((args, es => Or(es))) case SubString(t1, a, b) => Some((t1::a::b::Nil, es => SubString(es(0), es(1), es(2)))) - case BigSubString(t1, a, b) => Some((t1::a::b::Nil, es => BigSubString(es(0), es(1), es(2)))) case FiniteSet(els, base) => Some((els, els => FiniteSet(els, base))) case FiniteBag(els, base) => @@ -202,17 +189,6 @@ trait Extractors { self: Trees => def unapply[T <: Typed](e: T)(implicit p: Symbols): Option[(T, Type)] = Some((e, e.getType)) } - object WithStringconverter { - def unapply(t: Type): Option[Expr => Expr] = t match { - case BooleanType => Some(BooleanToString) - case Int32Type => Some(Int32ToString) - case IntegerType => Some(IntegerToString) - case CharType => Some(CharToString) - case RealType => Some(RealToString) - case _ => None - } - } - def unwrapTuple(e: Expr, isTuple: Boolean)(implicit s: Symbols): Seq[Expr] = e.getType match { case TupleType(subs) if isTuple => for (ind <- 1 to subs.size) yield { s.tupleSelect(e, ind, isTuple) } diff --git a/src/main/scala/inox/ast/GenTreeOps.scala b/src/main/scala/inox/ast/GenTreeOps.scala index d75706378..d1690fbb3 100644 --- a/src/main/scala/inox/ast/GenTreeOps.scala +++ b/src/main/scala/inox/ast/GenTreeOps.scala @@ -23,7 +23,7 @@ trait TreeExtractor { * @tparam SubTree The type of the tree */ trait GenTreeOps { - private[ast] val trees: Trees + protected val trees: Trees import trees._ type SubTree <: Tree diff --git a/src/main/scala/inox/ast/Printers.scala b/src/main/scala/inox/ast/Printers.scala index 3d7f09c6b..de9b43264 100644 --- a/src/main/scala/inox/ast/Printers.scala +++ b/src/main/scala/inox/ast/Printers.scala @@ -89,9 +89,6 @@ trait Printers { self: Trees => case Let(vd, expr, SubString(v2: Variable, start, StringLength(v3: Variable))) if vd == v2 && v2 == v3 => p"$expr.substring($start)" - case Let(vd, expr, BigSubString(v2: Variable, start, StringLength(v3: Variable))) if vd == v2 && v2 == v3 => - p"$expr.bigSubstring($start)" - case Let(b,d,e) => p"""|val $b = $d |$e""" @@ -108,19 +105,10 @@ trait Printers { self: Trees => case Implies(l,r) => optP { p"$l ==> $r" } case UMinus(expr) => p"-$expr" case Equals(l,r) => optP { p"$l == $r" } - - - case Int32ToString(expr) => p"$expr.toString" - case BooleanToString(expr) => p"$expr.toString" - case IntegerToString(expr) => p"$expr.toString" - case CharToString(expr) => p"$expr.toString" - case RealToString(expr) => p"$expr.toString" + case StringConcat(lhs, rhs) => optP { p"$lhs + $rhs" } - case SubString(expr, start, end) => p"$expr.substring($start, $end)" - case BigSubString(expr, start, end) => p"$expr.bigSubstring($start, $end)" case StringLength(expr) => p"$expr.length" - case StringBigLength(expr) => p"$expr.bigLength" case IntLiteral(v) => p"$v" case BVLiteral(bits, size) => p"x${(size to 1 by -1).map(i => if (bits(i)) "1" else "0")}" @@ -237,11 +225,6 @@ trait Printers { self: Trees => p"${c.id}${nary(c.tps, ", ", "[", "]")}" // Definitions - case Symbols(classes, functions) => - p"""${nary(classes.map(_._2).toSeq, "\n\n")}""" - p"\n\n" - p"""${nary(functions.map(_._2).toSeq, "\n\n")}""" - case acd: AbstractClassDef => p"abstract class ${acd.id}${nary(acd.tparams, ", ", "[", "]")}" diff --git a/src/main/scala/inox/ast/SymbolOps.scala b/src/main/scala/inox/ast/SymbolOps.scala index 40e2618c7..0cc2b401b 100644 --- a/src/main/scala/inox/ast/SymbolOps.scala +++ b/src/main/scala/inox/ast/SymbolOps.scala @@ -361,7 +361,7 @@ trait SymbolOps { self: TypeOps => def transform(expr: Expr): Option[Expr] = expr match { case IfExpr(c, t, e) => None - case nop@Deconstructor(ts, op) => { + case nop @ Deconstructor(ts, op) => { val iteIndex = ts.indexWhere{ case IfExpr(_, _, _) => true case _ => false } if(iteIndex == -1) None else { val (beforeIte, startIte) = ts.splitAt(iteIndex) @@ -455,9 +455,7 @@ trait SymbolOps { self: TypeOps => case StringConcat(b, StringLiteral("")) => b case StringConcat(StringLiteral(a), StringLiteral(b)) => StringLiteral(a + b) case StringLength(StringLiteral(a)) => IntLiteral(a.length) - case StringBigLength(StringLiteral(a)) => IntegerLiteral(a.length) case SubString(StringLiteral(a), IntLiteral(start), IntLiteral(end)) => StringLiteral(a.substring(start.toInt, end.toInt)) - case BigSubString(StringLiteral(a), IntegerLiteral(start), IntegerLiteral(end)) => StringLiteral(a.substring(start.toInt, end.toInt)) case _ => expr }).copiedFrom(expr) simplify0(expr) diff --git a/src/main/scala/inox/ast/TreeOps.scala b/src/main/scala/inox/ast/TreeOps.scala index 81729beb7..d2da99b30 100644 --- a/src/main/scala/inox/ast/TreeOps.scala +++ b/src/main/scala/inox/ast/TreeOps.scala @@ -26,7 +26,7 @@ trait TreeOps { self: Trees => } @inline - private def transformChanged(vds: Seq[ValDef]): (Seq[ValDef], Boolean) = { + private def transformValDefs(vds: Seq[ValDef]): (Seq[ValDef], Boolean) = { var changed = false val newVds = vds.map { vd => val newVd = transform(vd) @@ -38,7 +38,7 @@ trait TreeOps { self: Trees => } @inline - private def transformChanged(args: Seq[Expr]): (Seq[Expr], Boolean) = { + private def transformExprs(args: Seq[Expr]): (Seq[Expr], Boolean) = { var changed = false val newArgs = args.map { arg => val newArg = transform(arg) @@ -50,7 +50,7 @@ trait TreeOps { self: Trees => } @inline - private def transformChanged(tps: Seq[Type]): (Seq[Type], Boolean) = { + private def transformTypes(tps: Seq[Type]): (Seq[Type], Boolean) = { var changed = false val newTps = tps.map { tp => val newTp = transform(tp) @@ -65,7 +65,7 @@ trait TreeOps { self: Trees => case v: Variable => transform(v) case Lambda(args, body) => - val (newArgs, changedArgs) = transformChanged(args) + val (newArgs, changedArgs) = transformValDefs(args) val newBody = transform(body) if (changedArgs || (body ne newBody)) { Lambda(newArgs, newBody).copiedFrom(e) @@ -74,7 +74,7 @@ trait TreeOps { self: Trees => } case Forall(args, body) => - val (newArgs, changedArgs) = transformChanged(args) + val (newArgs, changedArgs) = transformValDefs(args) val newBody = transform(body) if (changedArgs || (body ne newBody)) { Forall(newArgs, newBody).copiedFrom(e) @@ -94,7 +94,7 @@ trait TreeOps { self: Trees => case CaseClass(ct, args) => val newCt = transform(ct) - val (newArgs, changedArgs) = transformChanged(args) + val (newArgs, changedArgs) = transformExprs(args) if ((ct ne newCt) || changedArgs) { CaseClass(newCt.asInstanceOf[ClassType], newArgs).copiedFrom(e) } else { @@ -110,8 +110,8 @@ trait TreeOps { self: Trees => } case FunctionInvocation(id, tps, args) => - val (newTps, changedTps) = transformChanged(tps) - val (newArgs, changedArgs) = transformChanged(args) + val (newTps, changedTps) = transformTypes(tps) + val (newArgs, changedArgs) = transformExprs(args) if (changedTps || changedArgs) { FunctionInvocation(id, newTps, newArgs).copiedFrom(e) } else { @@ -137,7 +137,7 @@ trait TreeOps { self: Trees => } case FiniteSet(es, tpe) => - val (newArgs, changed) = transformChanged(es) + val (newArgs, changed) = transformExprs(es) val newTpe = transform(tpe) if (changed || (tpe ne newTpe)) { FiniteSet(newArgs, newTpe).copiedFrom(e) diff --git a/src/main/scala/inox/ast/Trees.scala b/src/main/scala/inox/ast/Trees.scala index 69cc86967..83297cc39 100644 --- a/src/main/scala/inox/ast/Trees.scala +++ b/src/main/scala/inox/ast/Trees.scala @@ -33,9 +33,13 @@ trait Trees } object exprOps extends { - private[ast] val trees: Trees.this.type = Trees.this + protected val trees: Trees.this.type = Trees.this } with ExprOps + object dsl extends { + protected val trees: Trees.this.type = Trees.this + } with DSL + /** Represents a unique symbol in Inox. * * The name is stored in the decoded (source code) form rather than encoded (JVM) form. diff --git a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala index 794fd17fa..6c457f342 100644 --- a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala @@ -217,53 +217,16 @@ trait RecursiveEvaluator } case StringLength(a) => e(a) match { - case StringLiteral(a) => IntLiteral(a.length) - case res => throw EvalError(typeErrorMsg(res, Int32Type)) - } - - case StringBigLength(a) => e(a) match { case StringLiteral(a) => IntegerLiteral(a.length) - case res => throw EvalError(typeErrorMsg(res, IntegerType)) + case res => throw EvalError(typeErrorMsg(res, Int32Type)) } case SubString(a, start, end) => (e(a), e(start), e(end)) match { - case (StringLiteral(a), IntLiteral(b), IntLiteral(c)) => - StringLiteral(a.substring(b, c)) - case res => throw EvalError(typeErrorMsg(res._1, StringType)) - } - - case BigSubString(a, start, end) => (e(a), e(start), e(end)) match { case (StringLiteral(a), IntegerLiteral(b), IntegerLiteral(c)) => StringLiteral(a.substring(b.toInt, c.toInt)) case res => throw EvalError(typeErrorMsg(res._1, StringType)) } - case Int32ToString(a) => e(a) match { - case IntLiteral(i) => StringLiteral(i.toString) - case res => throw EvalError(typeErrorMsg(res, Int32Type)) - } - - case CharToString(a) => - e(a) match { - case CharLiteral(i) => StringLiteral(i.toString) - case res => throw EvalError(typeErrorMsg(res, CharType)) - } - - case IntegerToString(a) => e(a) match { - case IntegerLiteral(i) => StringLiteral(i.toString) - case res => throw EvalError(typeErrorMsg(res, IntegerType)) - } - - case BooleanToString(a) => e(a) match { - case BooleanLiteral(i) => StringLiteral(i.toString) - case res => throw EvalError(typeErrorMsg(res, BooleanType)) - } - - case RealToString(a) => e(a) match { - case FractionLiteral(n, d) => StringLiteral(n.toString + "/" + d.toString) - case res => throw EvalError(typeErrorMsg(res, RealType)) - } - case UMinus(ex) => e(ex) match { case BVLiteral(b, s) => diff --git a/src/main/scala/inox/solvers/theories/BagEncoder.scala b/src/main/scala/inox/solvers/theories/BagEncoder.scala index d6ba2e33d..57f38edf4 100644 --- a/src/main/scala/inox/solvers/theories/BagEncoder.scala +++ b/src/main/scala/inox/solvers/theories/BagEncoder.scala @@ -1,118 +1,159 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -package leon +package inox package solvers package theories -import purescala.Common._ -import purescala.Expressions._ -import purescala.Definitions._ -import purescala.Types._ +trait BagEncoder extends TheoryEncoder { + import trees._ + import trees.dsl._ -class BagEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder { - val Bag = p.library.lookupUnique[CaseClassDef]("leon.theories.Bag") + val BagID = FreshIdentifier("Bag") + val f = FreshIdentifier("f") - val Get = p.library.lookupUnique[FunDef]("leon.theories.Bag.get") - val Add = p.library.lookupUnique[FunDef]("leon.theories.Bag.add") - val Union = p.library.lookupUnique[FunDef]("leon.theories.Bag.union") - val Difference = p.library.lookupUnique[FunDef]("leon.theories.Bag.difference") - val Intersect = p.library.lookupUnique[FunDef]("leon.theories.Bag.intersect") - val BagEquals = p.library.lookupUnique[FunDef]("leon.theories.Bag.equals") + val bagClass = mkCaseClass(BagID)("T")(None)( + { case Seq(aT) => Seq(ValDef(f, aT =>: IntegerType)) }) - val encoder = new Encoder { - override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match { + val Bag = T(BagID) + + val GetID = FreshIdentifier("get") + val Get = mkFunDef(GetID)("T") { case Seq(aT) => ( + Seq("bag" :: Bag(aT), "x" :: aT), + IntegerType, { case Seq(bag, x) => bag.getField(f)(x) }) + } + + val AddID = FreshIdentifier("add") + val Add = mkFunDef(AddID)("T") { case Seq(aT) => ( + Seq("bag" :: Bag(aT), "x" :: aT), + Bag(aT), { case Seq(bag, x) => Bag(aT)( + \("y" :: aT)(y => bag.getField(f)(y) + { + if_ (y === x) { E(BigInt(1)) } else_ { E(BigInt(0)) } + })) + }) + } + + val UnionID = FreshIdentifier("union") + val Union = mkFunDef(UnionID)("T") { case Seq(aT) => ( + Seq("b1" :: Bag(aT), "b2" :: Bag(aT)), + Bag(aT), { case Seq(b1, b2) => Bag(aT)( + \("y" :: aT)(y => b1.getField(f)(y) + b2.getField(f)(y))) + }) + } + + val DifferenceID = FreshIdentifier("difference") + val Difference = mkFunDef(DifferenceID)("T") { case Seq(aT) => ( + Seq("b1" :: Bag(aT), "b2" :: Bag(aT)), + Bag(aT), { case Seq(b1, b2) => Bag(aT)( + \("y" :: aT)(y => let("res" :: IntegerType, b1.getField(f)(y) - b2.getField(f)(y)) { + res => if_ (res < E(BigInt(0))) { E(BigInt(0)) } else_ { res } + })) + }) + } + + val IntersectID = FreshIdentifier("intersect") + val Intersect = mkFunDef(IntersectID)("T") { case Seq(aT) => ( + Seq("b1" :: Bag(aT), "b2" :: Bag(aT)), + Bag(aT), { case Seq(b1, b2) => Bag(aT)( + \("y" :: aT)(y => let("r1" :: IntegerType, b1.getField(f)(y)) { r1 => + let("r2" :: IntegerType, b2.getField(f)(y)) { r2 => + if_ (r1 > r2) { r2 } else_ { r1 } + } + })) + }) + } + + val EqualsID = FreshIdentifier("equals") + val BagEquals = mkFunDef(EqualsID)("T") { case Seq(aT) => ( + Seq("b1" :: Bag(aT), "b2" :: Bag(aT)), + BooleanType, { case Seq(b1, b2) => + forall("x" :: aT)(x => b1.getField(f)(x) === b2.getField(f)(x)) + }) + } + + val encoder = new TreeTransformer { + import sourceProgram._ + + override def transform(e: Expr): Expr = e match { case FiniteBag(elems, tpe) => val newTpe = transform(tpe) - val id = FreshIdentifier("x", newTpe, true) - Some(CaseClass(Bag.typed(Seq(newTpe)), Seq(Lambda(Seq(ValDef(id)), - elems.foldRight[Expr](InfiniteIntegerLiteral(0).copiedFrom(e)) { case ((k, v), ite) => - IfExpr(Equals(Variable(id), transform(k)), transform(v), ite).copiedFrom(e) - }))).copiedFrom(e)) + Bag(newTpe)(\("x" :: newTpe)(x => elems.foldRight[Expr](IntegerLiteral(0).copiedFrom(e)) { + case ((k, v), ite) => IfExpr(x === transform(k), transform(v), ite).copiedFrom(e) + })).copiedFrom(e) case BagAdd(bag, elem) => val BagType(base) = bag.getType - Some(FunctionInvocation(Add.typed(Seq(transform(base))), Seq(transform(bag), transform(elem))).copiedFrom(e)) + Add(transform(base))(transform(bag), transform(elem)).copiedFrom(e) case MultiplicityInBag(elem, bag) => val BagType(base) = bag.getType - Some(FunctionInvocation(Get.typed(Seq(transform(base))), Seq(transform(bag), transform(elem))).copiedFrom(e)) + Get(transform(base))(transform(bag), transform(elem)).copiedFrom(e) case BagIntersection(b1, b2) => val BagType(base) = b1.getType - Some(FunctionInvocation(Intersect.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e)) + Intersect(transform(base))(transform(b1), transform(b2)).copiedFrom(e) case BagUnion(b1, b2) => val BagType(base) = b1.getType - Some(FunctionInvocation(Union.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e)) + Union(transform(base))(transform(b1), transform(b2)).copiedFrom(e) case BagDifference(b1, b2) => val BagType(base) = b1.getType - Some(FunctionInvocation(Difference.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e)) + Difference(transform(base))(transform(b1), transform(b2)) case Equals(b1, b2) if b1.getType.isInstanceOf[BagType] => val BagType(base) = b1.getType - Some(FunctionInvocation(BagEquals.typed(Seq(transform(base))), Seq(transform(b1), transform(b2))).copiedFrom(e)) + BagEquals(transform(base))(transform(b1), transform(b2)).copiedFrom(e) - case _ => None + case _ => super.transform(e) } - override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { - case BagType(base) => Some(Bag.typed(Seq(transform(base))).copiedFrom(tpe)) - case _ => None + override def transform(tpe: Type): Type = tpe match { + case BagType(base) => Bag(transform(base)).copiedFrom(tpe) + case _ => super.transform(tpe) } } - val decoder = new Decoder { - override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match { - case cc @ CaseClass(CaseClassType(Bag, Seq(tpe)), args) => - Some(FiniteBag(args(0) match { - case FiniteLambda(mapping, dflt, tpe) => - if (dflt != InfiniteIntegerLiteral(0)) - throw new Unsupported(cc, "Bags can't have default value " + dflt.asString(ctx))(ctx) - mapping.map { case (ks, v) => transform(ks.head) -> transform(v) }.toMap + val decoder = new TreeTransformer { + import targetProgram._ - case Lambda(Seq(ValDef(id)), body) => - def rec(expr: Expr): Map[Expr, Expr] = expr match { - case IfExpr(Equals(`id`, k), v, elze) => rec(elze) + (transform(k) -> transform(v)) - case InfiniteIntegerLiteral(v) if v == 0 => Map.empty - case _ => throw new Unsupported(expr, "Bags can't have default value " + expr.asString(ctx))(ctx) - } + override def transform(e: Expr): Expr = e match { + case cc @ CaseClass(ClassType(BagID, Seq(tpe)), Seq(Lambda(Seq(vd), body))) => + val Variable = vd.toVariable + def rec(expr: Expr): Seq[(Expr, Expr)] = expr match { + case IfExpr(Equals(Variable, k), v, elze) => rec(elze) :+ (transform(k) -> transform(v)) + case IntegerLiteral(v) if v == 0 => Seq.empty + case _ => throw new Unsupported(expr, "Bags can't have default value " + expr.asString) + } - rec(body) + FiniteBag(rec(body), transform(tpe)).copiedFrom(e) - case f => scala.sys.error("Unexpected function " + f.asString(ctx)) - }, transform(tpe)).copiedFrom(e)) + case cc @ CaseClass(ClassType(BagID, Seq(tpe)), args) => + throw new Unsupported(e, "Unexpected argument to bag constructor") - case FunctionInvocation(TypedFunDef(Add, Seq(_)), Seq(bag, elem)) => - Some(BagAdd(transform(bag), transform(elem)).copiedFrom(e)) + case FunctionInvocation(AddID, Seq(_), Seq(bag, elem)) => + BagAdd(transform(bag), transform(elem)).copiedFrom(e) - case FunctionInvocation(TypedFunDef(Get, Seq(_)), Seq(bag, elem)) => - Some(MultiplicityInBag(transform(elem), transform(bag)).copiedFrom(e)) + case FunctionInvocation(GetID, Seq(_), Seq(bag, elem)) => + MultiplicityInBag(transform(elem), transform(bag)).copiedFrom(e) - case FunctionInvocation(TypedFunDef(Intersect, Seq(_)), Seq(b1, b2)) => - Some(BagIntersection(transform(b1), transform(b2)).copiedFrom(e)) + case FunctionInvocation(IntersectID, Seq(_), Seq(b1, b2)) => + BagIntersection(transform(b1), transform(b2)).copiedFrom(e) - case FunctionInvocation(TypedFunDef(Union, Seq(_)), Seq(b1, b2)) => - Some(BagUnion(transform(b1), transform(b2)).copiedFrom(e)) + case FunctionInvocation(UnionID, Seq(_), Seq(b1, b2)) => + BagUnion(transform(b1), transform(b2)).copiedFrom(e) - case FunctionInvocation(TypedFunDef(Difference, Seq(_)), Seq(b1, b2)) => - Some(BagDifference(transform(b1), transform(b2)).copiedFrom(e)) + case FunctionInvocation(DifferenceID, Seq(_), Seq(b1, b2)) => + BagDifference(transform(b1), transform(b2)).copiedFrom(e) - case FunctionInvocation(TypedFunDef(BagEquals, Seq(_)), Seq(b1, b2)) => - Some(Equals(transform(b1), transform(b2)).copiedFrom(e)) - - case _ => None - } + case FunctionInvocation(EqualsID, Seq(_), Seq(b1, b2)) => + Equals(transform(b1), transform(b2)).copiedFrom(e) - override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { - case CaseClassType(Bag, Seq(base)) => Some(BagType(transform(base)).copiedFrom(tpe)) - case _ => None + case _ => super.transform(e) } - override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { - case CaseClassPattern(b, CaseClassType(Bag, Seq(tpe)), Seq(sub)) => - throw new Unsupported(pat, "Can't translate Bag case class pattern")(ctx) - case _ => super.transform(pat) + override def transform(tpe: Type): Type = tpe match { + case ClassType(BagID, Seq(base)) => BagType(transform(base)).copiedFrom(tpe) + case _ => super.transform(tpe) } } } diff --git a/src/main/scala/inox/solvers/theories/StringEncoder.scala b/src/main/scala/inox/solvers/theories/StringEncoder.scala index 4c6f3a51d..7199069aa 100644 --- a/src/main/scala/inox/solvers/theories/StringEncoder.scala +++ b/src/main/scala/inox/solvers/theories/StringEncoder.scala @@ -1,172 +1,140 @@ /* Copyright 2009-2015 EPFL, Lausanne */ -package leon +package inox package solvers package theories -import purescala.Common._ -import purescala.Expressions._ -import purescala.Constructors._ -import purescala.Types._ -import purescala.Definitions._ -import leon.utils.Bijection -import leon.purescala.TypeOps - -class StringEncoder(ctx: LeonContext, p: Program) extends TheoryEncoder { - - val StringID = FreshIdentifier("String") - val StringNilID = FreshIdentifier("StringNil") - val StringConsID = FreshIdentifier("StringCons") - - val StringConsHeadID = FreshIdentifier("head") - val StringConsTailID = FreshIdentifier("tail") - - val String = new AbstractClassDef(StringID, Seq.empty Seq(StringConsID, StringNilID), Set.empty).typed - val StringNil = new CaseClassDef(StringNilID, Seq.empty, Some(StringID), Seq.empty, Set.empty).typed - val StringCons = new CaseClassDef(StringConsID, Seq.empty, Some(StringID), Seq( - ValDef(StringConsHeadID, CharType), - ValDef(StringConsTailID, ClassType(StringID, Seq.empty)) - ), Set.empty).typed +import utils._ + +trait StringEncoder extends TheoryEncoder { + import trees._ + import trees.dsl._ + + val stringID = FreshIdentifier("String") + val stringNilID = FreshIdentifier("StringNil") + val stringConsID = FreshIdentifier("StringCons") + + val head = FreshIdentifier("head") + val tail = FreshIdentifier("tail") + + val stringClass = new AbstractClassDef(stringID, Seq.empty, Seq(stringConsID, stringNilID), Set.empty) + val stringNilClass = new CaseClassDef(stringNilID, Seq.empty, Some(stringID), Seq.empty, Set.empty) + val stringConsClass = new CaseClassDef(stringConsID, Seq.empty, Some(stringID), Seq( + ValDef(head, CharType), + ValDef(tail, ClassType(stringID, Seq.empty)) + ), Set.empty) + + val String : ClassType = T(stringID)() + val StringNil : ClassType = T(stringNilID)() + val StringCons : ClassType = T(stringConsID)() val SizeID = FreshIdentifier("size") - val Size = new FunDef() - - val Size = p.library.lookupUnique[FunDef]("leon.theories.String.size").typed - val Take = p.library.lookupUnique[FunDef]("leon.theories.String.take").typed - val Drop = p.library.lookupUnique[FunDef]("leon.theories.String.drop").typed - val Slice = p.library.lookupUnique[FunDef]("leon.theories.String.slice").typed - val Concat = p.library.lookupUnique[FunDef]("leon.theories.String.concat").typed - - val SizeI = p.library.lookupUnique[FunDef]("leon.theories.String.sizeI").typed - val TakeI = p.library.lookupUnique[FunDef]("leon.theories.String.takeI").typed - val DropI = p.library.lookupUnique[FunDef]("leon.theories.String.dropI").typed - val SliceI = p.library.lookupUnique[FunDef]("leon.theories.String.sliceI").typed - - val FromInt = p.library.lookupUnique[FunDef]("leon.theories.String.fromInt").typed - val FromChar = p.library.lookupUnique[FunDef]("leon.theories.String.fromChar").typed - val FromBoolean = p.library.lookupUnique[FunDef]("leon.theories.String.fromBoolean").typed - val FromBigInt = p.library.lookupUnique[FunDef]("leon.theories.String.fromBigInt").typed - + val Size = mkFunDef(SizeID)()(_ => ( + Seq("s" :: String), + IntegerType, { case Seq(s) => + if_ (s.isInstOf(StringCons)) { + E(BigInt(1)) + E(SizeID)(s.asInstOf(StringCons).getField(tail)) + } else_ { + E(BigInt(0)) + } + })) + + val TakeID = FreshIdentifier("take") + val Take = mkFunDef(TakeID)()(_ => ( + Seq("s" :: String, "i" :: IntegerType), + String, { case Seq(s, i) => + if_ (s.isInstOf(StringCons) && i > E(BigInt(0))) { + StringCons( + s.asInstOf(StringCons).getField(head), + E(TakeID)(s.asInstOf(StringCons).getField(tail), i - E(BigInt(1)))) + } else_ { + StringNil() + } + })) + + val DropID = FreshIdentifier("drop") + val Drop = mkFunDef(DropID)()(_ => ( + Seq("s" :: String, "i" :: IntegerType), + String, { case Seq(s, i) => + if_ (s.isInstOf(StringCons) && i > E(BigInt(0))) { + E(DropID)(s.asInstOf(StringCons).getField(tail), i - E(BigInt(1))) + } else_ { + s + } + })) + + val SliceID = FreshIdentifier("slice") + val Slice = mkFunDef(SliceID)()(_ => ( + Seq("s" :: String, "from" :: IntegerType, "to" :: IntegerType), + String, { case Seq(s, from, to) => Take(Drop(s, from), to - from) })) + + val ConcatID = FreshIdentifier("concat") + val Concat = mkFunDef(ConcatID)()(_ => ( + Seq("s1" :: String, "s2" :: String), + String, { case Seq(s1, s2) => + if_ (s1.isInstOf(StringCons)) { + StringCons( + s1.asInstOf(StringCons).getField(head), + E(ConcatID)(s1.asInstOf(StringCons).getField(tail), s2)) + } else_ { + s2 + } + })) private val stringBijection = new Bijection[String, Expr]() - + private def convertToString(e: Expr): String = stringBijection.cachedA(e)(e match { - case CaseClass(_, Seq(CharLiteral(c), l)) => (if(c < 31) (c + 97).toChar else c) + convertToString(l) - case CaseClass(_, Seq()) => "" + case CaseClass(StringCons, Seq(CharLiteral(c), l)) => (if(c < 31) (c + 97).toChar else c) + convertToString(l) + case CaseClass(StringNil, Seq()) => "" }) private def convertFromString(v: String): Expr = stringBijection.cachedB(v) { - v.toList.foldRight(CaseClass(StringNil, Seq())){ - case (char, l) => CaseClass(StringCons, Seq(CharLiteral(char), l)) - } + v.toList.foldRight(StringNil()){ case (char, l) => StringCons(E(char), l) } } - val encoder = new Encoder { - override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match { - case StringLiteral(v) => - Some(convertFromString(v)) - case StringBigLength(a) => - Some(FunctionInvocation(Size, Seq(transform(a))).copiedFrom(e)) - case StringLength(a) => - Some(FunctionInvocation(SizeI, Seq(transform(a))).copiedFrom(e)) - case StringConcat(a, b) => - Some(FunctionInvocation(Concat, Seq(transform(a), transform(b))).copiedFrom(e)) + val encoder = new TreeTransformer { + override def transform(e: Expr): Expr = e match { + case StringLiteral(v) => convertFromString(v) + case StringLength(a) => Size(transform(a)).copiedFrom(e) + case StringConcat(a, b) => Concat(transform(a), transform(b)).copiedFrom(e) case SubString(a, start, Plus(start2, length)) if start == start2 => - Some(FunctionInvocation(TakeI, Seq(FunctionInvocation(DropI, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e)) - case SubString(a, start, end) => - Some(FunctionInvocation(SliceI, Seq(transform(a), transform(start), transform(end))).copiedFrom(e)) - case BigSubString(a, start, Plus(start2, length)) if start == start2 => - Some(FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(transform(a), transform(start))), transform(length))).copiedFrom(e)) - case BigSubString(a, start, end) => - Some(FunctionInvocation(Slice, Seq(transform(a), transform(start), transform(end))).copiedFrom(e)) - case Int32ToString(a) => - Some(FunctionInvocation(FromInt, Seq(transform(a))).copiedFrom(e)) - case IntegerToString(a) => - Some(FunctionInvocation(FromBigInt, Seq(transform(a))).copiedFrom(e)) - case CharToString(a) => - Some(FunctionInvocation(FromChar, Seq(transform(a))).copiedFrom(e)) - case BooleanToString(a) => - Some(FunctionInvocation(FromBoolean, Seq(transform(a))).copiedFrom(e)) - case _ => None - } - - override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { - case StringType => Some(String) - case _ => None + Take(Drop(transform(a), transform(start)), transform(length)).copiedFrom(e) + case SubString(a, start, end) => + Slice(transform(a), transform(start), transform(end)).copiedFrom(e) + case _ => super.transform(e) } - override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { - case LiteralPattern(binder, StringLiteral(s)) => - val newBinder = binder map transform - val newPattern = s.foldRight(CaseClassPattern(None, StringNil, Seq())) { - case (elem, pattern) => CaseClassPattern(None, StringCons, Seq(LiteralPattern(None, CharLiteral(elem)), pattern)) - } - (newPattern.copy(binder = newBinder), (binder zip newBinder).filter(p => p._1 != p._2).toMap) - case _ => super.transform(pat) + override def transform(tpe: Type): Type = tpe match { + case StringType => String + case _ => super.transform(tpe) } } - val decoder = new Decoder { - override def transformExpr(e: Expr)(implicit binders: Map[Identifier, Identifier]): Option[Expr] = e match { - case cc @ CaseClass(cct, args) if TypeOps.isSubtypeOf(cct, String)=> - Some(StringLiteral(convertToString(cc)).copiedFrom(cc)) - case FunctionInvocation(SizeI, Seq(a)) => - Some(StringLength(transform(a)).copiedFrom(e)) - case FunctionInvocation(Size, Seq(a)) => - Some(StringBigLength(transform(a)).copiedFrom(e)) - case FunctionInvocation(Concat, Seq(a, b)) => - Some(StringConcat(transform(a), transform(b)).copiedFrom(e)) - case FunctionInvocation(SliceI, Seq(a, from, to)) => - Some(SubString(transform(a), transform(from), transform(to)).copiedFrom(e)) - case FunctionInvocation(Slice, Seq(a, from, to)) => - Some(BigSubString(transform(a), transform(from), transform(to)).copiedFrom(e)) - case FunctionInvocation(TakeI, Seq(FunctionInvocation(DropI, Seq(a, start)), length)) => + val decoder = new TreeTransformer { + override def transform(e: Expr): Expr = e match { + case cc @ CaseClass(ct, args) if ct == StringNil || ct == StringCons => + StringLiteral(convertToString(cc)).copiedFrom(cc) + case FunctionInvocation(SizeID, Seq(), Seq(a)) => + StringLength(transform(a)).copiedFrom(e) + case FunctionInvocation(ConcatID, Seq(), Seq(a, b)) => + StringConcat(transform(a), transform(b)).copiedFrom(e) + case FunctionInvocation(SliceID, Seq(), Seq(a, from, to)) => + SubString(transform(a), transform(from), transform(to)).copiedFrom(e) + case FunctionInvocation(TakeID, Seq(), Seq(FunctionInvocation(DropID, Seq(), Seq(a, start)), length)) => val rstart = transform(start) - Some(SubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e)) - case FunctionInvocation(Take, Seq(FunctionInvocation(Drop, Seq(a, start)), length)) => - val rstart = transform(start) - Some(BigSubString(transform(a), rstart, plus(rstart, transform(length))).copiedFrom(e)) - case FunctionInvocation(TakeI, Seq(a, length)) => - Some(SubString(transform(a), IntLiteral(0), transform(length)).copiedFrom(e)) - case FunctionInvocation(Take, Seq(a, length)) => - Some(BigSubString(transform(a), InfiniteIntegerLiteral(0), transform(length)).copiedFrom(e)) - case FunctionInvocation(DropI, Seq(a, count)) => - val ra = transform(a) - Some(SubString(ra, transform(count), StringLength(ra)).copiedFrom(e)) - case FunctionInvocation(Drop, Seq(a, count)) => + SubString(transform(a), rstart, Plus(rstart, transform(length))).copiedFrom(e) + case FunctionInvocation(TakeID, Seq(), Seq(a, length)) => + SubString(transform(a), IntegerLiteral(0), transform(length)).copiedFrom(e) + case FunctionInvocation(DropID, Seq(), Seq(a, count)) => val ra = transform(a) - Some(BigSubString(ra, transform(count), StringBigLength(ra)).copiedFrom(e)) - case FunctionInvocation(FromInt, Seq(a)) => - Some(Int32ToString(transform(a)).copiedFrom(e)) - case FunctionInvocation(FromBigInt, Seq(a)) => - Some(IntegerToString(transform(a)).copiedFrom(e)) - case FunctionInvocation(FromChar, Seq(a)) => - Some(CharToString(transform(a)).copiedFrom(e)) - case FunctionInvocation(FromBoolean, Seq(a)) => - Some(BooleanToString(transform(a)).copiedFrom(e)) - case _ => None + SubString(ra, transform(count), StringLength(ra)).copiedFrom(e) + case _ => super.transform(e) } - override def transformType(tpe: TypeTree): Option[TypeTree] = tpe match { - case String | StringCons | StringNil => Some(StringType) - case _ => None - } - - override def transform(pat: Pattern): (Pattern, Map[Identifier, Identifier]) = pat match { - case CaseClassPattern(b, StringNil, Seq()) => - val newBinder = b map transform - (LiteralPattern(newBinder , StringLiteral("")), (b zip newBinder).filter(p => p._1 != p._2).toMap) - - case CaseClassPattern(b, StringCons, Seq(LiteralPattern(_, CharLiteral(elem)), sub)) => transform(sub) match { - case (LiteralPattern(_, StringLiteral(s)), binders) => - val newBinder = b map transform - (LiteralPattern(newBinder, StringLiteral(elem + s)), (b zip newBinder).filter(p => p._1 != p._2).toMap ++ binders) - case (e, binders) => - throw new Unsupported(pat, "Failed to parse pattern back as string: " + e)(ctx) - } - - case _ => super.transform(pat) + override def transform(tpe: Type): Type = tpe match { + case String | StringCons | StringNil => StringType + case _ => super.transform(tpe) } } } diff --git a/src/main/scala/inox/solvers/theories/TheoryEncoder.scala b/src/main/scala/inox/solvers/theories/TheoryEncoder.scala index def0c9869..5c78ad27b 100644 --- a/src/main/scala/inox/solvers/theories/TheoryEncoder.scala +++ b/src/main/scala/inox/solvers/theories/TheoryEncoder.scala @@ -7,9 +7,14 @@ package theories import utils._ trait TheoryEncoder { - val trees: Trees + val trees: ast.Trees import trees._ + val sourceProgram: Program { val trees: TheoryEncoder.this.trees.type } + lazy val targetProgram: Program { val trees: TheoryEncoder.this.trees.type } = { + sourceProgram.transform(encoder) + } + private type SameTrees = TheoryEncoder { val trees: TheoryEncoder.this.trees.type } @@ -28,32 +33,40 @@ trait TheoryEncoder { def >>(that: SameTrees): SameTrees = new TheoryEncoder { val trees: TheoryEncoder.this.trees.type = TheoryEncoder.this.trees + val sourceProgram: TheoryEncoder.this.sourceProgram.type = TheoryEncoder.this.sourceProgram val encoder = new TreeTransformer { override def transform(id: Identifier, tpe: Type): (Identifier, Type) = { - val (id1, tpe1) = TheoryEncoder.this.transform(id, tpe) - that.transform(id1, tpe1) + val (id1, tpe1) = TheoryEncoder.this.encoder.transform(id, tpe) + that.encoder.transform(id1, tpe1) } - override def transform(expr: Expr): Expr = that.transform(TheoryEncoder.this.transform(expr)) - override def transform(tpe: Type): Type = that.transform(TheoryEncoder.this.transform(expr)) + override def transform(expr: Expr): Expr = + that.encoder.transform(TheoryEncoder.this.encoder.transform(expr)) + + override def transform(tpe: Type): Type = + that.encoder.transform(TheoryEncoder.this.encoder.transform(tpe)) } val decoder = new TreeTransformer { override def transform(id: Identifier, tpe: Type): (Identifier, Type) = { - val (id1, tpe1) = that.transform(id, tpe) - TheoryEncoder.this.transform(id1, tpe1) + val (id1, tpe1) = that.decoder.transform(id, tpe) + TheoryEncoder.this.decoder.transform(id1, tpe1) } - override def transform(expr: Expr): Expr = TheoryEncoder.this.transform(that.transform(expr)) - override def transform(tpe: Type): Type = TheoryEncoder.this.transform(that.transform(tpe)) + override def transform(expr: Expr): Expr = + TheoryEncoder.this.decoder.transform(that.decoder.transform(expr)) + + override def transform(tpe: Type): Type = + TheoryEncoder.this.decoder.transform(that.decoder.transform(tpe)) } } } trait NoEncoder extends TheoryEncoder { + import trees._ - private object NoTransformer extends trees.TreeTransformer { + private object NoTransformer extends TreeTransformer { override def transform(id: Identifier, tpe: Type): (Identifier, Type) = (id, tpe) override def transform(v: Variable): Variable = v override def transform(vd: ValDef): ValDef = vd @@ -61,7 +74,7 @@ trait NoEncoder extends TheoryEncoder { override def transform(tpe: Type): Type = tpe } - val encoder = NoTransformer - val decoder = NoTransformer + val encoder: TreeTransformer = NoTransformer + val decoder: TreeTransformer = NoTransformer } diff --git a/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala b/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala index 5466a41d8..1b06f3de2 100644 --- a/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala +++ b/src/main/scala/inox/solvers/unrolling/FunctionTemplates.scala @@ -97,8 +97,6 @@ trait FunctionTemplates { self: Templates => } private[unrolling] object functionsManager extends Manager { - val incrementals: Seq[IncrementalState] = Seq(callInfos, defBlockers) - // Function instantiations have their own defblocker private[FunctionTemplates] val defBlockers = new IncrementalMap[Call, Encoded]() @@ -106,6 +104,8 @@ trait FunctionTemplates { self: Templates => // also specify the generation of the blocker. private[FunctionTemplates] val callInfos = new IncrementalMap[Encoded, (Int, Int, Encoded, Set[Call])]() + val incrementals: Seq[IncrementalState] = Seq(callInfos, defBlockers) + def unrollGeneration: Option[Int] = if (callInfos.isEmpty) None else Some(callInfos.values.map(_._1).min) diff --git a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala index f75a9ff68..adfdf3a9f 100644 --- a/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala +++ b/src/main/scala/inox/solvers/unrolling/UnrollingSolver.scala @@ -21,14 +21,13 @@ trait AbstractUnrollingSolver import program.trees._ import program.symbols._ - val theories: TheoryEncoder - lazy val encodedProgram: Program { val trees: program.trees.type } = theories.encode(program) + val theories: TheoryEncoder { val trees: program.trees.type } type Encoded implicit val printable: Encoded => Printable val templates: Templates { - val program: encodedProgram.type + val program: theories.targetProgram.type type Encoded = AbstractUnrollingSolver.this.Encoded } @@ -140,13 +139,13 @@ trait AbstractUnrollingSolver def eval(elem: Encoded, tpe: Type): Option[Expr] = modelEval(elem, theories.encode(tpe)).flatMap { expr => try { - Some(theories.decode(expr)(Map.empty)) + Some(theories.decode(expr)) } catch { case u: Unsupported => None } } - def get(v: Variable): Option[Expr] = eval(freeVars(v), theories.encode(id.getType)).filter { + def get(v: Variable): Option[Expr] = eval(freeVars(v), theories.encode(v.getType)).filter { case v: Variable => false case _ => true } @@ -495,7 +494,7 @@ trait UnrollingSolver extends AbstractUnrollingSolver { import program.symbols._ type Encoded = Expr - val solver: Solver { val program: encodedProgram.type } + val solver: Solver { val program: theories.targetProgram.type } override val name = "U:"+solver.name @@ -506,7 +505,7 @@ trait UnrollingSolver extends AbstractUnrollingSolver { val printable = (e: Expr) => e val templates = new Templates { - val program = encodedProgram + val program: theories.targetProgram.type = theories.targetProgram type Encoded = Expr def encodeSymbol(v: Variable): Expr = v.freshen -- GitLab