[Mlir-commits] [mlir] c39915f - [mlir][NFC] Simplify constant checks with isOneInteger and renamed isZeroInteger. (#139340)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 20 14:53:05 PDT 2025


Author: Han-Chung Wang
Date: 2025-05-20T14:53:02-07:00
New Revision: c39915fa2ea2f8cc7f51d3692d17cbc8a968714c

URL: https://github.com/llvm/llvm-project/commit/c39915fa2ea2f8cc7f51d3692d17cbc8a968714c
DIFF: https://github.com/llvm/llvm-project/commit/c39915fa2ea2f8cc7f51d3692d17cbc8a968714c.diff

LOG: [mlir][NFC] Simplify constant checks with isOneInteger and renamed isZeroInteger. (#139340)

The revision adds isOneInteger helper, and simplifies the existing code
with the two methods. It removes some lambda, which makes code cleaner.

For downstream users, you can update the code with the below script.

```bash
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

---------

Signed-off-by: hanhanW <hanhan0912 at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
    mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 2a3a2defb810d..b37fb55b67931 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -24,9 +24,11 @@
 
 namespace mlir {
 
-/// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp
-/// with attribute with value `0`.
-bool isZeroIndex(OpFoldResult v);
+/// Return true if `v` is an IntegerAttr with value `0`.
+bool isZeroInteger(OpFoldResult v);
+
+/// Return true if `v` is an IntegerAttr with value `1`.
+bool isOneInteger(OpFoldResult v);
 
 /// Represents a range (offset, size, and stride) where each element of the
 /// triple may be dynamic or static.

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 04bc62262c3d8..fdf799a20efdd 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
   OpFoldResult offset =
       getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
           .front();
-  if (isConstantIntValue(offset, 0)) {
+  if (isZeroInteger(offset)) {
     rewriter.replaceOp(op, src);
     return success();
   }

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8a8bd5e232c40..5fc3ace5d6aab 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4488,8 +4488,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
 
   // Return true if we have a zero-value tile.
   auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
-    return llvm::any_of(
-        tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
+    return llvm::any_of(tiles, isZeroInteger);
   };
 
   // Verify tiles. Do not allow zero tiles.

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a9370dc003830..1c3b621828315 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3401,10 +3401,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
   SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
   SmallVector<OpFoldResult> steps = loop.getMixedStep();
 
-  if (llvm::all_of(
-          lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
-      llvm::all_of(
-          steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+  if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
     return loop;
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index b1340be04e011..a62510deefc4a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -441,8 +441,8 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
   // If the `padOp` has a nofold attribute and all paddings are known to be 0,
   // explicitly insert a `linalg.copy`.
   if (padOp.getNofoldAttr() &&
-      llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) &&
-      llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) {
+      llvm::all_of(padOp.getMixedLowPad(), isZeroInteger) &&
+      llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) {
     using bufferization::AllocTensorOp;
     Value allocated =
         rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7c2788f16a3b6..4162aa0b71e6d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes(
 
   SmallVector<Value> threadIds = forallOp.getInductionVars();
   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
-      numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+      numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
   int64_t nLoops = loopRanges.size();
   tiledOffsets.reserve(nLoops);
   tiledSizes.reserve(nLoops);
   for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
     bool overflow = loopIdx >= numThreads.size();
-    bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+    bool isZero = !overflow && isZeroInteger(numThreads[loopIdx]);
     // Degenerate case: take the whole domain.
     if (overflow || isZero) {
       tiledOffsets.push_back(loopRanges[loopIdx].offset);
@@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes(
     OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
         b, loc, i + j * m - n,
         {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
-    if (!isConstantIntValue(residualTileSize, 0)) {
+    if (!isZeroInteger(residualTileSize)) {
       OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
           b, loc, -i + m, {offsetPerThread, size});
       tileSizePerThread =
@@ -655,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   Operation *tiledOp = nullptr;
 
   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
-      numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+      numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
   SmallVector<Value> materializedNonZeroNumThreads =
       getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index e8d460020cf69..7c14cc16437fe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -369,7 +369,7 @@ struct LinalgOpPartialReductionInterface
 
     SmallVector<OpFoldResult> tiledShape;
     for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
-      if (isZeroIndex(tileSize)) {
+      if (isZeroInteger(tileSize)) {
         tiledShape.push_back(dimSize);
       } else {
         tiledShape.push_back(tileSize);
@@ -732,7 +732,7 @@ struct PackOpTiling
     // iterated or inner dims are not tiled. Otherwise, it will generate a
     // sequence of non-trivial ops (for partial tiles).
     for (auto offset : offsets.take_back(numTiles))
-      if (!isConstantIntValue(offset, 0))
+      if (!isZeroInteger(offset))
         return failure();
 
     for (auto iter :

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index d3d301ca093b1..bae06c003fd97 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -59,7 +59,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
   TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
 
   void visitDimExpr(AffineDimExpr expr) {
-    isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
+    isTiled |= !isZeroInteger(tileSizes[expr.getPosition()]);
   }
   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
     visit(expr.getLHS());
@@ -741,7 +741,7 @@ SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
   SmallVector<OpFoldResult> offsets;
   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
-    bool isTiled = !isZeroIndex(tileSizes[idx]);
+    bool isTiled = !isZeroInteger(tileSizes[idx]);
     offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
     LLVM_DEBUG(llvm::dbgs()
                << "computeTileOffsets: " << offsets.back() << "\n");
@@ -754,7 +754,7 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
                                            ArrayRef<OpFoldResult> sizeBounds) {
   SmallVector<OpFoldResult> sizes;
   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
-    bool isTiled = !isZeroIndex(tileSizes[idx]);
+    bool isTiled = !isZeroInteger(tileSizes[idx]);
     // Before composing, we need to make range a closed interval.
     OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
@@ -810,7 +810,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
                           bool omitPartialTileCheck) {
   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
-                           [](OpFoldResult v) { return !isZeroIndex(v); })) &&
+                           [](OpFoldResult v) { return !isZeroInteger(v); })) &&
          "expected as many ivs as non-zero sizes");
 
   // Construct (potentially temporary) mins and maxes on which to apply maps

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 82702789c2913..cab0ab8d15d5d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1894,9 +1894,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
     // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
     // are 0.
     if (auto prev = src.getDefiningOp<SubViewOp>())
-      if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
-            return isConstantIntValue(val, 0);
-          }))
+      if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
         return prev.getSource();
 
     return nullptr;
@@ -3290,11 +3288,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
     auto srcSizes = srcSubview.getMixedSizes();
     auto sizes = getMixedSizes();
     auto offsets = getMixedOffsets();
-    bool allOffsetsZero = llvm::all_of(
-        offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+    bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
     auto strides = getMixedStrides();
-    bool allStridesOne = llvm::all_of(
-        strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+    bool allStridesOne = llvm::all_of(strides, isOneInteger);
     bool allSizesSame = llvm::equal(sizes, srcSizes);
     if (allOffsetsZero && allStridesOne && allSizesSame &&
         resultMemrefType == sourceMemrefType)

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 723b4d01186f9..2f5c9436fb8c7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -251,9 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
     // to do.
     SmallVector<OpFoldResult> indices =
         getAsOpFoldResult(loadStoreLikeOp.getIndices());
-    if (llvm::all_of(indices, [](const OpFoldResult &opFold) {
-          return isConstantIntValue(opFold, 0);
-        })) {
+    if (llvm::all_of(indices, isZeroInteger)) {
       return rewriter.notifyMatchFailure(
           loadStoreLikeOp, "no computation to extract: offsets are 0s");
     }

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0cd7da5db9163..719e2c6fa459e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
     tileSizes.resize(numLoops, zero);
     for (auto [index, range, nt] :
          llvm::enumerate(iterationDomain, numThreads)) {
-      if (isConstantIntValue(nt, 0))
+      if (isZeroInteger(nt))
         continue;
 
       tileSizes[index] = affine::makeComposedFoldedAffineApply(
@@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
 
       // Non-tiled cases, set the offset and size to the
       // `loopRange.offset/size`.
-      if (isConstantIntValue(nt, 0)) {
+      if (isZeroInteger(nt)) {
         offsets.push_back(loopRange.offset);
         sizes.push_back(loopRange.size);
         continue;
@@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
           {loopRange.offset, nt, tileSize, loopRange.size});
 
       OpFoldResult size = tileSize;
-      if (!isConstantIntValue(residualTileSize, 0)) {
+      if (!isZeroInteger(residualTileSize)) {
         OpFoldResult sizeMinusOffsetPerThread =
             affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
                                                   {offset, loopRange.size});
@@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
 
       // Non-tiled cases, set the offset and size to the
       // `loopRange.offset/size`.
-      if (isConstantIntValue(tileSize, 0)) {
+      if (isZeroInteger(tileSize)) {
         offsets.push_back(loopRange.offset);
         sizes.push_back(loopRange.size);
         continue;
@@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
   SmallVector<OpFoldResult> lbs, ubs, steps;
   for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
     // No loop if the tile size is 0.
-    if (isConstantIntValue(tileSize, 0))
+    if (isZeroInteger(tileSize))
       continue;
     lbs.push_back(loopRange.offset);
     ubs.push_back(loopRange.size);
@@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp(
     // Prune the zero numthreads.
     SmallVector<OpFoldResult> nonZeroNumThreads;
     for (auto nt : numThreads) {
-      if (isConstantIntValue(nt, 0))
+      if (isZeroInteger(nt))
         continue;
       nonZeroNumThreads.push_back(nt);
     }
@@ -551,7 +551,7 @@ static LogicalResult generateLoopNest(
     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
   // If the tile sizes are all zero, no loops are generated. Just call the
   // callback function to handle untiled case.
-  if (llvm::all_of(tileSizes, isZeroIndex)) {
+  if (llvm::all_of(tileSizes, isZeroInteger)) {
     SmallVector<Value> tiledResults;
     SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
     return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
@@ -999,7 +999,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     // 5b. Early return cloned op if tiling is not happening. We can not
     // return the original op because it could lead to `rewriter.replaceOp(op,
     // op->getResults())` and users would get crash.
-    if (llvm::all_of(tileSizes, isZeroIndex)) {
+    if (llvm::all_of(tileSizes, isZeroInteger)) {
       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
       tilingResult =
           TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
@@ -1290,9 +1290,7 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
                               sliceSizes = sliceOp.getMixedSizes();
 
     // expect all strides of sliceOp being 1
-    if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-          return !isConstantIntValue(ofr, 1);
-        }))
+    if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
       return failure();
 
     unsigned sliceResultNumber =
@@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(
     SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
 
     // 9. Check all insert stride is 1.
-    if (llvm::any_of(strides, [](OpFoldResult stride) {
-          return !isConstantIntValue(stride, 1);
-        })) {
+    if (!llvm::all_of(strides, isOneInteger)) {
       return rewriter.notifyMatchFailure(
           candidateSliceOp, "containingOp's result yield with stride");
     }

diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index d9550fe18dc02..8ab5bdc0c5dc5 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
   // If an `affine.apply` operation is generated for denormalization, the use
   // of `origLb` in those ops must not be replaced. These arent not generated
   // when `origLb == 0` and `origStep == 1`.
-  if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
+  if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
     if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
       preservedUses.insert(preservedUse);
     }
@@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
   }
   Value denormalizedIv;
   SmallPtrSet<Operation *, 2> preserve;
-  bool isStepOne = isConstantIntValue(origStep, 1);
-  bool isZeroBased = isConstantIntValue(origLb, 0);
+  bool isStepOne = isOneInteger(origStep);
+  bool isZeroBased = isZeroInteger(origLb);
 
   Value scaled = normalizedIv;
   if (!isStepOne) {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index b2eca539194a8..3d963dea2f572 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
     // Check for single block, unit-stride for-loop that is generated by
     // sparsifier, which means no data dependence analysis is required,
     // and its loop-body is very restricted in form.
-    if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
+    if (!op.getRegion().hasOneBlock() || !isOneInteger(op.getStep()) ||
         !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
       return failure();
     // Analyze (!codegen) and rewrite (codegen) loop-body.

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9a0d5d7e16960..30ca20fc0d883 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2839,8 +2839,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
     return getResult();
   if (auto result = foldInsertAfterExtractSlice(*this))
     return result;
-  if (llvm::any_of(getMixedSizes(),
-                   [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+  if (llvm::any_of(getMixedSizes(), isZeroInteger))
     return getDest();
   return OpFoldResult();
 }

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 7778a02dbeaf4..92540bd56ecbc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -135,9 +135,9 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
   SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1));
   for (unsigned dim = 0; dim < rank; ++dim) {
     auto low = padOp.getMixedLowPad()[dim];
-    bool hasLowPad = !isConstantIntValue(low, 0);
+    bool hasLowPad = !isZeroInteger(low);
     auto high = padOp.getMixedHighPad()[dim];
-    bool hasHighPad = !isConstantIntValue(high, 0);
+    bool hasHighPad = !isZeroInteger(high);
     auto offset = offsets[dim];
     auto length = sizes[dim];
     // If the dim has no padding, we dont need to calculate new values for that
@@ -208,7 +208,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
 
     // Check if newLength is zero. In that case, no SubTensorOp should be
     // executed.
-    if (isConstantIntValue(newLength, 0)) {
+    if (isZeroInteger(newLength)) {
       hasZeroLen = true;
     } else if (!hasZeroLen) {
       Value check = b.create<arith::CmpIOp>(

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 6525e58d002a2..b6843e560a899 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -646,7 +646,7 @@ static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
   // Dest is not read if it is entirely overwritten. E.g.:
   // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
   bool allOffsetsZero =
-      llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
+      llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger);
   RankedTensorType destType = insertSliceOp.getDestType();
   bool sizesMatchDestSizes =
       areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index a3de7f9b44ae6..2b229d60c691b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
     std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
         isZeroOffsetAndFullSize =
             [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
-              if (!isConstantIntValue(offset, 0))
+              if (!isZeroInteger(offset))
                 return false;
               FailureOr<bool> maybeEqual =
                   ValueBoundsConstraintSet::areEqual(sliceSize, size);
@@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
       // Find the first expanded dim after the first dim with non-unit extracted
       // size.
       for (; i < e; ++i) {
-        if (!isConstantIntValue(sizes[indices[i]], 1)) {
+        if (!isOneInteger(sizes[indices[i]])) {
           // +1 to skip the first non-unit size dim.
           i++;
           break;

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 858adfc436164..6f33f9b55ceb6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -27,9 +27,7 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
     return failure();
 
   // `TilingInterface` currently only supports strides being 1.
-  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 1);
-      }))
+  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
     return failure();
 
   FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
@@ -49,9 +47,7 @@ FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
     return failure();
 
   // `TilingInterface` currently only supports strides being 1.
-  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 1);
-      }))
+  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
     return failure();
 
   FailureOr<TilingResult> tiledResult =

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index fac836ebd7a36..29f7bd6857c27 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -15,14 +15,9 @@
 
 namespace mlir {
 
-bool isZeroIndex(OpFoldResult v) {
-  if (!v)
-    return false;
-  std::optional<int64_t> constint = getConstantIntValue(v);
-  if (!constint)
-    return false;
-  return *constint == 0;
-}
+bool isZeroInteger(OpFoldResult v) { return isConstantIntValue(v, 0); }
+
+bool isOneInteger(OpFoldResult v) { return isConstantIntValue(v, 1); }
 
 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
            SmallVector<OpFoldResult>>

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 060ce7d1d6643..678a88627ca82 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -116,8 +116,7 @@ static bool stridesAllOne(TOp op) {
           std::is_same_v<TOp, vector::InsertStridedSliceOp>,
       "expected vector.extract_strided_slice or vector.insert_strided_slice");
   ArrayAttr strides = op.getStrides();
-  return llvm::all_of(
-      strides, [](auto stride) { return isConstantIntValue(stride, 1); });
+  return llvm::all_of(strides, isOneInteger);
 }
 
 /// Convert an array of attributes into a vector of integers, if possible.

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d4d07c7eadc77..7dbb7a334fe62 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -538,7 +538,7 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
       indices.begin(), indices.begin() + firstDimToCollapse);
   SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
                                        indices.end());
-  if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
+  if (llvm::all_of(indicesToCollapse, isZeroInteger)) {
     indicesAfterCollapsing.push_back(indicesToCollapse[0]);
     return indicesAfterCollapsing;
   }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c635be6e83b6a..71c557f7eda06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1118,7 +1118,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
     ArithIndexingBuilder idxBuilderf(rewriter, loc);
     for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
       OpFoldResult pos = extractPos[i - rankOffset];
-      if (isConstantIntValue(pos, 0))
+      if (isZeroInteger(pos))
         continue;
 
       Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);


        


More information about the Mlir-commits mailing list