Skip to content
Snippets Groups Projects
Commit 4ef6b4c5 authored by Ali Sinan Köksal's avatar Ali Sinan Köksal
Browse files

We can now combine constraints and invoke `solve' on them to get a Scala

value as solution
parent feca614a
Branches
Tags
No related merge requests found
import cp.Definitions._
import cp.Constraints._
object FirstClassConstraints {
def oneOf(lst : List[Int]) : Constraint1[Int] = lst match {
case Nil => (x : Int) => false
case c::cs => ((x : Int) => x == c) || oneOf(cs)
}
@spec object Specs {
abstract class A
case class B() extends A
case class C() extends A
}
def main(args: Array[String]) : Unit = {
val outer: Int = 42
val pred1 : Constraint1[Int] = (x : Int) => x > outer
val pred2 : Constraint1[Int] = (y : Int) => y == outer
val orPred = pred1 || pred2
val solution: Int = orPred.solve
println(solution)
}
}
......@@ -102,7 +102,7 @@ object RedBlackTree {
println("Fixing size of trees to " + (bound))
val sw = new Stopwatch("Fixed-size enumeration", false)
sw.start
for (tree <- findAll((t : Tree) => isRedBlackTree(t) && boundValues(t, bound - 1))) {
for (tree <- findAll((t : Tree) => isRedBlackTree(t) && boundValues(t, bound - 1) && size(t) == bound)) {
solutionSet += tree
}
sw.stop
......
......@@ -7,6 +7,7 @@ import purescala.Definitions._
import purescala.Trees._
import Serialization._
import Constraints._
trait CallTransformation
extends TypingTransformers
......@@ -36,13 +37,15 @@ trait CallTransformation
signatures.toSet
}
/** extract predicates beforehand so the stored last used ID value is valid */
def predicateMap(unit: CompilationUnit) : Map[Position,(FunDef,Option[Expr],Option[Expr])] = {
val extracted = scala.collection.mutable.Map[Position,(FunDef,Option[Expr],Option[Expr])]()
def extractPredicates(tree: Tree) = tree match {
case Apply(TypeApply(Select(s: Select, n), typeTreeList), List(predicate: Function)) if (cpDefinitionsModule == s.symbol &&
(n.toString == "choose" || n.toString == "find" || n.toString == "findAll")) =>
case Apply(TypeApply(Select(Select(cpIdent, definitionsName), pred2cons1Name), typeTreeList), List(predicate: Function)) if
(definitionsName.toString == "Definitions" && pred2cons1Name.toString.matches("pred2cons\\d")) => {
val Function(funValDefs, funBody) = predicate
extracted += (predicate.pos -> extractPredicate(unit, funValDefs, funBody))
}
case _ =>
}
new ForeachTreeTraverser(extractPredicates).traverse(unit.body)
......@@ -63,23 +66,27 @@ trait CallTransformation
override def transform(tree: Tree) : Tree = {
tree match {
case a @ Apply(TypeApply(Select(s: Select, n), typeTreeList), rhs @ List(predicate: Function)) if (cpDefinitionsModule == s.symbol && n.toString == "choose") => {
/** Transform implicit conversions to Constraint into instantiation of Constraints */
case Apply(TypeApply(Select(Select(cpIdent, definitionsName), pred2cons1Name), typeTreeList), List(predicate: Function)) if
(definitionsName.toString == "Definitions" && pred2cons1Name.toString.matches("pred2cons\\d")) => {
println("i'm in conversion from pred to constraint!")
val Function(funValDefs, funBody) = predicate
val (fd, minExpr, maxExpr) = extractedPredicates(predicate.pos)
val outputVars : Seq[Identifier] = fd.args.map(_.id)
purescalaReporter.info("Considering predicate:")
purescalaReporter.info(fd)
val codeGen = new CodeGenerator(unit, currentOwner, tree.pos)
fd.body match {
case None => purescalaReporter.error("Could not extract `choose' predicate: " + funBody); super.transform(tree)
case None => purescalaReporter.error("Could not extract predicate: " + funBody); super.transform(tree)
case Some(b) =>
// serialize expression
val serializedExpr = serialize(b)
// compute input variables
val inputVars : Seq[Identifier] = (variablesOf(b) ++ (minExpr match {
case Some(e) => variablesOf(e)
......@@ -98,35 +105,30 @@ trait CallTransformation
// serialize outputVars sequence
val serializedOutputVars = serialize(outputVars)
// input constraints
val inputConstraints : Seq[Tree] = (for (iv <- inputVars) yield {
codeGen.inputEquality(serializedInputVarList, iv, scalaToExprSym)
})
// sequence of input values
val inputVarValues : Tree = codeGen.inputVarValues(serializedInputVarList, inputVars, scalaToExprSym)
val inputConstraintsConjunction = if (inputVars.isEmpty) codeGen.trueLiteral else codeGen.andExpr(inputConstraints)
val exprSeqTree = (minExpr, maxExpr) match {
case (None, None) => {
codeGen.chooseExecCode(serializedProg, serializedExpr, serializedOutputVars, inputConstraintsConjunction)
}
case (Some(minE), None) => {
val serializedMinExpr = serialize(minE)
codeGen.chooseMinimizingExecCode(serializedProg, serializedExpr, serializedOutputVars, serializedMinExpr, inputConstraintsConjunction)
}
case (None, Some(maxE)) => {
val serializedMaxExpr = serialize(maxE)
codeGen.chooseMaximizingExecCode(serializedProg, serializedExpr, serializedOutputVars, serializedMaxExpr, inputConstraintsConjunction)
}
case _ =>
scala.Predef.error("Unreachable case")
}
// create constraint instance
val code = codeGen.newConstraint(exprToScalaSym, serializedProg, serializedInputVarList, serializedOutputVars, serializedExpr, inputVarValues, outputVars.size)
typer.typed(atOwner(currentOwner) {
exprSeqToScalaSyms(typeTreeList) APPLY exprSeqTree
code
})
}
}
case a @ Apply(TypeApply(Select(s: Select, n), typeTreeList), rhs @ List(constraint: Constraint)) if (cpDefinitionsModule == s.symbol && n.toString == "choose") => {
val codeGen = new CodeGenerator(unit, currentOwner, tree.pos)
val serializedConstraint = serialize(constraint)
val exprSeqTree = codeGen.chooseExecCode(serializedProg, serializedConstraint)
typer.typed(atOwner(currentOwner) {
exprSeqToScalaSyms(typeTreeList) APPLY exprSeqTree
})
}
case a @ Apply(TypeApply(Select(s: Select, n), typeTreeList), rhs @ List(predicate: Function)) if (cpDefinitionsModule == s.symbol && n.toString == "find") => {
val Function(funValDefs, funBody) = predicate
......
......@@ -44,6 +44,13 @@ trait CodeGeneration {
private lazy val skipCounterFunction = definitions.getMember(runtimeMethodsModule, "skipCounter")
private lazy val copySettingsFunction = definitions.getMember(runtimeMethodsModule, "copySettings")
private lazy val baseConstraintClasses = List(
definitions.getClass("cp.Constraints.BaseConstraint1"),
definitions.getClass("cp.Constraints.BaseConstraint2")
)
private lazy val converterClass = definitions.getClass("cp.Converter")
private lazy val serializationModule = definitions.getModule("cp.Serialization")
private lazy val getProgramFunction = definitions.getMember(serializationModule, "getProgram")
private lazy val getInputVarListFunction = definitions.getMember(serializationModule, "getInputVarList")
......@@ -90,8 +97,9 @@ trait CodeGeneration {
(newSerialized(serializedProg), newSerialized(serializedExpr), newSerialized(serializedOutputVars), newSerialized(serializedOptExpr), inputConstraints)
}
def chooseExecCode(serializedProg : Serialized, serializedExpr : Serialized, serializedOutputVars : Serialized, inputConstraints : Tree) : Tree = {
execCode(chooseExecFunction, serializedProg, serializedExpr, serializedOutputVars, inputConstraints)
def chooseExecCode(serializedProg : Serialized, serializedConstraint : Serialized) : Tree = {
(cpPackage DOT runtimeMethodsModule DOT chooseExecFunction) APPLY
(newSerialized(serializedProg), newSerialized(serializedConstraint))
}
def chooseMinimizingExecCode(serializedProg : Serialized, serializedExpr : Serialized, serializedOutputVars : Serialized, serializedMinExpr : Serialized, inputConstraints : Tree) : Tree = {
......@@ -307,6 +315,38 @@ trait CodeGeneration {
(DEF(methodSym) === matchExpr, methodSym)
}
def inputVarValues(serializedInputVarList : Serialized, inputVars : Seq[Identifier], scalaToExprSym : Symbol) : Tree = {
val inputVarTrees = inputVars.map((iv: Identifier) => scalaToExprSym APPLY variablesToTrees(Variable(iv))).toList
(scalaPackage DOT collectionModule DOT immutableModule DOT definitions.ListModule DOT listModuleApplyFunction) APPLY (inputVarTrees)
}
def newConstraint(exprToScalaSym : Symbol, serializedProg : Serialized, serializedInputVarList : Serialized, serializedOutputVars : Serialized, serializedExpr : Serialized, inputVarValues : Tree, arity : Int) : Tree = {
NEW(
ID(baseConstraintClasses(arity-1)),
newConverter(exprToScalaSym),
newSerialized(serializedProg),
newSerialized(serializedInputVarList),
newSerialized(serializedOutputVars),
newSerialized(serializedExpr),
inputVarValues
)
}
def newConverter(exprToScalaSym : Symbol) : Tree = {
val anonFunSym = owner.newValue(NoPosition, nme.ANON_FUN_NAME) setInfo (exprToScalaSym.tpe)
val argValue = anonFunSym.newValue(NoPosition, unit.fresh.newName(NoPosition, "x")) setInfo (exprClass.tpe)
val anonFun = Function(
List(ValDef(argValue, EmptyTree)),
exprToScalaSym APPLY ID(argValue)
) setSymbol anonFunSym
NEW(
ID(converterClass),
anonFun
)
}
def inputEquality(serializedInputVarList : Serialized, varId : Identifier, scalaToExprSym : Symbol) : Tree = {
NEW(
ID(equalsClass),
......
package cp
import Serialization._
import purescala.Definitions._
import purescala.Trees._
import purescala.TypeTrees._
import purescala.Common._
import purescala.{QuietReporter,DefaultReporter}
import purescala.FairZ3Solver
import Definitions.{UnsatisfiableConstraintException,UnknownConstraintException}
object Constraints {
final class NotImplementedException extends Exception
private val silent = true
private def newReporter() = if (silent) new QuietReporter() else new DefaultReporter()
private def newSolver() = new FairZ3Solver(newReporter())
sealed trait Constraint
private def modelValue(varId: Identifier, model: Map[Identifier, Expr]) : Expr = model.get(varId) match {
case Some(value) => value
case None => simplestValue(varId.getType)
}
def programOf(constraint : Constraint) : Program = constraint match {
case bc : BaseConstraint => bc.program
case NAryConstraint(cs) => programOf(cs.head)
}
def exprOf(constraint : Constraint) : Expr = constraint match {
case bc : BaseConstraint => bc.exprWithIndices
case OrConstraint(cs) => Or(cs map exprOf)
}
def typesOf(constraint : Constraint) : Seq[TypeTree] = constraint match {
case bc : BaseConstraint => bc.outputVars.map(_.getType)
case NAryConstraint(cs) => typesOf(cs.head)
}
def envOf(constraint : Constraint) : Map[Variable,Expr] = constraint match {
case bc : BaseConstraint => bc.env
case NAryConstraint(cs) =>
val allEnvs = cs map (envOf(_))
val distinctEnvs = allEnvs.distinct
if (distinctEnvs.size > 1) {
throw new Exception("Environments differ in constraint: \n" + distinctEnvs.mkString("\n"))
}
allEnvs(0)
}
def converterOf(constraint : Constraint) : Converter = constraint match {
case bc : BaseConstraint => bc.converter
case NAryConstraint(cs) => converterOf(cs.head)
}
private def exprSeqSolution(constraint : Constraint) : Seq[Expr] = {
val solver = newSolver()
val program = programOf(constraint)
solver.setProgram(program)
// println("My program is")
// println(program)
val expr = exprOf(constraint)
// println("My expr is")
// println(expr)
val outputVarTypes = typesOf(constraint)
val freshOutputIDs = outputVarTypes.zipWithIndex.map{ case (t, idx) => FreshIdentifier("x" + idx).setType(t) }
val deBruijnIndices = outputVarTypes.zipWithIndex.map{ case (t, idx) => DeBruijnIndex(idx).setType(t) }
val exprWithFreshIDs = replace((deBruijnIndices zip (freshOutputIDs map (Variable(_)))).toMap, expr)
// println("Expr with fresh IDs")
// println(exprWithFreshIDs)
val env = envOf(constraint)
// println("Environment")
// println(env)
val inputConstraints = if (env.isEmpty) BooleanLiteral(true) else And(env.map{ case (v, e) => Equals(v, e) }.toSeq)
val (outcome, model) = solver.restartAndDecideWithModel(And(exprWithFreshIDs, inputConstraints), false)
val exprSeq = outcome match {
case Some(false) =>
freshOutputIDs.map(id => modelValue(id, model))
case Some(true) =>
throw new UnsatisfiableConstraintException()
case None =>
throw new UnknownConstraintException()
}
// println("Solution!")
// println(exprSeq)
exprSeq
}
sealed trait Constraint1[A] extends Constraint {
def solve : A = {
val convertingFunction = converterOf(this).exprSeq2scala1[A] _
convertingFunction(exprSeqSolution(this))
}
def ||(other : Constraint1[A]) : Constraint1[A] = OrConstraint1[A](this, other)
}
sealed trait Constraint2[A,B] extends Constraint {
def ||(other : Constraint2[A,B]) : Constraint2[A,B] = OrConstraint2[A,B](this, other)
}
abstract class BaseConstraint(conv : Converter, serializedProg : Serialized, serializedInputVars : Serialized, serializedOutputVars : Serialized, serializedExpr : Serialized, inputVarValues : Seq[Expr])
extends Constraint {
val converter : Converter = conv
lazy val program : Program = deserialize[Program](serializedProg)
lazy val inputVars : Seq[Variable] = deserialize[Seq[Variable]](serializedInputVars)
lazy val outputVars : Seq[Identifier] = deserialize[Seq[Identifier]](serializedOutputVars)
lazy val expr : Expr = deserialize[Expr](serializedExpr)
lazy val env : Map[Variable,Expr] = (inputVars zip inputVarValues).toMap
lazy val deBruijnIndices: Seq[DeBruijnIndex] = outputVars.zipWithIndex.map{ case (v, idx) => DeBruijnIndex(idx).setType(v.getType) }
lazy val exprWithIndices: Expr = replace(((outputVars map (Variable(_))) zip deBruijnIndices).toMap, expr)
}
case class BaseConstraint1[A](conv : Converter, serializedProg : Serialized, serializedInputVars : Serialized, serializedOutputVars : Serialized, serializedExpr : Serialized, inputVarValues : Seq[Expr])
extends BaseConstraint(conv, serializedProg, serializedInputVars, serializedOutputVars, serializedExpr, inputVarValues) with Constraint1[A]
case class BaseConstraint2[A,B](conv : Converter, serializedProg : Serialized, serializedInputVars : Serialized, serializedOutputVars : Serialized, serializedExpr : Serialized, inputVarValues : Seq[Expr])
extends BaseConstraint(conv, serializedProg, serializedInputVars, serializedOutputVars, serializedExpr, inputVarValues) with Constraint2[A,B]
class OrConstraint1[A](val exprs : Seq[Constraint1[A]]) extends Constraint1[A]
object OrConstraint1 {
def apply[A](l : Constraint1[A], r : Constraint1[A]) : Constraint1[A] = {
new OrConstraint1[A](Seq(l,r))
}
def unapply[A](or : OrConstraint1[A]) : Option[Seq[Constraint1[A]]] =
if (or == null) None else Some(or.exprs)
}
class OrConstraint2[A,B](val exprs : Seq[Constraint2[A,B]]) extends Constraint2[A,B]
object OrConstraint2 {
def apply[A,B](l : Constraint2[A,B], r : Constraint2[A,B]) : Constraint2[A,B] = {
new OrConstraint2[A,B](Seq(l,r))
}
def unapply[A,B](or : OrConstraint2[A,B]) : Option[Seq[Constraint2[A,B]]] =
if (or == null) None else Some(or.exprs)
}
/** Extractor for or constraints of any type signature */
object OrConstraint {
def unapply(constraint : Constraint) : Option[Seq[Constraint]] = constraint match {
case OrConstraint1(exprs) => Some(exprs)
case OrConstraint2(exprs) => Some(exprs)
case _ => None
}
}
/** Extractor for NAry constraints of any type signature */
object NAryConstraint {
def unapply(constraint : Constraint) : Option[Seq[Constraint]] = constraint match {
case OrConstraint(exprs) => Some(exprs)
case _ => None
}
}
}
package cp
import purescala.Trees._
class Converter(expr2scala : (Expr => Any)) {
def exprSeq2scala1[A](exprs: Seq[Expr]) : A =
expr2scala(exprs(0)).asInstanceOf[A]
def exprSeq2scala2[A,B](exprs: Seq[Expr]) : (A,B) =
(expr2scala(exprs(0)).asInstanceOf[A], expr2scala(exprs(1)).asInstanceOf[B])
}
package cp
object Definitions {
import Trees._
import Constraints._
class spec extends StaticAnnotation
......@@ -22,13 +22,14 @@ object Definitions {
implicit def any2Optimizable(x : Boolean) : Optimizable = new Optimizable(x)
implicit def pred2cons1[A](pred: A => Boolean) : Constraint1[A] = throw new NotImplementedException
implicit def pred2cons1[A](pred : A => Boolean) : Constraint1[A] = throw new NotImplementedException
implicit def pred2cons2[A,B](pred : (A,B) => Boolean) : Constraint2[A,B] = throw new NotImplementedException
def choose[A](constraint: Constraint1[A]) : A = {
def choose[A](constraint : Constraint1[A]) : A = {
throw new NotImplementedException()
}
def choose[A,B](pred : (A,B) => Boolean) : (A,B) = {
def choose[A,B](constraint : Constraint2[A,B]) : (A,B) = {
throw new NotImplementedException()
}
......@@ -145,7 +146,7 @@ object Definitions {
}
def distinct[A](args: A*) : Boolean = {
throw new NotImplementedException()
args.toList.distinct.size == args.size
}
}
......@@ -3,6 +3,7 @@ package cp
/** A collection of methods that are called on runtime */
object RuntimeMethods {
import Serialization._
import Constraints._
import Definitions.UnsatisfiableConstraintException
import Definitions.UnknownConstraintException
import purescala.Definitions._
......@@ -17,18 +18,19 @@ object RuntimeMethods {
private def newReporter() = if (silent) new QuietReporter() else new DefaultReporter()
private def newSolver() = new FairZ3Solver(newReporter())
def chooseExec(serializedProg : Serialized, serializedExpr : Serialized, serializedOutputVars : Serialized, inputConstraints : Expr) : Seq[Expr] = {
def chooseExec(serializedProg : Serialized, serializedConstraint : Serialized) : Seq[Expr] = {
val program = deserialize[Program](serializedProg)
val expr = deserialize[Expr](serializedExpr)
val outputVars = deserialize[Seq[Identifier]](serializedOutputVars)
val constraint = deserialize[Constraint](serializedConstraint)
chooseExec(program, expr, outputVars, inputConstraints)
chooseExec(program, constraint)
}
private def chooseExec(program : Program, expr : Expr, outputVars : Seq[Identifier], inputConstraints : Expr) : Seq[Expr] = {
private def chooseExec(program : Program, constraint : Constraint) : Seq[Expr] = {
val solver = newSolver()
solver.setProgram(program)
throw new Exception("not implemented")
/*
val toCheck = expr :: inputConstraints :: Nil
val (outcome, model) = solver.restartAndDecideWithModel(And(toCheck), false)
......@@ -40,6 +42,7 @@ object RuntimeMethods {
case None =>
throw new UnknownConstraintException()
}
*/
}
def chooseMinimizingExec(serializedProg : Serialized, serializedExpr : Serialized, serializedOutputVars : Serialized, serializedMinExpr : Serialized, inputConstraints : Expr) : Seq[Expr] = {
......@@ -210,7 +213,10 @@ object RuntimeMethods {
def findExec(serializedProg : Serialized, serializedExpr : Serialized, serializedOutputVars : Serialized, inputConstraints : Expr) : Option[Seq[Expr]] = {
try {
/*
Some(chooseExec(serializedProg, serializedExpr, serializedOutputVars, inputConstraints))
*/
throw new Exception("not implemented")
} catch {
case e: UnsatisfiableConstraintException => None
case e: UnknownConstraintException => None
......
package cp
import Serialization._
import purescala.Trees._
object Trees {
final class NotImplementedException extends Exception
abstract class Constraint1[A] {
def ||(other: Constraint1[A]): Constraint1[A] = OrConstraint1[A](this, other)
}
case class BaseConstraint1[A](serializedInputVars : Serialized, serializedOutputVars : Serialized, serializedExpr : Serialized, inputVarValues : Seq[Expr])
object OrConstraint1 {
def apply[A](l: Constraint1[A], r: Constraint1[A]): Constraint1[A] = {
new OrConstraint1[A](Seq(l,r))
}
}
class OrConstraint1[A](exprs: Seq[Constraint1[A]]) extends Constraint1[A]
}
......@@ -67,6 +67,7 @@ object PrettyPrinter {
private def pp(tree: Expr, sb: StringBuffer, lvl: Int): StringBuffer = tree match {
case Variable(id) => sb.append(id)
case DeBruijnIndex(idx) => sb.append("_" + idx)
case Let(b,d,e) => {
pp(e, pp(d, sb.append("(let (" + b + " := "), lvl).append(") in "), lvl).append(")")
}
......
......@@ -738,9 +738,19 @@ object Trees {
case i @ IfExpr(a1,a2,a3) => allIdentifiers(a1) ++ allIdentifiers(a2) ++ allIdentifiers(a3)
case m @ MatchExpr(scrut, cses) =>
(cses map (_.allIdentifiers)).foldLeft(Set[Identifier]())((a, b) => a ++ b) ++ allIdentifiers(scrut)
case Variable(id) => Set(id)
case t: Terminal => Set.empty
}
def allDeBruijnIndices(expr: Expr) : Set[DeBruijnIndex] = {
def convert(t: Expr) : Set[DeBruijnIndex] = t match {
case i @ DeBruijnIndex(idx) => Set(i)
case _ => Set.empty
}
def combine(s1: Set[DeBruijnIndex], s2: Set[DeBruijnIndex]) = s1 ++ s2
treeCatamorphism(convert, combine, expr)
}
/* Simplifies let expressions:
* - removes lets when expression never occurs
* - simplifies when expressions occurs exactly once
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment