diff --git a/src/main/scala/inox/parsers/TIPParser.scala b/src/main/scala/inox/parsers/TIPParser.scala deleted file mode 100644 index afd40476a1705acdcaf44637681fd7906e6fe6ad..0000000000000000000000000000000000000000 --- a/src/main/scala/inox/parsers/TIPParser.scala +++ /dev/null @@ -1,714 +0,0 @@ -/* Copyright 2009-2016 EPFL, Lausanne */ - -package inox -package parsers - -import _root_.smtlib.lexer._ -import _root_.smtlib.parser.{Parser => SMTParser} -import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} -import _root_.smtlib.parser.Terms.{Let => SMTLet, Forall => SMTForall, Identifier => SMTIdentifier, _} -import _root_.smtlib.theories._ -import _root_.smtlib.theories.experimental._ - -import scala.collection.BitSet -import java.io.{Reader, File} - -import utils._ - -import scala.language.implicitConversions - -trait TIPParser { - val trees: ast.Trees - import trees._ - - def parse(file: File): (Symbols, Expr) = { - val pos = new PositionProvider(new java.io.BufferedReader(new java.io.FileReader(file)), Some(file)) - val parser = new Parser(new Lexer(pos.reader), pos) - parser.parseTIPScript - } - - def parse(reader: Reader): (Symbols, Expr) = { - val pos = new PositionProvider(reader, None) - val parser = new Parser(new Lexer(pos.reader), pos) - parser.parseTIPScript - } - - private class PositionProvider(_reader: Reader, _file: Option[File]) { - val (reader, file): (Reader, File) = _file match { - case Some(file) => (_reader, file) - case None => - val file = File.createTempFile("input", ".tip") - val writer = new java.io.BufferedWriter(new java.io.FileWriter(file)) - - val buffer = new Array[Char](1024) - var count: Int = 0 - while ((count = _reader.read(buffer)) != -1) { - writer.write(buffer, 0, count) - } - - val reader = new java.io.BufferedReader(new java.io.FileReader(file)) - (reader, file) - } - - private val fileLines: List[String] = scala.io.Source.fromFile(file).getLines.toList - - def get(line: Int, col: Int): OffsetPosition = { - val point = fileLines.take(line).map(_.length).sum + col - new OffsetPosition(line, col, point, file) - } - } - - class MissformedTIPException(reason: String, pos: Position) - extends Exception("Missfomed TIP source @" + pos + ":\n" + reason) - - protected class Parser(lex: Lexer, positions: PositionProvider) extends SMTParser(lex) { - - implicit def smtlibPositionToPosition(pos: Option[_root_.smtlib.common.Position]): Position = { - pos.map(p => positions.get(p.line, p.col)).getOrElse(NoPosition) - } - - protected class Locals ( - funs: Map[SSymbol, Identifier], - adts: Map[SSymbol, Identifier], - selectors: Map[SSymbol, Identifier], - vars: Map[SSymbol, Variable], - tps: Map[SSymbol, TypeParameter], - val symbols: Symbols) { - - def isADT(sym: SSymbol): Boolean = adts.isDefinedAt(sym) - def lookupADT(sym: SSymbol): Option[Identifier] = adts.get(sym) - def getADT(sym: SSymbol): Identifier = adts.get(sym).getOrElse { - throw new MissformedTIPException("unknown ADT " + sym, sym.optPos) - } - - def withADT(sym: SSymbol, id: Identifier): Locals = withADTs(Seq(sym -> id)) - def withADTs(seq: Seq[(SSymbol, Identifier)]): Locals = - new Locals(funs, adts ++ seq, selectors, vars, tps, symbols) - - def isADTSelector(sym: SSymbol): Boolean = selectors.isDefinedAt(sym) - def getADTSelector(sym: SSymbol): Identifier = selectors.get(sym).getOrElse { - throw new MissformedTIPException("unknown ADT selector " + sym, sym.optPos) - } - - def withADTSelectors(seq: Seq[(SSymbol, Identifier)]): Locals = - new Locals(funs, adts, selectors ++ seq, vars, tps, symbols) - - def isGeneric(sym: SSymbol): Boolean = tps.isDefinedAt(sym) - def getGeneric(sym: SSymbol): TypeParameter = tps.get(sym).getOrElse { - throw new MissformedTIPException("unknown generic type " + sym, sym.optPos) - } - - def withGeneric(sym: SSymbol, tp: TypeParameter): Locals = withGenerics(Seq(sym -> tp)) - def withGenerics(seq: Seq[(SSymbol, TypeParameter)]): Locals = - new Locals(funs, adts, selectors, vars, tps ++ seq, symbols) - - def isVariable(sym: SSymbol): Boolean = vars.isDefinedAt(sym) - def getVariable(sym: SSymbol): Variable = vars.get(sym).getOrElse { - throw new MissformedTIPException("unknown variable " + sym, sym.optPos) - } - - def withVariable(sym: SSymbol, v: Variable): Locals = withVariables(Seq(sym -> v)) - def withVariables(seq: Seq[(SSymbol, Variable)]): Locals = - new Locals(funs, adts, selectors, vars ++ seq, tps, symbols) - - def isFunction(sym: SSymbol): Boolean = funs.isDefinedAt(sym) - def getFunction(sym: SSymbol): Identifier = funs.get(sym).getOrElse { - throw new MissformedTIPException("unknown function " + sym, sym.optPos) - } - - def withFunction(sym: SSymbol, fd: FunDef): Locals = withFunctions(Seq(sym -> fd)) - def withFunctions(fds: Seq[(SSymbol, FunDef)]): Locals = - new Locals(funs ++ fds.map(p => p._1 -> p._2.id), adts, selectors, vars, tps, - symbols.withFunctions(fds.map(_._2))) - - def registerADT(adt: ADTDefinition): Locals = registerADTs(Seq(adt)) - def registerADTs(defs: Seq[ADTDefinition]): Locals = - new Locals(funs, adts, selectors, vars, tps, symbols.withADTs(defs)) - } - - protected val NoLocals: Locals = new Locals( - Map.empty, Map.empty, Map.empty, Map.empty, Map.empty, NoSymbols) - - protected def getIdentifier(sym: SSymbol): Identifier = { - // TODO: check keywords! - FreshIdentifier(sym.name) - } - - def parseTIPScript: (Symbols, Expr) = { - - var assertions: Seq[Expr] = Seq.empty - implicit var locals: Locals = NoLocals - - while (peekToken != null) { - eat(Tokens.OParen) - val (newAssertions, newLocals) = parseTIPCommand(nextToken) - assertions ++= newAssertions - locals = newLocals - eat(Tokens.CParen) - } - - val expr: Expr = locals.symbols.andJoin(assertions) - (locals.symbols, expr) - } - - protected def parseParTerm(implicit locals: Locals): Expr = getPeekToken.kind match { - case Tokens.OParen => - eat(Tokens.OParen) - getPeekToken.kind match { - case Tokens.Par => - eat(Tokens.Par) - val tps = parseMany(parseSymbol _) - val res = parseTerm - eat(Tokens.CParen) - extractExpr(res)(locals.withGenerics(tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos)))) - - case _ => - extractExpr(parseBefore(Tokens.CParen)(parseTermWithoutParens _)) - } - case _ => - extractExpr(parseTerm) - } - - protected def parseTIPCommand(token: Tokens.Token) - (implicit locals: Locals): (Option[Expr], Locals) = token match { - case Tokens.SymbolLit("assert-not") => - (Some(Not(parseParTerm)), locals) - - case Tokens.Token(Tokens.Assert) => - (Some(parseParTerm), locals) - - case Tokens.Token(Tokens.DefineFun) | Tokens.Token(Tokens.DefineFunRec) => - val isRec = token.kind == Tokens.DefineFunRec - val (tps, funDef) = getPeekToken.kind match { - case Tokens.OParen => - eat(Tokens.OParen) - eat(Tokens.Par) - val tps = parseMany(parseSymbol _) - val res = parseWithin(Tokens.OParen, Tokens.CParen)(parseFunDef _) - eat(Tokens.CParen) - (tps, res) - - case _ => - (Seq.empty[SSymbol], parseFunDef) - } - - val tpsLocals = locals.withGenerics(tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos))) - val fdsLocals = if (!isRec) tpsLocals else { - tpsLocals.withFunction(funDef.name, extractSignature(funDef, tps)(tpsLocals)) - } - val fd = extractFunction(funDef, tps)(fdsLocals) - (None, locals.withFunction(funDef.name, fd)) - - case Tokens.Token(Tokens.DefineFunsRec) => - val (funDec, funDecs) = parseOneOrMore(() => { - eat(Tokens.OParen) - val (tps, funDec) = getPeekToken.kind match { - case Tokens.Par => - eat(Tokens.Par) - val tps = parseMany(parseSymbol _) - val funDec = parseWithin(Tokens.OParen, Tokens.CParen)(parseFunDec _) - (tps -> funDec) - - case _ => - (Seq.empty[SSymbol], parseFunDec) - } - eat(Tokens.CParen) - (tps, funDec) - }) - - val (body, bodies) = parseOneOrMore(parseTerm _) - assert(funDecs.size == bodies.size) - - val funDefs = ((funDec -> body) +: (funDecs zip bodies)).map { - case ((tps, FunDec(name, params, returnSort)), body) => - tps -> SMTFunDef(name, params, returnSort, body) - } - - val bodyLocals = locals.withFunctions(for ((tps, funDef) <- funDefs) yield { - val tpsLocals = locals.withGenerics(tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos))) - funDef.name -> extractSignature(funDef, tps)(tpsLocals) - }) - - (None, locals.withFunctions(for ((tps, funDef) <- funDefs) yield { - val tpsLocals = bodyLocals.withGenerics(tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos))) - funDef.name -> extractFunction(funDef, tps)(tpsLocals) - })) - - case Tokens.Token(Tokens.DeclareDatatypes) => - val tps = parseMany(parseSymbol _) - val datatypes = parseMany(parseDatatypes _) - - var locs = locals.withADTs(datatypes - .flatMap { case (sym, conss) => - val tpeId = getIdentifier(sym) - val cids = if (conss.size == 1) { - Seq(conss.head.sym -> tpeId) - } else { - conss.map(c => c.sym -> getIdentifier(c.sym)) - } - (sym -> tpeId) +: cids - }) - - val generics = tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos)) - for ((sym, conss) <- datatypes) { - val adtLocals = locs.withGenerics(generics) - val children = for (Constructor(sym, fields) <- conss) yield { - val id = locs.getADT(sym) - val vds = fields.map { case (s, sort) => - ValDef(getIdentifier(s), extractType(sort)(adtLocals)).setPos(s.optPos) - } - - (id, vds) - } - - val allVds: Set[ValDef] = children.flatMap(_._2).toSet - val allTparams: Set[TypeParameter] = children.flatMap(_._2).toSet.flatMap { - (vd: ValDef) => locs.symbols.typeParamsOf(vd.tpe): Set[TypeParameter] - } - - val tparams: Seq[TypeParameterDef] = tps.flatMap { sym => - val tp = adtLocals.getGeneric(sym) - if (allTparams(tp)) Some(TypeParameterDef(tp).setPos(sym.optPos)) else None - } - - val parent = if (children.size > 1) { - val id = adtLocals.getADT(sym) - locs = locs.registerADT( - new ADTSort(id, tparams, children.map(_._1), Set.empty).setPos(sym.optPos)) - Some(id) - } else { - None - } - - locs = locs.registerADTs((conss zip children).map { case (cons, (cid, vds)) => - new ADTConstructor(cid, tparams, parent, vds, Set.empty).setPos(cons.sym.optPos) - }).withADTSelectors((conss zip children).flatMap { case (Constructor(_, fields), (_, vds)) => - (fields zip vds).map(p => p._1._1 -> p._2.id) - }) - } - - (None, locs) - - case Tokens.Token(Tokens.DeclareConst) => - val sym = parseSymbol - val sort = parseSort - (None, locals.withVariable(sym, - Variable(getIdentifier(sym), extractType(sort)).setPos(sym.optPos))) - - case Tokens.Token(Tokens.DeclareSort) => - val sym = parseSymbol - val arity = parseNumeral.value.toInt - val id = getIdentifier(sym) - (None, locals.withADT(sym, id).registerADT { - val tparams = List.range(0, arity).map { - i => TypeParameterDef(TypeParameter.fresh("A" + i).setPos(sym.optPos)).setPos(sym.optPos) - } - val field = ValDef(FreshIdentifier("val"), IntegerType).setPos(sym.optPos) - - new ADTConstructor(id, tparams, None, Seq(field), Set.empty) - }) - - case Tokens.Token(Tokens.CheckSat) => - // TODO: what do I do with this?? - (None, locals) - - case token => - throw new MissformedTIPException("unknown TIP command " + token, token.optPos) - } - - override protected def parseTermWithoutParens: Term = getPeekToken match { - case Tokens.SymbolLit("lambda") => - nextToken // eat the "lambda" token - val vars = parseMany(parseSortedVar _) - val term = parseTerm - FunctionApplication( - QualifiedIdentifier(SMTIdentifier(SSymbol("lambda"))), - vars.map { case SortedVar(sym, sort) => QualifiedIdentifier(SMTIdentifier(sym), Some(sort)) } :+ term - ) - - case _ => super.parseTermWithoutParens - } - - private def extractSignature(fd: SMTFunDef, tps: Seq[SSymbol])(implicit locals: Locals): FunDef = { - assert(!locals.isFunction(fd.name)) - val id = getIdentifier(fd.name) - val tparams = tps.map(sym => TypeParameterDef(locals.getGeneric(sym)).setPos(sym.optPos)) - - val params = fd.params.map { case SortedVar(s, sort) => - ValDef(getIdentifier(s), extractType(sort)).setPos(s.optPos) - } - - val returnType = extractType(fd.returnSort) - val body = Choose(ValDef(FreshIdentifier("res"), returnType), BooleanLiteral(true)) - - new FunDef(id, tparams, params, returnType, body, Set.empty).setPos(fd.name.optPos) - } - - private def extractFunction(fd: SMTFunDef, tps: Seq[SSymbol])(implicit locals: Locals): FunDef = { - val sig = if (locals.isFunction(fd.name)) { - locals.symbols.getFunction(locals.getFunction(fd.name)) - } else { - extractSignature(fd, tps) - } - - val bodyLocals = locals - .withVariables((fd.params zip sig.params).map(p => p._1.name -> p._2.toVariable)) - .withFunctions(if (locals.isFunction(fd.name)) Seq(fd.name -> sig) else Seq.empty) - - val fullBody = extractExpr(fd.body)(bodyLocals) - - new FunDef(sig.id, sig.tparams, sig.params, sig.returnType, fullBody, Set.empty).setPos(fd.name.optPos) - } - - private def isInstanceOfSymbol(sym: SSymbol)(implicit locals: Locals): Option[Identifier] = { - if (sym.name.startsWith("is-")) { - val adtSym = SSymbol(sym.name.split("-").tail.mkString("-")) - locals.lookupADT(adtSym) - } else { - None - } - } - - private def typeADTConstructor(id: Identifier, superType: Type)(implicit locals: Locals): ADTType = { - val tcons = locals.symbols.getADT(id).typed(locals.symbols).toConstructor - val troot = tcons.root.toType - locals.symbols.canBeSupertypeOf(troot, superType) match { - case Some(tmap) => locals.symbols.instantiateType(tcons.toType, tmap).asInstanceOf[ADTType] - case None => throw new MissformedTIPException( - "cannot construct full typing for " + tcons, - superType.getPos - ) - } - } - - private def instantiateTypeParams(tps: Seq[TypeParameterDef], formals: Seq[Type], actuals: Seq[Type]) - (implicit locals: Locals): Seq[Type] = { - assert(formals.size == actuals.size) - - import locals.symbols._ - val formal = bestRealType(tupleTypeWrap(formals)) - val actual = bestRealType(tupleTypeWrap(actuals)) - - // freshen the type parameters in case we're building a substitution that includes params from `tps` - val tpSubst: Map[Type, Type] = locals.symbols.typeParamsOf(actual).map(tp => tp -> tp.freshen).toMap - val tpRSubst = tpSubst.map(_.swap) - val substActual = locals.symbols.typeOps.replace(tpSubst, actual) - - canBeSupertypeOf(formal, substActual) match { - case Some(tmap) => tps.map(tpd => tmap.get(tpd.tp).map { - tpe => locals.symbols.typeOps.replace(tpRSubst, tpe) - }.getOrElse(tpd.tp)) - - case None => throw new MissformedTIPException( - s"could not instantiate $tps in $formals given $actuals", - actuals.headOption.map(_.getPos).getOrElse(NoPosition) - ) - } - } - - private def wrapAsInstanceOf(formals: Seq[Type], exprs: Seq[Expr])(implicit locals: Locals): Seq[Expr] = { - (formals zip exprs).map { case (tpe, e) => - (tpe, e.getType(locals.symbols)) match { - case (tp1: ADTType, tp2: ADTType) if tp1 != tp2 && locals.symbols.isSubtypeOf(tp1, tp2) => - AsInstanceOf(e, tp1) - case _ => e - } - } - } - - protected def extractExpr(term: Term)(implicit locals: Locals): Expr = (term match { - case QualifiedIdentifier(SimpleIdentifier(sym), None) if locals.isVariable(sym) => - locals.getVariable(sym) - - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("assume")), None), Seq(pred, body)) => - Assume(extractExpr(pred), extractExpr(body)) - - case SMTLet(binding, bindings, term) => - var locs = locals - val mapping = for (VarBinding(name, term) <- (binding +: bindings)) yield { - val e = extractExpr(term)(locs) - val tpe = e.getType(locs.symbols) - val vd = ValDef(getIdentifier(name), tpe).setPos(name.optPos) - locs = locs.withVariable(name, vd.toVariable) - vd -> e - } - - mapping.foldRight(extractExpr(term)(locs)) { case ((vd, e), body) => Let(vd, e, body).setPos(vd.getPos) } - - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("@")), None), fun +: args) => - Application(extractExpr(fun), args.map(extractExpr)) - - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("lambda")), None), args :+ body) => - val (vds, bindings) = args.map { case QualifiedIdentifier(SimpleIdentifier(s), Some(sort)) => - val vd = ValDef(getIdentifier(s), extractType(sort)).setPos(s.optPos) - (vd, s -> vd.toVariable) - }.unzip - - Lambda(vds, extractExpr(body)(locals.withVariables(bindings))) - - case SMTForall(sv, svs, term) => - val (vds, bindings) = (sv +: svs).map { case SortedVar(s, sort) => - val vd = ValDef(getIdentifier(s), extractType(sort)).setPos(s.optPos) - (vd, s -> vd.toVariable) - }.unzip - - Forall(vds, extractExpr(term)(locals.withVariables(bindings))) - - case Exists(sv, svs, term) => - val (vds, bindings) = (sv +: svs).map { case SortedVar(s, sort) => - val vd = ValDef(getIdentifier(s), extractType(sort)).setPos(s.optPos) - (vd, s -> vd.toVariable) - }.unzip - - val body = Not(extractExpr(term)(locals.withVariables(bindings))).setPos(term.optPos) - Forall(vds, body) - - case Core.ITE(cond, thenn, elze) => - IfExpr(extractExpr(cond), extractExpr(thenn), extractExpr(elze)) - - case SNumeral(n) => - IntegerLiteral(n) - - // TODO: hexadecimal case - //case SHexadecimal(value) => BVLiteral() - - case SBinary(bs) => - BVLiteral(BitSet.empty ++ bs.reverse.zipWithIndex.collect { case (true, i) => i }, bs.size) - - case SDecimal(value) => - FractionLiteral( - value.bigDecimal.movePointRight(value.scale).toBigInteger, - BigInt(10).pow(value.scale)) - - case SString(value) => - StringLiteral(value) - - case QualifiedIdentifier(SimpleIdentifier(sym), optSort) if locals.isADT(sym) => - val cons = locals.symbols.getADT(locals.getADT(sym)).asInstanceOf[ADTConstructor] - val tpe = optSort match { - case Some(sort) => - val tps = instantiateTypeParams( - cons.tparams, - Seq(cons.typed(locals.symbols).toType), - Seq(extractType(sort))) - cons.typed(tps)(locals.symbols).toType - case _ => - assert(cons.tparams.isEmpty) - cons.typed(locals.symbols).toType - } - ADT(tpe, Seq.empty) - - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(sym), None), args) - if locals.isADT(sym) => - val es = args.map(extractExpr) - val cons = locals.symbols.getADT(locals.getADT(sym)).asInstanceOf[ADTConstructor] - val tps = instantiateTypeParams(cons.tparams, cons.fields.map(_.tpe), es.map(_.getType(locals.symbols))) - val tcons = cons.typed(tps)(locals.symbols) - ADT(tcons.toType, wrapAsInstanceOf(tcons.fieldsTypes, es)) - - case QualifiedIdentifier(SimpleIdentifier(sym), optSort) if locals.isFunction(sym) => - val fd = locals.symbols.getFunction(locals.getFunction(sym)) - val tfd = optSort match { - case Some(sort) => - val tpe = extractType(sort) - val tps = instantiateTypeParams(fd.tparams, Seq(fd.returnType), Seq(tpe)) - fd.typed(tps)(locals.symbols) - - case None => - fd.typed(locals.symbols) - } - tfd.applied - - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(sym), None), args) - if locals.isFunction(sym) => - val es = args.map(extractExpr) - val fd = locals.symbols.getFunction(locals.getFunction(sym)) - val tps = instantiateTypeParams(fd.tparams, fd.params.map(_.tpe), es.map(_.getType(locals.symbols))) - val tfd = fd.typed(tps)(locals.symbols) - tfd.applied(wrapAsInstanceOf(tfd.params.map(_.tpe), es)) - - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(sym), None), Seq(term)) - if isInstanceOfSymbol(sym).isDefined => - val e = extractExpr(term) - val tpe = typeADTConstructor(isInstanceOfSymbol(sym).get, e.getType(locals.symbols)) - IsInstanceOf(e, tpe) - - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(sym), None), Seq(term)) - if locals.isADTSelector(sym) => - ADTSelector(extractExpr(term), locals.getADTSelector(sym)) - - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("distinct")), None), args) => - val es = args.map(extractExpr).toArray - val indexPairs = args.indices.flatMap(i1 => args.indices.map(i2 => (i1, i2))).filter(p => p._1 != p._2) - locals.symbols.andJoin( - indexPairs.map(p => Not(Equals(es(p._1), es(p._2)).setPos(term.optPos)).setPos(term.optPos))) - - case Core.Equals(e1, e2) => Equals(extractExpr(e1), extractExpr(e2)) - case Core.And(es @ _*) => And(es.map(extractExpr)) - case Core.Or(es @ _*) => Or(es.map(extractExpr)) - case Core.Implies(e1, e2) => Implies(extractExpr(e1), extractExpr(e2)) - case Core.Not(e) => Not(extractExpr(e)) - - case Core.True() => BooleanLiteral(true) - case Core.False() => BooleanLiteral(false) - - case Strings.Length(s) => StringLength(extractExpr(s)) - case Strings.Concat(e1, e2, es @ _*) => - es.foldLeft(StringConcat(extractExpr(e1), extractExpr(e2)).setPos(term.optPos)) { - (c,e) => StringConcat(c, extractExpr(e)).setPos(term.optPos) - } - - case Strings.Substring(e, start, end) => - SubString(extractExpr(e), extractExpr(start), extractExpr(end)) - - /* Ints extractors cover the Reals operations as well */ - - case Ints.Neg(e) => UMinus(extractExpr(e)) - case Ints.Add(e1, e2) => Plus(extractExpr(e1), extractExpr(e2)) - case Ints.Sub(e1, e2) => Minus(extractExpr(e1), extractExpr(e2)) - case Ints.Mul(e1, e2) => Times(extractExpr(e1), extractExpr(e2)) - case Ints.Div(e1, e2) => Division(extractExpr(e1), extractExpr(e2)) - case Ints.Mod(e1, e2) => Modulo(extractExpr(e1), extractExpr(e2)) - case Ints.Abs(e) => - val ie = extractExpr(e) - IfExpr( - LessThan(ie, IntegerLiteral(BigInt(0)).setPos(term.optPos)).setPos(term.optPos), - UMinus(ie).setPos(term.optPos), - ie - ) - - case Ints.LessThan(e1, e2) => LessThan(extractExpr(e1), extractExpr(e2)) - case Ints.LessEquals(e1, e2) => LessEquals(extractExpr(e1), extractExpr(e2)) - case Ints.GreaterThan(e1, e2) => GreaterThan(extractExpr(e1), extractExpr(e2)) - case Ints.GreaterEquals(e1, e2) => GreaterEquals(extractExpr(e1), extractExpr(e2)) - - case FixedSizeBitVectors.Not(e) => BVNot(extractExpr(e)) - case FixedSizeBitVectors.Neg(e) => UMinus(extractExpr(e)) - case FixedSizeBitVectors.And(e1, e2) => BVAnd(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.Or(e1, e2) => BVOr(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.XOr(e1, e2) => BVXor(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.Add(e1, e2) => Plus(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.Sub(e1, e2) => Minus(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.Mul(e1, e2) => Times(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.SDiv(e1, e2) => Division(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.SRem(e1, e2) => Remainder(extractExpr(e1), extractExpr(e2)) - - case FixedSizeBitVectors.SLessThan(e1, e2) => LessThan(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.SLessEquals(e1, e2) => LessEquals(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.SGreaterThan(e1, e2) => GreaterThan(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.SGreaterEquals(e1, e2) => GreaterEquals(extractExpr(e1), extractExpr(e2)) - - case FixedSizeBitVectors.ShiftLeft(e1, e2) => BVShiftLeft(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.AShiftRight(e1, e2) => BVAShiftRight(extractExpr(e1), extractExpr(e2)) - case FixedSizeBitVectors.LShiftRight(e1, e2) => BVLShiftRight(extractExpr(e1), extractExpr(e2)) - - case ArraysEx.Select(e1, e2) => MapApply(extractExpr(e1), extractExpr(e2)) - case ArraysEx.Store(e1, e2, e3) => MapUpdated(extractExpr(e1), extractExpr(e2), extractExpr(e3)) - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("const")), Some(sort)), Seq(dflt)) => - FiniteMap(Seq.empty, extractExpr(dflt), extractType(sort)) - - case Sets.Union(e1, e2) => SetUnion(extractExpr(e1), extractExpr(e2)) - case Sets.Intersection(e1, e2) => SetIntersection(extractExpr(e1), extractExpr(e2)) - case Sets.Setminus(e1, e2) => SetDifference(extractExpr(e1), extractExpr(e2)) - case Sets.Member(e1, e2) => ElementOfSet(extractExpr(e1), extractExpr(e2)) - case Sets.Subset(e1, e2) => SubsetOf(extractExpr(e1), extractExpr(e2)) - - case Sets.Singleton(e) => - val elem = extractExpr(e) - FiniteSet(Seq(elem), locals.symbols.bestRealType(elem.getType(locals.symbols))) - - case Sets.Insert(set, es @ _*) => - es.foldLeft(extractExpr(set))((s,e) => SetAdd(s, extractExpr(e))) - - // TODO: bags - - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("match")), None), s +: cases) => - val scrut = extractExpr(s) - val matchCases: Seq[(Option[Expr], Expr)] = cases.map { - case FunctionApplication( - QualifiedIdentifier(SimpleIdentifier(SSymbol("case")), None), - Seq(pat, term) - ) => pat match { - case QualifiedIdentifier(SimpleIdentifier(SSymbol("default")), None) => - (None, extractExpr(term)) - - case QualifiedIdentifier(SimpleIdentifier(sym), None) => - val id = locals.getADT(sym) - val tpe = typeADTConstructor(id, scrut.getType(locals.symbols)) - (Some(IsInstanceOf(scrut, tpe).setPos(sym.optPos)), extractExpr(term)) - - case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(sym), None), args) => - val id = locals.getADT(sym) - val tpe = typeADTConstructor(id, scrut.getType(locals.symbols)) - - val tcons = tpe.getADT(locals.symbols).toConstructor - val bindings = (tcons.fields zip args).map { - case (vd, QualifiedIdentifier(SimpleIdentifier(sym), None)) => - (sym, vd.id, vd.freshen) - } - - val expr = extractExpr(term)(locals.withVariables(bindings.map(p => p._1 -> p._3.toVariable))) - val fullExpr = bindings.foldRight(expr) { case ((s, id, vd), e) => - val selector = ADTSelector(AsInstanceOf(scrut, tpe).setPos(s.optPos), id).setPos(s.optPos) - Let(vd, selector, e).setPos(s.optPos) - } - - (Some(IsInstanceOf(scrut, tpe).setPos(sym.optPos)), fullExpr) - - case _ => throw new MissformedTIPException("unexpected match pattern " + pat, pat.optPos) - } - - case cse => throw new MissformedTIPException("unexpected match case " + cse, cse.optPos) - } - - val (withCond, withoutCond) = matchCases.partition(_._1.isDefined) - val (ifs, last) = if (withoutCond.size > 1) { - throw new MissformedTIPException("unexpected multiple defaults in " + term, term.optPos) - } else if (withoutCond.size == 1) { - (withCond.map(p => p._1.get -> p._2), withoutCond.head._2) - } else { - val wc = withCond.map(p => p._1.get -> p._2) - (wc.init, wc.last._2) - } - - ifs.foldRight(last) { case ((cond, body), elze) => IfExpr(cond, body, elze).setPos(cond.getPos) } - - }).setPos(term.optPos) - - protected def extractType(sort: Sort)(implicit locals: Locals): Type = (sort match { - case Sort(SMTIdentifier(SSymbol("bitvector"), Seq(SNumeral(n))), Seq()) => BVType(n.toInt) - case Sort(SimpleIdentifier(SSymbol("Bool")), Seq()) => BooleanType - case Sort(SimpleIdentifier(SSymbol("Int")), Seq()) => IntegerType - - case Sort(SimpleIdentifier(SSymbol("Array")), Seq(from, to)) => - MapType(extractType(from), extractType(to)) - - case Sort(SimpleIdentifier(SSymbol("Set")), Seq(base)) => - SetType(extractType(base)) - - case Sort(SimpleIdentifier(SSymbol("Bag")), Seq(base)) => - BagType(extractType(base)) - - case Sort(SimpleIdentifier(SSymbol("=>")), params :+ res) => - FunctionType(params.map(extractType), extractType(res)) - - case Sort(SimpleIdentifier(sym), Seq()) if locals.isGeneric(sym) => - locals.getGeneric(sym) - - case Sort(SimpleIdentifier(sym), tps) if locals.isADT(sym) => - ADTType(locals.getADT(sym), tps.map(extractType)) - - case _ => throw new MissformedTIPException("unexpected sort: " + sort, sort.id.symbol.optPos) - }).setPos(sort.id.symbol.optPos) - } -} - -object TIPParser { - def parse(file: File): (inox.trees.Symbols, inox.trees.Expr) = new TIPParser { - val trees: inox.trees.type = inox.trees - }.parse(file) - - def parse(reader: Reader): (inox.trees.Symbols, inox.trees.Expr) = new TIPParser { - val trees: inox.trees.type = inox.trees - }.parse(reader) -} diff --git a/src/main/scala/inox/solvers/smtlib/CVC4Solver.scala b/src/main/scala/inox/solvers/smtlib/CVC4Solver.scala index b76460324fbc3b6e6651eb5428c17a8edfa21d8c..738e344faf789615154301f5c44b3f1eb4a2f28d 100644 --- a/src/main/scala/inox/solvers/smtlib/CVC4Solver.scala +++ b/src/main/scala/inox/solvers/smtlib/CVC4Solver.scala @@ -10,7 +10,7 @@ trait CVC4Solver extends SMTLIBSolver with CVC4Target { import program.trees._ import SolverResponses._ - def interpreterOps(ctx: InoxContext) = { + def interpreterOpts = { Seq( "-q", "--produce-models", diff --git a/src/main/scala/inox/solvers/smtlib/CVC4Target.scala b/src/main/scala/inox/solvers/smtlib/CVC4Target.scala index 704db8f24dea25fe84cc4ef9609d6a05b422f4e6..b28755ba2f5feb05913bc1b3a23df152a35baa44 100644 --- a/src/main/scala/inox/solvers/smtlib/CVC4Target.scala +++ b/src/main/scala/inox/solvers/smtlib/CVC4Target.scala @@ -12,17 +12,16 @@ import _root_.smtlib.interpreters.CVC4Interpreter import _root_.smtlib.theories.experimental.Sets import _root_.smtlib.theories.experimental.Strings -trait CVC4Target extends SMTLIBTarget { +trait CVC4Target extends SMTLIBTarget with SMTLIBDebugger { import program._ import trees._ import symbols._ def targetName = "cvc4" - override def getNewInterpreter(ctx: InoxContext) = { - val opts = interpreterOps(ctx) + protected lazy val interpreter = { + val opts = interpreterOpts ctx.reporter.debug("Invoking solver with "+opts.mkString(" ")) - new CVC4Interpreter("cvc4", opts.toArray) } diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBDebugger.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBDebugger.scala new file mode 100644 index 0000000000000000000000000000000000000000..89d23a60d78f2a8bb357d684677be78d79576bd1 --- /dev/null +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBDebugger.scala @@ -0,0 +1,50 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package solvers +package smtlib + +import _root_.smtlib.parser.Terms._ + +trait SMTLIBDebugger extends SMTLIBTarget { + import program._ + + protected def interpreterOpts: Seq[String] + + implicit val debugSection: DebugSection + + override def free(): Unit = { + super.free() + debugOut.foreach(_.close()) + } + + /* Printing VCs */ + protected lazy val debugOut: Option[java.io.FileWriter] = { + if (ctx.reporter.isDebugEnabled) { + val file = "" // TODO: real file name + val n = DebugFileNumbers.next(targetName + file) + val fileName = s"smt-sessions/$targetName-$file-$n.smt2" + + val javaFile = new java.io.File(fileName) + javaFile.getParentFile.mkdirs() + + ctx.reporter.debug(s"Outputting smt session into $fileName") + + val fw = new java.io.FileWriter(javaFile, false) + fw.write("; Options: " + interpreterOpts.mkString(" ") + "\n") + + Some(fw) + } else { + None + } + } + + override def emit(cmd: SExpr, rawOut: Boolean = false): SExpr = { + debugOut.foreach { o => + interpreter.printer.printSExpr(cmd, o) + o.write("\n") + o.flush() + } + super.emit(cmd, rawOut = rawOut) + } +} diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala index cae4b915a78665bf7860f20a682e8ac340d2e578..755ed3dc93ea5f3cd7e1dbe4b811a5e4668c2e4b 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBSolver.scala @@ -8,7 +8,7 @@ import _root_.smtlib.parser.Commands.{FunDef => SMTFunDef, _} import _root_.smtlib.parser.Terms.{Identifier => _, _} import _root_.smtlib.parser.CommandsResponses._ -trait SMTLIBSolver extends Solver with SMTLIBTarget { +trait SMTLIBSolver extends Solver with SMTLIBTarget with SMTLIBDebugger { import program._ import trees._ diff --git a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala index a299edc50bcb7ddc7d92de9227013921a850be3b..890761d963ea4bb37625985d6e34a354bdb2f471 100644 --- a/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala +++ b/src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala @@ -21,7 +21,7 @@ import _root_.smtlib.parser.Terms.{ } import _root_.smtlib.parser.CommandsResponses._ import _root_.smtlib.theories.{Constructors => SmtLibConstructors, _} -import _root_.smtlib.interpreters.ProcessInterpreter +import _root_.smtlib.Interpreter trait SMTLIBTarget extends Interruptible with ADTManagers { val program: Program @@ -31,15 +31,9 @@ trait SMTLIBTarget extends Interruptible with ADTManagers { def targetName: String - implicit val debugSection: DebugSection - - protected def interpreterOps(ctx: InoxContext): Seq[String] - - protected def getNewInterpreter(ctx: InoxContext): ProcessInterpreter - protected def unsupported(t: Tree, str: String): Nothing - protected lazy val interpreter = getNewInterpreter(ctx) + protected val interpreter: Interpreter /* Interruptible interface */ private var interrupted = false @@ -50,46 +44,18 @@ trait SMTLIBTarget extends Interruptible with ADTManagers { interrupted = true interpreter.interrupt() } + override def recoverInterrupt(): Unit = { interrupted = false } - def free() = { + def free(): Unit = { interpreter.free() ctx.interruptManager.unregisterForInterrupts(this) - debugOut foreach { _.close } - } - - /* Printing VCs */ - protected lazy val debugOut: Option[java.io.FileWriter] = { - if (ctx.reporter.isDebugEnabled) { - val file = ""//ctx.files.headOption.map(_.getName).getOrElse("NA") - val n = DebugFileNumbers.next(targetName + file) - - val fileName = s"smt-sessions/$targetName-$file-$n.smt2" - - val javaFile = new java.io.File(fileName) - javaFile.getParentFile.mkdirs() - - ctx.reporter.debug(s"Outputting smt session into $fileName") - - val fw = new java.io.FileWriter(javaFile, false) - - fw.write("; Options: " + interpreterOps(ctx).mkString(" ") + "\n") - - Some(fw) - } else { - None - } } /* Send a command to the solver */ def emit(cmd: SExpr, rawOut: Boolean = false): SExpr = { - debugOut foreach { o => - SMTPrinter.printSExpr(cmd, o) - o.write("\n") - o.flush() - } interpreter.eval(cmd) match { case err @ Error(msg) if !interrupted && !rawOut => ctx.reporter.fatalError(s"Unexpected error from $targetName solver: $msg") diff --git a/src/main/scala/inox/solvers/smtlib/Z3Target.scala b/src/main/scala/inox/solvers/smtlib/Z3Target.scala index 5d8e307038924e5944eebbddb92ce62793ed0156..e803c54a3ddfc36ab1ffce0331bb70ff56321bb4 100644 --- a/src/main/scala/inox/solvers/smtlib/Z3Target.scala +++ b/src/main/scala/inox/solvers/smtlib/Z3Target.scala @@ -10,25 +10,21 @@ import _root_.smtlib.interpreters.Z3Interpreter import _root_.smtlib.theories.Core.{Equals => SMTEquals, _} import _root_.smtlib.theories._ -trait Z3Target extends SMTLIBTarget { - +trait Z3Target extends SMTLIBTarget with SMTLIBDebugger { import program._ import trees._ import symbols._ def targetName = "z3" - def interpreterOps(ctx: InoxContext) = { - Seq( - "-in", - "-smt2" - ) - } + protected def interpreterOpts = Seq( + "-in", + "-smt2" + ) - def getNewInterpreter(ctx: InoxContext) = { - val opts = interpreterOps(ctx) + protected val interpreter = { + val opts = interpreterOpts ctx.reporter.debug("Invoking solver "+targetName+" with "+opts.mkString(" ")) - new Z3Interpreter("z3", opts.toArray) } diff --git a/src/main/scala/inox/tip/Parser.scala b/src/main/scala/inox/tip/Parser.scala new file mode 100644 index 0000000000000000000000000000000000000000..613bdf879f3df6ee6f081478441ba731f98be77d --- /dev/null +++ b/src/main/scala/inox/tip/Parser.scala @@ -0,0 +1,600 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package tip + +import smtlib.lexer._ +import smtlib.parser.Commands.{FunDef => SMTFunDef, _} +import smtlib.parser.Terms.{Let => SMTLet, Forall => SMTForall, Identifier => SMTIdentifier, _} +import smtlib.theories._ +import smtlib.theories.experimental._ +import smtlib.extensions.tip.{Parser => SMTParser, Lexer => SMTLexer} +import smtlib.extensions.tip.Terms.{Lambda => SMTLambda, Application => SMTApplication, _} +import smtlib.extensions.tip.Commands._ + +import scala.collection.BitSet +import java.io.{Reader, File, BufferedReader, FileReader} + +import utils._ + +import scala.language.implicitConversions + +class MissformedTIPException(reason: String, pos: Position) + extends Exception("Missfomed TIP source @" + pos + ":\n" + reason) + +class Parser(file: File) { + import inox.trees._ + + protected val positions = new PositionProvider(new BufferedReader(new FileReader(file)), Some(file)) + + protected implicit def smtlibPositionToPosition(pos: Option[_root_.smtlib.common.Position]): Position = { + pos.map(p => positions.get(p.line, p.col)).getOrElse(NoPosition) + } + + def parseScript: (Symbols, Expr) = { + val parser = new SMTParser(new SMTLexer(positions.reader)) + val script = parser.parseScript + + var assertions: Seq[Expr] = Seq.empty + implicit var locals: Locals = NoLocals + + for (cmd <- script.commands) { + val (newAssertions, newLocals) = extractCommand(cmd) + assertions ++= newAssertions + locals = newLocals + } + + val expr: Expr = locals.symbols.andJoin(assertions) + (locals.symbols, expr) + } + + protected class Locals ( + funs: Map[SSymbol, Identifier], + adts: Map[SSymbol, Identifier], + selectors: Map[SSymbol, Identifier], + vars: Map[SSymbol, Variable], + tps: Map[SSymbol, TypeParameter], + val symbols: Symbols) { + + def isADT(sym: SSymbol): Boolean = adts.isDefinedAt(sym) + def lookupADT(sym: SSymbol): Option[Identifier] = adts.get(sym) + def getADT(sym: SSymbol): Identifier = adts.get(sym).getOrElse { + throw new MissformedTIPException("unknown ADT " + sym, sym.optPos) + } + + def withADT(sym: SSymbol, id: Identifier): Locals = withADTs(Seq(sym -> id)) + def withADTs(seq: Seq[(SSymbol, Identifier)]): Locals = + new Locals(funs, adts ++ seq, selectors, vars, tps, symbols) + + def isADTSelector(sym: SSymbol): Boolean = selectors.isDefinedAt(sym) + def getADTSelector(sym: SSymbol): Identifier = selectors.get(sym).getOrElse { + throw new MissformedTIPException("unknown ADT selector " + sym, sym.optPos) + } + + def withADTSelectors(seq: Seq[(SSymbol, Identifier)]): Locals = + new Locals(funs, adts, selectors ++ seq, vars, tps, symbols) + + def isGeneric(sym: SSymbol): Boolean = tps.isDefinedAt(sym) + def getGeneric(sym: SSymbol): TypeParameter = tps.get(sym).getOrElse { + throw new MissformedTIPException("unknown generic type " + sym, sym.optPos) + } + + def withGeneric(sym: SSymbol, tp: TypeParameter): Locals = withGenerics(Seq(sym -> tp)) + def withGenerics(seq: Seq[(SSymbol, TypeParameter)]): Locals = + new Locals(funs, adts, selectors, vars, tps ++ seq, symbols) + + def isVariable(sym: SSymbol): Boolean = vars.isDefinedAt(sym) + def getVariable(sym: SSymbol): Variable = vars.get(sym).getOrElse { + throw new MissformedTIPException("unknown variable " + sym, sym.optPos) + } + + def withVariable(sym: SSymbol, v: Variable): Locals = withVariables(Seq(sym -> v)) + def withVariables(seq: Seq[(SSymbol, Variable)]): Locals = + new Locals(funs, adts, selectors, vars ++ seq, tps, symbols) + + def isFunction(sym: SSymbol): Boolean = funs.isDefinedAt(sym) + def getFunction(sym: SSymbol): Identifier = funs.get(sym).getOrElse { + throw new MissformedTIPException("unknown function " + sym, sym.optPos) + } + + def withFunction(sym: SSymbol, fd: FunDef): Locals = withFunctions(Seq(sym -> fd)) + def withFunctions(fds: Seq[(SSymbol, FunDef)]): Locals = + new Locals(funs ++ fds.map(p => p._1 -> p._2.id), adts, selectors, vars, tps, + symbols.withFunctions(fds.map(_._2))) + + def registerADT(adt: ADTDefinition): Locals = registerADTs(Seq(adt)) + def registerADTs(defs: Seq[ADTDefinition]): Locals = + new Locals(funs, adts, selectors, vars, tps, symbols.withADTs(defs)) + } + + protected val NoLocals: Locals = new Locals( + Map.empty, Map.empty, Map.empty, Map.empty, Map.empty, NoSymbols) + + protected def extractCommand(cmd: Command) + (implicit locals: Locals): (Option[Expr], Locals) = cmd match { + case Assert(term) => + (Some(extractTerm(term)), locals) + + case AssertPar(tps, term) => + val tpsLocals = locals.withGenerics(tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos))) + (Some(extractTerm(term)(tpsLocals)), locals) + + case DeclareConst(sym, sort) => + (None, locals.withVariable(sym, + Variable(FreshIdentifier(sym.name), extractSort(sort)).setPos(sym.optPos))) + + case DeclareConstPar(tps, sym, sort) => + val tpsLocals = locals.withGenerics(tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos))) + (None, locals.withVariable(sym, + Variable(FreshIdentifier(sym.name), extractSort(sort)(tpsLocals)).setPos(sym.optPos))) + + case DefineFun(funDef) => + val fd = extractFunction(funDef, Seq.empty) + (None, locals.withFunction(funDef.name, fd)) + + case DefineFunPar(tps, funDef) => + val tpsLocals = locals.withGenerics(tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos))) + val fd = extractFunction(funDef, tps)(tpsLocals) + (None, locals.withFunction(funDef.name, fd)) + + case DefineFunRec(funDef) => + val fdsLocals = locals.withFunction(funDef.name, extractSignature(funDef, Seq.empty)) + val fd = extractFunction(funDef, Seq.empty)(fdsLocals) + (None, locals.withFunction(funDef.name, fd)) + + case DefineFunRecPar(tps, funDef) => + val tpsLocals = locals.withGenerics(tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos))) + val fdsLocals = tpsLocals.withFunction(funDef.name, extractSignature(funDef, tps)(tpsLocals)) + val fd = extractFunction(funDef, tps)(fdsLocals) + (None, locals.withFunction(funDef.name, fd)) + + case DefineFunsRec(funDecs, bodies) => + val funDefs = for ((funDec, body) <- funDecs zip bodies) yield { + SMTFunDef(funDec.name, funDec.params, funDec.returnSort, body) + } + val bodyLocals = locals.withFunctions(for (funDef <- funDefs) yield { + funDef.name -> extractSignature(funDef, Seq.empty) + }) + (None, locals.withFunctions(for (funDef <- funDefs) yield { + funDef.name -> extractFunction(funDef, Seq.empty)(bodyLocals) + })) + + case DefineFunsRecPar(funDecs, bodies) => + val funDefs = for ((funDec, body) <- funDecs zip bodies) yield (funDec match { + case Left(funDec) => (funDec.tps, SMTFunDef(funDec.name, funDec.params, funDec.returnSort, body)) + case Right(funDec) => (Seq.empty[SSymbol], SMTFunDef(funDec.name, funDec.params, funDec.returnSort, body)) + }) + val bodyLocals = locals.withFunctions(for ((tps, funDef) <- funDefs) yield { + val tpsLocals = locals.withGenerics(tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos))) + funDef.name -> extractSignature(funDef, tps)(tpsLocals) + }) + (None, locals.withFunctions(for ((tps, funDef) <- funDefs) yield { + val tpsLocals = bodyLocals.withGenerics(tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos))) + funDef.name -> extractFunction(funDef, tps)(tpsLocals) + })) + + case DeclareDatatypesPar(tps, datatypes) => + var locs = locals.withADTs(datatypes + .flatMap { case (sym, conss) => + val tpeId = FreshIdentifier(sym.name) + val cids = if (conss.size == 1) { + Seq(conss.head.sym -> tpeId) + } else { + conss.map(c => c.sym -> FreshIdentifier(c.sym.name)) + } + (sym -> tpeId) +: cids + }) + + val generics = tps.map(s => s -> TypeParameter.fresh(s.name).setPos(s.optPos)) + for ((sym, conss) <- datatypes) { + val adtLocals = locs.withGenerics(generics) + val children = for (Constructor(sym, fields) <- conss) yield { + val id = locs.getADT(sym) + val vds = fields.map { case (s, sort) => + ValDef(FreshIdentifier(s.name), extractSort(sort)(adtLocals)).setPos(s.optPos) + } + + (id, vds) + } + + val allVds: Set[ValDef] = children.flatMap(_._2).toSet + val allTparams: Set[TypeParameter] = children.flatMap(_._2).toSet.flatMap { + (vd: ValDef) => locs.symbols.typeParamsOf(vd.tpe): Set[TypeParameter] + } + + val tparams: Seq[TypeParameterDef] = tps.flatMap { sym => + val tp = adtLocals.getGeneric(sym) + if (allTparams(tp)) Some(TypeParameterDef(tp).setPos(sym.optPos)) else None + } + + val parent = if (children.size > 1) { + val id = adtLocals.getADT(sym) + locs = locs.registerADT( + new ADTSort(id, tparams, children.map(_._1), Set.empty).setPos(sym.optPos)) + Some(id) + } else { + None + } + + locs = locs.registerADTs((conss zip children).map { case (cons, (cid, vds)) => + new ADTConstructor(cid, tparams, parent, vds, Set.empty).setPos(cons.sym.optPos) + }).withADTSelectors((conss zip children).flatMap { case (Constructor(_, fields), (_, vds)) => + (fields zip vds).map(p => p._1._1 -> p._2.id) + }) + } + + (None, locs) + + case DeclareSort(sym, arity) => + val id = FreshIdentifier(sym.name) + (None, locals.withADT(sym, id).registerADT { + val tparams = List.range(0, arity).map { + i => TypeParameterDef(TypeParameter.fresh("A" + i).setPos(sym.optPos)).setPos(sym.optPos) + } + val field = ValDef(FreshIdentifier("val"), IntegerType).setPos(sym.optPos) + + new ADTConstructor(id, tparams, None, Seq(field), Set.empty) + }) + + case CheckSat() => + // TODO: what do I do with this?? + (None, locals) + + case _ => + throw new MissformedTIPException("unknown TIP command " + cmd, cmd.optPos) + } + + private def extractSignature(fd: SMTFunDef, tps: Seq[SSymbol])(implicit locals: Locals): FunDef = { + assert(!locals.isFunction(fd.name)) + val id = FreshIdentifier(fd.name.name) + val tparams = tps.map(sym => TypeParameterDef(locals.getGeneric(sym)).setPos(sym.optPos)) + + val params = fd.params.map { case SortedVar(s, sort) => + ValDef(FreshIdentifier(s.name), extractSort(sort)).setPos(s.optPos) + } + + val returnType = extractSort(fd.returnSort) + val body = Choose(ValDef(FreshIdentifier("res"), returnType), BooleanLiteral(true)) + + new FunDef(id, tparams, params, returnType, body, Set.empty).setPos(fd.name.optPos) + } + + private def extractFunction(fd: SMTFunDef, tps: Seq[SSymbol])(implicit locals: Locals): FunDef = { + val sig = if (locals.isFunction(fd.name)) { + locals.symbols.getFunction(locals.getFunction(fd.name)) + } else { + extractSignature(fd, tps) + } + + val bodyLocals = locals + .withVariables((fd.params zip sig.params).map(p => p._1.name -> p._2.toVariable)) + .withFunctions(if (locals.isFunction(fd.name)) Seq(fd.name -> sig) else Seq.empty) + + val fullBody = extractTerm(fd.body)(bodyLocals) + + new FunDef(sig.id, sig.tparams, sig.params, sig.returnType, fullBody, Set.empty).setPos(fd.name.optPos) + } + + private def isInstanceOfSymbol(sym: SSymbol)(implicit locals: Locals): Option[Identifier] = { + if (sym.name.startsWith("is-")) { + val adtSym = SSymbol(sym.name.split("-").tail.mkString("-")) + locals.lookupADT(adtSym) + } else { + None + } + } + + private def typeADTConstructor(id: Identifier, superType: Type)(implicit locals: Locals): ADTType = { + val tcons = locals.symbols.getADT(id).typed(locals.symbols).toConstructor + val troot = tcons.root.toType + locals.symbols.canBeSupertypeOf(troot, superType) match { + case Some(tmap) => locals.symbols.instantiateType(tcons.toType, tmap).asInstanceOf[ADTType] + case None => throw new MissformedTIPException( + "cannot construct full typing for " + tcons, + superType.getPos + ) + } + } + + private def instantiateTypeParams(tps: Seq[TypeParameterDef], formals: Seq[Type], actuals: Seq[Type]) + (implicit locals: Locals): Seq[Type] = { + assert(formals.size == actuals.size) + + import locals.symbols._ + val formal = bestRealType(tupleTypeWrap(formals)) + val actual = bestRealType(tupleTypeWrap(actuals)) + + // freshen the type parameters in case we're building a substitution that includes params from `tps` + val tpSubst: Map[Type, Type] = locals.symbols.typeParamsOf(actual).map(tp => tp -> tp.freshen).toMap + val tpRSubst = tpSubst.map(_.swap) + val substActual = locals.symbols.typeOps.replace(tpSubst, actual) + + canBeSupertypeOf(formal, substActual) match { + case Some(tmap) => tps.map(tpd => tmap.get(tpd.tp).map { + tpe => locals.symbols.typeOps.replace(tpRSubst, tpe) + }.getOrElse(tpd.tp)) + + case None => throw new MissformedTIPException( + s"could not instantiate $tps in $formals given $actuals", + actuals.headOption.map(_.getPos).getOrElse(NoPosition) + ) + } + } + + private def wrapAsInstanceOf(formals: Seq[Type], exprs: Seq[Expr])(implicit locals: Locals): Seq[Expr] = { + (formals zip exprs).map { case (tpe, e) => + (tpe, e.getType(locals.symbols)) match { + case (tp1: ADTType, tp2: ADTType) if tp1 != tp2 && locals.symbols.isSubtypeOf(tp1, tp2) => + AsInstanceOf(e, tp1) + case _ => e + } + } + } + + protected def extractTerm(term: Term)(implicit locals: Locals): Expr = (term match { + case QualifiedIdentifier(SimpleIdentifier(sym), None) if locals.isVariable(sym) => + locals.getVariable(sym) + + case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("assume")), None), Seq(pred, body)) => + Assume(extractTerm(pred), extractTerm(body)) + + case SMTLet(binding, bindings, term) => + var locs = locals + val mapping = for (VarBinding(name, term) <- (binding +: bindings)) yield { + val e = extractTerm(term)(locs) + val tpe = e.getType(locs.symbols) + val vd = ValDef(FreshIdentifier(name.name), tpe).setPos(name.optPos) + locs = locs.withVariable(name, vd.toVariable) + vd -> e + } + mapping.foldRight(extractTerm(term)(locs)) { case ((vd, e), body) => Let(vd, e, body).setPos(vd.getPos) } + + case SMTApplication(caller, args) => + Application(extractTerm(caller), args.map(extractTerm)) + + case SMTLambda(svs, term) => + val (vds, bindings) = svs.map { case SortedVar(s, sort) => + val vd = ValDef(FreshIdentifier(s.name), extractSort(sort)).setPos(s.optPos) + (vd, s -> vd.toVariable) + }.unzip + Lambda(vds, extractTerm(term)(locals.withVariables(bindings))) + + case SMTForall(sv, svs, term) => + val (vds, bindings) = (sv +: svs).map { case SortedVar(s, sort) => + val vd = ValDef(FreshIdentifier(s.name), extractSort(sort)).setPos(s.optPos) + (vd, s -> vd.toVariable) + }.unzip + Forall(vds, extractTerm(term)(locals.withVariables(bindings))) + + case Exists(sv, svs, term) => + val (vds, bindings) = (sv +: svs).map { case SortedVar(s, sort) => + val vd = ValDef(FreshIdentifier(s.name), extractSort(sort)).setPos(s.optPos) + (vd, s -> vd.toVariable) + }.unzip + val body = Not(extractTerm(term)(locals.withVariables(bindings))).setPos(term.optPos) + Forall(vds, body) + + case Core.ITE(cond, thenn, elze) => + IfExpr(extractTerm(cond), extractTerm(thenn), extractTerm(elze)) + + case SNumeral(n) => + IntegerLiteral(n) + + // TODO: hexadecimal case + //case SHexadecimal(value) => BVLiteral() + + case SBinary(bs) => + BVLiteral(BitSet.empty ++ bs.reverse.zipWithIndex.collect { case (true, i) => i }, bs.size) + + case SDecimal(value) => + FractionLiteral( + value.bigDecimal.movePointRight(value.scale).toBigInteger, + BigInt(10).pow(value.scale)) + + case SString(value) => + StringLiteral(value) + + case QualifiedIdentifier(SimpleIdentifier(sym), optSort) if locals.isADT(sym) => + val cons = locals.symbols.getADT(locals.getADT(sym)).asInstanceOf[ADTConstructor] + val tpe = optSort match { + case Some(sort) => + val tps = instantiateTypeParams( + cons.tparams, + Seq(cons.typed(locals.symbols).toType), + Seq(extractSort(sort))) + cons.typed(tps)(locals.symbols).toType + case _ => + assert(cons.tparams.isEmpty) + cons.typed(locals.symbols).toType + } + ADT(tpe, Seq.empty) + + case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(sym), None), args) + if locals.isADT(sym) => + val es = args.map(extractTerm) + val cons = locals.symbols.getADT(locals.getADT(sym)).asInstanceOf[ADTConstructor] + val tps = instantiateTypeParams(cons.tparams, cons.fields.map(_.tpe), es.map(_.getType(locals.symbols))) + val tcons = cons.typed(tps)(locals.symbols) + ADT(tcons.toType, wrapAsInstanceOf(tcons.fieldsTypes, es)) + + case QualifiedIdentifier(SimpleIdentifier(sym), optSort) if locals.isFunction(sym) => + val fd = locals.symbols.getFunction(locals.getFunction(sym)) + val tfd = optSort match { + case Some(sort) => + val tpe = extractSort(sort) + val tps = instantiateTypeParams(fd.tparams, Seq(fd.returnType), Seq(tpe)) + fd.typed(tps)(locals.symbols) + + case None => + fd.typed(locals.symbols) + } + tfd.applied + + case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(sym), None), args) + if locals.isFunction(sym) => + val es = args.map(extractTerm) + val fd = locals.symbols.getFunction(locals.getFunction(sym)) + val tps = instantiateTypeParams(fd.tparams, fd.params.map(_.tpe), es.map(_.getType(locals.symbols))) + val tfd = fd.typed(tps)(locals.symbols) + tfd.applied(wrapAsInstanceOf(tfd.params.map(_.tpe), es)) + + case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(sym), None), Seq(term)) + if isInstanceOfSymbol(sym).isDefined => + val e = extractTerm(term) + val tpe = typeADTConstructor(isInstanceOfSymbol(sym).get, e.getType(locals.symbols)) + IsInstanceOf(e, tpe) + + case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(sym), None), Seq(term)) + if locals.isADTSelector(sym) => + ADTSelector(extractTerm(term), locals.getADTSelector(sym)) + + case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("distinct")), None), args) => + val es = args.map(extractTerm).toArray + val indexPairs = args.indices.flatMap(i1 => args.indices.map(i2 => (i1, i2))).filter(p => p._1 != p._2) + locals.symbols.andJoin( + indexPairs.map(p => Not(Equals(es(p._1), es(p._2)).setPos(term.optPos)).setPos(term.optPos))) + + case Core.Equals(e1, e2) => Equals(extractTerm(e1), extractTerm(e2)) + case Core.And(es @ _*) => And(es.map(extractTerm)) + case Core.Or(es @ _*) => Or(es.map(extractTerm)) + case Core.Implies(e1, e2) => Implies(extractTerm(e1), extractTerm(e2)) + case Core.Not(e) => Not(extractTerm(e)) + + case Core.True() => BooleanLiteral(true) + case Core.False() => BooleanLiteral(false) + + case Strings.Length(s) => StringLength(extractTerm(s)) + case Strings.Concat(e1, e2, es @ _*) => + es.foldLeft(StringConcat(extractTerm(e1), extractTerm(e2)).setPos(term.optPos)) { + (c,e) => StringConcat(c, extractTerm(e)).setPos(term.optPos) + } + + case Strings.Substring(e, start, end) => + SubString(extractTerm(e), extractTerm(start), extractTerm(end)) + + /* Ints extractors cover the Reals operations as well */ + + case Ints.Neg(e) => UMinus(extractTerm(e)) + case Ints.Add(e1, e2) => Plus(extractTerm(e1), extractTerm(e2)) + case Ints.Sub(e1, e2) => Minus(extractTerm(e1), extractTerm(e2)) + case Ints.Mul(e1, e2) => Times(extractTerm(e1), extractTerm(e2)) + case Ints.Div(e1, e2) => Division(extractTerm(e1), extractTerm(e2)) + case Ints.Mod(e1, e2) => Modulo(extractTerm(e1), extractTerm(e2)) + case Ints.Abs(e) => + val ie = extractTerm(e) + IfExpr( + LessThan(ie, IntegerLiteral(BigInt(0)).setPos(term.optPos)).setPos(term.optPos), + UMinus(ie).setPos(term.optPos), + ie + ) + + case Ints.LessThan(e1, e2) => LessThan(extractTerm(e1), extractTerm(e2)) + case Ints.LessEquals(e1, e2) => LessEquals(extractTerm(e1), extractTerm(e2)) + case Ints.GreaterThan(e1, e2) => GreaterThan(extractTerm(e1), extractTerm(e2)) + case Ints.GreaterEquals(e1, e2) => GreaterEquals(extractTerm(e1), extractTerm(e2)) + + case FixedSizeBitVectors.Not(e) => BVNot(extractTerm(e)) + case FixedSizeBitVectors.Neg(e) => UMinus(extractTerm(e)) + case FixedSizeBitVectors.And(e1, e2) => BVAnd(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.Or(e1, e2) => BVOr(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.XOr(e1, e2) => BVXor(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.Add(e1, e2) => Plus(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.Sub(e1, e2) => Minus(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.Mul(e1, e2) => Times(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.SDiv(e1, e2) => Division(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.SRem(e1, e2) => Remainder(extractTerm(e1), extractTerm(e2)) + + case FixedSizeBitVectors.SLessThan(e1, e2) => LessThan(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.SLessEquals(e1, e2) => LessEquals(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.SGreaterThan(e1, e2) => GreaterThan(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.SGreaterEquals(e1, e2) => GreaterEquals(extractTerm(e1), extractTerm(e2)) + + case FixedSizeBitVectors.ShiftLeft(e1, e2) => BVShiftLeft(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.AShiftRight(e1, e2) => BVAShiftRight(extractTerm(e1), extractTerm(e2)) + case FixedSizeBitVectors.LShiftRight(e1, e2) => BVLShiftRight(extractTerm(e1), extractTerm(e2)) + + case ArraysEx.Select(e1, e2) => MapApply(extractTerm(e1), extractTerm(e2)) + case ArraysEx.Store(e1, e2, e3) => MapUpdated(extractTerm(e1), extractTerm(e2), extractTerm(e3)) + case FunctionApplication(QualifiedIdentifier(SimpleIdentifier(SSymbol("const")), Some(sort)), Seq(dflt)) => + FiniteMap(Seq.empty, extractTerm(dflt), extractSort(sort)) + + case Sets.Union(e1, e2) => SetUnion(extractTerm(e1), extractTerm(e2)) + case Sets.Intersection(e1, e2) => SetIntersection(extractTerm(e1), extractTerm(e2)) + case Sets.Setminus(e1, e2) => SetDifference(extractTerm(e1), extractTerm(e2)) + case Sets.Member(e1, e2) => ElementOfSet(extractTerm(e1), extractTerm(e2)) + case Sets.Subset(e1, e2) => SubsetOf(extractTerm(e1), extractTerm(e2)) + + case Sets.Singleton(e) => + val elem = extractTerm(e) + FiniteSet(Seq(elem), locals.symbols.bestRealType(elem.getType(locals.symbols))) + + case Sets.Insert(set, es @ _*) => + es.foldLeft(extractTerm(set))((s,e) => SetAdd(s, extractTerm(e))) + + // TODO: bags + + case Match(s, cases) => + val scrut = extractTerm(s) + val matchCases: Seq[(Option[Expr], Expr)] = cases.map(cse => cse.pattern match { + case Default => + (None, extractTerm(cse.rhs)) + + case CaseObject(sym) => + val id = locals.getADT(sym) + val tpe = typeADTConstructor(id, scrut.getType(locals.symbols)) + (Some(IsInstanceOf(scrut, tpe).setPos(sym.optPos)), extractTerm(cse.rhs)) + + case CaseClass(sym, args) => + val id = locals.getADT(sym) + val tpe = typeADTConstructor(id, scrut.getType(locals.symbols)) + + val tcons = tpe.getADT(locals.symbols).toConstructor + val bindings = (tcons.fields zip args).map { case (vd, sym) => (sym, vd.id, vd.freshen) } + + val expr = extractTerm(cse.rhs)(locals.withVariables(bindings.map(p => p._1 -> p._3.toVariable))) + val fullExpr = bindings.foldRight(expr) { case ((s, id, vd), e) => + val selector = ADTSelector(AsInstanceOf(scrut, tpe).setPos(s.optPos), id).setPos(s.optPos) + Let(vd, selector, e).setPos(s.optPos) + } + (Some(IsInstanceOf(scrut, tpe).setPos(sym.optPos)), fullExpr) + }) + + val (withCond, withoutCond) = matchCases.partition(_._1.isDefined) + val (ifs, last) = if (withoutCond.size > 1) { + throw new MissformedTIPException("unexpected multiple defaults in " + term, term.optPos) + } else if (withoutCond.size == 1) { + (withCond.map(p => p._1.get -> p._2), withoutCond.head._2) + } else { + val wc = withCond.map(p => p._1.get -> p._2) + (wc.init, wc.last._2) + } + + ifs.foldRight(last) { case ((cond, body), elze) => IfExpr(cond, body, elze).setPos(cond.getPos) } + }).setPos(term.optPos) + + protected def extractSort(sort: Sort)(implicit locals: Locals): Type = (sort match { + case Sort(SMTIdentifier(SSymbol("bitvector"), Seq(SNumeral(n))), Seq()) => BVType(n.toInt) + case Sort(SimpleIdentifier(SSymbol("Bool")), Seq()) => BooleanType + case Sort(SimpleIdentifier(SSymbol("Int")), Seq()) => IntegerType + + case Sort(SimpleIdentifier(SSymbol("Array")), Seq(from, to)) => + MapType(extractSort(from), extractSort(to)) + + case Sort(SimpleIdentifier(SSymbol("Set")), Seq(base)) => + SetType(extractSort(base)) + + case Sort(SimpleIdentifier(SSymbol("Bag")), Seq(base)) => + BagType(extractSort(base)) + + case Sort(SimpleIdentifier(SSymbol("=>")), params :+ res) => + FunctionType(params.map(extractSort), extractSort(res)) + + case Sort(SimpleIdentifier(sym), Seq()) if locals.isGeneric(sym) => + locals.getGeneric(sym) + + case Sort(SimpleIdentifier(sym), tps) if locals.isADT(sym) => + ADTType(locals.getADT(sym), tps.map(extractSort)) + + case _ => throw new MissformedTIPException("unexpected sort: " + sort, sort.id.symbol.optPos) + }).setPos(sort.id.symbol.optPos) +} diff --git a/src/main/scala/inox/tip/PositionProvider.scala b/src/main/scala/inox/tip/PositionProvider.scala new file mode 100644 index 0000000000000000000000000000000000000000..87e4f6711e946fd4b005c5d49df96507dba371ca --- /dev/null +++ b/src/main/scala/inox/tip/PositionProvider.scala @@ -0,0 +1,34 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package tip + +import java.io._ + +import utils._ + +class PositionProvider(_reader: Reader, _file: Option[File]) { + val (reader, file): (Reader, File) = _file match { + case Some(file) => (_reader, file) + case None => + val file = File.createTempFile("tip-input", ".smt2") + val writer = new BufferedWriter(new FileWriter(file)) + + val buffer = new Array[Char](1024) + var count: Int = 0 + while ((count = _reader.read(buffer)) != -1) { + writer.write(buffer, 0, count) + } + + val reader = new BufferedReader(new FileReader(file)) + (reader, file) + } + + private val fileLines: List[String] = scala.io.Source.fromFile(file).getLines.toList + + def get(line: Int, col: Int): OffsetPosition = { + val point = fileLines.take(line).map(_.length).sum + col + new OffsetPosition(line, col, point, file) + } +} + diff --git a/src/main/scala/inox/tip/Printer.scala b/src/main/scala/inox/tip/Printer.scala new file mode 100644 index 0000000000000000000000000000000000000000..aff75f9443ba91152b91ada93ea21e1684588ac2 --- /dev/null +++ b/src/main/scala/inox/tip/Printer.scala @@ -0,0 +1,200 @@ +/* Copyright 2009-2016 EPFL, Lausanne */ + +package inox +package tip + +import java.io.{File, Writer} + +import smtlib.parser.Terms.{Forall => SMTForall, Identifier => SMTIdentifier, _} +import smtlib.parser.Commands.{Constructor => SMTConstructor, _} +import smtlib.extensions.tip.Terms.{Lambda => SMTLambda, Application => SMTApplication, _} +import smtlib.extensions.tip.Commands._ +import smtlib.Interpreter + +import scala.collection.mutable.{Map => MutableMap} + +class Printer(val program: InoxProgram, writer: Writer) extends solvers.smtlib.SMTLIBTarget { + import program._ + import program.trees._ + import program.symbols._ + + def targetName = "tip" + + protected def unsupported(t: Tree, str: String): Nothing = { + throw new Unsupported(t, s"(of class ${t.getClass}) is unsupported by TIP printer:\n " + str) + } + + /* Note that we are NOT relying on a "real" interpreter here. We just + * need the printer for calls to [[emit]] to function correctly. */ + protected val interpreter = new Interpreter { + // the parser should never be used + val parser: smtlib.parser.Parser = null + + object printer extends smtlib.printer.Printer { + val name: String = "tip-printer" + protected def newContext(writer: Writer) = new smtlib.printer.PrintingContext(writer) + } + + def eval(cmd: SExpr): SExpr = { + printer.printSExpr(cmd, writer) + writer.write("\n") + writer.flush() + + smtlib.parser.CommandsResponses.Success + } + + def free(): Unit = { + writer.close() + } + + def interrupt(): Unit = free() + } + + def printScript(expr: Expr): Unit = { + val tparams = exprOps.collect(e => typeParamsOf(e.getType))(expr) + val bindings = exprOps.variablesOf(expr).map(v => v.id -> (id2sym(v.id): Term)).toMap + val cmd = if (tparams.nonEmpty) { + AssertPar(tparams.map(tp => id2sym(tp.id)).toSeq, toSMT(expr)(bindings)) + } else { + Assert(toSMT(expr)(bindings)) + } + emit(cmd) + } + + protected def liftADTType(adt: ADTType): Type = adt.getADT.definition.typed.toType + + protected val tuples: MutableMap[Int, TupleType] = MutableMap.empty + + override protected def declareStructuralSort(t: Type): Sort = t match { + case adt: ADTType => + adtManager.declareADTs(liftADTType(adt), declareDatatypes) + val tpSorts = adt.tps.map(declareSort) + Sort(SMTIdentifier(id2sym(adt.id)), tpSorts) + + case TupleType(ts) => + val tpe = tuples.getOrElseUpdate(ts.size, { + TupleType(List.range(0, ts.size).map(i => TypeParameter.fresh("A" + i))) + }) + adtManager.declareADTs(tpe, declareDatatypes) + val tpSorts = ts.map(declareSort) + Sort(sorts.toB(tpe).id, tpSorts) + + case _ => super.declareStructuralSort(t) + } + + override protected def declareDatatypes(datatypes: Seq[(Type, DataType)]): Unit = { + for ((tpe, DataType(id, _)) <- datatypes) { + sorts += tpe -> Sort(SMTIdentifier(id2sym(id))) + } + + val (generics, adts) = datatypes.partition { + case (_: TypeParameter, _) => true + case _ => false + } + + val genericSyms = generics.map { case (_, DataType(id, _)) => id2sym(id) } + + if (adts.nonEmpty) { + emit(DeclareDatatypesPar(genericSyms, + (for ((tpe, DataType(sym, cases)) <- adts.toList) yield { + id2sym(sym) -> (for (c <- cases) yield { + val s = id2sym(c.sym) + + testers += c.tpe -> SSymbol("is-" + s.name) + constructors += c.tpe -> s + + SMTConstructor(s, c.fields.zipWithIndex.map { case ((cs, t), i) => + selectors += (c.tpe, i) -> id2sym(cs) + (id2sym(cs), declareSort(t)) + }) + }).toList + }).toList + )) + } + } + + override protected def declareFunction(tfd: TypedFunDef): SSymbol = { + val fd = tfd.fd + + val scc = transitiveCallees(fd).filter(fd2 => transitivelyCalls(fd2, fd)) + if (scc.size <= 1) { + val (sym, params, returnSort, body) = ( + id2sym(fd.id), + fd.params.map(vd => SortedVar(id2sym(vd.id), declareSort(vd.tpe))), + declareSort(fd.returnType), + toSMT(fd.fullBody)(fd.params.map(vd => vd.id -> (id2sym(vd.id): Term)).toMap) + ) + + val tps = fd.tparams.map(tpd => declareSort(tpd.tp).id.symbol) + + emit((scc.isEmpty, tps.isEmpty) match { + case (true, true) => DefineFun(FunDef(sym, params, returnSort, body)) + case (false, true) => DefineFunRec(FunDef(sym, params, returnSort, body)) + case (true, false) => DefineFunPar(tps, FunDef(sym, params, returnSort, body)) + case (false, false) => DefineFunRecPar(tps, FunDef(sym, params, returnSort, body)) + }) + } else { + val (decs, bodies) = (for (fd <- scc.toList) yield { + val (sym, params, returnSort) = ( + id2sym(fd.id), + fd.params.map(vd => SortedVar(id2sym(vd.id), declareSort(vd.tpe))), + declareSort(fd.returnType) + ) + + val tps = fd.tparams.map(tpd => declareSort(tpd.tp).id.symbol) + + val dec = if (tps.isEmpty) { + Right(FunDec(sym, params, returnSort)) + } else { + Left(FunDecPar(tps, sym, params, returnSort)) + } + + val body = toSMT(fd.fullBody)(fd.params.map(vd => vd.id -> (id2sym(vd.id): Term)).toMap) + (dec, body) + }).unzip + + emit(if (decs.exists(_.isLeft)) { + DefineFunsRecPar(decs, bodies) + } else { + DefineFunsRec(decs.map(_.right.get), bodies) + }) + } + + id2sym(fd.id) + } + + override protected def toSMT(e: Expr)(implicit bindings: Map[Identifier, Term]): Term = e match { + case Lambda(args, body) => + val (newBindings, params) = args.map { vd => + val sym = id2sym(vd.id) + (vd.id -> (sym: Term), SortedVar(sym, declareSort(vd.tpe))) + }.unzip + SMTLambda(params, toSMT(body)(bindings ++ newBindings)) + + case Forall(args, body) => + val (newBindings, param +: params) = args.map { vd => + val sym = id2sym(vd.id) + (vd.id -> (sym: Term), SortedVar(sym, declareSort(vd.tpe))) + }.unzip + SMTForall(param, params, toSMT(body)(bindings ++ newBindings)) + + case Not(Forall(args, body)) => + val (newBindings, param +: params) = args.map { vd => + val sym = id2sym(vd.id) + (vd.id -> (sym: Term), SortedVar(sym, declareSort(vd.tpe))) + }.unzip + Exists(param, params, toSMT(body)(bindings ++ newBindings)) + + case Application(caller, args) => + SMTApplication(toSMT(caller), args.map(toSMT)) + + case Assume(pred, body) => + FunctionApplication( + QualifiedIdentifier(SMTIdentifier(SSymbol("assume")), None), + Seq(toSMT(pred), toSMT(body)) + ) + + case _ => super.toSMT(e) + } +} +