[Mlir-commits] [mlir] [mlir][vector] Prevent incorrect vector.transfer_{read|write} hoisting (PR #66930)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Sep 22 00:33:07 PDT 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/66930
>From 9aae59ad4c4feebb9f78fd5d8efc6e63e31484e2 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 20 Sep 2023 17:15:09 +0000
Subject: [PATCH] [mlir][vector] Prevent incorrect vector.transfer_{read|write}
hoisting
At the moment, `hoistRedundantVectorTransfers` would hoist the
`vector.transfer_read`/`vector.transfer_write` pair in this function:
```mlir
func.func @no_hoisting_write_to_memref(%rhs: i32, %arg1: vector<1xi32>) {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c20 = arith.constant 20 : index
%alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32>
%cast = memref.cast %alloca : memref<1x1x2xi32> to memref<1x1x2xi32>
%collapsed_1 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32>
scf.for %_ = %c0 to %c20 step %c4 {
%collapsed_2 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32>
%lhs = vector.transfer_read %collapsed_1[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32>
%acc = vector.transfer_read %collapsed_2[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32>
%op = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<1xi32>, i32
vector.transfer_write %op, %collapsed_1[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32>
}
return
}
```
as follows:
```mlir
func.func @no_hoisting_write_to_memref(%arg0: i32, %arg1: vector<1xi32>) {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c20 = arith.constant 20 : index
%alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32>
%collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32>
%collapse_shape_0 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32>
%0 = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32>
%1 = vector.transfer_read %collapse_shape_0[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32>
%2 = scf.for %arg2 = %c0 to %c20 step %c4 iter_args(%arg3 = %0) -> (vector<1xi32>) {
%3 = vector.outerproduct %arg3, %arg0, %1 {kind = #vector.kind<add>} : vector<1xi32>, i32
scf.yield %3 : vector<1xi32>
}
vector.transfer_write %2, %collapse_shape[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32>
return
}
```
This is not safe. While one argument for `vector.outerproduct` (`%rhs`
from the original loop) is correctly being forwarded via `iter_args`,
the other one (`%acc` from the original loop) is not.
This patch disables hoisting in cases where the source of "candidate"
`vector.transfer_read` aliases with some other `memref`. A more generic
approach would be to make sure that all values are correctly forwarded
via `iter_args`, but that would require involving alias analysis.
[1] Based on https://github.com/openxla/iree/issues/14994.
---
.../Dialect/Linalg/Transforms/Hoisting.cpp | 8 +++
mlir/test/Dialect/Linalg/hoisting.mlir | 53 +++++++++++++++++--
2 files changed, 57 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 7c6639304d97c58..31cb9c010e00a93 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -152,6 +152,14 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
+ // When the source of transfer_read aliases, the following dominance
+ // analysis might not be sufficient.
+ // TODO: There might be other, similar cases missing here (i.e. other
+ // Memref Ops).
+ auto source = transferRead.getSource();
+ if (source.getDefiningOp<memref::CollapseShapeOp>())
+ return WalkResult::advance();
+
// TODO: may want to memoize this information for performance but it
// likely gets invalidated often.
DominanceInfo dom(loop);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index e25914620726b9b..7d0c3648c344b1d 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -765,10 +765,10 @@ transform.sequence failures(propagate) {
// CHECK-LABEL: func.func @no_hoisting_collapse_shape
// CHECK: scf.for {{.*}} {
-// CHECK: vector.transfer_write
-// CHECK: vector.transfer_read
-// CHECK: vector.transfer_write
-// CHECK: }
+// CHECK: vector.transfer_write {{.*}} : vector<4xi32>, memref<4xi32>
+// CHECK-NEXT: vector.transfer_read {{.*}} : memref<1x4x1xi32>, vector<1x4x1xi32>
+// CHECK-NEXT: vector.transfer_write {{.*}} : vector<1x4x1xi32>, memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
+// CHECK-NEXT: }
func.func @no_hoisting_collapse_shape(%in_0: memref<1x20x1xi32>, %1: memref<9x1xi32>, %vec: vector<4xi32>) {
%c0_i32 = arith.constant 0 : i32
@@ -827,3 +827,48 @@ transform.sequence failures(propagate) {
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
}
+
+// -----
+
+// Regression test - hoisting the following `vector.transfer_{read|write}` pair
+// would not be safe:
+// %lhs = vector.transfer_read %collapsed_1[%c0]
+// vector.transfer_write %op, %collapsed_1[%c0]
+// That's because the following `vector.transfer_read` reads from the same
+// memory (i.e. `%collapsed_1` and `%collapsed_2` alias):
+// %acc = vector.transfer_read %collapsed_2[%c0]
+
+// CHECK-LABEL: func.func @no_hoisting_write_to_memref
+// CHECK: scf.for {{.*}} {
+// CHECK: vector.transfer_read {{.*}} : memref<2xi32>, vector<1xi32>
+// CHECK-NEXT: vector.transfer_read {{.*}} : memref<2xi32>, vector<1xi32>
+// CHECK-NEXT: vector.outerproduct {{.*}} : vector<1xi32>, i32
+// CHECK-NEXT: vector.transfer_write {{.*}} : vector<1xi32>, memref<2xi32>
+// CHECK-NEXT: }
+
+func.func @no_hoisting_write_to_memref(%rhs: i32, %arg1: vector<1xi32>) {
+ %c0_i32 = arith.constant 0 : i32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c20 = arith.constant 20 : index
+ %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32>
+ %cast = memref.cast %alloca : memref<1x1x2xi32> to memref<1x1x2xi32>
+ %collapsed_1 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32>
+ scf.for %_ = %c0 to %c20 step %c4 {
+ %collapsed_2 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32>
+ %lhs = vector.transfer_read %collapsed_1[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32>
+ %acc = vector.transfer_read %collapsed_2[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32>
+ %op = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<1xi32>, i32
+ vector.transfer_write %op, %collapsed_1[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32>
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+}
More information about the Mlir-commits
mailing list