diff --git a/src/main/scala/leon/purescala/TreeOps.scala b/src/main/scala/leon/purescala/TreeOps.scala index 476f0ca853837bf01b604032f089e4f7dd67ac77..9a44926166161496d39b2c99a9934aefce28817b 100644 --- a/src/main/scala/leon/purescala/TreeOps.scala +++ b/src/main/scala/leon/purescala/TreeOps.scala @@ -1192,7 +1192,6 @@ object TreeOps { } def rec(e : Expr, path : Seq[Expr]): Expr = e match { - case Let(i, e, b) => // The path condition for the body of the Let is the same as outside, plus an equality to constrain the newly bound variable. val se = rec(e, path) @@ -1241,25 +1240,29 @@ object TreeOps { case _ => IfExpr(rc, rec(then, rc +: path), rec(elze, Not(rc) +: path)) } - case And(es) => + case And(es) => { var extPath = path var continue = true - And(for(e <- es if continue) yield { + var r = And(for(e <- es if continue) yield { val se = rec(e, extPath) if(se == BooleanLiteral(false)) continue = false extPath = se +: extPath se }) + if (continue) r else BooleanLiteral(false) + } - case Or(es) => + case Or(es) => { var extPath = path var continue = true - Or(for(e <- es if continue) yield { + val r = Or(for(e <- es if continue) yield { val se = rec(e, extPath) if(se == BooleanLiteral(true)) continue = false extPath = Not(se) +: extPath - se + se }) + if (continue) r else BooleanLiteral(true) + } case b if b.getType == BooleanType && impliedBy(b, path) => BooleanLiteral(true) diff --git a/src/main/scala/leon/purescala/Trees.scala b/src/main/scala/leon/purescala/Trees.scala index adabf268d41d53230fbfd4b91fd2460a32fc2fc5..f522f4f0dd1ebe88d9918d2ef338bb76b81498e2 100644 --- a/src/main/scala/leon/purescala/Trees.scala +++ b/src/main/scala/leon/purescala/Trees.scala @@ -24,7 +24,7 @@ object Trees { case class Error(description: String) extends Expr with Terminal with ScalacPositional case class Choose(vars: List[Identifier], pred: Expr) extends Expr with ScalacPositional with UnaryExtractable { - def extract = Some((pred, (e: Expr) => Choose(vars, e).setPosInfo(this))) + def extract = Some((pred, (e: Expr) => Choose(vars, e).setPosInfo(this).setType(TupleType(vars.map(_.getType))))) } /* Like vals */ diff --git a/src/main/scala/leon/synthesis/Rules.scala b/src/main/scala/leon/synthesis/Rules.scala index 6f02e767afecf632012e55c62b00501481cdb3cc..e289e6f838acbf266265a1b459f748e284b116b0 100644 --- a/src/main/scala/leon/synthesis/Rules.scala +++ b/src/main/scala/leon/synthesis/Rules.scala @@ -22,6 +22,7 @@ object Rules { EqualitySplit, CEGIS, Assert, + DetupleOutput, ADTSplit, IntegerEquation, IntegerInequalities diff --git a/src/main/scala/leon/synthesis/Solution.scala b/src/main/scala/leon/synthesis/Solution.scala index 902ec8e767f1d2723610b69bcc64ad586af3bbba..6dbfc3844b52a0eb873ea7304c6436ee9c67ca89 100644 --- a/src/main/scala/leon/synthesis/Solution.scala +++ b/src/main/scala/leon/synthesis/Solution.scala @@ -2,7 +2,7 @@ package leon package synthesis import leon.purescala.Trees._ -import leon.purescala.TypeTrees.TypeTree +import leon.purescala.TypeTrees.{TypeTree,TupleType} import leon.purescala.Definitions._ import leon.purescala.TreeOps._ import leon.xlang.Trees.LetDef @@ -38,7 +38,7 @@ object Solution { def unapply(s: Solution): Option[(Expr, Set[FunDef], Expr)] = if (s eq null) None else Some((s.pre, s.defs, s.term)) def choose(p: Problem): Solution = { - new Solution(BooleanLiteral(true), Set(), Choose(p.xs, p.phi)) + new Solution(BooleanLiteral(true), Set(), Choose(p.xs, p.phi).setType(TupleType(p.xs.map(_.getType)))) } // Generate the simplest, wrongest solution, used for complexity lowerbound diff --git a/src/main/scala/leon/synthesis/rules/DetupleOutput.scala b/src/main/scala/leon/synthesis/rules/DetupleOutput.scala new file mode 100644 index 0000000000000000000000000000000000000000..b5746897e6cddd6900ceebc0cf09eafe4d61851c --- /dev/null +++ b/src/main/scala/leon/synthesis/rules/DetupleOutput.scala @@ -0,0 +1,57 @@ +package leon +package synthesis +package rules + +import purescala.Trees._ +import purescala.Definitions._ +import purescala.Common._ +import purescala.TypeTrees._ +import purescala.TreeOps._ +import purescala.Extractors._ + +case object DetupleOutput extends Rule("Detuple Out") { + + def instantiateOn(sctx: SynthesisContext, p: Problem): Traversable[RuleInstantiation] = { + def isDecomposable(id: Identifier) = id.getType match { + case CaseClassType(t) if !t.isAbstract => true + case _ => false + } + + if (p.xs.exists(isDecomposable)) { + var subProblem = p.phi + + val (subOuts, outerOuts) = p.xs.map { x => + if (isDecomposable(x)) { + val CaseClassType(ccd @ CaseClassDef(name, _, fields)) = x.getType + + val newIds = fields.map(vd => FreshIdentifier(vd.id.name, true).setType(vd.getType)) + + val newCC = CaseClass(ccd, newIds.map(Variable(_))) + + subProblem = subst(x -> newCC, subProblem) + + (newIds, newCC) + } else { + (List(x), Variable(x)) + } + }.unzip + + val newOuts = subOuts.flatten + //sctx.reporter.warning("newOuts: " + newOuts.toString) + + val sub = Problem(p.as, p.pc, subProblem, newOuts) + + val onSuccess: List[Solution] => Option[Solution] = { + case List(sol) => + Some(Solution(sol.pre, sol.defs, LetTuple(newOuts, sol.term, Tuple(outerOuts)))) + case _ => + None + } + + + Some(RuleInstantiation.immediateDecomp(p, this, List(sub), onSuccess)) + } else { + Nil + } + } +} diff --git a/testcases/synthesis/cav2013/ManyTimeSec.scala b/testcases/synthesis/cav2013/ManyTimeSec.scala index ef521e14699601dc0349c353213835cae4636359..8e36f1009a22256422ef7b68e5fa1d209848b699 100644 --- a/testcases/synthesis/cav2013/ManyTimeSec.scala +++ b/testcases/synthesis/cav2013/ManyTimeSec.scala @@ -14,9 +14,41 @@ object ManyTimeSec { choose((seconds:Seconds) => timeAndSec(t,seconds)) def sec2time(seconds:Seconds):Time = choose((t:Time) => timeAndSec(t,seconds)) + def incTime(t0:Time,k:Int) : Time = choose((t1:Time) => time2sec(t1).total == time2sec(t0).total + k) + def testDetuple1(k:Int) : Seconds = { + choose((seconds0:Seconds) => + k == 2*seconds0.total + ) + } + def testDetuple2(total:Int) : Time = { + require(0 <= total) + choose((t:Time) => + 3600*t.h + 60*t.m + t.s == total && + t.h >= 0 && t.m >= 0 && t.m < 60 && t.s >= 0 && t.s < 60 + ) + } + + def incTimeUnfolded(t0:Time,k:Int) : Time = { + require(0 <= t0.h && 0 <= t0.m && t0.m < 60 && 0 <= t0.s && t0.s < 60) + choose((t1:Time,seconds0:Seconds) => + 3600*t0.h + 60*t0.m + t0.s == seconds0.total && + 3600*t1.h + 60*t1.m + t1.s == seconds0.total + k && + t1.h >= 0 && t1.m >= 0 && t1.m < 60 && t1.s >= 0 && t1.s < 60 + )._1 + } + + def incTimeUnfoldedOutOnly(t0:Time,k:Int) : Time = { + require(0 <= t0.h && 0 <= t0.m && t0.m < 60 && 0 <= t0.s && t0.s < 60) + val total = k + 3600*t0.h + 60*t0.m + t0.s + choose((t1:Time) => + 3600*t1.h + 60*t1.m + t1.s == total + k && + t1.h >= 0 && t1.m >= 0 && t1.m < 60 && t1.s >= 0 && t1.s < 60 + ) + } + def incTime2(h1:Int,m1:Int,s1:Int,k:Int) : (Int,Int,Int) = { require(0 <= k && 0 <= h1 && 0 <= m1 && m1 < 60 && 0 <= s1 && s1 < 60) choose((h2:Int,m2:Int,s2:Int) =>