[Mlir-commits] [mlir] [mlir][arith] Add support for expanding arith.maxnumf/minnumf ops. (PR #75989)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 19 16:20:08 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

<details>
<summary>Changes</summary>

The revision also updates function names in lit tests to match op name.

Take arith.maxnumf as example:

```
func.func @<!-- -->maxnumf(%lhs: f32, %rhs: f32) -> f32 {
  %result = arith.maxnumf %lhs, %rhs : f32
  return %result : f32
}
```

will be expanded to

```
func.func @<!-- -->maxnumf(%lhs: f32, %rhs: f32) -> f32 {
  %0 = arith.cmpf ugt, %lhs, %rhs : f32
  %1 = arith.select %0, %lhs, %rhs : f32
  %2 = arith.cmpf uno, %lhs, %lhs : f32
  %3 = arith.select %2, %rhs, %1 : f32
  return %3 : f32
}
```

Case 1: Both LHS and RHS are not NaN; LHS > RHS

In this case, `%1` is LHS. `%3` and `%1` have the same value, so `%3` is LHS.

Case 2: LHS is NaN and RHS is not NaN

In this case, `%2` is true, so `%3` is always RHS.

Case 3: LHS is not NaN and RHS is NaN

In this case, `%0` is true and `%1` is LHS. `%2` is false, so `%3` and `%1` have the same value, which is LHS.

Case 4: Both LHS and RHS are NaN:

%1 and RHS are all NaN, so the result is still NaN.

---
Full diff: https://github.com/llvm/llvm-project/pull/75989.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+32-2) 
- (modified) mlir/test/Dialect/Arith/expand-ops.mlir (+36-6) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 42a63316b31c6b..8deb8f028ba458 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -186,6 +186,32 @@ struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
   }
 };
 
+template <typename OpTy, arith::CmpFPredicate pred>
+struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const final {
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+
+    Location loc = op.getLoc();
+    // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
+    static_assert(pred == arith::CmpFPredicate::UGT ||
+                      pred == arith::CmpFPredicate::ULT,
+                  "pred must be either UGT or ULT");
+    Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
+    Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
+
+    // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
+    Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
+                                                 lhs, lhs);
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
+    return success();
+  }
+};
+
 struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -319,7 +345,9 @@ struct ArithExpandOpsPass
       arith::CeilDivUIOp,
       arith::FloorDivSIOp,
       arith::MaximumFOp,
-      arith::MinimumFOp
+      arith::MinimumFOp,
+      arith::MaxNumFOp,
+      arith::MinNumFOp
     >();
 
     if (includeBf16) {
@@ -365,7 +393,9 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   // clang-format off
   patterns.add<
     MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
-    MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>
+    MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
+    MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
+    MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
    >(patterns.getContext());
   // clang-format on
 }
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 2c41f098c6c15c..046e8ff64fba6d 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -176,8 +176,8 @@ func.func @ceildivui_index(%arg0: index, %arg1: index) -> (index) {
 
 // -----
 
-// CHECK-LABEL: func @maxf
-func.func @maxf(%a: f32, %b: f32) -> f32 {
+// CHECK-LABEL: func @maximumf
+func.func @maximumf(%a: f32, %b: f32) -> f32 {
   %result = arith.maximumf %a, %b : f32
   return %result : f32
 }
@@ -190,8 +190,8 @@ func.func @maxf(%a: f32, %b: f32) -> f32 {
 
 // -----
 
-// CHECK-LABEL: func @maxf_vector
-func.func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
+// CHECK-LABEL: func @maximumf_vector
+func.func @maximumf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
   %result = arith.maximumf %a, %b : vector<4xf16>
   return %result : vector<4xf16>
 }
@@ -204,8 +204,23 @@ func.func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
 
 // -----
 
-// CHECK-LABEL: func @minf
-func.func @minf(%a: f32, %b: f32) -> f32 {
+// CHECK-LABEL: func @maxnumf
+func.func @maxnumf(%a: f32, %b: f32) -> f32 {
+  %result = arith.maxnumf %a, %b : f32
+  return %result : f32
+}
+
+// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[LHS]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: func @minimumf
+func.func @minimumf(%a: f32, %b: f32) -> f32 {
   %result = arith.minimumf %a, %b : f32
   return %result : f32
 }
@@ -219,6 +234,21 @@ func.func @minf(%a: f32, %b: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @minnumf
+func.func @minnumf(%a: f32, %b: f32) -> f32 {
+  %result = arith.minnumf %a, %b : f32
+  return %result : f32
+}
+
+// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[LHS]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
 func.func @truncf_f32(%arg0 : f32) -> bf16 {
     %0 = arith.truncf %arg0 : f32 to bf16
     return %0 : bf16

``````````

</details>


https://github.com/llvm/llvm-project/pull/75989


More information about the Mlir-commits mailing list