-
Manos Koukoutos authoredManos Koukoutos authored
Expressions.scala 14.36 KiB
/* Copyright 2009-2015 EPFL, Lausanne */
package leon
package purescala
import Common._
import Types._
import TypeOps._
import Definitions._
import Extractors._
import Constructors._
import ExprOps.replaceFromIDs
/** AST definitions for Pure Scala. */
object Expressions {
/* EXPRESSIONS */
abstract class Expr extends Tree with Typed
trait Terminal {
self: Expr =>
}
case class NoTree(tpe: TypeTree) extends Expr with Terminal {
val getType = tpe
}
/* This describes computational errors (unmatched case, taking min of an
* empty set, division by zero, etc.). It should always be typed according to
* the expected type. */
case class Error(tpe: TypeTree, description: String) extends Expr with Terminal {
val getType = tpe
}
case class Require(pred: Expr, body: Expr) extends Expr {
val getType = body.getType
}
case class Ensuring(body: Expr, pred: Expr) extends Expr {
val getType = pred.getType match {
case FunctionType(Seq(bodyType), BooleanType) if bodyType == body.getType => bodyType
case _ => Untyped
}
def toAssert: Expr = {
val res = FreshIdentifier("res", getType, true)
Let(res, body, Assert(application(pred, Seq(Variable(res))), Some("Postcondition failed @" + this.getPos), Variable(res)))
}
}
case class Assert(pred: Expr, error: Option[String], body: Expr) extends Expr {
val getType = body.getType
}
case class Choose(pred: Expr) extends Expr {
val getType = pred.getType match {
case FunctionType(from, to) if from.nonEmpty => // @mk why nonEmpty?
tupleTypeWrap(from)
case _ =>
Untyped
}
}
/* Like vals */
case class Let(binder: Identifier, value: Expr, body: Expr) extends Expr {
val getType = body.getType
}
case class LetDef(fd: FunDef, body: Expr) extends Expr {
val getType = body.getType
}
case class FunctionInvocation(tfd: TypedFunDef, args: Seq[Expr]) extends Expr {
val getType = tfd.returnType
}
/**
* OO Trees
*
* Both MethodInvocation and This get removed by phase MethodLifting. Methods become functions,
* This becomes first argument, and MethodInvocation become FunctionInvocation.
*/
case class MethodInvocation(rec: Expr, cd: ClassDef, tfd: TypedFunDef, args: Seq[Expr]) extends Expr {
val getType = {
// We need ot instanciate the type based on the type of the function as well as receiver
val fdret = tfd.returnType
val extraMap: Map[TypeParameterDef, TypeTree] = rec.getType match {
case ct: ClassType =>
(cd.tparams zip ct.tps).toMap
case _ =>
Map()
}
instantiateType(fdret, extraMap)
}
}
case class Application(callee: Expr, args: Seq[Expr]) extends Expr {
require(callee.getType.isInstanceOf[FunctionType])
val getType = callee.getType.asInstanceOf[FunctionType].to
}
case class Lambda(args: Seq[ValDef], body: Expr) extends Expr {
val getType = FunctionType(args.map(_.getType), body.getType).unveilUntyped
def paramSubst(realArgs: Seq[Expr]) = {
require(realArgs.size == args.size)
(args map { _.id } zip realArgs).toMap
}
def withParamSubst(realArgs: Seq[Expr], e: Expr) = {
replaceFromIDs(paramSubst(realArgs), e)
}
}
case class This(ct: ClassType) extends Expr with Terminal {
val getType = ct
}
case class IfExpr(cond: Expr, thenn: Expr, elze: Expr) extends Expr {
val getType = leastUpperBound(thenn.getType, elze.getType).getOrElse(Untyped).unveilUntyped
}
/*
* If you are not sure about the requirement you should use
* the tupleWrap in purescala.Constructors
*/
case class Tuple (exprs: Seq[Expr]) extends Expr {
require(exprs.size >= 2)
val getType = TupleType(exprs.map(_.getType)).unveilUntyped
}
/*
* Index is 1-based, first element of tuple is 1.
* If you are not sure that tuple has a TupleType,
* you should use tupleSelect in pureScala.Constructors
*/
case class TupleSelect(tuple: Expr, index: Int) extends Expr {
require(index >= 1)
val getType = tuple.getType match {
case TupleType(ts) =>
require(index <= ts.size)
ts(index - 1)
case _ =>
Untyped
}
}
case class MatchExpr(scrutinee: Expr, cases: Seq[MatchCase]) extends Expr {
require(cases.nonEmpty)
val getType = leastUpperBound(cases.map(_.rhs.getType)).getOrElse(Untyped).unveilUntyped
}
case class Passes(in: Expr, out : Expr, cases : Seq[MatchCase]) extends Expr {
require(cases.nonEmpty)
val getType = BooleanType
def asConstraint = {
val defaultCase = SimpleCase(WildcardPattern(None), out)
Equals(out, MatchExpr(in, cases :+ defaultCase))
}
}
case class MatchCase(pattern : Pattern, optGuard : Option[Expr], rhs: Expr) extends Tree {
def expressions: Seq[Expr] = optGuard.toList :+ rhs
}
sealed abstract class Pattern extends Tree {
val subPatterns: Seq[Pattern]
val binder: Option[Identifier]
private def subBinders = subPatterns.flatMap(_.binders).toSet
def binders: Set[Identifier] = subBinders ++ binder.toSet
def withBinder(b : Identifier) = { this match {
case Pattern(None, subs, builder) => builder(Some(b), subs)
case other => other
}}.copiedFrom(this)
}
case class InstanceOfPattern(binder: Option[Identifier], ct: ClassType) extends Pattern { // c: Class
val subPatterns = Seq()
}
case class WildcardPattern(binder: Option[Identifier]) extends Pattern { // c @ _
val subPatterns = Seq()
}
case class CaseClassPattern(binder: Option[Identifier], ct: CaseClassType, subPatterns: Seq[Pattern]) extends Pattern
case class TuplePattern(binder: Option[Identifier], subPatterns: Seq[Pattern]) extends Pattern
case class LiteralPattern[+T](binder: Option[Identifier], lit : Literal[T]) extends Pattern {
val subPatterns = Seq()
}
/* Propositional logic */
case class And(exprs: Seq[Expr]) extends Expr {
val getType = BooleanType
require(exprs.size >= 2)
}
object And {
def apply(a: Expr, b: Expr): Expr = And(Seq(a, b))
}
case class Or(exprs: Seq[Expr]) extends Expr {
val getType = BooleanType
require(exprs.size >= 2)
}
object Or {
def apply(a: Expr, b: Expr): Expr = Or(Seq(a, b))
}
case class Implies(lhs: Expr, rhs: Expr) extends Expr {
val getType = BooleanType
}
case class Not(expr: Expr) extends Expr {
val getType = BooleanType
}
case class Equals(lhs: Expr, rhs: Expr) extends Expr {
val getType = BooleanType
}
case class Variable(id: Identifier) extends Expr with Terminal {
val getType = id.getType
}
/* Literals */
sealed abstract class Literal[+T] extends Expr with Terminal {
val value: T
}
case class GenericValue(tp: TypeParameter, id: Int) extends Expr with Terminal {
val getType = tp
}
case class CharLiteral(value: Char) extends Literal[Char] {
val getType = CharType
}
case class IntLiteral(value: Int) extends Literal[Int] {
val getType = Int32Type
}
case class InfiniteIntegerLiteral(value: BigInt) extends Literal[BigInt] {
val getType = IntegerType
}
case class BooleanLiteral(value: Boolean) extends Literal[Boolean] {
val getType = BooleanType
}
case class UnitLiteral() extends Literal[Unit] {
val getType = UnitType
val value = ()
}
/* Case classes */
case class CaseClass(ct: CaseClassType, args: Seq[Expr]) extends Expr {
val getType = ct
}
case class IsInstanceOf(classType: ClassType, expr: Expr) extends Expr {
val getType = BooleanType
}
case class CaseClassSelector(classType: CaseClassType, caseClass: Expr, selector: Identifier) extends Expr {
val selectorIndex = classType.classDef.selectorID2Index(selector)
val getType = classType.fieldsTypes(selectorIndex)
}
/* Arithmetic */
case class Plus(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == IntegerType && rhs.getType == IntegerType)
val getType = IntegerType
}
case class Minus(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == IntegerType && rhs.getType == IntegerType)
val getType = IntegerType
}
case class UMinus(expr: Expr) extends Expr {
require(expr.getType == IntegerType)
val getType = IntegerType
}
case class Times(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == IntegerType && rhs.getType == IntegerType)
val getType = IntegerType
}
/*
* Division and Remainder follows Java/Scala semantics. Division corresponds
* to / operator on BigInt and Remainder corresponds to %. Note that in
* Java/Scala % is called remainder and the "mod" operator (Modulo in Leon) is also
* defined on BigInteger and differs from Remainder. The "mod" operator
* returns an always positive remainder, while Remainder could return
* a negative remainder. The following must hold:
*
* Division(x, y) * y + Remainder(x, y) == x
*/
case class Division(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == IntegerType && rhs.getType == IntegerType)
val getType = IntegerType
}
case class Remainder(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == IntegerType && rhs.getType == IntegerType)
val getType = IntegerType
}
case class Modulo(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == IntegerType && rhs.getType == IntegerType)
val getType = IntegerType
}
case class LessThan(lhs: Expr, rhs: Expr) extends Expr {
val getType = BooleanType
}
case class GreaterThan(lhs: Expr, rhs: Expr) extends Expr {
val getType = BooleanType
}
case class LessEquals(lhs: Expr, rhs: Expr) extends Expr {
val getType = BooleanType
}
case class GreaterEquals(lhs: Expr, rhs: Expr) extends Expr {
val getType = BooleanType
}
/* Bit-vector arithmetic */
case class BVPlus(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == Int32Type && rhs.getType == Int32Type)
val getType = Int32Type
}
case class BVMinus(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == Int32Type && rhs.getType == Int32Type)
val getType = Int32Type
}
case class BVUMinus(expr: Expr) extends Expr {
require(expr.getType == Int32Type)
val getType = Int32Type
}
case class BVTimes(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == Int32Type && rhs.getType == Int32Type)
val getType = Int32Type
}
case class BVDivision(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == Int32Type && rhs.getType == Int32Type)
val getType = Int32Type
}
case class BVRemainder(lhs: Expr, rhs: Expr) extends Expr {
require(lhs.getType == Int32Type && rhs.getType == Int32Type)
val getType = Int32Type
}
case class BVNot(expr: Expr) extends Expr {
val getType = Int32Type
}
case class BVAnd(lhs: Expr, rhs: Expr) extends Expr {
val getType = Int32Type
}
case class BVOr(lhs: Expr, rhs: Expr) extends Expr {
val getType = Int32Type
}
case class BVXOr(lhs: Expr, rhs: Expr) extends Expr {
val getType = Int32Type
}
case class BVShiftLeft(lhs: Expr, rhs: Expr) extends Expr {
val getType = Int32Type
}
case class BVAShiftRight(lhs: Expr, rhs: Expr) extends Expr {
val getType = Int32Type
}
case class BVLShiftRight(lhs: Expr, rhs: Expr) extends Expr {
val getType = Int32Type
}
/* Set expressions */
case class FiniteSet(elements: Set[Expr], base: TypeTree) extends Expr {
val getType = SetType(base).unveilUntyped
}
case class ElementOfSet(element: Expr, set: Expr) extends Expr {
val getType = BooleanType
}
case class SetCardinality(set: Expr) extends Expr {
val getType = Int32Type
}
case class SubsetOf(set1: Expr, set2: Expr) extends Expr {
val getType = BooleanType
}
case class SetIntersection(set1: Expr, set2: Expr) extends Expr {
val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped
}
case class SetUnion(set1: Expr, set2: Expr) extends Expr {
val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped
}
case class SetDifference(set1: Expr, set2: Expr) extends Expr {
val getType = leastUpperBound(Seq(set1, set2).map(_.getType)).getOrElse(Untyped).unveilUntyped
}
/* Map operations. */
case class FiniteMap(singletons: Seq[(Expr, Expr)], keyType: TypeTree, valueType: TypeTree) extends Expr {
val getType = MapType(keyType, valueType).unveilUntyped
}
case class MapGet(map: Expr, key: Expr) extends Expr {
val getType = map.getType match {
case MapType(_, to) => to
case _ => Untyped
}
}
case class MapUnion(map1: Expr, map2: Expr) extends Expr {
val getType = leastUpperBound(Seq(map1, map2).map(_.getType)).getOrElse(Untyped).unveilUntyped
}
case class MapDifference(map: Expr, keys: Expr) extends Expr {
val getType = map.getType
}
case class MapIsDefinedAt(map: Expr, key: Expr) extends Expr {
val getType = BooleanType
}
/* Array operations */
case class ArraySelect(array: Expr, index: Expr) extends Expr {
val getType = array.getType match {
case ArrayType(base) =>
base
case _ =>
Untyped
}
}
case class ArrayUpdated(array: Expr, index: Expr, newValue: Expr) extends Expr {
val getType = array.getType match {
case ArrayType(base) =>
leastUpperBound(base, newValue.getType).map(ArrayType).getOrElse(Untyped).unveilUntyped
case _ =>
Untyped
}
}
case class ArrayLength(array: Expr) extends Expr {
val getType = Int32Type
}
case class NonemptyArray(elems: Map[Int, Expr], defaultLength: Option[(Expr, Expr)]) extends Expr {
private val elements = elems.values.toList ++ defaultLength.map{_._1}
val getType = ArrayType(optionToType(leastUpperBound(elements map { _.getType}))).unveilUntyped
}
case class EmptyArray(tpe: TypeTree) extends Expr with Terminal {
val getType = ArrayType(tpe).unveilUntyped
}
/* Special trees */
// Provide an oracle (synthesizable, all-seeing choose)
case class WithOracle(oracles: List[Identifier], body: Expr) extends Expr with Extractable {
require(oracles.nonEmpty)
val getType = body.getType
def extract = {
Some((Seq(body), (es: Seq[Expr]) => WithOracle(oracles, es.head).setPos(this)))
}
}
case class Hole(tpe: TypeTree, alts: Seq[Expr]) extends Expr with Extractable {
val getType = tpe
def extract = {
Some((alts, (es: Seq[Expr]) => Hole(tpe, es).setPos(this)))
}
}
}