[Mlir-commits] [mlir] [mlir][emitc] Add EmitC lowering for arith.cmpi (PR #88700)

Corentin Ferry llvmlistbot at llvm.org
Mon Apr 15 02:09:27 PDT 2024


https://github.com/cferry-AMD created https://github.com/llvm/llvm-project/pull/88700

None

>From 4b0ffa64c50f51d5d2a6748380ec2f3101000bcb Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 15 Apr 2024 10:53:55 +0200
Subject: [PATCH] Add EmitC lowering for arith.cmpi

---
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 74 +++++++++++++++++++
 .../ArithToEmitC/arith-to-emitc.mlir          | 48 ++++++++++++
 2 files changed, 122 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index db493c1294ba2d..9b2544276ce474 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -39,6 +39,79 @@ class ArithConstantOpConversionPattern
   }
 };
 
+class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
+    switch (pred) {
+    case arith::CmpIPredicate::eq:
+    case arith::CmpIPredicate::ne:
+    case arith::CmpIPredicate::slt:
+    case arith::CmpIPredicate::sle:
+    case arith::CmpIPredicate::sgt:
+    case arith::CmpIPredicate::sge:
+      return false;
+    case arith::CmpIPredicate::ult:
+    case arith::CmpIPredicate::ule:
+    case arith::CmpIPredicate::ugt:
+    case arith::CmpIPredicate::uge:
+      return true;
+    }
+    llvm_unreachable("unknown cmpi predicate kind");
+  }
+
+  emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
+    switch (pred) {
+    case arith::CmpIPredicate::eq:
+      return emitc::CmpPredicate::eq;
+    case arith::CmpIPredicate::ne:
+      return emitc::CmpPredicate::ne;
+    case arith::CmpIPredicate::slt:
+    case arith::CmpIPredicate::ult:
+      return emitc::CmpPredicate::lt;
+    case arith::CmpIPredicate::sle:
+    case arith::CmpIPredicate::ule:
+      return emitc::CmpPredicate::le;
+    case arith::CmpIPredicate::sgt:
+    case arith::CmpIPredicate::ugt:
+      return emitc::CmpPredicate::gt;
+    case arith::CmpIPredicate::sge:
+    case arith::CmpIPredicate::uge:
+      return emitc::CmpPredicate::ge;
+    }
+    llvm_unreachable("unknown cmpi predicate kind");
+  }
+
+  LogicalResult
+  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type type = adaptor.getLhs().getType();
+    if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
+      return rewriter.notifyMatchFailure(op, "expected integer or index type");
+    }
+
+    bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
+    emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
+    Type arithmeticType = type;
+    if (type.isUnsignedInteger() != needsUnsigned) {
+      arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
+                                               /*isSigned=*/!needsUnsigned);
+    }
+    Value lhs = adaptor.getLhs();
+    Value rhs = adaptor.getRhs();
+    if (arithmeticType != type) {
+      lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
+                                                    lhs);
+      rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
+                                                    rhs);
+    }
+    rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
+    return success();
+  }
+};
+
 template <typename ArithOp, typename EmitCOp>
 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
 public:
@@ -148,6 +221,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
     IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
     IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
+    CmpIOpConversion,
     SelectOpConversion
   >(typeConverter, ctx);
   // clang-format on
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 76ba518577ab8e..46b407177b46aa 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -93,3 +93,51 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -
   %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
   return
 }
+
+// -----
+
+func.func @arith_cmpi_eq(%arg0: i32, %arg1: i32) -> i1 {
+  // CHECK-LABEL: arith_cmpi_eq
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: i32, [[Arg1:[^ ]*]]: i32)
+  // CHECK-DAG: [[EQ:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg1]] : (i32, i32) -> i1
+  %eq = arith.cmpi eq, %arg0, %arg1 : i32
+  // CHECK: return [[EQ]]
+  return %eq: i1
+}
+
+func.func @arith_cmpi_ult(%arg0: i32, %arg1: i32) -> i1 {
+  // CHECK-LABEL: arith_cmpi_ult
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: i32, [[Arg1:[^ ]*]]: i32)
+  // CHECK-DAG: [[CastArg0:[^ ]*]] = emitc.cast [[Arg0]] : i32 to ui32
+  // CHECK-DAG: [[CastArg1:[^ ]*]] = emitc.cast [[Arg1]] : i32 to ui32
+  // CHECK-DAG: [[ULT:[^ ]*]] = emitc.cmp lt, [[CastArg0]], [[CastArg1]] : (ui32, ui32) -> i1
+  %ult = arith.cmpi ult, %arg0, %arg1 : i32
+
+  // CHECK: return [[ULT]]
+  return %ult: i1
+}
+
+func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
+  // CHECK: emitc.cmp lt, {{.*}} : (ui32, ui32) -> i1
+  %ult = arith.cmpi ult, %arg0, %arg1 : i32
+  // CHECK: emitc.cmp lt, {{.*}} : (i32, i32) -> i1
+  %slt = arith.cmpi slt, %arg0, %arg1 : i32
+  // CHECK: emitc.cmp le, {{.*}} : (ui32, ui32) -> i1
+  %ule = arith.cmpi ule, %arg0, %arg1 : i32
+  // CHECK: emitc.cmp le, {{.*}} : (i32, i32) -> i1
+  %sle = arith.cmpi sle, %arg0, %arg1 : i32
+  // CHECK: emitc.cmp gt, {{.*}} : (ui32, ui32) -> i1
+  %ugt = arith.cmpi ugt, %arg0, %arg1 : i32
+  // CHECK: emitc.cmp gt, {{.*}} : (i32, i32) -> i1
+  %sgt = arith.cmpi sgt, %arg0, %arg1 : i32
+  // CHECK: emitc.cmp ge, {{.*}} : (ui32, ui32) -> i1
+  %uge = arith.cmpi uge, %arg0, %arg1 : i32
+  // CHECK: emitc.cmp ge, {{.*}} : (i32, i32) -> i1
+  %sge = arith.cmpi sge, %arg0, %arg1 : i32
+  // CHECK: emitc.cmp eq, {{.*}} : (i32, i32) -> i1
+  %eq = arith.cmpi eq, %arg0, %arg1 : i32
+  // CHECK: emitc.cmp ne, {{.*}} : (i32, i32) -> i1
+  %ne = arith.cmpi ne, %arg0, %arg1 : i32
+  
+  return
+}



More information about the Mlir-commits mailing list