[Mlir-commits] [mlir] 6953cf6 - [mlir][Linalg] Add a hoistRedundantVectorTransfers helper function

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jun 5 03:54:46 PDT 2020


Author: Nicolas Vasilache
Date: 2020-06-05T06:50:24-04:00
New Revision: 6953cf65024395508c464dc78c90b158b3241a3a

URL: https://github.com/llvm/llvm-project/commit/6953cf65024395508c464dc78c90b158b3241a3a
DIFF: https://github.com/llvm/llvm-project/commit/6953cf65024395508c464dc78c90b158b3241a3a.diff

LOG: [mlir][Linalg] Add a hoistRedundantVectorTransfers helper function

This revision adds a helper function to hoist vector.transfer_read /
vector.transfer_write pairs out of immediately enclosing scf::ForOp
iteratively, if the following conditions are true:
   1. The 2 ops access the same memref with the same indices.
   2. All operands are invariant under the enclosing scf::ForOp.
   3. No uses of the memref either dominate the transfer_read or are
   dominated by the transfer_write (i.e. no aliasing between the write and
   the read across the loop)

To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
function on the candidate loop above which to hoist. Hoisting the transfers
results in scf::ForOp yielding the value that originally transited through
memory.

This revision additionally exposes `moveLoopInvariantCode` as a helper in
LoopUtils.h and updates SliceAnalysis to support return scf::For values and
allow hoisting across multiple scf::ForOps.

Differential Revision: https://reviews.llvm.org/D81199

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
    mlir/include/mlir/Transforms/LoopUtils.h
    mlir/lib/Analysis/SliceAnalysis.cpp
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
    mlir/test/Dialect/Linalg/hoisting.mlir
    mlir/test/lib/Transforms/TestLinalgHoisting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 283e68e8a76b..32693555ff40 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -16,11 +16,25 @@ namespace linalg {
 
 /// Hoist alloc/dealloc pairs and alloca op out of immediately enclosing
 /// scf::ForOp if both conditions are true:
-///   1. all operands are defined outside the loop.
-///   2. all uses are ViewLikeOp or DeallocOp.
+///   1. All operands are defined outside the loop.
+///   2. All uses are ViewLikeOp or DeallocOp.
 // TODO: generalize on a per-need basis.
 void hoistViewAllocOps(FuncOp func);
 
+/// Hoist vector.transfer_read/vector.transfer_write pairs out of immediately
+/// enclosing scf::ForOp iteratively, if the following conditions are true:
+///   1. The two ops access the same memref with the same indices.
+///   2. All operands are invariant under the enclosing scf::ForOp.
+///   3. No uses of the memref either dominate the transfer_read or are
+///   dominated by the transfer_write (i.e. no aliasing between the write and
+///   the read across the loop)
+/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
+/// function on the candidate loop above which to hoist. Hoisting the transfers
+/// results in scf::ForOp yielding the value that originally transited through
+/// memory.
+// TODO: generalize on a per-need basis.
+void hoistRedundantVectorTransfers(FuncOp func);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h
index b3f23aaa4397..5a0d46f5ba57 100644
--- a/mlir/include/mlir/Transforms/LoopUtils.h
+++ b/mlir/include/mlir/Transforms/LoopUtils.h
@@ -22,10 +22,11 @@
 namespace mlir {
 class AffineForOp;
 class FuncOp;
+class LoopLikeOpInterface;
+struct MemRefRegion;
 class OpBuilder;
 class Value;
 class ValueRange;
-struct MemRefRegion;
 
 namespace scf {
 class ForOp;
@@ -294,6 +295,9 @@ LogicalResult
 separateFullTiles(MutableArrayRef<AffineForOp> nest,
                   SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
 
+/// Move loop invariant code out of `looplike`.
+LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike);
+
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_LOOP_UTILS_H

diff  --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 5b630be5e9a5..e0c828fb55c1 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -41,20 +41,24 @@ static void getForwardSliceImpl(Operation *op,
   }
 
   if (auto forOp = dyn_cast<AffineForOp>(op)) {
-    for (auto *ownerInst : forOp.getInductionVar().getUsers())
-      if (forwardSlice->count(ownerInst) == 0)
-        getForwardSliceImpl(ownerInst, forwardSlice, filter);
+    for (auto *ownerOp : forOp.getInductionVar().getUsers())
+      if (forwardSlice->count(ownerOp) == 0)
+        getForwardSliceImpl(ownerOp, forwardSlice, filter);
   } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
-    for (auto *ownerInst : forOp.getInductionVar().getUsers())
-      if (forwardSlice->count(ownerInst) == 0)
-        getForwardSliceImpl(ownerInst, forwardSlice, filter);
+    for (auto *ownerOp : forOp.getInductionVar().getUsers())
+      if (forwardSlice->count(ownerOp) == 0)
+        getForwardSliceImpl(ownerOp, forwardSlice, filter);
+    for (auto result : forOp.getResults())
+      for (auto *ownerOp : result.getUsers())
+        if (forwardSlice->count(ownerOp) == 0)
+          getForwardSliceImpl(ownerOp, forwardSlice, filter);
   } else {
     assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
     assert(op->getNumResults() <= 1 && "unexpected multiple results");
     if (op->getNumResults() > 0) {
-      for (auto *ownerInst : op->getResult(0).getUsers())
-        if (forwardSlice->count(ownerInst) == 0)
-          getForwardSliceImpl(ownerInst, forwardSlice, filter);
+      for (auto *ownerOp : op->getResult(0).getUsers())
+        if (forwardSlice->count(ownerOp) == 0)
+          getForwardSliceImpl(ownerOp, forwardSlice, filter);
     }
   }
 
@@ -139,15 +143,15 @@ SetVector<Operation *> mlir::getSlice(Operation *op,
   SetVector<Operation *> backwardSlice;
   SetVector<Operation *> forwardSlice;
   while (currentIndex != slice.size()) {
-    auto *currentInst = (slice)[currentIndex];
-    // Compute and insert the backwardSlice starting from currentInst.
+    auto *currentOp = (slice)[currentIndex];
+    // Compute and insert the backwardSlice starting from currentOp.
     backwardSlice.clear();
-    getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
+    getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
     slice.insert(backwardSlice.begin(), backwardSlice.end());
 
-    // Compute and insert the forwardSlice starting from currentInst.
+    // Compute and insert the forwardSlice starting from currentOp.
     forwardSlice.clear();
-    getForwardSlice(currentInst, &forwardSlice, forwardFilter);
+    getForwardSlice(currentOp, &forwardSlice, forwardFilter);
     slice.insert(forwardSlice.begin(), forwardSlice.end());
     ++currentIndex;
   }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 737e69021bc8..eeabdeb815dd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -12,10 +12,15 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Utils.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Function.h"
+#include "mlir/Transforms/LoopUtils.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Debug.h"
 
@@ -75,3 +80,96 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
     });
   }
 }
+
+void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
+  bool changed = true;
+  while (changed) {
+    changed = false;
+
+    func.walk([&](vector::TransferReadOp transferRead) {
+      LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
+                        << *transferRead.getOperation() << "\n");
+      auto loop = dyn_cast<scf::ForOp>(transferRead.getParentOp());
+      LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead.getParentOp()
+                        << "\n");
+      if (!loop)
+        return WalkResult::advance();
+
+      if (failed(moveLoopInvariantCode(
+              cast<LoopLikeOpInterface>(loop.getOperation()))))
+        llvm_unreachable(
+            "Unexpected failure to move invariant code out of loop");
+
+      LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
+                        << "\n");
+
+      llvm::SetVector<Operation *> forwardSlice;
+      getForwardSlice(transferRead, &forwardSlice);
+
+      // Look for the last TransferWriteOp in the forwardSlice of
+      // `transferRead` that operates on the same memref.
+      vector::TransferWriteOp transferWrite;
+      for (auto *sliceOp : llvm::reverse(forwardSlice)) {
+        auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
+        if (!candidateWrite || candidateWrite.memref() != transferRead.memref())
+          continue;
+        transferWrite = candidateWrite;
+      }
+
+      // All operands of the TransferRead must be defined outside of the loop.
+      for (auto operand : transferRead.getOperands())
+        if (!loop.isDefinedOutsideOfLoop(operand))
+          return WalkResult::advance();
+
+      // Only hoist transfer_read / transfer_write pairs for now.
+      if (!transferWrite)
+        return WalkResult::advance();
+
+      LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
+                        << "\n");
+
+      // Approximate aliasing by checking that:
+      //   1. indices are the same,
+      //   2. no other use either dominates the transfer_read or is dominated
+      //   by the transfer_write (i.e. aliasing between the write and the read
+      //   across the loop).
+      if (transferRead.indices() != transferWrite.indices())
+        return WalkResult::advance();
+
+      // TODO: may want to memoize this information for performance but it
+      // likely gets invalidated often.
+      DominanceInfo dom(loop);
+      if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
+        return WalkResult::advance();
+      for (auto &use : transferRead.memref().getUses())
+        if (dom.properlyDominates(use.getOwner(),
+                                  transferRead.getOperation()) ||
+            dom.properlyDominates(transferWrite, use.getOwner()))
+          return WalkResult::advance();
+
+      // Hoist read before.
+      if (failed(loop.moveOutOfLoop({transferRead})))
+        llvm_unreachable(
+            "Unexpected failure to move transfer read out of loop");
+
+      // Hoist write after.
+      transferWrite.getOperation()->moveAfter(loop);
+
+      // Rewrite `loop` with new yields by cloning and erase the original loop.
+      OpBuilder b(transferRead);
+      auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(),
+                                         transferWrite.vector());
+
+      // Transfer write has been hoisted, need to update the written value to
+      // the value yielded by the newForOp.
+      transferWrite.vector().replaceAllUsesWith(
+          newForOp.getResults().take_back()[0]);
+
+      changed = true;
+      loop.erase();
+      // Need to interrupt and restart because erasing the loop messes up the
+      // walk.
+      return WalkResult::interrupt();
+    });
+  }
+}

diff  --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index 5fee7d62ee10..0a78cb4526f6 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Function.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Transforms/LoopUtils.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -73,7 +74,7 @@ static bool canBeHoisted(Operation *op,
   return true;
 }
 
-static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike) {
+LogicalResult mlir::moveLoopInvariantCode(LoopLikeOpInterface looplike) {
   auto &loopBody = looplike.getLoopBody();
 
   // We use two collections here as we need to preserve the order for insertion

diff  --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index db686655d64d..54ab1dcf07c7 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,12 +1,13 @@
-// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect | FileCheck %s --check-prefix=VECTOR_TRANSFERS
 
-// CHECK-LABEL: func @hoist(
+// CHECK-LABEL: func @hoist_allocs(
 //  CHECK-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
 //  CHECK-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
-func @hoist(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
+func @hoist_allocs(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
 //   CHECK-DAG:   alloca(%[[VAL]]) : memref<?xi8>
 //   CHECK-DAG: %[[A0:.*]] = alloc(%[[VAL]]) : memref<?xi8>
 //       CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
@@ -80,3 +81,69 @@ func @hoist(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
 //       CHECK: dealloc %[[A0]] : memref<?xi8>
   return
 }
+
+// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs(
+//  VECTOR_TRANSFERS-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
+//  VECTOR_TRANSFERS-SAME:   %[[LB:[a-zA-Z0-9]*]]: index,
+//  VECTOR_TRANSFERS-SAME:   %[[UB:[a-zA-Z0-9]*]]: index,
+//  VECTOR_TRANSFERS-SAME:   %[[STEP:[a-zA-Z0-9]*]]: index,
+//  VECTOR_TRANSFERS-SAME:   %[[CMP:[a-zA-Z0-9]*]]: i1
+func @hoist_vector_transfer_pairs(
+    %memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
+    %memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
+    %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
+  %c0 = constant 0 : index
+  %cst = constant 0.0 : f32
+
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
+// VECTOR_TRANSFERS: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
+// VECTOR_TRANSFERS:   vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
+// VECTOR_TRANSFERS:   scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
+// VECTOR_TRANSFERS:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
+// VECTOR_TRANSFERS:     "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
+// VECTOR_TRANSFERS:     vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS:     "some_use"(%[[MEMREF2]]) : (memref<?x?xf32>) -> vector<3xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
+// VECTOR_TRANSFERS:     vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS:     vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS:     vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS:     "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> ()
+// VECTOR_TRANSFERS:     scf.yield {{.*}} : vector<1xf32>, vector<2xf32>
+// VECTOR_TRANSFERS:   }
+// VECTOR_TRANSFERS:   vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32>
+// VECTOR_TRANSFERS:   scf.yield {{.*}} : vector<1xf32>
+// VECTOR_TRANSFERS: }
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
+  scf.for %i = %lb to %ub step %step {
+    scf.for %j = %lb to %ub step %step {
+      %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
+      %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
+      %r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
+      %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32>
+      "some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
+      %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
+      %u2 = "some_use"(%memref2) : (memref<?x?xf32>) -> vector<3xf32>
+      %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
+      %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
+      vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+      vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
+      vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
+      vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
+      vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
+      "some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
+    }
+  }
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
index 2fb5c091aed4..d1e478fec3bc 100644
--- a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
@@ -29,6 +29,10 @@ struct TestLinalgHoisting
       *this, "test-hoist-view-allocs",
       llvm::cl::desc("Test hoisting alloc used by view"),
       llvm::cl::init(false)};
+  Option<bool> testHoistRedundantTransfers{
+      *this, "test-hoist-redundant-transfers",
+      llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"),
+      llvm::cl::init(false)};
 };
 } // end anonymous namespace
 
@@ -37,6 +41,10 @@ void TestLinalgHoisting::runOnFunction() {
     hoistViewAllocOps(getFunction());
     return;
   }
+  if (testHoistRedundantTransfers) {
+    hoistRedundantVectorTransfers(getFunction());
+    return;
+  }
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list