diff --git a/cp-demo/ChooseCalls.scala b/cp-demo/ChooseCalls.scala index 44bbdf2c9094500ed600b7ea449ed97b655520b6..c0d4ab277c07fc8dcd83ae50f5c046f693b203a1 100644 --- a/cp-demo/ChooseCalls.scala +++ b/cp-demo/ChooseCalls.scala @@ -40,16 +40,12 @@ object ChooseCalls { case Node(Red(), l, _, _) => blackHeight(l) } - // def ct2(t : Tree) : Boolean = { - // !(blackBalanced(t) && size(t) == 4) - // } holds - def chooseTree(height : Int) : Tree = { - choose((t: Tree) => blackBalanced(t) && size(t) == 5) + choose((t: Tree) => blackBalanced(t) && size(t) == height) } def main(args: Array[String]) : Unit = { - val height = 3 + val height = 5 println("The chosen tree (of height " + height + ") is : " + chooseTree(height)) } } diff --git a/src/cp/CallTransformation.scala b/src/cp/CallTransformation.scala index 3c1f05e2c586f383e3407cd32c3808ff7f635a5b..50bdc92d6e271df2600c32ad4943281e0452d758 100644 --- a/src/cp/CallTransformation.scala +++ b/src/cp/CallTransformation.scala @@ -56,6 +56,8 @@ trait CallTransformation val (programAssignment, progSym) = codeGen.assignProgram(programFilename) val (exprAssignment, exprSym) = codeGen.assignExpr(exprFilename) + val skipCounter = codeGen.skipCounter(progSym) + // compute input variables and assert equalities val inputVars = variablesOf(b).filter{ v => !outputVarList.contains(v.name) }.toList println("Input variables: " + inputVars.mkString(", ")) @@ -87,16 +89,7 @@ trait CallTransformation New(tupleTypeTree,List(returnExpressions map (Ident(_)))) } - val code = BLOCK(List(programAssignment, exprAssignment, andExprAssignment) ::: solverInvocation ::: List(modelAssignment) ::: valueAssignments ::: List(returnExpr) : _*) - - /** generated code: */ - val prog1: purescala.Definitions.Program = cp.Serialization.getProgram(programFilename); - val expr1: purescala.Trees.Expr = cp.Serialization.getExpr(exprFilename); - val andExpr1: purescala.Trees.Expr = new And(scala.collection.immutable.List.apply[purescala.Trees.Expr](expr1)); - val solver1: purescala.FairZ3Solver = new FairZ3Solver(new DefaultReporter()); - solver1.setProgram(prog1); - val outcome1: (Option[Boolean], Map[purescala.Common.Identifier,purescala.Trees.Expr]) = solver1.restartAndDecideWithModel(expr1, false); - println("the outcome: " + outcome1) + val code = BLOCK(List(programAssignment, exprAssignment, skipCounter, andExprAssignment) ::: solverInvocation ::: List(modelAssignment) ::: valueAssignments ::: List(returnExpr) : _*) typer.typed(atOwner(currentOwner) { code @@ -165,4 +158,9 @@ object CallTransformation { def inputVar(inputVarList : List[Variable], varName : String) : Variable = { inputVarList.find(_.id.name == varName).getOrElse(scala.Predef.error("Could not find input variable '" + varName + "' in list " + inputVarList)) } + + def skipCounter(prog: Program) : Unit = { + val maxId = prog.allIdentifiers max Ordering[Int].on[Identifier](_.id) + purescala.Common.FreshIdentifier.forceSkip(maxId.id) + } } diff --git a/src/cp/CodeGeneration.scala b/src/cp/CodeGeneration.scala index 829367c51603c36b1a82caec865182a0c789acd7..183b2051bf20c9124f796e1dd03fdf00321d8ecc 100644 --- a/src/cp/CodeGeneration.scala +++ b/src/cp/CodeGeneration.scala @@ -29,6 +29,7 @@ trait CodeGeneration { private lazy val modelFunction = definitions.getMember(callTransformationModule, "model") private lazy val modelValueFunction = definitions.getMember(callTransformationModule, "modelValue") private lazy val inputVarFunction = definitions.getMember(callTransformationModule, "inputVar") + private lazy val skipCounterFunction = definitions.getMember(callTransformationModule, "skipCounter") private lazy val serializationModule = definitions.getModule("cp.Serialization") private lazy val getProgramFunction = definitions.getMember(serializationModule, "getProgram") @@ -261,5 +262,9 @@ trait CodeGeneration { (statement, andSym) } + def skipCounter(progSym : Symbol) : Tree = { + (cpPackage DOT callTransformationModule DOT skipCounterFunction) APPLY ID(progSym) + } + } } diff --git a/src/purescala/Common.scala b/src/purescala/Common.scala index fb01de0f1f1e53b7603f49ad3344b806f7dbb93c..4b195eb9660f573a2443cd39c61803ad6e4db901 100644 --- a/src/purescala/Common.scala +++ b/src/purescala/Common.scala @@ -42,9 +42,19 @@ object Common { soFar = soFar + 1 soFar } + + def last: Int = { + soFar + } } object FreshIdentifier { + def forceSkip(i : Int) : Unit = { + while(UniqueCounter.last < i) { + UniqueCounter.next + } + } + def apply(name: String, alwaysShowUniqueID: Boolean = false) : Identifier = new Identifier(name, UniqueCounter.next, alwaysShowUniqueID) } diff --git a/src/purescala/Definitions.scala b/src/purescala/Definitions.scala index 547ac316eb81dbd608b623bf8066327adbabe409..ee59e388ad6ddb8981fa9c6d495f7bbdceb1b3dd 100644 --- a/src/purescala/Definitions.scala +++ b/src/purescala/Definitions.scala @@ -13,6 +13,7 @@ object Definitions { case t : Definition => t.id == this.id case _ => false } + def allIdentifiers : Set[Identifier] } /** A VarDecl declares a new identifier to be of a certain type. */ @@ -48,6 +49,7 @@ object Definitions { def isRecursive(f1: FunDef) = mainObject.isRecursive(f1) def isCatamorphism(f1: FunDef) = mainObject.isCatamorphism(f1) def caseClassDef(name: String) = mainObject.caseClassDef(name) + def allIdentifiers : Set[Identifier] = mainObject.allIdentifiers + id } /** Objects work as containers for class definitions, functions (def's) and @@ -60,6 +62,11 @@ object Definitions { def caseClassDef(caseClassName : String) : CaseClassDef = definedClasses.find(ctd => ctd.id.name == caseClassName).getOrElse(scala.Predef.error("Asking for non-existent case class def: " + caseClassName)).asInstanceOf[CaseClassDef] + def allIdentifiers : Set[Identifier] = { + (defs map (_.allIdentifiers)).foldLeft(Set[Identifier]())((a, b) => a ++ b) ++ + (invariants map (Trees.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b) + id + } + lazy val classHierarchyRoots : Seq[ClassTypeDef] = defs.filter(_.isInstanceOf[ClassTypeDef]).map(_.asInstanceOf[ClassTypeDef]).filter(!_.hasParent) lazy val (callGraph, callers, callees) = { @@ -174,6 +181,10 @@ object Definitions { children = child :: children } + def allIdentifiers : Set[Identifier] = { + fields.map(f => f.id).toSet + id + } + def knownChildren : Seq[ClassTypeDef] = { children } @@ -218,6 +229,10 @@ object Definitions { } def parent = parent_ + def allIdentifiers : Set[Identifier] = { + fields.map(f => f.id).toSet + } + def selectorID2Index(id: Identifier) : Int = { var i : Int = 0 var found = false @@ -246,6 +261,7 @@ object Definitions { /** Values */ @serializable case class ValDef(varDecl: VarDecl, value: Expr) extends Definition { val id: Identifier = varDecl.id + def allIdentifiers : Set[Identifier] = Trees.allIdentifiers(value) + id } /** Functions (= 'methods' of objects) */ @@ -267,6 +283,13 @@ object Definitions { def hasBody = hasImplementation def hasPrecondition : Boolean = precondition.isDefined def hasPostcondition : Boolean = postcondition.isDefined + + def allIdentifiers : Set[Identifier] = { + args.map(_.id).toSet ++ + body.map(Trees.allIdentifiers(_)).getOrElse(Set[Identifier]()) ++ + precondition.map(Trees.allIdentifiers(_)).getOrElse(Set[Identifier]()) ++ + postcondition.map(Trees.allIdentifiers(_)).getOrElse(Set[Identifier]()) + id + } private var annots: Set[String] = Set.empty[String] def addAnnotation(as: String*) : FunDef = { diff --git a/src/purescala/Trees.scala b/src/purescala/Trees.scala index 3e037816931f9a2443386d5035a53827142ae284..a9f6b2438fa529477e6b68f73d911695086ba909 100644 --- a/src/purescala/Trees.scala +++ b/src/purescala/Trees.scala @@ -60,6 +60,13 @@ object Trees { val theGuard: Option[Expr] def hasGuard = theGuard.isDefined def expressions: Seq[Expr] + + def allIdentifiers : Set[Identifier] = { + pattern.allIdentifiers ++ + Trees.allIdentifiers(rhs) ++ + theGuard.map(Trees.allIdentifiers(_)).getOrElse(Set[Identifier]()) ++ + (expressions map (Trees.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b) + } } @serializable case class SimpleCase(pattern: Pattern, rhs: Expr) extends MatchCase { @@ -77,6 +84,10 @@ object Trees { private def subBinders = subPatterns.map(_.binders).foldLeft[Set[Identifier]](Set.empty)(_ ++ _) def binders: Set[Identifier] = subBinders ++ (if(binder.isDefined) Set(binder.get) else Set.empty) + + def allIdentifiers : Set[Identifier] = { + ((subPatterns map (_.allIdentifiers)).foldLeft(Set[Identifier]())((a, b) => a ++ b)) ++ binders + } } @serializable case class InstanceOfPattern(binder: Option[Identifier], classTypeDef: ClassTypeDef) extends Pattern { // c: Class val subPatterns = Seq.empty @@ -162,6 +173,10 @@ object Trees { @serializable class Implies(val left: Expr, val right: Expr) extends Expr with FixedType { val fixedType = BooleanType + // if(left.getType != BooleanType || right.getType != BooleanType) { + // println("culprits: " + left.getType + ", " + right.getType) + // assert(false) + // } } @serializable case class Not(expr: Expr) extends Expr with FixedType { @@ -703,6 +718,19 @@ object Trees { expr) } + def allIdentifiers(expr: Expr) : Set[Identifier] = expr match { + case l @ Let(binder, e, b) => allIdentifiers(e) ++ allIdentifiers(b) + binder + case n @ NAryOperator(args, _) => + (args map (Trees.allIdentifiers(_))).foldLeft(Set[Identifier]())((a, b) => a ++ b) + case b @ BinaryOperator(a1,a2,_) => allIdentifiers(a1) ++ allIdentifiers(a2) + case u @ UnaryOperator(a,_) => allIdentifiers(a) + case i @ IfExpr(a1,a2,a3) => allIdentifiers(a1) ++ allIdentifiers(a2) ++ allIdentifiers(a3) + case m @ MatchExpr(scrut, cses) => + (cses map (_.allIdentifiers)).foldLeft(Set[Identifier]())((a, b) => a ++ b) ++ allIdentifiers(scrut) + case t: Terminal => Set.empty + case unhandled => scala.Predef.error("Non-terminal case should be handled in treeCatamorphism: " + unhandled) + } + /* Simplifies let expressions: * - removes lets when expression never occurs * - simplifies when expressions occurs exactly once