[Mlir-commits] [mlir] [mlir][vector] Refactor `createWriteOrMaskedWrite` (PR #138137)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 1 07:07:32 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

<details>
<summary>Changes</summary>

This patch updates `createWriteOrMaskedWrite` to make it consistent with
`createReadOrMaskedRead`.

Before diving into the details: note that these utilities are currently
implemented in different files — "VectorUtils.cpp" (Vector) and
"Vectorization.cpp" (Linalg). In a subsequent patch, I plan to move
`createWriteOrMaskedWrite` into "VectorUtils.cpp".

SUMMARY OF CHANGES:

The main change is to remove the logic that creates the destination
tensor, which previously looked like:
```cpp
  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
                                               inputType.getElementType());
```

With this patch, createWriteOrMaskedWrite now simply generates:
```mlir
  %res = vector.transfer_write %vectorToStore into %dest
```

This replaces the previous form:
```mlir
  %dest = tensor.empty(%destSizes)
  %res = vector.transfer_write %vectorToStore into %dest
```

In other words, the destination value `%dest` is now passed as an input
parameter. This makes `createWriteOrMaskedWrite` re-usable in contexts
where the destination tensor is already known — for example, in
`vectorizeAsInsertSliceOp`, which I will update in a follow-up patch.

OTHER CHANGES:

* Added comments and clarified TODOs.

* Updated tests: since destination sizes are now computed independently
  inside `createWriteOrMaskedWrite`, some additional `tensor.dim` ops
  appear. These will be cleaned up by CSE + canonicalization.


---
Full diff: https://github.com/llvm/llvm-project/pull/138137.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+55-43) 
- (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+6-2) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a477c2fb3f8cb..12ecdf9494bef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1506,72 +1506,68 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
   return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
 }
 
-/// Creates a TransferWriteOp to write `input` into a newly initialized
-/// output tensor.
+/// Creates an optionally masked TransferWriteOp
 ///
-/// Given:
-/// - an input vector to write,
-/// - the mixed destination sizes for the output tensor,
-/// - and the vector sizes used for vectorization (i.e., the leading N dims,
-///   for some value of N),
-///
-/// this function generates the following sequence of ops:
-///
-///   %dest = tensor.empty(%destSizes)
-///   %res = vector.transfer_write %input into %dest
+/// Generates the following operation:
+///   %res = vector.transfer_write %vectorToStore into %dest
 ///
 /// If the leading N dimensions of the destination tensor do not match
-/// `inputVecSizesForLeadingDims` (where N =
-/// rank(`inputVecSizesForLeadingDims`)), masking is applied to ensure
-/// correctness:
+/// `inputVecSizesForLeadingDims`, where=
+///   * N = rank(`inputVecSizesForLeadingDims`)),
+/// masking is applied to ensure correctness:
 ///
-///   %dest = tensor.empty(%destSizes)
-///   %write = vector.transfer_write %input into %dest
-///   %mask = vector.create_mask(%destSizes)
+///   %write = vector.transfer_write %vectorToStore into %dest
+///   %mask = vector.create_mask(%destShape)
 ///   %res = vector.mask %mask { %write }
 ///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
-///   %dest = tensor.empty(%destSizes)
+///   %write = vector.transfer_write %vectorToStore into %dest
 ///   in_bounds_flags = (...)
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// NOTE: all write offsets are set to 0.
+/// NOTE: All write offsets are set to 0.
+/// TODO: Allow specyfying write offsets.
 /// NOTE: When N < rank(input), the missing vector sizes are effectively
 /// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static. Supporting dynamic sizes will require the user to specify
-/// the remaining vector sizes. This is left as a TODO.
+/// must be static.
+/// TODO: Support cases where an arbitrary dim is dynamic - this will require
+/// specifying all the vector sizes.
 static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
-                         SmallVector<OpFoldResult> destSizes,
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
+                         Value dest,
                          ArrayRef<int64_t> inputVecSizesForLeadingDims,
                          bool useInBoundsInsteadOfMasking = false) {
 
-  auto inputType = cast<VectorType>(input.getType());
-  assert(inputType.getRank() == static_cast<int64_t>(destSizes.size()) &&
+  ShapedType destType = cast<ShapedType>(dest.getType());
+  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
+             static_cast<int64_t>(destType.getRank()) &&
          "Rank mismatch!");
 
-  Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
-                                               inputType.getElementType());
   int64_t rank = cast<ShapedType>(dest.getType()).getRank();
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   auto destShape = cast<ShapedType>(dest.getType()).getShape();
+
+  // Compute the in_bounds attribute
   SmallVector<bool> inBoundsVal(rank, true);
   if (useInBoundsInsteadOfMasking) {
     // In this case, assume that all the required vector sizes have been
     // provided.
-    assert(inputVecSizesForLeadingDims.size() == destSizes.size() &&
+    assert(inputVecSizesForLeadingDims.size() ==
+               static_cast<size_t>(destType.getRank()) &&
            "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
     for (unsigned i = 0; i < rank; i++)
       inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
+
+  // Generate the xfer_write Op
+  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   Operation *write = builder.create<vector::TransferWriteOp>(
       loc,
-      /*vector=*/input,
+      /*vector=*/vectorToStore,
       /*source=*/dest,
       /*indices=*/SmallVector<Value>(rank, zero),
       /*inBounds=*/inBoundsVal);
@@ -1579,11 +1575,17 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
              destShape.drop_front(inputVecSizesForLeadingDims.size()),
              [](int64_t size) { return size == ShapedType::kDynamic; }) &&
          "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+
+  // If masking is disabled, exit.
   if (useInBoundsInsteadOfMasking)
     return write;
+
+  // Check if masking is needed.
   bool needMaskForWrite =
       !llvm::equal(inputVecSizesForLeadingDims,
                    destShape.take_front(inputVecSizesForLeadingDims.size()));
+
+  // If masking is needed, generate the mask and mask the operation.
   if (needMaskForWrite) {
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
@@ -1592,10 +1594,11 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
                               inputVecSizesForLeadingDims.size(),
                           destShape.end());
     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-    Value maskForWrite =
-        builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
+    Value maskForWrite = builder.create<vector::CreateMaskOp>(
+        loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
     write = mlir::vector::maskOperation(builder, write, maskForWrite);
   }
+
   return write;
 }
 
@@ -1693,9 +1696,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
       loc, shapeCastOp.getResult(), destPermutation);
 
   // Create TransferWriteOp.
+  Value dest = rewriter.create<tensor::EmptyOp>(
+      loc, reifiedReturnShapes[0],
+      transposeOp.getResult().getType().getElementType());
   Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
-                               /*destSizes=*/reifiedReturnShapes[0],
+      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
                                /*inputVecSizesForLeadingDims=*/inputVectorSizes,
                                /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
@@ -1830,10 +1835,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
       unpackOp.getDestType().hasStaticShape()
           ? vectorSizes
           : shapeCastOp.getResultVectorType().getShape());
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, shapeCastOp.getResult(), /*destSizes=*/reifiedRetShapes[0],
-      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
-      useInBoundsInsteadOfMasking);
+  Value dest = rewriter.create<tensor::EmptyOp>(
+      loc, reifiedRetShapes[0],
+      shapeCastOp.getResult().getType().getElementType());
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
+                               /*inputVecSizesForLeadingDims=*/writeVectorSizes,
+                               useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1861,10 +1869,14 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   auto maskedRead = vector::createReadOrMaskedRead(
       rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
       /*useInBoundsInsteadOfMasking=*/false);
-  Operation *write = createWriteOrMaskedWrite(
-      rewriter, loc, maskedRead, reifiedReturnShapes[0],
-      /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-      /*useInBoundsInsteadOfMasking=*/false);
+
+  // Create Xfer write Op
+  Value dest = rewriter.create<tensor::EmptyOp>(
+      loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
+  Operation *write =
+      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
+                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
+                               /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 299be1296aa66..6b760a15afd56 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -641,7 +641,9 @@ func.func @test_masked_vectorize_dynamic_pad(
   // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
   //  CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[res_d0]], %[[res_d1]]) : tensor<?x?xf32>
   //  CHECK-DAG: %[[c0_3:.*]] = arith.constant 0 : index
-  //      CHECK: %[[mask_2:.*]] = vector.create_mask %[[res_d0]], %[[res_d1]] : vector<2x4xi1>
+  //  CHECK-DAG: %[[d2:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?xf32>
+  //  CHECK-DAG: %[[d3:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?xf32>
+  //      CHECK: %[[mask_2:.*]] = vector.create_mask %[[d2]], %[[d3]] : vector<2x4xi1>
   //      CHECK: %[[masked_write:.*]] = vector.mask %[[mask_2]] {
   // CHECK-SAME: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0_3]], %[[c0_3]]]
   // CHECK-SAME:   {in_bounds = [true, true]} : vector<2x4xf32>, tensor<?x?xf32>
@@ -800,7 +802,9 @@ func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?
 //  CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
 //  CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
 //  CHECK-DAG: %[[empty:.*]] = tensor.empty(%[[d0]], %[[d1]]) : tensor<?x?x16x2xf32>
-//      CHECK: %[[mask_0:.*]] = vector.create_mask %[[d0]], %[[d1]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1>
+//  CHECK-DAG: %[[d2:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?x16x2xf32>
+//  CHECK-DAG: %[[d3:.*]] = tensor.dim %[[empty]], {{.*}} : tensor<?x?x16x2xf32>
+//      CHECK: %[[mask_0:.*]] = vector.create_mask %[[d2]], %[[d3]], %[[c16]], %[[c2]] : vector<4x1x16x2xi1>
 //      CHECK: %[[masked_write:.*]] = vector.mask %[[mask_0]] {
 // CHECK-SAME:   vector.transfer_write %[[transpose]], %[[empty]][%[[c0_2]], %[[c0_2]], %[[c0_2]], %[[c0_2]]]
 // CHECK-SAME:   {in_bounds = [true, true, true, true]} : vector<4x1x16x2xf32>, tensor<?x?x16x2xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/138137


More information about the Mlir-commits mailing list