[Mlir-commits] [mlir] da291ba - [mlir] Add hoisting of transfer ops in affine loops

Javier Setoain llvmlistbot at llvm.org
Wed Dec 7 12:10:21 PST 2022


Author: Javier Setoain
Date: 2022-12-07T20:08:07Z
New Revision: da291bab81200c93dffaa809a894168b7dedffd8

URL: https://github.com/llvm/llvm-project/commit/da291bab81200c93dffaa809a894168b7dedffd8
DIFF: https://github.com/llvm/llvm-project/commit/da291bab81200c93dffaa809a894168b7dedffd8.diff

LOG: [mlir] Add hoisting of transfer ops in affine loops

The only way to do this with the current hoisting strategy is by
lowering Affine to Scf first, but that prevents further passes on
Affine.

Differential Revision: https://reviews.llvm.org/D137600

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/test/Dialect/Linalg/hoisting.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 79d3e5524dc4..f51b4ffe9999 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -14,7 +14,9 @@
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -29,6 +31,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 
 using llvm::dbgs;
@@ -425,10 +428,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
 
       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
                         << *transferRead.getOperation() << "\n");
-      auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
+      auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
                         << "\n");
-      if (!loop)
+      if (!isa_and_nonnull<scf::ForOp, AffineForOp>(loop))
         return WalkResult::advance();
 
       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
@@ -513,18 +516,43 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
                                     ArrayRef<BlockArgument> newBBArgs) {
         return SmallVector<Value>{transferWrite.getVector()};
       };
-      auto newForOp =
-          replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn);
 
       // Transfer write has been hoisted, need to update the written vector by
       // the value yielded by the newForOp.
-      transferWrite.getVectorMutable().assign(newForOp.getResults().back());
-
-      changed = true;
-      loop.erase();
-      // Need to interrupt and restart because erasing the loop messes up the
-      // walk.
-      return WalkResult::interrupt();
+      return TypeSwitch<Operation *, WalkResult>(loop)
+          .Case<scf::ForOp>([&](scf::ForOp scfForOp) {
+            auto newForOp = replaceLoopWithNewYields(
+                b, scfForOp, transferRead.getVector(), yieldFn);
+            transferWrite.getVectorMutable().assign(
+                newForOp.getResults().back());
+            changed = true;
+            loop.erase();
+            // Need to interrupt and restart because erasing the loop messes up
+            // the walk.
+            return WalkResult::interrupt();
+          })
+          .Case<AffineForOp>([&](AffineForOp affineForOp) {
+            auto newForOp = replaceForOpWithNewYields(
+                b, affineForOp, transferRead.getVector(),
+                SmallVector<Value>{transferWrite.getVector()},
+                transferWrite.getVector());
+            // Replace all uses of the `transferRead` with the corresponding
+            // basic block argument.
+            transferRead.getVector().replaceUsesWithIf(
+                newForOp.getLoopBody().getArguments().back(),
+                [&](OpOperand &use) {
+                  Operation *user = use.getOwner();
+                  return newForOp->isProperAncestor(user);
+                });
+            transferWrite.getVectorMutable().assign(
+                newForOp.getResults().back());
+            changed = true;
+            loop.erase();
+            // Need to interrupt and restart because erasing the loop messes up
+            // the walk.
+            return WalkResult::interrupt();
+          })
+          .Default([](Operation *) { return WalkResult::interrupt(); });
     });
   }
 }

diff  --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 2b783d144b7b..eac8ddcedf55 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -469,3 +469,39 @@ func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
   return %1 : tensor<?x?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops(
+//  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>,
+//  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>,
+//  CHECK-SAME:   %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) {
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : i32
+//       CHECK:   affine.for %[[I:.*]] = 0 to 64 {
+//       CHECK:     affine.for %[[J:.*]] = 0 to 64 step 16 {
+//       CHECK:       %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32>
+//       CHECK:       %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) {
+//       CHECK:         %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
+//       CHECK:         %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
+//       CHECK:         %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32>
+//       CHECK:         %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32>
+//       CHECK:         affine.yield %[[T1]] : vector<16xi32>
+//       CHECK:       }
+//       CHECK:       vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32>
+//       CHECK:     }
+//       CHECK:   }
+func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) {
+  %c0_i32 = arith.constant 0 : i32
+  affine.for %arg3 = 0 to 64 {
+    affine.for %arg4 = 0 to 64 step 16 {
+      affine.for %arg5 = 0 to 64 {
+        %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32>
+        %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
+        %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
+        %3 = arith.muli %0, %1 : vector<16xi32>
+        %4 = arith.addi %2, %3 : vector<16xi32>
+        vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32>
+      }
+    }
+  }
+  return
+}


        


More information about the Mlir-commits mailing list