diff --git a/src/main/scala/leon/Evaluator.scala b/src/main/scala/leon/Evaluator.scala index 47337af61e94395275fa76f349958f08856e00b9..8767458c2251b0d8cdeb8e2a87af9c6f47306a8f 100644 --- a/src/main/scala/leon/Evaluator.scala +++ b/src/main/scala/leon/Evaluator.scala @@ -247,6 +247,7 @@ object Evaluator { case e @ EmptySet(_) => e case i @ IntLiteral(_) => i case b @ BooleanLiteral(_) => b + case u @ UnitLiteral => u case f @ ArrayMake(default) => { val rDefault = rec(ctx, default) diff --git a/src/main/scala/leon/FairZ3Solver.scala b/src/main/scala/leon/FairZ3Solver.scala index 3ab173b7d855c76e1d614200cbab8dc3b9078eaf..ea2681cb4dd76c4f487b12fee6b7d94b83df3644 100644 --- a/src/main/scala/leon/FairZ3Solver.scala +++ b/src/main/scala/leon/FairZ3Solver.scala @@ -100,6 +100,9 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S private var mapSorts: Map[TypeTree, Z3Sort] = Map.empty private var arraySorts: Map[TypeTree, Z3Sort] = Map.empty + private var unitSort: Z3Sort = null + private var unitValue: Z3AST = null + protected[leon] var funSorts: Map[TypeTree, Z3Sort] = Map.empty protected[leon] var funDomainConstructors: Map[TypeTree, Z3FuncDecl] = Map.empty protected[leon] var funDomainSelectors: Map[TypeTree, Seq[Z3FuncDecl]] = Map.empty @@ -191,6 +194,27 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S setSorts = Map.empty setCardFuns = Map.empty + //unitSort = z3.mkUninterpretedSort("unit") + //unitValue = z3.mkFreshConst("Unit", unitSort) + //val bound = z3.mkBound(0, unitSort) + //val eq = z3.mkEq(bound, unitValue) + //val decls = Seq((z3.mkFreshStringSymbol("u"), unitSort)) + //val unitAxiom = z3.mkForAll(0, Seq(), decls, eq) + //println(unitAxiom) + //println(unitValue) + //z3.assertCnstr(unitAxiom) + val Seq((us, Seq(unitCons), Seq(unitTester), _)) = z3.mkADTSorts( + Seq( + ( + "Unit", + Seq("Unit"), + Seq(Seq()) + ) + ) + ) + unitSort = us + unitValue = unitCons() + val intSetSort = typeToSort(SetType(Int32Type)) intSetMinFun = z3.mkFreshFuncDecl("setMin", Seq(intSetSort), intSort) intSetMaxFun = z3.mkFreshFuncDecl("setMax", Seq(intSetSort), intSort) @@ -360,6 +384,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S def typeToSort(tt: TypeTree): Z3Sort = tt match { case Int32Type => intSort case BooleanType => boolSort + case UnitType => unitSort case AbstractClassType(cd) => adtSorts(cd) case CaseClassType(cd) => { if (cd.hasParent) { @@ -967,6 +992,7 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S case Not(e) => z3.mkNot(rec(e)) case IntLiteral(v) => z3.mkInt(v, intSort) case BooleanLiteral(v) => if (v) z3.mkTrue() else z3.mkFalse() + case UnitLiteral => unitValue case Equals(l, r) => z3.mkEq(rec(l), rec(r)) case Plus(l, r) => if(USEBV) z3.mkBVAdd(rec(l), rec(r)) else z3.mkAdd(rec(l), rec(r)) case Minus(l, r) => if(USEBV) z3.mkBVSub(rec(l), rec(r)) else z3.mkSub(rec(l), rec(r)) @@ -1131,7 +1157,9 @@ class FairZ3Solver(reporter: Reporter) extends Solver(reporter) with AbstractZ3S Tuple(rargs) } case other => - z3.getASTKind(t) match { + if(t == unitValue) + UnitLiteral + else z3.getASTKind(t) match { case Z3AppAST(decl, args) => { val argsSize = args.size if(argsSize == 0 && z3IdToExpr.isDefinedAt(t)) { diff --git a/src/main/scala/leon/Main.scala b/src/main/scala/leon/Main.scala index b3b3c17b410bc1c0565c1b4fab5964edb914aed7..68161e9e4e93600f2df7be787b9db8f4af796e3c 100644 --- a/src/main/scala/leon/Main.scala +++ b/src/main/scala/leon/Main.scala @@ -32,7 +32,7 @@ object Main { private def defaultAction(program: Program, reporter: Reporter) : Unit = { Logger.debug("Default action on program: " + program, 3, "main") - val passManager = new PassManager(Seq(ArrayTransformation, EpsilonElimination, ImperativeCodeElimination, UnitElimination, FunctionClosure, FunctionHoisting, Simplificator)) + val passManager = new PassManager(Seq(ArrayTransformation, EpsilonElimination, ImperativeCodeElimination, /*UnitElimination,*/ FunctionClosure, FunctionHoisting, Simplificator)) val program2 = passManager.run(program) val analysis = new Analysis(program2, reporter) analysis.analyse diff --git a/src/main/scala/leon/plugin/CodeExtraction.scala b/src/main/scala/leon/plugin/CodeExtraction.scala index 6ac28303d874fa3cdbd35a928531b93be4c2af8d..6b9ef14e5d419082c8901fd2bef62121ec150a76 100644 --- a/src/main/scala/leon/plugin/CodeExtraction.scala +++ b/src/main/scala/leon/plugin/CodeExtraction.scala @@ -935,6 +935,7 @@ trait CodeExtraction extends Extractors { def rec(tr: Type): purescala.TypeTrees.TypeTree = tr match { case tpe if tpe == IntClass.tpe => Int32Type case tpe if tpe == BooleanClass.tpe => BooleanType + case tpe if tpe == UnitClass.tpe => UnitType case TypeRef(_, sym, btt :: Nil) if isSetTraitSym(sym) => SetType(rec(btt)) case TypeRef(_, sym, btt :: Nil) if isMultisetTraitSym(sym) => MultisetType(rec(btt)) case TypeRef(_, sym, btt :: Nil) if isOptionClassSym(sym) => OptionType(rec(btt)) diff --git a/testcases/regression/invalid/Unit1.scala b/testcases/regression/invalid/Unit1.scala new file mode 100644 index 0000000000000000000000000000000000000000..789a8f058cd8145a0dd3b7bb0d72d44fb92ee94e --- /dev/null +++ b/testcases/regression/invalid/Unit1.scala @@ -0,0 +1,7 @@ +object Unit1 { + + def foo(u: Unit): Unit = ({ + u + }) ensuring(_ != ()) + +} diff --git a/testcases/regression/valid/Unit1.scala b/testcases/regression/valid/Unit1.scala new file mode 100644 index 0000000000000000000000000000000000000000..a7b890d762648cba480c817506c7eee22259860d --- /dev/null +++ b/testcases/regression/valid/Unit1.scala @@ -0,0 +1,7 @@ +object Unit1 { + + def foo(): Unit = ({ + () + }) ensuring(_ == ()) + +} diff --git a/testcases/regression/valid/Unit2.scala b/testcases/regression/valid/Unit2.scala new file mode 100644 index 0000000000000000000000000000000000000000..ac659589af503a8a79f261f9eb05be095e2e0943 --- /dev/null +++ b/testcases/regression/valid/Unit2.scala @@ -0,0 +1,7 @@ +object Unit2 { + + def foo(u: Unit): Unit = { + u + } ensuring(_ == ()) + +}