[Mlir-commits] [mlir] 226896c - [mlir][linalg] Fix bug in vector transfer hoisting
Matthias Springer
llvmlistbot at llvm.org
Wed Jul 12 07:29:34 PDT 2023
Author: Matthias Springer
Date: 2023-07-12T16:24:07+02:00
New Revision: 226896c3a865fc79c1ee8fe5c7e5f8c8b1a4753f
URL: https://github.com/llvm/llvm-project/commit/226896c3a865fc79c1ee8fe5c7e5f8c8b1a4753f
DIFF: https://github.com/llvm/llvm-project/commit/226896c3a865fc79c1ee8fe5c7e5f8c8b1a4753f.diff
LOG: [mlir][linalg] Fix bug in vector transfer hoisting
Do not hoist vector transfers that do not match exactly. In particular, do not hoist transfers with different vector types. This has lead to invalid IR (yielded vector type is different from iter_arg type) in downstream projects.
Differential Revision: https://reviews.llvm.org/D155052
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/test/Dialect/Linalg/hoisting.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 306762b2da5b4e..5f20ea42e49924 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -135,12 +135,14 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
<< "\n");
// Approximate aliasing by checking that:
- // 1. indices are the same,
+ // 1. indices, vector type and permutation map are the same (i.e., the
+ // transfer_read/transfer_write ops are matching),
// 2. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
- if (transferRead.getIndices() != transferWrite.getIndices() &&
- transferRead.getVectorType() == transferWrite.getVectorType())
+ if (transferRead.getIndices() != transferWrite.getIndices() ||
+ transferRead.getVectorType() != transferWrite.getVectorType() ||
+ transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
// TODO: may want to memoize this information for performance but it
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index dadde430e1b22a..5b1bb3fc15e09e 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -722,3 +722,37 @@ transform.sequence failures(propagate) {
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
}
+
+// -----
+
+// The transfers in this test case cannot be hoisted and replaced by a vector
+// iter_arg because they do not match.
+
+// CHECK-LABEL: func.func @non_matching_transfers(
+// CHECK: scf.for {{.*}} {
+// CHECK: vector.transfer_read
+// CHECK: vector.transfer_write
+// CHECK: }
+func.func @non_matching_transfers(%m: memref<6x1x7x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1024 = arith.constant 1024 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant dense<5.5> : vector<6x7x32xf32>
+ %cst_0 = arith.constant 0.0 : f32
+ scf.for %iv = %c0 to %c1024 step %c128 {
+ %read = vector.transfer_read %m[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>} : memref<6x1x7x32xf32>, vector<6x7x32xf32>
+ %added = arith.addf %read, %cst : vector<6x7x32xf32>
+ %bc = vector.broadcast %added : vector<6x7x32xf32> to vector<1x6x7x32xf32>
+ %tr = vector.transpose %bc, [1, 0, 2, 3] : vector<1x6x7x32xf32> to vector<6x1x7x32xf32>
+ vector.transfer_write %tr, %m[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<6x1x7x32xf32>, memref<6x1x7x32xf32>
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+}
More information about the Mlir-commits
mailing list