[Mlir-commits] [mlir] 8f27a1f - [mlir] Relax transfer op hoisting on tensor
Thomas Raoux
llvmlistbot at llvm.org
Mon Sep 26 11:22:36 PDT 2022
Author: Thomas Raoux
Date: 2022-09-26T18:22:19Z
New Revision: 8f27a1f8655924fe1d5679d5bd0a4d09c0683728
URL: https://github.com/llvm/llvm-project/commit/8f27a1f8655924fe1d5679d5bd0a4d09c0683728
DIFF: https://github.com/llvm/llvm-project/commit/8f27a1f8655924fe1d5679d5bd0a4d09c0683728.diff
LOG: [mlir] Relax transfer op hoisting on tensor
Improve hoisting logic to support cases where the read being hoisted
comes from a transfer_write with disjoint indices.
Differential Revision: https://reviews.llvm.org/D134624
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/test/Dialect/Linalg/hoisting.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 6ab8fcd6e28d4..29fe94ea9d149 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -106,8 +106,10 @@ static HoistableRead findMatchingTransferRead(HoistableWrite write,
if (write.insertSliceOp)
LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: "
<< *write.insertSliceOp.getOperation() << "\n");
-
- for (Operation *user : srcTensor.getUsers()) {
+ SmallVector<Operation *> users(srcTensor.getUsers().begin(),
+ srcTensor.getUsers().end());
+ while (!users.empty()) {
+ Operation *user = users.pop_back_val();
LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
<< "\n");
@@ -153,6 +155,16 @@ static HoistableRead findMatchingTransferRead(HoistableWrite write,
if (read && read.getIndices() == write.transferWriteOp.getIndices() &&
read.getVectorType() == write.transferWriteOp.getVectorType())
return HoistableRead{read, sliceOp};
+
+ if (isa<vector::TransferWriteOp>(user)) {
+ // If we find a write with disjoint indices recurse through its uses.
+ if (vector::isDisjointTransferIndices(
+ cast<VectorTransferOpInterface>(user),
+ cast<VectorTransferOpInterface>(
+ write.transferWriteOp.getOperation()))) {
+ users.append(user->getUsers().begin(), user->getUsers().end());
+ }
+ }
}
return HoistableRead();
}
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 1b1c6a8d2be2f..2b783d144b7bc 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -431,3 +431,41 @@ func.func @hoist_vector_transfer_pairs_tensor_and_slices(
}
return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor(
+// CHECK-SAME: %[[T:.*]]: tensor<?x?xf32>,
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// CHECK-DAG: %[[R1:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C3]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// CHECK: %[[F:.*]]:2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[R3:.*]] = %[[R1:.*]], %[[R2:.*]] = %[[R0]]) -> (vector<2xf32>, vector<2xf32>) {
+// CHECK: %[[R4:.*]] = "some_use"(%[[R2]]) : (vector<2xf32>) -> vector<2xf32>
+// CHECK: %[[R5:.*]] = "some_use"(%[[R3]]) : (vector<2xf32>) -> vector<2xf32>
+// CHECK: scf.yield %[[R5]], %[[R4]] : vector<2xf32>, vector<2xf32>
+// CHECK: }
+// CHECK: %[[W0:.*]] = vector.transfer_write %[[F]]#1, %[[T]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor<?x?xf32>
+// CHECK: %[[W1:.*]] = vector.transfer_write %[[F]]#0, %[[W0]][%[[C0]], %[[C3]]] : vector<2xf32>, tensor<?x?xf32>
+// CHECK: return %[[W1]] : tensor<?x?xf32>
+func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
+ %tensor: tensor<?x?xf32>,
+ %val: index, %lb : index, %ub : index, %step: index) ->
+ (tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 0.0 : f32
+ %1 = scf.for %j = %lb to %ub step %step iter_args(%arg5 = %tensor)
+ -> (tensor<?x?xf32>) {
+ %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+ %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
+ %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
+ %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor<?x?xf32>, vector<2xf32>
+ %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
+ %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor<?x?xf32>
+ scf.yield %w11 : tensor<?x?xf32>
+ }
+ return %1 : tensor<?x?xf32>
+}
+
More information about the Mlir-commits
mailing list