[Mlir-commits] [mlir] [mlir][NFC] Simplify constant checks with isZeroIndex and isOneIndex. (PR #139340)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 9 17:21:54 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: Han-Chung Wang (hanhanW)
<details>
<summary>Changes</summary>
The revision adds isOneIndex helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner.
---
Full diff: https://github.com/llvm/llvm-project/pull/139340.diff
18 Files Affected:
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+4)
- (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+1-1)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+1-2)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1-4)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+5-4)
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+1-1)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+3-7)
- (modified) mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp (+1-4)
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+8-12)
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+3-3)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+1-1)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+1-2)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+3-3)
- (modified) mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (+2-2)
- (modified) mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp (+2-6)
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+6-3)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+1-1)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 2a3a2defb810d..ea1a2384f8cba 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -28,6 +28,10 @@ namespace mlir {
/// with attribute with value `0`.
bool isZeroIndex(OpFoldResult v);
+/// Return true if `v` is an IntegerAttr with value `1` of a ConstantIndexOp
+/// with attribute with value `1`.
+bool isOneIndex(OpFoldResult v);
+
/// Represents a range (offset, size, and stride) where each element of the
/// triple may be dynamic or static.
struct Range {
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 04bc62262c3d8..c9e7ae6f8bdb5 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
OpFoldResult offset =
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
.front();
- if (isConstantIntValue(offset, 0)) {
+ if (isZeroIndex(offset)) {
rewriter.replaceOp(op, src);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fce0751430305..a6b1e21cd3b53 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4426,8 +4426,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// Return true if we have a zero-value tile.
auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
- return llvm::any_of(
- tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
+ return llvm::any_of(tiles, isZeroIndex);
};
// Verify tiles. Do not allow zero tiles.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f6ca109b84f9e..25b0635220f3b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3315,10 +3315,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
SmallVector<OpFoldResult> steps = loop.getMixedStep();
- if (llvm::all_of(
- lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
- llvm::all_of(
- steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+ if (llvm::all_of(lbs, isZeroIndex) && llvm::all_of(steps, isOneIndex)) {
return loop;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7c2788f16a3b6..700be3ad35705 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
@@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes(
SmallVector<Value> threadIds = forallOp.getInductionVars();
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
- numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+ numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
int64_t nLoops = loopRanges.size();
tiledOffsets.reserve(nLoops);
tiledSizes.reserve(nLoops);
for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
bool overflow = loopIdx >= numThreads.size();
- bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+ bool isZero = !overflow && isZeroIndex(numThreads[loopIdx]);
// Degenerate case: take the whole domain.
if (overflow || isZero) {
tiledOffsets.push_back(loopRanges[loopIdx].offset);
@@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes(
OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
b, loc, i + j * m - n,
{offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
- if (!isConstantIntValue(residualTileSize, 0)) {
+ if (!isZeroIndex(residualTileSize)) {
OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
b, loc, -i + m, {offsetPerThread, size});
tileSizePerThread =
@@ -655,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
Operation *tiledOp = nullptr;
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
- numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+ numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
SmallVector<Value> materializedNonZeroNumThreads =
getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 0cc840403a020..faae77a6eecb3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -732,7 +732,7 @@ struct PackOpTiling
// iterated or inner dims are not tiled. Otherwise, it will generate a
// sequence of non-trivial ops (for partial tiles).
for (auto offset : offsets.take_back(numTiles))
- if (!isConstantIntValue(offset, 0))
+ if (!isZeroIndex(offset))
return failure();
for (auto iter :
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..1175c57694272 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1889,9 +1889,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
// are 0.
if (auto prev = src.getDefiningOp<SubViewOp>())
- if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
- return isConstantIntValue(val, 0);
- }))
+ if (llvm::all_of(prev.getMixedOffsets(), isZeroIndex))
return prev.getSource();
return nullptr;
@@ -3285,11 +3283,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
auto srcSizes = srcSubview.getMixedSizes();
auto sizes = getMixedSizes();
auto offsets = getMixedOffsets();
- bool allOffsetsZero = llvm::all_of(
- offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+ bool allOffsetsZero = llvm::all_of(offsets, isZeroIndex);
auto strides = getMixedStrides();
- bool allStridesOne = llvm::all_of(
- strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+ bool allStridesOne = llvm::all_of(strides, isOneIndex);
bool allSizesSame = llvm::equal(sizes, srcSizes);
if (allOffsetsZero && allStridesOne && allSizesSame &&
resultMemrefType == sourceMemrefType)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 05ba6a3f38708..e28f7d3e4924a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -251,10 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
// to do.
SmallVector<OpFoldResult> indices =
getAsOpFoldResult(loadStoreLikeOp.getIndices());
- if (std::all_of(indices.begin(), indices.end(),
- [](const OpFoldResult &opFold) {
- return isConstantIntValue(opFold, 0);
- })) {
+ if (std::all_of(indices.begin(), indices.end(), isZeroIndex)) {
return rewriter.notifyMatchFailure(
loadStoreLikeOp, "no computation to extract: offsets are 0s");
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0cd7da5db9163..d7d42219bc7b6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
tileSizes.resize(numLoops, zero);
for (auto [index, range, nt] :
llvm::enumerate(iterationDomain, numThreads)) {
- if (isConstantIntValue(nt, 0))
+ if (isZeroIndex(nt))
continue;
tileSizes[index] = affine::makeComposedFoldedAffineApply(
@@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
// Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`.
- if (isConstantIntValue(nt, 0)) {
+ if (isZeroIndex(nt)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
@@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
{loopRange.offset, nt, tileSize, loopRange.size});
OpFoldResult size = tileSize;
- if (!isConstantIntValue(residualTileSize, 0)) {
+ if (!isZeroIndex(residualTileSize)) {
OpFoldResult sizeMinusOffsetPerThread =
affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
{offset, loopRange.size});
@@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
// Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`.
- if (isConstantIntValue(tileSize, 0)) {
+ if (isZeroIndex(tileSize)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
@@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
SmallVector<OpFoldResult> lbs, ubs, steps;
for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
// No loop if the tile size is 0.
- if (isConstantIntValue(tileSize, 0))
+ if (isZeroIndex(tileSize))
continue;
lbs.push_back(loopRange.offset);
ubs.push_back(loopRange.size);
@@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp(
// Prune the zero numthreads.
SmallVector<OpFoldResult> nonZeroNumThreads;
for (auto nt : numThreads) {
- if (isConstantIntValue(nt, 0))
+ if (isZeroIndex(nt))
continue;
nonZeroNumThreads.push_back(nt);
}
@@ -1290,9 +1290,7 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
sliceSizes = sliceOp.getMixedSizes();
// expect all strides of sliceOp being 1
- if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 1);
- }))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
return failure();
unsigned sliceResultNumber =
@@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
// 9. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
+ if (!llvm::all_of(strides, isOneIndex)) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "containingOp's result yield with stride");
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index d9550fe18dc02..f95e38fc75c8d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
// If an `affine.apply` operation is generated for denormalization, the use
// of `origLb` in those ops must not be replaced. These arent not generated
// when `origLb == 0` and `origStep == 1`.
- if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
+ if (!isZeroIndex(origLb) || !isOneIndex(origStep)) {
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
preservedUses.insert(preservedUse);
}
@@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
}
Value denormalizedIv;
SmallPtrSet<Operation *, 2> preserve;
- bool isStepOne = isConstantIntValue(origStep, 1);
- bool isZeroBased = isConstantIntValue(origLb, 0);
+ bool isStepOne = isOneIndex(origStep);
+ bool isZeroBased = isZeroIndex(origLb);
Value scaled = normalizedIv;
if (!isStepOne) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index b2eca539194a8..649375b4c4037 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
// Check for single block, unit-stride for-loop that is generated by
// sparsifier, which means no data dependence analysis is required,
// and its loop-body is very restricted in form.
- if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
+ if (!op.getRegion().hasOneBlock() || !isOneIndex(op.getStep()) ||
!op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
return failure();
// Analyze (!codegen) and rewrite (codegen) loop-body.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 29da32cd1791c..717ea1d0d7618 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2738,8 +2738,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
return getResult();
if (auto result = foldInsertAfterExtractSlice(*this))
return result;
- if (llvm::any_of(getMixedSizes(),
- [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+ if (llvm::any_of(getMixedSizes(), isZeroIndex))
return getDest();
return OpFoldResult();
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 7778a02dbeaf4..41407064cb6d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -135,9 +135,9 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1));
for (unsigned dim = 0; dim < rank; ++dim) {
auto low = padOp.getMixedLowPad()[dim];
- bool hasLowPad = !isConstantIntValue(low, 0);
+ bool hasLowPad = !isZeroIndex(low);
auto high = padOp.getMixedHighPad()[dim];
- bool hasHighPad = !isConstantIntValue(high, 0);
+ bool hasHighPad = !isZeroIndex(high);
auto offset = offsets[dim];
auto length = sizes[dim];
// If the dim has no padding, we dont need to calculate new values for that
@@ -208,7 +208,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
// Check if newLength is zero. In that case, no SubTensorOp should be
// executed.
- if (isConstantIntValue(newLength, 0)) {
+ if (isZeroIndex(newLength)) {
hasZeroLen = true;
} else if (!hasZeroLen) {
Value check = b.create<arith::CmpIOp>(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index a3de7f9b44ae6..9978aac1ee80e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
isZeroOffsetAndFullSize =
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
- if (!isConstantIntValue(offset, 0))
+ if (!isZeroIndex(offset))
return false;
FailureOr<bool> maybeEqual =
ValueBoundsConstraintSet::areEqual(sliceSize, size);
@@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
// Find the first expanded dim after the first dim with non-unit extracted
// size.
for (; i < e; ++i) {
- if (!isConstantIntValue(sizes[indices[i]], 1)) {
+ if (!isOneIndex(sizes[indices[i]])) {
// +1 to skip the first non-unit size dim.
i++;
break;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 858adfc436164..36cc31e614f21 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -27,9 +27,7 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
return failure();
// `TilingInterface` currently only supports strides being 1.
- if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 1);
- }))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
return failure();
FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
@@ -49,9 +47,7 @@ FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
return failure();
// `TilingInterface` currently only supports strides being 1.
- if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 1);
- }))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
return failure();
FailureOr<TilingResult> tiledResult =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index fcb736aa031f3..51b51d8aa32e4 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -18,10 +18,13 @@ namespace mlir {
bool isZeroIndex(OpFoldResult v) {
if (!v)
return false;
- std::optional<int64_t> constint = getConstantIntValue(v);
- if (!constint)
+ return isConstantIntValue(v, 0);
+}
+
+bool isOneIndex(OpFoldResult v) {
+ if (!v)
return false;
- return *constint == 0;
+ return isConstantIntValue(v, 1);
}
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..4e5c60671b976 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -141,7 +141,7 @@ struct LinearizeVectorExtractStridedSlice final
ArrayAttr offsets = extractOp.getOffsets();
ArrayAttr sizes = extractOp.getSizes();
ArrayAttr strides = extractOp.getStrides();
- if (!isConstantIntValue(strides[0], 1))
+ if (!isOneIndex(strides[0]))
return rewriter.notifyMatchFailure(
extractOp, "Strided slice with stride != 1 is not supported.");
Value srcVector = adaptor.getVector();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..83dc34e4b4139 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1118,7 +1118,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
ArithIndexingBuilder idxBuilderf(rewriter, loc);
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
OpFoldResult pos = extractPos[i - rankOffset];
- if (isConstantIntValue(pos, 0))
+ if (isZeroIndex(pos))
continue;
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
``````````
</details>
https://github.com/llvm/llvm-project/pull/139340
More information about the Mlir-commits
mailing list