[Mlir-commits] [mlir] [mlir][vector] Update docs + add tests (PR #137144)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 24 02:13:42 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

This is a small follow-on for #<!-- -->133721:
* Renamed `getRealVectorRank` as `getEffectiveVectorRankForXferOp` (to
  emphasise that this method was written specifically for transfer Ops).
* Marginally tweaked the description for
  `getEffectiveVectorRankForXferOp` (mostly to highlight the two edge
  cases being covered).
* Added tests for cases when the element type (of the shaped type) is a
  vector.
* Unified the naming (and the order) of arguments in tests with the
  surrounding tests (e.g. `%vec_to_write` -> `%arg1`). Mostly for
  consistency (it would be good to use self-documenting names like
  `%vec_to_write` throughout).


---
Full diff: https://github.com/llvm/llvm-project/pull/137144.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+20-14) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+21-4) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3fee1e949aeed..49dd433597e8c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -151,29 +151,32 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
   return false;
 }
 
-/// Returns the number of dimensions of the `shapedType` that participate in the
-/// vector transfer, effectively the rank of the vector dimensions within the
-/// `shapedType`. This is calculated by taking the rank of the `vectorType`
-/// being transferred and subtracting the rank of the `shapedType`'s element
-/// type if it's also a vector.
+///  Returns the effective rank of the vector to read/write for Xfer Ops
 ///
-/// This is used to determine the number of minor dimensions for identity maps
-/// in vector transfers.
+///  When the element type of the shaped type is _a scalar_, this will simply
+///  return the rank of the vector ( the result for xfer_read or the value to
+///  store for xfer_write).
 ///
-/// For example, given a transfer operation involving `shapedType` and
-/// `vectorType`:
+///  When the element type of the base shaped type is _a vector_, returns the
+///  difference between the original vector type and the element type of the
+///  shaped type.
 ///
+///  EXAMPLE 1 (element type is _a scalar_):
 ///   - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
 ///     - shapedType.getElementType() = f32 (rank 0)
 ///     - vectorType.getRank() = 2
 ///     - Result = 2 - 0 = 2
 ///
+/// EXAMPLE 2 (element type is _a vector_):
 ///   - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
 ///     - shapedType.getElementType() = vector<20xf32> (rank 1)
 ///     - vectorType.getRank() = 1
 ///     - Result = 1 - 1 = 0
-static unsigned getRealVectorRank(ShapedType shapedType,
-                                  VectorType vectorType) {
+///
+/// This is used to determine the number of minor dimensions for identity maps
+/// in vector transfer Ops.
+static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
+                                                VectorType vectorType) {
   unsigned elementVectorRank = 0;
   VectorType elementVectorType =
       llvm::dyn_cast<VectorType>(shapedType.getElementType());
@@ -192,7 +195,8 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
         /*numDims=*/0, /*numSymbols=*/0,
         getAffineConstantExpr(0, shapedType.getContext()));
   return AffineMap::getMinorIdentityMap(
-      shapedType.getRank(), getRealVectorRank(shapedType, vectorType),
+      shapedType.getRank(),
+      getEffectiveVectorRankForXferOp(shapedType, vectorType),
       shapedType.getContext());
 }
 
@@ -4260,7 +4264,8 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   Attribute permMapAttr = result.attributes.get(permMapAttrName);
   AffineMap permMap;
   if (!permMapAttr) {
-    if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
+    if (shapedType.getRank() <
+        getEffectiveVectorRankForXferOp(shapedType, vectorType))
       return parser.emitError(typesLoc,
                               "expected a custom permutation_map when "
                               "rank(source) != rank(destination)");
@@ -4679,7 +4684,8 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
   auto permMapAttr = result.attributes.get(permMapAttrName);
   AffineMap permMap;
   if (!permMapAttr) {
-    if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
+    if (shapedType.getRank() <
+        getEffectiveVectorRankForXferOp(shapedType, vectorType))
       return parser.emitError(typesLoc,
                               "expected a custom permutation_map when "
                               "rank(source) != rank(destination)");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 19096f0e4c895..349a58d4eb4e4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -525,15 +525,24 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 
 // -----
 
-func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xindex> {
+func.func @test_vector.transfer_read(%arg0: memref<?xindex>) -> vector<3x4xindex> {
   %c3 = arith.constant 3 : index
   // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
-  %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xindex>
+  %0 = vector.transfer_read %arg0[%c3], %c3 : memref<?xindex>, vector<3x4xindex>
   return %0 : vector<3x4xindex>
 }
 
 // -----
 
+func.func @test_vector.transfer_write(%arg0: memref<?xvector<2xindex>>) {
+  %c3 = arith.constant 3 : index
+  // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
+  %0 = vector.transfer_read %arg0[%c3], %c3 : memref<?xvector<2xindex>>, vector<2x3x4xindex>
+  return %0 : vector<2x3x4xindex>
+}
+
+// -----
+
 func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
   %c3 = arith.constant 3 : index
   %cst = arith.constant 3.0 : f32
@@ -655,10 +664,18 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
 
 // -----
 
-func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref<?xindex>) {
+func.func @test_vector.transfer_write(%arg0: memref<?xindex>, %arg1: vector<3x4xindex>) {
+  %c3 = arith.constant 3 : index
+  // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
+  vector.transfer_write %arg1, %arg0[%c3, %c3] : vector<3x4xindex>, memref<?xindex>
+}
+
+// -----
+
+func.func @test_vector.transfer_write(%arg0: memref<?xvector<2xindex>>, %arg1: vector<2x3x4xindex>) {
   %c3 = arith.constant 3 : index
   // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
-  vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xindex>, memref<?xindex>
+  vector.transfer_write %arg1, %arg0[%c3, %c3] : vector<2x3x4xindex>, memref<?xvector<2xindex>>
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/137144


More information about the Mlir-commits mailing list