diff --git a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala index c1b622242e9365a8302d4a0aaca6d3125ba7a137..b5b81a3e37d0aa905a80ff2bfd457d1757e8aba1 100644 --- a/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/leon/solvers/smtlib/SMTLIBTarget.scala @@ -624,194 +624,174 @@ trait SMTLIBTarget extends Interruptible { /* Translate an SMTLIB term back to a Leon Expr */ protected def fromSMT(pair: (Term, TypeTree))(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { - fromSMT(pair._1, pair._2) + fromSMT(pair._1, Some(pair._2)) } - protected def fromUntypedSMT(t: Term)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = t match { - case SimpleSymbol(s) if constructors.containsB(s) => - constructors.toA(s) match { - case cct: CaseClassType => - CaseClass(cct, Nil) - case t => - unsupported(t, "woot? for a single constructor for non-case-object") - } + protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { + fromSMT(s, Some(tpe)) + } - case SimpleSymbol(s) if lets contains s => - fromUntypedSMT(lets(s)) + protected def fromSMT(t: Term, otpe: Option[TypeTree] = None) + (implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = { - case SimpleSymbol(s) => - variables.getA(s).map(_.toVariable).getOrElse { - reporter.fatalError("Unknown symbol: "+s) - } - case _ => - reporter.fatalError("Unhandled case in fromUntypedSMT: " + t) + // Use as much information as there is, if there is an expected type, great, but it might not always be there + (t, otpe) match { + case (_, Some(UnitType)) => + UnitLiteral() - } + case (FixedSizeBitVectors.BitVectorConstant(n, b), Some(CharType)) if b == BigInt(32) => + CharLiteral(n.toInt.toChar) - protected def fromSMT(s: Term, tpe: TypeTree)(implicit lets: Map[SSymbol, Term], letDefs: Map[SSymbol, DefineFun]): Expr = (s, tpe) match { - case (_, UnitType) => - UnitLiteral() + case (FixedSizeBitVectors.BitVectorConstant(n, b), Some(Int32Type)) if b == BigInt(32) => + IntLiteral(n.toInt) - case (FixedSizeBitVectors.BitVectorConstant(n, b), CharType) if b == BigInt(32) => - CharLiteral(n.toInt.toChar) + case (SHexadecimal(h), Some(CharType)) => + CharLiteral(h.toInt.toChar) - case (SHexadecimal(h), CharType) => - CharLiteral(h.toInt.toChar) + case (SHexadecimal(hexa), Some(Int32Type)) => + IntLiteral(hexa.toInt) - case (SNumeral(n), IntegerType) => - InfiniteIntegerLiteral(n) + case (SDecimal(d), Some(RealType)) => + RealLiteral(d) - case (SDecimal(d), RealType) => - RealLiteral(d) + case (SNumeral(n), Some(RealType)) => + RealLiteral(BigDecimal(n)) - case (SNumeral(n), RealType) => - RealLiteral(BigDecimal(n)) + case (FunctionApplication(SimpleSymbol(SSymbol("ite")), Seq(cond, thenn, elze)), t) => + IfExpr( + fromSMT(cond, Some(BooleanType)), + fromSMT(thenn, t), + fromSMT(elze, t) + ) - case (Core.True(), BooleanType) => BooleanLiteral(true) - case (Core.False(), BooleanType) => BooleanLiteral(false) + // Best-effort case + case (SNumeral(n), _) => + InfiniteIntegerLiteral(n) + + // EK: Since we have no type information, we cannot do type-directed + // extraction of defs, instead, we expand them in smt-world + case (SMTLet(binding, bindings, body), tpe) => + val defsMap: Map[SSymbol, Term] = (binding +: bindings).map { + case VarBinding(s, value) => (s, value) + }.toMap + + fromSMT(body, tpe)(lets ++ defsMap, letDefs) + case (SimpleSymbol(s), _) if constructors.containsB(s) => + constructors.toA(s) match { + case cct: CaseClassType => + CaseClass(cct, Nil) + case t => + unsupported(t, "woot? for a single constructor for non-case-object") + } - case (FixedSizeBitVectors.BitVectorConstant(n, b), Int32Type) if b == BigInt(32) => IntLiteral(n.toInt) - case (SHexadecimal(hexa), Int32Type) => IntLiteral(hexa.toInt) + case (FunctionApplication(SimpleSymbol(s), List(e)), _) if testers.containsB(s) => + testers.toA(s) match { + case cct: CaseClassType => + IsInstanceOf(fromSMT(e, cct), cct) + } - case (SimpleSymbol(s), _: ClassType) if constructors.containsB(s) => - constructors.toA(s) match { - case cct: CaseClassType => - CaseClass(cct, Nil) - case t => - unsupported(t, "woot? for a single constructor for non-case-object") - } + case (FunctionApplication(SimpleSymbol(s), List(e)), _) if selectors.containsB(s) => + selectors.toA(s) match { + case (cct: CaseClassType, i) => + CaseClassSelector(cct, fromSMT(e, cct), cct.fields(i).id) + } - case (SimpleSymbol(s), tpe) if lets contains s => - fromSMT(lets(s), tpe) + case (FunctionApplication(SimpleSymbol(s), args), _) if constructors.containsB(s) => + constructors.toA(s) match { + case cct: CaseClassType => + val rargs = args.zip(cct.fields.map(_.getType)).map(fromSMT) + CaseClass(cct, rargs) + case tt: TupleType => + val rargs = args.zip(tt.bases).map(fromSMT) + tupleWrap(rargs) + + case ArrayType(baseType) => + val IntLiteral(size) = fromSMT(args(0), Int32Type) + val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, baseType)) + + if(size > 10) { + val definedElements = elems.collect { + case (IntLiteral(i), value) => (i, value) + } + finiteArray(definedElements, Some(default, IntLiteral(size)), baseType) + + } else { + val entries = for (i <- 0 to size-1) yield elems.getOrElse(IntLiteral(i), default) + + finiteArray(entries, None, baseType) + } - case (SimpleSymbol(s), _) => - variables.getA(s).map(_.toVariable).getOrElse { - reporter.fatalError("Unknown symbol: "+s) - } + case t => + unsupported(t, "Woot? structural type that is non-structural") + } - case (FunctionApplication(SimpleSymbol(SSymbol("ite")), Seq(cond, thenn, elze)), t) => - IfExpr( - fromSMT(cond, BooleanType), - fromSMT(thenn, t), - fromSMT(elze, t) - ) + case (FunctionApplication(SimpleSymbol(s @ SSymbol(app)), args), _) => + (app, args) match { + case (">=", List(a, b)) => + GreaterEquals(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - case (FunctionApplication(SimpleSymbol(s), args), tpe) if constructors.containsB(s) => - constructors.toA(s) match { - case cct: CaseClassType => - val rargs = args.zip(cct.fields.map(_.getType)).map(fromSMT) - CaseClass(cct, rargs) - case tt: TupleType => - val rargs = args.zip(tt.bases).map(fromSMT) - tupleWrap(rargs) - - case ArrayType(baseType) => - val IntLiteral(size) = fromSMT(args(0), Int32Type) - val RawArrayValue(_, elems, default) = fromSMT(args(1), RawArrayType(Int32Type, baseType)) - - if(size > 10) { - val definedElements = elems.collect { - case (IntLiteral(i), value) => (i, value) - } - finiteArray(definedElements, Some(default, IntLiteral(size)), baseType) + case ("<=", List(a, b)) => + LessEquals(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - } else { - val entries = for (i <- 0 to size-1) yield elems.getOrElse(IntLiteral(i), default) + case (">", List(a, b)) => + GreaterThan(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - finiteArray(entries, None, baseType) - } + case (">", List(a, b)) => + LessThan(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - case t => - unsupported(t, "Woot? structural type that is non-structural") - } + case ("+", args) => + args.map(fromSMT(_, IntegerType)).reduceLeft(plus _) - // EK: Since we have no type information, we cannot do type-directed - // extraction of defs, instead, we expand them in smt-world - case (SMTLet(binding, bindings, body), tpe) => - val defsMap: Map[SSymbol, Term] = (binding +: bindings).map { - case VarBinding(s, value) => (s, value) - }.toMap + case ("-", List(a)) => + UMinus(fromSMT(a, IntegerType)) - fromSMT(body, tpe)(lets ++ defsMap, letDefs) + case ("-", List(a, b)) => + Minus(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - case (FunctionApplication(SimpleSymbol(SSymbol(app)), args), tpe) => { - app match { - case ">=" => - (args, tpe) match { - case (List(a, b), BooleanType) => GreaterEquals(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - } + case ("*", args) => + args.map(fromSMT(_, IntegerType)).reduceLeft(times _) - case "<=" => - (args, tpe) match { - case (List(a, b), BooleanType) => LessEquals(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - } + case ("/", List(a, b)) => + Division(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - case ">" => - (args, tpe) match { - case (List(a, b), BooleanType) => GreaterThan(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - } + case ("div", List(a, b)) => + Division(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - case "<" => - (args, tpe) match { - case (List(a, b), BooleanType) => LessThan(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - } + case ("not", List(a)) => + Not(fromSMT(a, BooleanType)) - case "+" => - (args, tpe) match { - case (List(a, b), IntegerType) => Plus(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - case (List(a, b), RealType) => RealPlus(fromSMT(a, RealType), fromSMT(b, RealType)) - } + case ("or", args) => + orJoin(args.map(fromSMT(_, BooleanType))) - case "not" => - (args, tpe) match { - case (List(a), BooleanType) => Not(fromSMT(a, BooleanType)) - } + case ("and", args) => + andJoin(args.map(fromSMT(_, BooleanType))) - case "or" => - (args, tpe) match { - case (List(a, b), BooleanType) => Or(fromSMT(a, BooleanType), fromSMT(b, BooleanType)) - } + case ("=", List(a, b)) => + val ra = fromSMT(a, None) + Equals(ra, fromSMT(b, ra.getType)) - case "and" => - (args, tpe) match { - case (List(a, b), BooleanType) => And(fromSMT(a, BooleanType), fromSMT(b, BooleanType)) - } + case _ => + reporter.fatalError("Function "+app+" not handled in fromSMT: "+s) + } - case "=" => - (args, tpe) match { - case (List(a, b), BooleanType) => - val ra = fromUntypedSMT(a) - Equals(ra, fromSMT(b, ra.getType)) - } + case (SimpleSymbol(s), otpe) if lets contains s => + fromSMT(lets(s), otpe) - case "*" => - (args, tpe) match { - case (List(a, b), IntegerType) => Times(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - case (List(a, b), RealType) => RealTimes(fromSMT(a, RealType), fromSMT(b, RealType)) - } + case (SimpleSymbol(s), otpe) => + variables.getA(s).map(_.toVariable).getOrElse { + reporter.fatalError("Unknown symbol: "+s) + } - case "-" => - (args, tpe) match { - case (List(a), IntegerType) => UMinus(fromSMT(a, IntegerType)) - case (List(a, b), IntegerType) => Minus(fromSMT(a, IntegerType), fromSMT(b, IntegerType)) - case (List(a), RealType) => RealUMinus(fromSMT(a, RealType)) - case (List(a, b), RealType) => RealMinus(fromSMT(a, RealType), fromSMT(b, RealType)) - } - case "/" => - (args, tpe) match { - case (List(a, b), RealType) => RealDivision(fromSMT(a, RealType), fromSMT(b, RealType)) - } + case (Core.True(), Some(BooleanType)) => BooleanLiteral(true) + case (Core.False(), Some(BooleanType)) => BooleanLiteral(false) + + case _ => + reporter.fatalError("Unhandled case in fromSMT: " + t+" (_ :"+otpe+")") - case _ => - reporter.fatalError("Function "+app+" not handled in fromSMT: "+s) - } } - case (QualifiedIdentifier(id, sort), tpe) => - reporter.fatalError("Unhandled case in fromSMT: " + id +": "+sort +" ("+tpe+")") - case _ => - reporter.fatalError("Unhandled case in fromSMT: " + (s, tpe)) } + } // Unique numbers diff --git a/src/main/scala/leon/solvers/sygus/SygusSolver.scala b/src/main/scala/leon/solvers/sygus/SygusSolver.scala index 952745b338a77f5cc5db9936da368808e2b5fe03..3cd370fefad7167cb7ba06fb3fb7a37800f2b655 100644 --- a/src/main/scala/leon/solvers/sygus/SygusSolver.scala +++ b/src/main/scala/leon/solvers/sygus/SygusSolver.scala @@ -25,6 +25,7 @@ import _root_.smtlib.common._ import _root_.smtlib.parser.Terms.{Identifier => SMTIdentifier, _} import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} import _root_.smtlib.parser.CommandsResponses.{Error => _, _} +import _root_.smtlib.parser.Parser.UnexpectedEOFException abstract class SygusSolver(val context: LeonContext, val program: Program, val p: Problem) extends SMTLIBTarget { implicit val ctx = context @@ -68,45 +69,48 @@ abstract class SygusSolver(val context: LeonContext, val program: Program, val p val synthPhi = replaceFromIDs(xToFdCall, p.phi) - val TopLevelAnds(clauses) = synthPhi + val constraint = implies(p.pc, synthPhi) - for(c <- clauses) { - emit(FunctionApplication(constraintId, Seq(toSMT(c)(bindings)))) - } + emit(FunctionApplication(constraintId, Seq(toSMT(constraint)(bindings)))) emit(SList(SSymbol("check-synth"))) // check-synth emits: success; unsat; fdef* // We currently cannot predict the amount of success we will get, so we read as many as possible - var lastRes = interpreter.parser.parseSExpr - while(lastRes == SSymbol("success")) { - lastRes = interpreter.parser.parseSExpr - } - - lastRes match { - case SSymbol("unsat") => - - val solutions = (for (x <- p.xs) yield { - interpreter.parser.parseCommand match { - case DefineFun(SMTFunDef(name, params, retSort, body)) => - val res = fromSMT(body, sorts.toA(retSort))(Map(), Map()) - Some(res) - case r => - reporter.warning("Unnexpected result from cvc4-sygus: "+r) - None + try { + var lastRes = interpreter.parser.parseSExpr + while(lastRes == SSymbol("success")) { + lastRes = interpreter.parser.parseSExpr + } + + lastRes match { + case SSymbol("unsat") => + + val solutions = (for (x <- p.xs) yield { + interpreter.parser.parseCommand match { + case DefineFun(SMTFunDef(name, params, retSort, body)) => + val res = fromSMT(body, sorts.toA(retSort))(Map(), Map()) + Some(res) + case r => + reporter.warning("Unnexpected result from cvc4-sygus: "+r) + None + } + }).flatten + + if (solutions.size == p.xs.size) { + Some(tupleWrap(solutions)) + } else { + None } - }).flatten - if (solutions.size == p.xs.size) { - Some(tupleWrap(solutions)) - } else { + case SSymbol("unknown") => None - } - case SSymbol("unknown") => - None - - case r => - reporter.warning("Unnexpected result from cvc4-sygus: "+r+" expected unsat") + case r => + reporter.warning("Unnexpected result from cvc4-sygus: "+r+" expected unsat") + None + } + } catch { + case _: UnexpectedEOFException => None } } diff --git a/src/main/scala/leon/synthesis/rules/SygusCVC4.scala b/src/main/scala/leon/synthesis/rules/SygusCVC4.scala index 2e8c40dd393cca442dd9758172bbd7aa2bade23a..7a0e7edf495ef20597e4c9cde1939b5778ebcbbb 100644 --- a/src/main/scala/leon/synthesis/rules/SygusCVC4.scala +++ b/src/main/scala/leon/synthesis/rules/SygusCVC4.scala @@ -22,7 +22,7 @@ case object SygusCVC4 extends Rule("SygusCVC4") { s.checkSynth() match { case Some(expr) => - RuleClosed(Solution.term(expr)) + RuleClosed(Solution.term(expr, isTrusted = false)) case None => RuleFailed() }