[Mlir-commits] [mlir] [MLIR][Arith] expand-ops: Support mini/maxi (PR #90575)
Matthias Gehre
llvmlistbot at llvm.org
Tue Apr 30 08:12:03 PDT 2024
https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/90575
>From 7d3ff77c75395ce52fdf9d8397426a20129ba644 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 29 Apr 2024 16:02:02 +0200
Subject: [PATCH 1/2] [MLIR][Arith] expand-ops: Support mini/maxi
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 25 ++++++++++
mlir/test/Dialect/Arith/expand-ops.mlir | 48 +++++++++++++++++++
2 files changed, 73 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index dd04a599655894..676747ff01d09d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -152,6 +152,23 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
}
};
+template <typename OpTy, arith::CmpIPredicate pred>
+struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const final {
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+
+ Location loc = op.getLoc();
+ Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
+ return success();
+ }
+};
+
template <typename OpTy, arith::CmpFPredicate pred>
struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
public:
@@ -335,6 +352,10 @@ struct ArithExpandOpsPass
arith::CeilDivSIOp,
arith::CeilDivUIOp,
arith::FloorDivSIOp,
+ arith::MaxSIOp,
+ arith::MaxUIOp,
+ arith::MinSIOp,
+ arith::MinUIOp,
arith::MaximumFOp,
arith::MinimumFOp,
arith::MaxNumFOp,
@@ -383,6 +404,10 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
// clang-format off
patterns.add<
+ MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
+ MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
+ MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
+ MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 6bed93e4c969db..174eb468cc0041 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -262,3 +262,51 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
// CHECK-LABEL: @truncf_vector_f32
// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @maxsi(%a: i32, %b: i32) -> i32 {
+ %result = arith.maxsi %a, %b : i32
+ return %result : i32
+}
+// CHECK-LABEL: func @maxsi
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @minsi(%a: i32, %b: i32) -> i32 {
+ %result = arith.minsi %a, %b : i32
+ return %result : i32
+}
+// CHECK-LABEL: func @minsi
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @maxui(%a: i32, %b: i32) -> i32 {
+ %result = arith.maxui %a, %b : i32
+ return %result : i32
+}
+// CHECK-LABEL: func @maxui
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @minui(%a: i32, %b: i32) -> i32 {
+ %result = arith.minui %a, %b : i32
+ return %result : i32
+}
+// CHECK-LABEL: func @minui
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
>From 7dca154695c8b99ca43de51eaf954a9ede063181 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Tue, 30 Apr 2024 17:11:55 +0200
Subject: [PATCH 2/2] Inline loc
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 676747ff01d09d..54be644a710113 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -162,8 +162,7 @@ struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
- Location loc = op.getLoc();
- Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
+ Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
return success();
}
More information about the Mlir-commits
mailing list