[Mlir-commits] [mlir] [mlir][EmitC] Add an `emitc.conditional` operator (PR #84883)
Marius Brehler
llvmlistbot at llvm.org
Tue Mar 12 02:17:16 PDT 2024
https://github.com/marbre updated https://github.com/llvm/llvm-project/pull/84883
>From a3cdad058354b522c93747b144fc05a97751970f Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Tue, 12 Mar 2024 08:01:11 +0000
Subject: [PATCH 1/2] [mlir][EmitC] Add an `emitc.conditional` operator
This adds an `emitc.conditional` operation for the ternary conditional
operator. Furthermore, this adds a converion from `arith.select` to the
new op.
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 30 +++++++++++++++
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 23 +++++++++++-
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 37 ++++++++++++++++---
.../ArithToEmitC/arith-to-emitc.mlir | 8 ++++
mlir/test/Dialect/EmitC/ops.mlir | 5 +++
mlir/test/Target/Cpp/conditional.mlir | 9 +++++
6 files changed, 105 insertions(+), 7 deletions(-)
create mode 100644 mlir/test/Target/Cpp/conditional.mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ac1e38a5506da0..ec842f76628c08 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -908,6 +908,36 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
let hasVerifier = 1;
}
+def EmitC_ConditionalOp : EmitC_Op<"conditional",
+ [AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> {
+ let summary = "Conditional (ternary) operation";
+ let description = [{
+ With the `conditional` operation the ternary conditional operator can
+ be applied.
+
+ Example:
+
+ ```mlir
+ %0 = emitc.cmp gt, %arg0, %arg1 : (i32, i32) -> i1
+
+ %c0 = "emitc.constant"() {value = 10 : i32} : () -> i32
+ %c1 = "emitc.constant"() {value = 11 : i32} : () -> i32
+
+ %1 = emitc.conditional %0, %c0, %c1 : i32
+ ```
+ ```c++
+ // Code emitted for the operations above.
+ bool v3 = v1 > v2;
+ int32_t v4 = 10;
+ int32_t v5 = 11;
+ int32_t v6 = v3 ? v4 : v5;
+ ```
+ }];
+ let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
let summary = "Unary minus operation";
let description = [{
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 40dce001a3b224..afef4ac1bfc5f9 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -54,6 +54,26 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
return success();
}
};
+
+class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
+public:
+ using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!selectOp.getCondition().getType().isInteger())
+ return rewriter.notifyMatchFailure(
+ selectOp, "can only converted if condition is a scalar of type i1");
+
+ rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(
+ selectOp, selectOp.getType(), adaptor.getOperands());
+
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -70,7 +90,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
- ArithOpConversion<arith::SubFOp, emitc::SubOp>
+ ArithOpConversion<arith::SubFOp, emitc::SubOp>,
+ SelectOpConversion
>(typeConverter, ctx);
// clang-format on
}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 3cf137c1d07c0e..7cbb1e9265e174 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -96,6 +96,7 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
}
return op->emitError("unsupported cmp predicate");
})
+ .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
.Case<emitc::DivOp>([&](auto op) { return 13; })
.Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
.Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
@@ -446,6 +447,29 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
return printBinaryOperation(emitter, operation, binaryOperator);
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::ConditionalOp conditionalOp) {
+ raw_ostream &os = emitter.ostream();
+
+ if (failed(emitter.emitAssignPrefix(*conditionalOp)))
+ return failure();
+
+ if (failed(emitter.emitOperand(conditionalOp.getCondition())))
+ return failure();
+
+ os << " ? ";
+
+ if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
+ return failure();
+
+ os << " : ";
+
+ if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
+ return failure();
+
+ return success();
+}
+
static LogicalResult printOperation(CppEmitter &emitter,
emitc::VerbatimOp verbatimOp) {
raw_ostream &os = emitter.ostream();
@@ -1383,12 +1407,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::BitwiseNotOp, emitc::BitwiseOrOp,
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
- emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
- emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
- emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
- emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
- emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
- emitc::VariableOp, emitc::VerbatimOp>(
+ emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
+ emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
+ emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
+ emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
+ emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+ emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
+ emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 2886810c01e917..022530ef4db84b 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -34,3 +34,11 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
return
}
+
+// -----
+
+func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
+ // CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
+ %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
+ return
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 122b1d9ef1059f..5f00a295ed740e 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -71,6 +71,11 @@ func.func @bitwise(%arg0: i32, %arg1: i32) -> () {
return
}
+func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
+ %0 = emitc.conditional %cond, %arg0, %arg1 : i32
+ return
+}
+
func.func @div_int(%arg0: i32, %arg1: i32) {
%1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32
return
diff --git a/mlir/test/Target/Cpp/conditional.mlir b/mlir/test/Target/Cpp/conditional.mlir
new file mode 100644
index 00000000000000..2470fbeb33adae
--- /dev/null
+++ b/mlir/test/Target/Cpp/conditional.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
+ %0 = emitc.conditional %cond, %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: void cond
+// CHECK-NEXT: int32_t [[V3:[^ ]*]] = [[V0:[^ ]*]] ? [[V1:[^ ]*]] : [[V2:[^ ]*]];
>From 807dcb5de1952d0e0a1ff8fd0537103d7475ffc8 Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Tue, 12 Mar 2024 09:17:00 +0000
Subject: [PATCH 2/2] Address review feedback and improve error message
---
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index afef4ac1bfc5f9..752227fe142731 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -63,12 +63,17 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (!selectOp.getCondition().getType().isInteger())
+ auto dstType = getTypeConverter()->convertType(selectOp.getType());
+ if (!dstType)
+ return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
+
+ if (!adaptor.getCondition().getType().isInteger(1))
return rewriter.notifyMatchFailure(
- selectOp, "can only converted if condition is a scalar of type i1");
+ selectOp,
+ "can only be converted if condition is a scalar of type i1");
- rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(
- selectOp, selectOp.getType(), adaptor.getOperands());
+ rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
+ adaptor.getOperands());
return success();
}
More information about the Mlir-commits
mailing list