[Mlir-commits] [mlir] [mlir][NFC] Simplify constant checks with isZeroIndex and isOneIndex. (PR #139340)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 9 17:21:54 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: Han-Chung Wang (hanhanW)

<details>
<summary>Changes</summary>

The revision adds isOneIndex helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner.

---
Full diff: https://github.com/llvm/llvm-project/pull/139340.diff


18 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+4) 
- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+1-2) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1-4) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+5-4) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+1-1) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+3-7) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp (+1-4) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+8-12) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+3-3) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+1-2) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+3-3) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp (+2-6) 
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+6-3) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 2a3a2defb810d..ea1a2384f8cba 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -28,6 +28,10 @@ namespace mlir {
 /// with attribute with value `0`.
 bool isZeroIndex(OpFoldResult v);
 
+/// Return true if `v` is an IntegerAttr with value `1` of a ConstantIndexOp
+/// with attribute with value `1`.
+bool isOneIndex(OpFoldResult v);
+
 /// Represents a range (offset, size, and stride) where each element of the
 /// triple may be dynamic or static.
 struct Range {
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 04bc62262c3d8..c9e7ae6f8bdb5 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
   OpFoldResult offset =
       getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
           .front();
-  if (isConstantIntValue(offset, 0)) {
+  if (isZeroIndex(offset)) {
     rewriter.replaceOp(op, src);
     return success();
   }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fce0751430305..a6b1e21cd3b53 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4426,8 +4426,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
 
   // Return true if we have a zero-value tile.
   auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
-    return llvm::any_of(
-        tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
+    return llvm::any_of(tiles, isZeroIndex);
   };
 
   // Verify tiles. Do not allow zero tiles.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f6ca109b84f9e..25b0635220f3b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3315,10 +3315,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
   SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
   SmallVector<OpFoldResult> steps = loop.getMixedStep();
 
-  if (llvm::all_of(
-          lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
-      llvm::all_of(
-          steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+  if (llvm::all_of(lbs, isZeroIndex) && llvm::all_of(steps, isOneIndex)) {
     return loop;
   }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7c2788f16a3b6..700be3ad35705 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes(
 
   SmallVector<Value> threadIds = forallOp.getInductionVars();
   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
-      numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+      numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
   int64_t nLoops = loopRanges.size();
   tiledOffsets.reserve(nLoops);
   tiledSizes.reserve(nLoops);
   for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
     bool overflow = loopIdx >= numThreads.size();
-    bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+    bool isZero = !overflow && isZeroIndex(numThreads[loopIdx]);
     // Degenerate case: take the whole domain.
     if (overflow || isZero) {
       tiledOffsets.push_back(loopRanges[loopIdx].offset);
@@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes(
     OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
         b, loc, i + j * m - n,
         {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
-    if (!isConstantIntValue(residualTileSize, 0)) {
+    if (!isZeroIndex(residualTileSize)) {
       OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
           b, loc, -i + m, {offsetPerThread, size});
       tileSizePerThread =
@@ -655,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   Operation *tiledOp = nullptr;
 
   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
-      numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+      numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
   SmallVector<Value> materializedNonZeroNumThreads =
       getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 0cc840403a020..faae77a6eecb3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -732,7 +732,7 @@ struct PackOpTiling
     // iterated or inner dims are not tiled. Otherwise, it will generate a
     // sequence of non-trivial ops (for partial tiles).
     for (auto offset : offsets.take_back(numTiles))
-      if (!isConstantIntValue(offset, 0))
+      if (!isZeroIndex(offset))
         return failure();
 
     for (auto iter :
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..1175c57694272 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1889,9 +1889,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
     // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
     // are 0.
     if (auto prev = src.getDefiningOp<SubViewOp>())
-      if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
-            return isConstantIntValue(val, 0);
-          }))
+      if (llvm::all_of(prev.getMixedOffsets(), isZeroIndex))
         return prev.getSource();
 
     return nullptr;
@@ -3285,11 +3283,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
     auto srcSizes = srcSubview.getMixedSizes();
     auto sizes = getMixedSizes();
     auto offsets = getMixedOffsets();
-    bool allOffsetsZero = llvm::all_of(
-        offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+    bool allOffsetsZero = llvm::all_of(offsets, isZeroIndex);
     auto strides = getMixedStrides();
-    bool allStridesOne = llvm::all_of(
-        strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+    bool allStridesOne = llvm::all_of(strides, isOneIndex);
     bool allSizesSame = llvm::equal(sizes, srcSizes);
     if (allOffsetsZero && allStridesOne && allSizesSame &&
         resultMemrefType == sourceMemrefType)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 05ba6a3f38708..e28f7d3e4924a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -251,10 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
     // to do.
     SmallVector<OpFoldResult> indices =
         getAsOpFoldResult(loadStoreLikeOp.getIndices());
-    if (std::all_of(indices.begin(), indices.end(),
-                    [](const OpFoldResult &opFold) {
-                      return isConstantIntValue(opFold, 0);
-                    })) {
+    if (std::all_of(indices.begin(), indices.end(), isZeroIndex)) {
       return rewriter.notifyMatchFailure(
           loadStoreLikeOp, "no computation to extract: offsets are 0s");
     }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0cd7da5db9163..d7d42219bc7b6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
     tileSizes.resize(numLoops, zero);
     for (auto [index, range, nt] :
          llvm::enumerate(iterationDomain, numThreads)) {
-      if (isConstantIntValue(nt, 0))
+      if (isZeroIndex(nt))
         continue;
 
       tileSizes[index] = affine::makeComposedFoldedAffineApply(
@@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
 
       // Non-tiled cases, set the offset and size to the
       // `loopRange.offset/size`.
-      if (isConstantIntValue(nt, 0)) {
+      if (isZeroIndex(nt)) {
         offsets.push_back(loopRange.offset);
         sizes.push_back(loopRange.size);
         continue;
@@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
           {loopRange.offset, nt, tileSize, loopRange.size});
 
       OpFoldResult size = tileSize;
-      if (!isConstantIntValue(residualTileSize, 0)) {
+      if (!isZeroIndex(residualTileSize)) {
         OpFoldResult sizeMinusOffsetPerThread =
             affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
                                                   {offset, loopRange.size});
@@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
 
       // Non-tiled cases, set the offset and size to the
       // `loopRange.offset/size`.
-      if (isConstantIntValue(tileSize, 0)) {
+      if (isZeroIndex(tileSize)) {
         offsets.push_back(loopRange.offset);
         sizes.push_back(loopRange.size);
         continue;
@@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
   SmallVector<OpFoldResult> lbs, ubs, steps;
   for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
     // No loop if the tile size is 0.
-    if (isConstantIntValue(tileSize, 0))
+    if (isZeroIndex(tileSize))
       continue;
     lbs.push_back(loopRange.offset);
     ubs.push_back(loopRange.size);
@@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp(
     // Prune the zero numthreads.
     SmallVector<OpFoldResult> nonZeroNumThreads;
     for (auto nt : numThreads) {
-      if (isConstantIntValue(nt, 0))
+      if (isZeroIndex(nt))
         continue;
       nonZeroNumThreads.push_back(nt);
     }
@@ -1290,9 +1290,7 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
                               sliceSizes = sliceOp.getMixedSizes();
 
     // expect all strides of sliceOp being 1
-    if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-          return !isConstantIntValue(ofr, 1);
-        }))
+    if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
       return failure();
 
     unsigned sliceResultNumber =
@@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(
     SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
 
     // 9. Check all insert stride is 1.
-    if (llvm::any_of(strides, [](OpFoldResult stride) {
-          return !isConstantIntValue(stride, 1);
-        })) {
+    if (!llvm::all_of(strides, isOneIndex)) {
       return rewriter.notifyMatchFailure(
           candidateSliceOp, "containingOp's result yield with stride");
     }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index d9550fe18dc02..f95e38fc75c8d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
   // If an `affine.apply` operation is generated for denormalization, the use
   // of `origLb` in those ops must not be replaced. These arent not generated
   // when `origLb == 0` and `origStep == 1`.
-  if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
+  if (!isZeroIndex(origLb) || !isOneIndex(origStep)) {
     if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
       preservedUses.insert(preservedUse);
     }
@@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
   }
   Value denormalizedIv;
   SmallPtrSet<Operation *, 2> preserve;
-  bool isStepOne = isConstantIntValue(origStep, 1);
-  bool isZeroBased = isConstantIntValue(origLb, 0);
+  bool isStepOne = isOneIndex(origStep);
+  bool isZeroBased = isZeroIndex(origLb);
 
   Value scaled = normalizedIv;
   if (!isStepOne) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index b2eca539194a8..649375b4c4037 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
     // Check for single block, unit-stride for-loop that is generated by
     // sparsifier, which means no data dependence analysis is required,
     // and its loop-body is very restricted in form.
-    if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
+    if (!op.getRegion().hasOneBlock() || !isOneIndex(op.getStep()) ||
         !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
       return failure();
     // Analyze (!codegen) and rewrite (codegen) loop-body.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 29da32cd1791c..717ea1d0d7618 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2738,8 +2738,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
     return getResult();
   if (auto result = foldInsertAfterExtractSlice(*this))
     return result;
-  if (llvm::any_of(getMixedSizes(),
-                   [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+  if (llvm::any_of(getMixedSizes(), isZeroIndex))
     return getDest();
   return OpFoldResult();
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 7778a02dbeaf4..41407064cb6d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -135,9 +135,9 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
   SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1));
   for (unsigned dim = 0; dim < rank; ++dim) {
     auto low = padOp.getMixedLowPad()[dim];
-    bool hasLowPad = !isConstantIntValue(low, 0);
+    bool hasLowPad = !isZeroIndex(low);
     auto high = padOp.getMixedHighPad()[dim];
-    bool hasHighPad = !isConstantIntValue(high, 0);
+    bool hasHighPad = !isZeroIndex(high);
     auto offset = offsets[dim];
     auto length = sizes[dim];
     // If the dim has no padding, we dont need to calculate new values for that
@@ -208,7 +208,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
 
     // Check if newLength is zero. In that case, no SubTensorOp should be
     // executed.
-    if (isConstantIntValue(newLength, 0)) {
+    if (isZeroIndex(newLength)) {
       hasZeroLen = true;
     } else if (!hasZeroLen) {
       Value check = b.create<arith::CmpIOp>(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index a3de7f9b44ae6..9978aac1ee80e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
     std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
         isZeroOffsetAndFullSize =
             [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
-              if (!isConstantIntValue(offset, 0))
+              if (!isZeroIndex(offset))
                 return false;
               FailureOr<bool> maybeEqual =
                   ValueBoundsConstraintSet::areEqual(sliceSize, size);
@@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
       // Find the first expanded dim after the first dim with non-unit extracted
       // size.
       for (; i < e; ++i) {
-        if (!isConstantIntValue(sizes[indices[i]], 1)) {
+        if (!isOneIndex(sizes[indices[i]])) {
           // +1 to skip the first non-unit size dim.
           i++;
           break;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 858adfc436164..36cc31e614f21 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -27,9 +27,7 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
     return failure();
 
   // `TilingInterface` currently only supports strides being 1.
-  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 1);
-      }))
+  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
     return failure();
 
   FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
@@ -49,9 +47,7 @@ FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
     return failure();
 
   // `TilingInterface` currently only supports strides being 1.
-  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 1);
-      }))
+  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
     return failure();
 
   FailureOr<TilingResult> tiledResult =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index fcb736aa031f3..51b51d8aa32e4 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -18,10 +18,13 @@ namespace mlir {
 bool isZeroIndex(OpFoldResult v) {
   if (!v)
     return false;
-  std::optional<int64_t> constint = getConstantIntValue(v);
-  if (!constint)
+  return isConstantIntValue(v, 0);
+}
+
+bool isOneIndex(OpFoldResult v) {
+  if (!v)
     return false;
-  return *constint == 0;
+  return isConstantIntValue(v, 1);
 }
 
 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..4e5c60671b976 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -141,7 +141,7 @@ struct LinearizeVectorExtractStridedSlice final
     ArrayAttr offsets = extractOp.getOffsets();
     ArrayAttr sizes = extractOp.getSizes();
     ArrayAttr strides = extractOp.getStrides();
-    if (!isConstantIntValue(strides[0], 1))
+    if (!isOneIndex(strides[0]))
       return rewriter.notifyMatchFailure(
           extractOp, "Strided slice with stride != 1 is not supported.");
     Value srcVector = adaptor.getVector();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..83dc34e4b4139 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1118,7 +1118,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
     ArithIndexingBuilder idxBuilderf(rewriter, loc);
     for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
       OpFoldResult pos = extractPos[i - rankOffset];
-      if (isConstantIntValue(pos, 0))
+      if (isZeroIndex(pos))
         continue;
 
       Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);

``````````

</details>


https://github.com/llvm/llvm-project/pull/139340


More information about the Mlir-commits mailing list