From d1aea47602e17ad5f0d44cfcd17674c5c6ad3aa5 Mon Sep 17 00:00:00 2001 From: Nicolas Voirol <voirol.nicolas@gmail.com> Date: Wed, 13 May 2015 19:52:47 +0200 Subject: [PATCH] Finished up call-by-name --- src/main/scala/leon/purescala/ExprOps.scala | 4 +- .../leon/solvers/smtlib/SMTLIBTarget.scala | 118 ++++++++++-------- .../leon/solvers/smtlib/SMTLIBZ3Target.scala | 4 + .../leon/solvers/z3/AbstractZ3Solver.scala | 54 ++++---- .../scala/leon/synthesis/rules/ADTDual.scala | 4 +- .../scala/leon/synthesis/rules/ADTSplit.scala | 2 +- .../leon/verification/InductionTactic.scala | 2 +- 7 files changed, 98 insertions(+), 90 deletions(-) diff --git a/src/main/scala/leon/purescala/ExprOps.scala b/src/main/scala/leon/purescala/ExprOps.scala index 565f1d3f6..99e246512 100644 --- a/src/main/scala/leon/purescala/ExprOps.scala +++ b/src/main/scala/leon/purescala/ExprOps.scala @@ -1480,8 +1480,8 @@ object ExprOps { val isType = IsInstanceOf(Variable(on), cct) - val recSelectors = cct.fields.collect { - case vd if vd.getType == on.getType => vd.id + val recSelectors = (cct.classDef.fields zip cct.fieldsTypes).collect { + case (vd, tpe) if tpe == on.getType => vd.id } if (recSelectors.isEmpty) { diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index ac5e0fd4b..ed8e75fbf 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -125,6 +125,7 @@ trait SMTLIBTarget extends Interruptible { /* Symbol handling */ protected object SimpleSymbol { + def apply(sym: SSymbol) = QualifiedIdentifier(SMTIdentifier(sym)) def unapply(term: Term): Option[SSymbol] = term match { case QualifiedIdentifier(SMTIdentifier(sym, Seq()), None) => Some(sym) case _ => None @@ -132,9 +133,7 @@ trait SMTLIBTarget extends Interruptible { } import scala.language.implicitConversions - protected implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = { - QualifiedIdentifier(SMTIdentifier(s)) - } + protected implicit def symbolToQualifiedId(s: SSymbol): QualifiedIdentifier = SimpleSymbol(s) protected val adtManager = new ADTManager(context) @@ -666,65 +665,78 @@ trait SMTLIBTarget extends Interruptible { case (SNumeral(n), Some(ft @ FunctionType(from, to))) => val dynLambda = lambdas.toB(ft) - letDefs.get(dynLambda) match { - case Some(DefineFun(SMTFunDef(a, SortedVar(dispatcher, dkind) +: args, rkind, body))) => - - object EQ { - def unapply(t: Term): Option[(Term, Term)] = t match { - case Core.Equals(e1, e2) => Some((e1, e2)) - case FunctionApplication(f, Seq(e1, e2)) if f.toString == "=" => Some((e1, e2)) - case _ => None - } - } + val DefineFun(SMTFunDef(a, SortedVar(dispatcher, dkind) +: args, rkind, body)) = letDefs(dynLambda) - object Num { - def unapply(t: Term): Option[BigInt] = t match { - case SNumeral(n) => Some(n) - case FunctionApplication(f, Seq(SNumeral(n))) if f.toString == "-" => Some(-n) - case _ => None - } - } + object EQ { + def unapply(t: Term): Option[(Term, Term)] = t match { + case Core.Equals(e1, e2) => Some((e1, e2)) + case FunctionApplication(f, Seq(e1, e2)) if f.toString == "=" => Some((e1, e2)) + case _ => None + } + } - val d = symbolToQualifiedId(dispatcher) - def dispatch(t: Term): Term = t match { - case Core.ITE(EQ(di, Num(ni)), thenn, elze) if di == d => - if (ni == n) thenn else dispatch(elze) - case Core.ITE(Core.And(EQ(di, Num(ni)), _), thenn, elze) if di == d => - if (ni == n) thenn else dispatch(elze) - case _ => t - } + object AND { + def unapply(t: Term): Option[Seq[Term]] = t match { + case Core.And(e1, e2) => Some(Seq(e1, e2)) + case FunctionApplication(SimpleSymbol(SSymbol("and")), args) => Some(args) + case _ => None + } + def apply(ts: Seq[Term]): Term = ts match { + case Seq() => throw new IllegalArgumentException + case Seq(t) => t + case _ => FunctionApplication(SimpleSymbol(SSymbol("and")), ts) + } + } - def extract(t: Term): Expr = { - def recCond(term: Term, index: Int): Seq[Expr] = term match { - case Core.And(e1, e2) => - val e1s = recCond(e1, index) - e1s ++ recCond(e2, index + e1s.size) - case EQ(e1, e2) => - recCond(e2, index) - case _ => Seq(fromSMT(term, from(index))) - } + object Num { + def unapply(t: Term): Option[BigInt] = t match { + case SNumeral(n) => Some(n) + case FunctionApplication(f, Seq(SNumeral(n))) if f.toString == "-" => Some(-n) + case _ => None + } + } - def recCases(term: Term, matchers: Seq[Expr]): Seq[(Seq[Expr], Expr)] = term match { - case Core.ITE(cond, thenn, elze) => - val cs = recCond(cond, matchers.size) - recCases(thenn, matchers ++ cs) ++ recCases(elze, matchers) - case _ => Seq(matchers -> fromSMT(term, to)) - } + val d = symbolToQualifiedId(dispatcher) + def dispatch(t: Term): Term = t match { + case Core.ITE(EQ(di, Num(ni)), thenn, elze) if di == d => + if (ni == n) thenn else dispatch(elze) + case Core.ITE(AND(EQ(di, Num(ni)) +: rest), thenn, elze) if di == d => + if (ni == n) Core.ITE(AND(rest), thenn, dispatch(elze)) else dispatch(elze) + case _ => t + } - val cases = recCases(t, Seq.empty) - val (default, rest) = cases.partition(_._1.isEmpty) - - assert(default.size == 1 && rest.forall(_._1.size == from.size)) - PartialLambda(rest, Some(default.head._2), ft) - } + def extract(t: Term): Expr = { + def recCond(term: Term, index: Int): Seq[Expr] = term match { + case AND(es) => + es.foldLeft(Seq.empty[Expr]) { case (seq, e) => seq ++ recCond(e, index + seq.size) } + case EQ(e1, e2) => + recCond(e2, index) + case _ => Seq(fromSMT(term, from(index))) + } - val lambdaTerm = dispatch(body) - val lambda = extract(lambdaTerm) - lambda + def recCases(term: Term, matchers: Seq[Expr]): Seq[(Seq[Expr], Expr)] = term match { + case Core.ITE(cond, thenn, elze) => + val cs = recCond(cond, matchers.size) + recCases(thenn, matchers ++ cs) ++ recCases(elze, matchers) + case AND(es) if to == BooleanType => + Seq((matchers ++ recCond(term, matchers.size)) -> BooleanLiteral(true)) + case EQ(e1, e2) if to == BooleanType => + Seq((matchers ++ recCond(term, matchers.size)) -> BooleanLiteral(true)) + case _ => Seq(matchers -> fromSMT(term, to)) + } - case None => unsupported(InfiniteIntegerLiteral(n), "Unknown function ref") + val cases = recCases(t, Seq.empty) + val (default, rest) = cases.partition(_._1.isEmpty) + val leonDefault = if (default.isEmpty && to == BooleanType) BooleanLiteral(false) else default.head._2 + + assert(rest.forall(_._1.size == from.size)) + PartialLambda(rest, Some(leonDefault), ft) } + val lambdaTerm = dispatch(body) + val lambda = extract(lambdaTerm) + lambda + case (SNumeral(n), Some(RealType)) => FractionalLiteral(n, 1) diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala index eac8f47f9..334f9f0ad 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBZ3Target.scala @@ -76,6 +76,10 @@ trait SMTLIBZ3Target extends SMTLIBTarget { val n = s.name.split("!").toList.last GenericValue(tp, n.toInt) + // XXX: (NV) Z3 doesn't seem to produce models for uninterpreted functions that + // don't impact satisfiability... + case (SNumeral(n), Some(ft: FunctionType)) if !letDefs.isDefinedAt(lambdas.toB(ft)) => + purescala.ExprOps.simplestValue(ft) case (QualifiedIdentifier(ExtendedIdentifier(SSymbol("as-array"), k: SSymbol), _), Some(tpe)) => if (letDefs contains k) { diff --git a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala index ca2c0e714..d7a3fd8fa 100644 --- a/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala +++ b/src/main/scala/leon/solvers/z3/AbstractZ3Solver.scala @@ -257,7 +257,7 @@ trait AbstractZ3Solver extends Solver { case ft @ FunctionType(from, to) => sorts.cachedB(ft) { - val symbol = z3.mkFreshStringSymbol("fun") + val symbol = z3.mkFreshStringSymbol(ft.toString) z3.mkUninterpretedSort(symbol) } @@ -494,7 +494,7 @@ trait AbstractZ3Solver extends Solver { z3.mkApp(functionDefToDecl(tfd), args.map(rec): _*) case fa @ Application(caller, args) => - val ft @ FunctionType(froms, to) = bestRealType(caller.getType) + val ft @ FunctionType(froms, to) = normalizeType(caller.getType) val funDecl = lambdas.cachedB(ft) { val sortSeq = (ft +: froms).map(tpe => typeToSort(tpe)) val returnSort = typeToSort(to) @@ -569,25 +569,8 @@ trait AbstractZ3Solver extends Solver { def rec(t: Z3AST, tpe: TypeTree): Expr = { val kind = z3.getASTKind(t) - (kind, tpe) match { - case (Z3NumeralIntAST(Some(v)), ft @ FunctionType(fts, tt)) => lambdas.getB(ft) match { - case None => throw new IllegalArgumentException - case Some(decl) => model.getModelFuncInterpretations.find(_._1 == decl) match { - case None => throw new IllegalArgumentException - case Some((_, mapping, elseValue)) => - val lambdaID = InfiniteIntegerLiteral(v) - val leonElseValue = rec(elseValue, tt) - PartialLambda(mapping.flatMap { case (z3Args, z3Result) => - z3.getASTKind(z3Args.head) match { - case Z3NumeralIntAST(Some(v)) if InfiniteIntegerLiteral(v) == lambdaID => - List((z3Args.tail zip fts).map(p => rec(p._1, p._2)) -> rec(z3Result, tt)) - case _ => Nil - } - }, Some(leonElseValue), ft) - } - } - - case (Z3NumeralIntAST(Some(v)), _) => + kind match { + case Z3NumeralIntAST(Some(v)) => val leading = t.toString.substring(0, 2 min t.toString.length) if(leading == "#x") { _root_.smtlib.common.Hexadecimal.fromString(t.toString.substring(2)) match { @@ -604,7 +587,7 @@ trait AbstractZ3Solver extends Solver { InfiniteIntegerLiteral(v) } - case (Z3NumeralIntAST(None), _) => + case Z3NumeralIntAST(None) => _root_.smtlib.common.Hexadecimal.fromString(t.toString.substring(2)) match { case Some(hexa) => tpe match { @@ -615,9 +598,9 @@ trait AbstractZ3Solver extends Solver { case None => unsound(t, "could not translate Z3NumeralIntAST numeral") } - case (Z3NumeralRealAST(n: BigInt, d: BigInt), _) => FractionalLiteral(n, d) + case Z3NumeralRealAST(n: BigInt, d: BigInt) => FractionalLiteral(n, d) - case (Z3AppAST(decl, args), _) => + case Z3AppAST(decl, args) => val argsSize = args.size if(argsSize == 0 && (variables containsB t)) { variables.toA(t) @@ -673,6 +656,22 @@ trait AbstractZ3Solver extends Solver { case None => unsound(t, "invalid array AST") } + case ft @ FunctionType(fts, tt) => lambdas.getB(ft) match { + case None => throw new IllegalArgumentException + case Some(decl) => model.getModelFuncInterpretations.find(_._1 == decl) match { + case None => throw new IllegalArgumentException + case Some((_, mapping, elseValue)) => + val leonElseValue = rec(elseValue, tt) + PartialLambda(mapping.flatMap { case (z3Args, z3Result) => + if (t == z3Args.head) { + List((z3Args.tail zip fts).map(p => rec(p._1, p._2)) -> rec(z3Result, tt)) + } else { + Nil + } + }, Some(leonElseValue), ft) + } + } + case tp: TypeParameter => val id = t.toString.split("!").last.toInt GenericValue(tp, id) @@ -695,13 +694,6 @@ trait AbstractZ3Solver extends Solver { FiniteMap(elems, from, to) } - case ft @ FunctionType(fts, tt) => - rec(t, RawArrayType(tupleTypeWrap(fts), tt)) match { - case r: RawArrayValue => - val elems = r.elems.toSeq.map { case (k, v) => unwrapTuple(k, fts.size) -> v } - PartialLambda(elems, Some(r.default), ft) - } - case tpe @ SetType(dt) => model.getSetValue(t) match { case None => unsound(t, "invalid set AST") diff --git a/src/main/scala/leon/synthesis/rules/ADTDual.scala b/src/main/scala/leon/synthesis/rules/ADTDual.scala index 392670edf..004a88d04 100644 --- a/src/main/scala/leon/synthesis/rules/ADTDual.scala +++ b/src/main/scala/leon/synthesis/rules/ADTDual.scala @@ -18,10 +18,10 @@ case object ADTDual extends NormalizingRule("ADTDual") { val (toRemove, toAdd) = exprs.collect { case eq @ Equals(cc @ CaseClass(ct, args), e) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty => - (eq, IsInstanceOf(e, ct) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) + (eq, IsInstanceOf(e, ct) +: (ct.classDef.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) case eq @ Equals(e, cc @ CaseClass(ct, args)) if (variablesOf(e) -- as).isEmpty && (variablesOf(cc) & xs).nonEmpty => - (eq, IsInstanceOf(e, ct) +: (ct.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) + (eq, IsInstanceOf(e, ct) +: (ct.classDef.fields zip args).map{ case (vd, ex) => Equals(ex, caseClassSelector(ct, e, vd.id)) } ) }.unzip if (toRemove.nonEmpty) { diff --git a/src/main/scala/leon/synthesis/rules/ADTSplit.scala b/src/main/scala/leon/synthesis/rules/ADTSplit.scala index 7ed086087..32848ae1c 100644 --- a/src/main/scala/leon/synthesis/rules/ADTSplit.scala +++ b/src/main/scala/leon/synthesis/rules/ADTSplit.scala @@ -94,7 +94,7 @@ case object ADTSplit extends Rule("ADT Split.") { val cases = for ((sol, (cct, problem, pattern)) <- sols zip subInfo) yield { if (sol.pre != BooleanLiteral(true)) { - val substs = (for ((field,arg) <- cct.fields zip problem.as ) yield { + val substs = (for ((field,arg) <- cct.classDef.fields zip problem.as ) yield { (arg, caseClassSelector(cct, id.toVariable, field.id)) }).toMap globalPre ::= and(IsInstanceOf(Variable(id), cct), replaceFromIDs(substs, sol.pre)) diff --git a/src/main/scala/leon/verification/InductionTactic.scala b/src/main/scala/leon/verification/InductionTactic.scala index 65f96a090..dd437c224 100644 --- a/src/main/scala/leon/verification/InductionTactic.scala +++ b/src/main/scala/leon/verification/InductionTactic.scala @@ -21,7 +21,7 @@ class InductionTactic(vctx: VerificationContext) extends DefaultTactic(vctx) { } private def selectorsOfParentType(parentType: ClassType, cct: CaseClassType, expr: Expr): Seq[Expr] = { - val childrenOfSameType = cct.fields.filter(_.getType == parentType) + val childrenOfSameType = (cct.classDef.fields zip cct.fieldsTypes).collect { case (vd, tpe) if tpe == parentType => vd } for (field <- childrenOfSameType) yield { caseClassSelector(cct, expr, field.id) } -- GitLab