[Mlir-commits] [mlir] [mlir][linalg] Add support for masked vectorization of `tensor.insert_slice` (1/N) (PR #122927)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Jan 22 08:36:57 PST 2025
================
@@ -2583,113 +2626,139 @@ static Value getStaticPadVal(Operation *op) {
return {};
}
-/// Rewrite tensor.insert.slice as a vector.transfer_read +
-/// vector.transfer_write pair. The vector size is inferred from the static
-/// dims in the input and output tensors. If a dim is dynamic in both the input
-/// and output tensors, bails out.
-///
-/// Before:
-/// !t_in_type = tensor<1x2x3xf32>
-/// !t_out_type = tensor<9x8x7x1x2x3xf32>
-/// !v_type = vector<1x2x3xf32>
-/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
-/// into !t_out_type
-/// After:
-/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
-/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
-///
-/// TODO: Support masking
-struct InsertSliceVectorizePattern
- : public OpRewritePattern<tensor::InsertSliceOp> {
- using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+static LogicalResult
+vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
+ ArrayRef<int64_t> inputVectorSizes,
+ SmallVectorImpl<Value> &newResults) {
+ // TODO: Introduce a parent class that will handle the insertion point update.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(sliceOp);
- LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
- PatternRewriter &rewriter) const final {
- auto sourceType = sliceOp.getSource().getType();
- if (!VectorType::isValidElementType(sourceType.getElementType()))
- return failure();
+ TypedValue<RankedTensorType> source = sliceOp.getSource();
+ auto sourceType = source.getType();
+ if (!VectorType::isValidElementType(sourceType.getElementType()))
+ return failure();
- auto resultType = sliceOp.getResultType();
-
- // 1. Get the pad value.
- // TransferReadOp requires a scalar padding value. Note that:
- // * for in-bounds access, the value is actually irrelevant.
- // There are 2 cases in which xfer.read accesses are known to be in-bounds:
- // 1. The source shape is static (output vector sizes would be based on
- // the source shape and hence all memory accesses would be in-bounds),
- // 2. Masking is used (output vector sizes would be user-provided, in which
- // case it is assumed that all memory accesses are in-bounds). This
- // remains a TODO.
- //
- // When the value is not known and not needed, use 0. Otherwise, bail out.
- Value padValue = getStaticPadVal(sliceOp);
- bool isOutOfBoundsRead = !sourceType.hasStaticShape();
-
- if (!padValue && isOutOfBoundsRead) {
- LDBG("Failed to get a pad value for out-of-bounds read access\n");
+ auto resultType = sliceOp.getResultType();
+
+ // 1. Get the pad value.
+ // TransferReadOp requires a scalar padding value. Note that:
+ // * for in-bounds access, the value is actually irrelevant.
+ // There are 2 cases in which xfer.read accesses are known to be in-bounds:
+ // 1. The source shape is static (output vector sizes would be based on
+ // the source shape and hence all memory accesses would be in-bounds),
+ // 2. Masking is used (output vector sizes would be user-provided, in which
+ // case it is assumed that all memory accesses are in-bounds). This
+ // remains a TODO.
+ //
+ // When the value is not known and not needed, use 0. Otherwise, bail out.
+ Value padValue = getStaticPadVal(sliceOp);
+ bool isOutOfBoundsRead =
+ !sourceType.hasStaticShape() && inputVectorSizes.empty();
+
+ if (!padValue && isOutOfBoundsRead) {
+ LDBG("Failed to get a pad value for out-of-bounds read access\n");
+ return failure();
+ }
+
+ if (!padValue) {
+ auto elemType = sourceType.getElementType();
+ padValue = rewriter.create<arith::ConstantOp>(
+ sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
+ }
+
+ // 2. Get the vector shape and in-bounds attributes
+ SmallVector<int64_t> vecShape;
+ SmallVector<bool> readInBounds;
+ SmallVector<bool> writeInBounds;
+ size_t rankDiff = resultType.getRank() - sourceType.getRank();
+ for (unsigned i = 0; i < sourceType.getRank(); ++i) {
----------------
banach-space wrote:
Done.
https://github.com/llvm/llvm-project/pull/122927
More information about the Mlir-commits
mailing list