[Mlir-commits] [mlir] 226896c - [mlir][linalg] Fix bug in vector transfer hoisting

Matthias Springer llvmlistbot at llvm.org
Wed Jul 12 07:29:34 PDT 2023


Author: Matthias Springer
Date: 2023-07-12T16:24:07+02:00
New Revision: 226896c3a865fc79c1ee8fe5c7e5f8c8b1a4753f

URL: https://github.com/llvm/llvm-project/commit/226896c3a865fc79c1ee8fe5c7e5f8c8b1a4753f
DIFF: https://github.com/llvm/llvm-project/commit/226896c3a865fc79c1ee8fe5c7e5f8c8b1a4753f.diff

LOG: [mlir][linalg] Fix bug in vector transfer hoisting

Do not hoist vector transfers that do not match exactly. In particular, do not hoist transfers with different vector types. This has lead to invalid IR (yielded vector type is different from iter_arg type) in downstream projects.

Differential Revision: https://reviews.llvm.org/D155052

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/test/Dialect/Linalg/hoisting.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 306762b2da5b4e..5f20ea42e49924 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -135,12 +135,14 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
                         << "\n");
 
       // Approximate aliasing by checking that:
-      //   1. indices are the same,
+      //   1. indices, vector type and permutation map are the same (i.e., the
+      //      transfer_read/transfer_write ops are matching),
       //   2. no other operations in the loop access the same memref except
       //      for transfer_read/transfer_write accessing statically disjoint
       //      slices.
-      if (transferRead.getIndices() != transferWrite.getIndices() &&
-          transferRead.getVectorType() == transferWrite.getVectorType())
+      if (transferRead.getIndices() != transferWrite.getIndices() ||
+          transferRead.getVectorType() != transferWrite.getVectorType() ||
+          transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
       // TODO: may want to memoize this information for performance but it

diff  --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index dadde430e1b22a..5b1bb3fc15e09e 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -722,3 +722,37 @@ transform.sequence failures(propagate) {
   transform.structured.hoist_redundant_vector_transfers %0
     : (!transform.any_op) -> !transform.any_op
 }
+
+// -----
+
+// The transfers in this test case cannot be hoisted and replaced by a vector
+// iter_arg because they do not match.
+
+// CHECK-LABEL:  func.func @non_matching_transfers(
+//       CHECK:    scf.for {{.*}} {
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+//       CHECK:    }
+func.func @non_matching_transfers(%m: memref<6x1x7x32xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1024 = arith.constant 1024 : index
+  %c128 = arith.constant 128 : index
+  %cst = arith.constant dense<5.5> : vector<6x7x32xf32>
+  %cst_0 = arith.constant 0.0 : f32
+  scf.for %iv = %c0 to %c1024 step %c128 {
+    %read = vector.transfer_read %m[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>} : memref<6x1x7x32xf32>, vector<6x7x32xf32>
+    %added = arith.addf %read, %cst : vector<6x7x32xf32>
+    %bc = vector.broadcast %added : vector<6x7x32xf32> to vector<1x6x7x32xf32>
+    %tr = vector.transpose %bc, [1, 0, 2, 3] : vector<1x6x7x32xf32> to vector<6x1x7x32xf32>
+    vector.transfer_write %tr, %m[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<6x1x7x32xf32>, memref<6x1x7x32xf32>
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1
+    : (!transform.any_op) -> !transform.any_op
+  transform.structured.hoist_redundant_vector_transfers %0
+    : (!transform.any_op) -> !transform.any_op
+}


        


More information about the Mlir-commits mailing list