[Mlir-commits] [mlir] [mlir][EmitC] Do not inline expressions used by ops with the CExpression trait (PR #93691)
Simon Camphausen
llvmlistbot at llvm.org
Tue Jun 4 01:04:27 PDT 2024
https://github.com/simon-camp updated https://github.com/llvm/llvm-project/pull/93691
>From f5e1d9596b09942b80effa79773f10aac9405063 Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Wed, 29 May 2024 14:10:03 +0000
Subject: [PATCH 1/4] [mlir][EmitC] Emit parentheses for users of expression
ops
---
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 9 ++-
mlir/test/Target/Cpp/expressions.mlir | 84 +++++++++++++++++++++++---
mlir/test/Target/Cpp/for.mlir | 4 +-
3 files changed, 85 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index f19e0f8c4c2a4..e7d80d80855a5 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1338,8 +1338,13 @@ LogicalResult CppEmitter::emitOperand(Value value) {
}
auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
- if (expressionOp && shouldBeInlined(expressionOp))
- return emitExpression(expressionOp);
+ if (expressionOp && shouldBeInlined(expressionOp)) {
+ os << "(";
+ if (failed(emitExpression(expressionOp)))
+ return failure();
+ os << ")";
+ return success();
+ }
auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
if (!literalOp && !hasValueInScope(value))
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index aaddd5af874a9..37e0a0ffbdeb1 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -66,11 +66,11 @@ func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 {
}
// CPP-DEFAULT: float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DEFAULT-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
+// CPP-DEFAULT-NEXT: return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]));
// CPP-DEFAULT-NEXT: }
// CPP-DECLTOP: float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DECLTOP-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
+// CPP-DECLTOP-NEXT: return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]));
// CPP-DECLTOP-NEXT: }
func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 {
@@ -84,11 +84,11 @@ func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) ->
}
// CPP-DEFAULT: int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
// CPP-DEFAULT-NEXT: }
// CPP-DECLTOP: int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
// CPP-DECLTOP-NEXT: }
func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
%e = emitc.expression : i32 {
@@ -100,6 +100,74 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
return %e : i32
}
+// CPP-DEFAULT: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT: int32_t v4 = 0;
+// CPP-DEFAULT-NEXT: bool v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DEFAULT-NEXT: int32_t v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
+// CPP-DEFAULT-NEXT: int32_t v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4);
+// CPP-DEFAULT-NEXT: int32_t v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
+// CPP-DEFAULT-NEXT: int32_t v9;
+// CPP-DEFAULT-NEXT: v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DEFAULT-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT: int32_t v4;
+// CPP-DECLTOP-NEXT: bool v5;
+// CPP-DECLTOP-NEXT: int32_t v6;
+// CPP-DECLTOP-NEXT: int32_t v7;
+// CPP-DECLTOP-NEXT: int32_t v8;
+// CPP-DECLTOP-NEXT: int32_t v9;
+// CPP-DECLTOP-NEXT: v4 = 0;
+// CPP-DECLTOP-NEXT: v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DECLTOP-NEXT: v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
+// CPP-DECLTOP-NEXT: v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4);
+// CPP-DECLTOP-NEXT: v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DECLTOP-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DECLTOP-NEXT: }
+func.func @parentheses_for_expression_users(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+ %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
+ %e0 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e1 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e2 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e3 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e4 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e5 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %cast = emitc.cast %e0 : i32 to i1
+ %add = emitc.add %e1, %c0 : (i32, i32) -> i32
+ %call = emitc.call_opaque "bar" (%e2, %c0) : (i32, i32) -> (i32)
+ %cond = emitc.conditional %cast, %e3, %c0 : i32
+ %var = "emitc.variable"() {value = #emitc.opaque<"">} : () -> i32
+ emitc.assign %e4 : i32 to %var : i32
+ return %e5 : i32
+}
+
// CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]];
@@ -154,7 +222,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32
// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_3]] % [[VAL_4]];
// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]];
-// CPP-DEFAULT-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) {
+// CPP-DEFAULT-NEXT: if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) {
// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]];
// CPP-DEFAULT-NEXT: } else {
// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]];
@@ -169,7 +237,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32
// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_3]] % [[VAL_4]];
// CPP-DECLTOP-NEXT: [[VAL_6]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: ;
-// CPP-DECLTOP-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) {
+// CPP-DECLTOP-NEXT: if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) {
// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]];
// CPP-DECLTOP-NEXT: } else {
// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]];
@@ -205,13 +273,13 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32)
// CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = [[VAL_1]] % [[VAL_2]];
-// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
+// CPP-DEFAULT-NEXT: return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]);
// CPP-DEFAULT-NEXT: }
// CPP-DECLTOP: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]];
// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]] % [[VAL_2]];
-// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
+// CPP-DECLTOP-NEXT: return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]);
// CPP-DECLTOP-NEXT: }
func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 {
diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir
index 60988bcb46556..2e41dce45f580 100644
--- a/mlir/test/Target/Cpp/for.mlir
+++ b/mlir/test/Target/Cpp/for.mlir
@@ -20,14 +20,14 @@ func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
return
}
// CPP-DEFAULT: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) {
-// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) {
+// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) {
// CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]] = f();
// CPP-DEFAULT-NEXT: }
// CPP-DEFAULT-NEXT: return;
// CPP-DECLTOP: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) {
// CPP-DECLTOP-NEXT: int32_t [[V4:[^ ]*]];
-// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) {
+// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) {
// CPP-DECLTOP-NEXT: [[V4]] = f();
// CPP-DECLTOP-NEXT: }
// CPP-DECLTOP-NEXT: return;
>From aa4b1bfbf320886e3855fbb0b2a0a3f76cae455a Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Wed, 29 May 2024 14:15:53 +0000
Subject: [PATCH 2/4] Skip parenthesis where its safe
---
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 10 +++++++--
mlir/test/Target/Cpp/expressions.mlir | 28 +++++++++++++-------------
mlir/test/Target/Cpp/for.mlir | 4 ++--
3 files changed, 24 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index e7d80d80855a5..83ef2a39950f2 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1339,10 +1339,16 @@ LogicalResult CppEmitter::emitOperand(Value value) {
auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
if (expressionOp && shouldBeInlined(expressionOp)) {
- os << "(";
+ Operation *user = *expressionOp->getUsers().begin();
+ const bool safeToSkipParentheses =
+ isa<emitc::AssignOp, emitc::CallOp, emitc::CallOpaqueOp, emitc::ForOp,
+ emitc::IfOp, emitc::ReturnOp, func::CallOp, func::ReturnOp>(user);
+ if (!safeToSkipParentheses)
+ os << "(";
if (failed(emitExpression(expressionOp)))
return failure();
- os << ")";
+ if (!safeToSkipParentheses)
+ os << ")";
return success();
}
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index 37e0a0ffbdeb1..1c55b9404225d 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -66,11 +66,11 @@ func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 {
}
// CPP-DEFAULT: float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DEFAULT-NEXT: return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]));
+// CPP-DEFAULT-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
// CPP-DEFAULT-NEXT: }
// CPP-DECLTOP: float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DECLTOP-NEXT: return ((float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]));
+// CPP-DECLTOP-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
// CPP-DECLTOP-NEXT: }
func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 {
@@ -84,11 +84,11 @@ func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) ->
}
// CPP-DEFAULT: int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DEFAULT-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: }
// CPP-DECLTOP: int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DECLTOP-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: }
func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
%e = emitc.expression : i32 {
@@ -104,11 +104,11 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
// CPP-DEFAULT-NEXT: int32_t v4 = 0;
// CPP-DEFAULT-NEXT: bool v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
// CPP-DEFAULT-NEXT: int32_t v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
-// CPP-DEFAULT-NEXT: int32_t v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4);
+// CPP-DEFAULT-NEXT: int32_t v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4);
// CPP-DEFAULT-NEXT: int32_t v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
// CPP-DEFAULT-NEXT: int32_t v9;
-// CPP-DEFAULT-NEXT: v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
-// CPP-DEFAULT-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DEFAULT-NEXT: v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: }
// CPP-DECLTOP: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
@@ -121,11 +121,11 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
// CPP-DECLTOP-NEXT: v4 = 0;
// CPP-DECLTOP-NEXT: v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
// CPP-DECLTOP-NEXT: v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
-// CPP-DECLTOP-NEXT: v7 = bar(([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])), v4);
+// CPP-DECLTOP-NEXT: v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4);
// CPP-DECLTOP-NEXT: v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
// CPP-DECLTOP-NEXT: ;
-// CPP-DECLTOP-NEXT: v9 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
-// CPP-DECLTOP-NEXT: return ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
+// CPP-DECLTOP-NEXT: v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: }
func.func @parentheses_for_expression_users(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
%c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
@@ -222,7 +222,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32
// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_3]] % [[VAL_4]];
// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]];
-// CPP-DEFAULT-NEXT: if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) {
+// CPP-DEFAULT-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) {
// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]];
// CPP-DEFAULT-NEXT: } else {
// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]];
@@ -237,7 +237,7 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32
// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_3]] % [[VAL_4]];
// CPP-DECLTOP-NEXT: [[VAL_6]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: ;
-// CPP-DECLTOP-NEXT: if (([[VAL_6]] - [[VAL_4]] < [[VAL_2]])) {
+// CPP-DECLTOP-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) {
// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]];
// CPP-DECLTOP-NEXT: } else {
// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]];
@@ -273,13 +273,13 @@ func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32)
// CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = [[VAL_1]] % [[VAL_2]];
-// CPP-DEFAULT-NEXT: return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]);
+// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
// CPP-DEFAULT-NEXT: }
// CPP-DECLTOP: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]];
// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]] % [[VAL_2]];
-// CPP-DECLTOP-NEXT: return (&[[VAL_4]] - [[VAL_2]] < [[VAL_3]]);
+// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
// CPP-DECLTOP-NEXT: }
func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 {
diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir
index 2e41dce45f580..60988bcb46556 100644
--- a/mlir/test/Target/Cpp/for.mlir
+++ b/mlir/test/Target/Cpp/for.mlir
@@ -20,14 +20,14 @@ func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
return
}
// CPP-DEFAULT: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) {
-// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) {
+// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) {
// CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]] = f();
// CPP-DEFAULT-NEXT: }
// CPP-DEFAULT-NEXT: return;
// CPP-DECLTOP: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) {
// CPP-DECLTOP-NEXT: int32_t [[V4:[^ ]*]];
-// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = ([[V1]] + [[V2]]); [[ITER]] < (([[V2]] * [[V3]])); [[ITER]] += ([[V1]] / [[V3]])) {
+// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) {
// CPP-DECLTOP-NEXT: [[V4]] = f();
// CPP-DECLTOP-NEXT: }
// CPP-DECLTOP-NEXT: return;
>From c6817f88c0ce4cb7e0a86bb6137242f2d13742cc Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Mon, 3 Jun 2024 11:55:33 +0000
Subject: [PATCH 3/4] Do not inline expressions into ops with the CExpression
trait
---
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 19 ++++-------
mlir/test/Target/Cpp/expressions.mlir | 44 ++++++++++++++++----------
2 files changed, 35 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 83ef2a39950f2..01648ba693180 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -303,7 +303,12 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
// Do not inline expressions used by other expressions, as any desired
// expression folding was taken care of by transformations.
- return !user->getParentOfType<ExpressionOp>();
+ if (user->getParentOfType<ExpressionOp>())
+ return false;
+
+ // Do not inline expressions used by ops with the CExpression trait. If this
+ // was intended, the user could have been merged into the expression op.
+ return !user->hasTrait<OpTrait::emitc::CExpression>();
}
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
@@ -1339,17 +1344,7 @@ LogicalResult CppEmitter::emitOperand(Value value) {
auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
if (expressionOp && shouldBeInlined(expressionOp)) {
- Operation *user = *expressionOp->getUsers().begin();
- const bool safeToSkipParentheses =
- isa<emitc::AssignOp, emitc::CallOp, emitc::CallOpaqueOp, emitc::ForOp,
- emitc::IfOp, emitc::ReturnOp, func::CallOp, func::ReturnOp>(user);
- if (!safeToSkipParentheses)
- os << "(";
- if (failed(emitExpression(expressionOp)))
- return failure();
- if (!safeToSkipParentheses)
- os << ")";
- return success();
+ return emitExpression(expressionOp);
}
auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index 1c55b9404225d..810a629c71533 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -100,34 +100,46 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
return %e : i32
}
-// CPP-DEFAULT: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
// CPP-DEFAULT-NEXT: int32_t v4 = 0;
-// CPP-DEFAULT-NEXT: bool v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
-// CPP-DEFAULT-NEXT: int32_t v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
-// CPP-DEFAULT-NEXT: int32_t v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4);
-// CPP-DEFAULT-NEXT: int32_t v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
-// CPP-DEFAULT-NEXT: int32_t v9;
-// CPP-DEFAULT-NEXT: v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: bool v9 = (bool) [[EXP_0]];
+// CPP-DEFAULT-NEXT: int32_t v10 = [[EXP_1]] + v4;
+// CPP-DEFAULT-NEXT: int32_t v11 = bar([[EXP_2]], v4);
+// CPP-DEFAULT-NEXT: int32_t v12 = v9 ? [[EXP_3]] : v4;
+// CPP-DEFAULT-NEXT: int32_t v13;
+// CPP-DEFAULT-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: }
-// CPP-DECLTOP: int32_t parentheses_for_expression_users(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
// CPP-DECLTOP-NEXT: int32_t v4;
-// CPP-DECLTOP-NEXT: bool v5;
+// CPP-DECLTOP-NEXT: int32_t v5;
// CPP-DECLTOP-NEXT: int32_t v6;
// CPP-DECLTOP-NEXT: int32_t v7;
// CPP-DECLTOP-NEXT: int32_t v8;
-// CPP-DECLTOP-NEXT: int32_t v9;
+// CPP-DECLTOP-NEXT: bool v9;
+// CPP-DECLTOP-NEXT: int32_t v10;
+// CPP-DECLTOP-NEXT: int32_t v11;
+// CPP-DECLTOP-NEXT: int32_t v12;
+// CPP-DECLTOP-NEXT: int32_t v13;
// CPP-DECLTOP-NEXT: v4 = 0;
-// CPP-DECLTOP-NEXT: v5 = (bool) ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]));
-// CPP-DECLTOP-NEXT: v6 = ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) + v4;
-// CPP-DECLTOP-NEXT: v7 = bar([[VAL_3]] / ([[VAL_1]] * [[VAL_2]]), v4);
-// CPP-DECLTOP-NEXT: v8 = v5 ? ([[VAL_3]] / ([[VAL_1]] * [[VAL_2]])) : v4;
+// CPP-DECLTOP-NEXT: v5 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: v6 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: v7 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: v8 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: v9 = (bool) v5;
+// CPP-DECLTOP-NEXT: v10 = v6 + v4;
+// CPP-DECLTOP-NEXT: v11 = bar(v7, v4);
+// CPP-DECLTOP-NEXT: v12 = v9 ? v8 : v4;
// CPP-DECLTOP-NEXT: ;
-// CPP-DECLTOP-NEXT: v9 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: }
-func.func @parentheses_for_expression_users(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
%c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
%e0 = emitc.expression : i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
>From fd5b962dffdfb071209004027fba6d9d8633ec0e Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Tue, 4 Jun 2024 08:03:53 +0000
Subject: [PATCH 4/4] Review comments
---
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 3 +-
mlir/test/Target/Cpp/expressions.mlir | 54 +++++++++++++-------------
2 files changed, 28 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 01648ba693180..6cfe846a785dd 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1343,9 +1343,8 @@ LogicalResult CppEmitter::emitOperand(Value value) {
}
auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
- if (expressionOp && shouldBeInlined(expressionOp)) {
+ if (expressionOp && shouldBeInlined(expressionOp))
return emitExpression(expressionOp);
- }
auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
if (!literalOp && !hasValueInScope(value))
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index 810a629c71533..caa0a340d3e0a 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -101,42 +101,42 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
}
// CPP-DEFAULT: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DEFAULT-NEXT: int32_t v4 = 0;
+// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 0;
// CPP-DEFAULT-NEXT: int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
-// CPP-DEFAULT-NEXT: bool v9 = (bool) [[EXP_0]];
-// CPP-DEFAULT-NEXT: int32_t v10 = [[EXP_1]] + v4;
-// CPP-DEFAULT-NEXT: int32_t v11 = bar([[EXP_2]], v4);
-// CPP-DEFAULT-NEXT: int32_t v12 = v9 ? [[EXP_3]] : v4;
-// CPP-DEFAULT-NEXT: int32_t v13;
-// CPP-DEFAULT-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: bool [[CAST:v[0-9]+]] = (bool) [[EXP_0]];
+// CPP-DEFAULT-NEXT: int32_t [[ADD:v[0-9]+]] = [[EXP_1]] + [[VAL_4]];
+// CPP-DEFAULT-NEXT: int32_t [[CALL:v[0-9]+]] = bar([[EXP_2]], [[VAL_4]]);
+// CPP-DEFAULT-NEXT: int32_t [[COND:v[0-9]+]] = [[CAST]] ? [[EXP_3]] : [[VAL_4]];
+// CPP-DEFAULT-NEXT: int32_t [[VAR:v[0-9]+]];
+// CPP-DEFAULT-NEXT: [[VAR]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: }
// CPP-DECLTOP: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
-// CPP-DECLTOP-NEXT: int32_t v4;
-// CPP-DECLTOP-NEXT: int32_t v5;
-// CPP-DECLTOP-NEXT: int32_t v6;
-// CPP-DECLTOP-NEXT: int32_t v7;
-// CPP-DECLTOP-NEXT: int32_t v8;
-// CPP-DECLTOP-NEXT: bool v9;
-// CPP-DECLTOP-NEXT: int32_t v10;
-// CPP-DECLTOP-NEXT: int32_t v11;
-// CPP-DECLTOP-NEXT: int32_t v12;
-// CPP-DECLTOP-NEXT: int32_t v13;
-// CPP-DECLTOP-NEXT: v4 = 0;
-// CPP-DECLTOP-NEXT: v5 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
-// CPP-DECLTOP-NEXT: v6 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
-// CPP-DECLTOP-NEXT: v7 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
-// CPP-DECLTOP-NEXT: v8 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
-// CPP-DECLTOP-NEXT: v9 = (bool) v5;
-// CPP-DECLTOP-NEXT: v10 = v6 + v4;
-// CPP-DECLTOP-NEXT: v11 = bar(v7, v4);
-// CPP-DECLTOP-NEXT: v12 = v9 ? v8 : v4;
+// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[EXP_0:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[EXP_1:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[EXP_2:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[EXP_3:v[0-9]+]];
+// CPP-DECLTOP-NEXT: bool [[CAST:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[ADD:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[CALL:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[COND:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[VAR:v[0-9]+]];
+// CPP-DECLTOP-NEXT: [[VAL_4]] = 0;
+// CPP-DECLTOP-NEXT: [[EXP_0]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: [[EXP_1]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: [[EXP_2]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: [[EXP_3]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: [[CAST]] = (bool) [[EXP_0]];
+// CPP-DECLTOP-NEXT: [[ADD]] = [[EXP_1]] + [[VAL_4]];
+// CPP-DECLTOP-NEXT: [[CALL]] = bar([[EXP_2]], [[VAL_4]]);
+// CPP-DECLTOP-NEXT: [[COND]] = [[CAST]] ? [[EXP_3]] : [[VAL_4]];
// CPP-DECLTOP-NEXT: ;
-// CPP-DECLTOP-NEXT: v13 = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: [[VAR]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: }
func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
More information about the Mlir-commits
mailing list