[Mlir-commits] [mlir] 8f27a1f - [mlir] Relax transfer op hoisting on tensor

Thomas Raoux llvmlistbot at llvm.org
Mon Sep 26 11:22:36 PDT 2022


Author: Thomas Raoux
Date: 2022-09-26T18:22:19Z
New Revision: 8f27a1f8655924fe1d5679d5bd0a4d09c0683728

URL: https://github.com/llvm/llvm-project/commit/8f27a1f8655924fe1d5679d5bd0a4d09c0683728
DIFF: https://github.com/llvm/llvm-project/commit/8f27a1f8655924fe1d5679d5bd0a4d09c0683728.diff

LOG: [mlir] Relax transfer op hoisting on tensor

Improve hoisting logic to support cases where the read being hoisted
comes from a transfer_write with disjoint indices.

Differential Revision: https://reviews.llvm.org/D134624

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/test/Dialect/Linalg/hoisting.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 6ab8fcd6e28d4..29fe94ea9d149 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -106,8 +106,10 @@ static HoistableRead findMatchingTransferRead(HoistableWrite write,
   if (write.insertSliceOp)
     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: "
                       << *write.insertSliceOp.getOperation() << "\n");
-
-  for (Operation *user : srcTensor.getUsers()) {
+  SmallVector<Operation *> users(srcTensor.getUsers().begin(),
+                                 srcTensor.getUsers().end());
+  while (!users.empty()) {
+    Operation *user = users.pop_back_val();
     LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
                       << "\n");
 
@@ -153,6 +155,16 @@ static HoistableRead findMatchingTransferRead(HoistableWrite write,
     if (read && read.getIndices() == write.transferWriteOp.getIndices() &&
         read.getVectorType() == write.transferWriteOp.getVectorType())
       return HoistableRead{read, sliceOp};
+
+    if (isa<vector::TransferWriteOp>(user)) {
+      // If we find a write with disjoint indices recurse through its uses.
+      if (vector::isDisjointTransferIndices(
+              cast<VectorTransferOpInterface>(user),
+              cast<VectorTransferOpInterface>(
+                  write.transferWriteOp.getOperation()))) {
+        users.append(user->getUsers().begin(), user->getUsers().end());
+      }
+    }
   }
   return HoistableRead();
 }

diff  --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 1b1c6a8d2be2f..2b783d144b7bc 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -431,3 +431,41 @@ func.func @hoist_vector_transfer_pairs_tensor_and_slices(
   }
   return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor(
+//  CHECK-SAME:   %[[T:.*]]: tensor<?x?xf32>,
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+//   CHECK-DAG:   %[[R1:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C3]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+//       CHECK:   %[[F:.*]]:2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[R3:.*]] = %[[R1:.*]], %[[R2:.*]] = %[[R0]]) -> (vector<2xf32>, vector<2xf32>) {
+//       CHECK:     %[[R4:.*]] = "some_use"(%[[R2]]) : (vector<2xf32>) -> vector<2xf32>
+//       CHECK:     %[[R5:.*]] = "some_use"(%[[R3]]) : (vector<2xf32>) -> vector<2xf32>
+//       CHECK:     scf.yield %[[R5]], %[[R4]] : vector<2xf32>, vector<2xf32>
+//       CHECK:   }
+//       CHECK:   %[[W0:.*]] = vector.transfer_write %[[F]]#1, %[[T]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor<?x?xf32>
+//       CHECK:   %[[W1:.*]] = vector.transfer_write %[[F]]#0, %[[W0]][%[[C0]], %[[C3]]] : vector<2xf32>, tensor<?x?xf32>
+//       CHECK:  return %[[W1]] : tensor<?x?xf32>
+func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
+    %tensor: tensor<?x?xf32>,
+    %val: index, %lb : index, %ub : index, %step: index) ->
+    (tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.0 : f32
+  %1 = scf.for %j = %lb to %ub step %step iter_args(%arg5 = %tensor)
+    -> (tensor<?x?xf32>) {
+    %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+    %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
+    %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
+    %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor<?x?xf32>, vector<2xf32>
+    %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
+    %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor<?x?xf32>
+    scf.yield %w11 : tensor<?x?xf32>
+  }
+  return %1 : tensor<?x?xf32>
+}
+


        


More information about the Mlir-commits mailing list