[Mlir-commits] [mlir] be1aeb8 - Remove NaN constant from arith.minf, arith.maxf expansion

Christian Sigg llvmlistbot at llvm.org
Wed Jan 12 11:56:48 PST 2022


Author: Christian Sigg
Date: 2022-01-12T20:56:40+01:00
New Revision: be1aeb818cd9d4f329428a035604bebdd0c2f6e1

URL: https://github.com/llvm/llvm-project/commit/be1aeb818cd9d4f329428a035604bebdd0c2f6e1
DIFF: https://github.com/llvm/llvm-project/commit/be1aeb818cd9d4f329428a035604bebdd0c2f6e1.diff

LOG: Remove NaN constant from arith.minf, arith.maxf expansion

If any of the operands is NaN, return the operand instead of a new constant.

When the rhs operand is a constant, the second arith.cmpf+select ops will be folded away.

https://reviews.llvm.org/D117010 marks the two ops commutative, which will place the constant on the rhs.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D117011

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Arithmetic/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
index d06c3043664dd..d836ae5c84f5a 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
@@ -156,19 +156,16 @@ struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
     Value rhs = op.getRhs();
 
     Location loc = op.getLoc();
+    // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
+    static_assert(pred == arith::CmpFPredicate::UGT ||
+                  pred == arith::CmpFPredicate::ULT);
     Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
     Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
 
-    auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
+    // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
     Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
-                                                 lhs, rhs);
-
-    Value nan = rewriter.create<arith::ConstantFloatOp>(
-        loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType);
-    if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>())
-      nan = rewriter.create<SplatOp>(loc, vectorType, nan);
-
-    rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select);
+                                                 rhs, rhs);
+    rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, rhs, select);
     return success();
   }
 };
@@ -226,8 +223,8 @@ void mlir::arith::populateArithmeticExpandOpsPatterns(
     CeilDivSIOpConverter,
     CeilDivUIOpConverter,
     FloorDivSIOpConverter,
-    MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::OGT>,
-    MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::OLT>,
+    MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>,
+    MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>,
     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
     MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
     MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,

diff  --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
index 2f14178e88f2c..f4a557a02b205 100644
--- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
@@ -154,11 +154,10 @@ func @maxf(%a: f32, %b: f32) -> f32 {
   return %result : f32
 }
 // CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32
 // CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
-// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
+// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
@@ -169,12 +168,10 @@ func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
   return %result : vector<4xf16>
 }
 // CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16>
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : vector<4xf16>
 // CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]]
-// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16>
-// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7E00 : f16
-// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16>
-// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]]
+// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : vector<4xf16>
+// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]]
 // CHECK-NEXT: return %[[RESULT]] : vector<4xf16>
 
 // -----
@@ -185,11 +182,10 @@ func @minf(%a: f32, %b: f32) -> f32 {
   return %result : f32
 }
 // CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32
 // CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
-// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
+// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 


        


More information about the Mlir-commits mailing list