[Mlir-commits] [mlir] 962f71b - [mlir][index] Fold `cmp(max/min(x, cstA), cstB)`

Jeff Niu llvmlistbot at llvm.org
Mon Jun 26 11:49:23 PDT 2023


Author: Jeff Niu
Date: 2023-06-26T11:49:09-07:00
New Revision: 962f71be5a2b65d75e1f42c36bb134448f6e3a0d

URL: https://github.com/llvm/llvm-project/commit/962f71be5a2b65d75e1f42c36bb134448f6e3a0d
DIFF: https://github.com/llvm/llvm-project/commit/962f71be5a2b65d75e1f42c36bb134448f6e3a0d.diff

LOG: [mlir][index] Fold `cmp(max/min(x, cstA), cstB)`

This is a case that is not picked up by integer range inference and
suggests a weakness with integer range inference on the index dialect.
The problem is that when `[1, SMAX_64]` is truncated to 32 bits, the
resulting range could be `[SMIN_32, SMAX_32]`, making the subsequent
comparison worthless. This is because integer range inference doesn't
know that the result of the max/min inference also changes based on the
bitwidth, and doing the truncation locally at the input of the
comparison op loses that information.

This also was a pattern that frequently showed up in our code, so adding
it as a folder allows dead code to be pruned more frequently.

Depends on D153731

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D153732

Added: 
    

Modified: 
    mlir/lib/Dialect/Index/IR/IndexOps.cpp
    mlir/test/Dialect/Index/index-canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 3218933c84afc..fb6f891300225 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -10,9 +10,11 @@
 #include "mlir/Dialect/Index/IR/IndexAttrs.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 #include "llvm/ADT/SmallString.h"
-#include <optional>
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::index;
@@ -313,9 +315,10 @@ OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
-  return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
-    return lhs.slt(rhs) ? lhs : rhs;
-  });
+  return foldBinaryOpChecked(adaptor.getOperands(),
+                             [](const APInt &lhs, const APInt &rhs) {
+                               return lhs.slt(rhs) ? lhs : rhs;
+                             });
 }
 
 //===----------------------------------------------------------------------===//
@@ -323,9 +326,10 @@ OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
-  return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
-    return lhs.ult(rhs) ? lhs : rhs;
-  });
+  return foldBinaryOpChecked(adaptor.getOperands(),
+                             [](const APInt &lhs, const APInt &rhs) {
+                               return lhs.ult(rhs) ? lhs : rhs;
+                             });
 }
 
 //===----------------------------------------------------------------------===//
@@ -455,19 +459,64 @@ bool compareIndices(const APInt &lhs, const APInt &rhs,
   llvm_unreachable("unhandled IndexCmpPredicate predicate");
 }
 
+/// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
+/// values of `cstA` and `cstB`, the max or min operation, and the comparison
+/// predicate. Check whether the value folds in both 32-bit and 64-bit
+/// arithmetic and to the same value.
+static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
+                                             const APInt &cstA,
+                                             const APInt &cstB, unsigned width,
+                                             IndexCmpPredicate pred) {
+  ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp)
+                                   .Case([&](MinSOp op) {
+                                     return ConstantIntRanges::fromSigned(
+                                         APInt::getSignedMinValue(width), cstA);
+                                   })
+                                   .Case([&](MinUOp op) {
+                                     return ConstantIntRanges::fromUnsigned(
+                                         APInt::getMinValue(width), cstA);
+                                   })
+                                   .Case([&](MaxSOp op) {
+                                     return ConstantIntRanges::fromSigned(
+                                         cstA, APInt::getSignedMaxValue(width));
+                                   })
+                                   .Case([&](MaxUOp op) {
+                                     return ConstantIntRanges::fromUnsigned(
+                                         cstA, APInt::getMaxValue(width));
+                                   });
+  return intrange::evaluatePred(static_cast<intrange::CmpPredicate>(pred),
+                                lhsRange, ConstantIntRanges::constant(cstB));
+}
+
 OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
+  // Attempt to fold if both inputs are constant.
   auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
   auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
-  if (!lhs || !rhs)
-    return {};
+  if (lhs && rhs) {
+    // Perform the comparison in 64-bit and 32-bit.
+    bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
+    bool result32 = compareIndices(lhs.getValue().trunc(32),
+                                   rhs.getValue().trunc(32), getPred());
+    if (result64 == result32)
+      return BoolAttr::get(getContext(), result64);
+  }
 
-  // Perform the comparison in 64-bit and 32-bit.
-  bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
-  bool result32 = compareIndices(lhs.getValue().trunc(32),
-                                 rhs.getValue().trunc(32), getPred());
-  if (result64 != result32)
-    return {};
-  return BoolAttr::get(getContext(), result64);
+  // Fold `cmp(max/min(x, cstA), cstB)`.
+  Operation *lhsOp = getLhs().getDefiningOp();
+  IntegerAttr cstA;
+  if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
+      matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) {
+    std::optional<bool> result64 = foldCmpOfMaxOrMin(
+        lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
+    std::optional<bool> result32 =
+        foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32),
+                          rhs.getValue().trunc(32), 32, getPred());
+    // Fold if the 32-bit and 64-bit results are the same.
+    if (result64 && result32 && *result64 == *result32)
+      return BoolAttr::get(getContext(), *result64);
+  }
+
+  return {};
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index a9b060bbd6a09..56a3cb4c6031b 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -510,3 +510,14 @@ func.func @cmp_edge() -> i1 {
   // CHECK: return %[[TRUE]]
   return %0 : i1
 }
+
+// CHECK-LABEL: @cmp_maxs
+func.func @cmp_maxs(%arg0: index) -> (i1, i1) {
+  %idx0 = index.constant 0
+  %idx1 = index.constant 1
+  %0 = index.maxs %arg0, %idx1
+  %1 = index.cmp sgt(%0, %idx0)
+  %2 = index.cmp eq(%0, %idx0)
+  // CHECK: return %true, %false
+  return %1, %2 : i1, i1
+}


        


More information about the Mlir-commits mailing list