[Mlir-commits] [mlir] [mlir][linalg] Refactor vectorization hooks to improve code reuse (PR #141244)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 29 10:36:10 PDT 2025


================
@@ -1528,75 +1628,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// NOTE: All write offsets are set to 0.
-/// TODO: Allow specyfying write offsets.
-/// NOTE: When N < rank(input), the missing vector sizes are effectively
-/// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static.
-/// TODO: Support cases where an arbitrary dim is dynamic - this will require
-/// specifying all the vector sizes.
+/// `writeIndices` specifies the offsets to use. If empty, all indices are set
+/// to 0.
+///
+/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
+/// `valueToStore`.
+/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
+/// already provided in `vectorToStore`.
 static Operation *
 createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
                          Value dest,
                          ArrayRef<int64_t> inputVecSizesForLeadingDims,
+                         SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
-  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
-             static_cast<int64_t>(destType.getRank()) &&
-         "Rank mismatch!");
-  (void)destType;
+  int64_t destRank = destType.getRank();
+  auto destShape = destType.getShape();
 
-  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
-  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  int64_t vecToStoreRank = vecToStoreType.getRank();
+  auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
-  SmallVector<bool> inBoundsVal(rank, true);
+  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
     // In this case, assume that all the required vector sizes have been
     // provided.
     assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(destType.getRank()) &&
+               static_cast<size_t>(vecToStoreType.getRank()) &&
            "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
-    for (unsigned i = 0; i < rank; i++)
+    for (unsigned i = 0; i < destRank; i++)
       inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
+  // If missing, initialize the write indices to 0.
+  assert(writeIndices.empty() ||
+         writeIndices.size() == static_cast<size_t>(destRank) &&
+             "Invalid number of write indices!");
+  if (writeIndices.empty()) {
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    writeIndices = SmallVector<Value>(destRank, zero);
----------------
Max191 wrote:

nit: I think you can use `writeIndices.assign(destRank, zero);`

https://github.com/llvm/llvm-project/pull/141244


More information about the Mlir-commits mailing list