[Mlir-commits] [mlir] 962f71b - [mlir][index] Fold `cmp(max/min(x, cstA), cstB)`
Jeff Niu
llvmlistbot at llvm.org
Mon Jun 26 11:49:23 PDT 2023
Author: Jeff Niu
Date: 2023-06-26T11:49:09-07:00
New Revision: 962f71be5a2b65d75e1f42c36bb134448f6e3a0d
URL: https://github.com/llvm/llvm-project/commit/962f71be5a2b65d75e1f42c36bb134448f6e3a0d
DIFF: https://github.com/llvm/llvm-project/commit/962f71be5a2b65d75e1f42c36bb134448f6e3a0d.diff
LOG: [mlir][index] Fold `cmp(max/min(x, cstA), cstB)`
This is a case that is not picked up by integer range inference and
suggests a weakness with integer range inference on the index dialect.
The problem is that when `[1, SMAX_64]` is truncated to 32 bits, the
resulting range could be `[SMIN_32, SMAX_32]`, making the subsequent
comparison worthless. This is because integer range inference doesn't
know that the result of the max/min inference also changes based on the
bitwidth, and doing the truncation locally at the input of the
comparison op loses that information.
This also was a pattern that frequently showed up in our code, so adding
it as a folder allows dead code to be pruned more frequently.
Depends on D153731
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D153732
Added:
Modified:
mlir/lib/Dialect/Index/IR/IndexOps.cpp
mlir/test/Dialect/Index/index-canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 3218933c84afc..fb6f891300225 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -10,9 +10,11 @@
#include "mlir/Dialect/Index/IR/IndexAttrs.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "llvm/ADT/SmallString.h"
-#include <optional>
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::index;
@@ -313,9 +315,10 @@ OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
- return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
- return lhs.slt(rhs) ? lhs : rhs;
- });
+ return foldBinaryOpChecked(adaptor.getOperands(),
+ [](const APInt &lhs, const APInt &rhs) {
+ return lhs.slt(rhs) ? lhs : rhs;
+ });
}
//===----------------------------------------------------------------------===//
@@ -323,9 +326,10 @@ OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
- return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
- return lhs.ult(rhs) ? lhs : rhs;
- });
+ return foldBinaryOpChecked(adaptor.getOperands(),
+ [](const APInt &lhs, const APInt &rhs) {
+ return lhs.ult(rhs) ? lhs : rhs;
+ });
}
//===----------------------------------------------------------------------===//
@@ -455,19 +459,64 @@ bool compareIndices(const APInt &lhs, const APInt &rhs,
llvm_unreachable("unhandled IndexCmpPredicate predicate");
}
+/// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
+/// values of `cstA` and `cstB`, the max or min operation, and the comparison
+/// predicate. Check whether the value folds in both 32-bit and 64-bit
+/// arithmetic and to the same value.
+static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
+ const APInt &cstA,
+ const APInt &cstB, unsigned width,
+ IndexCmpPredicate pred) {
+ ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp)
+ .Case([&](MinSOp op) {
+ return ConstantIntRanges::fromSigned(
+ APInt::getSignedMinValue(width), cstA);
+ })
+ .Case([&](MinUOp op) {
+ return ConstantIntRanges::fromUnsigned(
+ APInt::getMinValue(width), cstA);
+ })
+ .Case([&](MaxSOp op) {
+ return ConstantIntRanges::fromSigned(
+ cstA, APInt::getSignedMaxValue(width));
+ })
+ .Case([&](MaxUOp op) {
+ return ConstantIntRanges::fromUnsigned(
+ cstA, APInt::getMaxValue(width));
+ });
+ return intrange::evaluatePred(static_cast<intrange::CmpPredicate>(pred),
+ lhsRange, ConstantIntRanges::constant(cstB));
+}
+
OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
+ // Attempt to fold if both inputs are constant.
auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
- if (!lhs || !rhs)
- return {};
+ if (lhs && rhs) {
+ // Perform the comparison in 64-bit and 32-bit.
+ bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
+ bool result32 = compareIndices(lhs.getValue().trunc(32),
+ rhs.getValue().trunc(32), getPred());
+ if (result64 == result32)
+ return BoolAttr::get(getContext(), result64);
+ }
- // Perform the comparison in 64-bit and 32-bit.
- bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
- bool result32 = compareIndices(lhs.getValue().trunc(32),
- rhs.getValue().trunc(32), getPred());
- if (result64 != result32)
- return {};
- return BoolAttr::get(getContext(), result64);
+ // Fold `cmp(max/min(x, cstA), cstB)`.
+ Operation *lhsOp = getLhs().getDefiningOp();
+ IntegerAttr cstA;
+ if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
+ matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) {
+ std::optional<bool> result64 = foldCmpOfMaxOrMin(
+ lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
+ std::optional<bool> result32 =
+ foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32),
+ rhs.getValue().trunc(32), 32, getPred());
+ // Fold if the 32-bit and 64-bit results are the same.
+ if (result64 && result32 && *result64 == *result32)
+ return BoolAttr::get(getContext(), *result64);
+ }
+
+ return {};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index a9b060bbd6a09..56a3cb4c6031b 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -510,3 +510,14 @@ func.func @cmp_edge() -> i1 {
// CHECK: return %[[TRUE]]
return %0 : i1
}
+
+// CHECK-LABEL: @cmp_maxs
+func.func @cmp_maxs(%arg0: index) -> (i1, i1) {
+ %idx0 = index.constant 0
+ %idx1 = index.constant 1
+ %0 = index.maxs %arg0, %idx1
+ %1 = index.cmp sgt(%0, %idx0)
+ %2 = index.cmp eq(%0, %idx0)
+ // CHECK: return %true, %false
+ return %1, %2 : i1, i1
+}
More information about the Mlir-commits
mailing list