[Mlir-commits] [mlir] 3529ce0 - [mlir][Vector] Make createWriteOrMaskedWrite utility (#190967)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 9 03:52:09 PDT 2026


Author: Lukas Sommer
Date: 2026-04-09T12:52:04+02:00
New Revision: 3529ce05e9a96760ae3b9435ca264cc707bb4dac

URL: https://github.com/llvm/llvm-project/commit/3529ce05e9a96760ae3b9435ca264cc707bb4dac
DIFF: https://github.com/llvm/llvm-project/commit/3529ce05e9a96760ae3b9435ca264cc707bb4dac.diff

LOG: [mlir][Vector] Make createWriteOrMaskedWrite utility (#190967)

Analog to https://github.com/llvm/llvm-project/pull/89119, make
`createWriteOrMaskedWrite` a vector utility, exposing it for re-use by
downstream users.

This PR is mostly just moving code and updating documentation but also
addresses a `TODO` for `isMaskTriviallyFoldable` to use that utility in
`createReadOrMaskedRead` as well.

No new tests were added, because the functionality is covered by existing tests.

---------

Signed-off-by: Lukas Sommer <lukas.sommer at amd.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 45626aa280946..773b27bc6bfff 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -236,6 +236,19 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
                              bool useInBoundsInsteadOfMasking = false,
                              ArrayRef<bool> inputScalableVecDims = {});
 
+/// Create a TransferWriteOp of `vecToStore` into `dest`.
+///
+/// If the shape of the vector to write 
diff ers from the destination shape,
+/// masking is used to avoid out-of-bounds accesses. Set
+/// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
+/// instead of explicit masks.
+/// `writeIndices` specifies the offsets to use. If empty, all indices are set
+/// to 0.
+Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
+                                    Value vecToStore, Value dest,
+                                    SmallVector<Value> writeIndices = {},
+                                    bool useInBoundsInsteadOfMasking = false);
+
 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
 /// given `shape`, i.e., it meets:
 ///   1. The numbers of elements in both array are equal.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a7acdd2c018ea..b57e66a1c3580 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1574,211 +1574,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
-/// Determines whether a mask for xfer_write is trivially "all true"
-///
-/// Given all the inputs required to generate a mask (mask sizes and shapes),
-/// and an xfer_write operation (write indices and the destination tensor
-/// shape), determines whether the corresponding mask would be trivially
-/// foldable (i.e., trivially "all true").
-///
-/// Use this method to avoid generating spurious masks and relaying on
-/// vectorization post-processing to remove them.
-///
-/// Pre-conditions for a mask to be trivially foldable:
-///   * All involved shapes (mask + destination tensor) are static.
-///   * All write indices are constant.
-///   * All mask sizes are constant (including `arith.constant`).
-///
-/// If the pre-conditions are met, the method checks for each destination
-/// dimension `d`:
-///   (1) destDimSize[rankDiff + d] <= maskShape[d]
-///   (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
-///
-/// rankDiff = rank(dest) - rank(mask).
-///
-/// This method takes a conservative view: it may return false even if the mask
-/// is technically foldable.
-///
-/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
-/// of the dest tensor):
-///   %c0 = arith.constant 0 : index
-///   %mask = vector.create_mask 5, 1
-///   vector.mask %mask {
-///     vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
-///       {in_bounds = [true, true]}
-///     : vector<5x1xi32>, tensor<5x1xi32>
-///   }
-///
-/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
-/// mask is required to avoid out-of-bounds write):
-///   %c0 = arith.constant 0 : index
-///   %mask = vector.create_mask 5, 1
-///   vector.mask %mask {
-///     vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
-///      {in_bounds = [true, true]}
-///     : vector<8x1xi32>, tensor<5x1xi32>
-///   }
-///
-/// TODO: Re-use in createReadOrMaskedRead
-static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
-                                    SmallVector<Value> &writeIdxs,
-                                    ArrayRef<int64_t> destShape,
-                                    ArrayRef<int64_t> maskShape) {
-  // Masking is unavoidable in the case of dynamic tensors.
-  if (ShapedType::isDynamicShape(destShape))
-    return false;
-
-  // Collect all constant mask sizes.
-  SmallVector<int64_t, 4> cstMaskSizes;
-  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
-    if (auto intSize = getConstantIntValue(dimSize)) {
-      cstMaskSizes.push_back(*intSize);
-    }
-  }
-
-  // If any of the mask sizes is non-constant, bail out.
-  if (cstMaskSizes.size() != maskShape.size())
-    return false;
-
-  // Collect all constant write indices.
-  SmallVector<int64_t, 4> cstWriteIdxs;
-  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
-    APSInt intVal;
-    if (matchPattern(idx, m_ConstantInt(&intVal))) {
-      cstWriteIdxs.push_back(intVal.getSExtValue());
-    }
-  }
-
-  // If any of the write indices is non-constant, bail out.
-  if (cstWriteIdxs.size() != destShape.size())
-    return false;
-
-  // Go over all destination dims and check (1) and (2). Take into account that:
-  //  * The number of mask sizes will match the rank of the vector to store.
-  //    This could be lower than the rank of the destination tensor.
-  //  * Mask sizes could be larger than the corresponding mask shape (hence
-  //  `clamp`).
-  // TODO: The 2nd item should be rejected by the verifier.
-  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
-  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
-    if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
-        /*(2)*/ destShape[rankDiff + i] <
-            (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
-             cstWriteIdxs[i]))
-      return false;
-  }
-
-  return true;
-}
-
-/// Creates an optionally masked TransferWriteOp
-///
-/// Generates the following operation:
-///   %res = vector.transfer_write %vecToStore into %dest
-///
-/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
-///
-///   %mask = vector.create_mask(%destShape) : %vecToStoreShape
-///   %res = vector.mask %mask {
-///     vector.transfer_write %vecToStore into %dest
-///   }
-///
-/// The mask shape is identical to `vecToStore` (with the element type ==
-/// i1), and the mask values are based on the shape of the `dest` tensor.
-///
-/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
-/// is used instead of masking:
-///
-///   %write = vector.transfer_write %vecToStore into %dest
-///   in_bounds_flags = (...)
-///   %res = vector.transfer_write %input into %dest
-///       {in_bounds = in_bounds_flags}
-///
-/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
-/// are set to 0.
-static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
-                         Value dest, SmallVector<Value> writeIndices = {},
-                         bool useInBoundsInsteadOfMasking = false) {
-
-  ShapedType destType = cast<ShapedType>(dest.getType());
-  int64_t destRank = destType.getRank();
-  auto destShape = destType.getShape();
-
-  VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
-  int64_t vecToStoreRank = vecToStoreType.getRank();
-  auto vecToStoreShape = vecToStoreType.getShape();
-
-  // Compute the in_bounds attribute
-  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
-  if (useInBoundsInsteadOfMasking) {
-    // Update the inBounds attribute.
-    // FIXME: This computation is too weak - it ignores the write indices.
-    for (unsigned i = 0; i < vecToStoreRank; i++)
-      inBoundsVal[i] =
-          (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
-          ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
-  }
-
-  // If missing, initialize the write indices to 0.
-  bool useDefaultWriteIdxs = writeIndices.empty();
-  assert((useDefaultWriteIdxs ||
-          writeIndices.size() == static_cast<size_t>(destRank)) &&
-         "Invalid number of write indices!");
-  if (writeIndices.empty()) {
-    auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
-    writeIndices.assign(destRank, zero);
-  }
-
-  // Generate the xfer_write Op
-  Operation *write = vector::TransferWriteOp::create(builder, loc,
-                                                     /*vector=*/vecToStore,
-                                                     /*source=*/dest,
-                                                     /*indices=*/writeIndices,
-                                                     /*inBounds=*/inBoundsVal);
-
-  // If masking is disabled, exit.
-  if (useInBoundsInsteadOfMasking)
-    return write;
-
-  // Check if masking is needed. If not, exit.
-  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
-    return write;
-
-  // Compute the mask and mask the write Op.
-  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
-                                       vecToStoreType.getScalableDims());
-
-  SmallVector<OpFoldResult> destSizes =
-      isa<MemRefType>(dest.getType())
-          ? memref::getMixedSizes(builder, loc, dest)
-          : tensor::getMixedSizes(builder, loc, dest);
-
-  // Compute sizes for write-mask
-  SmallVector<OpFoldResult> maskSizes;
-  if (useDefaultWriteIdxs) {
-    maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
-                                          destSizes.end());
-  } else {
-    size_t 
diff  = destShape.size() - vecToStoreRank;
-    for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
-      auto value =
-          getValueOrCreateConstantIndexOp(builder, loc, destSizes[
diff  + idx]);
-      auto neg =
-          builder.createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
-      maskSizes.push_back(OpFoldResult(neg));
-    }
-  }
-
-  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
-                              vecToStoreShape))
-    return write;
-
-  Value maskForWrite =
-      builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
-  return mlir::vector::maskOperation(builder, write, maskForWrite);
-}
-
 /// Given the re-associations, "collapses" the input Vector type
 ///
 /// This is similar to CollapseShapeOp::inferCollapsedType with two notable
@@ -1929,7 +1724,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
       rewriter, loc, shapeCastOp.getResult(), destPermutation);
 
   // Create TransferWriteOp.
-  Operation *write = createWriteOrMaskedWrite(
+  Operation *write = vector::createWriteOrMaskedWrite(
       rewriter, loc, transposeOp.getResult(), packOp.getDest());
   newResults.push_back(write->getResult(0));
   return success();
@@ -2025,7 +1820,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
       rewriter, loc, collapsedVecType, transposeOp->getResult(0));
 
   // -- Generate the write operation --
-  Operation *write = createWriteOrMaskedWrite(
+  Operation *write = vector::createWriteOrMaskedWrite(
       rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
       /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
 
@@ -2061,7 +1856,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
                                        padOp.getResultType().getElementType());
-  Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
+  Operation *write =
+      vector::createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -2279,7 +2075,7 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
   contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
 
   // Store result.
-  Operation *write = createWriteOrMaskedWrite(
+  Operation *write = vector::createWriteOrMaskedWrite(
       rewriter, loc, contractOp->getResult(0), outOperand->get());
 
   // Finalize.
@@ -3207,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   auto writeIndices =
       getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
   Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
-                               writeIndices, inputVectorSizes.empty());
+      vector::createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
+                                       writeIndices, inputVectorSizes.empty());
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 9585f5a1d774a..576023dbc9de1 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -22,6 +23,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LLVM.h"
@@ -309,6 +311,101 @@ bool vector::isLinearizableVector(VectorType type) {
   return (type.getRank() > 1) && (type.getNumScalableDims() <= 1);
 }
 
+/// Determines whether a mask for xfer_read/write is trivially "all true"
+///
+/// Given all the inputs required to generate a mask (mask sizes and shapes),
+/// and an xfer_read/write operation (indices and the source/destination tensor
+/// shape), determines whether the corresponding mask would be trivially
+/// foldable (i.e., trivially "all true").
+///
+/// Use this method to avoid generating spurious masks and relying on
+/// vectorization post-processing to remove them.
+///
+/// Pre-conditions for a mask to be trivially foldable:
+///   * All involved shapes (mask + destination tensor) are static.
+///   * All indices are constant.
+///   * All mask sizes are constant (including `arith.constant`).
+///
+/// If the pre-conditions are met, the method checks for each destination
+/// dimension `d`:
+///   (1) destDimSize[rankDiff + d] <= maskShape[d]
+///   (2) destDimSize[rankDiff + d] <= index[d] + maskSize[d]
+///
+/// rankDiff = rank(dest) - rank(mask).
+///
+/// This method takes a conservative view: it may return false even if the mask
+/// is technically foldable.
+///
+/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
+/// of the dest tensor):
+///   %c0 = arith.constant 0 : index
+///   %mask = vector.create_mask 5, 1
+///   vector.mask %mask {
+///     vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
+///       {in_bounds = [true, true]}
+///     : vector<5x1xi32>, tensor<5x1xi32>
+///   }
+///
+/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
+/// mask is required to avoid out-of-bounds write):
+///   %c0 = arith.constant 0 : index
+///   %mask = vector.create_mask 5, 1
+///   vector.mask %mask {
+///     vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
+///      {in_bounds = [true, true]}
+///     : vector<8x1xi32>, tensor<5x1xi32>
+///   }
+static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
+                                    SmallVector<Value> &indices,
+                                    ArrayRef<int64_t> baseShape,
+                                    ArrayRef<int64_t> maskShape) {
+  // Masking is unavoidable in the case of dynamic tensors.
+  if (ShapedType::isDynamicShape(baseShape))
+    return false;
+
+  // Collect all constant mask sizes.
+  SmallVector<int64_t, 4> cstMaskSizes;
+  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      cstMaskSizes.push_back(*intSize);
+    }
+  }
+
+  // If any of the mask sizes is non-constant, bail out.
+  if (cstMaskSizes.size() != maskShape.size())
+    return false;
+
+  // Collect all constant indices.
+  SmallVector<int64_t, 4> cstIndices;
+  for (auto [i, idx] : llvm::enumerate(indices)) {
+    APSInt intVal;
+    if (matchPattern(idx, m_ConstantInt(&intVal))) {
+      cstIndices.push_back(intVal.getSExtValue());
+    }
+  }
+
+  // If any of the indices is non-constant, bail out.
+  if (cstIndices.size() != baseShape.size())
+    return false;
+
+  // Go over all destination dims and check (1) and (2). Take into account that:
+  //  * The number of mask sizes will match the rank of the vector to
+  //    load/store. This could be lower than the rank of the destination tensor.
+  //  * Mask sizes could be larger than the corresponding mask shape (hence
+  //    `clamp`).
+  // TODO: The 2nd item should be rejected by the verifier.
+  int64_t rankDiff = baseShape.size() - cstMaskSizes.size();
+  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
+    if (/*(1)*/ maskShape[i] > baseShape[rankDiff + i] ||
+        /*(2)*/ baseShape[rankDiff + i] <
+            (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
+             cstIndices[i]))
+      return false;
+  }
+
+  return true;
+}
+
 Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
                                      Value source,
                                      ArrayRef<int64_t> inputVectorSizes,
@@ -353,22 +450,27 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
       inBoundsVal[i] = (sourceShape[i] == vecToReadShape[i]) &&
                        ShapedType::isStatic(sourceShape[i]);
   }
-  auto transferReadOp = vector::TransferReadOp::create(
-      builder, loc,
-      /*vectorType=*/vecToReadTy,
-      /*source=*/source,
-      /*indices=*/Repeated<Value>(vecToReadRank, zero),
-      /*padding=*/padValue,
-      /*inBounds=*/inBoundsVal);
-
-  if (llvm::equal(vecToReadTy.getShape(), sourceShape) ||
-      useInBoundsInsteadOfMasking)
+  SmallVector<Value> indices(vecToReadRank, zero);
+  auto transferReadOp =
+      vector::TransferReadOp::create(builder, loc,
+                                     /*vectorType=*/vecToReadTy,
+                                     /*source=*/source,
+                                     /*indices=*/indices,
+                                     /*padding=*/padValue,
+                                     /*inBounds=*/inBoundsVal);
+
+  if (useInBoundsInsteadOfMasking)
     return transferReadOp;
+
   SmallVector<OpFoldResult> mixedSourceDims =
       isa<MemRefType>(source.getType())
           ? memref::getMixedSizes(builder, loc, source)
           : tensor::getMixedSizes(builder, loc, source);
 
+  if (isMaskTriviallyFoldable(mixedSourceDims, indices, sourceShape,
+                              vecToReadShape))
+    return transferReadOp;
+
   auto maskType = vecToReadTy.cloneWith(/*shape=*/{}, builder.getI1Type());
   Value mask =
       vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
@@ -376,6 +478,89 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
       ->getResult(0);
 }
 
+Operation *vector::createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
+                                            Value vecToStore, Value dest,
+                                            SmallVector<Value> writeIndices,
+                                            bool useInBoundsInsteadOfMasking) {
+
+  ShapedType destType = cast<ShapedType>(dest.getType());
+  int64_t destRank = destType.getRank();
+  auto destShape = destType.getShape();
+
+  VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
+  int64_t vecToStoreRank = vecToStoreType.getRank();
+  auto vecToStoreShape = vecToStoreType.getShape();
+
+  // Compute the in_bounds attribute
+  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
+  if (useInBoundsInsteadOfMasking) {
+    // Update the inBounds attribute.
+    // FIXME: This computation is too weak - it ignores the write indices.
+    for (unsigned i = 0; i < vecToStoreRank; i++)
+      inBoundsVal[i] =
+          (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
+          ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
+  }
+
+  // If missing, initialize the write indices to 0.
+  bool useDefaultWriteIdxs = writeIndices.empty();
+  assert((useDefaultWriteIdxs ||
+          writeIndices.size() == static_cast<size_t>(destRank)) &&
+         "Invalid number of write indices!");
+  if (useDefaultWriteIdxs) {
+    auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
+    writeIndices.assign(destRank, zero);
+  }
+
+  // Generate the xfer_write Op
+  Operation *write = vector::TransferWriteOp::create(builder, loc,
+                                                     /*vector=*/vecToStore,
+                                                     /*dest=*/dest,
+                                                     /*indices=*/writeIndices,
+                                                     /*inBounds=*/inBoundsVal);
+
+  // If masking is disabled, exit.
+  if (useInBoundsInsteadOfMasking)
+    return write;
+
+  // Check if masking is needed. If not, exit.
+  if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+    return write;
+
+  // Compute the mask and mask the write Op.
+  auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
+                                       vecToStoreType.getScalableDims());
+
+  SmallVector<OpFoldResult> destSizes =
+      isa<MemRefType>(dest.getType())
+          ? memref::getMixedSizes(builder, loc, dest)
+          : tensor::getMixedSizes(builder, loc, dest);
+
+  // Compute sizes for write-mask
+  SmallVector<OpFoldResult> maskSizes;
+  if (useDefaultWriteIdxs) {
+    maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
+                                          destSizes.end());
+  } else {
+    size_t 
diff  = destShape.size() - vecToStoreRank;
+    for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
+      auto value =
+          getValueOrCreateConstantIndexOp(builder, loc, destSizes[
diff  + idx]);
+      auto neg =
+          builder.createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
+      maskSizes.push_back(OpFoldResult(neg));
+    }
+  }
+
+  if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                              vecToStoreShape))
+    return write;
+
+  Value maskForWrite =
+      builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+  return mlir::vector::maskOperation(builder, write, maskForWrite);
+}
+
 LogicalResult
 vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
                                  ArrayRef<int64_t> inputVectorSizes) {


        


More information about the Mlir-commits mailing list