[Mlir-commits] [mlir] 6d5aeb0 - [mlir][linalg] Improve aliasing approximation for hoisting transfer read/write

Thomas Raoux llvmlistbot at llvm.org
Fri Jul 10 15:09:34 PDT 2020


Author: Thomas Raoux
Date: 2020-07-10T14:55:04-07:00
New Revision: 6d5aeb0dceebd4c4d4ac9d9d78919362db28f081

URL: https://github.com/llvm/llvm-project/commit/6d5aeb0dceebd4c4d4ac9d9d78919362db28f081
DIFF: https://github.com/llvm/llvm-project/commit/6d5aeb0dceebd4c4d4ac9d9d78919362db28f081.diff

LOG: [mlir][linalg] Improve aliasing approximation for hoisting transfer read/write

Improve the logic deciding if it is safe to hoist vector transfer read/write
out of the loop. Change the logic to prevent hoisting operations if there are
any unknown access to the memref in the loop no matter where the operation is.
For other transfer read/write in the loop check if we can prove that they
access disjoint memory and ignore them in this case.

Differential Revision: https://reviews.llvm.org/D83538

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/test/Dialect/Linalg/hoisting.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 485f5131ac01..cb7540b46cf8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -80,6 +80,42 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
   }
 }
 
+/// Return true if we can prove that the transfer operations access dijoint
+/// memory.
+template <typename TransferTypeA, typename TransferTypeB>
+static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) {
+  if (transferA.memref() != transferB.memref())
+    return false;
+  // For simplicity only look at transfer of same type.
+  if (transferA.getVectorType() != transferB.getVectorType())
+    return false;
+  unsigned rankOffset = transferA.getLeadingMemRefRank();
+  for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
+    auto indexA = transferA.indices()[i].template getDefiningOp<ConstantOp>();
+    auto indexB = transferB.indices()[i].template getDefiningOp<ConstantOp>();
+    // If any of the indices are dynamic we cannot prove anything.
+    if (!indexA || !indexB)
+      continue;
+
+    if (i < rankOffset) {
+      // For dimension used as index if we can prove that index are 
diff erent we
+      // know we are accessing disjoint slices.
+      if (indexA.getValue().template cast<IntegerAttr>().getInt() !=
+          indexB.getValue().template cast<IntegerAttr>().getInt())
+        return true;
+    } else {
+      // For this dimension, we slice a part of the memref we need to make sure
+      // the intervals accessed don't overlap.
+      int64_t distance =
+          std::abs(indexA.getValue().template cast<IntegerAttr>().getInt() -
+                   indexB.getValue().template cast<IntegerAttr>().getInt());
+      if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
+        return true;
+    }
+  }
+  return false;
+}
+
 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
   bool changed = true;
   while (changed) {
@@ -129,10 +165,11 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
 
       // Approximate aliasing by checking that:
       //   1. indices are the same,
-      //   2. no other use either dominates the transfer_read or is dominated
-      //   by the transfer_write (i.e. aliasing between the write and the read
-      //   across the loop).
-      if (transferRead.indices() != transferWrite.indices())
+      //   2. no other operations in the loop access the same memref except
+      //      for transfer_read/transfer_write accessing statically disjoint
+      //      slices.
+      if (transferRead.indices() != transferWrite.indices() &&
+          transferRead.getVectorType() == transferWrite.getVectorType())
         return WalkResult::advance();
 
       // TODO: may want to memoize this information for performance but it
@@ -140,11 +177,26 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
       DominanceInfo dom(loop);
       if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
         return WalkResult::advance();
-      for (auto &use : transferRead.memref().getUses())
-        if (dom.properlyDominates(use.getOwner(),
-                                  transferRead.getOperation()) ||
-            dom.properlyDominates(transferWrite, use.getOwner()))
+      for (auto &use : transferRead.memref().getUses()) {
+        if (!dom.properlyDominates(loop, use.getOwner()))
+          continue;
+        if (use.getOwner() == transferRead.getOperation() ||
+            use.getOwner() == transferWrite.getOperation())
+          continue;
+        if (auto transferWriteUse =
+                dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
+          if (!isDisjoint(transferWrite, transferWriteUse))
+            return WalkResult::advance();
+        } else if (auto transferReadUse =
+                       dyn_cast<vector::TransferReadOp>(use.getOwner())) {
+          if (!isDisjoint(transferWrite, transferReadUse))
+            return WalkResult::advance();
+        } else {
+          // Unknown use, we cannot prove that it doesn't alias with the
+          // transferRead/transferWrite operations.
           return WalkResult::advance();
+        }
+      }
 
       // Hoist read before.
       if (failed(loop.moveOutOfLoop({transferRead})))

diff  --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 54ab1dcf07c7..4a6fca554a09 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -132,18 +132,101 @@ func @hoist_vector_transfer_pairs(
       %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32>
       "some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
       %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
+      %r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref<?x?xf32>, vector<6xf32>
+      "some_crippling_use"(%memref5) : (memref<?x?xf32>) -> ()
       %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
       %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
       %u2 = "some_use"(%memref2) : (memref<?x?xf32>) -> vector<3xf32>
       %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
       %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
+      %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
       vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
       vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
       vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
       vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
       vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
+      vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref<?x?xf32>
       "some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
     }
   }
   return
 }
+
+// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint(
+//  VECTOR_TRANSFERS-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
+//  VECTOR_TRANSFERS-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
+//  VECTOR_TRANSFERS-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
+//  VECTOR_TRANSFERS-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
+//  VECTOR_TRANSFERS-SAME:   %[[RANDOM:[a-zA-Z0-9]*]]: index,
+//  VECTOR_TRANSFERS-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
+func @hoist_vector_transfer_pairs_disjoint(
+    %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>,
+    %memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index, %lb : index, %ub : index,
+    %step: index, %random_index : index, %cmp: i1) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %cst = constant 0.0 : f32
+
+// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
+//  VECTOR_TRANSFERS-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// VECTOR_TRANSFERS:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
+//  VECTOR_TRANSFERS-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// VECTOR_TRANSFERS:     vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
+// VECTOR_TRANSFERS:     vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS:     vector.transfer_write %{{.*}}, %[[MEMREF1]]{{.*}} : vector<2xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS:     vector.transfer_write %{{.*}}, %[[MEMREF1]]{{.*}} : vector<2xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS:     scf.yield {{.*}} : vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// VECTOR_TRANSFERS:   }
+// VECTOR_TRANSFERS:   scf.yield {{.*}} : vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: }
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF3]]{{.*}} : vector<4xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF3]]{{.*}} : vector<4xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF2]]{{.*}} : vector<3xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF2]]{{.*}} : vector<3xf32>, memref<?x?xf32>
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r00 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<2xf32>
+      %r01 = vector.transfer_read %memref1[%c0, %c1], %cst: memref<?x?xf32>, vector<2xf32>
+      %r20 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
+      %r21 = vector.transfer_read %memref2[%c0, %c3], %cst: memref<?x?xf32>, vector<3xf32>
+      %r30 = vector.transfer_read %memref3[%c0, %random_index], %cst: memref<?x?xf32>, vector<4xf32>
+      %r31 = vector.transfer_read %memref3[%c1, %random_index], %cst: memref<?x?xf32>, vector<4xf32>
+      %r10 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
+      %r11 = vector.transfer_read %memref0[%random_index, %random_index], %cst: memref<?x?xf32>, vector<2xf32>
+      %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
+      %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
+      %u20 = "some_use"(%r20) : (vector<3xf32>) -> vector<3xf32>
+      %u21 = "some_use"(%r21) : (vector<3xf32>) -> vector<3xf32>
+      %u30 = "some_use"(%r30) : (vector<4xf32>) -> vector<4xf32>
+      %u31 = "some_use"(%r31) : (vector<4xf32>) -> vector<4xf32>
+      %u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32>
+      %u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32>
+      vector.transfer_write %u00, %memref1[%c0, %c0] : vector<2xf32>, memref<?x?xf32>
+      vector.transfer_write %u01, %memref1[%c0, %c1] : vector<2xf32>, memref<?x?xf32>
+      vector.transfer_write %u20, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
+      vector.transfer_write %u21, %memref2[%c0, %c3] : vector<3xf32>, memref<?x?xf32>
+      vector.transfer_write %u30, %memref3[%c0, %random_index] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u31, %memref3[%c1, %random_index] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u10, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
+      vector.transfer_write %u11, %memref0[%random_index, %random_index] : vector<2xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}


        


More information about the Mlir-commits mailing list