[Mlir-commits] [mlir] [mlir][vector] Fix scalability issues in drop innermost transfer_read unit dims (PR #92402)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri May 17 08:03:19 PDT 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/92402
>From 48806226061c81293e65658b983d4e29d6dd28bb Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 16 May 2024 14:03:24 +0000
Subject: [PATCH 1/3] [mlir][vector] Fix scalability issues in drop innermost
transfer_read unit dims
Previously, this rewrite would drop scalable dimensions and treated `[1]`
(scalable one dim) as a unit dimension. This patch propagates scalable
dimensions and ensures `[1]` is not treated as a unit dimension.
---
.../Vector/Transforms/VectorTransforms.cpp | 10 +++++--
...tor-transfer-collapse-inner-most-dims.mlir | 30 +++++++++++++++++++
2 files changed, 37 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 69c497264fd1e..720e638a74b55 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1237,6 +1237,10 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
return failure();
+ auto isUnitDim = [](VectorType type, int dim) {
+ return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
+ };
+
// According to vector.transfer_read/write semantics, the vector can be a
// slice. Thus, we have to offset the check index with `rankDiff` in
// `srcStrides` and source dim sizes.
@@ -1247,8 +1251,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
// It can be folded only if they are 1 and the stride is 1.
int dim = vectorType.getRank() - i - 1;
if (srcStrides[dim + rankDiff] != 1 ||
- srcType.getDimSize(dim + rankDiff) != 1 ||
- vectorType.getDimSize(dim) != 1)
+ srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
break;
result++;
}
@@ -1292,7 +1295,8 @@ class DropInnerMostUnitDimsTransferRead
auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
- targetType.getElementType());
+ targetType.getElementType(),
+ targetType.getScalableDims().drop_back(dimsToDrop));
auto loc = readOp.getLoc();
SmallVector<OpFoldResult> sizes =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 477755b66c020..ddfae5590e4c4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -174,3 +174,33 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
// The inner most unit dims can not be dropped if the strides are not ones.
// CHECK: func.func @non_unit_strides
// CHECK-NOT: memref.subview
+
+// -----
+
+func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<[4]x1xf32> {
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %4 = vector.transfer_read %dest[%index, %c0], %cst_0 {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
+ return %4 : vector<[4]x1xf32>
+}
+// CHECK: func.func @leading_scalable_dimension_transfer_read
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]][%[[IDX]]], %{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32>
+// CHECK: return %[[CAST]]
+
+// -----
+
+// Negative test: [1] (scalable 1) is _not_ a unit dimension.
+func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<4x[1]xf32> {
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %4 = vector.transfer_read %dest[%index, %c0], %cst_0 {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
+ return %4 : vector<4x[1]xf32>
+}
+// CHECK: func.func @trailing_scalable_one_dim_transfer_read
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.transfer_read {{.*}} : memref<24x1xf32>, vector<4x[1]xf32>
+// CHECK-NOT: vector.shape_cast
>From 26255c1ccee57aa15e76bf453e9189298efd83f8 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 17 May 2024 14:48:46 +0000
Subject: [PATCH 2/3] Fixups
---
.../Vector/Transforms/VectorTransforms.cpp | 3 +-
...tor-transfer-collapse-inner-most-dims.mlir | 40 ++++++++++++++++---
2 files changed, 36 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 720e638a74b55..f29eba90c3ceb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1382,7 +1382,8 @@ class DropInnerMostUnitDimsTransferWrite
auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
- targetType.getElementType());
+ targetType.getElementType(),
+ targetType.getScalableDims().drop_back(dimsToDrop));
Location loc = writeOp.getLoc();
SmallVector<OpFoldResult> sizes =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index ddfae5590e4c4..e5a609f238803 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -179,9 +179,9 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<[4]x1xf32> {
%c0 = arith.constant 0 : index
- %cst_0 = arith.constant 0.000000e+00 : f32
- %4 = vector.transfer_read %dest[%index, %c0], %cst_0 {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
- return %4 : vector<[4]x1xf32>
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %dest[%index, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
+ return %0 : vector<[4]x1xf32>
}
// CHECK: func.func @leading_scalable_dimension_transfer_read
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
@@ -196,11 +196,39 @@ func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>, %i
// Negative test: [1] (scalable 1) is _not_ a unit dimension.
func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<4x[1]xf32> {
%c0 = arith.constant 0 : index
- %cst_0 = arith.constant 0.000000e+00 : f32
- %4 = vector.transfer_read %dest[%index, %c0], %cst_0 {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
- return %4 : vector<4x[1]xf32>
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %dest[%index, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
+ return %0 : vector<4x[1]xf32>
}
// CHECK: func.func @trailing_scalable_one_dim_transfer_read
// CHECK-NOT: vector.shape_cast
// CHECK: vector.transfer_read {{.*}} : memref<24x1xf32>, vector<4x[1]xf32>
// CHECK-NOT: vector.shape_cast
+
+// -----
+
+func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>, %index: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>, memref<24x1xf32>
+ return
+}
+// CHECK: func.func @leading_scalable_dimension_transfer_write
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
+// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]][%[[IDX]]] {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
+
+// -----
+
+// Negative test: [1] (scalable 1) is _not_ a unit dimension.
+func.func @trailing_scalable_one_dim_transfer_write(%dest : memref<24x1xf32>, %vec: vector<4x[1]xf32>, %index: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<4x[1]xf32>, memref<24x1xf32>
+ return
+}
+// CHECK: func.func @trailing_scalable_one_dim_transfer_write
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32>
+// CHECK-NOT: vector.shape_cast
>From bc8622e2d6ab6df448dacd06a765e4e05877377e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 17 May 2024 15:02:29 +0000
Subject: [PATCH 3/3] Fixups
---
...ctor-transfer-collapse-inner-most-dims.mlir | 18 ++++++++----------
1 file changed, 8 insertions(+), 10 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index e5a609f238803..b4cb640108bae 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -177,27 +177,26 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
// -----
-func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<[4]x1xf32> {
+func.func @leading_scalable_dimension_transfer_read(%dest : memref<24x1xf32>) -> vector<[4]x1xf32> {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
- %0 = vector.transfer_read %dest[%index, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
+ %0 = vector.transfer_read %dest[%c0, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<[4]x1xf32>
return %0 : vector<[4]x1xf32>
}
// CHECK: func.func @leading_scalable_dimension_transfer_read
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
-// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]][%[[IDX]]], %{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32>
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32>
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32>
// CHECK: return %[[CAST]]
// -----
// Negative test: [1] (scalable 1) is _not_ a unit dimension.
-func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>, %index: index) -> vector<4x[1]xf32> {
+func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>) -> vector<4x[1]xf32> {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
- %0 = vector.transfer_read %dest[%index, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
+ %0 = vector.transfer_read %dest[%c0, %c0], %pad {in_bounds = [true, true]} : memref<24x1xf32>, vector<4x[1]xf32>
return %0 : vector<4x[1]xf32>
}
// CHECK: func.func @trailing_scalable_one_dim_transfer_read
@@ -207,18 +206,17 @@ func.func @trailing_scalable_one_dim_transfer_read(%dest : memref<24x1xf32>, %in
// -----
-func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>, %index: index) {
+func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>) {
%c0 = arith.constant 0 : index
- vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>, memref<24x1xf32>
+ vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>, memref<24x1xf32>
return
}
// CHECK: func.func @leading_scalable_dimension_transfer_write
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
-// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]][%[[IDX]]] {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
+// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
// -----
More information about the Mlir-commits
mailing list