[Mlir-commits] [mlir] [mlir][arith] Add support for expanding arith.maxnumf/minnumf ops. (PR #75989)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 19 16:20:08 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Han-Chung Wang (hanhanW)
<details>
<summary>Changes</summary>
The revision also updates function names in lit tests to match op name.
Take arith.maxnumf as example:
```
func.func @<!-- -->maxnumf(%lhs: f32, %rhs: f32) -> f32 {
%result = arith.maxnumf %lhs, %rhs : f32
return %result : f32
}
```
will be expanded to
```
func.func @<!-- -->maxnumf(%lhs: f32, %rhs: f32) -> f32 {
%0 = arith.cmpf ugt, %lhs, %rhs : f32
%1 = arith.select %0, %lhs, %rhs : f32
%2 = arith.cmpf uno, %lhs, %lhs : f32
%3 = arith.select %2, %rhs, %1 : f32
return %3 : f32
}
```
Case 1: Both LHS and RHS are not NaN; LHS > RHS
In this case, `%1` is LHS. `%3` and `%1` have the same value, so `%3` is LHS.
Case 2: LHS is NaN and RHS is not NaN
In this case, `%2` is true, so `%3` is always RHS.
Case 3: LHS is not NaN and RHS is NaN
In this case, `%0` is true and `%1` is LHS. `%2` is false, so `%3` and `%1` have the same value, which is LHS.
Case 4: Both LHS and RHS are NaN:
%1 and RHS are all NaN, so the result is still NaN.
---
Full diff: https://github.com/llvm/llvm-project/pull/75989.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+32-2)
- (modified) mlir/test/Dialect/Arith/expand-ops.mlir (+36-6)
``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 42a63316b31c6b..8deb8f028ba458 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -186,6 +186,32 @@ struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
}
};
+template <typename OpTy, arith::CmpFPredicate pred>
+struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const final {
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+
+ Location loc = op.getLoc();
+ // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
+ static_assert(pred == arith::CmpFPredicate::UGT ||
+ pred == arith::CmpFPredicate::ULT,
+ "pred must be either UGT or ULT");
+ Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
+ Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
+
+ // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
+ Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
+ lhs, lhs);
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
+ return success();
+ }
+};
+
struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
@@ -319,7 +345,9 @@ struct ArithExpandOpsPass
arith::CeilDivUIOp,
arith::FloorDivSIOp,
arith::MaximumFOp,
- arith::MinimumFOp
+ arith::MinimumFOp,
+ arith::MaxNumFOp,
+ arith::MinNumFOp
>();
if (includeBf16) {
@@ -365,7 +393,9 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
- MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>
+ MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
+ MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
+ MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
>(patterns.getContext());
// clang-format on
}
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 2c41f098c6c15c..046e8ff64fba6d 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -176,8 +176,8 @@ func.func @ceildivui_index(%arg0: index, %arg1: index) -> (index) {
// -----
-// CHECK-LABEL: func @maxf
-func.func @maxf(%a: f32, %b: f32) -> f32 {
+// CHECK-LABEL: func @maximumf
+func.func @maximumf(%a: f32, %b: f32) -> f32 {
%result = arith.maximumf %a, %b : f32
return %result : f32
}
@@ -190,8 +190,8 @@ func.func @maxf(%a: f32, %b: f32) -> f32 {
// -----
-// CHECK-LABEL: func @maxf_vector
-func.func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
+// CHECK-LABEL: func @maximumf_vector
+func.func @maximumf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
%result = arith.maximumf %a, %b : vector<4xf16>
return %result : vector<4xf16>
}
@@ -204,8 +204,23 @@ func.func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
// -----
-// CHECK-LABEL: func @minf
-func.func @minf(%a: f32, %b: f32) -> f32 {
+// CHECK-LABEL: func @maxnumf
+func.func @maxnumf(%a: f32, %b: f32) -> f32 {
+ %result = arith.maxnumf %a, %b : f32
+ return %result : f32
+}
+
+// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[LHS]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: func @minimumf
+func.func @minimumf(%a: f32, %b: f32) -> f32 {
%result = arith.minimumf %a, %b : f32
return %result : f32
}
@@ -219,6 +234,21 @@ func.func @minf(%a: f32, %b: f32) -> f32 {
// -----
+// CHECK-LABEL: func @minnumf
+func.func @minnumf(%a: f32, %b: f32) -> f32 {
+ %result = arith.minnumf %a, %b : f32
+ return %result : f32
+}
+
+// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[LHS]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
func.func @truncf_f32(%arg0 : f32) -> bf16 {
%0 = arith.truncf %arg0 : f32 to bf16
return %0 : bf16
``````````
</details>
https://github.com/llvm/llvm-project/pull/75989
More information about the Mlir-commits
mailing list