diff --git a/src/main/scala/inox/tip/Parser.scala b/src/main/scala/inox/tip/Parser.scala index 2b612203654e136cbb362443e7b64a21b4f10584..aebbbbc90e59148b7900ab96b59c139c4d522913 100644 --- a/src/main/scala/inox/tip/Parser.scala +++ b/src/main/scala/inox/tip/Parser.scala @@ -5,15 +5,17 @@ package tip import utils._ -import smtlib.lexer._ +import smtlib.lexer.{Tokens => LT, _} import smtlib.parser.Commands.{FunDef => SMTFunDef, _} import smtlib.parser.Terms.{Let => SMTLet, Forall => SMTForall, Identifier => SMTIdentifier, _} import smtlib.theories._ import smtlib.theories.experimental._ -import smtlib.extensions.tip.{Parser => SMTParser, Lexer => SMTLexer} import smtlib.extensions.tip.Terms.{Lambda => SMTLambda, Application => SMTApplication, _} import smtlib.extensions.tip.Commands._ +import Terms.{Assume => SMTAssume} +import Commands._ + import scala.collection.BitSet import java.io.{Reader, File, BufferedReader, FileReader} @@ -32,7 +34,7 @@ class Parser(file: File) { } def parseScript: (Symbols, Expr) = { - val parser = new SMTParser(new SMTLexer(positions.reader)) + val parser = new TipParser(new TipLexer(positions.reader)) val script = parser.parseScript var assertions: Seq[Expr] = Seq.empty @@ -52,7 +54,7 @@ class Parser(file: File) { funs: Map[SSymbol, Identifier], adts: Map[SSymbol, Identifier], selectors: Map[SSymbol, Identifier], - vars: Map[SSymbol, Variable], + vars: Map[SSymbol, Expr], tps: Map[SSymbol, TypeParameter], val symbols: Symbols) { @@ -84,12 +86,12 @@ class Parser(file: File) { new Locals(funs, adts, selectors, vars, tps ++ seq, symbols) def isVariable(sym: SSymbol): Boolean = vars.isDefinedAt(sym) - def getVariable(sym: SSymbol): Variable = vars.get(sym).getOrElse { + def getVariable(sym: SSymbol): Expr = vars.get(sym).getOrElse { throw new MissformedTIPException("unknown variable " + sym, sym.optPos) } - def withVariable(sym: SSymbol, v: Variable): Locals = withVariables(Seq(sym -> v)) - def withVariables(seq: Seq[(SSymbol, Variable)]): Locals = + def withVariable(sym: SSymbol, v: Expr): Locals = withVariables(Seq(sym -> v)) + def withVariables(seq: Seq[(SSymbol, Expr)]): Locals = new Locals(funs, adts, selectors, vars ++ seq, tps, symbols) def isFunction(sym: SSymbol): Boolean = funs.isDefinedAt(sym) @@ -105,11 +107,21 @@ class Parser(file: File) { def registerADT(adt: ADTDefinition): Locals = registerADTs(Seq(adt)) def registerADTs(defs: Seq[ADTDefinition]): Locals = new Locals(funs, adts, selectors, vars, tps, symbols.withADTs(defs)) + + def withSymbols(symbols: Symbols) = new Locals(funs, adts, selectors, vars, tps, symbols) } protected val NoLocals: Locals = new Locals( Map.empty, Map.empty, Map.empty, Map.empty, Map.empty, NoSymbols) + protected object DatatypeInvariantExtractor { + def unapply(cmd: Command): Option[(Seq[SSymbol], SSymbol, Sort, Term)] = cmd match { + case DatatypeInvariantPar(syms, s, sort, pred) => Some((syms, s, sort, pred)) + case DatatypeInvariant(s, sort, pred) => Some((Seq.empty, s, sort, pred)) + case _ => None + } + } + protected def extractCommand(cmd: Command) (implicit locals: Locals): (Option[Expr], Locals) = cmd match { case Assert(term) => @@ -233,11 +245,57 @@ class Parser(file: File) { } val field = ValDef(FreshIdentifier("val"), IntegerType).setPos(sym.optPos) - new ADTConstructor(id, tparams, None, Seq(field), Set.empty) + new ADTConstructor(id, tparams, None, Seq(field), Set.empty).setPos(sym.optPos) }) + case DatatypeInvariantExtractor(syms, s, sort, pred) => + val tps = syms.map(s => TypeParameter.fresh(s.name).setPos(s.optPos)) + val adt = extractSort(sort)(locals.withGenerics(syms zip tps)) match { + case adt @ ADTType(id, typeArgs) if tps == typeArgs => adt.getADT(locals.symbols).definition + case _ => throw new MissformedTIPException(s"Unexpected type parameters $syms", sort.optPos) + } + + val root = adt.root(locals.symbols) + val rootType = root.typed(locals.symbols).toType + val vd = ValDef(FreshIdentifier(s.name), rootType).setPos(s.optPos) + + val body = if (root != adt) { + val adtType = adt.typed(root.typeArgs)(locals.symbols).toType + Implies( + IsInstanceOf(vd.toVariable, adtType).setPos(pred.optPos), + extractTerm(pred)( + locals.withGenerics(syms zip root.typeArgs) + .withVariable(s, AsInstanceOf(vd.toVariable, adtType).setPos(s.optPos)) + ) + ).setPos(pred.optPos) + } else { + extractTerm(pred)(locals.withVariable(s, vd.toVariable)) + } + + val (optAdt, fd) = root.invariant(locals.symbols) match { + case Some(fd) => + val Seq(v) = fd.params + val fullBody = locals.symbols.and( + fd.fullBody, + exprOps.replaceFromSymbols(Map(v.toVariable -> vd.toVariable), body).setPos(body) + ).setPos(body) + (None, fd.copy(fullBody = fullBody)) + + case None => + val id = FreshIdentifier("inv$" + root.id.name) + val newAdt = root match { + case sort: ADTSort => sort.copy(flags = sort.flags + HasADTInvariant(id)) + case cons: ADTConstructor => cons.copy(flags = cons.flags + HasADTInvariant(id)) + } + val fd = new FunDef(id, root.tparams, Seq(vd), BooleanType, body, Set.empty).setPos(s.optPos) + (Some(newAdt), fd) + } + + (None, locals.withSymbols( + locals.symbols.withFunctions(Seq(fd)).withADTs(optAdt.toSeq))) + case CheckSat() => - // TODO: what do I do with this?? + // FIXME: what do I do with this?? (None, locals) case _ => @@ -335,7 +393,7 @@ class Parser(file: File) { case QualifiedIdentifier(SimpleIdentifier(sym), None) if locals.isVariable(sym) => locals.getVariable(sym) - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("assume")), None), Seq(pred, body)) => + case SMTAssume(pred, body) => Assume(extractTerm(pred), extractTerm(body)) case SMTLet(binding, bindings, term) => diff --git a/src/main/scala/inox/tip/Printer.scala b/src/main/scala/inox/tip/Printer.scala index e33a078f45d11f40b3e9717dca0ce41b77ee7083..6778349a5e2f0c09350f41d9596c31b03bb0e9b1 100644 --- a/src/main/scala/inox/tip/Printer.scala +++ b/src/main/scala/inox/tip/Printer.scala @@ -9,6 +9,9 @@ import smtlib.extensions.tip.Terms.{Lambda => SMTLambda, Application => SMTAppli import smtlib.extensions.tip.Commands._ import smtlib.Interpreter +import Terms.{Assume => SMTAssume} +import Commands._ + import java.io.Writer import scala.collection.mutable.{Map => MutableMap} @@ -109,6 +112,30 @@ class Printer(val program: InoxProgram, writer: Writer) extends solvers.smtlib.S }).toList }).toList )) + + val invariants = adts + .collect { case (adt: ADTType, _) => adt } + .map(_.getADT.definition) + .flatMap(_.invariant) + + for (fd <- invariants) { + val Seq(vd) = fd.params + if (fd.tparams.isEmpty) { + emit(DatatypeInvariant( + id2sym(vd.id), + declareSort(vd.tpe), + toSMT(fd.fullBody)(Map(vd.id -> id2sym(vd.id))) + )) + } else { + val tps = fd.tparams.map(tpd => declareSort(tpd.tp).id.symbol) + emit(DatatypeInvariantPar( + tps, + id2sym(vd.id), + declareSort(vd.tpe), + toSMT(fd.fullBody)(Map(vd.id -> id2sym(vd.id))) + )) + } + } } } @@ -188,10 +215,7 @@ class Printer(val program: InoxProgram, writer: Writer) extends solvers.smtlib.S SMTApplication(toSMT(caller), args.map(toSMT)) case Assume(pred, body) => - FunctionApplication( - QualifiedIdentifier(SMTIdentifier(SSymbol("assume")), None), - Seq(toSMT(pred), toSMT(body)) - ) + SMTAssume(toSMT(pred), toSMT(body)) case _ => super.toSMT(e) } diff --git a/src/main/scala/inox/tip/TipExtensions.scala b/src/main/scala/inox/tip/TipExtensions.scala new file mode 100644 index 0000000000000000000000000000000000000000..cdbdae3f3e1cbb0168ba3c1df37f8fcd72bbffef --- /dev/null +++ b/src/main/scala/inox/tip/TipExtensions.scala @@ -0,0 +1,104 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package tip + +import smtlib.printer._ +import smtlib.parser.Terms._ +import smtlib.parser.Commands._ +import smtlib.lexer.{Tokens => LT, _} +import smtlib.extensions.tip.{Parser => SMTParser, Lexer => SMTLexer} + +object Tokens { + import LT.ReservedWord + + case object Assume extends ReservedWord + case object DatatypeInvariant extends ReservedWord +} + +object Terms { + case class Assume(pred: Term, body: Term) extends TermExtension { + def print(ctx: PrintingContext): Unit = { + ctx.print("(assume ") + ctx.print(pred) + ctx.print(" ") + ctx.print(body) + ctx.print(")") + } + } +} + +object Commands { + case class DatatypeInvariant(name: SSymbol, sort: Sort, pred: Term) extends CommandExtension { + def print(ctx: PrintingContext): Unit = { + ctx.print("(datatype-invariant ") + ctx.print(name) + ctx.print(" ") + ctx.print(sort) + ctx.print(" ") + ctx.print(pred) + ctx.print(")") + } + } + + case class DatatypeInvariantPar(syms: Seq[SSymbol], name: SSymbol, sort: Sort, pred: Term) extends CommandExtension { + def print(ctx: PrintingContext): Unit = { + ctx.print("(datatype-invariant (par ") + ctx.printNary(syms, "(", " ", ") ") + ctx.print(name) + ctx.print(" ") + ctx.print(sort) + ctx.print(" ") + ctx.print(pred) + ctx.print("))") + } + } +} + +class TipLexer(reader: java.io.Reader) extends SMTLexer(reader) { + import LT.Token + + override protected def toReserved(s: String): Option[Token] = s match { + case "assume" => Some(Token(Tokens.Assume)) + case "datatype-invariant" => Some(Token(Tokens.DatatypeInvariant)) + case _ => super.toReserved(s) + } +} + +class TipParser(lexer: TipLexer) extends SMTParser(lexer) { + import Terms._ + import Commands._ + + override protected def parseTermWithoutParens: Term = getPeekToken.kind match { + case Tokens.Assume => + val pred = parseTerm + val body = parseTerm + Assume(pred, body) + + case _ => super.parseTermWithoutParens + } + + override protected def parseCommandWithoutParens: Command = getPeekToken.kind match { + case Tokens.DatatypeInvariant => + eat(Tokens.DatatypeInvariant) + getPeekToken.kind match { + case LT.OParen => + eat(LT.OParen) + eat(LT.Par) + val tps = parseMany(parseSymbol _) + val name = parseSymbol + val sort = parseSort + val pred = parseTerm + eat(LT.CParen) + DatatypeInvariantPar(tps, name, sort, pred) + + case _ => + val name = parseSymbol + val sort = parseSort + val pred = parseTerm + DatatypeInvariant(name, sort, pred) + } + + case _ => super.parseCommandWithoutParens + } +}