[Mlir-commits] [mlir] 1594ceb - [mlir][EmitC] Fix evaluation order of expressions (#93549)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 29 02:42:09 PDT 2024
Author: Simon Camphausen
Date: 2024-05-29T11:42:06+02:00
New Revision: 1594cebedd60a08f408e3fa975116ef4db86bf9b
URL: https://github.com/llvm/llvm-project/commit/1594cebedd60a08f408e3fa975116ef4db86bf9b
DIFF: https://github.com/llvm/llvm-project/commit/1594cebedd60a08f408e3fa975116ef4db86bf9b.diff
LOG: [mlir][EmitC] Fix evaluation order of expressions (#93549)
Expressions with the same precedence were not parenthesized and
therefore were possibly evaluated in the wrong order depending on the
shape of the expression tree.
---------
Co-authored-by: Matthias Gehre <matthias.gehre at amd.com>
Co-authored-by: Corentin Ferry <corentin.ferry at amd.com>
Added:
Modified:
mlir/lib/Target/Cpp/TranslateToCpp.cpp
mlir/test/Target/Cpp/expressions.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 7db7163bac4ab..f19e0f8c4c2a4 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1316,7 +1316,11 @@ LogicalResult CppEmitter::emitOperand(Value value) {
FailureOr<int> precedence = getOperatorPrecedence(def);
if (failed(precedence))
return failure();
- bool encloseInParenthesis = precedence.value() < getExpressionPrecedence();
+
+ // Sub-expressions with equal or lower precedence need to be parenthesized,
+ // as they might be evaluated in the wrong order depending on the shape of
+ // the expression tree.
+ bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence();
if (encloseInParenthesis) {
os << "(";
pushExpressionPrecedence(lowestPrecedence());
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index 2eda58902cb1d..aaddd5af874a9 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -65,15 +65,15 @@ func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 {
return %e : i32
}
-// CPP-DEFAULT: float paranthesis_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT: float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
// CPP-DEFAULT-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
// CPP-DEFAULT-NEXT: }
-// CPP-DECLTOP: float paranthesis_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP: float parentheses_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
// CPP-DECLTOP-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
// CPP-DECLTOP-NEXT: }
-func.func @paranthesis_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 {
+func.func @parentheses_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 {
%e = emitc.expression : f32 {
%a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
%b = emitc.mul %a, %arg2 : (i32, i32) -> i32
@@ -83,6 +83,23 @@ func.func @paranthesis_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) ->
return %e : f32
}
+// CPP-DEFAULT: int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: int32_t parentheses_for_same_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: }
+func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+ %e = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+
+ return %e : i32
+}
+
// CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]];
More information about the Mlir-commits
mailing list