[Mlir-commits] [mlir] [MLIR] Fix arbitrary checks in affine LICM (PR #116469)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 15 21:46:26 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: Uday Bondhugula (bondhugula)

<details>
<summary>Changes</summary>

Fix arbitrary checks and hardcoding/specialcasing  in affine LICM. Drop
unnecessary (too much) debug logging.

This pass is still unsound due to not handling aliases. This will have
to be handled later.


---
Full diff: https://github.com/llvm/llvm-project/pull/116469.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp (+50-92) 
- (modified) mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir (+3-2) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
index fc31931da06073..e3f316443161f6 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
@@ -12,25 +12,9 @@
 
 #include "mlir/Dialect/Affine/Passes.h"
 
-#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
-#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
-#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/LoopUtils.h"
-#include "mlir/Dialect/Affine/Utils.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/SmallPtrSet.h"
-#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -41,7 +25,7 @@ namespace affine {
 } // namespace affine
 } // namespace mlir
 
-#define DEBUG_TYPE "licm"
+#define DEBUG_TYPE "affine-licm"
 
 using namespace mlir;
 using namespace mlir::affine;
@@ -49,9 +33,13 @@ using namespace mlir::affine;
 namespace {
 
 /// Affine loop invariant code motion (LICM) pass.
-/// TODO: The pass is missing zero-trip tests.
-/// TODO: This code should be removed once the new LICM pass can handle its
-///       uses.
+/// TODO: The pass is missing zero tripcount tests.
+/// TODO: When compared to the other standard LICM pass, this pass
+/// has some special handling for affine read/write ops but such handling
+/// requires aliasing to be sound, and as such this pass is unsound. In
+/// addition, this handling is nothing particular to affine memory ops but would
+/// apply to any memory read/write effect ops. Either aliasing should be handled
+/// or this pass can be removed and the standard LICM can be used.
 struct LoopInvariantCodeMotion
     : public affine::impl::AffineLoopInvariantCodeMotionBase<
           LoopInvariantCodeMotion> {
@@ -61,100 +49,80 @@ struct LoopInvariantCodeMotion
 } // namespace
 
 static bool
-checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs,
+checkInvarianceOfNestedIfOps(AffineIfOp ifOp, AffineForOp loop,
                              SmallPtrSetImpl<Operation *> &opsWithUsers,
                              SmallPtrSetImpl<Operation *> &opsToHoist);
-static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
+static bool isOpLoopInvariant(Operation &op, AffineForOp loop,
                               SmallPtrSetImpl<Operation *> &opsWithUsers,
                               SmallPtrSetImpl<Operation *> &opsToHoist);
 
 static bool
-areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
-                                 ValueRange iterArgs,
+areAllOpsInTheBlockListInvariant(Region &blockList, AffineForOp loop,
                                  SmallPtrSetImpl<Operation *> &opsWithUsers,
                                  SmallPtrSetImpl<Operation *> &opsToHoist);
 
-// Returns true if the individual op is loop invariant.
-static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
+/// Returns true if `op` is invariant on `loop`.
+static bool isOpLoopInvariant(Operation &op, AffineForOp loop,
                               SmallPtrSetImpl<Operation *> &opsWithUsers,
                               SmallPtrSetImpl<Operation *> &opsToHoist) {
-  LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;);
+  Value iv = loop.getInductionVar();
 
   if (auto ifOp = dyn_cast<AffineIfOp>(op)) {
-    if (!checkInvarianceOfNestedIfOps(ifOp, indVar, iterArgs, opsWithUsers,
-                                      opsToHoist))
+    if (!checkInvarianceOfNestedIfOps(ifOp, loop, opsWithUsers, opsToHoist))
       return false;
   } else if (auto forOp = dyn_cast<AffineForOp>(op)) {
-    if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), indVar, iterArgs,
-                                          opsWithUsers, opsToHoist))
+    if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), loop, opsWithUsers,
+                                          opsToHoist))
       return false;
   } else if (auto parOp = dyn_cast<AffineParallelOp>(op)) {
-    if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), indVar, iterArgs,
-                                          opsWithUsers, opsToHoist))
+    if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), loop, opsWithUsers,
+                                          opsToHoist))
       return false;
   } else if (!isMemoryEffectFree(&op) &&
-             !isa<AffineReadOpInterface, AffineWriteOpInterface,
-                  AffinePrefetchOp>(&op)) {
+             !isa<AffineReadOpInterface, AffineWriteOpInterface>(&op)) {
     // Check for side-effecting ops. Affine read/write ops are handled
     // separately below.
     return false;
-  } else if (!matchPattern(&op, m_Constant())) {
+  } else if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
     // Register op in the set of ops that have users.
     opsWithUsers.insert(&op);
-    if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
-      auto read = dyn_cast<AffineReadOpInterface>(op);
-      Value memref = read ? read.getMemRef()
-                          : cast<AffineWriteOpInterface>(op).getMemRef();
-      for (auto *user : memref.getUsers()) {
-        // If this memref has a user that is a DMA, give up because these
-        // operations write to this memref.
-        if (isa<AffineDmaStartOp, AffineDmaWaitOp>(user))
+    SmallVector<AffineForOp, 8> userIVs;
+    auto read = dyn_cast<AffineReadOpInterface>(op);
+    Value memref =
+        read ? read.getMemRef() : cast<AffineWriteOpInterface>(op).getMemRef();
+    for (auto *user : memref.getUsers()) {
+      // If the memref used by the load/store is used in a store elsewhere in
+      // the loop nest, we do not hoist. Similarly, if the memref used in a
+      // load is also being stored too, we do not hoist the load.
+      // FIXME: This is missing checking aliases.
+      if (&op == user)
+        continue;
+      if (hasEffect<MemoryEffects::Write>(user, memref) ||
+          (hasEffect<MemoryEffects::Read>(user, memref) &&
+           isa<AffineWriteOpInterface>(op))) {
+        userIVs.clear();
+        getAffineForIVs(*user, &userIVs);
+        // Check that userIVs don't contain the for loop around the op.
+        if (llvm::is_contained(userIVs, loop))
           return false;
-        // If the memref used by the load/store is used in a store elsewhere in
-        // the loop nest, we do not hoist. Similarly, if the memref used in a
-        // load is also being stored too, we do not hoist the load.
-        if (isa<AffineWriteOpInterface>(user) ||
-            (isa<AffineReadOpInterface>(user) &&
-             isa<AffineWriteOpInterface>(op))) {
-          if (&op != user) {
-            SmallVector<AffineForOp, 8> userIVs;
-            getAffineForIVs(*user, &userIVs);
-            // Check that userIVs don't contain the for loop around the op.
-            if (llvm::is_contained(userIVs, getForInductionVarOwner(indVar)))
-              return false;
-          }
-        }
       }
     }
-
-    if (op.getNumOperands() == 0 && !isa<AffineYieldOp>(op)) {
-      LLVM_DEBUG(llvm::dbgs() << "Non-constant op with 0 operands\n");
-      return false;
-    }
   }
 
   // Check operands.
+  ValueRange iterArgs = loop.getRegionIterArgs();
   for (unsigned int i = 0; i < op.getNumOperands(); ++i) {
     auto *operandSrc = op.getOperand(i).getDefiningOp();
 
-    LLVM_DEBUG(
-        op.getOperand(i).print(llvm::dbgs() << "Iterating on operand\n"));
-
     // If the loop IV is the operand, this op isn't loop invariant.
-    if (indVar == op.getOperand(i)) {
-      LLVM_DEBUG(llvm::dbgs() << "Loop IV is the operand\n");
+    if (iv == op.getOperand(i))
       return false;
-    }
 
     // If the one of the iter_args is the operand, this op isn't loop invariant.
-    if (llvm::is_contained(iterArgs, op.getOperand(i))) {
-      LLVM_DEBUG(llvm::dbgs() << "One of the iter_args is the operand\n");
+    if (llvm::is_contained(iterArgs, op.getOperand(i)))
       return false;
-    }
 
     if (operandSrc) {
-      LLVM_DEBUG(llvm::dbgs() << *operandSrc << "Iterating on operand src\n");
-
       // If the value was defined in the loop (outside of the if/else region),
       // and that operation itself wasn't meant to be hoisted, then mark this
       // operation loop dependent.
@@ -170,14 +138,13 @@ static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
 
 // Checks if all ops in a region (i.e. list of blocks) are loop invariant.
 static bool
-areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
-                                 ValueRange iterArgs,
+areAllOpsInTheBlockListInvariant(Region &blockList, AffineForOp loop,
                                  SmallPtrSetImpl<Operation *> &opsWithUsers,
                                  SmallPtrSetImpl<Operation *> &opsToHoist) {
 
   for (auto &b : blockList) {
     for (auto &op : b) {
-      if (!isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist))
+      if (!isOpLoopInvariant(op, loop, opsWithUsers, opsToHoist))
         return false;
     }
   }
@@ -187,14 +154,14 @@ areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
 
 // Returns true if the affine.if op can be hoisted.
 static bool
-checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs,
+checkInvarianceOfNestedIfOps(AffineIfOp ifOp, AffineForOp loop,
                              SmallPtrSetImpl<Operation *> &opsWithUsers,
                              SmallPtrSetImpl<Operation *> &opsToHoist) {
-  if (!areAllOpsInTheBlockListInvariant(ifOp.getThenRegion(), indVar, iterArgs,
+  if (!areAllOpsInTheBlockListInvariant(ifOp.getThenRegion(), loop,
                                         opsWithUsers, opsToHoist))
     return false;
 
-  if (!areAllOpsInTheBlockListInvariant(ifOp.getElseRegion(), indVar, iterArgs,
+  if (!areAllOpsInTheBlockListInvariant(ifOp.getElseRegion(), loop,
                                         opsWithUsers, opsToHoist))
     return false;
 
@@ -202,10 +169,6 @@ checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs,
 }
 
 void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
-  auto *loopBody = forOp.getBody();
-  auto indVar = forOp.getInductionVar();
-  ValueRange iterArgs = forOp.getRegionIterArgs();
-
   // This is the place where hoisted instructions would reside.
   OpBuilder b(forOp.getOperation());
 
@@ -213,14 +176,14 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
   SmallVector<Operation *, 8> opsToMove;
   SmallPtrSet<Operation *, 8> opsWithUsers;
 
-  for (auto &op : *loopBody) {
+  for (Operation &op : *forOp.getBody()) {
     // Register op in the set of ops that have users. This set is used
     // to prevent hoisting ops that depend on these ops that are
     // not being hoisted.
     if (!op.use_empty())
       opsWithUsers.insert(&op);
     if (!isa<AffineYieldOp>(op)) {
-      if (isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist)) {
+      if (isOpLoopInvariant(op, forOp, opsWithUsers, opsToHoist)) {
         opsToMove.push_back(&op);
       }
     }
@@ -231,18 +194,13 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
   for (auto *op : opsToMove) {
     op->moveBefore(forOp);
   }
-
-  LLVM_DEBUG(forOp->print(llvm::dbgs() << "Modified loop\n"));
 }
 
 void LoopInvariantCodeMotion::runOnOperation() {
   // Walk through all loops in a function in innermost-loop-first order.  This
   // way, we first LICM from the inner loop, and place the ops in
   // the outer loop, which in turn can be further LICM'ed.
-  getOperation().walk([&](AffineForOp op) {
-    LLVM_DEBUG(op->print(llvm::dbgs() << "\nOriginal loop\n"));
-    runOnAffineForOp(op);
-  });
+  getOperation().walk([&](AffineForOp op) { runOnAffineForOp(op); });
 }
 
 std::unique_ptr<OperationPass<func::FuncOp>>
diff --git a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir
index c04d7d2053866c..858b7d3ddf9f11 100644
--- a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir
+++ b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir
@@ -855,15 +855,16 @@ func.func @affine_prefetch_invariant() {
   affine.for %i0 = 0 to 10 {
     affine.for %i1 = 0 to 10 {
       %1 = affine.load %0[%i0, %i1] : memref<10x10xf32>
+      // A prefetch shouldn't be hoisted.
       affine.prefetch %0[%i0, %i0], write, locality<0>, data : memref<10x10xf32>
     }
   }
 
   // CHECK:      memref.alloc() : memref<10x10xf32>
   // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:   affine.prefetch
   // CHECK-NEXT:   affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:     %{{.*}}  = affine.load %{{.*}}[%{{.*}}  : memref<10x10xf32>
+  // CHECK-NEXT:     affine.load %{{.*}}[%{{.*}}  : memref<10x10xf32>
+  // CHECK-NEXT:     affine.prefetch
   // CHECK-NEXT:   }
   // CHECK-NEXT: }
   return

``````````

</details>


https://github.com/llvm/llvm-project/pull/116469


More information about the Mlir-commits mailing list