[Mlir-commits] [mlir] [mlir][vector] Fix parser of vector.transfer_read (PR #133721)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 15 06:56:04 PDT 2025


https://github.com/douyixuan updated https://github.com/llvm/llvm-project/pull/133721

>From 8675a216b18efd159a220cacb10e1a2331228fc4 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Mon, 31 Mar 2025 14:14:54 +0800
Subject: [PATCH 1/9] [mlir][vector] Fix parser of vector.transfer_read

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 ++++++++-
 mlir/test/Dialect/Vector/invalid.mlir    | 9 +++++++++
 2 files changed, 17 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..d99c8ef0680f4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -151,7 +151,7 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
 }
 
 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
-                                                    VectorType vectorType) {
+                                               VectorType vectorType) {
   int64_t elementVectorRank = 0;
   VectorType elementVectorType =
       llvm::dyn_cast<VectorType>(shapedType.getElementType());
@@ -164,6 +164,9 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
     return AffineMap::get(
         /*numDims=*/0, /*numSymbols=*/0,
         getAffineConstantExpr(0, shapedType.getContext()));
+  if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
+    return AffineMap::get(shapedType.getContext());
+  }
   return AffineMap::getMinorIdentityMap(
       shapedType.getRank(), vectorType.getRank() - elementVectorRank,
       shapedType.getContext());
@@ -4260,6 +4263,10 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   AffineMap permMap;
   if (!permMapAttr) {
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
+    if (permMap.isEmpty()) {
+      return parser.emitError(typesLoc,
+                              "failed to create minor identity permutation map");
+    }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
     permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea6d0021391fb..667e1615212f4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -525,6 +525,15 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 
 // -----
 
+func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
+  %c3_i32 = arith.constant 3 : i32
+  // expected-error at +1 {{failed to create minor identity permutation map}}
+  %0 = vector.transfer_read %arg1[%c3_i32, %c3_i32], %c3_i32 : memref<?xindex>, vector<3x4xi32>
+  return %0 : vector<3x4xi32>
+}
+
+// -----
+
 func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
   %c3 = arith.constant 3 : index
   %cst = arith.constant 3.0 : f32

>From 38f8b62b37ac447c460ed6d789f2b9efc1ad39f3 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Tue, 1 Apr 2025 09:33:59 +0800
Subject: [PATCH 2/9] [mlir][vector] Fix parser of vector.transfer_read

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d99c8ef0680f4..fc0e17be52f70 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -151,7 +151,7 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
 }
 
 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
-                                               VectorType vectorType) {
+                                                    VectorType vectorType) {
   int64_t elementVectorRank = 0;
   VectorType elementVectorType =
       llvm::dyn_cast<VectorType>(shapedType.getElementType());
@@ -165,7 +165,8 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
         /*numDims=*/0, /*numSymbols=*/0,
         getAffineConstantExpr(0, shapedType.getContext()));
   if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
-    return AffineMap::get(shapedType.getContext());
+    return AffineMap(); // Not enough dimensions in the shaped type to form a
+                        // minor identity map.
   }
   return AffineMap::getMinorIdentityMap(
       shapedType.getRank(), vectorType.getRank() - elementVectorRank,
@@ -4263,9 +4264,9 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   AffineMap permMap;
   if (!permMapAttr) {
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
-    if (permMap.isEmpty()) {
-      return parser.emitError(typesLoc,
-                              "failed to create minor identity permutation map");
+    if (!permMap) {
+      return parser.emitError(
+          typesLoc, "failed to create minor identity permutation map");
     }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {

>From 6755d1dce840822f22eb3c95aa06c7439072bf57 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Tue, 1 Apr 2025 21:30:10 +0800
Subject: [PATCH 3/9] [mlir][vector] Fix parser of vector.transfer_read

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 ++++++++--
 mlir/test/Dialect/Vector/invalid.mlir    | 10 +++++++++-
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fc0e17be52f70..b3ebd190a80d0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4265,8 +4265,9 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   if (!permMapAttr) {
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     if (!permMap) {
-      return parser.emitError(
-          typesLoc, "failed to create minor identity permutation map");
+      return parser.emitError(typesLoc,
+                              "expected the same rank for the vector and the "
+                              "results of the permutation map");
     }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
@@ -4676,6 +4677,11 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
   AffineMap permMap;
   if (!permMapAttr) {
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
+    if (!permMap) {
+      return parser.emitError(typesLoc,
+                              "expected the same rank for the vector and the "
+                              "results of the permutation map");
+    }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
     permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 667e1615212f4..e0d440ea0f4b6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -527,7 +527,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 
 func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
   %c3_i32 = arith.constant 3 : i32
-  // expected-error at +1 {{failed to create minor identity permutation map}}
+  // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
   %0 = vector.transfer_read %arg1[%c3_i32, %c3_i32], %c3_i32 : memref<?xindex>, vector<3x4xi32>
   return %0 : vector<3x4xi32>
 }
@@ -655,6 +655,14 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
 
 // -----
 
+func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
+  %c3_idx = arith.constant 3 : index
+  // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
+  vector.transfer_write %vec_to_write, %output_memref[%c3_idx, %c3_idx] : vector<3x4xi32>, memref<?xindex>
+}
+
+// -----
+
 func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected offsets of same size as destination vector rank}}
   %1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>

>From ba2f392c9c50cf2a3aa042ae39352a20213badb1 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Wed, 2 Apr 2025 19:44:57 +0800
Subject: [PATCH 4/9] [mlir][vector] Fix parser of vector.transfer_read

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++--------
 mlir/test/Dialect/Vector/invalid.mlir    |  4 ++--
 2 files changed, 8 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b3ebd190a80d0..e6e2fd72aa747 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -165,8 +165,8 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
         /*numDims=*/0, /*numSymbols=*/0,
         getAffineConstantExpr(0, shapedType.getContext()));
   if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
-    return AffineMap(); // Not enough dimensions in the shaped type to form a
-                        // minor identity map.
+    // Not enough dimensions in the shaped type to form a minor identity map.
+    return AffineMap();
   }
   return AffineMap::getMinorIdentityMap(
       shapedType.getRank(), vectorType.getRank() - elementVectorRank,
@@ -4265,9 +4265,8 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   if (!permMapAttr) {
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     if (!permMap) {
-      return parser.emitError(typesLoc,
-                              "expected the same rank for the vector and the "
-                              "results of the permutation map");
+      return parser.emitError(
+          typesLoc, "source rank is less than required for vector rank");
     }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
@@ -4678,9 +4677,8 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
   if (!permMapAttr) {
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     if (!permMap) {
-      return parser.emitError(typesLoc,
-                              "expected the same rank for the vector and the "
-                              "results of the permutation map");
+      return parser.emitError(
+          typesLoc, "result rank is less than required for vector rank");
     }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e0d440ea0f4b6..3e269a247c448 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -527,7 +527,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 
 func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
   %c3_i32 = arith.constant 3 : i32
-  // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
+  // expected-error at +1 {{source rank is less than required for vector rank}}
   %0 = vector.transfer_read %arg1[%c3_i32, %c3_i32], %c3_i32 : memref<?xindex>, vector<3x4xi32>
   return %0 : vector<3x4xi32>
 }
@@ -657,7 +657,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
 
 func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
   %c3_idx = arith.constant 3 : index
-  // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
+  // expected-error at +1 {{result rank is less than required for vector rank}}
   vector.transfer_write %vec_to_write, %output_memref[%c3_idx, %c3_idx] : vector<3x4xi32>, memref<?xindex>
 }
 

>From 3a8d23c7ead058a0f63eeb5b6f4b1d7e0bc78862 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Thu, 3 Apr 2025 00:27:45 +0800
Subject: [PATCH 5/9] [mlir][vector] Fix parser of vector.transfer_read

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 +++++++-------
 mlir/test/Dialect/Vector/invalid.mlir    |  8 ++++----
 2 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e6e2fd72aa747..8126bf0c52d75 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -152,11 +152,6 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
 
 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
                                                     VectorType vectorType) {
-  int64_t elementVectorRank = 0;
-  VectorType elementVectorType =
-      llvm::dyn_cast<VectorType>(shapedType.getElementType());
-  if (elementVectorType)
-    elementVectorRank += elementVectorType.getRank();
   // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
   // TODO: replace once we have 0-d vectors.
   if (shapedType.getRank() == 0 &&
@@ -164,6 +159,11 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
     return AffineMap::get(
         /*numDims=*/0, /*numSymbols=*/0,
         getAffineConstantExpr(0, shapedType.getContext()));
+  int64_t elementVectorRank = 0;
+  VectorType elementVectorType =
+      llvm::dyn_cast<VectorType>(shapedType.getElementType());
+  if (elementVectorType)
+    elementVectorRank += elementVectorType.getRank();
   if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
     // Not enough dimensions in the shaped type to form a minor identity map.
     return AffineMap();
@@ -4266,7 +4266,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     if (!permMap) {
       return parser.emitError(
-          typesLoc, "source rank is less than required for vector rank");
+          typesLoc, "failed to create a minor identity map, source rank is less than required for vector rank");
     }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
@@ -4678,7 +4678,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     if (!permMap) {
       return parser.emitError(
-          typesLoc, "result rank is less than required for vector rank");
+          typesLoc, "failed to create a minor identity map, result rank is less than required for vector rank");
     }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3e269a247c448..980c4287c9ac6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -526,9 +526,9 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 // -----
 
 func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
-  %c3_i32 = arith.constant 3 : i32
-  // expected-error at +1 {{source rank is less than required for vector rank}}
-  %0 = vector.transfer_read %arg1[%c3_i32, %c3_i32], %c3_i32 : memref<?xindex>, vector<3x4xi32>
+  %c3 = arith.constant 3 : index
+  // expected-error at +1 {{failed to create a minor identity map, source rank is less than required for vector rank}}
+  %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xi32>
   return %0 : vector<3x4xi32>
 }
 
@@ -657,7 +657,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
 
 func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
   %c3_idx = arith.constant 3 : index
-  // expected-error at +1 {{result rank is less than required for vector rank}}
+  // expected-error at +1 {{failed to create a minor identity map, result rank is less than required for vector rank}}
   vector.transfer_write %vec_to_write, %output_memref[%c3_idx, %c3_idx] : vector<3x4xi32>, memref<?xindex>
 }
 

>From f671fe9bde677151a87c5f62a8cd87a419c8de04 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Fri, 4 Apr 2025 00:01:30 +0800
Subject: [PATCH 6/9] [mlir][vector] Fix parser of vector.transfer_read

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 22 ++++++++++++++++++----
 mlir/test/Dialect/Vector/invalid.mlir    |  8 ++++----
 2 files changed, 22 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8126bf0c52d75..d6221f3488632 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4265,8 +4265,15 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   if (!permMapAttr) {
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     if (!permMap) {
-      return parser.emitError(
-          typesLoc, "failed to create a minor identity map, source rank is less than required for vector rank");
+      int64_t elementVectorRank = 0;
+      VectorType elementVectorType =
+          llvm::dyn_cast<VectorType>(shapedType.getElementType());
+      if (elementVectorType)
+        elementVectorRank += elementVectorType.getRank();
+      if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+        return parser.emitError(typesLoc,
+                                "expected a custom permutation_map when source "
+                                "rank is less than required for vector rank");
     }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
@@ -4677,8 +4684,15 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
   if (!permMapAttr) {
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
     if (!permMap) {
-      return parser.emitError(
-          typesLoc, "failed to create a minor identity map, result rank is less than required for vector rank");
+      int64_t elementVectorRank = 0;
+      VectorType elementVectorType =
+          llvm::dyn_cast<VectorType>(shapedType.getElementType());
+      if (elementVectorType)
+        elementVectorRank += elementVectorType.getRank();
+      if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+        return parser.emitError(typesLoc,
+                                "expected a custom permutation_map when result "
+                                "rank is less than required for vector rank");
     }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 980c4287c9ac6..cb2d97bd17e04 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -527,7 +527,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 
 func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
   %c3 = arith.constant 3 : index
-  // expected-error at +1 {{failed to create a minor identity map, source rank is less than required for vector rank}}
+  // expected-error at +1 {{expected a custom permutation_map when source rank is less than required for vector rank}}
   %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xi32>
   return %0 : vector<3x4xi32>
 }
@@ -656,9 +656,9 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
 // -----
 
 func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
-  %c3_idx = arith.constant 3 : index
-  // expected-error at +1 {{failed to create a minor identity map, result rank is less than required for vector rank}}
-  vector.transfer_write %vec_to_write, %output_memref[%c3_idx, %c3_idx] : vector<3x4xi32>, memref<?xindex>
+  %c3 = arith.constant 3 : index
+  // expected-error at +1 {{expected a custom permutation_map when result rank is less than required for vector rank}}
+  vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xi32>, memref<?xindex>
 }
 
 // -----

>From bbe35d61a9f7ea2f53c06305a411ba7a52e18eac Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Tue, 8 Apr 2025 19:38:22 +0800
Subject: [PATCH 7/9] [mlir][vector] Fix parser of vector.transfer_read

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 54 ++++++++++--------------
 mlir/test/Dialect/Vector/invalid.mlir    | 21 ++++++++-
 2 files changed, 42 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d6221f3488632..daf5b0d70d345 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -152,6 +152,11 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
 
 AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
                                                     VectorType vectorType) {
+  int64_t elementVectorRank = 0;
+  VectorType elementVectorType =
+      llvm::dyn_cast<VectorType>(shapedType.getElementType());
+  if (elementVectorType)
+    elementVectorRank += elementVectorType.getRank();
   // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
   // TODO: replace once we have 0-d vectors.
   if (shapedType.getRank() == 0 &&
@@ -159,15 +164,6 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
     return AffineMap::get(
         /*numDims=*/0, /*numSymbols=*/0,
         getAffineConstantExpr(0, shapedType.getContext()));
-  int64_t elementVectorRank = 0;
-  VectorType elementVectorType =
-      llvm::dyn_cast<VectorType>(shapedType.getElementType());
-  if (elementVectorType)
-    elementVectorRank += elementVectorType.getRank();
-  if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
-    // Not enough dimensions in the shaped type to form a minor identity map.
-    return AffineMap();
-  }
   return AffineMap::getMinorIdentityMap(
       shapedType.getRank(), vectorType.getRank() - elementVectorRank,
       shapedType.getContext());
@@ -4263,18 +4259,16 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   Attribute permMapAttr = result.attributes.get(permMapAttrName);
   AffineMap permMap;
   if (!permMapAttr) {
+    int64_t elementVectorRank = 0;
+    VectorType elementVectorType =
+        llvm::dyn_cast<VectorType>(shapedType.getElementType());
+    if (elementVectorType)
+      elementVectorRank += elementVectorType.getRank();
+    if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+      return parser.emitError(typesLoc,
+                              "expected a custom permutation_map when "
+                              "rank(source) != rank(destination)");
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
-    if (!permMap) {
-      int64_t elementVectorRank = 0;
-      VectorType elementVectorType =
-          llvm::dyn_cast<VectorType>(shapedType.getElementType());
-      if (elementVectorType)
-        elementVectorRank += elementVectorType.getRank();
-      if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
-        return parser.emitError(typesLoc,
-                                "expected a custom permutation_map when source "
-                                "rank is less than required for vector rank");
-    }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
     permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
@@ -4682,18 +4676,16 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
   auto permMapAttr = result.attributes.get(permMapAttrName);
   AffineMap permMap;
   if (!permMapAttr) {
+    int64_t elementVectorRank = 0;
+    VectorType elementVectorType =
+        llvm::dyn_cast<VectorType>(shapedType.getElementType());
+    if (elementVectorType)
+      elementVectorRank += elementVectorType.getRank();
+    if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+      return parser.emitError(typesLoc,
+                              "expected a custom permutation_map when "
+                              "rank(source) != rank(destination)");
     permMap = getTransferMinorIdentityMap(shapedType, vectorType);
-    if (!permMap) {
-      int64_t elementVectorRank = 0;
-      VectorType elementVectorType =
-          llvm::dyn_cast<VectorType>(shapedType.getElementType());
-      if (elementVectorType)
-        elementVectorRank += elementVectorType.getRank();
-      if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
-        return parser.emitError(typesLoc,
-                                "expected a custom permutation_map when result "
-                                "rank is less than required for vector rank");
-    }
     result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
   } else {
     permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index cb2d97bd17e04..05a8ea4d6c71e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -527,13 +527,22 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 
 func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
   %c3 = arith.constant 3 : index
-  // expected-error at +1 {{expected a custom permutation_map when source rank is less than required for vector rank}}
+  // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
   %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xi32>
   return %0 : vector<3x4xi32>
 }
 
 // -----
 
+func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xindex> {
+  %c3 = arith.constant 3 : index
+  // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
+  %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xindex>
+  return %0 : vector<3x4xindex>
+}
+
+// -----
+
 func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
   %c3 = arith.constant 3 : index
   %cst = arith.constant 3.0 : f32
@@ -657,12 +666,20 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
 
 func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
   %c3 = arith.constant 3 : index
-  // expected-error at +1 {{expected a custom permutation_map when result rank is less than required for vector rank}}
+  // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
   vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xi32>, memref<?xindex>
 }
 
 // -----
 
+func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref<?xindex>) {
+  %c3 = arith.constant 3 : index
+  // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
+  vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xindex>, memref<?xindex>
+}
+
+// -----
+
 func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
   // expected-error at +1 {{expected offsets of same size as destination vector rank}}
   %1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>

>From 2a920e0f9abcf3dd917635d2d97f704a1c6b4624 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Wed, 9 Apr 2025 21:26:43 +0800
Subject: [PATCH 8/9] [mlir][vector] Fix parser of vector.transfer_read

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 27 ++++++++++--------------
 1 file changed, 11 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index daf5b0d70d345..e969a7e02ba74 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -150,13 +150,18 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
   return false;
 }
 
-AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
-                                                    VectorType vectorType) {
-  int64_t elementVectorRank = 0;
+static unsigned getRealVectorRank(ShapedType shapedType,
+                                  VectorType vectorType) {
+  unsigned elementVectorRank = 0;
   VectorType elementVectorType =
       llvm::dyn_cast<VectorType>(shapedType.getElementType());
   if (elementVectorType)
     elementVectorRank += elementVectorType.getRank();
+  return vectorType.getRank() - elementVectorRank;
+}
+
+AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
+                                                    VectorType vectorType) {
   // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
   // TODO: replace once we have 0-d vectors.
   if (shapedType.getRank() == 0 &&
@@ -165,7 +170,7 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
         /*numDims=*/0, /*numSymbols=*/0,
         getAffineConstantExpr(0, shapedType.getContext()));
   return AffineMap::getMinorIdentityMap(
-      shapedType.getRank(), vectorType.getRank() - elementVectorRank,
+      shapedType.getRank(), getRealVectorRank(shapedType, vectorType),
       shapedType.getContext());
 }
 
@@ -4259,12 +4264,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   Attribute permMapAttr = result.attributes.get(permMapAttrName);
   AffineMap permMap;
   if (!permMapAttr) {
-    int64_t elementVectorRank = 0;
-    VectorType elementVectorType =
-        llvm::dyn_cast<VectorType>(shapedType.getElementType());
-    if (elementVectorType)
-      elementVectorRank += elementVectorType.getRank();
-    if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+    if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
       return parser.emitError(typesLoc,
                               "expected a custom permutation_map when "
                               "rank(source) != rank(destination)");
@@ -4676,12 +4676,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
   auto permMapAttr = result.attributes.get(permMapAttrName);
   AffineMap permMap;
   if (!permMapAttr) {
-    int64_t elementVectorRank = 0;
-    VectorType elementVectorType =
-        llvm::dyn_cast<VectorType>(shapedType.getElementType());
-    if (elementVectorType)
-      elementVectorRank += elementVectorType.getRank();
-    if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+    if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
       return parser.emitError(typesLoc,
                               "expected a custom permutation_map when "
                               "rank(source) != rank(destination)");

>From 34f89f7db139c40d85ae6639237a1d81e11f581b Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Tue, 15 Apr 2025 21:52:47 +0800
Subject: [PATCH 9/9] [mlir][vector] Fix parser of vector.transfer_read

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 21 +++++++++++++++++++++
 mlir/test/Dialect/Vector/invalid.mlir    | 17 -----------------
 2 files changed, 21 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 206fb343f913d..74222cb56d412 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -151,6 +151,27 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
   return false;
 }
 
+/// Returns the number of dimensions of the `shapedType` that participate in the
+/// vector transfer, effectively the rank of the vector dimensions within the
+/// `shapedType`. This is calculated by taking the rank of the `vectorType`
+/// being transferred and subtracting the rank of the `shapedType`'s element
+/// type if it's also a vector.
+///
+/// This is used to determine the number of minor dimensions for identity maps
+/// in vector transfers.
+///
+/// For example, given a transfer operation involving `shapedType` and
+/// `vectorType`:
+///
+///   - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
+///     - shapedType.getElementType() = f32 (rank 0)
+///     - vectorType.getRank() = 2
+///     - Result = 2 - 0 = 2
+///
+///   - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
+///     - shapedType.getElementType() = vector<20xf32> (rank 1)
+///     - vectorType.getRank() = 1
+///     - Result = 1 - 1 = 0
 static unsigned getRealVectorRank(ShapedType shapedType,
                                   VectorType vectorType) {
   unsigned elementVectorRank = 0;
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 4fa2703683da2..63f8667ce6b9e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -525,15 +525,6 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
 
 // -----
 
-func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
-  %c3 = arith.constant 3 : index
-  // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
-  %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xi32>
-  return %0 : vector<3x4xi32>
-}
-
-// -----
-
 func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xindex> {
   %c3 = arith.constant 3 : index
   // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
@@ -664,14 +655,6 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
 
 // -----
 
-func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
-  %c3 = arith.constant 3 : index
-  // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
-  vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xi32>, memref<?xindex>
-}
-
-// -----
-
 func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref<?xindex>) {
   %c3 = arith.constant 3 : index
   // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}



More information about the Mlir-commits mailing list