[Mlir-commits] [mlir] 30badf9 - [MLIR][Arith] expand-ops: Support mini/maxi (#90575)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 30 10:02:36 PDT 2024


Author: Matthias Gehre
Date: 2024-04-30T19:02:32+02:00
New Revision: 30badf96bbaa5ddfd8049442e573fd270a89ddc8

URL: https://github.com/llvm/llvm-project/commit/30badf96bbaa5ddfd8049442e573fd270a89ddc8
DIFF: https://github.com/llvm/llvm-project/commit/30badf96bbaa5ddfd8049442e573fd270a89ddc8.diff

LOG: [MLIR][Arith] expand-ops: Support mini/maxi (#90575)

Expand `arith.minsi`, `arith.minui`, `arith.maxsi`, `arith.maxui` into
`arith.cmpi` and `arith.select`.

---------

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Arith/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index dd04a599655894..54be644a710113 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -152,6 +152,22 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
   }
 };
 
+template <typename OpTy, arith::CmpIPredicate pred>
+struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const final {
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+
+    Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
+    return success();
+  }
+};
+
 template <typename OpTy, arith::CmpFPredicate pred>
 struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
 public:
@@ -335,6 +351,10 @@ struct ArithExpandOpsPass
       arith::CeilDivSIOp,
       arith::CeilDivUIOp,
       arith::FloorDivSIOp,
+      arith::MaxSIOp,
+      arith::MaxUIOp,
+      arith::MinSIOp,
+      arith::MinUIOp,
       arith::MaximumFOp,
       arith::MinimumFOp,
       arith::MaxNumFOp,
@@ -383,6 +403,10 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
   // clang-format off
   patterns.add<
+    MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
+    MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
+    MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
+    MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
     MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
     MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
     MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,

diff  --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 6bed93e4c969db..174eb468cc0041 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -262,3 +262,51 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
 
 // CHECK-LABEL: @truncf_vector_f32
 // CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @maxsi(%a: i32, %b: i32) -> i32 {
+  %result = arith.maxsi %a, %b : i32
+  return %result : i32
+}
+// CHECK-LABEL: func @maxsi
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @minsi(%a: i32, %b: i32) -> i32 {
+  %result = arith.minsi %a, %b : i32
+  return %result : i32
+}
+// CHECK-LABEL: func @minsi
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @maxui(%a: i32, %b: i32) -> i32 {
+  %result = arith.maxui %a, %b : i32
+  return %result : i32
+}
+// CHECK-LABEL: func @maxui
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @minui(%a: i32, %b: i32) -> i32 {
+  %result = arith.minui %a, %b : i32
+  return %result : i32
+}
+// CHECK-LABEL: func @minui
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32


        


More information about the Mlir-commits mailing list