[Mlir-commits] [mlir] [mlir][emitc] Add EmitC lowering for arith.cmpi (PR #88700)
Corentin Ferry
llvmlistbot at llvm.org
Mon Apr 15 02:09:27 PDT 2024
https://github.com/cferry-AMD created https://github.com/llvm/llvm-project/pull/88700
None
>From 4b0ffa64c50f51d5d2a6748380ec2f3101000bcb Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Mon, 15 Apr 2024 10:53:55 +0200
Subject: [PATCH] Add EmitC lowering for arith.cmpi
---
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 74 +++++++++++++++++++
.../ArithToEmitC/arith-to-emitc.mlir | 48 ++++++++++++
2 files changed, 122 insertions(+)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index db493c1294ba2d..9b2544276ce474 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -39,6 +39,79 @@ class ArithConstantOpConversionPattern
}
};
+class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
+ switch (pred) {
+ case arith::CmpIPredicate::eq:
+ case arith::CmpIPredicate::ne:
+ case arith::CmpIPredicate::slt:
+ case arith::CmpIPredicate::sle:
+ case arith::CmpIPredicate::sgt:
+ case arith::CmpIPredicate::sge:
+ return false;
+ case arith::CmpIPredicate::ult:
+ case arith::CmpIPredicate::ule:
+ case arith::CmpIPredicate::ugt:
+ case arith::CmpIPredicate::uge:
+ return true;
+ }
+ llvm_unreachable("unknown cmpi predicate kind");
+ }
+
+ emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
+ switch (pred) {
+ case arith::CmpIPredicate::eq:
+ return emitc::CmpPredicate::eq;
+ case arith::CmpIPredicate::ne:
+ return emitc::CmpPredicate::ne;
+ case arith::CmpIPredicate::slt:
+ case arith::CmpIPredicate::ult:
+ return emitc::CmpPredicate::lt;
+ case arith::CmpIPredicate::sle:
+ case arith::CmpIPredicate::ule:
+ return emitc::CmpPredicate::le;
+ case arith::CmpIPredicate::sgt:
+ case arith::CmpIPredicate::ugt:
+ return emitc::CmpPredicate::gt;
+ case arith::CmpIPredicate::sge:
+ case arith::CmpIPredicate::uge:
+ return emitc::CmpPredicate::ge;
+ }
+ llvm_unreachable("unknown cmpi predicate kind");
+ }
+
+ LogicalResult
+ matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Type type = adaptor.getLhs().getType();
+ if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
+ return rewriter.notifyMatchFailure(op, "expected integer or index type");
+ }
+
+ bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
+ emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
+ Type arithmeticType = type;
+ if (type.isUnsignedInteger() != needsUnsigned) {
+ arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
+ /*isSigned=*/!needsUnsigned);
+ }
+ Value lhs = adaptor.getLhs();
+ Value rhs = adaptor.getRhs();
+ if (arithmeticType != type) {
+ lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
+ lhs);
+ rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
+ rhs);
+ }
+ rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
+ return success();
+ }
+};
+
template <typename ArithOp, typename EmitCOp>
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
public:
@@ -148,6 +221,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
+ CmpIOpConversion,
SelectOpConversion
>(typeConverter, ctx);
// clang-format on
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 76ba518577ab8e..46b407177b46aa 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -93,3 +93,51 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
return
}
+
+// -----
+
+func.func @arith_cmpi_eq(%arg0: i32, %arg1: i32) -> i1 {
+ // CHECK-LABEL: arith_cmpi_eq
+ // CHECK-SAME: ([[Arg0:[^ ]*]]: i32, [[Arg1:[^ ]*]]: i32)
+ // CHECK-DAG: [[EQ:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg1]] : (i32, i32) -> i1
+ %eq = arith.cmpi eq, %arg0, %arg1 : i32
+ // CHECK: return [[EQ]]
+ return %eq: i1
+}
+
+func.func @arith_cmpi_ult(%arg0: i32, %arg1: i32) -> i1 {
+ // CHECK-LABEL: arith_cmpi_ult
+ // CHECK-SAME: ([[Arg0:[^ ]*]]: i32, [[Arg1:[^ ]*]]: i32)
+ // CHECK-DAG: [[CastArg0:[^ ]*]] = emitc.cast [[Arg0]] : i32 to ui32
+ // CHECK-DAG: [[CastArg1:[^ ]*]] = emitc.cast [[Arg1]] : i32 to ui32
+ // CHECK-DAG: [[ULT:[^ ]*]] = emitc.cmp lt, [[CastArg0]], [[CastArg1]] : (ui32, ui32) -> i1
+ %ult = arith.cmpi ult, %arg0, %arg1 : i32
+
+ // CHECK: return [[ULT]]
+ return %ult: i1
+}
+
+func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
+ // CHECK: emitc.cmp lt, {{.*}} : (ui32, ui32) -> i1
+ %ult = arith.cmpi ult, %arg0, %arg1 : i32
+ // CHECK: emitc.cmp lt, {{.*}} : (i32, i32) -> i1
+ %slt = arith.cmpi slt, %arg0, %arg1 : i32
+ // CHECK: emitc.cmp le, {{.*}} : (ui32, ui32) -> i1
+ %ule = arith.cmpi ule, %arg0, %arg1 : i32
+ // CHECK: emitc.cmp le, {{.*}} : (i32, i32) -> i1
+ %sle = arith.cmpi sle, %arg0, %arg1 : i32
+ // CHECK: emitc.cmp gt, {{.*}} : (ui32, ui32) -> i1
+ %ugt = arith.cmpi ugt, %arg0, %arg1 : i32
+ // CHECK: emitc.cmp gt, {{.*}} : (i32, i32) -> i1
+ %sgt = arith.cmpi sgt, %arg0, %arg1 : i32
+ // CHECK: emitc.cmp ge, {{.*}} : (ui32, ui32) -> i1
+ %uge = arith.cmpi uge, %arg0, %arg1 : i32
+ // CHECK: emitc.cmp ge, {{.*}} : (i32, i32) -> i1
+ %sge = arith.cmpi sge, %arg0, %arg1 : i32
+ // CHECK: emitc.cmp eq, {{.*}} : (i32, i32) -> i1
+ %eq = arith.cmpi eq, %arg0, %arg1 : i32
+ // CHECK: emitc.cmp ne, {{.*}} : (i32, i32) -> i1
+ %ne = arith.cmpi ne, %arg0, %arg1 : i32
+
+ return
+}
More information about the Mlir-commits
mailing list