[Mlir-commits] [mlir] Generalize affine fusion to work at all depths and inside other region-holding ops (PR #72288)

Uday Bondhugula llvmlistbot at llvm.org
Tue Nov 14 09:04:33 PST 2023


https://github.com/bondhugula created https://github.com/llvm/llvm-project/pull/72288

Generalize affine fusion to work at any inner depth; fusing loops inside other
affine.for or even inside scf.for or scf.while nests. Apply in post order on
all affine nests on the pass' top-level operation.

Fix MDG init for blocks inside other affine nests.

Relax unnecessary requirement for unique vars during merge and align of
FlatLinearValueConstraints. There are several cases where
FlatLinearValueConstraints need to have duplicate Values for the dimensions:
for eg. in dependence relation systems with source and destination accesses
could have common loop IVs.  `mergeAndAlign` can be done even in the presence
of Values reappearing by simply aligning from left to right in that order.

While at this, drop outdated comments; improve some debug messages.


>From 9ee170e6eb29594004e861bb156abc1637af4bf6 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Fri, 3 Nov 2023 17:54:46 +0530
Subject: [PATCH] Generalize affine fusion to work at all depths and inside
 other region-holding ops

Generalize affine fusion to work at any inner depth; fusing loops inside other
affine.for or even inside scf.for or scf.while nests. Apply in post order on
all affine nests on the pass' top-level operation.

Fix MDG init for blocks inside other affine nests.

Relax unnecessary requirement for unique vars during merge and align of
FlatLinearValueConstraints. There are several cases where
FlatLinearValueConstraints need to have duplicate Values for the dimensions:
for eg. in dependence relation systems with source and destination accesses
could have common loop IVs.  `mergeAndAlign` can be done even in the presence
of Values reappearing by simply aligning from left to right in that order.

While at this, drop outdated comments; improve some debug messages.
---
 .../Analysis/FlatLinearValueConstraints.h     |   8 +-
 .../mlir/Dialect/Affine/Analysis/Utils.h      |  16 +-
 .../mlir/Dialect/Affine/LoopFusionUtils.h     |   1 -
 .../Analysis/FlatLinearValueConstraints.cpp   |  39 +--
 mlir/lib/Dialect/Affine/Analysis/Utils.cpp    |  17 +-
 .../Dialect/Affine/Transforms/LoopFusion.cpp  |  65 ++++-
 .../Dialect/Affine/Utils/LoopFusionUtils.cpp  |   1 -
 .../Dialect/Affine/loop-fusion-inner.mlir     | 225 ++++++++++++++++++
 8 files changed, 320 insertions(+), 52 deletions(-)
 create mode 100644 mlir/test/Dialect/Affine/loop-fusion-inner.mlir

diff --git a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
index f3c46f5efd73cb1..e4de5b0661571c8 100644
--- a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
+++ b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
@@ -374,10 +374,10 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
       setValue(i, values[i - start]);
   }
 
-  /// Looks up the position of the variable with the specified Value. Returns
-  /// true if found (false otherwise). `pos` is set to the (column) position of
-  /// the variable.
-  bool findVar(Value val, unsigned *pos) const;
+  /// Looks up the position of the variable with the specified Value starting
+  /// with variables at offset `offset`. Returns true if found (false
+  /// otherwise). `pos` is set to the (column) position of the variable.
+  bool findVar(Value val, unsigned *pos, unsigned offset = 0) const;
 
   /// Returns true if a variable with the specified Value exists, false
   /// otherwise.
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index 7376d9b992a0442..3dea99cd6b3e54f 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -213,20 +213,20 @@ struct MemRefDependenceGraph {
 };
 
 /// Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered
-/// from the outermost 'affine.for' operation to the innermost one.
+/// from the outermost 'affine.for' operation to the innermost one while not
+/// traversing outside of the surrounding affine scope.
 void getAffineForIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);
 
 /// Populates 'ivs' with IVs of the surrounding affine.for and affine.parallel
-/// ops ordered from the outermost one to the innermost.
+/// ops ordered from the outermost one to the innermost while not traversing
+/// outside of the surrounding affine scope.
 void getAffineIVs(Operation &op, SmallVectorImpl<Value> &ivs);
 
 /// Populates 'ops' with affine operations enclosing `op` ordered from outermost
-/// to innermost. affine.for, affine.if, or affine.parallel ops comprise such
-/// surrounding affine ops.
-/// TODO: Change this to return a list of enclosing ops up until the op that
-/// starts an `AffineScope`. In such a case, `ops` is guaranteed by design to
-/// have a successive chain of affine parent ops, and this is primarily what is
-/// needed for most analyses.
+/// to innermost while stopping at the boundary of the affine scope. affine.for,
+/// affine.if, or affine.parallel ops comprise such surrounding affine ops.
+/// `ops` is guaranteed by design to have a successive chain of affine parent
+/// ops.
 void getEnclosingAffineOps(Operation &op, SmallVectorImpl<Operation *> *ops);
 
 /// Returns the nesting depth of this operation, i.e., the number of loops
diff --git a/mlir/include/mlir/Dialect/Affine/LoopFusionUtils.h b/mlir/include/mlir/Dialect/Affine/LoopFusionUtils.h
index 284d394ddceffef..4cec777f9888ca4 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopFusionUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopFusionUtils.h
@@ -110,7 +110,6 @@ class FusionStrategy {
 /// returns a FusionResult explaining why fusion is not feasible.
 /// NOTE: This function is not feature complete and should only be used in
 /// testing.
-/// TODO: Update comments when this function is fully implemented.
 FusionResult
 canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth,
              ComputationSliceState *srcSlice,
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 382d05f3b2d4851..d76e94b7070e8f3 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -958,15 +958,15 @@ areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) {
 /// so that they have the union of all variables, with A's original
 /// variables appearing first followed by any of B's variables that didn't
 /// appear in A. Local variables in B that have the same division
-/// representation as local variables in A are merged into one.
+/// representation as local variables in A are merged into one. We allow A
+/// and B to have non-unique values for their variables; in such cases, they are
+/// still aligned with the variables appearing first aligned with those
+/// appearing first in the other system from left to right.
 //  E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
 //        Output: both A, B have (%i, %j, %k) [%M, %N, %P]
 static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
                               FlatLinearValueConstraints *b) {
   assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars());
-  // A merge/align isn't meaningful if a cst's vars aren't distinct.
-  assert(areVarsUnique(*a) && "A's values aren't unique");
-  assert(areVarsUnique(*b) && "B's values aren't unique");
 
   assert(llvm::all_of(
       llvm::drop_begin(a->getMaybeValues(), offset),
@@ -982,9 +982,12 @@ static void mergeAndAlignVars(unsigned offset, FlatLinearValueConstraints *a,
   {
     // Merge dims from A into B.
     unsigned d = offset;
-    for (auto aDimValue : aDimValues) {
+    for (Value aDimValue : aDimValues) {
       unsigned loc;
-      if (b->findVar(aDimValue, &loc)) {
+      // Find from the position `d` since we'd like to also consider the
+      // possibility of multiple variables with the same `Value`. We align with
+      // the next appearing one.
+      if (b->findVar(aDimValue, &loc, d)) {
         assert(loc >= offset && "A's dim appears in B's aligned range");
         assert(loc < b->getNumDimVars() &&
                "A's dim appears in B's non-dim position");
@@ -1017,15 +1020,12 @@ void FlatLinearValueConstraints::mergeAndAlignVarsWithOther(
 }
 
 /// Merge and align symbols of `this` and `other` such that both get union of
-/// of symbols that are unique. Symbols in `this` and `other` should be
-/// unique. Symbols with Value as `None` are considered to be inequal to all
-/// other symbols.
+/// of symbols. Existing symbols need not be unique; they will be aligned from
+/// left to right with duplicates aligned in the same order. Symbols with Value
+/// as `None` are considered to be inequal to all other symbols.
 void FlatLinearValueConstraints::mergeSymbolVars(
     FlatLinearValueConstraints &other) {
 
-  assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique");
-  assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique");
-
   SmallVector<Value, 4> aSymValues;
   getValues(getNumDimVars(), getNumDimAndSymbolVars(), &aSymValues);
 
@@ -1034,8 +1034,9 @@ void FlatLinearValueConstraints::mergeSymbolVars(
   for (Value aSymValue : aSymValues) {
     unsigned loc;
     // If the var is a symbol in `other`, then align it, otherwise assume that
-    // it is a new symbol
-    if (other.findVar(aSymValue, &loc) && loc >= other.getNumDimVars() &&
+    // it is a new symbol. Search in `other` starting at position `s` since the
+    // left of it is aligned.
+    if (other.findVar(aSymValue, &loc, s) && loc >= other.getNumDimVars() &&
         loc < other.getNumDimAndSymbolVars())
       other.swapVar(s, loc);
     else
@@ -1051,8 +1052,6 @@ void FlatLinearValueConstraints::mergeSymbolVars(
 
   assert(getNumSymbolVars() == other.getNumSymbolVars() &&
          "expected same number of symbols");
-  assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique");
-  assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique");
 }
 
 bool FlatLinearValueConstraints::hasConsistentState() const {
@@ -1104,9 +1103,11 @@ FlatLinearValueConstraints::computeAlignedMap(AffineMap map,
   return alignedMap;
 }
 
-bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos) const {
-  unsigned i = 0;
-  for (const auto &mayBeVar : values) {
+bool FlatLinearValueConstraints::findVar(Value val, unsigned *pos,
+                                         unsigned offset) const {
+  unsigned i = offset;
+  for (const auto &mayBeVar :
+       ArrayRef<std::optional<Value>>(values).drop_front(offset)) {
     if (mayBeVar && *mayBeVar == val) {
       *pos = i;
       return true;
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9d2998456a4b679..d9b1050c83e269b 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -199,10 +199,15 @@ bool MemRefDependenceGraph::init() {
           continue;
         SmallVector<AffineForOp, 4> loops;
         getAffineForIVs(*user, &loops);
-        if (loops.empty())
+        // Find the surrounding affine.for nested immediately within the
+        // block.
+        auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
+          return loop->getBlock() == █
+        });
+        if (it == loops.end())
           continue;
-        assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping");
-        unsigned userLoopNestId = forToNodeMap[loops[0]];
+        assert(forToNodeMap.count(*it) > 0 && "missing mapping");
+        unsigned userLoopNestId = forToNodeMap[*it];
         addEdge(node.id, userLoopNestId, value);
       }
     }
@@ -631,8 +636,8 @@ void mlir::affine::getAffineForIVs(Operation &op,
   AffineForOp currAffineForOp;
   // Traverse up the hierarchy collecting all 'affine.for' operation while
   // skipping over 'affine.if' operations.
-  while (currOp) {
-    if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
+  while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
+    if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
       loops->push_back(currAffineForOp);
     currOp = currOp->getParentOp();
   }
@@ -646,7 +651,7 @@ void mlir::affine::getEnclosingAffineOps(Operation &op,
 
   // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
   // affine.parallel operations.
-  while (currOp) {
+  while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
     if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(currOp))
       ops->push_back(currOp);
     currOp = currOp->getParentOp();
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 2629f5a511cecb7..66d921b4889f59f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -896,6 +896,16 @@ struct GreedyFusion {
         if (fusedLoopInsPoint == nullptr)
           continue;
 
+        // It's possible this fusion is at an inner depth (i.e., there are
+        // common surrounding affine loops for the source and destination for
+        // ops). We need to get this number because the call to canFuseLoops
+        // needs to be passed the absolute depth. The max legal depth and the
+        // depths we try below are however *relative* and as such don't include
+        // the common depth.
+        SmallVector<AffineForOp, 4> surroundingLoops;
+        getAffineForIVs(*dstAffineForOp, &surroundingLoops);
+        unsigned numSurroundingLoops = surroundingLoops.size();
+
         // Compute the innermost common loop depth for dstNode
         // producer-consumer loads/stores.
         SmallVector<Operation *, 2> dstMemrefOps;
@@ -907,7 +917,8 @@ struct GreedyFusion {
           if (producerConsumerMemrefs.count(
                   cast<AffineWriteOpInterface>(op).getMemRef()))
             dstMemrefOps.push_back(op);
-        unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
+        unsigned dstLoopDepthTest =
+            getInnermostCommonLoopDepth(dstMemrefOps) - numSurroundingLoops;
 
         // Check the feasibility of fusing src loop nest into dst loop nest
         // at loop depths in range [1, dstLoopDepthTest].
@@ -916,9 +927,10 @@ struct GreedyFusion {
         depthSliceUnions.resize(dstLoopDepthTest);
         FusionStrategy strategy(FusionStrategy::ProducerConsumer);
         for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
-          FusionResult result = affine::canFuseLoops(
-              srcAffineForOp, dstAffineForOp,
-              /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
+          FusionResult result =
+              affine::canFuseLoops(srcAffineForOp, dstAffineForOp,
+                                   /*dstLoopDepth=*/i + numSurroundingLoops,
+                                   &depthSliceUnions[i - 1], strategy);
 
           if (result.value == FusionResult::Success)
             maxLegalFusionDepth = i;
@@ -1125,9 +1137,18 @@ struct GreedyFusion {
       SmallVector<Operation *, 2> dstLoadOpInsts;
       dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
 
+      // It's possible this fusion is at an inner depth (i.e., there are common
+      // surrounding affine loops for the source and destination for ops). We
+      // need to get this number because the call to canFuseLoops needs to be
+      // passed the absolute depth. The max legal depth and the depths we try
+      // below are however *relative* and as such don't include the common
+      // depth.
+      SmallVector<AffineForOp, 4> surroundingLoops;
+      getAffineForIVs(*dstAffineForOp, &surroundingLoops);
+      unsigned numSurroundingLoops = surroundingLoops.size();
       SmallVector<AffineForOp, 4> dstLoopIVs;
       getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
-      unsigned dstLoopDepthTest = dstLoopIVs.size();
+      unsigned dstLoopDepthTest = dstLoopIVs.size() - numSurroundingLoops;
       auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
 
       // Compute loop depth and slice union for fusion.
@@ -1136,14 +1157,18 @@ struct GreedyFusion {
       unsigned maxLegalFusionDepth = 0;
       FusionStrategy strategy(memref);
       for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
-        FusionResult result = affine::canFuseLoops(
-            sibAffineForOp, dstAffineForOp,
-            /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
+        FusionResult result =
+            affine::canFuseLoops(sibAffineForOp, dstAffineForOp,
+                                 /*dstLoopDepth=*/i + numSurroundingLoops,
+                                 &depthSliceUnions[i - 1], strategy);
 
         if (result.value == FusionResult::Success)
           maxLegalFusionDepth = i;
       }
 
+      LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
+                              << maxLegalFusionDepth << '\n');
+
       // Skip if fusion is not feasible at any loop depths.
       if (maxLegalFusionDepth == 0)
         continue;
@@ -1238,9 +1263,15 @@ struct GreedyFusion {
         SmallVector<AffineForOp, 4> loops;
         getAffineForIVs(*user, &loops);
         // Skip 'use' if it is not within a loop nest.
-        if (loops.empty())
+        // Find the surrounding affine.for nested immediately within the
+        // block.
+        auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
+          return loop->getBlock() == &mdg->block;
+        });
+        // Skip 'use' if it is not within a loop nest in `block`.
+        if (it == loops.end())
           continue;
-        Node *sibNode = mdg->getForOpNode(loops[0]);
+        Node *sibNode = mdg->getForOpNode(*it);
         assert(sibNode != nullptr);
         // Skip 'use' if it not a sibling to 'dstNode'.
         if (sibNode->id == dstNode->id)
@@ -1373,9 +1404,17 @@ void LoopFusion::runOnBlock(Block *block) {
 }
 
 void LoopFusion::runOnOperation() {
-  for (Region &region : getOperation()->getRegions())
-    for (Block &block : region.getBlocks())
-      runOnBlock(&block);
+  // Call fusion on every op that has at least two affine.for nests (in post
+  // order).
+  getOperation()->walk([&](Operation *op) {
+    for (Region &region : op->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        auto affineFors = block.getOps<AffineForOp>();
+        if (!affineFors.empty() && !llvm::hasSingleElement(affineFors))
+          runOnBlock(&block);
+      }
+    }
+  });
 }
 
 std::unique_ptr<Pass> mlir::affine::createLoopFusionPass(
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 5053b08ee0834cd..77ee1d4c7a02d97 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -243,7 +243,6 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
   return loopDepth;
 }
 
-// TODO: Prevent fusion of loop nests with side-effecting operations.
 // TODO: This pass performs some computation that is the same for all the depths
 // (e.g., getMaxLoopDepth). Implement a version of this utility that processes
 // all the depths at once or only the legal maximal depth for maximal fusion.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-inner.mlir b/mlir/test/Dialect/Affine/loop-fusion-inner.mlir
new file mode 100644
index 000000000000000..61af9a4baf46d22
--- /dev/null
+++ b/mlir/test/Dialect/Affine/loop-fusion-inner.mlir
@@ -0,0 +1,225 @@
+// RUN: mlir-opt -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer fusion-maximal}))' %s | FileCheck %s
+
+// Test fusion of affine nests inside other region-holding ops (scf.for in the
+// test case below).
+
+// CHECK-LABEL: func @fusion_inner_simple
+func.func @fusion_inner_simple(%A : memref<10xf32>) {
+  %cst = arith.constant 0.0 : f32
+
+  affine.for %i = 0 to 100 {
+    %B = memref.alloc() : memref<10xf32>
+    %C = memref.alloc() : memref<10xf32>
+
+    affine.for %j = 0 to 10 {
+      %v = affine.load %A[%j] : memref<10xf32>
+      affine.store %v, %B[%j] : memref<10xf32>
+    }
+
+    affine.for %j = 0 to 10 {
+      %v = affine.load %B[%j] : memref<10xf32>
+      affine.store %v, %C[%j] : memref<10xf32>
+    }
+  }
+
+  // CHECK:      affine.for %{{.*}} = 0 to 100
+  // CHECK-NEXT:   memref.alloc
+  // CHECK-NEXT:   memref.alloc
+  // CHECK-NEXT:   affine.for %{{.*}} = 0 to 10
+  // CHECK-NOT:    affine.for
+
+  return
+}
+
+// CHECK-LABEL: func @fusion_inner_simple_scf
+func.func @fusion_inner_simple_scf(%A : memref<10xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %cst = arith.constant 0.0 : f32
+
+  scf.for %i = %c0 to %c100 step %c1 {
+    %B = memref.alloc() : memref<10xf32>
+    %C = memref.alloc() : memref<10xf32>
+
+    affine.for %j = 0 to 10 {
+      %v = affine.load %A[%j] : memref<10xf32>
+      affine.store %v, %B[%j] : memref<10xf32>
+    }
+
+    affine.for %j = 0 to 10 {
+      %v = affine.load %B[%j] : memref<10xf32>
+      affine.store %v, %C[%j] : memref<10xf32>
+    }
+  }
+  // CHECK:      scf.for
+  // CHECK-NEXT:   memref.alloc
+  // CHECK-NEXT:   memref.alloc
+  // CHECK-NEXT:   affine.for %{{.*}} = 0 to 10
+  // CHECK-NOT:    affine.for
+  return
+}
+
+// CHECK-LABEL: func @fusion_inner_multiple_nests
+func.func @fusion_inner_multiple_nests() {
+  %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x4xi8>
+  %alloc_10 = memref.alloc() : memref<8x4xi32>
+  affine.for %arg8 = 0 to 4 {
+    %alloc_14 = memref.alloc() : memref<4xi8>
+    %alloc_15 = memref.alloc() : memref<8x4xi8>
+    affine.for %arg9 = 0 to 4 {
+      %0 = affine.load %alloc_5[%arg9, %arg8] : memref<4x4xi8>
+      affine.store %0, %alloc_14[%arg9] : memref<4xi8>
+    }
+    %alloc_16 = memref.alloc() : memref<4xi8>
+    affine.for %arg9 = 0 to 4 {
+      %0 = affine.load %alloc_14[%arg9] : memref<4xi8>
+      affine.store %0, %alloc_16[%arg9] : memref<4xi8>
+    }
+    affine.for %arg9 = 0 to 2 {
+      %0 = affine.load %alloc_15[%arg9 * 4, 0] : memref<8x4xi8>
+      %1 = affine.load %alloc_16[0] : memref<4xi8>
+      %2 = affine.load %alloc_10[%arg9 * 4, %arg8] : memref<8x4xi32>
+      %3 = arith.muli %0, %1 : i8
+      %4 = arith.extsi %3 : i8 to i32
+      %5 = arith.addi %4, %2 : i32
+      affine.store %5, %alloc_10[%arg9 * 4 + 3, %arg8] : memref<8x4xi32>
+    }
+    memref.dealloc %alloc_16 : memref<4xi8>
+  }
+  // CHECK:      affine.for %{{.*}} = 0 to 4 {
+  // Everything inside fused into two nests (the second will be DCE'd).
+  // CHECK-NEXT:   memref.alloc() : memref<4xi8>
+  // CHECK-NEXT:   memref.alloc() : memref<1xi8>
+  // CHECK-NEXT:   memref.alloc() : memref<1xi8>
+  // CHECK-NEXT:   memref.alloc() : memref<8x4xi8>
+  // CHECK-NEXT:   memref.alloc() : memref<4xi8>
+  // CHECK-NEXT:   affine.for %{{.*}} = 0 to 2 {
+  // CHECK:        }
+  // CHECK:        affine.for %{{.*}} = 0 to 4 {
+  // CHECK:        }
+  // CHECK-NEXT:   memref.dealloc
+  // CHECK-NEXT: }
+  // CHECK-NEXT: return
+  return
+}
+
+// CHECK-LABEL: func @fusion_inside_scf_while
+func.func @fusion_inside_scf_while(%A : memref<10xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+  %cst = arith.constant 0.0 : f32
+
+  %0 = scf.while (%arg3 = %cst) : (f32) -> (f32) {
+    %1 = arith.cmpf ult, %arg3, %cst : f32
+    scf.condition(%1) %arg3 : f32
+  } do {
+  ^bb0(%arg5: f32):
+
+    %B = memref.alloc() : memref<10xf32>
+    %C = memref.alloc() : memref<10xf32>
+
+    affine.for %j = 0 to 10 {
+      %v = affine.load %A[%j] : memref<10xf32>
+      affine.store %v, %B[%j] : memref<10xf32>
+    }
+
+    affine.for %j = 0 to 10 {
+      %v = affine.load %B[%j] : memref<10xf32>
+      affine.store %v, %C[%j] : memref<10xf32>
+    }
+    %1 = arith.mulf %arg5, %cst : f32
+    scf.yield %1 : f32
+  }
+  // CHECK:      scf.while
+  // CHECK:        affine.for %{{.*}} = 0 to 10
+  // CHECK-NOT:    affine.for
+  // CHECK:        scf.yield
+  return
+}
+
+
+memref.global "private" constant @__constant_10x2xf32 : memref<10x2xf32> = dense<0.000000e+00>
+
+// CHECK-LABEL: func @fusion_inner_long
+func.func @fusion_inner_long(%arg0: memref<10x2xf32>, %arg1: memref<10x10xf32>, %arg2: memref<10x2xf32>, %s: index) {
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 1.000000e-03 : f32
+  %c9 = arith.constant 9 : index
+  %c10_i32 = arith.constant 10 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %c100_i32 = arith.constant 100 : i32
+  %c0_i32 = arith.constant 0 : i32
+  %0 = memref.get_global @__constant_10x2xf32 : memref<10x2xf32>
+  %1 = scf.for %arg3 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %arg0) -> (memref<10x2xf32>)  : i32 {
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
+    affine.for %arg5 = 0 to 10 {
+      %3 = arith.index_cast %arg5 : index to i32
+      affine.store %3, %alloc[%arg5] : memref<10xi32>
+    }
+    %2 = scf.for %arg5 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg6 = %0) -> (memref<10x2xf32>)  : i32 {
+      %alloc_5 = memref.alloc() : memref<2xf32>
+      affine.for %arg7 = 0 to 2 {
+        %16 = affine.load %arg4[%s, %arg7] : memref<10x2xf32>
+        affine.store %16, %alloc_5[%arg7] : memref<2xf32>
+      }
+      %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<1x2xf32>
+      affine.for %arg7 = 0 to 2 {
+        %16 = affine.load %alloc_5[%arg7] : memref<2xf32>
+        affine.store %16, %alloc_6[0, %arg7] : memref<1x2xf32>
+      }
+      %alloc_7 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
+      affine.for %arg7 = 0 to 10 {
+        affine.for %arg8 = 0 to 2 {
+          %16 = affine.load %alloc_6[0, %arg8] : memref<1x2xf32>
+          affine.store %16, %alloc_7[%arg7, %arg8] : memref<10x2xf32>
+        }
+      }
+      %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
+      affine.for %arg7 = 0 to 10 {
+        affine.for %arg8 = 0 to 2 {
+          %16 = affine.load %alloc_7[%arg7, %arg8] : memref<10x2xf32>
+          %17 = affine.load %arg4[%arg7, %arg8] : memref<10x2xf32>
+          %18 = arith.subf %16, %17 : f32
+          affine.store %18, %alloc_8[%arg7, %arg8] : memref<10x2xf32>
+        }
+      }
+      scf.yield %alloc_8 : memref<10x2xf32>
+      // CHECK:      scf.for
+      // CHECK:        scf.for
+      // CHECK:          affine.for %{{.*}} = 0 to 10
+      // CHECK-NEXT:       affine.for %{{.*}} = 0 to 2
+      // CHECK-NOT:      affine.for
+      // CHECK:          scf.yield
+    }
+    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
+    affine.for %arg5 = 0 to 10 {
+      affine.for %arg6 = 0 to 2 {
+        affine.store %cst_0, %alloc_2[%arg5, %arg6] : memref<10x2xf32>
+      }
+    }
+    %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
+    affine.for %arg5 = 0 to 10 {
+      affine.for %arg6 = 0 to 2 {
+        %3 = affine.load %alloc_2[%arg5, %arg6] : memref<10x2xf32>
+        %4 = affine.load %2[%arg5, %arg6] : memref<10x2xf32>
+        %5 = arith.mulf %3, %4 : f32
+        affine.store %5, %alloc_3[%arg5, %arg6] : memref<10x2xf32>
+      }
+    }
+    scf.yield %alloc_3 : memref<10x2xf32>
+    // The nests above will be fused as well.
+    // CHECK:      affine.for %{{.*}} = 0 to 10
+    // CHECK-NEXT:   affine.for %{{.*}} = 0 to 2
+    // CHECK-NOT:  affine.for
+    // CHECK:      scf.yield
+  }
+  affine.for %arg3 = 0 to 10 {
+    affine.for %arg4 = 0 to 2 {
+      %2 = affine.load %1[%arg3, %arg4] : memref<10x2xf32>
+      affine.store %2, %arg2[%arg3, %arg4] : memref<10x2xf32>
+    }
+  }
+  return
+}



More information about the Mlir-commits mailing list