[Mlir-commits] [mlir] [mlir][EmitC] Add an `emitc.conditional` operator (PR #84883)

Marius Brehler llvmlistbot at llvm.org
Tue Mar 12 02:17:16 PDT 2024


https://github.com/marbre updated https://github.com/llvm/llvm-project/pull/84883

>From a3cdad058354b522c93747b144fc05a97751970f Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Tue, 12 Mar 2024 08:01:11 +0000
Subject: [PATCH 1/2] [mlir][EmitC] Add an `emitc.conditional` operator

This adds an `emitc.conditional` operation for the ternary conditional
operator. Furthermore, this adds a converion from `arith.select` to the
new op.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 30 +++++++++++++++
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 23 +++++++++++-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 37 ++++++++++++++++---
 .../ArithToEmitC/arith-to-emitc.mlir          |  8 ++++
 mlir/test/Dialect/EmitC/ops.mlir              |  5 +++
 mlir/test/Target/Cpp/conditional.mlir         |  9 +++++
 6 files changed, 105 insertions(+), 7 deletions(-)
 create mode 100644 mlir/test/Target/Cpp/conditional.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ac1e38a5506da0..ec842f76628c08 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -908,6 +908,36 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
   let hasVerifier = 1;
 }
 
+def EmitC_ConditionalOp : EmitC_Op<"conditional",
+    [AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> {
+  let summary = "Conditional (ternary) operation";
+  let description = [{
+    With the `conditional` operation the ternary conditional operator can
+    be applied.
+
+    Example:
+
+    ```mlir
+    %0 = emitc.cmp gt, %arg0, %arg1 : (i32, i32) -> i1
+
+    %c0 = "emitc.constant"() {value = 10 : i32} : () -> i32
+    %c1 = "emitc.constant"() {value = 11 : i32} : () -> i32
+
+    %1 = emitc.conditional %0, %c0, %c1 : i32
+    ```
+    ```c++
+    // Code emitted for the operations above.
+    bool v3 = v1 > v2;
+    int32_t v4 = 10;
+    int32_t v5 = 11;
+    int32_t v6 = v3 ? v4 : v5;
+    ```
+  }];
+  let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
 def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
   let summary = "Unary minus operation";
   let description = [{
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 40dce001a3b224..afef4ac1bfc5f9 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -54,6 +54,26 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
     return success();
   }
 };
+
+class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
+public:
+  using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!selectOp.getCondition().getType().isInteger())
+      return rewriter.notifyMatchFailure(
+          selectOp, "can only converted if condition is a scalar of type i1");
+
+    rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(
+        selectOp, selectOp.getType(), adaptor.getOperands());
+
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -70,7 +90,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     ArithOpConversion<arith::AddFOp, emitc::AddOp>,
     ArithOpConversion<arith::DivFOp, emitc::DivOp>,
     ArithOpConversion<arith::MulFOp, emitc::MulOp>,
-    ArithOpConversion<arith::SubFOp, emitc::SubOp>
+    ArithOpConversion<arith::SubFOp, emitc::SubOp>,
+    SelectOpConversion
   >(typeConverter, ctx);
   // clang-format on
 }
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 3cf137c1d07c0e..7cbb1e9265e174 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -96,6 +96,7 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
         }
         return op->emitError("unsupported cmp predicate");
       })
+      .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
       .Case<emitc::DivOp>([&](auto op) { return 13; })
       .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
       .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
@@ -446,6 +447,29 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
   return printBinaryOperation(emitter, operation, binaryOperator);
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::ConditionalOp conditionalOp) {
+  raw_ostream &os = emitter.ostream();
+
+  if (failed(emitter.emitAssignPrefix(*conditionalOp)))
+    return failure();
+
+  if (failed(emitter.emitOperand(conditionalOp.getCondition())))
+    return failure();
+
+  os << " ? ";
+
+  if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
+    return failure();
+
+  os << " : ";
+
+  if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
+    return failure();
+
+  return success();
+}
+
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::VerbatimOp verbatimOp) {
   raw_ostream &os = emitter.ostream();
@@ -1383,12 +1407,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::BitwiseNotOp, emitc::BitwiseOrOp,
                 emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
                 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
-                emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
-                emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
-                emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
-                emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
-                emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
-                emitc::VariableOp, emitc::VerbatimOp>(
+                emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
+                emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
+                emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
+                emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
+                emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+                emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
+                emitc::VerbatimOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 2886810c01e917..022530ef4db84b 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -34,3 +34,11 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
 
   return
 }
+
+// -----
+
+func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
+  // CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
+  %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
+  return
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 122b1d9ef1059f..5f00a295ed740e 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -71,6 +71,11 @@ func.func @bitwise(%arg0: i32, %arg1: i32) -> () {
   return
 }
 
+func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
+  %0 = emitc.conditional %cond, %arg0, %arg1 : i32
+  return
+}
+
 func.func @div_int(%arg0: i32, %arg1: i32) {
   %1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32
   return
diff --git a/mlir/test/Target/Cpp/conditional.mlir b/mlir/test/Target/Cpp/conditional.mlir
new file mode 100644
index 00000000000000..2470fbeb33adae
--- /dev/null
+++ b/mlir/test/Target/Cpp/conditional.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
+  %0 = emitc.conditional %cond, %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: void cond
+// CHECK-NEXT:  int32_t [[V3:[^ ]*]] = [[V0:[^ ]*]] ? [[V1:[^ ]*]] : [[V2:[^ ]*]];

>From 807dcb5de1952d0e0a1ff8fd0537103d7475ffc8 Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Tue, 12 Mar 2024 09:17:00 +0000
Subject: [PATCH 2/2] Address review feedback and improve error message

---
 mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index afef4ac1bfc5f9..752227fe142731 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -63,12 +63,17 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
   matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    if (!selectOp.getCondition().getType().isInteger())
+    auto dstType = getTypeConverter()->convertType(selectOp.getType());
+    if (!dstType)
+      return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
+
+    if (!adaptor.getCondition().getType().isInteger(1))
       return rewriter.notifyMatchFailure(
-          selectOp, "can only converted if condition is a scalar of type i1");
+          selectOp,
+          "can only be converted if condition is a scalar of type i1");
 
-    rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(
-        selectOp, selectOp.getType(), adaptor.getOperands());
+    rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
+                                                      adaptor.getOperands());
 
     return success();
   }



More information about the Mlir-commits mailing list