diff --git a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala index b5db1105561a998dd7e4ee96c798c8f8ac18ae7b..2b2cd3df44a2707c446fd64108dbc759a1bc94d5 100644 --- a/src/main/scala/leon/evaluators/RecursiveEvaluator.scala +++ b/src/main/scala/leon/evaluators/RecursiveEvaluator.scala @@ -215,7 +215,6 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int (lv,rv) match { case (FiniteSet(el1, _),FiniteSet(el2, _)) => BooleanLiteral(el1 == el2) case (FiniteMap(el1, _, _),FiniteMap(el2, _, _)) => BooleanLiteral(el1.toSet == el2.toSet) - case (BooleanLiteral(b1),BooleanLiteral(b2)) => BooleanLiteral(b1 == b2) case _ => BooleanLiteral(lv == rv) } @@ -576,7 +575,7 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int def matchesCase(scrut: Expr, caze: MatchCase)(implicit rctx: RC, gctx: GC): Option[(MatchCase, Map[Identifier, Expr])] = { import purescala.TypeOps.isSubtypeOf - def matchesPattern(pat: Pattern, e: Expr): Option[Map[Identifier, Expr]] = (pat, e) match { + def matchesPattern(pat: Pattern, expr: Expr): Option[Map[Identifier, Expr]] = (pat, expr) match { case (InstanceOfPattern(ob, pct), e) => if (isSubtypeOf(e.getType, pct)) { Some(obind(ob, e)) @@ -590,18 +589,34 @@ abstract class RecursiveEvaluator(ctx: LeonContext, prog: Program, maxSteps: Int if (pct == ct) { val res = (subs zip args).map{ case (s, a) => matchesPattern(s, a) } if (res.forall(_.isDefined)) { - Some(obind(ob, e) ++ res.flatten.flatten) + Some(obind(ob, expr) ++ res.flatten.flatten) } else { None } } else { None } + case (up@UnapplyPattern(ob, _, subs), scrut) => + e(FunctionInvocation(up.unapplyFun, Seq(scrut))) match { + case CaseClass(CaseClassType(cd, _), Seq()) if cd == program.library.Nil.get => + None + case CaseClass(CaseClassType(cd, _), Seq(arg)) if cd == program.library.Cons.get => + val res = subs zip unwrapTuple(arg, up.unapplyFun.returnType.isInstanceOf[TupleType]) map { + case (s,a) => matchesPattern(s,a) + } + if (res.forall(_.isDefined)) { + Some(obind(ob, expr) ++ res.flatten.flatten) + } else { + None + } + case other => + throw EvalError(typeErrorMsg(other, up.unapplyFun.returnType)) + } case (TuplePattern(ob, subs), Tuple(args)) => if (subs.size == args.size) { val res = (subs zip args).map{ case (s, a) => matchesPattern(s, a) } if (res.forall(_.isDefined)) { - Some(obind(ob, e) ++ res.flatten.flatten) + Some(obind(ob, expr) ++ res.flatten.flatten) } else { None } diff --git a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala index da0a4074da8470c952fc6d80aaff2022f4bd1f19..402a7d7cf724659e0b240516281f28e22cc8e87f 100644 --- a/src/main/scala/leon/frontends/scalac/ASTExtractors.scala +++ b/src/main/scala/leon/frontends/scalac/ASTExtractors.scala @@ -223,6 +223,14 @@ trait ASTExtractors { } } + object ExUnapplyPattern { + def unapply(tree: Tree): Option[(Symbol, Seq[Tree])] = tree match { + case UnApply(Apply(s, _), args) => + Some((s.symbol, args)) + case _ => None + } + } + object ExBigIntLiteral { def unapply(tree: Tree): Option[Tree] = tree match { case Apply(ExSelected("scala", "package", "BigInt", "apply"), n :: Nil) => diff --git a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala index 4d4a0af1629b0318bf4b3e2ba05af1f229665442..acd964b9028fbc0be2fe4f6a60befc65ee966dc1 100644 --- a/src/main/scala/leon/frontends/scalac/CodeExtraction.scala +++ b/src/main/scala/leon/frontends/scalac/CodeExtraction.scala @@ -741,7 +741,7 @@ trait CodeExtraction extends ASTExtractors { private def extractTypeParams(tps: Seq[Type]): Seq[(Symbol, TypeParameter)] = { tps.flatMap { case TypeRef(_, sym, Nil) => - Some(sym -> TypeParameter(FreshIdentifier(sym.name.toString))) + Some(sym -> TypeParameter.fresh(sym.name.toString)) case t => outOfSubsetError(t.typeSymbol.pos, "Unhandled type for parameter: "+t) None @@ -892,6 +892,28 @@ trait CodeExtraction extends ASTExtractors { def charPat(ch : Char) = LiteralPattern(None, CharLiteral(ch)) (chars.foldRight(nil)( (ch: Char, p : Pattern) => mkCons( charPat(ch), p)), dctx) + case up@ExUnapplyPattern(s, args) => + implicit val p: Position = NoPosition + val fd = getFunDef(s, up.pos) + val (sub, ctx) = args.map (extractPattern(_)).unzip + val unapplyMethod = defsToDefs(s) + val formalTypes = tupleTypeWrap( + unapplyMethod.params.map { _.getType } ++ + unapplyMethod.returnType.asInstanceOf[ClassType].tps + ) + val realTypes = tupleTypeWrap(Seq( + extractType(up.tpe), + tupleTypeWrap(args map { tr => extractType(tr.tpe)}) + )) + val newTps = canBeSubtypeOf(realTypes, typeParamsOf(formalTypes).toSeq, formalTypes) match { + case Some(tmap) => + fd.tparams map { tpd => tmap.getOrElse(tpd.tp, tpd.tp) } + case None => + reporter.fatalError("Could not instantiate type of unapply method") + } + + (UnapplyPattern(binder, fd.typed(newTps), sub).setPos(up.pos), ctx.foldLeft(dctx)(_ union _)) + case _ => outOfSubsetError(p, "Unsupported pattern: "+p.getClass) } @@ -1292,7 +1314,7 @@ trait CodeExtraction extends ASTExtractors { val rr = extractTree(r) (rl, rr) match { - case (IsTyped(_, rt), IsTyped(_, lt)) if isSubtypeOf(rt, lt) || isSubtypeOf(lt, rt) => + case (IsTyped(_, rt), IsTyped(_, lt)) if typesCompatible(lt, rt) => Not(Equals(rl, rr)) case (IntLiteral(v), IsTyped(_, IntegerType)) => @@ -1311,7 +1333,7 @@ trait CodeExtraction extends ASTExtractors { val rr = extractTree(r) (rl, rr) match { - case (IsTyped(_, rt), IsTyped(_, lt)) if isSubtypeOf(rt, lt) || isSubtypeOf(lt, rt) => + case (IsTyped(_, rt), IsTyped(_, lt)) if typesCompatible(lt, rt) => Equals(rl, rr) case (IntLiteral(v), IsTyped(_, IntegerType)) => diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 57d7ed48610f5ba582a5a1c19af1b471b393a1c8..94b8c952caab1c476b395035fa6b83c1d00f105b 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -7,7 +7,7 @@ import Common._ import Types._ import Definitions._ import Expressions._ -import leon.purescala.Extractors._ +import Extractors._ import Constructors._ import utils.Simplifiers import solvers._ @@ -305,12 +305,12 @@ object ExprOps { // ATTENTION: Unused, and untested def freshenLocals(expr: Expr) : Expr = { def rewritePattern(p: Pattern, sm: Map[Identifier,Identifier]) : Pattern = p match { - case InstanceOfPattern(Some(b), ctd) => InstanceOfPattern(Some(sm(b)), ctd) - case WildcardPattern(Some(b)) => WildcardPattern(Some(sm(b))) + case InstanceOfPattern(ob, ctd) => InstanceOfPattern(ob map sm, ctd) + case WildcardPattern(ob) => WildcardPattern(ob map sm) case TuplePattern(ob, sps) => TuplePattern(ob.map(sm(_)), sps.map(rewritePattern(_, sm))) case CaseClassPattern(ob, ccd, sps) => CaseClassPattern(ob.map(sm(_)), ccd, sps.map(rewritePattern(_, sm))) - case LiteralPattern(Some(bind), lit) => LiteralPattern(Some(sm(bind)), lit) - case other => other + case UnapplyPattern(ob, obj, sps) => UnapplyPattern(ob.map(sm(_)), obj, sps.map(rewritePattern(_, sm))) + case LiteralPattern(ob, lit) => LiteralPattern(ob map sm, lit) } def freshenCase(cse: MatchCase) : MatchCase = { @@ -691,6 +691,13 @@ object ExprOps { val subTests = subps.zipWithIndex.map{case (p, i) => rec(tupleSelect(in, i+1, subps.size), p)} and(bind(ob, in) +: subTests: _*) + case up@UnapplyPattern(ob, fd, subps) => + def someCase(e: Expr) = { + // In the case where unapply returns a Some, it is enough that the subpatterns match + andJoin(unwrapTuple(e, subps.size) zip subps map { case (ex, p) => rec(ex, p).setPos(p) }).setPos(e) + } + and(up.patternMatch(in, BooleanLiteral(false), someCase).setPos(in), bind(ob, in)) + case LiteralPattern(ob,lit) => and(Equals(in,lit), bind(ob,in)) } @@ -699,33 +706,35 @@ object ExprOps { rec(in, pattern) } - def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = pattern match { - case CaseClassPattern(b, ccd, subps) => - assert(ccd.fields.size == subps.size) - val pairs = ccd.fields.map(_.id).toList zip subps.toList - val subMaps = pairs.map(p => mapForPattern(caseClassSelector(ccd, in, p._1), p._2)) - val together = subMaps.flatten.toMap - b match { - case Some(id) => Map(id -> in) ++ together - case None => together - } - - case TuplePattern(b, subps) => - val TupleType(tpes) = in.getType - assert(tpes.size == subps.size) - - val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(tupleSelect(in, i+1, subps.size), p)} - val map = maps.flatten.toMap - b match { - case Some(id) => map + (id -> in) - case None => map - } - - case other => - other.binder match { - case None => Map.empty - case Some(b) => Map(b -> in) - } + def mapForPattern(in: Expr, pattern: Pattern) : Map[Identifier,Expr] = { + def bindIn(id: Option[Identifier]): Map[Identifier,Expr] = id match { + case None => Map() + case Some(id) => Map(id -> in) + } + pattern match { + case CaseClassPattern(b, ccd, subps) => + assert(ccd.fields.size == subps.size) + val pairs = ccd.fields.map(_.id).toList zip subps.toList + val subMaps = pairs.map(p => mapForPattern(caseClassSelector(ccd, in, p._1), p._2)) + val together = subMaps.flatten.toMap + bindIn(b) ++ together + + case TuplePattern(b, subps) => + val TupleType(tpes) = in.getType + assert(tpes.size == subps.size) + + val maps = subps.zipWithIndex.map{case (p, i) => mapForPattern(tupleSelect(in, i+1, subps.size), p)} + val map = maps.flatten.toMap + bindIn(b) ++ map + + case up@UnapplyPattern(b, _, subps) => + bindIn(b) ++ unwrapTuple(up.getUnsafe(in), subps.size).zip(subps).map{ + case (e, p) => mapForPattern(e, p) + }.flatten.toMap + + case other => + bindIn(other.binder) + } } /** Rewrites all pattern-matching expressions into if-then-else expressions, @@ -830,7 +839,7 @@ object ExprOps { * Also returns a sequence of (Identifier -> Expr) pairs which * represent the bindings for intermediate binders (from outermost to innermost) */ - def patternToExpression(p : Pattern, expectedType : TypeTree) : (Expr, Seq[(Identifier, Expr)]) = { + def patternToExpression(p: Pattern, expectedType: TypeTree): (Expr, Seq[(Identifier, Expr)]) = { def fresh(tp : TypeTree) = FreshIdentifier("binder", tp, true) var ieMap = Seq[(Identifier, Expr)]() def addBinding(b : Option[Identifier], e : Expr) = b foreach { ieMap +:= (_, e) } @@ -864,10 +873,12 @@ object ExprOps { addBinding(b, e) e case CaseClassPattern(b, cct, subs) => - val subTypes = cct.fields map { _.getType } - val e = CaseClass(cct, subs zip subTypes map { case (sub,tp) => rec(sub,tp) }) + val e = CaseClass(cct, subs zip cct.fieldsTypes map { case (sub,tp) => rec(sub,tp) }) addBinding(b, e) e + case up@UnapplyPattern(b, fd, subs) => + // TODO: Support this + NoTree(expectedType) } (rec(p, expectedType), ieMap) @@ -1109,19 +1120,11 @@ object ExprOps { } } - private object ChooseMatch extends PartialFunction[Expr, Choose] { - override def apply(e: Expr): Choose = e match { - case (c: Choose) => c - } - override def isDefinedAt(e: Expr): Boolean = e match { - case (c: Choose) => true - case _ => false - } - } - class ChooseCollectorWithPaths extends CollectorWithPaths[(Choose,Expr)] { - val matcher = ChooseMatch.lift - def collect(e: Expr, path: Seq[Expr]) = matcher(e).map(_ -> and(path: _*)) + def collect(e: Expr, path: Seq[Expr]) = e match { + case c: Choose => Some(c -> and(path: _*)) + case _ => None + } } def patternSize(p: Pattern): Int = p match { @@ -1335,6 +1338,17 @@ object ExprOps { (false, Map()) } + case (UnapplyPattern(ob1, fd1, subs1), UnapplyPattern(ob2, fd2, subs2)) => + val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) + + if (ob1.size == ob2.size && fd1 == fd2 && subs1.size == subs2.size) { + (subs1 zip subs2).map { case (p1, p2) => patternHomo(p1, p2) }.foldLeft((true, m)) { + case ((b1, m1), (b2,m2)) => (b1 && b2, m1 ++ m2) + } + } else { + (false, Map()) + } + case (TuplePattern(ob1, subs1), TuplePattern(ob2, subs2)) => val m = Map[Identifier, Identifier]() ++ (ob1 zip ob2) @@ -1374,9 +1388,6 @@ object ExprOps { case (Variable(i1), Variable(i2)) => idHomo(i1, i2) - case (Choose(e1), Choose(e2)) => - isHomo(e1, e2) - case (Let(id1, v1, e1), Let(id2, v2, e2)) => isHomo(v1, v2) && isHomo(e1, e2)(map + (id1 -> id2)) @@ -1401,14 +1412,11 @@ object ExprOps { fdHomo(tfd1.fd, tfd2.fd) && (args1 zip args2).forall{ case (a1, a2) => isHomo(a1, a2) } + // TODO: Seems a lot is missing, like Literals + case Same(Operator(es1, _), Operator(es2, _)) => - if (es1.size == es2.size) { - (es1 zip es2).forall{ case (e1, e2) => isHomo(e1, e2) } - } else { - false - } - case Same(t1 : Terminal, t2: Terminal) => - true + (es1.size == es2.size) && + (es1 zip es2).forall{ case (e1, e2) => isHomo(e1, e2) } case _ => false @@ -1436,6 +1444,8 @@ object ExprOps { * } * * is exaustive. + * + * WARNING: Unused and unmaintained */ def isMatchExhaustive(m: MatchExpr): Boolean = { /** diff --git a/src/main/scala/leon/purescala/Expressions.scala b/src/main/scala/leon/purescala/Expressions.scala index a415c4c600eb58d79e791158e34823861b3caf03..c25f9be4322e3494b75776c937414272515c8f14 100644 --- a/src/main/scala/leon/purescala/Expressions.scala +++ b/src/main/scala/leon/purescala/Expressions.scala @@ -14,7 +14,24 @@ import ExprOps.replaceFromIDs /** Expression definitions for Pure Scala. */ object Expressions { - abstract class Expr extends Tree with Typed + private def checkParamTypes(real: Seq[Typed], formal: Seq[Typed], result: TypeTree): TypeTree = { + if (real zip formal forall { case (real, formal) => isSubtypeOf(real.getType, formal.getType)} ) { + result.unveilUntyped + } else { + //println(real map { r => s"$r: ${r.getType}"} mkString ", " ) + //println(formal map { r => s"$r: ${r.getType}" } mkString ", " ) + Untyped + } + } + + abstract class Expr extends Tree with Typed { + def untyped = { + //println("@" + this.getPos) + //println(this) + //println + Untyped + } + } trait Terminal { self: Expr => @@ -38,24 +55,37 @@ object Expressions { // Preconditions case class Require(pred: Expr, body: Expr) extends Expr { - val getType = body.getType + val getType = { + if (pred.getType == BooleanType) + body.getType + else untyped + } } // Postconditions case class Ensuring(body: Expr, pred: Expr) extends Expr { val getType = pred.getType match { - case FunctionType(Seq(bodyType), BooleanType) if bodyType == body.getType => bodyType - case _ => Untyped + case FunctionType(Seq(bodyType), BooleanType) if isSubtypeOf(body.getType, bodyType) => + body.getType + case _ => + untyped } def toAssert: Expr = { val res = FreshIdentifier("res", getType, true) - Let(res, body, Assert(application(pred, Seq(Variable(res))), Some("Postcondition failed @" + this.getPos), Variable(res))) + Let(res, body, Assert( + application(pred, Seq(Variable(res))), + Some("Postcondition failed @" + this.getPos), Variable(res) + )) } } // Local assertions case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr { - val getType = body.getType + val getType = { + if (pred.getType == BooleanType) + body.getType + else untyped + } } @@ -67,7 +97,15 @@ object Expressions { /* Local val's and def's */ case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr { - val getType = body.getType + val getType = { + // We can't demand anything sticter here, because some binders are + // typed context-wise + if (typesCompatible(value.getType, binder.getType)) + body.getType + else { + untyped + } + } } case class LetDef(fd: FunDef, body: Expr) extends Expr { @@ -83,15 +121,14 @@ object Expressions { */ case class MethodInvocation(rec: Expr, cd: ClassDef, tfd: TypedFunDef, args: Seq[Expr]) extends Expr { val getType = { - // We need ot instanciate the type based on the type of the function as well as receiver + // We need ot instantiate the type based on the type of the function as well as receiver val fdret = tfd.returnType val extraMap: Map[TypeParameterDef, TypeTree] = rec.getType match { case ct: ClassType => - (cd.tparams zip ct.tps).toMap + (cd.tparams zip ct.tps).toMap case _ => Map() } - instantiateType(fdret, extraMap) } } @@ -103,8 +140,12 @@ object Expressions { /* HOFs */ case class Application(callee: Expr, args: Seq[Expr]) extends Expr { - require(callee.getType.isInstanceOf[FunctionType]) - val getType = callee.getType.asInstanceOf[FunctionType].to + val getType = callee.getType match { + case FunctionType(from, to) => + checkParamTypes(args, from, to) + case _ => + untyped + } } case class Lambda(args: Seq[ValDef], body: Expr) extends Expr { @@ -121,16 +162,17 @@ object Expressions { /* Control flow */ case class FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) extends Expr { - val getType = tfd.returnType + require(tfd.params.size == args.size) + val getType = checkParamTypes(args, tfd.params, tfd.returnType) } case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr { - val getType = leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped).unveilUntyped + val getType = leastUpperBound(thenn.getType, elze.getType).getOrElse(untyped).unveilUntyped } case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr { require(cases.nonEmpty) - val getType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped).unveilUntyped + val getType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(untyped).unveilUntyped } case class MatchCase(pattern : Pattern, optGuard : Option[Expr], rhs: Expr) extends Tree { @@ -164,12 +206,51 @@ object Expressions { val subPatterns = Seq() } + case class UnapplyPattern(binder: Option[Identifier], unapplyFun: TypedFunDef, subPatterns: Seq[Pattern]) extends Pattern { + // Hacky, but ok + lazy val optionType = unapplyFun.returnType.asInstanceOf[AbstractClassType] + lazy val Seq(noneType, someType) = optionType.knownCCDescendants.sortBy(_.fields.size) + lazy val someValue = someType.fields.head + // Pattern match unapply(scrut) + // In case of None, return noneCase. + // In case of Some(v), return someCase(v). + def patternMatch(scrut: Expr, noneCase: Expr, someCase: Expr => Expr): Expr = { + // We use this hand-coded if-then-else because we don't want to generate + // match exhaustiveness checks in the program + val binder = FreshIdentifier("unap", optionType, true) + Let( + binder, + FunctionInvocation(unapplyFun, Seq(scrut)), + IfExpr( + IsInstanceOf(someType, Variable(binder)), + someCase(CaseClassSelector(someType, Variable(binder), someValue.id)), + noneCase + ) + ) + } + // Inlined .get method + def get(scrut: Expr) = patternMatch( + scrut, + Error(optionType.tps.head, "None.get"), + e => e + ) + // Selects Some.v field without type-checking. + // Use in a context where scrut.isDefined returns true. + def getUnsafe(scrut: Expr) = CaseClassSelector( + someType, + FunctionInvocation(unapplyFun, Seq(scrut)), + someValue.id + ) + } /* Symbolic IO examples */ case class Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) extends Expr { require(cases.nonEmpty) - val getType = BooleanType + val getType = leastUpperBound(cases.map(_.rhs.getType)) match { + case None => untyped + case Some(_) => BooleanType + } def asConstraint = { val defaultCase = SimpleCase(WildcardPattern(None), out) @@ -183,10 +264,6 @@ object Expressions { val value: T } - case class GenericValue(tp: TypeParameter, id: Int) extends Expr with Terminal { - val getType = tp - } - case class CharLiteral(value: Char) extends Literal[Char] { val getType = CharType } @@ -208,9 +285,16 @@ object Expressions { } + /* Generic values. Represent values of the generic type tp */ + // TODO: Is it valid that GenericValue(tp, 0) != GenericValue(tp, 1)? + case class GenericValue(tp: TypeParameter, id: Int) extends Expr with Terminal { + val getType = tp + } + + /* Case classes */ case class CaseClass(ct: CaseClassType, args: Seq[Expr]) extends Expr { - val getType = ct + val getType = checkParamTypes(args, ct.fieldsTypes, ct) } case class IsInstanceOf(classType: ClassType, expr: Expr) extends Expr { @@ -219,21 +303,37 @@ object Expressions { case class CaseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier) extends Expr { val selectorIndex = classType.classDef.selectorID2Index(selector) - val getType = classType.fieldsTypes(selectorIndex) + val getType = { + // We don't demand equality because we may construct a mistyped field retrieval + // (retrieving from a supertype before) passing it to the solver. + // E.g. l.head where l:List[A] or even l: Nil[A]. This is ok for the solvers. + if (typesCompatible(classType, caseClass.getType)) { + classType.fieldsTypes(selectorIndex) + } else { + untyped + } + } } /* Equality */ case class Equals(lhs: Expr, rhs: Expr) extends Expr { - val getType = BooleanType + val getType = { + if (typesCompatible(lhs.getType, rhs.getType)) BooleanType + else { + untyped + } + } } /* Propositional logic */ case class And(exprs: Seq[Expr]) extends Expr { - val getType = BooleanType - require(exprs.size >= 2) + val getType = { + if (exprs forall (_.getType == BooleanType)) BooleanType + else untyped + } } object And { @@ -241,9 +341,11 @@ object Expressions { } case class Or(exprs: Seq[Expr]) extends Expr { - val getType = BooleanType - require(exprs.size >= 2) + val getType = { + if (exprs forall (_.getType == BooleanType)) BooleanType + else untyped + } } object Or { @@ -251,30 +353,44 @@ object Expressions { } case class Implies(lhs: Expr, rhs: Expr) extends Expr { - val getType = BooleanType + val getType = { + if(lhs.getType == BooleanType && rhs.getType == BooleanType) BooleanType + else untyped + } } case class Not(expr: Expr) extends Expr { - val getType = BooleanType + val getType = { + if (expr.getType == BooleanType) BooleanType + else untyped + } } /* Integer arithmetic */ case class Plus(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == IntegerType && rhs.getType == IntegerType) - val getType = IntegerType + val getType = { + if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType + else untyped + } } - case class Minus(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == IntegerType && rhs.getType == IntegerType) - val getType = IntegerType + case class Minus(lhs: Expr, rhs: Expr) extends Expr { + val getType = { + if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType + else untyped + } } - case class UMinus(expr: Expr) extends Expr { - require(expr.getType == IntegerType) - val getType = IntegerType + case class UMinus(expr: Expr) extends Expr { + val getType = { + if (expr.getType == IntegerType) IntegerType + else untyped + } } - case class Times(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == IntegerType && rhs.getType == IntegerType) - val getType = IntegerType + case class Times(lhs: Expr, rhs: Expr) extends Expr { + val getType = { + if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType + else untyped + } } /* * Division and Remainder follows Java/Scala semantics. Division corresponds @@ -286,19 +402,25 @@ object Expressions { * * Division(x, y) * y + Remainder(x, y) == x */ - case class Division(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == IntegerType && rhs.getType == IntegerType) - val getType = IntegerType + case class Division(lhs: Expr, rhs: Expr) extends Expr { + val getType = { + if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType + else untyped + } } - case class Remainder(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == IntegerType && rhs.getType == IntegerType) - val getType = IntegerType + case class Remainder(lhs: Expr, rhs: Expr) extends Expr { + val getType = { + if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType + else untyped + } } - case class Modulo(lhs: Expr, rhs: Expr) extends Expr { - require(lhs.getType == IntegerType && rhs.getType == IntegerType) - val getType = IntegerType + case class Modulo(lhs: Expr, rhs: Expr) extends Expr { + val getType = { + if (lhs.getType == IntegerType && rhs.getType == IntegerType) IntegerType + else untyped + } } - case class LessThan(lhs: Expr, rhs: Expr) extends Expr { + case class LessThan(lhs: Expr, rhs: Expr) extends Expr { val getType = BooleanType } case class GreaterThan(lhs: Expr, rhs: Expr) extends Expr { @@ -383,7 +505,7 @@ object Expressions { ts(index - 1) case _ => - Untyped + untyped } } @@ -402,15 +524,16 @@ object Expressions { val getType = BooleanType } case class SetIntersection(set1: Expr, set2: Expr) extends Expr { - val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(untyped).unveilUntyped } case class SetUnion(set1: Expr, set2: Expr) extends Expr { - val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(untyped).unveilUntyped } case class SetDifference(set1: Expr, set2: Expr) extends Expr { - val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped + val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(untyped).unveilUntyped } + // TODO: Add checks for these expressions too /* Map operations */ case class FiniteMap(singletons: Seq[(Expr, Expr)], keyType: TypeTree, valueType: TypeTree) extends Expr { @@ -418,12 +541,13 @@ object Expressions { } case class MapGet(map: Expr, key: Expr) extends Expr { val getType = map.getType match { - case MapType(_, to) => to - case _ => Untyped + case MapType(from, to) if isSubtypeOf(key.getType, from) => + to + case _ => untyped } } case class MapUnion(map1: Expr, map2: Expr) extends Expr { - val getType = leastUpperBound(Seq(map1, map2).map(_.getType)).getOrElse(Untyped).unveilUntyped + val getType = leastUpperBound(Seq(map1, map2).map(_.getType)).getOrElse(untyped).unveilUntyped } case class MapDifference(map: Expr, keys: Expr) extends Expr { val getType = map.getType @@ -439,16 +563,16 @@ object Expressions { case ArrayType(base) => base case _ => - Untyped + untyped } } case class ArrayUpdated(array: Expr, index: Expr, newValue: Expr) extends Expr { val getType = array.getType match { case ArrayType(base) => - leastUpperBound(base, newValue.getType).map(ArrayType).getOrElse(Untyped).unveilUntyped + leastUpperBound(base, newValue.getType).map(ArrayType).getOrElse(untyped).unveilUntyped case _ => - Untyped + untyped } } @@ -470,10 +594,10 @@ object Expressions { case class Choose(pred: Expr) extends Expr { val getType = pred.getType match { - case FunctionType(from, to) if from.nonEmpty => // @mk why nonEmpty? + case FunctionType(from, BooleanType) if from.nonEmpty => // @mk why nonEmpty? tupleTypeWrap(from) case _ => - Untyped + untyped } } diff --git a/src/main/scala/leon/purescala/Extractors.scala b/src/main/scala/leon/purescala/Extractors.scala index d17b58a969967b773f3c328a5ad4886cf66d1eb6..c22a175758c6f4a5ffb8ec669427e0ccac9d8207 100644 --- a/src/main/scala/leon/purescala/Extractors.scala +++ b/src/main/scala/leon/purescala/Extractors.scala @@ -320,6 +320,7 @@ object Extractors { case CaseClassPattern(b, ct, subs) => (b, subs, (b, sp) => CaseClassPattern(b, ct, sp)) case TuplePattern(b,subs) => (b, subs, (b, sp) => TuplePattern(b, sp)) case LiteralPattern(b, l) => (b, Seq(), (b, _) => LiteralPattern(b, l)) + case UnapplyPattern(b, fd, subs) => (b, subs, (b, sp) => UnapplyPattern(b, fd, sp)) } } diff --git a/src/main/scala/leon/purescala/FunctionClosure.scala b/src/main/scala/leon/purescala/FunctionClosure.scala index d27f90983931d5e1ce960800844e6669c75479b7..96ef8e2b05568dc9c19e5c5874351452d24d2689 100644 --- a/src/main/scala/leon/purescala/FunctionClosure.scala +++ b/src/main/scala/leon/purescala/FunctionClosure.scala @@ -151,6 +151,7 @@ object FunctionClosure extends TransformationPhase { case WildcardPattern(binder) => WildcardPattern(binder.map(id2freshId(_))) case CaseClassPattern(binder, caseClassDef, subPatterns) => CaseClassPattern(binder.map(id2freshId(_)), caseClassDef, subPatterns.map(freshIdInPat(_, id2freshId))) case TuplePattern(binder, subPatterns) => TuplePattern(binder.map(id2freshId(_)), subPatterns.map(freshIdInPat(_, id2freshId))) + case UnapplyPattern(binder, fd, subPatterns) => UnapplyPattern(binder.map(id2freshId(_)), fd, subPatterns.map(freshIdInPat(_, id2freshId))) case LiteralPattern(binder, lit) => LiteralPattern(binder.map(id2freshId(_)), lit) } diff --git a/src/main/scala/leon/purescala/PrettyPrinter.scala b/src/main/scala/leon/purescala/PrettyPrinter.scala index 5e5d3407d4faf8bdd4e978d67c59931104eaaf11..1767e38db86b75f6287c73e2d8cc2f00ddd58754 100644 --- a/src/main/scala/leon/purescala/PrettyPrinter.scala +++ b/src/main/scala/leon/purescala/PrettyPrinter.scala @@ -403,6 +403,17 @@ class PrettyPrinter(opts: PrinterOptions, ob.foreach { b => p"$b @ " } p"($subps)" + case UnapplyPattern(ob, tfd, subps) => + ob.foreach { b => p"$b @ " } + + // @mk: I admit this is pretty ugly + val id = for { + p <- opgm + mod <- p.modules.find( _.definedFunctions contains tfd.fd ) + } yield mod.id + + p"${id.getOrElse("<unknown object>")}(${nary(subps)})" + case LiteralPattern(ob, lit) => ob foreach { b => p"$b @ " } p"$lit" diff --git a/src/main/scala/leon/purescala/ScopeSimplifier.scala b/src/main/scala/leon/purescala/ScopeSimplifier.scala index 5559308fe6f3e50b1ee22bc46a20d57d9339caf6..40102267f0d9c3eadada7fb43cb22487661beb4c 100644 --- a/src/main/scala/leon/purescala/ScopeSimplifier.scala +++ b/src/main/scala/leon/purescala/ScopeSimplifier.scala @@ -82,6 +82,8 @@ class ScopeSimplifier extends Transformer { CaseClassPattern(newBinder, ccd, newSubPatterns) case TuplePattern(b, sub) => TuplePattern(newBinder, newSubPatterns) + case UnapplyPattern(b, obj, sub) => + UnapplyPattern(newBinder, obj, newSubPatterns) case LiteralPattern(_, lit) => LiteralPattern(newBinder, lit) } diff --git a/src/main/scala/leon/purescala/TypeOps.scala b/src/main/scala/leon/purescala/TypeOps.scala index 4b77380bf64475ad5156c762591e5e3bebd7a926..005c877d0e207e503b526ed65166f6f557d68c69 100644 --- a/src/main/scala/leon/purescala/TypeOps.scala +++ b/src/main/scala/leon/purescala/TypeOps.scala @@ -137,6 +137,15 @@ object TypeOps { case (TupleType(args1), TupleType(args2)) => val args = (args1 zip args2).map(p => leastUpperBound(p._1, p._2)) if (args.forall(_.isDefined)) Some(TupleType(args.map(_.get))) else None + + case (FunctionType(from1, to1), FunctionType(from2, to2)) => + // TODO: make functions contravariant to arg. types + if (from1 == from2) { + leastUpperBound(to1, to2) map { FunctionType(from1, _) } + } else { + None + } + case (o1, o2) if o1 == o2 => Some(o1) case _ => None } @@ -158,6 +167,9 @@ object TypeOps { leastUpperBound(t1, t2) == Some(t2) } + def typesCompatible(t1: TypeTree, t2: TypeTree) = { + leastUpperBound(t1, t2).isDefined + } def typeCheck(obj: Expr, exps: TypeTree*) { val res = exps.exists(e => isSubtypeOf(obj.getType, e)) @@ -207,7 +219,7 @@ object TypeOps { val newTpe = tpeSub(e.getType) def mapsUnion(maps: Seq[Map[Identifier, Identifier]]): Map[Identifier, Identifier] = { - maps.foldLeft(Map[Identifier, Identifier]())(_ ++ _) + maps.flatten.toMap } def trCase(c: MatchCase) = c match { @@ -243,6 +255,14 @@ object TypeOps { (CaseClassPattern(newOb, newCt, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) + case (up@UnapplyPattern(ob, fd, sps), tp) => + val newFd = if ((fd.tps map tpeSub) == fd.tps) fd else fd.fd.typed(fd.tps map tpeSub) + val newOb = ob.map(id => freshId(id,tp)) + val exType = tpeSub(up.someType.tps.head) + val exTypes = unwrapTupleType(exType, exType.isInstanceOf[TupleType]) + val (newSps, newMaps) = (sps zip exTypes).map { case (sp, stpe) => trPattern(sp, stpe) }.unzip + (UnapplyPattern(newOb, newFd, newSps), (ob zip newOb).toMap ++ mapsUnion(newMaps)) + case (WildcardPattern(ob), expTpe) => val newOb = ob.map(id => freshId(id, expTpe)) @@ -279,9 +299,6 @@ object TypeOps { val newId = freshId(id, tpeSub(id.getType)) Let(newId, srec(value), rec(idsMap + (id -> newId))(body)).copiedFrom(l) - case c @ Choose(pred) => - Choose(rec(idsMap)(pred)).copiedFrom(c) - case l @ Lambda(args, body) => val newArgs = args.map { arg => val tpe = tpeSub(arg.getType) @@ -308,9 +325,6 @@ object TypeOps { case other => // FIXME any better ideas? sys.error(s"Tried to substitute $tpar with $other within GenericValue $g") } - - case ens @ Ensuring(body, pred) => - Ensuring(srec(body), rec(idsMap)(pred)).copiedFrom(ens) case s @ FiniteSet(elems, tpe) => FiniteSet(elems.map(srec), tpeSub(tpe)).copiedFrom(s) diff --git a/src/main/scala/leon/purescala/Types.scala b/src/main/scala/leon/purescala/Types.scala index bbd1884d071c32c25c2ff6aa03fc3f288abbe1ab..6bb1f921dbb53d41937d4a6989d00148d8ca9734 100644 --- a/src/main/scala/leon/purescala/Types.scala +++ b/src/main/scala/leon/purescala/Types.scala @@ -51,8 +51,21 @@ object Types { case class BitVectorType(size: Int) extends TypeTree case object Int32Type extends TypeTree - case class TypeParameter(id: Identifier) extends TypeTree { - def freshen = TypeParameter(id.freshen) + class TypeParameter private (name: String) extends TypeTree { + val id = FreshIdentifier(name, this) + def freshen = new TypeParameter(name) + + override def equals(that: Any) = that match { + case TypeParameter(id) => this.id == id + case _ => false + } + + override def hashCode = id.hashCode + } + + object TypeParameter { + def unapply(tp: TypeParameter): Option[Identifier] = Some(tp.id) + def fresh(name: String) = new TypeParameter(name) } /* diff --git a/src/main/scala/leon/solvers/ADTManager.scala b/src/main/scala/leon/solvers/ADTManager.scala index 96c8ad60b762c76775cc3c7b21ae5171aa8cd710..3339fe2b79de7c8fa1d2b7c0509ffa2da1be4f75 100644 --- a/src/main/scala/leon/solvers/ADTManager.scala +++ b/src/main/scala/leon/solvers/ADTManager.scala @@ -49,7 +49,7 @@ class ADTManager(ctx: LeonContext) { if (conflicts(t)) { // There is no way to solve this, the type we requested is in conflict - reporter.warning("Encountered ADT '"+t+"' that can't be defined.") + reporter.warning(s"Encountered ADT '$t' that can't be defined.") reporter.warning("It appears it has recursive references through non-structural types (such as arrays, maps, or sets).") throw new IllegalArgumentException } else { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala index a261fc10c38cf8b9ab3adfd3fba4b1b25e000526..4fdf7b6e1e2b9fb4462569be413f0111af79a321 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBSolver.scala @@ -62,7 +62,11 @@ abstract class SMTLIBSolver(val context: LeonContext, dir.mkdir } - new java.io.FileWriter(s"vcs/$targetName-$file-$n.smt2", false) + val fileName = s"vcs/$targetName-$file-$n.smt2" + + reporter.debug(s"Outputting VC into $fileName" ) + + new java.io.FileWriter(fileName, false) } else None diff --git a/src/main/scala/leon/synthesis/ExamplesFinder.scala b/src/main/scala/leon/synthesis/ExamplesFinder.scala index 2fbf1d87d741bf8f3defcf2df162a67035daa82d..955805148e63e3a802a175538a406287a450ae72 100644 --- a/src/main/scala/leon/synthesis/ExamplesFinder.scala +++ b/src/main/scala/leon/synthesis/ExamplesFinder.scala @@ -141,7 +141,10 @@ class ExamplesFinder(ctx: LeonContext, program: Program) { // The pattern as expression (input expression)(may contain free variables) val (pattExpr, ieMap) = patternToExpression(cs.pattern, in.getType) val freeVars = variablesOf(pattExpr).toSeq - if (freeVars.isEmpty) { + if (exists(_.isInstanceOf[NoTree])(pattExpr)) { + reporter.warning(cs.pattern.getPos, "Unapply patterns are not supported in IO-example extraction") + Seq() + } else if (freeVars.isEmpty) { // The input contains no free vars. Trivially return input-output pair Seq((pattExpr, doSubstitute(ieMap,cs.rhs))) } else { diff --git a/src/test/resources/regression/verification/purescala/invalid/Unapply1.scala b/src/test/resources/regression/verification/purescala/invalid/Unapply1.scala new file mode 100644 index 0000000000000000000000000000000000000000..20ff95383b4f7042b0eba32fcc4e5c8b0cea3680 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/Unapply1.scala @@ -0,0 +1,15 @@ +import leon.lang._ + +object Unap1 { + def unapply[A, B](i: (Int, B, A)): Option[(A, B)] = + if (i._1 == 0) None() else Some((i._3, i._2)) +} + +object Unapply1 { + + def bar: Boolean = { (42, false, ()) match { + case Unap1(_, b) if b => b + case Unap1((), b) => b + }} ensuring { res => res } + +} diff --git a/src/test/resources/regression/verification/purescala/invalid/Unapply2.scala b/src/test/resources/regression/verification/purescala/invalid/Unapply2.scala new file mode 100644 index 0000000000000000000000000000000000000000..ae4167c20026f0e926494af6cf3a11f25e9399e5 --- /dev/null +++ b/src/test/resources/regression/verification/purescala/invalid/Unapply2.scala @@ -0,0 +1,11 @@ +import leon.lang._ +object Unap2 { + def unapply[A, B](i: (Int, B, A)): Option[(A, B)] = + if (i._1 == 0) None() else Some((i._3, i._2)) +} + +object Unapply { + def bar: Boolean = { (42, false, ()) match { + case Unap2(_, b) if b => b + }} ensuring { res => res } +} diff --git a/src/test/resources/regression/verification/purescala/valid/Unapply.scala b/src/test/resources/regression/verification/purescala/valid/Unapply.scala new file mode 100644 index 0000000000000000000000000000000000000000..941b1f370d740204e8fa850a9d5e5a787b1c401e --- /dev/null +++ b/src/test/resources/regression/verification/purescala/valid/Unapply.scala @@ -0,0 +1,12 @@ +import leon.lang._ +object Unap { + def unapply[A, B](i: (Int, B, A)): Option[(A, B)] = + if (i._1 == 0) None() else Some((i._3, i._2)) +} + +object Unapply { + def bar: Boolean = { (42, true, ()) match { + case Unap(_, b) if b => b + case Unap((), b) => !b + }} ensuring { res => res } +} diff --git a/src/test/scala/leon/test/codegen/CodeGenSuite.scala b/src/test/scala/leon/test/codegen/CodeGenSuite.scala index 82c0a1161e69b0ae461410c85dc08afea78eebc2..a33d868be873f925e6d52bbd859d33f7393cb090 100644 --- a/src/test/scala/leon/test/codegen/CodeGenSuite.scala +++ b/src/test/scala/leon/test/codegen/CodeGenSuite.scala @@ -143,10 +143,9 @@ class CodeGenSuite extends test.LeonTestSuite { object simple2 { abstract class Abs case class Conc(x : BigInt) extends Abs - def test = { + def test = { val c = Conc(1) c.x - } } object eager { @@ -287,7 +286,7 @@ class CodeGenSuite extends test.LeonTestSuite { val l = Cons(1, Cons(2, Cons(3, Nil()))) - def test = l.length + Nil().length + def test = l.length + Nil[Int]().length } object ListWithSumMono { abstract class List