[Mlir-commits] [mlir] [mlir][linalg] Simplify `createWriteOrMaskedWrite` (NFC) (PR #141567)

Andrzej Warzyński llvmlistbot at llvm.org
Sat Jun 7 12:33:51 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/141567

>From edcc6048fecedd5c9cc6361ae53133a39e82f09a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 27 May 2025 09:58:34 +0100
Subject: [PATCH] [mlir][linalg] Simplify `createWriteOrMaskedWrite` (NFC)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch removes `inputVecSizesForLeadingDims` from the parameter list
of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes
can be obtained from the `vecToStore` parameter. Since this doesn't change
behavior or test results, it's marked as NFC.

Additional cleanups:
  * Renamed `vectorToStore` to `vecToStore` for consistency and brevity.
  * Rewrote a conditional at the end of the function to use early exit,
    improving readability:

```cpp
  // BEFORE:
  if (maskingRequried) {
    Value maskForWrite = ...;
    write = maskOperation(write, maskForWrite);
  }
  return write;

  // AFTER
  if (!maskingRequried)
    return write;

  Value maskFroWrite = ...;
  return vector::maskOperation(builder, write, maskForWrite);
```

This change addresses a TODO from #141244.
---
 .../Linalg/Transforms/Vectorization.cpp       | 114 ++++++------------
 1 file changed, 38 insertions(+), 76 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index afae84ea4045f..ee61a04dd3d7c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1606,63 +1606,49 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
-///   %res = vector.transfer_write %vectorToStore into %dest
+///   %res = vector.transfer_write %vecToStore into %dest
 ///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+///   %mask = vector.create_mask(%destShape) : %vecToStoreShape
 ///   %res = vector.mask %mask {
-///     vector.transfer_write %vectorToStore into %dest
+///     vector.transfer_write %vecToStore into %dest
 ///   }
 ///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
 /// i1), and the mask values are based on the shape of the `dest` tensor.
 ///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
-///   %write = vector.transfer_write %vectorToStore into %dest
+///   %write = vector.transfer_write %vecToStore into %dest
 ///   in_bounds_flags = (...)
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
 static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
-                         Value dest,
-                         ArrayRef<int64_t> inputVecSizesForLeadingDims,
-                         SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+                         Value dest, SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
   int64_t destRank = destType.getRank();
   auto destShape = destType.getShape();
 
-  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
   int64_t vecToStoreRank = vecToStoreType.getRank();
   auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
   SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
-    // In this case, assume that all the required vector sizes have been
-    // provided.
-    assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(vecToStoreType.getRank()) &&
-           "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
     // FIXME: This computation is too weak - it ignores the write indices.
     for (unsigned i = 0; i < vecToStoreRank; i++)
       inBoundsVal[i] =
-          (destShape[i] >= inputVecSizesForLeadingDims[i]) &&
+          (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
           !ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
   }
 
@@ -1678,7 +1664,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
   // Generate the xfer_write Op
   Operation *write =
       builder.create<vector::TransferWriteOp>(loc,
-                                              /*vector=*/vectorToStore,
+                                              /*vector=*/vecToStore,
                                               /*source=*/dest,
                                               /*indices=*/writeIndices,
                                               /*inBounds=*/inBoundsVal);
@@ -1687,46 +1673,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
   if (useInBoundsInsteadOfMasking)
     return write;
 
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
-  // Check if masking is needed.
-  bool needMaskForWrite =
-      !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(destRank - vecToStoreRank +
-                                        inputVecSizesForLeadingDims.size()));
-
-  // If masking is needed, generate the mask and mask the operation.
-  if (needMaskForWrite) {
-    // Get the mask shape + type. Missing mask dimensions are taken from
-    // `vectorToStore`.
-    SmallVector<int64_t> writeMaskShape;
-    writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
-                          inputVecSizesForLeadingDims.end());
-    if (vecToStoreRank >
-        static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
-      writeMaskShape.append(vecToStoreShape.begin() +
-                                inputVecSizesForLeadingDims.size(),
-                            vecToStoreShape.end());
-    auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
-    SmallVector<OpFoldResult> destSizes =
-        tensor::getMixedSizes(builder, loc, dest);
-    SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
-                                        destSizes.end());
-
-    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
-                                writeMaskShape))
-      return write;
-
-    Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
-        loc, writeMaskType, maskSizes);
-    write = mlir::vector::maskOperation(builder, write, maskForWrite);
-  }
+  // Check if masking is needed. If not, exit.
+  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+    return write;
+
+  // Compute the mask and mask the write Op.
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+  SmallVector<OpFoldResult> destSizes =
+      tensor::getMixedSizes(builder, loc, dest);
+  SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+                                      destSizes.end());
+
+  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                              vecToStoreShape))
+    return write;
 
-  return write;
+  Value maskForWrite =
+      builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+  return mlir::vector::maskOperation(builder, write, maskForWrite);
 }
 
 /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1826,9 +1791,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, transposeOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1966,7 +1930,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
       shapeCastOp.getResult().getType().getElementType());
   Operation *write = createWriteOrMaskedWrite(
       rewriter, loc, shapeCastOp.getResult(), dest,
-      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
       /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
@@ -1999,9 +1962,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, maskedRead, dest,
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes);
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -3043,9 +3005,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   // Create write
   auto writeIndices =
       getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices,
-      /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
+                               writeIndices, inputVectorSizes.empty());
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));



More information about the Mlir-commits mailing list