[Mlir-commits] [mlir] [mlir][vector] Fix parser of vector.transfer_read (PR #133721)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 14 06:11:44 PDT 2025
https://github.com/douyixuan updated https://github.com/llvm/llvm-project/pull/133721
>From 8675a216b18efd159a220cacb10e1a2331228fc4 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Mon, 31 Mar 2025 14:14:54 +0800
Subject: [PATCH 1/8] [mlir][vector] Fix parser of vector.transfer_read
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 ++++++++-
mlir/test/Dialect/Vector/invalid.mlir | 9 +++++++++
2 files changed, 17 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..d99c8ef0680f4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -151,7 +151,7 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
}
AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
- VectorType vectorType) {
+ VectorType vectorType) {
int64_t elementVectorRank = 0;
VectorType elementVectorType =
llvm::dyn_cast<VectorType>(shapedType.getElementType());
@@ -164,6 +164,9 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
return AffineMap::get(
/*numDims=*/0, /*numSymbols=*/0,
getAffineConstantExpr(0, shapedType.getContext()));
+ if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
+ return AffineMap::get(shapedType.getContext());
+ }
return AffineMap::getMinorIdentityMap(
shapedType.getRank(), vectorType.getRank() - elementVectorRank,
shapedType.getContext());
@@ -4260,6 +4263,10 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
AffineMap permMap;
if (!permMapAttr) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
+ if (permMap.isEmpty()) {
+ return parser.emitError(typesLoc,
+ "failed to create minor identity permutation map");
+ }
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index ea6d0021391fb..667e1615212f4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -525,6 +525,15 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
// -----
+func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
+ %c3_i32 = arith.constant 3 : i32
+ // expected-error at +1 {{failed to create minor identity permutation map}}
+ %0 = vector.transfer_read %arg1[%c3_i32, %c3_i32], %c3_i32 : memref<?xindex>, vector<3x4xi32>
+ return %0 : vector<3x4xi32>
+}
+
+// -----
+
func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant 3.0 : f32
>From 38f8b62b37ac447c460ed6d789f2b9efc1ad39f3 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Tue, 1 Apr 2025 09:33:59 +0800
Subject: [PATCH 2/8] [mlir][vector] Fix parser of vector.transfer_read
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d99c8ef0680f4..fc0e17be52f70 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -151,7 +151,7 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
}
AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
- VectorType vectorType) {
+ VectorType vectorType) {
int64_t elementVectorRank = 0;
VectorType elementVectorType =
llvm::dyn_cast<VectorType>(shapedType.getElementType());
@@ -165,7 +165,8 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
/*numDims=*/0, /*numSymbols=*/0,
getAffineConstantExpr(0, shapedType.getContext()));
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
- return AffineMap::get(shapedType.getContext());
+ return AffineMap(); // Not enough dimensions in the shaped type to form a
+ // minor identity map.
}
return AffineMap::getMinorIdentityMap(
shapedType.getRank(), vectorType.getRank() - elementVectorRank,
@@ -4263,9 +4264,9 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
AffineMap permMap;
if (!permMapAttr) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
- if (permMap.isEmpty()) {
- return parser.emitError(typesLoc,
- "failed to create minor identity permutation map");
+ if (!permMap) {
+ return parser.emitError(
+ typesLoc, "failed to create minor identity permutation map");
}
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
>From 6755d1dce840822f22eb3c95aa06c7439072bf57 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Tue, 1 Apr 2025 21:30:10 +0800
Subject: [PATCH 3/8] [mlir][vector] Fix parser of vector.transfer_read
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 ++++++++--
mlir/test/Dialect/Vector/invalid.mlir | 10 +++++++++-
2 files changed, 17 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fc0e17be52f70..b3ebd190a80d0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4265,8 +4265,9 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
if (!permMapAttr) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
if (!permMap) {
- return parser.emitError(
- typesLoc, "failed to create minor identity permutation map");
+ return parser.emitError(typesLoc,
+ "expected the same rank for the vector and the "
+ "results of the permutation map");
}
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
@@ -4676,6 +4677,11 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
AffineMap permMap;
if (!permMapAttr) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
+ if (!permMap) {
+ return parser.emitError(typesLoc,
+ "expected the same rank for the vector and the "
+ "results of the permutation map");
+ }
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 667e1615212f4..e0d440ea0f4b6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -527,7 +527,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
%c3_i32 = arith.constant 3 : i32
- // expected-error at +1 {{failed to create minor identity permutation map}}
+ // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
%0 = vector.transfer_read %arg1[%c3_i32, %c3_i32], %c3_i32 : memref<?xindex>, vector<3x4xi32>
return %0 : vector<3x4xi32>
}
@@ -655,6 +655,14 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
// -----
+func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
+ %c3_idx = arith.constant 3 : index
+ // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
+ vector.transfer_write %vec_to_write, %output_memref[%c3_idx, %c3_idx] : vector<3x4xi32>, memref<?xindex>
+}
+
+// -----
+
func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
// expected-error at +1 {{expected offsets of same size as destination vector rank}}
%1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
>From ba2f392c9c50cf2a3aa042ae39352a20213badb1 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Wed, 2 Apr 2025 19:44:57 +0800
Subject: [PATCH 4/8] [mlir][vector] Fix parser of vector.transfer_read
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++--------
mlir/test/Dialect/Vector/invalid.mlir | 4 ++--
2 files changed, 8 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b3ebd190a80d0..e6e2fd72aa747 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -165,8 +165,8 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
/*numDims=*/0, /*numSymbols=*/0,
getAffineConstantExpr(0, shapedType.getContext()));
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
- return AffineMap(); // Not enough dimensions in the shaped type to form a
- // minor identity map.
+ // Not enough dimensions in the shaped type to form a minor identity map.
+ return AffineMap();
}
return AffineMap::getMinorIdentityMap(
shapedType.getRank(), vectorType.getRank() - elementVectorRank,
@@ -4265,9 +4265,8 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
if (!permMapAttr) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
if (!permMap) {
- return parser.emitError(typesLoc,
- "expected the same rank for the vector and the "
- "results of the permutation map");
+ return parser.emitError(
+ typesLoc, "source rank is less than required for vector rank");
}
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
@@ -4678,9 +4677,8 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
if (!permMapAttr) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
if (!permMap) {
- return parser.emitError(typesLoc,
- "expected the same rank for the vector and the "
- "results of the permutation map");
+ return parser.emitError(
+ typesLoc, "result rank is less than required for vector rank");
}
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e0d440ea0f4b6..3e269a247c448 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -527,7 +527,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
%c3_i32 = arith.constant 3 : i32
- // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
+ // expected-error at +1 {{source rank is less than required for vector rank}}
%0 = vector.transfer_read %arg1[%c3_i32, %c3_i32], %c3_i32 : memref<?xindex>, vector<3x4xi32>
return %0 : vector<3x4xi32>
}
@@ -657,7 +657,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
%c3_idx = arith.constant 3 : index
- // expected-error at +1 {{expected the same rank for the vector and the results of the permutation map}}
+ // expected-error at +1 {{result rank is less than required for vector rank}}
vector.transfer_write %vec_to_write, %output_memref[%c3_idx, %c3_idx] : vector<3x4xi32>, memref<?xindex>
}
>From 3a8d23c7ead058a0f63eeb5b6f4b1d7e0bc78862 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Thu, 3 Apr 2025 00:27:45 +0800
Subject: [PATCH 5/8] [mlir][vector] Fix parser of vector.transfer_read
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 +++++++-------
mlir/test/Dialect/Vector/invalid.mlir | 8 ++++----
2 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e6e2fd72aa747..8126bf0c52d75 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -152,11 +152,6 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
VectorType vectorType) {
- int64_t elementVectorRank = 0;
- VectorType elementVectorType =
- llvm::dyn_cast<VectorType>(shapedType.getElementType());
- if (elementVectorType)
- elementVectorRank += elementVectorType.getRank();
// 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
// TODO: replace once we have 0-d vectors.
if (shapedType.getRank() == 0 &&
@@ -164,6 +159,11 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
return AffineMap::get(
/*numDims=*/0, /*numSymbols=*/0,
getAffineConstantExpr(0, shapedType.getContext()));
+ int64_t elementVectorRank = 0;
+ VectorType elementVectorType =
+ llvm::dyn_cast<VectorType>(shapedType.getElementType());
+ if (elementVectorType)
+ elementVectorRank += elementVectorType.getRank();
if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
// Not enough dimensions in the shaped type to form a minor identity map.
return AffineMap();
@@ -4266,7 +4266,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
if (!permMap) {
return parser.emitError(
- typesLoc, "source rank is less than required for vector rank");
+ typesLoc, "failed to create a minor identity map, source rank is less than required for vector rank");
}
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
@@ -4678,7 +4678,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
if (!permMap) {
return parser.emitError(
- typesLoc, "result rank is less than required for vector rank");
+ typesLoc, "failed to create a minor identity map, result rank is less than required for vector rank");
}
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3e269a247c448..980c4287c9ac6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -526,9 +526,9 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
// -----
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
- %c3_i32 = arith.constant 3 : i32
- // expected-error at +1 {{source rank is less than required for vector rank}}
- %0 = vector.transfer_read %arg1[%c3_i32, %c3_i32], %c3_i32 : memref<?xindex>, vector<3x4xi32>
+ %c3 = arith.constant 3 : index
+ // expected-error at +1 {{failed to create a minor identity map, source rank is less than required for vector rank}}
+ %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xi32>
return %0 : vector<3x4xi32>
}
@@ -657,7 +657,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
%c3_idx = arith.constant 3 : index
- // expected-error at +1 {{result rank is less than required for vector rank}}
+ // expected-error at +1 {{failed to create a minor identity map, result rank is less than required for vector rank}}
vector.transfer_write %vec_to_write, %output_memref[%c3_idx, %c3_idx] : vector<3x4xi32>, memref<?xindex>
}
>From f671fe9bde677151a87c5f62a8cd87a419c8de04 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Fri, 4 Apr 2025 00:01:30 +0800
Subject: [PATCH 6/8] [mlir][vector] Fix parser of vector.transfer_read
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 22 ++++++++++++++++++----
mlir/test/Dialect/Vector/invalid.mlir | 8 ++++----
2 files changed, 22 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8126bf0c52d75..d6221f3488632 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4265,8 +4265,15 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
if (!permMapAttr) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
if (!permMap) {
- return parser.emitError(
- typesLoc, "failed to create a minor identity map, source rank is less than required for vector rank");
+ int64_t elementVectorRank = 0;
+ VectorType elementVectorType =
+ llvm::dyn_cast<VectorType>(shapedType.getElementType());
+ if (elementVectorType)
+ elementVectorRank += elementVectorType.getRank();
+ if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+ return parser.emitError(typesLoc,
+ "expected a custom permutation_map when source "
+ "rank is less than required for vector rank");
}
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
@@ -4677,8 +4684,15 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
if (!permMapAttr) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
if (!permMap) {
- return parser.emitError(
- typesLoc, "failed to create a minor identity map, result rank is less than required for vector rank");
+ int64_t elementVectorRank = 0;
+ VectorType elementVectorType =
+ llvm::dyn_cast<VectorType>(shapedType.getElementType());
+ if (elementVectorType)
+ elementVectorRank += elementVectorType.getRank();
+ if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+ return parser.emitError(typesLoc,
+ "expected a custom permutation_map when result "
+ "rank is less than required for vector rank");
}
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 980c4287c9ac6..cb2d97bd17e04 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -527,7 +527,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
%c3 = arith.constant 3 : index
- // expected-error at +1 {{failed to create a minor identity map, source rank is less than required for vector rank}}
+ // expected-error at +1 {{expected a custom permutation_map when source rank is less than required for vector rank}}
%0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xi32>
return %0 : vector<3x4xi32>
}
@@ -656,9 +656,9 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
// -----
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
- %c3_idx = arith.constant 3 : index
- // expected-error at +1 {{failed to create a minor identity map, result rank is less than required for vector rank}}
- vector.transfer_write %vec_to_write, %output_memref[%c3_idx, %c3_idx] : vector<3x4xi32>, memref<?xindex>
+ %c3 = arith.constant 3 : index
+ // expected-error at +1 {{expected a custom permutation_map when result rank is less than required for vector rank}}
+ vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xi32>, memref<?xindex>
}
// -----
>From bbe35d61a9f7ea2f53c06305a411ba7a52e18eac Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Tue, 8 Apr 2025 19:38:22 +0800
Subject: [PATCH 7/8] [mlir][vector] Fix parser of vector.transfer_read
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 54 ++++++++++--------------
mlir/test/Dialect/Vector/invalid.mlir | 21 ++++++++-
2 files changed, 42 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d6221f3488632..daf5b0d70d345 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -152,6 +152,11 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
VectorType vectorType) {
+ int64_t elementVectorRank = 0;
+ VectorType elementVectorType =
+ llvm::dyn_cast<VectorType>(shapedType.getElementType());
+ if (elementVectorType)
+ elementVectorRank += elementVectorType.getRank();
// 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
// TODO: replace once we have 0-d vectors.
if (shapedType.getRank() == 0 &&
@@ -159,15 +164,6 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
return AffineMap::get(
/*numDims=*/0, /*numSymbols=*/0,
getAffineConstantExpr(0, shapedType.getContext()));
- int64_t elementVectorRank = 0;
- VectorType elementVectorType =
- llvm::dyn_cast<VectorType>(shapedType.getElementType());
- if (elementVectorType)
- elementVectorRank += elementVectorType.getRank();
- if (shapedType.getRank() < vectorType.getRank() - elementVectorRank) {
- // Not enough dimensions in the shaped type to form a minor identity map.
- return AffineMap();
- }
return AffineMap::getMinorIdentityMap(
shapedType.getRank(), vectorType.getRank() - elementVectorRank,
shapedType.getContext());
@@ -4263,18 +4259,16 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
Attribute permMapAttr = result.attributes.get(permMapAttrName);
AffineMap permMap;
if (!permMapAttr) {
+ int64_t elementVectorRank = 0;
+ VectorType elementVectorType =
+ llvm::dyn_cast<VectorType>(shapedType.getElementType());
+ if (elementVectorType)
+ elementVectorRank += elementVectorType.getRank();
+ if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+ return parser.emitError(typesLoc,
+ "expected a custom permutation_map when "
+ "rank(source) != rank(destination)");
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
- if (!permMap) {
- int64_t elementVectorRank = 0;
- VectorType elementVectorType =
- llvm::dyn_cast<VectorType>(shapedType.getElementType());
- if (elementVectorType)
- elementVectorRank += elementVectorType.getRank();
- if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
- return parser.emitError(typesLoc,
- "expected a custom permutation_map when source "
- "rank is less than required for vector rank");
- }
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
@@ -4682,18 +4676,16 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
auto permMapAttr = result.attributes.get(permMapAttrName);
AffineMap permMap;
if (!permMapAttr) {
+ int64_t elementVectorRank = 0;
+ VectorType elementVectorType =
+ llvm::dyn_cast<VectorType>(shapedType.getElementType());
+ if (elementVectorType)
+ elementVectorRank += elementVectorType.getRank();
+ if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+ return parser.emitError(typesLoc,
+ "expected a custom permutation_map when "
+ "rank(source) != rank(destination)");
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
- if (!permMap) {
- int64_t elementVectorRank = 0;
- VectorType elementVectorType =
- llvm::dyn_cast<VectorType>(shapedType.getElementType());
- if (elementVectorType)
- elementVectorRank += elementVectorType.getRank();
- if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
- return parser.emitError(typesLoc,
- "expected a custom permutation_map when result "
- "rank is less than required for vector rank");
- }
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index cb2d97bd17e04..05a8ea4d6c71e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -527,13 +527,22 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xi32> {
%c3 = arith.constant 3 : index
- // expected-error at +1 {{expected a custom permutation_map when source rank is less than required for vector rank}}
+ // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
%0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xi32>
return %0 : vector<3x4xi32>
}
// -----
+func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xindex> {
+ %c3 = arith.constant 3 : index
+ // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
+ %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xindex>
+ return %0 : vector<3x4xindex>
+}
+
+// -----
+
func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant 3.0 : f32
@@ -657,12 +666,20 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xi32>, %output_memref: memref<?xindex>) {
%c3 = arith.constant 3 : index
- // expected-error at +1 {{expected a custom permutation_map when result rank is less than required for vector rank}}
+ // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xi32>, memref<?xindex>
}
// -----
+func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref<?xindex>) {
+ %c3 = arith.constant 3 : index
+ // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
+ vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xindex>, memref<?xindex>
+}
+
+// -----
+
func.func @insert_strided_slice(%a: vector<4x4xf32>, %b: vector<4x8x16xf32>) {
// expected-error at +1 {{expected offsets of same size as destination vector rank}}
%1 = vector.insert_strided_slice %a, %b {offsets = [100], strides = [1, 1]} : vector<4x4xf32> into vector<4x8x16xf32>
>From 2a920e0f9abcf3dd917635d2d97f704a1c6b4624 Mon Sep 17 00:00:00 2001
From: Cedric Meng <14017092+douyixuan at users.noreply.github.com>
Date: Wed, 9 Apr 2025 21:26:43 +0800
Subject: [PATCH 8/8] [mlir][vector] Fix parser of vector.transfer_read
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 27 ++++++++++--------------
1 file changed, 11 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index daf5b0d70d345..e969a7e02ba74 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -150,13 +150,18 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
return false;
}
-AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
- VectorType vectorType) {
- int64_t elementVectorRank = 0;
+static unsigned getRealVectorRank(ShapedType shapedType,
+ VectorType vectorType) {
+ unsigned elementVectorRank = 0;
VectorType elementVectorType =
llvm::dyn_cast<VectorType>(shapedType.getElementType());
if (elementVectorType)
elementVectorRank += elementVectorType.getRank();
+ return vectorType.getRank() - elementVectorRank;
+}
+
+AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
+ VectorType vectorType) {
// 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
// TODO: replace once we have 0-d vectors.
if (shapedType.getRank() == 0 &&
@@ -165,7 +170,7 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
/*numDims=*/0, /*numSymbols=*/0,
getAffineConstantExpr(0, shapedType.getContext()));
return AffineMap::getMinorIdentityMap(
- shapedType.getRank(), vectorType.getRank() - elementVectorRank,
+ shapedType.getRank(), getRealVectorRank(shapedType, vectorType),
shapedType.getContext());
}
@@ -4259,12 +4264,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
Attribute permMapAttr = result.attributes.get(permMapAttrName);
AffineMap permMap;
if (!permMapAttr) {
- int64_t elementVectorRank = 0;
- VectorType elementVectorType =
- llvm::dyn_cast<VectorType>(shapedType.getElementType());
- if (elementVectorType)
- elementVectorRank += elementVectorType.getRank();
- if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+ if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
return parser.emitError(typesLoc,
"expected a custom permutation_map when "
"rank(source) != rank(destination)");
@@ -4676,12 +4676,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
auto permMapAttr = result.attributes.get(permMapAttrName);
AffineMap permMap;
if (!permMapAttr) {
- int64_t elementVectorRank = 0;
- VectorType elementVectorType =
- llvm::dyn_cast<VectorType>(shapedType.getElementType());
- if (elementVectorType)
- elementVectorRank += elementVectorType.getRank();
- if (shapedType.getRank() < vectorType.getRank() - elementVectorRank)
+ if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
return parser.emitError(typesLoc,
"expected a custom permutation_map when "
"rank(source) != rank(destination)");
More information about the Mlir-commits
mailing list