[Mlir-commits] [mlir] 1b2a1f8 - [MLIR][Arith] Canonicalize cmpf(int to fp) to cmpi
William S. Moses
llvmlistbot at llvm.org
Wed Feb 23 11:09:28 PST 2022
Author: William S. Moses
Date: 2022-02-23T14:09:20-05:00
New Revision: 1b2a1f847354bf027a2ad1591a0b694b721d0177
URL: https://github.com/llvm/llvm-project/commit/1b2a1f847354bf027a2ad1591a0b694b721d0177
DIFF: https://github.com/llvm/llvm-project/commit/1b2a1f847354bf027a2ad1591a0b694b721d0177.diff
LOG: [MLIR][Arith] Canonicalize cmpf(int to fp) to cmpi
Given a cmpf of either uitofp or sitofp and a constant, attempt to canonicalize it to a cmpi.
This PR rewrites equivalent code within LLVM to now apply to MLIR arith.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D117257
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index f31fe2d9f0447..ea60a83998ad7 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -1153,6 +1153,7 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf"> {
}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index b03d9ea9f575d..60c61cdd56a76 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -54,6 +54,9 @@ class FloatType : public Type {
/// Return the bitwidth of this float type.
unsigned getWidth();
+ /// Return the width of the mantissa of this type.
+ unsigned getFPMantissaWidth();
+
/// Get or create a new FloatType with bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
FloatType scaleElementBitwidth(unsigned scale);
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 8074f2c4751fb..eea91dcde9724 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -1393,6 +1393,299 @@ OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
return BoolAttr::get(getContext(), val);
}
+class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
+public:
+ using OpRewritePattern<CmpFOp>::OpRewritePattern;
+
+ static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
+ bool isUnsigned) {
+ using namespace arith;
+ switch (pred) {
+ case CmpFPredicate::UEQ:
+ case CmpFPredicate::OEQ:
+ return CmpIPredicate::eq;
+ case CmpFPredicate::UGT:
+ case CmpFPredicate::OGT:
+ return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
+ case CmpFPredicate::UGE:
+ case CmpFPredicate::OGE:
+ return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
+ case CmpFPredicate::ULT:
+ case CmpFPredicate::OLT:
+ return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
+ case CmpFPredicate::ULE:
+ case CmpFPredicate::OLE:
+ return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
+ case CmpFPredicate::UNE:
+ case CmpFPredicate::ONE:
+ return CmpIPredicate::ne;
+ default:
+ llvm_unreachable("Unexpected predicate!");
+ }
+ }
+
+ LogicalResult matchAndRewrite(CmpFOp op,
+ PatternRewriter &rewriter) const override {
+ FloatAttr flt;
+ if (!matchPattern(op.getRhs(), m_Constant(&flt)))
+ return failure();
+
+ const APFloat &rhs = flt.getValue();
+
+ // Don't attempt to fold a nan.
+ if (rhs.isNaN())
+ return failure();
+
+ // Get the width of the mantissa. We don't want to hack on conversions that
+ // might lose information from the integer, e.g. "i64 -> float"
+ FloatType floatTy = op.getRhs().getType().cast<FloatType>();
+ int mantissaWidth = floatTy.getFPMantissaWidth();
+ if (mantissaWidth <= 0)
+ return failure();
+
+ bool isUnsigned;
+ Value intVal;
+
+ if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
+ isUnsigned = false;
+ intVal = si.getIn();
+ } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
+ isUnsigned = true;
+ intVal = ui.getIn();
+ } else {
+ return failure();
+ }
+
+ // Check to see that the input is converted from an integer type that is
+ // small enough that preserves all bits.
+ auto intTy = intVal.getType().cast<IntegerType>();
+ auto intWidth = intTy.getWidth();
+
+ // Number of bits representing values, as opposed to the sign
+ auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
+
+ // Following test does NOT adjust intWidth downwards for signed inputs,
+ // because the most negative value still requires all the mantissa bits
+ // to distinguish it from one less than that value.
+ if ((int)intWidth > mantissaWidth) {
+ // Conversion would lose accuracy. Check if loss can impact comparison.
+ int exponent = ilogb(rhs);
+ if (exponent == APFloat::IEK_Inf) {
+ int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
+ if (maxExponent < (int)valueBits) {
+ // Conversion could create infinity.
+ return failure();
+ }
+ } else {
+ // Note that if rhs is zero or NaN, then Exp is negative
+ // and first condition is trivially false.
+ if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
+ // Conversion could affect comparison.
+ return failure();
+ }
+ }
+ }
+
+ // Convert to equivalent cmpi predicate
+ CmpIPredicate pred;
+ switch (op.getPredicate()) {
+ case CmpFPredicate::ORD:
+ // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ return success();
+ case CmpFPredicate::UNO:
+ // Int to fp conversion doesn't create a nan (uno checks either is a nan)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ default:
+ pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
+ break;
+ }
+
+ if (!isUnsigned) {
+ // If the rhs value is > SignedMax, fold the comparison. This handles
+ // +INF and large values.
+ APFloat signedMax(rhs.getSemantics());
+ signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
+ APFloat::rmNearestTiesToEven);
+ if (signedMax < rhs) { // smax < 13123.0
+ if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
+ pred == CmpIPredicate::sle)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ else
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ } else {
+ // If the rhs value is > UnsignedMax, fold the comparison. This handles
+ // +INF and large values.
+ APFloat unsignedMax(rhs.getSemantics());
+ unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
+ APFloat::rmNearestTiesToEven);
+ if (unsignedMax < rhs) { // umax < 13123.0
+ if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
+ pred == CmpIPredicate::ule)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ else
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ }
+
+ if (!isUnsigned) {
+ // See if the rhs value is < SignedMin.
+ APFloat signedMin(rhs.getSemantics());
+ signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
+ APFloat::rmNearestTiesToEven);
+ if (signedMin > rhs) { // smin > 12312.0
+ if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
+ pred == CmpIPredicate::sge)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ else
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ } else {
+ // See if the rhs value is < UnsignedMin.
+ APFloat unsignedMin(rhs.getSemantics());
+ unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
+ APFloat::rmNearestTiesToEven);
+ if (unsignedMin > rhs) { // umin > 12312.0
+ if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
+ pred == CmpIPredicate::uge)
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ else
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ }
+
+ // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
+ // [0, UMAX], but it may still be fractional. See if it is fractional by
+ // casting the FP value to the integer value and back, checking for
+ // equality. Don't do this for zero, because -0.0 is not fractional.
+ bool ignored;
+ APSInt rhsInt(intWidth, isUnsigned);
+ if (APFloat::opInvalidOp ==
+ rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
+ // Undefined behavior invoked - the destination type can't represent
+ // the input constant.
+ return failure();
+ }
+
+ if (!rhs.isZero()) {
+ APFloat apf(floatTy.getFloatSemantics(),
+ APInt::getZero(floatTy.getWidth()));
+ apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
+
+ bool equal = apf == rhs;
+ if (!equal) {
+ // If we had a comparison against a fractional value, we have to adjust
+ // the compare predicate and sometimes the value. rhsInt is rounded
+ // towards zero at this point.
+ switch (pred) {
+ default:
+ llvm_unreachable("Unexpected integer comparison!");
+ case CmpIPredicate::ne: // (float)int != 4.4 --> true
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ return success();
+ case CmpIPredicate::eq: // (float)int == 4.4 --> false
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ case CmpIPredicate::ule:
+ // (float)int <= 4.4 --> int <= 4
+ // (float)int <= -4.4 --> false
+ if (rhs.isNegative()) {
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ break;
+ case CmpIPredicate::sle:
+ // (float)int <= 4.4 --> int <= 4
+ // (float)int <= -4.4 --> int < -4
+ if (rhs.isNegative())
+ pred = CmpIPredicate::slt;
+ break;
+ case CmpIPredicate::ult:
+ // (float)int < -4.4 --> false
+ // (float)int < 4.4 --> int <= 4
+ if (rhs.isNegative()) {
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
+ /*width=*/1);
+ return success();
+ }
+ pred = CmpIPredicate::ule;
+ break;
+ case CmpIPredicate::slt:
+ // (float)int < -4.4 --> int < -4
+ // (float)int < 4.4 --> int <= 4
+ if (!rhs.isNegative())
+ pred = CmpIPredicate::sle;
+ break;
+ case CmpIPredicate::ugt:
+ // (float)int > 4.4 --> int > 4
+ // (float)int > -4.4 --> true
+ if (rhs.isNegative()) {
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ return success();
+ }
+ break;
+ case CmpIPredicate::sgt:
+ // (float)int > 4.4 --> int > 4
+ // (float)int > -4.4 --> int >= -4
+ if (rhs.isNegative())
+ pred = CmpIPredicate::sge;
+ break;
+ case CmpIPredicate::uge:
+ // (float)int >= -4.4 --> true
+ // (float)int >= 4.4 --> int > 4
+ if (rhs.isNegative()) {
+ rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
+ /*width=*/1);
+ return success();
+ }
+ pred = CmpIPredicate::ugt;
+ break;
+ case CmpIPredicate::sge:
+ // (float)int >= -4.4 --> int >= -4
+ // (float)int >= 4.4 --> int > 4
+ if (!rhs.isNegative())
+ pred = CmpIPredicate::sgt;
+ break;
+ }
+ }
+ }
+
+ // Lower this FP comparison into an appropriate integer version of the
+ // comparison.
+ rewriter.replaceOpWithNewOp<CmpIOp>(
+ op, pred, intVal,
+ rewriter.create<ConstantOp>(
+ op.getLoc(), intVal.getType(),
+ rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
+ return success();
+ }
+};
+
+void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.insert<CmpFIntToFPConst>(context);
+}
+
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6d3ed12cedf22..d57005237187e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -137,6 +137,10 @@ FloatType FloatType::scaleElementBitwidth(unsigned scale) {
return FloatType();
}
+unsigned FloatType::getFPMantissaWidth() {
+ return APFloat::semanticsPrecision(getFloatSemantics());
+}
+
//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index bb40f62a5e01f..4bdefac2ac5f5 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -788,3 +788,86 @@ func @constant_UItoFP() -> f32 {
%res = arith.sitofp %c0 : i32 to f32
return %res : f32
}
+
+// -----
+
+// Tests rewritten from https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/InstCombine/2008-11-08-FCmp.ll
+// When inst combining an FCMP with the LHS coming from a arith.uitofp instruction, we
+// can lower it to signed ICMP instructions.
+
+// CHECK-LABEL: @test1(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test1(%arg0: i32) -> i1 {
+ %cst = arith.constant 0.000000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf ole, %1, %cst : f64
+ // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+ // CHECK: arith.cmpi ule, %[[arg0]], %[[c0]] : i32
+ return %2 : i1
+}
+
+// CHECK-LABEL: @test2(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test2(%arg0: i32) -> i1 {
+ %cst = arith.constant 0.000000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf olt, %1, %cst : f64
+ return %2 : i1
+ // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+ // CHECK: arith.cmpi ult, %[[arg0]], %[[c0]] : i32
+}
+
+// CHECK-LABEL: @test3(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test3(%arg0: i32) -> i1 {
+ %cst = arith.constant 0.000000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf oge, %1, %cst : f64
+ return %2 : i1
+ // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+ // CHECK: arith.cmpi uge, %[[arg0]], %[[c0]] : i32
+}
+
+// CHECK-LABEL: @test4(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test4(%arg0: i32) -> i1 {
+ %cst = arith.constant 0.000000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf ogt, %1, %cst : f64
+ // CHECK: %[[c0:.+]] = arith.constant 0 : i32
+ // CHECK: arith.cmpi ugt, %[[arg0]], %[[c0]] : i32
+ return %2 : i1
+}
+
+// CHECK-LABEL: @test5(
+func @test5(%arg0: i32) -> i1 {
+ %cst = arith.constant -4.400000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf ogt, %1, %cst : f64
+ return %2 : i1
+ // CHECK: %[[true:.+]] = arith.constant true
+ // CHECK: return %[[true]] : i1
+}
+
+// CHECK-LABEL: @test6(
+func @test6(%arg0: i32) -> i1 {
+ %cst = arith.constant -4.400000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf olt, %1, %cst : f64
+ return %2 : i1
+ // CHECK: %[[false:.+]] = arith.constant false
+ // CHECK: return %[[false]] : i1
+}
+
+// Check that optimizing unsigned >= comparisons correctly distinguishes
+// positive and negative constants.
+// CHECK-LABEL: @test7(
+// CHECK-SAME: %[[arg0:.+]]:
+func @test7(%arg0: i32) -> i1 {
+ %cst = arith.constant 3.200000e+00 : f64
+ %1 = arith.uitofp %arg0: i32 to f64
+ %2 = arith.cmpf oge, %1, %cst : f64
+ return %2 : i1
+ // CHECK: %[[c3:.+]] = arith.constant 3 : i32
+ // CHECK: arith.cmpi ugt, %[[arg0]], %[[c3]] : i32
+}
More information about the Mlir-commits
mailing list