diff --git a/src/main/scala/leon/genc/CAST.scala b/src/main/scala/leon/genc/CAST.scala index 253354acf67063b90188aac5b839821684c8fb65..07f4b558d097920c20fd2d0943973fcd60fbd0be 100644 --- a/src/main/scala/leon/genc/CAST.scala +++ b/src/main/scala/leon/genc/CAST.scala @@ -52,9 +52,7 @@ object CAST { // C Abstract Syntax Tree else name } - case class Var(id: Id, typ: Type) extends Def { - def access = AccessVar(id) - } + case class Var(id: Id, typ: Type) extends Def /* ----------------------------------------------------------- Stmts ----- */ abstract class Stmt extends Tree diff --git a/src/main/scala/leon/genc/CConverter.scala b/src/main/scala/leon/genc/CConverter.scala index c4c606f37780dbe1fc9407d40cf9100bff89d16a..05a7f2c38fbc057694c7fc8bfa46573d02f6ec0c 100644 --- a/src/main/scala/leon/genc/CConverter.scala +++ b/src/main/scala/leon/genc/CConverter.scala @@ -319,23 +319,44 @@ class CConverter(val ctx: LeonContext, val prog: Program) { fs.bodies ~~ CAST.ArrayInitWithValues(typ, fs.values) case ArrayUpdate(array1, index1, newValue1) => - val arrayType = convertToType(array1.getType) - val indexType = CAST.Int32 - val valueType = convertToType(newValue1.getType) - val values = array1 :: index1 :: newValue1 :: Nil - val types = arrayType :: indexType :: valueType :: Nil + val array2 = convertToStmt(array1) + val index2 = convertToStmt(index1) + val newValue2 = convertToStmt(newValue1) + val values = array2 :: index2 :: newValue2 :: Nil - val fs = convertAndNormaliseExecution(values, types) + val arePure = values forall { _.isPure } + val areValues = array2.isValue && index2.isValue // no newValue here - val array = fs.values(0) - val index = fs.values(1) - val newValue = fs.values(2) + newValue2 match { + case CAST.IfElse(cond, thn, elze) if arePure && areValues => + val array = array2 + val index = index2 + val ptr = CAST.AccessField(array, CAST.Array.dataId) + val select = CAST.SubscriptOp(ptr, index) - val ptr = CAST.AccessField(array, CAST.Array.dataId) - val select = CAST.SubscriptOp(ptr, index) - val assign = CAST.Assign(select, newValue) + val ifelse = buildIfElse(cond, injectAssign(select, thn), + injectAssign(select, elze)) + + ifelse - fs.bodies ~~ assign + case _ => + val arrayType = convertToType(array1.getType) + val indexType = CAST.Int32 + val valueType = convertToType(newValue1.getType) + val types = arrayType :: indexType :: valueType :: Nil + + val fs = normaliseExecution(values, types) + + val array = fs.values(0) + val index = fs.values(1) + val newValue = fs.values(2) + + val ptr = CAST.AccessField(array, CAST.Array.dataId) + val select = CAST.SubscriptOp(ptr, index) + val assign = CAST.Assign(select, newValue) + + fs.bodies ~~ assign + } case CaseClass(typ, args1) => val struct = convertToStruct(typ) @@ -581,6 +602,10 @@ class CConverter(val ctx: LeonContext, val prog: Program) { } private def injectAssign(x: CAST.Var, stmt: CAST.Stmt): CAST.Stmt = { + injectAssign(CAST.AccessVar(x.id), stmt) + } + + private def injectAssign(x: CAST.Stmt, stmt: CAST.Stmt): CAST.Stmt = { val f = flatten(stmt) f.value match { @@ -588,7 +613,7 @@ class CConverter(val ctx: LeonContext, val prog: Program) { f.body ~~ CAST.IfElse(cond, injectAssign(x, thn), injectAssign(x, elze)) case _ => - f.body ~~ CAST.Assign(x.access, f.value) + f.body ~~ CAST.Assign(x, f.value) } }