diff --git a/src/funcheck/CodeExtraction.scala b/src/funcheck/CodeExtraction.scala index 143832488d656d9da9e9c0e10bf15f1d22124645..f1f2420cb1c9e4bdfe3adbabe989059b237bdc82 100644 --- a/src/funcheck/CodeExtraction.scala +++ b/src/funcheck/CodeExtraction.scala @@ -275,6 +275,30 @@ trait CodeExtraction extends Extractors { Program(programName, topLevelObjDef) } + def extractPredicate(unit: CompilationUnit, params: Seq[ValDef], body: Tree) : FunDef = { + def s2ps(tree: Tree): Expr = toPureScala(unit)(tree) match { + case Some(ex) => ex + case None => stopIfErrors; scala.Predef.error("unreachable error.") + } + + def st2ps(tree: Type): purescala.TypeTrees.TypeTree = toPureScalaType(unit)(tree) match { + case Some(tt) => tt + case None => stopIfErrors; scala.Predef.error("unreachable error.") + } + + val newParams = params.map(p => { + val ptpe = st2ps(p.tpt.tpe) + val newID = FreshIdentifier(p.name.toString).setType(ptpe) + varSubsts(p.symbol) = (() => Variable(newID)) + VarDecl(newID, ptpe) + }) + val fd = new FunDef(FreshIdentifier("predicate"), BooleanType, newParams) + + val bodyAttempt = toPureScala(unit)(body) + fd.body = bodyAttempt + fd + } + /** An exception thrown when non-purescala compatible code is encountered. */ sealed case class ImpureCodeEncounteredException(tree: Tree) extends Exception