[Mlir-commits] [mlir] 825da07 - [mlir] Add hoisting of transfer ops in affine loops

Javier Setoain llvmlistbot at llvm.org
Tue Dec 6 02:08:55 PST 2022


Author: Javier Setoain
Date: 2022-12-06T10:07:21Z
New Revision: 825da072a8ede585be9d23829f7ac483f2dbae78

URL: https://github.com/llvm/llvm-project/commit/825da072a8ede585be9d23829f7ac483f2dbae78
DIFF: https://github.com/llvm/llvm-project/commit/825da072a8ede585be9d23829f7ac483f2dbae78.diff

LOG: [mlir] Add hoisting of transfer ops in affine loops

The only way to do this with the current hoisting strategy is by
lowering Affine to Scf first, but that prevents further passes on
Affine.

Differential Revision: https://reviews.llvm.org/D137600

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/test/Dialect/Linalg/hoisting.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 79d3e5524dc4e..7d80b0beb5d76 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -14,7 +14,9 @@
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -29,6 +31,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 
 using llvm::dbgs;
@@ -425,10 +428,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
 
       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
                         << *transferRead.getOperation() << "\n");
-      auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
+      auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
       LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
                         << "\n");
-      if (!loop)
+      if (!isa_and_nonnull<scf::ForOp, AffineForOp>(loop))
         return WalkResult::advance();
 
       LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
@@ -513,18 +516,47 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
                                     ArrayRef<BlockArgument> newBBArgs) {
         return SmallVector<Value>{transferWrite.getVector()};
       };
-      auto newForOp =
-          replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn);
 
       // Transfer write has been hoisted, need to update the written vector by
       // the value yielded by the newForOp.
-      transferWrite.getVectorMutable().assign(newForOp.getResults().back());
-
-      changed = true;
-      loop.erase();
-      // Need to interrupt and restart because erasing the loop messes up the
-      // walk.
-      return WalkResult::interrupt();
+      return TypeSwitch<Operation *, WalkResult>(loop)
+          .Case<scf::ForOp>([&](scf::ForOp scfForOp) {
+            auto newForOp = replaceLoopWithNewYields(
+                b, scfForOp, transferRead.getVector(), yieldFn);
+            transferWrite.getVectorMutable().assign(
+                newForOp.getResults().back());
+            changed = true;
+            loop.erase();
+            // Need to interrupt and restart because erasing the loop messes up
+            // the walk.
+            return WalkResult::interrupt();
+          })
+          .Case<AffineForOp>([&](AffineForOp affineForOp) {
+            auto newForOp = replaceForOpWithNewYields(
+                b, affineForOp, transferRead.getVector(),
+                SmallVector<Value>{transferWrite.getVector()},
+                transferWrite.getVector());
+            ArrayRef<BlockArgument> newBBArgs =
+                newForOp.getLoopBody().getArguments().take_back(1);
+            // Replace all uses of the `transferRead` with the corresponding
+            // basic block argument.
+            for (auto it :
+                 llvm::zip(ValueRange{transferRead.getVector()}, newBBArgs)) {
+              std::get<0>(it).replaceUsesWithIf(
+                  std::get<1>(it), [&](OpOperand &use) {
+                    Operation *user = use.getOwner();
+                    return newForOp->isProperAncestor(user);
+                  });
+            }
+            transferWrite.getVectorMutable().assign(
+                newForOp.getResults().back());
+            changed = true;
+            loop.erase();
+            // Need to interrupt and restart because erasing the loop messes up
+            // the walk.
+            return WalkResult::interrupt();
+          })
+          .Default([](Operation *) { return WalkResult::interrupt(); });
     });
   }
 }

diff  --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 2b783d144b7bc..eac8ddcedf554 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -469,3 +469,39 @@ func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
   return %1 : tensor<?x?xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops(
+//  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>,
+//  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>,
+//  CHECK-SAME:   %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) {
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : i32
+//       CHECK:   affine.for %[[I:.*]] = 0 to 64 {
+//       CHECK:     affine.for %[[J:.*]] = 0 to 64 step 16 {
+//       CHECK:       %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32>
+//       CHECK:       %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) {
+//       CHECK:         %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
+//       CHECK:         %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32>
+//       CHECK:         %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32>
+//       CHECK:         %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32>
+//       CHECK:         affine.yield %[[T1]] : vector<16xi32>
+//       CHECK:       }
+//       CHECK:       vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32>
+//       CHECK:     }
+//       CHECK:   }
+func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) {
+  %c0_i32 = arith.constant 0 : i32
+  affine.for %arg3 = 0 to 64 {
+    affine.for %arg4 = 0 to 64 step 16 {
+      affine.for %arg5 = 0 to 64 {
+        %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32>
+        %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
+        %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32>
+        %3 = arith.muli %0, %1 : vector<16xi32>
+        %4 = arith.addi %2, %3 : vector<16xi32>
+        vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32>
+      }
+    }
+  }
+  return
+}


        


More information about the Mlir-commits mailing list