diff --git a/src/main/scala/inox/Program.scala b/src/main/scala/inox/Program.scala index 94bd3b43e09544fe30ee6a4e0b8e5ca146daa39f..7dfbf25ee296d45885093fabe22b1d4c958adf4c 100644 --- a/src/main/scala/inox/Program.scala +++ b/src/main/scala/inox/Program.scala @@ -30,10 +30,15 @@ trait Program { val ctx = Program.this.ctx } - def extend(functions: Seq[trees.FunDef] = Seq.empty, adts: Seq[trees.ADTDefinition] = Seq.empty): - Program { val trees: Program.this.trees.type } = new Program { + def withFunctions(functions: Seq[trees.FunDef]): Program { val trees: Program.this.trees.type } = new Program { val trees: Program.this.trees.type = Program.this.trees - val symbols = Program.this.symbols.extend(functions, adts) + val symbols = Program.this.symbols.withFunctions(functions) + val ctx = Program.this.ctx + } + + def withADTs(adts: Seq[trees.ADTDefinition]): Program { val trees: Program.this.trees.type } = new Program { + val trees: Program.this.trees.type = Program.this.trees + val symbols = Program.this.symbols.withADTs(adts) val ctx = Program.this.ctx } } diff --git a/src/main/scala/inox/ast/DSL.scala b/src/main/scala/inox/ast/DSL.scala index 691b0a35bd904fccfbbd1ce770a8f09e1fe426fb..46778edfffb4b6a22681436680cb99829085f2d2 100644 --- a/src/main/scala/inox/ast/DSL.scala +++ b/src/main/scala/inox/ast/DSL.scala @@ -51,7 +51,6 @@ trait DSL { def _4 = TupleSelect(e, 4) // Sets - def size = SetCardinality(e) def subsetOf = SubsetOf(e, _: Expr) def insert = SetAdd(e, _: Expr) def ++ = SetUnion(e, _: Expr) diff --git a/src/main/scala/inox/ast/Definitions.scala b/src/main/scala/inox/ast/Definitions.scala index 57bf9d6ddb0c985f31e92b96ed89906fd808e43e..8e2417f3f3925eae0d1a22d8bc9b5661c8875c69 100644 --- a/src/main/scala/inox/ast/Definitions.scala +++ b/src/main/scala/inox/ast/Definitions.scala @@ -9,7 +9,7 @@ import scala.collection.mutable.{Map => MutableMap} trait Definitions { self: Trees => /** The base trait for Inox definitions */ - sealed trait Definition extends Tree { + trait Definition extends Tree { val id: Identifier override def equals(that: Any): Boolean = that match { @@ -23,7 +23,7 @@ trait Definitions { self: Trees => abstract class LookupException(id: Identifier, what: String) extends Exception("Lookup failed for " + what + " with symbol " + id) case class FunctionLookupException(id: Identifier) extends LookupException(id, "function") - case class ADTLookupException(id: Identifier) extends LookupException(id, "class") + case class ADTLookupException(id: Identifier) extends LookupException(id, "adt") case class NotWellFormedException(id: Identifier, s: Symbols) extends Exception(s"$id not well formed in $s") @@ -33,7 +33,7 @@ trait Definitions { self: Trees => * Both types share much in common and being able to reason about them * in a uniform manner can be useful in certain cases. */ - private[ast] trait VariableSymbol extends Typed { + protected[ast] trait VariableSymbol extends Typed { val id: Identifier val tpe: Type @@ -52,7 +52,7 @@ trait Definitions { self: Trees => implicit def variableSymbolOrdering[VS <: VariableSymbol]: Ordering[VS] = Ordering.by(e => e.id) - sealed abstract class VariableConverter[B <: VariableSymbol] { + abstract class VariableConverter[B <: VariableSymbol] { def convert(a: VariableSymbol): B } @@ -63,7 +63,7 @@ trait Definitions { self: Trees => } } - implicit def convertToVar = new VariableConverter[Variable] { + implicit def convertToVariable = new VariableConverter[Variable] { def convert(vs: VariableSymbol): Variable = vs match { case v: Variable => v case _ => Variable(vs.id, vs.tpe) @@ -78,12 +78,16 @@ trait Definitions { self: Trees => def toVariable: Variable = to[Variable] def freshen: ValDef = ValDef(id.freshen, tpe).copiedFrom(this) + val flags: Set[Annotation] = Set.empty + override def equals(that: Any): Boolean = super[VariableSymbol].equals(that) override def hashCode: Int = super[VariableSymbol].hashCode } type Symbols >: Null <: AbstractSymbols + val NoSymbols: Symbols + /** Provides the class and function definitions of a program and lookups on them */ trait AbstractSymbols extends Printable @@ -132,8 +136,35 @@ trait Definitions { self: Trees => functions.map(p => PrettyPrinter(p._2, opts)).mkString("\n\n") } - def transform(t: TreeTransformer): Symbols - def extend(functions: Seq[FunDef] = Seq.empty, adts: Seq[ADTDefinition] = Seq.empty): Symbols + def transform(t: TreeTransformer): Symbols = NoSymbols.withFunctions { + functions.values.toSeq.map(fd => new FunDef( + fd.id, + fd.tparams, // type parameters can't be transformed! + fd.params.map(vd => t.transform(vd)), + t.transform(fd.returnType), + t.transform(fd.fullBody), + fd.flags)) + }.withADTs { + adts.values.toSeq.map { + case sort: ADTSort => sort + case cons: ADTConstructor => new ADTConstructor( + cons.id, + cons.tparams, + cons.sort, + cons.fields.map(t.transform), + cons.flags) + } + } + + override def equals(that: Any): Boolean = that match { + case sym: AbstractSymbols => functions == sym.functions && adts == sym.adts + case _ => false + } + + override def hashCode: Int = functions.hashCode * 61 + adts.hashCode + + def withFunctions(functions: Seq[FunDef]): Symbols + def withADTs(adts: Seq[ADTDefinition]): Symbols } case class TypeParameterDef(tp: TypeParameter) extends Definition { @@ -141,38 +172,24 @@ trait Definitions { self: Trees => val id = tp.id } - /** A trait that represents flags that annotate an ADTDefinition with different attributes */ - sealed trait ADTFlag - - object ADTFlag { - def fromName(name: String, args: Seq[Option[Any]]): ADTFlag = Annotation(name, args) - } - - /** A trait that represents flags that annotate a FunDef with different attributes */ - sealed trait FunctionFlag - - object FunctionFlag { - def fromName(name: String, args: Seq[Option[Any]]): FunctionFlag = name match { - case "inline" => IsInlined - case _ => Annotation(name, args) + // Compiler annotations given in the source code as @annot + class Annotation(val annot: String, val args: Seq[Option[Any]]) { + override def equals(that: Any): Boolean = that match { + case o: Annotation => annot == o.annot && args == o.args + case _ => false } + + override def hashCode: Int = annot.hashCode + 31 * args.hashCode } - // Compiler annotations given in the source code as @annot - case class Annotation(annot: String, args: Seq[Option[Any]]) extends FunctionFlag with ADTFlag /** Denotes that this adt is refined by invariant ''id'' */ - case class HasADTInvariant(id: Identifier) extends ADTFlag - // Is inlined - case object IsInlined extends FunctionFlag + case class HasADTInvariant(id: Identifier) extends Annotation("invariant", Seq(Some(id))) /** Represents an ADT definition (either the ADT sort or a constructor). */ sealed trait ADTDefinition extends Definition { val id: Identifier val tparams: Seq[TypeParameterDef] - val flags: Set[ADTFlag] - - def annotations: Set[String] = extAnnotations.keySet - def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { case Annotation(s, args) => s -> args }.toMap + val flags: Set[Annotation] /** The root of the class hierarchy */ def root(implicit s: Symbols): ADTDefinition @@ -199,7 +216,7 @@ trait Definitions { self: Trees => class ADTSort(val id: Identifier, val tparams: Seq[TypeParameterDef], val cons: Seq[Identifier], - val flags: Set[ADTFlag]) extends ADTDefinition { + val flags: Set[Annotation]) extends ADTDefinition { val isSort = true def constructors(implicit s: Symbols): Seq[ADTConstructor] = cons @@ -250,7 +267,7 @@ trait Definitions { self: Trees => val tparams: Seq[TypeParameterDef], val sort: Option[Identifier], val fields: Seq[ValDef], - val flags: Set[ADTFlag]) extends ADTDefinition { + val flags: Set[Annotation]) extends ADTDefinition { val isSort = false /** Returns the index of the field with the specified id */ @@ -280,7 +297,7 @@ trait Definitions { self: Trees => val tps: Seq[Type] implicit val symbols: Symbols - val id: Identifier = definition.id + lazy val id: Identifier = definition.id /** The root of the class hierarchy */ lazy val root: TypedADTDefinition = definition.root.typed(tps) lazy val invariant: Option[TypedFunDef] = definition.invariant.map(_.typed(tps)) @@ -301,7 +318,7 @@ trait Definitions { self: Trees => /** Represents an [[ADTSort]] whose type parameters have been instantiated to ''tps'' */ case class TypedADTSort(definition: ADTSort, tps: Seq[Type])(implicit val symbols: Symbols) extends TypedADTDefinition { - def constructors: Seq[TypedADTConstructor] = definition.constructors.map(_.typed(tps)) + lazy val constructors: Seq[TypedADTConstructor] = definition.constructors.map(_.typed(tps)) } /** Represents an [[ADTConstructor]] whose type parameters have been instantiated to ''tps'' */ @@ -309,7 +326,7 @@ trait Definitions { self: Trees => lazy val fields: Seq[ValDef] = { val tmap = (definition.typeArgs zip tps).toMap if (tmap.isEmpty) definition.fields - else definition.fields.map(vd => vd.copy(tpe = symbols.instantiateType(vd.getType, tmap))) + else definition.fields.map(vd => vd.copy(tpe = symbols.instantiateType(vd.tpe, tmap))) } lazy val fieldsTypes = fields.map(_.tpe) @@ -333,14 +350,9 @@ trait Definitions { self: Trees => val params: Seq[ValDef], val returnType: Type, val fullBody: Expr, - val flags: Set[FunctionFlag] + val flags: Set[Annotation] ) extends Definition { - def annotations: Set[String] = extAnnotations.keySet - def extAnnotations: Map[String, Seq[Option[Any]]] = flags.collect { - case Annotation(s, args) => s -> args - }.toMap - /** Wraps this [[FunDef]] in a in [[TypedFunDef]] with the specified type parameters */ def typed(tps: Seq[Type])(implicit s: Symbols): TypedFunDef = { assert(tps.size == tparams.size) diff --git a/src/main/scala/inox/ast/Expressions.scala b/src/main/scala/inox/ast/Expressions.scala index 377f5ea22cd66794f337b4daffefd29e6ad9548d..57085d556e3bff30428bbb8d099678c3e791e0f3 100644 --- a/src/main/scala/inox/ast/Expressions.scala +++ b/src/main/scala/inox/ast/Expressions.scala @@ -19,7 +19,7 @@ import scala.collection.BitSet */ trait Expressions { self: Trees => - private def checkParamTypes(real: Seq[Type], formal: Seq[Type], result: Type)(implicit s: Symbols): Type = { + protected def checkParamTypes(real: Seq[Type], formal: Seq[Type], result: Type)(implicit s: Symbols): Type = { if (real zip formal forall { case (real, formal) => s.isSubtypeOf(real, formal)} ) { result.unveilUntyped } else { @@ -507,7 +507,7 @@ trait Expressions { self: Trees => } /** $encodingof `... ^ ...` $noteBitvector */ - case class BVXOr(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { + case class BVXor(lhs: Expr, rhs: Expr) extends Expr with CachingTyped { protected def computeType(implicit s: Symbols): Type = bitVectorType(lhs.getType, rhs.getType) } @@ -587,14 +587,6 @@ trait Expressions { self: Trees => }), BooleanType) } - /** $encodingof `set.length` */ - case class SetCardinality(set: Expr) extends Expr with CachingTyped { - protected def computeType(implicit s: Symbols): Type = set.getType match { - case SetType(_) => IntegerType - case _ => Untyped - } - } - /** $encodingof `set.subsetOf(set2)` */ case class SubsetOf(set1: Expr, set2: Expr) extends Expr with CachingTyped { protected def computeType(implicit s: Symbols): Type = (set1.getType, set2.getType) match { diff --git a/src/main/scala/inox/ast/Extractors.scala b/src/main/scala/inox/ast/Extractors.scala index 70f6eb66e1b88b494422b75645ead189a5f33210..7e5a18d1cdaf054359f7d505523ea439bfecc6ed 100644 --- a/src/main/scala/inox/ast/Extractors.scala +++ b/src/main/scala/inox/ast/Extractors.scala @@ -18,8 +18,6 @@ trait TreeDeconstructor { (Seq(e), Seq(), (es, tps) => t.UMinus(es.head)) case s.StringLength(e) => (Seq(e), Seq(), (es, tps) => t.StringLength(es.head)) - case s.SetCardinality(e) => - (Seq(e), Seq(), (es, tps) => t.SetCardinality(es.head)) case s.ADTSelector(e, sel) => (Seq(e), Seq(), (es, tps) => t.ADTSelector(es.head, sel)) case s.IsInstanceOf(e, ct) => @@ -66,8 +64,8 @@ trait TreeDeconstructor { (Seq(t1, t2), Seq(), (es, tps) => t.BVOr(es(0), es(1))) case s.BVAnd(t1, t2) => (Seq(t1, t2), Seq(), (es, tps) => t.BVAnd(es(0), es(1))) - case s.BVXOr(t1, t2) => - (Seq(t1, t2), Seq(), (es, tps) => t.BVXOr(es(0), es(1))) + case s.BVXor(t1, t2) => + (Seq(t1, t2), Seq(), (es, tps) => t.BVXor(es(0), es(1))) case s.BVShiftLeft(t1, t2) => (Seq(t1, t2), Seq(), (es, tps) => t.BVShiftLeft(es(0), es(1))) case s.BVAShiftRight(t1, t2) => diff --git a/src/main/scala/inox/ast/Printers.scala b/src/main/scala/inox/ast/Printers.scala index 080a5c47751793b5cda6cf7d83b8fe3c370dfc66..fecbe01df372c03eee42972cdd1cdcf63aa16cdf 100644 --- a/src/main/scala/inox/ast/Printers.scala +++ b/src/main/scala/inox/ast/Printers.scala @@ -207,7 +207,7 @@ trait Printers { case BVNot(e) => optP { p"~$e" } - case BVXOr(l, r) => optP { + case BVXor(l, r) => optP { p"$l ^ $r" } case BVOr(l, r) => optP { @@ -239,7 +239,6 @@ trait Printers { case BagDifference(l, r) => p"$l \\ $r" case SetIntersection(l, r) => p"$l \u2229 $r" case BagIntersection(l, r) => p"$l \u2229 $r" - case SetCardinality(s) => p"$s.size" case BagAdd(b, e) => p"$b + $e" case MultiplicityInBag(e, b) => p"$b($e)" case MapApply(m, k) => p"$m($k)" @@ -302,7 +301,7 @@ trait Printers { } case fd: FunDef => - for (an <- fd.annotations) { + for (an <- fd.flags) { p"""|@$an |""" } @@ -373,7 +372,7 @@ trait Printers { // 1: | case (_: Or | _: BVOr) => 1 // 2: ^ - case (_: BVXOr) => 2 + case (_: BVXor) => 2 // 3: & case (_: And | _: BVAnd | _: SetIntersection) => 3 // 4: < > diff --git a/src/main/scala/inox/ast/SimpleSymbols.scala b/src/main/scala/inox/ast/SimpleSymbols.scala index e02d0f379871d01eb95d8ca84260bf5862a83de8..bd5e6ddf2a18c06dcb52a27d9af6e5f119d28466 100644 --- a/src/main/scala/inox/ast/SimpleSymbols.scala +++ b/src/main/scala/inox/ast/SimpleSymbols.scala @@ -5,32 +5,21 @@ package ast trait SimpleSymbols { self: Trees => + val NoSymbols = new Symbols(Map.empty, Map.empty) + class Symbols( val functions: Map[Identifier, FunDef], val adts: Map[Identifier, ADTDefinition] ) extends AbstractSymbols { - def transform(t: TreeTransformer) = new Symbols( - functions.mapValues(fd => new FunDef( - fd.id, - fd.tparams, // type parameters can't be transformed! - fd.params.map(vd => t.transform(vd)), - t.transform(fd.returnType), - t.transform(fd.fullBody), - fd.flags)), - adts.mapValues { - case sort: ADTSort => sort - case cons: ADTConstructor => new ADTConstructor( - cons.id, - cons.tparams, - cons.sort, - cons.fields.map(t.transform), - cons.flags) - }) - - def extend(functions: Seq[FunDef] = Seq.empty, adts: Seq[ADTDefinition] = Seq.empty) = new Symbols( + def withFunctions(functions: Seq[FunDef]): Symbols = new Symbols( this.functions ++ functions.map(fd => fd.id -> fd), - this.adts ++ adts.map(cd => cd.id -> cd) + this.adts + ) + + def withADTs(adts: Seq[ADTDefinition]): Symbols = new Symbols( + this.functions, + this.adts ++ adts.map(adt => adt.id -> adt) ) } } diff --git a/src/main/scala/inox/ast/Trees.scala b/src/main/scala/inox/ast/Trees.scala index cd1963046eb7aabd80265091543f68b2bc6ceebf..5f342c5a85fa8895e3c3440fc556de2adbacb459 100644 --- a/src/main/scala/inox/ast/Trees.scala +++ b/src/main/scala/inox/ast/Trees.scala @@ -23,10 +23,7 @@ trait Trees extends Exception(s"${t.asString(PrinterOptions.fromContext(ctx))}@${t.getPos} $msg") abstract class Tree extends utils.Positioned with Serializable { - def copiedFrom(o: Trees#Tree): this.type = { - setPos(o) - this - } + def copiedFrom(o: Trees#Tree): this.type = setPos(o) // @EK: toString is considered harmful for non-internal things. Use asString(ctx) instead. diff --git a/src/main/scala/inox/ast/TypeOps.scala b/src/main/scala/inox/ast/TypeOps.scala index 467c8953999dbdb6c8d5ff0d40f914ac923ce27e..d4006f908bc0e862c05f77746866cdc9245fd4a7 100644 --- a/src/main/scala/inox/ast/TypeOps.scala +++ b/src/main/scala/inox/ast/TypeOps.scala @@ -31,6 +31,17 @@ trait TypeOps { subs.flatMap(typeParamsOf).toSet } + protected def flattenTypeMappings(res: Seq[Option[(Type, Map[TypeParameter, Type])]]): Option[(Seq[Type], Map[TypeParameter, Type])] = { + val (tps, subst) = res.map(_.getOrElse(return None)).unzip + val flat = subst.flatMap(_.toSeq).groupBy(_._1) + Some((tps, flat.mapValues { vs => + vs.map(_._2).distinct match { + case Seq(unique) => unique + case _ => return None + } + })) + } + /** Generic type bounds between two types. Serves as a base for a set of subtyping/unification functions. * It will allow subtyping between classes (but type parameters are invariant). * It will also allow a set of free parameters to be unified if needed. @@ -47,17 +58,6 @@ trait TypeOps { def typeBound(t1: Type, t2: Type, isLub: Boolean, allowSub: Boolean) (implicit freeParams: Seq[TypeParameter]): Option[(Type, Map[TypeParameter, Type])] = { - def flatten(res: Seq[Option[(Type, Map[TypeParameter, Type])]]): Option[(Seq[Type], Map[TypeParameter, Type])] = { - val (tps, subst) = res.map(_.getOrElse(return None)).unzip - val flat = subst.flatMap(_.toSeq).groupBy(_._1) - Some((tps, flat.mapValues { vs => - vs.map(_._2).distinct match { - case Seq(unique) => unique - case _ => return None - } - })) - } - (t1, t2) match { case (_: TypeParameter, _: TypeParameter) if t1 == t2 => Some((t1, Map())) @@ -74,6 +74,9 @@ trait TypeOps { case (_, _: TypeParameter) => None + case (adt: ADTType, _) if !adt.lookupADT.isDefined => None + case (_, adt: ADTType) if !adt.lookupADT.isDefined => None + case (adt1: ADTType, adt2: ADTType) => val def1 = adt1.getADT.definition val def2 = adt2.getADT.definition @@ -96,7 +99,7 @@ trait TypeOps { for { adtDef <- bound - (subs, map) <- flatten((adt1.tps zip adt2.tps).map { case (tp1, tp2) => + (subs, map) <- flattenTypeMappings((adt1.tps zip adt2.tps).map { case (tp1, tp2) => // Class types are invariant! typeBound(tp1, tp2, isLub, allowSub = false) }) @@ -109,7 +112,7 @@ trait TypeOps { typeBound(tp1, tp2, !isLub, allowSub) // Contravariant args } val out = typeBound(to1, to2, isLub, allowSub) // Covariant result - flatten(out +: in) map { + flattenTypeMappings(out +: in) map { case (Seq(newTo, newFrom@_*), map) => (FunctionType(newFrom, newTo), map) } @@ -124,7 +127,7 @@ trait TypeOps { val NAryType(ts1, recon) = t1 val NAryType(ts2, _) = t2 if (ts1.size == ts2.size) { - flatten((ts1 zip ts2).map { case (tp1, tp2) => + flattenTypeMappings((ts1 zip ts2).map { case (tp1, tp2) => typeBound(tp1, tp2, isLub, allowSub = allowVariance) }).map { case (subs, map) => (recon(subs), map) } } else None diff --git a/src/main/scala/inox/datagen/SolverDataGen.scala b/src/main/scala/inox/datagen/SolverDataGen.scala index d13360915f17bd6ec6d19937bb09c2258ca37232..2576742684986384139a28a7b875d26e79a1d168 100644 --- a/src/main/scala/inox/datagen/SolverDataGen.scala +++ b/src/main/scala/inox/datagen/SolverDataGen.scala @@ -77,7 +77,7 @@ trait SolverDataGen extends DataGenerator { self => val sizeOf = sizeFor(tupleWrap(ins.map(_.toVariable))) // We need to synthesize a size function for ins' types. - val pgm1 = program.extend(functions = fds) + val pgm1 = program.withFunctions(fds) val modelEnum = ModelEnumerator(pgm1)(factory(pgm1), evaluator(pgm1)) val enum = modelEnum.enumVarying(ins, satisfying, sizeOf, 5) diff --git a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala index 0ecd015b65b633773157ab868dd2ae7d185bdf19..1b40732590cb3ffaa2b2020f5154ac6bae50f043 100644 --- a/src/main/scala/inox/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/inox/evaluators/RecursiveEvaluator.scala @@ -299,7 +299,7 @@ trait RecursiveEvaluator case (le, re) => throw EvalError("Unexpected operation: (" + le.asString + ") | (" + re.asString + ")") } - case BVXOr(l,r) => + case BVXor(l,r) => (e(l), e(r)) match { case (BVLiteral(i1, s1), BVLiteral(i2, s2)) if s1 == s2 => BVLiteral(i1 ^ i2, s1) case (le,re) => throw EvalError("Unexpected operation: (" + le.asString + ") ^ (" + re.asString + ")") @@ -410,13 +410,6 @@ trait RecursiveEvaluator case (le,re) => throw EvalError(typeErrorMsg(le, s1.getType)) } - case SetCardinality(s) => - val sr = e(s) - sr match { - case FiniteSet(els, _) => IntegerLiteral(els.size) - case _ => throw EvalError(typeErrorMsg(sr, SetType(Untyped))) - } - case f @ FiniteSet(els, base) => FiniteSet(els.map(e).distinct, base) diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala index 6b6dcb0634411b0b21491df4e207642d6233e93d..c07d63da17159368fdf2988211190ce1e8921a77 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala @@ -458,7 +458,7 @@ trait SMTLIBTarget extends Interruptible with ADTManagers { case BVNot(u) => FixedSizeBitVectors.Not(toSMT(u)) case BVAnd(a, b) => FixedSizeBitVectors.And(toSMT(a), toSMT(b)) case BVOr(a, b) => FixedSizeBitVectors.Or(toSMT(a), toSMT(b)) - case BVXOr(a, b) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b)) + case BVXor(a, b) => FixedSizeBitVectors.XOr(toSMT(a), toSMT(b)) case BVShiftLeft(a, b) => FixedSizeBitVectors.ShiftLeft(toSMT(a), toSMT(b)) case BVAShiftRight(a, b) => FixedSizeBitVectors.AShiftRight(toSMT(a), toSMT(b)) case BVLShiftRight(a, b) => FixedSizeBitVectors.LShiftRight(toSMT(a), toSMT(b)) diff --git a/src/main/scala/inox/solvers/theories/BagEncoder.scala b/src/main/scala/inox/solvers/theories/BagEncoder.scala index b1be97cc57fb640e790e83ae23c4de849801c42c..6a026a1e3fdf9174f465761fabfa68a41bed88b3 100644 --- a/src/main/scala/inox/solvers/theories/BagEncoder.scala +++ b/src/main/scala/inox/solvers/theories/BagEncoder.scala @@ -71,6 +71,9 @@ trait BagEncoder extends TheoryEncoder { }) } + override val newFunctions = Seq(Get, Add, Union, Difference, Intersect, BagEquals) + override val newADTs = Seq(bagADT) + val encoder = new TreeTransformer { import sourceProgram._ diff --git a/src/main/scala/inox/solvers/theories/StringEncoder.scala b/src/main/scala/inox/solvers/theories/StringEncoder.scala index 58573b4f92ca98e2f134d9cdee8b3612470630a4..ddf7c4287cc2d21068ad0a2de1db93119ce7ba6f 100644 --- a/src/main/scala/inox/solvers/theories/StringEncoder.scala +++ b/src/main/scala/inox/solvers/theories/StringEncoder.scala @@ -81,6 +81,9 @@ trait StringEncoder extends TheoryEncoder { } })) + override val newFunctions = Seq(Size, Take, Drop, Slice, Concat) + override val newADTs = Seq(stringADT, stringNilADT, stringConsADT) + private val stringBijection = new Bijection[String, Expr]() private def convertToString(e: Expr): String = stringBijection.cachedA(e)(e match { diff --git a/src/main/scala/inox/solvers/theories/TheoryEncoder.scala b/src/main/scala/inox/solvers/theories/TheoryEncoder.scala index d358d0c7b72dfe0067ca448de09da6e543d57f5c..816291d41148edbedb60963298169d6fc6c9f68f 100644 --- a/src/main/scala/inox/solvers/theories/TheoryEncoder.scala +++ b/src/main/scala/inox/solvers/theories/TheoryEncoder.scala @@ -10,7 +10,7 @@ trait TheoryEncoder { val sourceProgram: Program lazy val trees: sourceProgram.trees.type = sourceProgram.trees lazy val targetProgram: Program { val trees: TheoryEncoder.this.trees.type } = { - sourceProgram.transform(encoder) + sourceProgram.transform(encoder).withFunctions(newFunctions).withADTs(newADTs) } import trees._ @@ -18,6 +18,9 @@ trait TheoryEncoder { protected val encoder: TreeTransformer protected val decoder: TreeTransformer + val newFunctions: Seq[FunDef] = Seq.empty + val newADTs: Seq[ADTDefinition] = Seq.empty + def encode(v: Variable): Variable = encoder.transform(v) def decode(v: Variable): Variable = decoder.transform(v) diff --git a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala index 5c18892492296e4a6c070e41955164dea7f0f9b1..65471acbf9948744d73a6831765c47f2aa8f28f6 100644 --- a/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/inox/solvers/z3/AbstractZ3Solver.scala @@ -323,7 +323,7 @@ trait AbstractZ3Solver case BVNot(e) => z3.mkBVNot(rec(e)) case BVAnd(l, r) => z3.mkBVAnd(rec(l), rec(r)) case BVOr(l, r) => z3.mkBVOr(rec(l), rec(r)) - case BVXOr(l, r) => z3.mkBVXor(rec(l), rec(r)) + case BVXor(l, r) => z3.mkBVXor(rec(l), rec(r)) case BVShiftLeft(l, r) => z3.mkBVShl(rec(l), rec(r)) case BVAShiftRight(l, r) => z3.mkBVAshr(rec(l), rec(r)) case BVLShiftRight(l, r) => z3.mkBVLshr(rec(l), rec(r)) diff --git a/src/test/scala/inox/evaluators/EvaluatorSuite.scala b/src/test/scala/inox/evaluators/EvaluatorSuite.scala index c5db3fd3ea26c70c21f8bb244581373ed09cef72..c835e9dbcd14879323887c65488a6b2c25678b99 100644 --- a/src/test/scala/inox/evaluators/EvaluatorSuite.scala +++ b/src/test/scala/inox/evaluators/EvaluatorSuite.scala @@ -56,8 +56,8 @@ class EvaluatorSuite extends FunSuite { eval(e, BVOr(IntLiteral(5), IntLiteral(4))) === IntLiteral(5) eval(e, BVOr(IntLiteral(5), IntLiteral(2))) === IntLiteral(7) - eval(e, BVXOr(IntLiteral(3), IntLiteral(1))) === IntLiteral(2) - eval(e, BVXOr(IntLiteral(3), IntLiteral(3))) === IntLiteral(0) + eval(e, BVXor(IntLiteral(3), IntLiteral(1))) === IntLiteral(2) + eval(e, BVXor(IntLiteral(3), IntLiteral(3))) === IntLiteral(0) eval(e, BVNot(IntLiteral(1))) === IntLiteral(-2)