[Mlir-commits] [mlir] 19266ca - [mlir][EmitC] Add an `emitc.conditional` operator (#84883)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 12 03:27:30 PDT 2024


Author: Marius Brehler
Date: 2024-03-12T11:27:26+01:00
New Revision: 19266ca389e3fc3bce9d24c074b836d6e69873ce

URL: https://github.com/llvm/llvm-project/commit/19266ca389e3fc3bce9d24c074b836d6e69873ce
DIFF: https://github.com/llvm/llvm-project/commit/19266ca389e3fc3bce9d24c074b836d6e69873ce.diff

LOG: [mlir][EmitC] Add an `emitc.conditional` operator (#84883)

This adds an `emitc.conditional` operation for the ternary conditional
operator. Furthermore, this adds a converion from `arith.select` to the
new op.

Added: 
    mlir/test/Target/Cpp/conditional.mlir

Modified: 
    mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
    mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
    mlir/lib/Target/Cpp/TranslateToCpp.cpp
    mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
    mlir/test/Dialect/EmitC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ac1e38a5506da0..ec842f76628c08 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -908,6 +908,36 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
   let hasVerifier = 1;
 }
 
+def EmitC_ConditionalOp : EmitC_Op<"conditional",
+    [AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> {
+  let summary = "Conditional (ternary) operation";
+  let description = [{
+    With the `conditional` operation the ternary conditional operator can
+    be applied.
+
+    Example:
+
+    ```mlir
+    %0 = emitc.cmp gt, %arg0, %arg1 : (i32, i32) -> i1
+
+    %c0 = "emitc.constant"() {value = 10 : i32} : () -> i32
+    %c1 = "emitc.constant"() {value = 11 : i32} : () -> i32
+
+    %1 = emitc.conditional %0, %c0, %c1 : i32
+    ```
+    ```c++
+    // Code emitted for the operations above.
+    bool v3 = v1 > v2;
+    int32_t v4 = 10;
+    int32_t v5 = 11;
+    int32_t v6 = v3 ? v4 : v5;
+    ```
+  }];
+  let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
 def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
   let summary = "Unary minus operation";
   let description = [{

diff  --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 40dce001a3b224..3532785c31b939 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -54,6 +54,31 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
     return success();
   }
 };
+
+class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
+public:
+  using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type dstType = getTypeConverter()->convertType(selectOp.getType());
+    if (!dstType)
+      return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
+
+    if (!adaptor.getCondition().getType().isInteger(1))
+      return rewriter.notifyMatchFailure(
+          selectOp,
+          "can only be converted if condition is a scalar of type i1");
+
+    rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
+                                                      adaptor.getOperands());
+
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -70,7 +95,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     ArithOpConversion<arith::AddFOp, emitc::AddOp>,
     ArithOpConversion<arith::DivFOp, emitc::DivOp>,
     ArithOpConversion<arith::MulFOp, emitc::MulOp>,
-    ArithOpConversion<arith::SubFOp, emitc::SubOp>
+    ArithOpConversion<arith::SubFOp, emitc::SubOp>,
+    SelectOpConversion
   >(typeConverter, ctx);
   // clang-format on
 }

diff  --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 3cf137c1d07c0e..7cbb1e9265e174 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -96,6 +96,7 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
         }
         return op->emitError("unsupported cmp predicate");
       })
+      .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
       .Case<emitc::DivOp>([&](auto op) { return 13; })
       .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
       .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
@@ -446,6 +447,29 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
   return printBinaryOperation(emitter, operation, binaryOperator);
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::ConditionalOp conditionalOp) {
+  raw_ostream &os = emitter.ostream();
+
+  if (failed(emitter.emitAssignPrefix(*conditionalOp)))
+    return failure();
+
+  if (failed(emitter.emitOperand(conditionalOp.getCondition())))
+    return failure();
+
+  os << " ? ";
+
+  if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
+    return failure();
+
+  os << " : ";
+
+  if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
+    return failure();
+
+  return success();
+}
+
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::VerbatimOp verbatimOp) {
   raw_ostream &os = emitter.ostream();
@@ -1383,12 +1407,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::BitwiseNotOp, emitc::BitwiseOrOp,
                 emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
                 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
-                emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
-                emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
-                emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
-                emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
-                emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
-                emitc::VariableOp, emitc::VerbatimOp>(
+                emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
+                emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
+                emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
+                emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
+                emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+                emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
+                emitc::VerbatimOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::FuncOp, func::ReturnOp>(

diff  --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 2886810c01e917..022530ef4db84b 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -34,3 +34,11 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
 
   return
 }
+
+// -----
+
+func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
+  // CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
+  %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
+  return
+}

diff  --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 122b1d9ef1059f..5f00a295ed740e 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -71,6 +71,11 @@ func.func @bitwise(%arg0: i32, %arg1: i32) -> () {
   return
 }
 
+func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
+  %0 = emitc.conditional %cond, %arg0, %arg1 : i32
+  return
+}
+
 func.func @div_int(%arg0: i32, %arg1: i32) {
   %1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32
   return

diff  --git a/mlir/test/Target/Cpp/conditional.mlir b/mlir/test/Target/Cpp/conditional.mlir
new file mode 100644
index 00000000000000..2470fbeb33adae
--- /dev/null
+++ b/mlir/test/Target/Cpp/conditional.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
+  %0 = emitc.conditional %cond, %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: void cond
+// CHECK-NEXT:  int32_t [[V3:[^ ]*]] = [[V0:[^ ]*]] ? [[V1:[^ ]*]] : [[V2:[^ ]*]];


        


More information about the Mlir-commits mailing list