[Mlir-commits] [mlir] a934ddc - [mlir][EmitC] Do not inline expressions used by ops with the CExpression trait (#93691)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 4 04:14:11 PDT 2024
Author: Simon Camphausen
Date: 2024-06-04T13:14:08+02:00
New Revision: a934ddcf7edb583e93102e2fa8b3b05ab34547f2
URL: https://github.com/llvm/llvm-project/commit/a934ddcf7edb583e93102e2fa8b3b05ab34547f2
DIFF: https://github.com/llvm/llvm-project/commit/a934ddcf7edb583e93102e2fa8b3b05ab34547f2.diff
LOG: [mlir][EmitC] Do not inline expressions used by ops with the CExpression trait (#93691)
Currently an expression is inlined without emitting enclosing
parentheses regardless of the context of the user. This could led to
wrong evaluation order depending on the precedence of both expressions.
If the inlining is intended, the user operation should be merged into
the expression op.
Fixes #93470.
Added:
Modified:
mlir/lib/Target/Cpp/TranslateToCpp.cpp
mlir/test/Target/Cpp/expressions.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index f19e0f8c4c2a4..202df89025f26 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -301,9 +301,9 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
if (isa<emitc::SubscriptOp>(user))
return false;
- // Do not inline expressions used by other expressions, as any desired
- // expression folding was taken care of by transformations.
- return !user->getParentOfType<ExpressionOp>();
+ // Do not inline expressions used by ops with the CExpression trait. If this
+ // was intended, the user could have been merged into the expression op.
+ return !user->hasTrait<OpTrait::emitc::CExpression>();
}
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index aaddd5af874a9..caa0a340d3e0a 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -100,6 +100,86 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
return %e : i32
}
+// CPP-DEFAULT: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 0;
+// CPP-DEFAULT-NEXT: int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: bool [[CAST:v[0-9]+]] = (bool) [[EXP_0]];
+// CPP-DEFAULT-NEXT: int32_t [[ADD:v[0-9]+]] = [[EXP_1]] + [[VAL_4]];
+// CPP-DEFAULT-NEXT: int32_t [[CALL:v[0-9]+]] = bar([[EXP_2]], [[VAL_4]]);
+// CPP-DEFAULT-NEXT: int32_t [[COND:v[0-9]+]] = [[CAST]] ? [[EXP_3]] : [[VAL_4]];
+// CPP-DEFAULT-NEXT: int32_t [[VAR:v[0-9]+]];
+// CPP-DEFAULT-NEXT: [[VAR]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[EXP_0:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[EXP_1:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[EXP_2:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[EXP_3:v[0-9]+]];
+// CPP-DECLTOP-NEXT: bool [[CAST:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[ADD:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[CALL:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[COND:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[VAR:v[0-9]+]];
+// CPP-DECLTOP-NEXT: [[VAL_4]] = 0;
+// CPP-DECLTOP-NEXT: [[EXP_0]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: [[EXP_1]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: [[EXP_2]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: [[EXP_3]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: [[CAST]] = (bool) [[EXP_0]];
+// CPP-DECLTOP-NEXT: [[ADD]] = [[EXP_1]] + [[VAL_4]];
+// CPP-DECLTOP-NEXT: [[CALL]] = bar([[EXP_2]], [[VAL_4]]);
+// CPP-DECLTOP-NEXT: [[COND]] = [[CAST]] ? [[EXP_3]] : [[VAL_4]];
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: [[VAR]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: }
+func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+ %c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
+ %e0 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e1 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e2 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e3 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e4 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %e5 = emitc.expression : i32 {
+ %0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %1 = emitc.div %arg2, %0 : (i32, i32) -> i32
+ emitc.yield %1 : i32
+ }
+ %cast = emitc.cast %e0 : i32 to i1
+ %add = emitc.add %e1, %c0 : (i32, i32) -> i32
+ %call = emitc.call_opaque "bar" (%e2, %c0) : (i32, i32) -> (i32)
+ %cond = emitc.conditional %cast, %e3, %c0 : i32
+ %var = "emitc.variable"() {value = #emitc.opaque<"">} : () -> i32
+ emitc.assign %e4 : i32 to %var : i32
+ return %e5 : i32
+}
+
// CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]];
More information about the Mlir-commits
mailing list