[Mlir-commits] [mlir] [mlir][vector] Update docs + add tests (PR #137144)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 24 02:13:42 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
This is a small follow-on for #<!-- -->133721:
* Renamed `getRealVectorRank` as `getEffectiveVectorRankForXferOp` (to
emphasise that this method was written specifically for transfer Ops).
* Marginally tweaked the description for
`getEffectiveVectorRankForXferOp` (mostly to highlight the two edge
cases being covered).
* Added tests for cases when the element type (of the shaped type) is a
vector.
* Unified the naming (and the order) of arguments in tests with the
surrounding tests (e.g. `%vec_to_write` -> `%arg1`). Mostly for
consistency (it would be good to use self-documenting names like
`%vec_to_write` throughout).
---
Full diff: https://github.com/llvm/llvm-project/pull/137144.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+20-14)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+21-4)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3fee1e949aeed..49dd433597e8c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -151,29 +151,32 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
return false;
}
-/// Returns the number of dimensions of the `shapedType` that participate in the
-/// vector transfer, effectively the rank of the vector dimensions within the
-/// `shapedType`. This is calculated by taking the rank of the `vectorType`
-/// being transferred and subtracting the rank of the `shapedType`'s element
-/// type if it's also a vector.
+/// Returns the effective rank of the vector to read/write for Xfer Ops
///
-/// This is used to determine the number of minor dimensions for identity maps
-/// in vector transfers.
+/// When the element type of the shaped type is _a scalar_, this will simply
+/// return the rank of the vector ( the result for xfer_read or the value to
+/// store for xfer_write).
///
-/// For example, given a transfer operation involving `shapedType` and
-/// `vectorType`:
+/// When the element type of the base shaped type is _a vector_, returns the
+/// difference between the original vector type and the element type of the
+/// shaped type.
///
+/// EXAMPLE 1 (element type is _a scalar_):
/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
/// - shapedType.getElementType() = f32 (rank 0)
/// - vectorType.getRank() = 2
/// - Result = 2 - 0 = 2
///
+/// EXAMPLE 2 (element type is _a vector_):
/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
/// - shapedType.getElementType() = vector<20xf32> (rank 1)
/// - vectorType.getRank() = 1
/// - Result = 1 - 1 = 0
-static unsigned getRealVectorRank(ShapedType shapedType,
- VectorType vectorType) {
+///
+/// This is used to determine the number of minor dimensions for identity maps
+/// in vector transfer Ops.
+static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
+ VectorType vectorType) {
unsigned elementVectorRank = 0;
VectorType elementVectorType =
llvm::dyn_cast<VectorType>(shapedType.getElementType());
@@ -192,7 +195,8 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
/*numDims=*/0, /*numSymbols=*/0,
getAffineConstantExpr(0, shapedType.getContext()));
return AffineMap::getMinorIdentityMap(
- shapedType.getRank(), getRealVectorRank(shapedType, vectorType),
+ shapedType.getRank(),
+ getEffectiveVectorRankForXferOp(shapedType, vectorType),
shapedType.getContext());
}
@@ -4260,7 +4264,8 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
Attribute permMapAttr = result.attributes.get(permMapAttrName);
AffineMap permMap;
if (!permMapAttr) {
- if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
+ if (shapedType.getRank() <
+ getEffectiveVectorRankForXferOp(shapedType, vectorType))
return parser.emitError(typesLoc,
"expected a custom permutation_map when "
"rank(source) != rank(destination)");
@@ -4679,7 +4684,8 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
auto permMapAttr = result.attributes.get(permMapAttrName);
AffineMap permMap;
if (!permMapAttr) {
- if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
+ if (shapedType.getRank() <
+ getEffectiveVectorRankForXferOp(shapedType, vectorType))
return parser.emitError(typesLoc,
"expected a custom permutation_map when "
"rank(source) != rank(destination)");
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 19096f0e4c895..349a58d4eb4e4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -525,15 +525,24 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
// -----
-func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xindex> {
+func.func @test_vector.transfer_read(%arg0: memref<?xindex>) -> vector<3x4xindex> {
%c3 = arith.constant 3 : index
// expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
- %0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xindex>
+ %0 = vector.transfer_read %arg0[%c3], %c3 : memref<?xindex>, vector<3x4xindex>
return %0 : vector<3x4xindex>
}
// -----
+func.func @test_vector.transfer_write(%arg0: memref<?xvector<2xindex>>) {
+ %c3 = arith.constant 3 : index
+ // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
+ %0 = vector.transfer_read %arg0[%c3], %c3 : memref<?xvector<2xindex>>, vector<2x3x4xindex>
+ return %0 : vector<2x3x4xindex>
+}
+
+// -----
+
func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant 3.0 : f32
@@ -655,10 +664,18 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
// -----
-func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref<?xindex>) {
+func.func @test_vector.transfer_write(%arg0: memref<?xindex>, %arg1: vector<3x4xindex>) {
+ %c3 = arith.constant 3 : index
+ // expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
+ vector.transfer_write %arg1, %arg0[%c3, %c3] : vector<3x4xindex>, memref<?xindex>
+}
+
+// -----
+
+func.func @test_vector.transfer_write(%arg0: memref<?xvector<2xindex>>, %arg1: vector<2x3x4xindex>) {
%c3 = arith.constant 3 : index
// expected-error at +1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
- vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xindex>, memref<?xindex>
+ vector.transfer_write %arg1, %arg0[%c3, %c3] : vector<2x3x4xindex>, memref<?xvector<2xindex>>
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/137144
More information about the Mlir-commits
mailing list