[llvm-branch-commits] [mlir] users/banach space/vector/update create write (PR #141567)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue May 27 02:12:27 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
- **[[mlir][linalg] Refactor vectorization hooks to improve code reuse**
- **[mlir][linalg] Simplify `createWriteOrMaskedWrite` (NFC)**
---
Full diff: https://github.com/llvm/llvm-project/pull/141567.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+40-78)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
-/// %res = vector.transfer_write %vectorToStore into %dest
+/// %res = vector.transfer_write %vecToStore into %dest
///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
-/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
-/// vector.transfer_write %vectorToStore into %dest
+/// vector.transfer_write %vecToStore into %dest
/// }
///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %write = vector.transfer_write %vectorToStore into %dest
+/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
- Value dest,
- ArrayRef<int64_t> inputVecSizesForLeadingDims,
- SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+ Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
- VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+ VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
- // In this case, assume that all the required vector sizes have been
- // provided.
- assert(inputVecSizesForLeadingDims.size() ==
- static_cast<size_t>(vecToStoreType.getRank()) &&
- "Insufficient number of input vector sizes!");
- // Update the inBounds attribute.
for (unsigned i = 0; i < destRank; i++)
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vectorToStore,
+ /*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;
- assert(llvm::none_of(
- destShape.drop_front(inputVecSizesForLeadingDims.size()),
- [](int64_t size) { return size == ShapedType::kDynamic; }) &&
- "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
- // Check if masking is needed.
- bool needMaskForWrite =
- !llvm::equal(inputVecSizesForLeadingDims,
- destShape.take_front(destRank - vecToStoreRank +
- inputVecSizesForLeadingDims.size()));
-
- // If masking is needed, generate the mask and mask the operation.
- if (needMaskForWrite) {
- // Get the mask shape + type. Missing mask dimensions are taken from
- // `vectorToStore`.
- SmallVector<int64_t> writeMaskShape;
- writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
- inputVecSizesForLeadingDims.end());
- if (vecToStoreRank >
- static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
- writeMaskShape.append(vecToStoreShape.begin() +
- inputVecSizesForLeadingDims.size(),
- vecToStoreShape.end());
- auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
- SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
- SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
- destSizes.end());
-
- if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
- writeMaskShape))
- return write;
-
- Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
- loc, writeMaskType, maskSizes);
- write = mlir::vector::maskOperation(builder, write, maskForWrite);
- }
+ // Check if masking is needed. If not, exit.
+ if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+ return write;
+
+ // Compute the mask and mask the write Op.
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+ SmallVector<OpFoldResult> destSizes =
+ tensor::getMixedSizes(builder, loc, dest);
+ SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+ destSizes.end());
+
+ if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+ vecToStoreShape))
+ return write;
- return write;
+ Value maskForWrite =
+ builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+ return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+ /*writeIndices=*/{},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+ Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+ sliceOp.getDest(), writeIndices);
// 4. Finalize
newResults.push_back(write->getResult(0));
``````````
</details>
https://github.com/llvm/llvm-project/pull/141567
More information about the llvm-branch-commits
mailing list