[Mlir-commits] [mlir] 1b2a1f8 - [MLIR][Arith] Canonicalize cmpf(int to fp) to cmpi

William S. Moses llvmlistbot at llvm.org
Wed Feb 23 11:09:28 PST 2022


Author: William S. Moses
Date: 2022-02-23T14:09:20-05:00
New Revision: 1b2a1f847354bf027a2ad1591a0b694b721d0177

URL: https://github.com/llvm/llvm-project/commit/1b2a1f847354bf027a2ad1591a0b694b721d0177
DIFF: https://github.com/llvm/llvm-project/commit/1b2a1f847354bf027a2ad1591a0b694b721d0177.diff

LOG: [MLIR][Arith] Canonicalize cmpf(int to fp) to cmpi

Given a cmpf of either uitofp or sitofp and a constant, attempt to canonicalize it to a cmpi.

This PR rewrites equivalent code within LLVM to now apply to MLIR arith.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D117257

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index f31fe2d9f0447..ea60a83998ad7 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -1153,6 +1153,7 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf"> {
   }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index b03d9ea9f575d..60c61cdd56a76 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -54,6 +54,9 @@ class FloatType : public Type {
   /// Return the bitwidth of this float type.
   unsigned getWidth();
 
+  /// Return the width of the mantissa of this type.
+  unsigned getFPMantissaWidth();
+
   /// Get or create a new FloatType with bitwidth scaled by `scale`.
   /// Return null if the scaled element type cannot be represented.
   FloatType scaleElementBitwidth(unsigned scale);

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 8074f2c4751fb..eea91dcde9724 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -1393,6 +1393,299 @@ OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
   return BoolAttr::get(getContext(), val);
 }
 
+class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
+public:
+  using OpRewritePattern<CmpFOp>::OpRewritePattern;
+
+  static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
+                                                 bool isUnsigned) {
+    using namespace arith;
+    switch (pred) {
+    case CmpFPredicate::UEQ:
+    case CmpFPredicate::OEQ:
+      return CmpIPredicate::eq;
+    case CmpFPredicate::UGT:
+    case CmpFPredicate::OGT:
+      return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
+    case CmpFPredicate::UGE:
+    case CmpFPredicate::OGE:
+      return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
+    case CmpFPredicate::ULT:
+    case CmpFPredicate::OLT:
+      return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
+    case CmpFPredicate::ULE:
+    case CmpFPredicate::OLE:
+      return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
+    case CmpFPredicate::UNE:
+    case CmpFPredicate::ONE:
+      return CmpIPredicate::ne;
+    default:
+      llvm_unreachable("Unexpected predicate!");
+    }
+  }
+
+  LogicalResult matchAndRewrite(CmpFOp op,
+                                PatternRewriter &rewriter) const override {
+    FloatAttr flt;
+    if (!matchPattern(op.getRhs(), m_Constant(&flt)))
+      return failure();
+
+    const APFloat &rhs = flt.getValue();
+
+    // Don't attempt to fold a nan.
+    if (rhs.isNaN())
+      return failure();
+
+    // Get the width of the mantissa.  We don't want to hack on conversions that
+    // might lose information from the integer, e.g. "i64 -> float"
+    FloatType floatTy = op.getRhs().getType().cast<FloatType>();
+    int mantissaWidth = floatTy.getFPMantissaWidth();
+    if (mantissaWidth <= 0)
+      return failure();
+
+    bool isUnsigned;
+    Value intVal;
+
+    if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
+      isUnsigned = false;
+      intVal = si.getIn();
+    } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
+      isUnsigned = true;
+      intVal = ui.getIn();
+    } else {
+      return failure();
+    }
+
+    // Check to see that the input is converted from an integer type that is
+    // small enough that preserves all bits.
+    auto intTy = intVal.getType().cast<IntegerType>();
+    auto intWidth = intTy.getWidth();
+
+    // Number of bits representing values, as opposed to the sign
+    auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
+
+    // Following test does NOT adjust intWidth downwards for signed inputs,
+    // because the most negative value still requires all the mantissa bits
+    // to distinguish it from one less than that value.
+    if ((int)intWidth > mantissaWidth) {
+      // Conversion would lose accuracy. Check if loss can impact comparison.
+      int exponent = ilogb(rhs);
+      if (exponent == APFloat::IEK_Inf) {
+        int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
+        if (maxExponent < (int)valueBits) {
+          // Conversion could create infinity.
+          return failure();
+        }
+      } else {
+        // Note that if rhs is zero or NaN, then Exp is negative
+        // and first condition is trivially false.
+        if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
+          // Conversion could affect comparison.
+          return failure();
+        }
+      }
+    }
+
+    // Convert to equivalent cmpi predicate
+    CmpIPredicate pred;
+    switch (op.getPredicate()) {
+    case CmpFPredicate::ORD:
+      // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
+      rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+                                                 /*width=*/1);
+      return success();
+    case CmpFPredicate::UNO:
+      // Int to fp conversion doesn't create a nan (uno checks either is a nan)
+      rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+                                                 /*width=*/1);
+      return success();
+    default:
+      pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
+      break;
+    }
+
+    if (!isUnsigned) {
+      // If the rhs value is > SignedMax, fold the comparison.  This handles
+      // +INF and large values.
+      APFloat signedMax(rhs.getSemantics());
+      signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
+                                 APFloat::rmNearestTiesToEven);
+      if (signedMax < rhs) { // smax < 13123.0
+        if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
+            pred == CmpIPredicate::sle)
+          rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+                                                     /*width=*/1);
+        else
+          rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+                                                     /*width=*/1);
+        return success();
+      }
+    } else {
+      // If the rhs value is > UnsignedMax, fold the comparison. This handles
+      // +INF and large values.
+      APFloat unsignedMax(rhs.getSemantics());
+      unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
+                                   APFloat::rmNearestTiesToEven);
+      if (unsignedMax < rhs) { // umax < 13123.0
+        if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
+            pred == CmpIPredicate::ule)
+          rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+                                                     /*width=*/1);
+        else
+          rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+                                                     /*width=*/1);
+        return success();
+      }
+    }
+
+    if (!isUnsigned) {
+      // See if the rhs value is < SignedMin.
+      APFloat signedMin(rhs.getSemantics());
+      signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
+                                 APFloat::rmNearestTiesToEven);
+      if (signedMin > rhs) { // smin > 12312.0
+        if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
+            pred == CmpIPredicate::sge)
+          rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+                                                     /*width=*/1);
+        else
+          rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+                                                     /*width=*/1);
+        return success();
+      }
+    } else {
+      // See if the rhs value is < UnsignedMin.
+      APFloat unsignedMin(rhs.getSemantics());
+      unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
+                                   APFloat::rmNearestTiesToEven);
+      if (unsignedMin > rhs) { // umin > 12312.0
+        if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
+            pred == CmpIPredicate::uge)
+          rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+                                                     /*width=*/1);
+        else
+          rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+                                                     /*width=*/1);
+        return success();
+      }
+    }
+
+    // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
+    // [0, UMAX], but it may still be fractional.  See if it is fractional by
+    // casting the FP value to the integer value and back, checking for
+    // equality. Don't do this for zero, because -0.0 is not fractional.
+    bool ignored;
+    APSInt rhsInt(intWidth, isUnsigned);
+    if (APFloat::opInvalidOp ==
+        rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
+      // Undefined behavior invoked - the destination type can't represent
+      // the input constant.
+      return failure();
+    }
+
+    if (!rhs.isZero()) {
+      APFloat apf(floatTy.getFloatSemantics(),
+                  APInt::getZero(floatTy.getWidth()));
+      apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
+
+      bool equal = apf == rhs;
+      if (!equal) {
+        // If we had a comparison against a fractional value, we have to adjust
+        // the compare predicate and sometimes the value.  rhsInt is rounded
+        // towards zero at this point.
+        switch (pred) {
+        default:
+          llvm_unreachable("Unexpected integer comparison!");
+        case CmpIPredicate::ne: // (float)int != 4.4   --> true
+          rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+                                                     /*width=*/1);
+          return success();
+        case CmpIPredicate::eq: // (float)int == 4.4   --> false
+          rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+                                                     /*width=*/1);
+          return success();
+        case CmpIPredicate::ule:
+          // (float)int <= 4.4   --> int <= 4
+          // (float)int <= -4.4  --> false
+          if (rhs.isNegative()) {
+            rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+                                                       /*width=*/1);
+            return success();
+          }
+          break;
+        case CmpIPredicate::sle:
+          // (float)int <= 4.4   --> int <= 4
+          // (float)int <= -4.4  --> int < -4
+          if (rhs.isNegative())
+            pred = CmpIPredicate::slt;
+          break;
+        case CmpIPredicate::ult:
+          // (float)int < -4.4   --> false
+          // (float)int < 4.4    --> int <= 4
+          if (rhs.isNegative()) {
+            rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+                                                       /*width=*/1);
+            return success();
+          }
+          pred = CmpIPredicate::ule;
+          break;
+        case CmpIPredicate::slt:
+          // (float)int < -4.4   --> int < -4
+          // (float)int < 4.4    --> int <= 4
+          if (!rhs.isNegative())
+            pred = CmpIPredicate::sle;
+          break;
+        case CmpIPredicate::ugt:
+          // (float)int > 4.4    --> int > 4
+          // (float)int > -4.4   --> true
+          if (rhs.isNegative()) {
+            rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+                                                       /*width=*/1);
+            return success();
+          }
+          break;
+        case CmpIPredicate::sgt:
+          // (float)int > 4.4    --> int > 4
+          // (float)int > -4.4   --> int >= -4
+          if (rhs.isNegative())
+            pred = CmpIPredicate::sge;
+          break;
+        case CmpIPredicate::uge:
+          // (float)int >= -4.4   --> true
+          // (float)int >= 4.4    --> int > 4
+          if (rhs.isNegative()) {
+            rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+                                                       /*width=*/1);
+            return success();
+          }
+          pred = CmpIPredicate::ugt;
+          break;
+        case CmpIPredicate::sge:
+          // (float)int >= -4.4   --> int >= -4
+          // (float)int >= 4.4    --> int > 4
+          if (!rhs.isNegative())
+            pred = CmpIPredicate::sgt;
+          break;
+        }
+      }
+    }
+
+    // Lower this FP comparison into an appropriate integer version of the
+    // comparison.
+    rewriter.replaceOpWithNewOp<CmpIOp>(
+        op, pred, intVal,
+        rewriter.create<ConstantOp>(
+            op.getLoc(), intVal.getType(),
+            rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
+    return success();
+  }
+};
+
+void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                MLIRContext *context) {
+  patterns.insert<CmpFIntToFPConst>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // SelectOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6d3ed12cedf22..d57005237187e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -137,6 +137,10 @@ FloatType FloatType::scaleElementBitwidth(unsigned scale) {
   return FloatType();
 }
 
+unsigned FloatType::getFPMantissaWidth() {
+  return APFloat::semanticsPrecision(getFloatSemantics());
+}
+
 //===----------------------------------------------------------------------===//
 // FunctionType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index bb40f62a5e01f..4bdefac2ac5f5 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -788,3 +788,86 @@ func @constant_UItoFP() -> f32 {
   %res = arith.sitofp %c0 : i32 to f32
   return %res : f32
 }
+
+// -----
+
+// Tests rewritten from https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/InstCombine/2008-11-08-FCmp.ll
+// When inst combining an FCMP with the LHS coming from a arith.uitofp instruction, we
+// can lower it to signed ICMP instructions.
+
+// CHECK-LABEL: @test1(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test1(%arg0: i32) -> i1 {
+  %cst = arith.constant 0.000000e+00 : f64
+  %1 = arith.uitofp %arg0: i32 to f64
+  %2 = arith.cmpf ole, %1, %cst : f64
+  // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+  // CHECK: arith.cmpi ule, %[[arg0]], %[[c0]] : i32
+  return %2 : i1
+}
+
+// CHECK-LABEL: @test2(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test2(%arg0: i32) -> i1 {
+  %cst = arith.constant 0.000000e+00 : f64
+  %1 = arith.uitofp %arg0: i32 to f64
+  %2 = arith.cmpf olt, %1, %cst : f64
+  return %2 : i1
+  // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+  // CHECK: arith.cmpi ult, %[[arg0]], %[[c0]] : i32
+}
+
+// CHECK-LABEL: @test3(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test3(%arg0: i32) -> i1 {
+  %cst = arith.constant 0.000000e+00 : f64
+  %1 = arith.uitofp %arg0: i32 to f64
+  %2 = arith.cmpf oge, %1, %cst : f64
+  return %2 : i1
+  // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+  // CHECK: arith.cmpi uge, %[[arg0]], %[[c0]] : i32
+}
+
+// CHECK-LABEL: @test4(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test4(%arg0: i32) -> i1 {
+  %cst = arith.constant 0.000000e+00 : f64
+  %1 = arith.uitofp %arg0: i32 to f64
+  %2 = arith.cmpf ogt, %1, %cst : f64
+  // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+  // CHECK: arith.cmpi ugt, %[[arg0]], %[[c0]] : i32
+  return %2 : i1
+}
+
+// CHECK-LABEL: @test5(
+func @test5(%arg0: i32) -> i1 {
+  %cst = arith.constant -4.400000e+00 : f64
+  %1 = arith.uitofp %arg0: i32 to f64
+  %2 = arith.cmpf ogt, %1, %cst : f64
+  return %2 : i1
+  // CHECK: %[[true:.+]] = arith.constant true
+  // CHECK: return %[[true]] : i1
+}
+
+// CHECK-LABEL: @test6(
+func @test6(%arg0: i32) -> i1 {
+  %cst = arith.constant -4.400000e+00 : f64
+  %1 = arith.uitofp %arg0: i32 to f64
+  %2 = arith.cmpf olt, %1, %cst : f64
+  return %2 : i1
+  // CHECK: %[[false:.+]] = arith.constant false
+  // CHECK: return %[[false]] : i1
+}
+
+// Check that optimizing unsigned >= comparisons correctly distinguishes
+// positive and negative constants.
+// CHECK-LABEL: @test7(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test7(%arg0: i32) -> i1 {
+  %cst = arith.constant 3.200000e+00 : f64
+  %1 = arith.uitofp %arg0: i32 to f64
+  %2 = arith.cmpf oge, %1, %cst : f64
+  return %2 : i1
+  // CHECK: %[[c3:.+]] = arith.constant 3 : i32
+  // CHECK: arith.cmpi ugt, %[[arg0]], %[[c3]] : i32
+}


        


More information about the Mlir-commits mailing list