From bb37e92e10c125b5e4d2df7513ac7d6864c64f89 Mon Sep 17 00:00:00 2001 From: Ivan Kuraj <ivan.kuraj@epfl.ch> Date: Thu, 25 Apr 2013 04:21:03 +0200 Subject: [PATCH] More flexible variable refinement - direct field refinement, operator refinement --- src/main/scala/insynth/InSynth.scala | 2 + .../scala/insynth/leon/loader/Loader.scala | 40 ++- .../scala/insynth/leon/loader/PreLoader.scala | 2 + src/main/scala/lesynth/ExampleRunner.scala | 1 + src/main/scala/lesynth/Refiner.scala | 16 +- src/main/scala/lesynth/Report.scala | 8 +- src/main/scala/lesynth/SynthesisInfo.scala | 63 +++++ .../scala/lesynth/SynthesizerExamples.scala | 232 +++++++----------- src/main/scala/lesynth/VariableRefiner.scala | 88 +++++++ .../ConditionAbductionSynthesisTwoPhase.scala | 7 +- .../scala/lesynth/VariableRefinerTest.scala | 90 +++++++ 11 files changed, 387 insertions(+), 162 deletions(-) create mode 100644 src/main/scala/lesynth/SynthesisInfo.scala create mode 100755 src/main/scala/lesynth/VariableRefiner.scala create mode 100644 src/test/scala/lesynth/VariableRefinerTest.scala diff --git a/src/main/scala/insynth/InSynth.scala b/src/main/scala/insynth/InSynth.scala index bcba1f478..42d386bc9 100644 --- a/src/main/scala/insynth/InSynth.scala +++ b/src/main/scala/insynth/InSynth.scala @@ -32,6 +32,7 @@ class InSynth(declarations: List[Declaration], goalType: Type, ordered: Boolean lazy val solver = new Solver(declarations, new LeonQueryBuilder(goalType)) def getExpressions = { + info("InSynth synthesizing type + " + goalType + " with declarations " + solver.allDeclarations.mkString("\n")) val proofTree = solver.getProofTree assert(proofTree != null, "Proof tree is null" ) @@ -43,6 +44,7 @@ class InSynth(declarations: List[Declaration], goalType: Type, ordered: Boolean } def getExpressions(builder: InitialEnvironmentBuilder) = { + info("InSynth synthesizing type + " + goalType + " with declarations " + builder.getAllDeclarations.mkString("\n")) val proofTree = solver.getProofTree(builder) assert(proofTree != null, "Proof tree is null" ) diff --git a/src/main/scala/insynth/leon/loader/Loader.scala b/src/main/scala/insynth/leon/loader/Loader.scala index f318a7a36..44175bd8e 100644 --- a/src/main/scala/insynth/leon/loader/Loader.scala +++ b/src/main/scala/insynth/leon/loader/Loader.scala @@ -63,6 +63,21 @@ case class LeonLoader(program: Program, hole: Hole, list ++= variableDeclarations + for (variable <- variables; variableType = variable.getType) variableType match { + case variableClassType: CaseClassType => variableClassType.classDef match { + case cas@CaseClassDef(id, parent, fields) => + fine("adding fields of variable " + variable) + for (field <- fields) + list += makeDeclaration( + ImmediateExpression( "Field(" + cas + "." + field.id + ")", + CaseClassSelector(cas, variable.toVariable, field.id) ), + field.id.getType + ) + case _ => + } + case _ => + } + list.toList // no need for doing this (we will have everything from the synthesis problem context) @@ -128,21 +143,20 @@ case class LeonLoader(program: Program, hole: Hole, inheritance <- extractInheritancesRec(classHierarchyRoot)) yield inheritance } + + def extractFields(classDef: ClassTypeDef) = classDef match { + case abs: AbstractClassDef => + // this case does not seem to work + //abs.fields + Seq.empty + case cas: CaseClassDef => + for (field <- cas.fields) + yield makeDeclaration( + UnaryReconstructionExpression("Field(" + cas + "." + field.id + ")", { CaseClassSelector(cas, _: Expr, field.id) }), + FunctionType(List(classMap(cas.id)), field.id.getType)) + } def extractFields: Seq[Declaration] = { - def extractFields(classDef: ClassTypeDef) = classDef match { - case abs: AbstractClassDef => - // this case does not seem to work - //abs.fields - Seq.empty - case cas: CaseClassDef => - for (field <- cas.fields) - yield makeDeclaration( - UnaryReconstructionExpression( "Field(" + cas + "." + field.id + ")", { CaseClassSelector(cas, _: Expr, field.id) } ), - FunctionType(List(classMap(cas.id)), field.id.getType) - ) - } - for (classDef <- program.definedClasses; decl <- extractFields(classDef)) yield decl diff --git a/src/main/scala/insynth/leon/loader/PreLoader.scala b/src/main/scala/insynth/leon/loader/PreLoader.scala index 1f72ac43b..b362af6be 100644 --- a/src/main/scala/insynth/leon/loader/PreLoader.scala +++ b/src/main/scala/insynth/leon/loader/PreLoader.scala @@ -22,6 +22,8 @@ object PreLoader extends ( (Boolean) => List[Declaration] ) { list += getNot list += getLessEquals + list += getLessThan + list += getGreaterThan list ++= getEquals(supportedBaseTypes) if (loadArithmeticOps) diff --git a/src/main/scala/lesynth/ExampleRunner.scala b/src/main/scala/lesynth/ExampleRunner.scala index 111101188..3d82d3c03 100644 --- a/src/main/scala/lesynth/ExampleRunner.scala +++ b/src/main/scala/lesynth/ExampleRunner.scala @@ -74,6 +74,7 @@ class ExampleRunner(program: Program, maxSteps: Int = 2000) extends HasLogger { // replace(Map(ResultVariable() -> Variable(resFresh)), matchToIfThenElse(holeFunDef.getPostcondition))) // } fine("expressionToCheck: " + expressionToCheck) + finest("program: " + program) (0 /: counterExamples) { (res, ce) => diff --git a/src/main/scala/lesynth/Refiner.scala b/src/main/scala/lesynth/Refiner.scala index c5dda1dbd..2fc2cbd44 100755 --- a/src/main/scala/lesynth/Refiner.scala +++ b/src/main/scala/lesynth/Refiner.scala @@ -18,9 +18,10 @@ class Refiner(program: Program, hole: Hole, holeFunDef: FunDef) extends HasLogge "Results for refining " + expr + ", are: " + " ,recurentExpression == expr " + (recurentExpression == expr) + " ,isCallAvoidableBySize(expr) " + isCallAvoidableBySize(expr, funDefArgs) + - " ,hasDoubleRecursion(expr) " + hasDoubleRecursion(expr) + " ,hasDoubleRecursion(expr) " + hasDoubleRecursion(expr) + + " ,isOperatorAvoidable(expr) " + isOperatorAvoidable(expr) ) - recurentExpression == expr || isCallAvoidableBySize(expr, funDefArgs) || hasDoubleRecursion(expr) + recurentExpression == expr || isCallAvoidableBySize(expr, funDefArgs) || hasDoubleRecursion(expr) || isOperatorAvoidable(expr) } //val holeFunDef = Globals.holeFunDef @@ -95,4 +96,15 @@ class Refiner(program: Program, hole: Hole, holeFunDef: FunDef) extends HasLogge found } + def isOperatorAvoidable(expr: Expr) = expr match { + case And(expr1 :: expr2) if expr1 == expr2 => true + case Or(expr1 :: expr2) if expr1 == expr2 => true + case LessThan(expr1, expr2) if expr1 == expr2 => true + case LessEquals(expr1, expr2) if expr1 == expr2 => true + case GreaterThan(expr1, expr2) if expr1 == expr2 => true + case GreaterEquals(expr1, expr2) if expr1 == expr2 => true + case Equals(expr1, expr2) if expr1 == expr2 => true + case _ => false + } + } \ No newline at end of file diff --git a/src/main/scala/lesynth/Report.scala b/src/main/scala/lesynth/Report.scala index 993a9f862..372784d03 100755 --- a/src/main/scala/lesynth/Report.scala +++ b/src/main/scala/lesynth/Report.scala @@ -21,10 +21,13 @@ case object EmptyReport extends Report { override def isSuccess = false } -case class FullReport(val function: FunDef, val totalTime: Long/*, innerVerificationReport */) extends Report { +case class FullReport(val function: FunDef, val synthInfo: SynthesisInfo) extends Report { + import SynthesisInfo.Action._ import Report._ + val totalTime = synthInfo.getTime(Synthesis) + implicit val width = 70 override def summaryString : String = { @@ -33,6 +36,9 @@ case class FullReport(val function: FunDef, val totalTime: Long/*, innerVerifica ("║ %-" + (width - 2) + "s ║\n").format(_) }.mkString + infoSep + + ("║ Generation: %" + (width - 15) + ".2fs ║\n").format(synthInfo.getTime(Generation).toDouble/1000) + + ("║ Evaluation: %" + (width - 15) + ".2fs ║\n").format(synthInfo.getTime(Evaluation).toDouble/1000) + + ("║ Verification: %" + (width - 15) + ".2fs ║\n").format(synthInfo.getTime(Verification).toDouble/1000) + ("║ Total time: %" + (width - 15) + ".2fs ║\n").format(totalTime.toDouble/1000) + infoFooter } diff --git a/src/main/scala/lesynth/SynthesisInfo.scala b/src/main/scala/lesynth/SynthesisInfo.scala new file mode 100644 index 000000000..82a5487d8 --- /dev/null +++ b/src/main/scala/lesynth/SynthesisInfo.scala @@ -0,0 +1,63 @@ +package lesynth + +/** + * Contains information about the synthesis process + */ +class SynthesisInfo { + + import SynthesisInfo.Action + + // times + private var times = new Array[Long](Action.values.size) + private var startTimes = new Array[Long](Action.values.size) + private var lastTimes = new Array[Long](Action.values.size) + + private var lastAction: Action.Action = Action.Evaluation + + def getTime(a: Action.Action) = times(a.id) + + def start(a: Action.Action) = { + lastAction = a + startTimes(a.id) = System.currentTimeMillis() + } + + def end: Unit = end(lastAction) + + def end(a: Action.Action) = { + lastAction = a + lastTimes(a.id) = System.currentTimeMillis() - startTimes(a.id) + times(a.id) += lastTimes(a.id) + } + + def end[T](returnValue: => T): T = { + val result = returnValue + end + result + } + + def last(a: Action.Action) = lastTimes(a.id) + + def last: Long = last(lastAction) + + def profile[T](a: Action.Action)(block: => T): T = { + lastAction = a + startTimes(a.id) = System.currentTimeMillis() + val result = block // call-by-name + lastTimes(a.id) = System.currentTimeMillis() - startTimes(a.id) + times(a.id) += lastTimes(a.id) + result + } + +} + +object SynthesisInfo { + object Action extends Enumeration { + type Action = Value + val Synthesis, + Verification, Generation, Evaluation = Value +// VerificationBranch, VerificationCondition, +// EvaluationGeneration, EvaluationBranch, EvaluationCondition, +// VerificationCounterExampleGen +// = Value + } +} \ No newline at end of file diff --git a/src/main/scala/lesynth/SynthesizerExamples.scala b/src/main/scala/lesynth/SynthesizerExamples.scala index c88f9521c..7f5cd95ac 100755 --- a/src/main/scala/lesynth/SynthesizerExamples.scala +++ b/src/main/scala/lesynth/SynthesizerExamples.scala @@ -34,6 +34,9 @@ import scala.collection.mutable.{Set => MutableSet} import scala.util.control.Breaks.break import scala.util.control.Breaks.breakable +import SynthesisInfo._ +import SynthesisInfo.Action._ + class SynthesizerForRuleExamples( // some synthesis instance information val solver: Solver, @@ -83,7 +86,8 @@ class SynthesizerForRuleExamples( //private var solver: Solver = _ private var ctx: LeonContext = _ private var initialPrecondition: Expr = _ - private var variableRefinements: MutableMap[Identifier, MutableSet[ClassType]] = _ + + private var variableRefiner: VariableRefiner = _ // can be used to unnecessary syntheses private var variableRefinedBranch = false private var variableRefinedCondition = true // assure initial synthesis @@ -100,20 +104,15 @@ class SynthesizerForRuleExamples( private var accumulatingExpression: Expr => Expr = _ //private var accumulatingExpressionMatch: Expr => Expr = _ - var flag1 = false - var flag2 = false - - // time - var startTime: Long = _ - var verTime: Long = 0 - var synTime: Long = 0 + // information about the synthesis + private val synthInfo = new SynthesisInfo // filtering/ranking with examples support var exampleRunner: ExampleRunner = _ def analyzeProgram = { - - val temp = System.currentTimeMillis + + synthInfo.start(Verification) Globals.allSolved = Some(true) import TreeOps._ @@ -145,26 +144,24 @@ class SynthesizerForRuleExamples( Array(fileName, "--timeout=" + leonTimeout) info("Leon context array: " + args.mkString(",")) ctx = processOptions(reporter, args.toList) - val solver = new TimeoutSolver(new FairZ3Solver(ctx), 1000L) + val solver = new TimeoutSolver(new FairZ3Solver(ctx), 4000L) + //new TimeoutSolver(synthesisContext.solver.getNewSolver, 2000L) solver.setProgram(program) Globals.allSolved = solver.solve(theExpr) fine("solver said " + Globals.allSolved + " for " + theExpr) //interactivePause - val time = System.currentTimeMillis - temp - //fine("Analysis took: " + time + ", from report: " + report.totalTime) - - // accumulate - verTime += time + // measure time + synthInfo.end + fine("Analysis took of theExpr: " + synthInfo.last) } // TODO return boolean (do not do unecessary analyze) def generateCounterexamples(program: Program, funDef: FunDef, number: Int): (Seq[Map[Identifier, Expr]], Expr) = { fine("generate counter examples with funDef.prec= " + funDef.precondition.getOrElse(BooleanLiteral(true))) - val temp = System.currentTimeMillis - + // get current precondition var precondition = funDef.precondition.getOrElse(BooleanLiteral(true)) // where we will accumulate counterexamples as sequence of maps @@ -196,30 +193,32 @@ class SynthesizerForRuleExamples( ind += 1 } - val temptime = System.currentTimeMillis - temp - fine("Generation of counter-examples took: " + temptime) - verTime += temptime +// val temptime = System.currentTimeMillis - temp +// fine("Generation of counter-examples took: " + temptime) +// verTime += temptime // return found counterexamples and the formed precondition (maps, precondition) } - - + def getCurrentBuilder = new InitialEnvironmentBuilder(allDeclarations) - def synthesizeBranchExpressions = - inSynth.getExpressions(getCurrentBuilder) + def synthesizeBranchExpressions = { + synthInfo.profile(Generation) { inSynth.getExpressions(getCurrentBuilder) } + } def synthesizeBooleanExpressions = { + synthInfo.start(Generation) if ( variableRefinedCondition ) { // store for later fetch (will memoize values) booleanExpressionsSaved = - inSynthBoolean.getExpressions(getCurrentBuilder) take numberOfBooleanSnippets + inSynthBoolean.getExpressions(getCurrentBuilder). + filterNot(expr => refiner.isAvoidable(expr.getSnippet, problem.as)) take numberOfBooleanSnippets // reset flag variableRefinedCondition = false } - booleanExpressionsSaved + synthInfo end booleanExpressionsSaved } def interactivePause = { @@ -267,6 +266,8 @@ class SynthesizerForRuleExamples( // funDef of the hole fine("postcondition is: " + holeFunDef.getPostcondition) + fine("declarations we see: " + allDeclarations.map(_.toString).mkString("\n")) + interactivePause // accumulate precondition for the remaining branch to synthesize accumulatingPrecondition = holeFunDef.precondition.getOrElse(BooleanLiteral(true)) @@ -277,18 +278,9 @@ class SynthesizerForRuleExamples( //accumulatingExpressionMatch = accumulatingExpression // each variable of super type can actually have a subtype - // get sine declaration maps to be able to refine them - val directSubclassMap = loader.directSubclassesMap - val variableDeclarations = loader.variableDeclarations - // map from identifier into a set of possible subclasses - variableRefinements = MutableMap.empty - for (varDec <- variableDeclarations) { - varDec match { - case LeonDeclaration(_, _, typeOfVar: ClassType, ImmediateExpression(_, LeonVariable(id))) => - variableRefinements += (id -> MutableSet(directSubclassMap(typeOfVar).toList: _*)) - case _ => - } - } + // get sine declaration maps to be able to refine them + variableRefiner = new VariableRefiner(loader.directSubclassesMap, + loader.variableDeclarations, loader.classMap, reporter) // calculate cases that should not happen refiner = new Refiner(program, hole, holeFunDef) @@ -302,6 +294,8 @@ class SynthesizerForRuleExamples( } def countPassedExamples(snippet: Expr) = { + synthInfo.start(Action.Evaluation) + val oldPreconditionSaved = holeFunDef.precondition val oldBodySaved = holeFunDef.body @@ -324,8 +318,8 @@ class SynthesizerForRuleExamples( replace(Map(ResultVariable() -> LeonVariable(resFresh)), matchToIfThenElse(holeFunDef.getPostcondition))) } - fine("going to count passed for: " + holeFunDef) - fine("going to count passed for: " + expressionToCheck) + finest("going to count passed for: " + holeFunDef) + finest("going to count passed for: " + expressionToCheck) val count = exampleRunner.countPassed(expressionToCheck) // if (snippet.toString == "Cons(l1.head, concat(l1.tail, l2))") @@ -334,11 +328,14 @@ class SynthesizerForRuleExamples( holeFunDef.precondition = oldPreconditionSaved holeFunDef.body = oldBodySaved - count + synthInfo end count } def evaluateCandidate(snippet: Expr, mapping: Map[Identifier, Expr]) = { + + synthInfo.start(Action.Evaluation) + val oldPreconditionSaved = holeFunDef.precondition val oldBodySaved = holeFunDef.body @@ -349,8 +346,7 @@ class SynthesizerForRuleExamples( val accumulatedExpression = accumulatingExpression(snippet) // set appropriate body to the function for the correct evaluation holeFunDef.body = Some(accumulatedExpression) - - + import TreeOps._ val expressionToCheck = //Globals.bodyAndPostPlug(exp) @@ -361,24 +357,25 @@ class SynthesizerForRuleExamples( replace(Map(ResultVariable() -> LeonVariable(resFresh)), matchToIfThenElse(holeFunDef.getPostcondition))) } - fine("going to evaluate candidate for: " + holeFunDef) - fine("going to evaluate candidate for: " + expressionToCheck) + finest("going to evaluate candidate for: " + holeFunDef) + finest("going to evaluate candidate for: " + expressionToCheck) val count = exampleRunner.evaluate(expressionToCheck, mapping) -// if (snippet.toString == "Cons(l1.head, concat(l1.tail, l2))") -// interactivePause holeFunDef.precondition = oldPreconditionSaved holeFunDef.body = oldBodySaved - count +// if(snippet.toString == "checkf(f2, reverse(r2))") +// interactivePause + + synthInfo end count } def synthesize: Report = { reporter.info("Synthesis called on file: " + fileName) - // get start time - startTime = System.currentTimeMillis + // profile + synthInfo start Synthesis reporter.info("Initializing synthesizer: ") reporter.info("numberOfBooleanSnippets: %d".format(numberOfBooleanSnippets)) @@ -427,18 +424,6 @@ class SynthesizerForRuleExamples( var numberOfTested = 0 - // just printing of expressions and pass counts - fine( { - val (it1, it2) = snippetsIterator.duplicate // we are dealing with iterators, need to duplicate - val logString = ((it1 zip Iterator.range(0, numberOfTestsInIteration)) map { - case ((snippet: Output, ind: Int)) => ind + ": snippet is " + snippet.getSnippet + - " pass count is " + countPassedExamples(snippet.getSnippet) - }).mkString("\n") - snippetsIterator = it2 - logString - }) - //interactivePause - reporter.info("Going into a enumeration/testing phase.") fine("evaluating examples: " + exampleRunner.counterExamples.mkString("\n")) @@ -454,14 +439,22 @@ class SynthesizerForRuleExamples( it1.take(batchSize). map(_.getSnippet).filterNot( snip => { - if (snip.toString == "merge(sort(split(list).fst), sort(split(list).snd))") println("AAA") - (seenBranchExpressions contains snip.toString) || refiner.isAvoidable(snip, problem.as) } ).toSeq } info("got candidates of size: " + candidates.size) //interactivePause + + // printing candidates and pass counts + fine( { + val logString = ((candidates.zipWithIndex) map { + case ((snippet: Expr, ind: Int)) => ind + ": snippet is " + snippet.toString + + " pass count is " + countPassedExamples(snippet) + }).mkString("\n") + logString + }) + //interactivePause if (candidates.size > 0) { val ranker = new Ranker(candidates, @@ -472,23 +465,27 @@ class SynthesizerForRuleExamples( info("maxCandidate is: " + maxCandidate) numberOfTested += batchSize -// if (candidates.exists(_.toString == "merge(sort(split(list).fst), sort(split(list).snd))")) { -// println(ranker.printTuples) -// println("AAA2") -// println("Candidates: " + candidates.zipWithIndex.map({ -// case (cand, ind) => "[" + ind + "]" + cand.toString -// }).mkString(", ")) -// println("Examples: " + exampleRunner.counterExamples.zipWithIndex.map({ -// case (example, ind) => "[" + ind + "]" + example.toString -// }).mkString(", ")) -// interactivePause -// } + if ( + candidates.exists(_.toString contains "checkf(f2, Cons(x, r2))") + ) { + println("maxCandidate is: " + maxCandidate) + println(ranker.printTuples) + println("AAA2") + println("Candidates: " + candidates.zipWithIndex.map({ + case (cand, ind) => "[" + ind + "]" + cand.toString + }).mkString("\n")) + println("Examples: " + exampleRunner.counterExamples.zipWithIndex.map({ + case (example, ind) => "[" + ind + "]" + example.toString + }).mkString("\n")) + interactivePause + } - //interactivePause + interactivePause if (tryToSynthesizeBranch(maxCandidate)) { noBranchFoundIteration = 0 break } + interactivePause noBranchFoundIteration += 1 } @@ -500,9 +497,9 @@ class SynthesizerForRuleExamples( // if did not found for any of the branch expressions if (found) { - val endTime = System.currentTimeMillis - reporter.info("We are done, in time: " + (endTime - startTime)) - return new FullReport(holeFunDef, (endTime - startTime)) + synthInfo end Synthesis + reporter.info("We are done, in time: " + synthInfo.last) + return new FullReport(holeFunDef, synthInfo) } if ( variableRefinedBranch ) { @@ -545,7 +542,7 @@ class SynthesizerForRuleExamples( // TODO spare one analyzing step // analyze the program fine("analyzing program for funDef:" + holeFunDef) - solver.setProgram(program) +// solver.setProgram(program) analyzeProgram // check if solver could solved this instance @@ -605,7 +602,7 @@ class SynthesizerForRuleExamples( try { { //if (!maps.isEmpty) { // proceed with synthesizing boolean expressions - //solver.setProgram(program) + solver.setProgram(program) // reconstruct (only defined number of boolean expressions) val innerSnippets = synthesizeBooleanExpressions @@ -659,23 +656,6 @@ class SynthesizerForRuleExamples( def tryToSynthesizeBooleanCondition(snippetTree: Expr, innerSnippetTree: Expr, precondition: Expr): (Boolean, Option[Expr]) = { - // trying some examples that cannot be verified - if (snippetTree.toString == "Cons(l.head, insert(e, l.tail))" //&& - //innerSnippetTree.toString.contains("aList.head < bList.head") -) { - val endTime = System.currentTimeMillis - reporter.info("We are done, in time: " + (endTime - startTime)) - interactivePause -} - - if (snippetTree.toString == "Cons(aList.head, merge(aList.tail, bList))" //&& - //innerSnippetTree.toString.contains("aList.head < bList.head") -) { - val endTime = System.currentTimeMillis - reporter.info("We are done, in time: " + (endTime - startTime)) - interactivePause -} - // new condition together with existing precondition val newCondition = And(Seq(accumulatingPrecondition, innerSnippetTree)) @@ -705,7 +685,7 @@ class SynthesizerForRuleExamples( // if expression implies counterexamples add it to the precondition and try to validate program holeFunDef.precondition = Some(newCondition) // do analysis - solver.setProgram(program) +// solver.setProgram(program) analyzeProgram // program is valid, we have a branch if (Globals.allSolved == Some(true)) { @@ -732,41 +712,15 @@ class SynthesizerForRuleExamples( // set to set new precondition val preconditionToRestore = Some(accumulatingPrecondition) - // check for refinements - checkRefinements(innerSnippetTree) match { - case Some(refinementPair @ (id, classType)) => - fine("And now we have refinement type: " + refinementPair) - fine("variableRefinements(id) before" + variableRefinements(id)) - variableRefinements(id) -= loader.classMap(classType.id) - fine("variableRefinements(id) after" + variableRefinements(id)) - - // if we have a single subclass possible to refine - if (variableRefinements(id).size == 1) { - reporter.info("We do variable refinement for " + id) - - val newType = variableRefinements(id).head - fine("new type is: " + newType) - - // update declarations - allDeclarations = - for (dec <- allDeclarations) - yield dec match { - case LeonDeclaration(inSynthType, _, decClassType, imex @ ImmediateExpression(_, LeonVariable(`id`))) => - LeonDeclaration( - imex, TypeTransformer(newType), newType) - case _ => - dec - } - - // the reason for two flags is for easier management of re-syntheses only if needed - variableRefinedBranch = true - variableRefinedCondition = true - - } else - fine("we cannot do variable refinement :(") - case _ => + val variableRefinementResult = variableRefiner.checkRefinements(innerSnippetTree, allDeclarations) + if (variableRefinementResult._1) { + allDeclarations = variableRefinementResult._2 + + // the reason for two flags is for easier management of re-syntheses only if needed + variableRefinedBranch = true + variableRefinedCondition = true } - + // found a boolean snippet, break (true, preconditionToRestore) } else { @@ -783,12 +737,4 @@ class SynthesizerForRuleExamples( } // notFalseSolveReturn match { } - // inspect the expression if some refinements can be done - def checkRefinements(expr: Expr) = expr match { - case CaseClassInstanceOf(classDef, LeonVariable(id)) => - Some((id, classDef)) - case _ => - None - } - -} +} \ No newline at end of file diff --git a/src/main/scala/lesynth/VariableRefiner.scala b/src/main/scala/lesynth/VariableRefiner.scala new file mode 100755 index 000000000..cc3b664cd --- /dev/null +++ b/src/main/scala/lesynth/VariableRefiner.scala @@ -0,0 +1,88 @@ +package lesynth + +import scala.collection.mutable.{Map => MutableMap} +import scala.collection.mutable.{Set => MutableSet} + +import leon._ +import leon.purescala.Trees._ +import leon.purescala.TypeTrees._ +import leon.purescala.Definitions._ +import leon.purescala.Common.{ Identifier, FreshIdentifier } + +import insynth.interfaces._ +import insynth.leon.loader._ +import insynth.leon._ + +import insynth.util.logging.HasLogger + +// each variable of super type can actually have a subtype +// get sine declaration maps to be able to refine them +class VariableRefiner(directSubclassMap: Map[ClassType, Set[ClassType]], variableDeclarations: Seq[LeonDeclaration], + classMap: Map[Identifier, ClassType], reporter: Reporter = new DefaultReporter) extends HasLogger { + + // map from identifier into a set of possible subclasses + private var variableRefinements: MutableMap[Identifier, MutableSet[ClassType]] = MutableMap.empty + for (varDec <- variableDeclarations) { + varDec match { + case LeonDeclaration(_, _, typeOfVar: ClassType, ImmediateExpression(_, Variable(id))) => + variableRefinements += (id -> MutableSet(directSubclassMap(typeOfVar).toList: _*)) + case _ => + } + } + + def checkRefinements(expr: Expr, allDeclarations: List[Declaration]) = + // check for refinements + getIdAndClassDef(expr) match { + case Some(refinementPair @ (id, classType)) if variableRefinements(id).size > 1 => + fine("And now we have refinement type: " + refinementPair) + fine("variableRefinements(id) before" + variableRefinements(id)) + variableRefinements(id) -= classMap(classType.id) + fine("variableRefinements(id) after" + variableRefinements(id)) + + // if we have a single subclass possible to refine + if (variableRefinements(id).size == 1) { + reporter.info("We do variable refinement for " + id) + + val newType = variableRefinements(id).head + fine("new type is: " + newType) + + // update declarations + val newDeclarations = + for (dec <- allDeclarations) + yield dec match { + case LeonDeclaration(inSynthType, _, decClassType, imex @ ImmediateExpression(_, Variable(`id`))) => + (( + newType.classDef match { + case newTypeCaseClassDef@CaseClassDef(id, parent, fields) => + for (field <- fields) + yield LeonDeclaration( + ImmediateExpression( "Field(" + newTypeCaseClassDef + "." + field.id + ")", + CaseClassSelector(newTypeCaseClassDef, imex.expr, field.id) ), + TypeTransformer(field.id.getType), field.id.getType + ) + case _ => + Seq.empty + } + ): Seq[Declaration]) :+ LeonDeclaration(imex, TypeTransformer(newType), newType) + case _ => + Seq(dec) + } + + (true, newDeclarations.flatten) + } else { + fine("we cannot do variable refinement :(") + (false, allDeclarations) + } + case _ => + (false, allDeclarations) + } + + // inspect the expression if some refinements can be done + def getIdAndClassDef(expr: Expr) = expr match { + case CaseClassInstanceOf(classDef, Variable(id)) => + Some((id, classDef)) + case _ => + None + } + +} \ No newline at end of file diff --git a/src/main/scala/lesynth/rules/ConditionAbductionSynthesisTwoPhase.scala b/src/main/scala/lesynth/rules/ConditionAbductionSynthesisTwoPhase.scala index ded1a568d..0b28a9986 100755 --- a/src/main/scala/lesynth/rules/ConditionAbductionSynthesisTwoPhase.scala +++ b/src/main/scala/lesynth/rules/ConditionAbductionSynthesisTwoPhase.scala @@ -49,16 +49,17 @@ case object ConditionAbductionSynthesisTwoPhase extends Rule("Condition abductio val synthesizer = new SynthesizerForRuleExamples( solver, program, desiredType, holeFunDef, p, sctx, freshResVar, - 20, 2, 1, + 40, 2, 1, reporter = reporter, introduceExamples = getInputExamples, numberOfTestsInIteration = 50, - numberOfCheckInIteration = 5 + numberOfCheckInIteration = 2 ) synthesizer.synthesize match { case EmptyReport => RuleApplicationImpossible - case FullReport(resFunDef, _) => + case fr@FullReport(resFunDef, _) => + println(fr.summaryString) RuleSuccess(Solution(resFunDef.getPrecondition, Set.empty, resFunDef.body.get)) } } catch { diff --git a/src/test/scala/lesynth/VariableRefinerTest.scala b/src/test/scala/lesynth/VariableRefinerTest.scala new file mode 100644 index 000000000..f1be15d79 --- /dev/null +++ b/src/test/scala/lesynth/VariableRefinerTest.scala @@ -0,0 +1,90 @@ +package lesynth + +import scala.util.Random + +import org.scalatest.FunSpec +import org.scalatest.GivenWhenThen + +import leon.purescala.Definitions._ +import leon.purescala.Common._ +import leon.purescala.TypeTrees._ +import leon.purescala.Trees._ + +import insynth.leon._ + +class VariableRefinerTest extends FunSpec with GivenWhenThen { + + val listClassId = FreshIdentifier("List") + val listAbstractClassDef = new AbstractClassDef(listClassId) + val listAbstractClass = new AbstractClassType(listAbstractClassDef) + + val nilClassId = FreshIdentifier("Nil") + val nilAbstractClassDef = new CaseClassDef(nilClassId).setParent(listAbstractClassDef) + val nilAbstractClass = new CaseClassType(nilAbstractClassDef) + + val consClassId = FreshIdentifier("Cons") + val consAbstractClassDef = new CaseClassDef(consClassId).setParent(listAbstractClassDef) + val headId = FreshIdentifier("head").setType(Int32Type) + consAbstractClassDef.fields = Seq(VarDecl(headId, Int32Type)) + val consAbstractClass = new CaseClassType(consAbstractClassDef) + + val directSubclassMap: Map[ClassType, Set[ClassType]] = Map( + listAbstractClass -> Set(nilAbstractClass, consAbstractClass) + ) + + val listVal = Variable(FreshIdentifier("tempVar")) + val listLeonDeclaration = LeonDeclaration( + ImmediateExpression( "tempVar", listVal ), + TypeTransformer(listAbstractClass), listAbstractClass + ) + + val classMap: Map[Identifier, ClassType] = Map( + listClassId -> listAbstractClass, + nilClassId -> nilAbstractClass, + consClassId -> consAbstractClass + ) + + describe("A variable refiner with list ADT") { + + it("should refine if variable is not Nil") { + + given("a VariableRefiner") + val variableRefiner = new VariableRefiner( + directSubclassMap, Seq(listLeonDeclaration), classMap + ) + + then("it should return appropriate id and class def") + expect(Some((listVal.id, nilAbstractClassDef))) { + variableRefiner.getIdAndClassDef(CaseClassInstanceOf(nilAbstractClassDef, listVal)) + } + and("return None for some unknown expression") + expect(None) { + variableRefiner.getIdAndClassDef(listVal) + } + + then("declarations should be updated accordingly") + val allDeclarations = List(listLeonDeclaration) + expect((true, + LeonDeclaration( + ImmediateExpression( "Field(" + consAbstractClassDef + "." + headId + ")", + CaseClassSelector(consAbstractClassDef, listVal, headId) ), + TypeTransformer(Int32Type), Int32Type + ) :: + LeonDeclaration( + listLeonDeclaration.expression, TypeTransformer(consAbstractClass), consAbstractClass + ) :: Nil + )) { + variableRefiner.checkRefinements(CaseClassInstanceOf(nilAbstractClassDef, listVal), + allDeclarations) + } + + and("after 2nd consequtive call, nothing should happen") + expect((false, allDeclarations)) { + variableRefiner.checkRefinements(CaseClassInstanceOf(nilAbstractClassDef, listVal), + allDeclarations) + } + } + + } + +} \ No newline at end of file -- GitLab