[Mlir-commits] [mlir] [mlir][EmitC] Emit parentheses when expression ops are used as operands (PR #93691)

Simon Camphausen llvmlistbot at llvm.org
Mon Jun 3 04:56:02 PDT 2024


https://github.com/simon-camp updated https://github.com/llvm/llvm-project/pull/93691

>From f5e1d9596b09942b80effa79773f10aac9405063 Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Wed, 29 May 2024 14:10:03 +0000
Subject: [PATCH 1/3] [mlir][EmitC] Emit parentheses for users of expression
 ops

---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp |  9 ++-
 mlir/test/Target/Cpp/expressions.mlir  | 84 +++++++++++++++++++++++---
 mlir/test/Target/Cpp/for.mlir          |  4 +-
 3 files changed, 85 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index f19e0f8c4c2a4..e7d80d80855a5 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1338,8 +1338,13 @@ LogicalResult CppEmitter::emitOperand(Value value) {
   }
 
   auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
-  if (expressionOp && shouldBeInlined(expressionOp))
-    return emitExpression(expressionOp);
+  if (expressionOp && shouldBeInlined(expressionOp)) {
+    os << "(";
+    if (failed(emitExpression(expressionOp)))
+      return failure();
+    os << ")";
+    return success();
+  }
 
   auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
   if (!literalOp && !hasValueInScope(value))
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index aaddd5af874a9..37e0a0ffbdeb1 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -66,11 +66,11 @@ func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 {
 }
 
 // CPP-DEFAULT:      float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DEFAULT-NEXT:   return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
+// CPP-DEFAULT-NEXT:   return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]));
 // CPP-DEFAULT-NEXT: }
 
 // CPP-DECLTOP:      float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DECLTOP-NEXT:   return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
+// CPP-DECLTOP-NEXT:   return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]));
 // CPP-DECLTOP-NEXT: }
 
 func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 {
@@ -84,11 +84,11 @@ func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) ->
 }
 
 // CPP-DEFAULT:      int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DEFAULT-NEXT:   return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
 // CPP-DEFAULT-NEXT: }
 
 // CPP-DECLTOP:      int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DECLTOP-NEXT:   return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
 // CPP-DECLTOP-NEXT: }
 func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
   %e = emitc.expression : i32 {
@@ -100,6 +100,74 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
   return %e : i32
 }
 
+// CPP-DEFAULT:      int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT:   int32_t v4 = 0;
+// CPP-DEFAULT-NEXT:   bool v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DEFAULT-NEXT:   int32_t v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
+// CPP-DEFAULT-NEXT:   int32_t v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4);
+// CPP-DEFAULT-NEXT:   int32_t v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
+// CPP-DEFAULT-NEXT:   int32_t v9;
+// CPP-DEFAULT-NEXT:   v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DEFAULT-NEXT:   return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP:      int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT:   int32_t v4;
+// CPP-DECLTOP-NEXT:   bool v5;
+// CPP-DECLTOP-NEXT:   int32_t v6;
+// CPP-DECLTOP-NEXT:   int32_t v7;
+// CPP-DECLTOP-NEXT:   int32_t v8;
+// CPP-DECLTOP-NEXT:   int32_t v9;
+// CPP-DECLTOP-NEXT:   v4 = 0;
+// CPP-DECLTOP-NEXT:   v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DECLTOP-NEXT:   v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
+// CPP-DECLTOP-NEXT:   v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4);
+// CPP-DECLTOP-NEXT:   v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
+// CPP-DECLTOP-NEXT:   ;
+// CPP-DECLTOP-NEXT:   v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DECLTOP-NEXT:   return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DECLTOP-NEXT: }
+func.func @parentheses_for_expression_users(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+  %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
+  %e0 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %e1 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %e2 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %e3 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %e4 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %e5 = emitc.expression : i32 {
+      %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+      %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+      emitc.yield %1 : i32
+    }
+  %cast = emitc.cast %e0 : i32 to i1
+  %add = emitc.add %e1, %c0 : (i32, i32) -> i32
+  %call = emitc.call_opaque "bar" (%e2, %c0) : (i32, i32) -> (i32)
+  %cond = emitc.conditional %cast, %e3, %c0 : i32
+  %var = "emitc.variable"() {value = #emitc.opaque<"">} : () -> i32
+  emitc.assign %e4 : i32 to %var : i32
+  return %e5 : i32
+}
+
 // CPP-DEFAULT:      int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
 // CPP-DEFAULT-NEXT:   bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
 // CPP-DEFAULT-NEXT:   int32_t [[VAL_6:v[0-9]+]];
@@ -154,7 +222,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32
 // CPP-DEFAULT-NEXT:   int32_t [[VAL_5:v[0-9]+]] = [[VAL_3]] % [[VAL_4]];
 // CPP-DEFAULT-NEXT:   int32_t [[VAL_6:v[0-9]+]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]);
 // CPP-DEFAULT-NEXT:   int32_t [[VAL_7:v[0-9]+]];
-// CPP-DEFAULT-NEXT:   if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) {
+// CPP-DEFAULT-NEXT:   if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) {
 // CPP-DEFAULT-NEXT:     [[VAL_7]] = [[VAL_1]];
 // CPP-DEFAULT-NEXT:   } else {
 // CPP-DEFAULT-NEXT:     [[VAL_7]] = [[VAL_1]];
@@ -169,7 +237,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32
 // CPP-DECLTOP-NEXT:   [[VAL_5]] = [[VAL_3]] % [[VAL_4]];
 // CPP-DECLTOP-NEXT:   [[VAL_6]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]);
 // CPP-DECLTOP-NEXT:   ;
-// CPP-DECLTOP-NEXT:   if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) {
+// CPP-DECLTOP-NEXT:   if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) {
 // CPP-DECLTOP-NEXT:     [[VAL_7]] = [[VAL_1]];
 // CPP-DECLTOP-NEXT:   } else {
 // CPP-DECLTOP-NEXT:     [[VAL_7]] = [[VAL_1]];
@@ -205,13 +273,13 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32)
 
 // CPP-DEFAULT:      bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
 // CPP-DEFAULT-NEXT:   int32_t [[VAL_4:v[0-9]+]] = [[VAL_1]] % [[VAL_2]];
-// CPP-DEFAULT-NEXT:   return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
+// CPP-DEFAULT-NEXT:   return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]);
 // CPP-DEFAULT-NEXT: }
 
 // CPP-DECLTOP:      bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
 // CPP-DECLTOP-NEXT:   int32_t [[VAL_4:v[0-9]+]];
 // CPP-DECLTOP-NEXT:   [[VAL_4]] = [[VAL_1]] % [[VAL_2]];
-// CPP-DECLTOP-NEXT:   return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
+// CPP-DECLTOP-NEXT:   return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]);
 // CPP-DECLTOP-NEXT: }
 
 func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 {
diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir
index 60988bcb46556..2e41dce45f580 100644
--- a/mlir/test/Target/Cpp/for.mlir
+++ b/mlir/test/Target/Cpp/for.mlir
@@ -20,14 +20,14 @@ func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
   return
 }
 // CPP-DEFAULT: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) {
-// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) {
+// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) {
 // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]] = f();
 // CPP-DEFAULT-NEXT: }
 // CPP-DEFAULT-NEXT: return;
 
 // CPP-DECLTOP: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) {
 // CPP-DECLTOP-NEXT: int32_t [[V4:[^ ]*]];
-// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) {
+// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) {
 // CPP-DECLTOP-NEXT: [[V4]] = f();
 // CPP-DECLTOP-NEXT: }
 // CPP-DECLTOP-NEXT: return;

>From aa4b1bfbf320886e3855fbb0b2a0a3f76cae455a Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Wed, 29 May 2024 14:15:53 +0000
Subject: [PATCH 2/3] Skip parenthesis where its safe

---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp | 10 +++++++--
 mlir/test/Target/Cpp/expressions.mlir  | 28 +++++++++++++-------------
 mlir/test/Target/Cpp/for.mlir          |  4 ++--
 3 files changed, 24 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index e7d80d80855a5..83ef2a39950f2 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1339,10 +1339,16 @@ LogicalResult CppEmitter::emitOperand(Value value) {
 
   auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
   if (expressionOp && shouldBeInlined(expressionOp)) {
-    os << "(";
+    Operation *user = *expressionOp->getUsers().begin();
+    const bool safeToSkipParentheses =
+        isa<emitc::AssignOp, emitc::CallOp, emitc::CallOpaqueOp, emitc::ForOp,
+            emitc::IfOp, emitc::ReturnOp, func::CallOp, func::ReturnOp>(user);
+    if (!safeToSkipParentheses)
+      os << "(";
     if (failed(emitExpression(expressionOp)))
       return failure();
-    os << ")";
+    if (!safeToSkipParentheses)
+      os << ")";
     return success();
   }
 
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index 37e0a0ffbdeb1..1c55b9404225d 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -66,11 +66,11 @@ func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 {
 }
 
 // CPP-DEFAULT:      float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DEFAULT-NEXT:   return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]));
+// CPP-DEFAULT-NEXT:   return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
 // CPP-DEFAULT-NEXT: }
 
 // CPP-DECLTOP:      float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DECLTOP-NEXT:   return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]));
+// CPP-DECLTOP-NEXT:   return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
 // CPP-DECLTOP-NEXT: }
 
 func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 {
@@ -84,11 +84,11 @@ func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) ->
 }
 
 // CPP-DEFAULT:      int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DEFAULT-NEXT:   return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DEFAULT-NEXT:   return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
 // CPP-DEFAULT-NEXT: }
 
 // CPP-DECLTOP:      int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DECLTOP-NEXT:   return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DECLTOP-NEXT:   return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
 // CPP-DECLTOP-NEXT: }
 func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
   %e = emitc.expression : i32 {
@@ -104,11 +104,11 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
 // CPP-DEFAULT-NEXT:   int32_t v4 = 0;
 // CPP-DEFAULT-NEXT:   bool v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
 // CPP-DEFAULT-NEXT:   int32_t v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
-// CPP-DEFAULT-NEXT:   int32_t v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4);
+// CPP-DEFAULT-NEXT:   int32_t v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4);
 // CPP-DEFAULT-NEXT:   int32_t v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
 // CPP-DEFAULT-NEXT:   int32_t v9;
-// CPP-DEFAULT-NEXT:   v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
-// CPP-DEFAULT-NEXT:   return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DEFAULT-NEXT:   v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
 // CPP-DEFAULT-NEXT: }
 
 // CPP-DECLTOP:      int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
@@ -121,11 +121,11 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
 // CPP-DECLTOP-NEXT:   v4 = 0;
 // CPP-DECLTOP-NEXT:   v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
 // CPP-DECLTOP-NEXT:   v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
-// CPP-DECLTOP-NEXT:   v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4);
+// CPP-DECLTOP-NEXT:   v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4);
 // CPP-DECLTOP-NEXT:   v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
 // CPP-DECLTOP-NEXT:   ;
-// CPP-DECLTOP-NEXT:   v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
-// CPP-DECLTOP-NEXT:   return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DECLTOP-NEXT:   v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
 // CPP-DECLTOP-NEXT: }
 func.func @parentheses_for_expression_users(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
   %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
@@ -222,7 +222,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32
 // CPP-DEFAULT-NEXT:   int32_t [[VAL_5:v[0-9]+]] = [[VAL_3]] % [[VAL_4]];
 // CPP-DEFAULT-NEXT:   int32_t [[VAL_6:v[0-9]+]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]);
 // CPP-DEFAULT-NEXT:   int32_t [[VAL_7:v[0-9]+]];
-// CPP-DEFAULT-NEXT:   if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) {
+// CPP-DEFAULT-NEXT:   if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) {
 // CPP-DEFAULT-NEXT:     [[VAL_7]] = [[VAL_1]];
 // CPP-DEFAULT-NEXT:   } else {
 // CPP-DEFAULT-NEXT:     [[VAL_7]] = [[VAL_1]];
@@ -237,7 +237,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32
 // CPP-DECLTOP-NEXT:   [[VAL_5]] = [[VAL_3]] % [[VAL_4]];
 // CPP-DECLTOP-NEXT:   [[VAL_6]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]);
 // CPP-DECLTOP-NEXT:   ;
-// CPP-DECLTOP-NEXT:   if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) {
+// CPP-DECLTOP-NEXT:   if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) {
 // CPP-DECLTOP-NEXT:     [[VAL_7]] = [[VAL_1]];
 // CPP-DECLTOP-NEXT:   } else {
 // CPP-DECLTOP-NEXT:     [[VAL_7]] = [[VAL_1]];
@@ -273,13 +273,13 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32)
 
 // CPP-DEFAULT:      bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
 // CPP-DEFAULT-NEXT:   int32_t [[VAL_4:v[0-9]+]] = [[VAL_1]] % [[VAL_2]];
-// CPP-DEFAULT-NEXT:   return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]);
+// CPP-DEFAULT-NEXT:   return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
 // CPP-DEFAULT-NEXT: }
 
 // CPP-DECLTOP:      bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
 // CPP-DECLTOP-NEXT:   int32_t [[VAL_4:v[0-9]+]];
 // CPP-DECLTOP-NEXT:   [[VAL_4]] = [[VAL_1]] % [[VAL_2]];
-// CPP-DECLTOP-NEXT:   return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]);
+// CPP-DECLTOP-NEXT:   return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
 // CPP-DECLTOP-NEXT: }
 
 func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 {
diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir
index 2e41dce45f580..60988bcb46556 100644
--- a/mlir/test/Target/Cpp/for.mlir
+++ b/mlir/test/Target/Cpp/for.mlir
@@ -20,14 +20,14 @@ func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
   return
 }
 // CPP-DEFAULT: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) {
-// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) {
+// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) {
 // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]] = f();
 // CPP-DEFAULT-NEXT: }
 // CPP-DEFAULT-NEXT: return;
 
 // CPP-DECLTOP: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) {
 // CPP-DECLTOP-NEXT: int32_t [[V4:[^ ]*]];
-// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) {
+// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) {
 // CPP-DECLTOP-NEXT: [[V4]] = f();
 // CPP-DECLTOP-NEXT: }
 // CPP-DECLTOP-NEXT: return;

>From 7b34b21a7f7ad81a194f76474aff0b9c398e7bde Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Mon, 3 Jun 2024 11:55:33 +0000
Subject: [PATCH 3/3] Do not inline expressions into ops with the CExpression
 trait

---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp | 18 ++++-------
 mlir/test/Target/Cpp/expressions.mlir  | 44 ++++++++++++++++----------
 2 files changed, 34 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 83ef2a39950f2..bbc8aa7c9fe91 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -303,7 +303,11 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
 
   // Do not inline expressions used by other expressions, as any desired
   // expression folding was taken care of by transformations.
-  return !user->getParentOfType<ExpressionOp>();
+  if (user->getParentOfType<ExpressionOp>())
+    return false;
+  
+  // Do not inline expressions used by ops with the CExpression trait.
+  return !user->hasTrait<OpTrait::emitc::CExpression>();
 }
 
 static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
@@ -1339,17 +1343,7 @@ LogicalResult CppEmitter::emitOperand(Value value) {
 
   auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
   if (expressionOp && shouldBeInlined(expressionOp)) {
-    Operation *user = *expressionOp->getUsers().begin();
-    const bool safeToSkipParentheses =
-        isa<emitc::AssignOp, emitc::CallOp, emitc::CallOpaqueOp, emitc::ForOp,
-            emitc::IfOp, emitc::ReturnOp, func::CallOp, func::ReturnOp>(user);
-    if (!safeToSkipParentheses)
-      os << "(";
-    if (failed(emitExpression(expressionOp)))
-      return failure();
-    if (!safeToSkipParentheses)
-      os << ")";
-    return success();
+    return emitExpression(expressionOp);
   }
 
   auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index 1c55b9404225d..810a629c71533 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -100,34 +100,46 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
   return %e : i32
 }
 
-// CPP-DEFAULT:      int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT:      int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
 // CPP-DEFAULT-NEXT:   int32_t v4 = 0;
-// CPP-DEFAULT-NEXT:   bool v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
-// CPP-DEFAULT-NEXT:   int32_t v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
-// CPP-DEFAULT-NEXT:   int32_t v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4);
-// CPP-DEFAULT-NEXT:   int32_t v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
-// CPP-DEFAULT-NEXT:   int32_t v9;
-// CPP-DEFAULT-NEXT:   v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT:   bool v9 = (bool) [[EXP_0]];
+// CPP-DEFAULT-NEXT:   int32_t v10 = [[EXP_1]] + v4;
+// CPP-DEFAULT-NEXT:   int32_t v11 = bar([[EXP_2]], v4);
+// CPP-DEFAULT-NEXT:   int32_t v12 = v9 ? [[EXP_3]] : v4;
+// CPP-DEFAULT-NEXT:   int32_t v13;
+// CPP-DEFAULT-NEXT:   v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
 // CPP-DEFAULT-NEXT:   return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
 // CPP-DEFAULT-NEXT: }
 
-// CPP-DECLTOP:      int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP:      int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
 // CPP-DECLTOP-NEXT:   int32_t v4;
-// CPP-DECLTOP-NEXT:   bool v5;
+// CPP-DECLTOP-NEXT:   int32_t v5;
 // CPP-DECLTOP-NEXT:   int32_t v6;
 // CPP-DECLTOP-NEXT:   int32_t v7;
 // CPP-DECLTOP-NEXT:   int32_t v8;
-// CPP-DECLTOP-NEXT:   int32_t v9;
+// CPP-DECLTOP-NEXT:   bool v9;
+// CPP-DECLTOP-NEXT:   int32_t v10;
+// CPP-DECLTOP-NEXT:   int32_t v11;
+// CPP-DECLTOP-NEXT:   int32_t v12;
+// CPP-DECLTOP-NEXT:   int32_t v13;
 // CPP-DECLTOP-NEXT:   v4 = 0;
-// CPP-DECLTOP-NEXT:   v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
-// CPP-DECLTOP-NEXT:   v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
-// CPP-DECLTOP-NEXT:   v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4);
-// CPP-DECLTOP-NEXT:   v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
+// CPP-DECLTOP-NEXT:   v5 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   v6 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   v7 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   v8 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   v9 = (bool) v5;
+// CPP-DECLTOP-NEXT:   v10 = v6 + v4;
+// CPP-DECLTOP-NEXT:   v11 = bar(v7, v4);
+// CPP-DECLTOP-NEXT:   v12 = v9 ? v8 : v4;
 // CPP-DECLTOP-NEXT:   ;
-// CPP-DECLTOP-NEXT:   v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT:   v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
 // CPP-DECLTOP-NEXT:   return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
 // CPP-DECLTOP-NEXT: }
-func.func @parentheses_for_expression_users(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
   %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
   %e0 = emitc.expression : i32 {
       %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32



More information about the Mlir-commits mailing list