[Mlir-commits] [mlir] [mlir][linalg] Simplify `createWriteOrMaskedWrite` (NFC) (PR #141567)
Andrzej Warzyński
llvmlistbot at llvm.org
Sat Jun 7 12:33:51 PDT 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/141567
>From edcc6048fecedd5c9cc6361ae53133a39e82f09a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 27 May 2025 09:58:34 +0100
Subject: [PATCH] [mlir][linalg] Simplify `createWriteOrMaskedWrite` (NFC)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This patch removes `inputVecSizesForLeadingDims` from the parameter list
of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes
can be obtained from the `vecToStore` parameter. Since this doesn't change
behavior or test results, it's marked as NFC.
Additional cleanups:
* Renamed `vectorToStore` to `vecToStore` for consistency and brevity.
* Rewrote a conditional at the end of the function to use early exit,
improving readability:
```cpp
// BEFORE:
if (maskingRequried) {
Value maskForWrite = ...;
write = maskOperation(write, maskForWrite);
}
return write;
// AFTER
if (!maskingRequried)
return write;
Value maskFroWrite = ...;
return vector::maskOperation(builder, write, maskForWrite);
```
This change addresses a TODO from #141244.
---
.../Linalg/Transforms/Vectorization.cpp | 114 ++++++------------
1 file changed, 38 insertions(+), 76 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index afae84ea4045f..ee61a04dd3d7c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1606,63 +1606,49 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
-/// %res = vector.transfer_write %vectorToStore into %dest
+/// %res = vector.transfer_write %vecToStore into %dest
///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
-/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
-/// vector.transfer_write %vectorToStore into %dest
+/// vector.transfer_write %vecToStore into %dest
/// }
///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %write = vector.transfer_write %vectorToStore into %dest
+/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
- Value dest,
- ArrayRef<int64_t> inputVecSizesForLeadingDims,
- SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+ Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
- VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+ VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
- // In this case, assume that all the required vector sizes have been
- // provided.
- assert(inputVecSizesForLeadingDims.size() ==
- static_cast<size_t>(vecToStoreType.getRank()) &&
- "Insufficient number of input vector sizes!");
// Update the inBounds attribute.
// FIXME: This computation is too weak - it ignores the write indices.
for (unsigned i = 0; i < vecToStoreRank; i++)
inBoundsVal[i] =
- (destShape[i] >= inputVecSizesForLeadingDims[i]) &&
+ (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
}
@@ -1678,7 +1664,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vectorToStore,
+ /*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
@@ -1687,46 +1673,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;
- assert(llvm::none_of(
- destShape.drop_front(inputVecSizesForLeadingDims.size()),
- [](int64_t size) { return size == ShapedType::kDynamic; }) &&
- "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
- // Check if masking is needed.
- bool needMaskForWrite =
- !llvm::equal(inputVecSizesForLeadingDims,
- destShape.take_front(destRank - vecToStoreRank +
- inputVecSizesForLeadingDims.size()));
-
- // If masking is needed, generate the mask and mask the operation.
- if (needMaskForWrite) {
- // Get the mask shape + type. Missing mask dimensions are taken from
- // `vectorToStore`.
- SmallVector<int64_t> writeMaskShape;
- writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
- inputVecSizesForLeadingDims.end());
- if (vecToStoreRank >
- static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
- writeMaskShape.append(vecToStoreShape.begin() +
- inputVecSizesForLeadingDims.size(),
- vecToStoreShape.end());
- auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
- SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
- SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
- destSizes.end());
-
- if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
- writeMaskShape))
- return write;
-
- Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
- loc, writeMaskType, maskSizes);
- write = mlir::vector::maskOperation(builder, write, maskForWrite);
- }
+ // Check if masking is needed. If not, exit.
+ if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+ return write;
+
+ // Compute the mask and mask the write Op.
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+ SmallVector<OpFoldResult> destSizes =
+ tensor::getMixedSizes(builder, loc, dest);
+ SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+ destSizes.end());
+
+ if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+ vecToStoreShape))
+ return write;
- return write;
+ Value maskForWrite =
+ builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+ return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1826,9 +1791,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1966,7 +1930,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
@@ -1999,9 +1962,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
newResults.push_back(write->getResult(0));
return success();
}
@@ -3043,9 +3005,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices,
- /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
+ writeIndices, inputVectorSizes.empty());
// 4. Finalize
newResults.push_back(write->getResult(0));
More information about the Mlir-commits
mailing list