diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index fc8e08951384d68917c32089cf2ca16775935911..af9756e60bb900fcbbdd4f75fd6678e5cbf5a30b 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -23,8 +23,9 @@ trait CodeExtraction extends Extractors { } catch { case _ => null } - private lazy val optionClassSym = definitions.getClass("scala.Option") - private lazy val someClassSym = definitions.getClass("scala.Some") + private lazy val optionClassSym = definitions.getClass("scala.Option") + private lazy val someClassSym = definitions.getClass("scala.Some") + private lazy val function1TraitSym = definitions.getClass("scala.Function1") def isSetTraitSym(sym : Symbol) : Boolean = { sym == setTraitSym || sym.tpe.toString.startsWith("scala.Predef.Set") @@ -42,6 +43,10 @@ trait CodeExtraction extends Extractors { sym == optionClassSym || sym == someClassSym } + def isFunction1TraitSym(sym : Symbol) : Boolean = { + sym == function1TraitSym + } + private val varSubsts: scala.collection.mutable.Map[Symbol,Function0[Expr]] = scala.collection.mutable.Map.empty[Symbol,Function0[Expr]] private val classesToClasses: scala.collection.mutable.Map[Symbol,ClassTypeDef] = @@ -540,18 +545,6 @@ trait CodeExtraction extends Extractors { } } } - case ExMapApply(m,f) => { - val rm = rec(m) - val rf = rec(f) - val tpe = rm.getType match { - case MapType(_,toType) => toType - case _ => { - if (!silent) unit.error(tree.pos, "apply on non-map expression") - throw ImpureCodeEncounteredException(tree) - } - } - MapGet(rm, rf).setType(tpe) - } case ExMapIsDefinedAt(m,k) => { val rm = rec(m) val rk = rec(k) @@ -563,6 +556,29 @@ trait CodeExtraction extends Extractors { val rr = rec(t2) MultisetPlus(rl, rr).setType(rl.getType) } + case ExApply(lhs,args) => { + val rlhs = rec(lhs) + val rargs = args map rec + rlhs.getType match { + case MapType(_,tt) => + assert(rargs.size == 1) + MapGet(rlhs, rargs.head).setType(tt) + case FunctionType(fts, tt) => { + rlhs match { + case Variable(id) => + AnonymousFunctionInvocation(id, rargs).setType(tt) + case _ => { + if (!silent) unit.error(tree.pos, "apply on non-variable or non-map expression") + throw ImpureCodeEncounteredException(tree) + } + } + } + case _ => { + if (!silent) unit.error(tree.pos, "apply on unexpected type") + throw ImpureCodeEncounteredException(tree) + } + } + } case ExIfThenElse(t1,t2,t3) => { val r1 = rec(t1) val r2 = rec(t2) @@ -632,8 +648,8 @@ trait CodeExtraction extends Extractors { case TypeRef(_, sym, btt :: Nil) if isMultisetTraitSym(sym) => MultisetType(rec(btt)) case TypeRef(_, sym, btt :: Nil) if isOptionClassSym(sym) => OptionType(rec(btt)) case TypeRef(_, sym, List(ftt,ttt)) if isMapTraitSym(sym) => MapType(rec(ftt),rec(ttt)) + case TypeRef(_, sym, ftt :: ttt :: Nil) if isFunction1TraitSym(sym) => FunctionType(List(rec(ftt)), rec(ttt)) case TypeRef(_, sym, Nil) if classesToClasses.keySet.contains(sym) => classDefToClassType(classesToClasses(sym)) - case _ => { if(!silent) { unit.error(NoPosition, "Could not extract type as PureScala. [" + tr + "]") diff --git a/src/funcheck/Extractors.scala b/src/funcheck/Extractors.scala index 1ea44d13bd5aaab1f466eb93cb6e88fa7bfe7921..30cf4d43cba7104e764bc57687689c229b43f51f 100644 --- a/src/funcheck/Extractors.scala +++ b/src/funcheck/Extractors.scala @@ -321,6 +321,13 @@ trait Extractors { } } +// object ExAnonymousFunctionInvocation { +// def unapply(tree: Apply): Option[(Ident,List[Tree])] = tree match { +// case a @ Apply(Select(i @ Ident(_), applyName), args) if applyName.toString == "apply" => Some((i, args)) +// case _ => None +// } +// } +// // used for case classes selectors. object ExParameterlessMethodCall { def unapply(tree: Select): Option[(Tree,Name)] = tree match { @@ -499,9 +506,9 @@ trait Extractors { } } - object ExMapApply { - def unapply(tree: Apply): Option[(Tree,Tree)] = tree match { - case Apply(Select(lhs, n), List(rhs)) if (n.toString == "apply") => Some((lhs, rhs)) + object ExApply { + def unapply(tree: Apply): Option[(Tree,List[Tree])] = tree match { + case Apply(Select(lhs, n), rhs) if (n.toString == "apply") => Some((lhs, rhs)) case _ => None } } diff --git a/src/purescala/AbstractZ3Solver.scala b/src/purescala/AbstractZ3Solver.scala index 6d95da313ca2f5fe4529889d4628dec1956b0fc4..885d1b98cd1a7b7ae39a71a3765b8ec8d305f980 100644 --- a/src/purescala/AbstractZ3Solver.scala +++ b/src/purescala/AbstractZ3Solver.scala @@ -34,6 +34,8 @@ trait AbstractZ3Solver { protected[purescala] val mapRangeNoneTesters: MutableMap[TypeTree, Z3FuncDecl] protected[purescala] val mapRangeValueSelectors: MutableMap[TypeTree, Z3FuncDecl] + protected[purescala] var anonymousFuns: Map[Identifier, Z3FuncDecl] + protected[purescala] var exprToZ3Id : Map[Expr,Z3AST] protected[purescala] def fromZ3Formula(tree : Z3AST) : Expr diff --git a/src/purescala/Evaluator.scala b/src/purescala/Evaluator.scala index ac0b983cff830357373e2a110e7bc87c1c093167..ebaea193ee660b0adb5a9e343ed185036aa40258 100644 --- a/src/purescala/Evaluator.scala +++ b/src/purescala/Evaluator.scala @@ -256,7 +256,18 @@ object Evaluator { case (FiniteMap(ss), e) => BooleanLiteral(ss.exists(_.from == e)) case (l, r) => throw TypeErrorEx(TypeError(l, m.getType)) } - + case AnonymousFunctionInvocation(i,as) => { + val fun = ctx(i) + fun match { + case AnonymousFunction(es, ev) => { + es.find(_._1 == as) match { + case Some(res) => res._2 + case None => ev + } + } + case _ => scala.sys.error("function id has non-function interpretation") + } + } case Distinct(args) => { val newArgs = args.map(rec(ctx, _)) BooleanLiteral(newArgs.distinct.size == newArgs.size) diff --git a/src/purescala/FairZ3Solver.scala b/src/purescala/FairZ3Solver.scala index 0be2fcbfd2da21ac10d4eedc6818de0b9691fd57..60e7fd81a637f5605d7cf73d7011f56aaf90cd4b 100644 --- a/src/purescala/FairZ3Solver.scala +++ b/src/purescala/FairZ3Solver.scala @@ -54,6 +54,9 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac exprToZ3Id = Map.empty z3IdToExpr = Map.empty + anonymousFuns = Map.empty + fallbackSorts = Map.empty + mapSorts = Map.empty mapRangeSorts.clear @@ -92,6 +95,8 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac private var adtSorts: Map[ClassTypeDef, Z3Sort] = Map.empty private var fallbackSorts: Map[TypeTree, Z3Sort] = Map.empty + protected[purescala] var anonymousFuns: Map[Identifier, Z3FuncDecl] = Map.empty + protected[purescala] var adtTesters: Map[CaseClassDef, Z3FuncDecl] = Map.empty protected[purescala] var adtConstructors: Map[CaseClassDef, Z3FuncDecl] = Map.empty protected[purescala] var adtFieldSelectors: Map[Identifier, Z3FuncDecl] = Map.empty @@ -514,7 +519,9 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac } } case (Some(true), m) => { // SAT + println("MODEL IS: " + m) validatingStopwatch.start + println("VARS IN VC: " + varsInVC) val (trueModel, model) = validateAndDeleteModel(m, toCheckAgainstModels, varsInVC, evaluator) validatingStopwatch.stop @@ -914,6 +921,19 @@ class FairZ3Solver(val reporter: Reporter) extends Solver(reporter) with Abstrac case MapType(ft, tt) => z3.mkDistinct(z3.mkSelect(rec(m), rec(k)), mapRangeNoneConstructors(tt)()) case errorType => scala.sys.error("Unexpected type for map: " + (ex, errorType)) } + case AnonymousFunctionInvocation(id, args) => anonymousFuns.get(id) match { + case Some(fd) => fd(args map rec: _*) + case None => { + id.getType match { + case FunctionType(fts, tt) => { + val newFD = z3.mkFreshFuncDecl(id.uniqueName, fts map typeToSort, typeToSort(tt)) + anonymousFuns = anonymousFuns + (id -> newFD) + newFD(args map rec: _*) + } + case errorType => scala.sys.error("Unexpected type for function: " + (id, errorType)) + } + } + } case Distinct(exs) => z3.mkDistinct(exs.map(rec(_)): _*) diff --git a/src/purescala/PrettyPrinter.scala b/src/purescala/PrettyPrinter.scala index a654580a7c780af282e057b1b0626f6fbb55ca6e..16d811e913df7497f0530bd15f829bee19a53958 100644 --- a/src/purescala/PrettyPrinter.scala +++ b/src/purescala/PrettyPrinter.scala @@ -111,6 +111,26 @@ object PrettyPrinter { nsb = ppNary(nsb, args, "(", ", ", ")", lvl) nsb } + case AnonymousFunction(es, ev) => { + var nsb = sb + nsb.append("{") + es.foreach { + case (as, res) => + nsb = ppNary(nsb, as, "", " ", "", lvl) + nsb.append(" -> ") + nsb = pp(res, nsb, lvl) + nsb.append(", ") + } + nsb.append("else -> ") + nsb = pp(ev, nsb, lvl) + nsb.append("}") + } + case AnonymousFunctionInvocation(id, args) => { + var nsb = sb + nsb.append(id) + nsb = ppNary(nsb, args, "(", ", ", ")", lvl) + nsb + } case Plus(l,r) => ppBinary(sb, l, r, " + ", lvl) case Minus(l,r) => ppBinary(sb, l, r, " - ", lvl) case Times(l,r) => ppBinary(sb, l, r, " * ", lvl) @@ -245,6 +265,20 @@ object PrettyPrinter { // TYPE TREES // all type trees are printed in-line + private def ppNaryType(sb: StringBuffer, tpes: Seq[TypeTree], pre: String, op: String, post: String, lvl: Int): StringBuffer = { + var nsb = sb + nsb.append(pre) + val sz = tpes.size + var c = 0 + + tpes.foreach(t => { + nsb = pp(t, nsb, lvl) ; c += 1 ; if(c < sz) nsb.append(op) + }) + + nsb.append(post) + nsb + } + private def pp(tpe: TypeTree, sb: StringBuffer, lvl: Int): StringBuffer = tpe match { case Untyped => sb.append("???") case Int32Type => sb.append("Int") @@ -253,6 +287,15 @@ object PrettyPrinter { case MapType(ft,tt) => pp(tt, pp(ft, sb.append("Map["), lvl).append(","), lvl).append("]") case MultisetType(bt) => pp(bt, sb.append("Multiset["), lvl).append("]") case OptionType(bt) => pp(bt, sb.append("Option["), lvl).append("]") + case FunctionType(fts, tt) => { + var nsb = sb + if (fts.size > 1) + nsb = ppNaryType(nsb, fts, "(", ", ", ")", lvl) + else if (fts.size == 1) + nsb = pp(fts.head, nsb, lvl) + nsb.append(" => ") + pp(tt, nsb, lvl) + } case c: ClassType => sb.append(c.classDef.id) case _ => sb.append("Type?") } diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index d711d0e2f244234b834df849d80bbe8c98e549fd..1551a79fce01b29c2bbae66814cd5567b5d4dcc4 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -331,6 +331,10 @@ object Trees { case class Concat(list1: Expr, list2: Expr) extends Expr case class ListAt(list: Expr, index: Expr) extends Expr + /* Function operations */ + case class AnonymousFunction(entries: Seq[(Seq[Expr],Expr)], elseValue: Expr) extends Expr + case class AnonymousFunctionInvocation(id: Identifier, args: Seq[Expr]) extends Expr + /* Constraint programming */ case class Distinct(exprs: Seq[Expr]) extends Expr with FixedType { val fixedType = BooleanType @@ -391,6 +395,7 @@ object Trees { object NAryOperator { def unapply(expr: Expr) : Option[(Seq[Expr],(Seq[Expr])=>Expr)] = expr match { case fi @ FunctionInvocation(fd, args) => Some((args, (as => FunctionInvocation(fd, as).setPosInfo(fi)))) + case AnonymousFunctionInvocation(id, args) => Some((args, (as => AnonymousFunctionInvocation(id, as)))) case CaseClass(cd, args) => Some((args, CaseClass(cd, _))) case And(args) => Some((args, And.apply)) case Or(args) => Some((args, Or.apply)) @@ -681,6 +686,7 @@ object Trees { def compute(t: Expr, s: Set[Identifier]) = t match { case Let(i,_,_) => s -- Set(i) case MatchExpr(_, cses) => s -- (cses.map(_.pattern.binders).foldLeft(Set[Identifier]())((a, b) => a ++ b)) + case AnonymousFunctionInvocation(i,_) => s ++ Set[Identifier](i) case _ => s } treeCatamorphism(convert, combine, compute, expr) diff --git a/src/purescala/TypeTrees.scala b/src/purescala/TypeTrees.scala index 53c12102d89d1914e92577f4945e9f8763d2f612..eab21e5cb9219832fe3a54a321c7de79bdc550f7 100644 --- a/src/purescala/TypeTrees.scala +++ b/src/purescala/TypeTrees.scala @@ -119,6 +119,16 @@ object TypeTrees { case InfiniteSize => InfiniteSize case FiniteSize(n) => FiniteSize(n+1) } + case FunctionType(fts, tt) => { + val fromSizes = fts map domainSize + val toSize = domainSize(tt) + if (fromSizes.exists(_ == InfiniteSize) || toSize == InfiniteSize) + InfiniteSize + else { + val n = toSize.asInstanceOf[FiniteSize].size + FiniteSize(scala.math.pow(n, fromSizes.foldLeft(1)((acc, s) => acc * s.asInstanceOf[FiniteSize].size)).toInt) + } + } case c: ClassType => InfiniteSize } @@ -134,6 +144,7 @@ object TypeTrees { case class MultisetType(base: TypeTree) extends TypeTree case class MapType(from: TypeTree, to: TypeTree) extends TypeTree case class OptionType(base: TypeTree) extends TypeTree + case class FunctionType(from: List[TypeTree], to: TypeTree) extends TypeTree sealed abstract class ClassType extends TypeTree { val classDef: ClassTypeDef diff --git a/src/purescala/Z3ModelReconstruction.scala b/src/purescala/Z3ModelReconstruction.scala index 5316608876deb161b042b41d13b0dca794fee8f3..ae294770241f86052bebc614661eccb6143b138f 100644 --- a/src/purescala/Z3ModelReconstruction.scala +++ b/src/purescala/Z3ModelReconstruction.scala @@ -18,7 +18,7 @@ trait Z3ModelReconstruction { def modelValue(model: Z3Model, id: Identifier, tpe: TypeTree = null) : Option[Expr] = { val expectedType = if(tpe == null) id.getType else tpe - if(!exprToZ3Id.isDefinedAt(id.toVariable)) None else { + if(exprToZ3Id.isDefinedAt(id.toVariable)) { val z3ID : Z3AST = exprToZ3Id(id.toVariable) expectedType match { @@ -42,12 +42,32 @@ trait Z3ModelReconstruction { if (singletons.isEmpty) Some(EmptyMap(kt, vt)) else Some(FiniteMap(singletons.toSeq)) } } + case FunctionType(fts, tt) => scala.sys.error("should not have reached this case, function interpretations are handled differently.") case other => model.eval(z3ID) match { case None => None case Some(t) => softFromZ3Formula(t) } } - } + } else if (anonymousFuns.isDefinedAt(id)) { + val z3fd: Z3FuncDecl = anonymousFuns(id) + + expectedType match { + case FunctionType(fts, tt) => { + // TODO change ScalaZ3 to avoid recomputing this + model.getModelFuncInterpretations.find(_._1 == z3fd) match { + case Some((_, es, ev)) => { + val entries = es.map { + case (args, value) => (args map fromZ3Formula, fromZ3Formula(value)) + } + val elseValue = fromZ3Formula(ev) + Some(AnonymousFunction(entries, elseValue)) + } + case None => None + } + } + case errorType => scala.sys.error("unexpected type for function: " + errorType) + } + } else None } def modelToMap(model: Z3Model, ids: Iterable[Identifier]) : Map[Identifier,Expr] = { diff --git a/src/purescala/Z3Solver.scala b/src/purescala/Z3Solver.scala index 1dd504834761cbda6a54af52094bc11ae5a1a12e..1f4cc06acbe4a8366c5d22ff78777906ca9eb450 100644 --- a/src/purescala/Z3Solver.scala +++ b/src/purescala/Z3Solver.scala @@ -107,6 +107,8 @@ class Z3Solver(val reporter: Reporter) extends Solver(reporter) with AbstractZ3S protected[purescala] val mapRangeNoneTesters: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty protected[purescala] val mapRangeValueSelectors: MutableMap[TypeTree, Z3FuncDecl] = MutableMap.empty + protected[purescala] var anonymousFuns: Map[Identifier, Z3FuncDecl] = Map.empty + case class UntranslatableTypeException(msg: String) extends Exception(msg) private def prepareSorts: Unit = { import Z3Context.{ADTSortReference, RecursiveType, RegularSort}