diff --git a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala index 3ac83225d73b7dfacaccd183eed03bb37825e818..c252880f59e6f1cba8c6248095f9448e21174574 100644 --- a/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala +++ b/src/main/scala/leon/synthesis/utils/ExpressionGrammar.scala @@ -85,13 +85,10 @@ object ExpressionGrammars { Generator(List(BooleanType), { case Seq(a) => not(a) }), Generator(List(BooleanType, BooleanType), { case Seq(a, b) => and(a, b) }), Generator(List(BooleanType, BooleanType), { case Seq(a, b) => or(a, b) }), - Generator(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), - Generator(List(Int32Type, Int32Type), { case Seq(a, b) => LessEquals(a, b) }), - Generator(List(Int32Type, Int32Type ), { case Seq(a, b) => equality(a, b) }), + Generator(List(Int32Type, Int32Type), { case Seq(a, b) => LessThan(a, b) }), + Generator(List(Int32Type, Int32Type), { case Seq(a, b) => LessEquals(a, b) }), Generator(List(IntegerType, IntegerType), { case Seq(a, b) => LessThan(a, b) }), - Generator(List(IntegerType, IntegerType), { case Seq(a, b) => LessEquals(a, b) }), - Generator(List(IntegerType, IntegerType), { case Seq(a, b) => equality(a, b) }), - Generator(List(BooleanType, BooleanType), { case Seq(a, b) => equality(a, b) }) + Generator(List(IntegerType, IntegerType), { case Seq(a, b) => LessEquals(a, b) }) ) case Int32Type => List( @@ -140,6 +137,17 @@ object ExpressionGrammars { } } + case class EqualityGrammar(types: Set[TypeTree]) extends ExpressionGrammar[TypeTree] { + override def computeProductions(t: TypeTree): Seq[Gen] = t match { + case BooleanType => + types.toList map { tp => + Generator[TypeTree, Expr](List(tp, tp), { case Seq(a, b) => equality(a, b) }) + } + + case _ => Nil + } + } + case object ValueGrammar extends ExpressionGrammar[TypeTree] { def computeProductions(t: TypeTree): Seq[Gen] = t match { case BooleanType => @@ -211,6 +219,7 @@ object ExpressionGrammars { val normalGrammar = BoundedGrammar(EmbeddedGrammar( BaseGrammar || + EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ terminals.map { _.getType }) || OneOf(terminals.toSeq :+ e) || FunctionCalls(sctx.program, sctx.functionContext, p.as.map(_.getType), excludeFCalls) || SafeRecCalls(sctx.program, p.ws, p.pc), @@ -285,13 +294,20 @@ object ExpressionGrammars { normalGrammar.getProductions(gl).map(gl -> _) } - def intVariations(gl: L, e : Expr): Seq[(L, Gen)] = { + def int32Variations(gl: L, e : Expr): Seq[(L, Gen)] = { Seq( gl -> Generator(Nil, { _ => BVMinus(e, IntLiteral(1))} ), gl -> Generator(Nil, { _ => BVPlus (e, IntLiteral(1))} ) ) } + def intVariations(gl: L, e : Expr): Seq[(L, Gen)] = { + Seq( + gl -> Generator(Nil, { _ => Minus(e, InfiniteIntegerLiteral(1))} ), + gl -> Generator(Nil, { _ => Plus (e, InfiniteIntegerLiteral(1))} ) + ) + } + // Find neighbor case classes that are compatible with the arguments: // Turns And(e1, e2) into Or(e1, e2)... def ccVariations(gl: L, cc: CaseClass): Seq[(L, Gen)] = { @@ -305,7 +321,7 @@ object ExpressionGrammars { } val funFilter = (fd: FunDef) => fd.isSynthetic || (excludeFCalls contains fd) - val subs: Seq[(L, Gen)] = (e match { + val subs: Seq[(L, Gen)] = e match { case _: Terminal | _: Let | _: LetDef | _: MatchExpr => gens(e, gl, Nil, { _ => e }) ++ cegis(gl) @@ -324,14 +340,20 @@ object ExpressionGrammars { case NAryOperator(subs, builder) => gens(e, gl, subs, { case ss => builder(ss) }) - }) ++ (if (e.getType == Int32Type ) intVariations(gl, e) else Nil) + } val terminalsMatching = terminals.collect { case IsTyped(term, tpe) if tpe == gl.getType && term != e => gl -> Generator[L, Expr](Nil, { _ => term }) } - subs ++ terminalsMatching + val variations = e.getType match { + case IntegerType => intVariations(gl, e) + case Int32Type => int32Variations(gl, e) + case _ => Nil + } + + subs ++ terminalsMatching ++ variations } val gl = getLabel(e.getType) @@ -482,6 +504,7 @@ object ExpressionGrammars { def default(prog: Program, inputs: Seq[Expr], currentFunction: FunDef, exclude: Set[FunDef], ws: Expr, pc: Expr): ExpressionGrammar[TypeTree] = { BaseGrammar || + EqualityGrammar(Set(IntegerType, Int32Type, BooleanType) ++ inputs.map { _.getType }) || OneOf(inputs) || FunctionCalls(prog, currentFunction, inputs.map(_.getType), exclude) || SafeRecCalls(prog, ws, pc)