[Mlir-commits] [mlir] [mlir][NFC] Simplify constant checks with isZeroIndex and isOneIndex. (PR #139340)
Han-Chung Wang
llvmlistbot at llvm.org
Fri May 16 15:52:57 PDT 2025
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/139340
>From c40ab2c51367714ec2527c2c10577dca342c2b20 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 9 May 2025 17:18:14 -0700
Subject: [PATCH 1/3] [mlir][NFC] Simplify constant checks with isZeroIndex and
isOneIndex.
The revision adds isOneIndex helper, and simplifies the existing code
with the two methods. It removes some lambda, which makes code cleaner.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../mlir/Dialect/Utils/StaticValueUtils.h | 4 ++++
.../MemRefToSPIRV/MemRefToSPIRV.cpp | 2 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 +--
.../TransformOps/LinalgTransformOps.cpp | 5 +----
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 9 +++++----
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 2 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 +++-------
.../Transforms/ExtractAddressComputations.cpp | 5 +----
.../SCF/Transforms/TileUsingInterface.cpp | 20 ++++++++-----------
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 +++---
.../Transforms/SparseVectorization.cpp | 2 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 3 +--
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 6 +++---
.../Tensor/Transforms/ReshapePatterns.cpp | 4 ++--
.../SwapExtractSliceWithProducerPatterns.cpp | 8 ++------
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 9 ++++++---
.../Vector/Transforms/VectorLinearize.cpp | 3 +--
.../Vector/Transforms/VectorTransforms.cpp | 2 +-
18 files changed, 45 insertions(+), 58 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 2a3a2defb810d..ea1a2384f8cba 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -28,6 +28,10 @@ namespace mlir {
/// with attribute with value `0`.
bool isZeroIndex(OpFoldResult v);
+/// Return true if `v` is an IntegerAttr with value `1` of a ConstantIndexOp
+/// with attribute with value `1`.
+bool isOneIndex(OpFoldResult v);
+
/// Represents a range (offset, size, and stride) where each element of the
/// triple may be dynamic or static.
struct Range {
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 04bc62262c3d8..c9e7ae6f8bdb5 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
OpFoldResult offset =
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
.front();
- if (isConstantIntValue(offset, 0)) {
+ if (isZeroIndex(offset)) {
rewriter.replaceOp(op, src);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 96106cf7ae120..4fdeca47ed304 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4488,8 +4488,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// Return true if we have a zero-value tile.
auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
- return llvm::any_of(
- tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
+ return llvm::any_of(tiles, isZeroIndex);
};
// Verify tiles. Do not allow zero tiles.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a9370dc003830..d736fb141cb0c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3401,10 +3401,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
SmallVector<OpFoldResult> steps = loop.getMixedStep();
- if (llvm::all_of(
- lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
- llvm::all_of(
- steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+ if (llvm::all_of(lbs, isZeroIndex) && llvm::all_of(steps, isOneIndex)) {
return loop;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7c2788f16a3b6..700be3ad35705 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
@@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes(
SmallVector<Value> threadIds = forallOp.getInductionVars();
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
- numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+ numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
int64_t nLoops = loopRanges.size();
tiledOffsets.reserve(nLoops);
tiledSizes.reserve(nLoops);
for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
bool overflow = loopIdx >= numThreads.size();
- bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+ bool isZero = !overflow && isZeroIndex(numThreads[loopIdx]);
// Degenerate case: take the whole domain.
if (overflow || isZero) {
tiledOffsets.push_back(loopRanges[loopIdx].offset);
@@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes(
OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
b, loc, i + j * m - n,
{offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
- if (!isConstantIntValue(residualTileSize, 0)) {
+ if (!isZeroIndex(residualTileSize)) {
OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
b, loc, -i + m, {offsetPerThread, size});
tileSizePerThread =
@@ -655,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
Operation *tiledOp = nullptr;
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
- numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+ numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
SmallVector<Value> materializedNonZeroNumThreads =
getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index e8d460020cf69..7485df2cd73b3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -732,7 +732,7 @@ struct PackOpTiling
// iterated or inner dims are not tiled. Otherwise, it will generate a
// sequence of non-trivial ops (for partial tiles).
for (auto offset : offsets.take_back(numTiles))
- if (!isConstantIntValue(offset, 0))
+ if (!isZeroIndex(offset))
return failure();
for (auto iter :
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..1175c57694272 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1889,9 +1889,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
// are 0.
if (auto prev = src.getDefiningOp<SubViewOp>())
- if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
- return isConstantIntValue(val, 0);
- }))
+ if (llvm::all_of(prev.getMixedOffsets(), isZeroIndex))
return prev.getSource();
return nullptr;
@@ -3285,11 +3283,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
auto srcSizes = srcSubview.getMixedSizes();
auto sizes = getMixedSizes();
auto offsets = getMixedOffsets();
- bool allOffsetsZero = llvm::all_of(
- offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+ bool allOffsetsZero = llvm::all_of(offsets, isZeroIndex);
auto strides = getMixedStrides();
- bool allStridesOne = llvm::all_of(
- strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+ bool allStridesOne = llvm::all_of(strides, isOneIndex);
bool allSizesSame = llvm::equal(sizes, srcSizes);
if (allOffsetsZero && allStridesOne && allSizesSame &&
resultMemrefType == sourceMemrefType)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index b906c727604dc..5a08900921ee5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -251,10 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
// to do.
SmallVector<OpFoldResult> indices =
getAsOpFoldResult(loadStoreLikeOp.getIndices());
- if (std::all_of(indices.begin(), indices.end(),
- [](const OpFoldResult &opFold) {
- return isConstantIntValue(opFold, 0);
- })) {
+ if (std::all_of(indices.begin(), indices.end(), isZeroIndex)) {
return rewriter.notifyMatchFailure(
loadStoreLikeOp, "no computation to extract: offsets are 0s");
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0cd7da5db9163..d7d42219bc7b6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
tileSizes.resize(numLoops, zero);
for (auto [index, range, nt] :
llvm::enumerate(iterationDomain, numThreads)) {
- if (isConstantIntValue(nt, 0))
+ if (isZeroIndex(nt))
continue;
tileSizes[index] = affine::makeComposedFoldedAffineApply(
@@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
// Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`.
- if (isConstantIntValue(nt, 0)) {
+ if (isZeroIndex(nt)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
@@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
{loopRange.offset, nt, tileSize, loopRange.size});
OpFoldResult size = tileSize;
- if (!isConstantIntValue(residualTileSize, 0)) {
+ if (!isZeroIndex(residualTileSize)) {
OpFoldResult sizeMinusOffsetPerThread =
affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
{offset, loopRange.size});
@@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
// Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`.
- if (isConstantIntValue(tileSize, 0)) {
+ if (isZeroIndex(tileSize)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
@@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
SmallVector<OpFoldResult> lbs, ubs, steps;
for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
// No loop if the tile size is 0.
- if (isConstantIntValue(tileSize, 0))
+ if (isZeroIndex(tileSize))
continue;
lbs.push_back(loopRange.offset);
ubs.push_back(loopRange.size);
@@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp(
// Prune the zero numthreads.
SmallVector<OpFoldResult> nonZeroNumThreads;
for (auto nt : numThreads) {
- if (isConstantIntValue(nt, 0))
+ if (isZeroIndex(nt))
continue;
nonZeroNumThreads.push_back(nt);
}
@@ -1290,9 +1290,7 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
sliceSizes = sliceOp.getMixedSizes();
// expect all strides of sliceOp being 1
- if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 1);
- }))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
return failure();
unsigned sliceResultNumber =
@@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
// 9. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
+ if (!llvm::all_of(strides, isOneIndex)) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "containingOp's result yield with stride");
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index d9550fe18dc02..f95e38fc75c8d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
// If an `affine.apply` operation is generated for denormalization, the use
// of `origLb` in those ops must not be replaced. These arent not generated
// when `origLb == 0` and `origStep == 1`.
- if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
+ if (!isZeroIndex(origLb) || !isOneIndex(origStep)) {
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
preservedUses.insert(preservedUse);
}
@@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
}
Value denormalizedIv;
SmallPtrSet<Operation *, 2> preserve;
- bool isStepOne = isConstantIntValue(origStep, 1);
- bool isZeroBased = isConstantIntValue(origLb, 0);
+ bool isStepOne = isOneIndex(origStep);
+ bool isZeroBased = isZeroIndex(origLb);
Value scaled = normalizedIv;
if (!isStepOne) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index b2eca539194a8..649375b4c4037 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
// Check for single block, unit-stride for-loop that is generated by
// sparsifier, which means no data dependence analysis is required,
// and its loop-body is very restricted in form.
- if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
+ if (!op.getRegion().hasOneBlock() || !isOneIndex(op.getStep()) ||
!op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
return failure();
// Analyze (!codegen) and rewrite (codegen) loop-body.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6c32476d8656f..6c17ebbb85c81 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2840,8 +2840,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
return getResult();
if (auto result = foldInsertAfterExtractSlice(*this))
return result;
- if (llvm::any_of(getMixedSizes(),
- [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+ if (llvm::any_of(getMixedSizes(), isZeroIndex))
return getDest();
return OpFoldResult();
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 7778a02dbeaf4..41407064cb6d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -135,9 +135,9 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1));
for (unsigned dim = 0; dim < rank; ++dim) {
auto low = padOp.getMixedLowPad()[dim];
- bool hasLowPad = !isConstantIntValue(low, 0);
+ bool hasLowPad = !isZeroIndex(low);
auto high = padOp.getMixedHighPad()[dim];
- bool hasHighPad = !isConstantIntValue(high, 0);
+ bool hasHighPad = !isZeroIndex(high);
auto offset = offsets[dim];
auto length = sizes[dim];
// If the dim has no padding, we dont need to calculate new values for that
@@ -208,7 +208,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
// Check if newLength is zero. In that case, no SubTensorOp should be
// executed.
- if (isConstantIntValue(newLength, 0)) {
+ if (isZeroIndex(newLength)) {
hasZeroLen = true;
} else if (!hasZeroLen) {
Value check = b.create<arith::CmpIOp>(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index a3de7f9b44ae6..9978aac1ee80e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
isZeroOffsetAndFullSize =
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
- if (!isConstantIntValue(offset, 0))
+ if (!isZeroIndex(offset))
return false;
FailureOr<bool> maybeEqual =
ValueBoundsConstraintSet::areEqual(sliceSize, size);
@@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
// Find the first expanded dim after the first dim with non-unit extracted
// size.
for (; i < e; ++i) {
- if (!isConstantIntValue(sizes[indices[i]], 1)) {
+ if (!isOneIndex(sizes[indices[i]])) {
// +1 to skip the first non-unit size dim.
i++;
break;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 858adfc436164..36cc31e614f21 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -27,9 +27,7 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
return failure();
// `TilingInterface` currently only supports strides being 1.
- if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 1);
- }))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
return failure();
FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
@@ -49,9 +47,7 @@ FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
return failure();
// `TilingInterface` currently only supports strides being 1.
- if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 1);
- }))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
return failure();
FailureOr<TilingResult> tiledResult =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index fac836ebd7a36..2edfdf2508895 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -18,10 +18,13 @@ namespace mlir {
bool isZeroIndex(OpFoldResult v) {
if (!v)
return false;
- std::optional<int64_t> constint = getConstantIntValue(v);
- if (!constint)
+ return isConstantIntValue(v, 0);
+}
+
+bool isOneIndex(OpFoldResult v) {
+ if (!v)
return false;
- return *constint == 0;
+ return isConstantIntValue(v, 1);
}
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 060ce7d1d6643..dc87424df3854 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -116,8 +116,7 @@ static bool stridesAllOne(TOp op) {
std::is_same_v<TOp, vector::InsertStridedSliceOp>,
"expected vector.extract_strided_slice or vector.insert_strided_slice");
ArrayAttr strides = op.getStrides();
- return llvm::all_of(
- strides, [](auto stride) { return isConstantIntValue(stride, 1); });
+ return llvm::all_of(strides, isOneIndex);
}
/// Convert an array of attributes into a vector of integers, if possible.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c635be6e83b6a..ca15d410efc7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1118,7 +1118,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
ArithIndexingBuilder idxBuilderf(rewriter, loc);
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
OpFoldResult pos = extractPos[i - rankOffset];
- if (isConstantIntValue(pos, 0))
+ if (isZeroIndex(pos))
continue;
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
>From 6f5eb8b28bcae88ce22b0c61ada71c2306a3c410 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 16 May 2025 15:50:15 -0700
Subject: [PATCH 2/3] Update comments and implementation.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h | 6 ++----
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 12 ++----------
2 files changed, 4 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index ea1a2384f8cba..c64dd88b8f52d 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -24,12 +24,10 @@
namespace mlir {
-/// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp
-/// with attribute with value `0`.
+/// Return true if `v` is an IntegerAttr with value `0`.
bool isZeroIndex(OpFoldResult v);
-/// Return true if `v` is an IntegerAttr with value `1` of a ConstantIndexOp
-/// with attribute with value `1`.
+/// Return true if `v` is an IntegerAttr with value `1`.
bool isOneIndex(OpFoldResult v);
/// Represents a range (offset, size, and stride) where each element of the
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 2edfdf2508895..3ad3a43fbed0e 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -15,17 +15,9 @@
namespace mlir {
-bool isZeroIndex(OpFoldResult v) {
- if (!v)
- return false;
- return isConstantIntValue(v, 0);
-}
+bool isZeroIndex(OpFoldResult v) { return isConstantIntValue(v, 0); }
-bool isOneIndex(OpFoldResult v) {
- if (!v)
- return false;
- return isConstantIntValue(v, 1);
-}
+bool isOneIndex(OpFoldResult v) { return isConstantIntValue(v, 1); }
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
SmallVector<OpFoldResult>>
>From 1e4b65eb879c78879647f3aafbf6e7bd71a7cbe4 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 16 May 2025 15:51:54 -0700
Subject: [PATCH 3/3] Rename to isZeroInteger and isOneInteger.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../mlir/Dialect/Utils/StaticValueUtils.h | 4 ++--
.../MemRefToSPIRV/MemRefToSPIRV.cpp | 2 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +-
.../TransformOps/LinalgTransformOps.cpp | 2 +-
.../Transforms/ConvertToDestinationStyle.cpp | 4 ++--
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 8 ++++----
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 4 ++--
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 8 ++++----
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 6 +++---
.../Transforms/ExtractAddressComputations.cpp | 2 +-
.../SCF/Transforms/TileUsingInterface.cpp | 20 +++++++++----------
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 +++---
.../Transforms/SparseVectorization.cpp | 2 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 6 +++---
.../BufferizableOpInterfaceImpl.cpp | 2 +-
.../Tensor/Transforms/ReshapePatterns.cpp | 4 ++--
.../SwapExtractSliceWithProducerPatterns.cpp | 4 ++--
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 4 ++--
.../Vector/Transforms/VectorLinearize.cpp | 2 +-
.../Transforms/VectorTransferOpTransforms.cpp | 2 +-
.../Vector/Transforms/VectorTransforms.cpp | 2 +-
22 files changed, 49 insertions(+), 49 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index c64dd88b8f52d..b37fb55b67931 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -25,10 +25,10 @@
namespace mlir {
/// Return true if `v` is an IntegerAttr with value `0`.
-bool isZeroIndex(OpFoldResult v);
+bool isZeroInteger(OpFoldResult v);
/// Return true if `v` is an IntegerAttr with value `1`.
-bool isOneIndex(OpFoldResult v);
+bool isOneInteger(OpFoldResult v);
/// Represents a range (offset, size, and stride) where each element of the
/// triple may be dynamic or static.
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index c9e7ae6f8bdb5..fdf799a20efdd 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
OpFoldResult offset =
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
.front();
- if (isZeroIndex(offset)) {
+ if (isZeroInteger(offset)) {
rewriter.replaceOp(op, src);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4fdeca47ed304..b7f78607e6241 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4488,7 +4488,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// Return true if we have a zero-value tile.
auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
- return llvm::any_of(tiles, isZeroIndex);
+ return llvm::any_of(tiles, isZeroInteger);
};
// Verify tiles. Do not allow zero tiles.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d736fb141cb0c..1c3b621828315 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3401,7 +3401,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
SmallVector<OpFoldResult> steps = loop.getMixedStep();
- if (llvm::all_of(lbs, isZeroIndex) && llvm::all_of(steps, isOneIndex)) {
+ if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
return loop;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index b1340be04e011..a62510deefc4a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -441,8 +441,8 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
// If the `padOp` has a nofold attribute and all paddings are known to be 0,
// explicitly insert a `linalg.copy`.
if (padOp.getNofoldAttr() &&
- llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) &&
- llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) {
+ llvm::all_of(padOp.getMixedLowPad(), isZeroInteger) &&
+ llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) {
using bufferization::AllocTensorOp;
Value allocated =
rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 700be3ad35705..4162aa0b71e6d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -377,13 +377,13 @@ static void calculateTileOffsetsAndSizes(
SmallVector<Value> threadIds = forallOp.getInductionVars();
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
- numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
+ numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
int64_t nLoops = loopRanges.size();
tiledOffsets.reserve(nLoops);
tiledSizes.reserve(nLoops);
for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
bool overflow = loopIdx >= numThreads.size();
- bool isZero = !overflow && isZeroIndex(numThreads[loopIdx]);
+ bool isZero = !overflow && isZeroInteger(numThreads[loopIdx]);
// Degenerate case: take the whole domain.
if (overflow || isZero) {
tiledOffsets.push_back(loopRanges[loopIdx].offset);
@@ -414,7 +414,7 @@ static void calculateTileOffsetsAndSizes(
OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
b, loc, i + j * m - n,
{offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
- if (!isZeroIndex(residualTileSize)) {
+ if (!isZeroInteger(residualTileSize)) {
OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
b, loc, -i + m, {offsetPerThread, size});
tileSizePerThread =
@@ -656,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
Operation *tiledOp = nullptr;
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
- numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
+ numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
SmallVector<Value> materializedNonZeroNumThreads =
getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7485df2cd73b3..7c14cc16437fe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -369,7 +369,7 @@ struct LinalgOpPartialReductionInterface
SmallVector<OpFoldResult> tiledShape;
for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
- if (isZeroIndex(tileSize)) {
+ if (isZeroInteger(tileSize)) {
tiledShape.push_back(dimSize);
} else {
tiledShape.push_back(tileSize);
@@ -732,7 +732,7 @@ struct PackOpTiling
// iterated or inner dims are not tiled. Otherwise, it will generate a
// sequence of non-trivial ops (for partial tiles).
for (auto offset : offsets.take_back(numTiles))
- if (!isZeroIndex(offset))
+ if (!isZeroInteger(offset))
return failure();
for (auto iter :
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index d3d301ca093b1..bae06c003fd97 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -59,7 +59,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
void visitDimExpr(AffineDimExpr expr) {
- isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
+ isTiled |= !isZeroInteger(tileSizes[expr.getPosition()]);
}
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
visit(expr.getLHS());
@@ -741,7 +741,7 @@ SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
SmallVector<OpFoldResult> offsets;
for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
- bool isTiled = !isZeroIndex(tileSizes[idx]);
+ bool isTiled = !isZeroInteger(tileSizes[idx]);
offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
LLVM_DEBUG(llvm::dbgs()
<< "computeTileOffsets: " << offsets.back() << "\n");
@@ -754,7 +754,7 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> sizeBounds) {
SmallVector<OpFoldResult> sizes;
for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
- bool isTiled = !isZeroIndex(tileSizes[idx]);
+ bool isTiled = !isZeroInteger(tileSizes[idx]);
// Before composing, we need to make range a closed interval.
OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
AffineExpr d0 = getAffineDimExpr(0, b.getContext());
@@ -810,7 +810,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
bool omitPartialTileCheck) {
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
llvm::make_range(tileSizes.begin(), tileSizes.end()),
- [](OpFoldResult v) { return !isZeroIndex(v); })) &&
+ [](OpFoldResult v) { return !isZeroInteger(v); })) &&
"expected as many ivs as non-zero sizes");
// Construct (potentially temporary) mins and maxes on which to apply maps
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1175c57694272..95c8b72643735 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1889,7 +1889,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
// are 0.
if (auto prev = src.getDefiningOp<SubViewOp>())
- if (llvm::all_of(prev.getMixedOffsets(), isZeroIndex))
+ if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
return prev.getSource();
return nullptr;
@@ -3283,9 +3283,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
auto srcSizes = srcSubview.getMixedSizes();
auto sizes = getMixedSizes();
auto offsets = getMixedOffsets();
- bool allOffsetsZero = llvm::all_of(offsets, isZeroIndex);
+ bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
auto strides = getMixedStrides();
- bool allStridesOne = llvm::all_of(strides, isOneIndex);
+ bool allStridesOne = llvm::all_of(strides, isOneInteger);
bool allSizesSame = llvm::equal(sizes, srcSizes);
if (allOffsetsZero && allStridesOne && allSizesSame &&
resultMemrefType == sourceMemrefType)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 5a08900921ee5..9e942f10b1f16 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -251,7 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
// to do.
SmallVector<OpFoldResult> indices =
getAsOpFoldResult(loadStoreLikeOp.getIndices());
- if (std::all_of(indices.begin(), indices.end(), isZeroIndex)) {
+ if (std::all_of(indices.begin(), indices.end(), isZeroInteger)) {
return rewriter.notifyMatchFailure(
loadStoreLikeOp, "no computation to extract: offsets are 0s");
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index d7d42219bc7b6..719e2c6fa459e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
tileSizes.resize(numLoops, zero);
for (auto [index, range, nt] :
llvm::enumerate(iterationDomain, numThreads)) {
- if (isZeroIndex(nt))
+ if (isZeroInteger(nt))
continue;
tileSizes[index] = affine::makeComposedFoldedAffineApply(
@@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
// Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`.
- if (isZeroIndex(nt)) {
+ if (isZeroInteger(nt)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
@@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
{loopRange.offset, nt, tileSize, loopRange.size});
OpFoldResult size = tileSize;
- if (!isZeroIndex(residualTileSize)) {
+ if (!isZeroInteger(residualTileSize)) {
OpFoldResult sizeMinusOffsetPerThread =
affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
{offset, loopRange.size});
@@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
// Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`.
- if (isZeroIndex(tileSize)) {
+ if (isZeroInteger(tileSize)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
@@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
SmallVector<OpFoldResult> lbs, ubs, steps;
for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
// No loop if the tile size is 0.
- if (isZeroIndex(tileSize))
+ if (isZeroInteger(tileSize))
continue;
lbs.push_back(loopRange.offset);
ubs.push_back(loopRange.size);
@@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp(
// Prune the zero numthreads.
SmallVector<OpFoldResult> nonZeroNumThreads;
for (auto nt : numThreads) {
- if (isZeroIndex(nt))
+ if (isZeroInteger(nt))
continue;
nonZeroNumThreads.push_back(nt);
}
@@ -551,7 +551,7 @@ static LogicalResult generateLoopNest(
YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
// If the tile sizes are all zero, no loops are generated. Just call the
// callback function to handle untiled case.
- if (llvm::all_of(tileSizes, isZeroIndex)) {
+ if (llvm::all_of(tileSizes, isZeroInteger)) {
SmallVector<Value> tiledResults;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
@@ -999,7 +999,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 5b. Early return cloned op if tiling is not happening. We can not
// return the original op because it could lead to `rewriter.replaceOp(op,
// op->getResults())` and users would get crash.
- if (llvm::all_of(tileSizes, isZeroIndex)) {
+ if (llvm::all_of(tileSizes, isZeroInteger)) {
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
tilingResult =
TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
@@ -1290,7 +1290,7 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
sliceSizes = sliceOp.getMixedSizes();
// expect all strides of sliceOp being 1
- if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
return failure();
unsigned sliceResultNumber =
@@ -2112,7 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
// 9. Check all insert stride is 1.
- if (!llvm::all_of(strides, isOneIndex)) {
+ if (!llvm::all_of(strides, isOneInteger)) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "containingOp's result yield with stride");
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index f95e38fc75c8d..8ab5bdc0c5dc5 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
// If an `affine.apply` operation is generated for denormalization, the use
// of `origLb` in those ops must not be replaced. These arent not generated
// when `origLb == 0` and `origStep == 1`.
- if (!isZeroIndex(origLb) || !isOneIndex(origStep)) {
+ if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
preservedUses.insert(preservedUse);
}
@@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
}
Value denormalizedIv;
SmallPtrSet<Operation *, 2> preserve;
- bool isStepOne = isOneIndex(origStep);
- bool isZeroBased = isZeroIndex(origLb);
+ bool isStepOne = isOneInteger(origStep);
+ bool isZeroBased = isZeroInteger(origLb);
Value scaled = normalizedIv;
if (!isStepOne) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 649375b4c4037..3d963dea2f572 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
// Check for single block, unit-stride for-loop that is generated by
// sparsifier, which means no data dependence analysis is required,
// and its loop-body is very restricted in form.
- if (!op.getRegion().hasOneBlock() || !isOneIndex(op.getStep()) ||
+ if (!op.getRegion().hasOneBlock() || !isOneInteger(op.getStep()) ||
!op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
return failure();
// Analyze (!codegen) and rewrite (codegen) loop-body.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6c17ebbb85c81..8db563fb7a25f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2840,7 +2840,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
return getResult();
if (auto result = foldInsertAfterExtractSlice(*this))
return result;
- if (llvm::any_of(getMixedSizes(), isZeroIndex))
+ if (llvm::any_of(getMixedSizes(), isZeroInteger))
return getDest();
return OpFoldResult();
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 41407064cb6d7..92540bd56ecbc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -135,9 +135,9 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1));
for (unsigned dim = 0; dim < rank; ++dim) {
auto low = padOp.getMixedLowPad()[dim];
- bool hasLowPad = !isZeroIndex(low);
+ bool hasLowPad = !isZeroInteger(low);
auto high = padOp.getMixedHighPad()[dim];
- bool hasHighPad = !isZeroIndex(high);
+ bool hasHighPad = !isZeroInteger(high);
auto offset = offsets[dim];
auto length = sizes[dim];
// If the dim has no padding, we dont need to calculate new values for that
@@ -208,7 +208,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
// Check if newLength is zero. In that case, no SubTensorOp should be
// executed.
- if (isZeroIndex(newLength)) {
+ if (isZeroInteger(newLength)) {
hasZeroLen = true;
} else if (!hasZeroLen) {
Value check = b.create<arith::CmpIOp>(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c0e697292d2a0..81a2480940742 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -646,7 +646,7 @@ static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
bool allOffsetsZero =
- llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
+ llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger);
RankedTensorType destType = insertSliceOp.getDestType();
bool sizesMatchDestSizes =
areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 9978aac1ee80e..2b229d60c691b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
isZeroOffsetAndFullSize =
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
- if (!isZeroIndex(offset))
+ if (!isZeroInteger(offset))
return false;
FailureOr<bool> maybeEqual =
ValueBoundsConstraintSet::areEqual(sliceSize, size);
@@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
// Find the first expanded dim after the first dim with non-unit extracted
// size.
for (; i < e; ++i) {
- if (!isOneIndex(sizes[indices[i]])) {
+ if (!isOneInteger(sizes[indices[i]])) {
// +1 to skip the first non-unit size dim.
i++;
break;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 36cc31e614f21..6f33f9b55ceb6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -27,7 +27,7 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
return failure();
// `TilingInterface` currently only supports strides being 1.
- if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
return failure();
FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
@@ -47,7 +47,7 @@ FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
return failure();
// `TilingInterface` currently only supports strides being 1.
- if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
return failure();
FailureOr<TilingResult> tiledResult =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 3ad3a43fbed0e..29f7bd6857c27 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -15,9 +15,9 @@
namespace mlir {
-bool isZeroIndex(OpFoldResult v) { return isConstantIntValue(v, 0); }
+bool isZeroInteger(OpFoldResult v) { return isConstantIntValue(v, 0); }
-bool isOneIndex(OpFoldResult v) { return isConstantIntValue(v, 1); }
+bool isOneInteger(OpFoldResult v) { return isConstantIntValue(v, 1); }
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
SmallVector<OpFoldResult>>
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index dc87424df3854..678a88627ca82 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -116,7 +116,7 @@ static bool stridesAllOne(TOp op) {
std::is_same_v<TOp, vector::InsertStridedSliceOp>,
"expected vector.extract_strided_slice or vector.insert_strided_slice");
ArrayAttr strides = op.getStrides();
- return llvm::all_of(strides, isOneIndex);
+ return llvm::all_of(strides, isOneInteger);
}
/// Convert an array of attributes into a vector of integers, if possible.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d4d07c7eadc77..7dbb7a334fe62 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -538,7 +538,7 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
indices.begin(), indices.begin() + firstDimToCollapse);
SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
indices.end());
- if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
+ if (llvm::all_of(indicesToCollapse, isZeroInteger)) {
indicesAfterCollapsing.push_back(indicesToCollapse[0]);
return indicesAfterCollapsing;
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ca15d410efc7a..71c557f7eda06 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1118,7 +1118,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
ArithIndexingBuilder idxBuilderf(rewriter, loc);
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
OpFoldResult pos = extractPos[i - rankOffset];
- if (isZeroIndex(pos))
+ if (isZeroInteger(pos))
continue;
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
More information about the Mlir-commits
mailing list