[Mlir-commits] [mlir] [MLIR][affine] Fix for #115849 Illegal affine loop fusion with vector types (PR #117617)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 3 14:22:53 PST 2024


https://github.com/brod4910 updated https://github.com/llvm/llvm-project/pull/117617

>From c86f5e1039df798e3cd8788066427dd62d96edca Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Mon, 25 Nov 2024 12:32:24 -0700
Subject: [PATCH 1/6] Fix for provided simple reproducer of #115849

---
 .../Dialect/Affine/Transforms/LoopFusion.cpp  | 73 +++++++++++++++++--
 1 file changed, 65 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 6fefe4487ef59a..7e6085c5399e4b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
 #include "mlir/Dialect/Affine/Passes.h"
 
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
@@ -23,6 +24,8 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
@@ -177,13 +180,58 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
                                 producerConsumerMemrefs);
 }
 
+static bool containedWithin(Type loadType, Type storeType) {
+  ShapedType loadShapedType = cast<ShapedType>(loadType);
+  ShapedType storeShapedType = cast<ShapedType>(storeType);
+
+  for (int i = 0; i < loadShapedType.getRank(); ++i) {
+    auto loadDim = loadShapedType.getDimSize(i);
+    auto storeDim = storeShapedType.getDimSize(i);
+
+    if (loadDim > storeDim) {
+      return false;
+    }
+  }
+  return true;
+}
+
+static bool
+verifyLoadStoreDomainContainment(unsigned srcId, unsigned dstId,
+                                 DenseSet<Value> &producerConsumerMemrefs,
+                                 MemRefDependenceGraph *mdg) {
+  SmallVector<Operation *> storeOps;
+  SmallVector<Operation *> loadOps;
+
+  auto *srcNode = mdg->getNode(srcId);
+  auto *dstNode = mdg->getNode(dstId);
+
+  for (Value memref : producerConsumerMemrefs) {
+    srcNode->getStoreOpsForMemref(memref, &storeOps);
+    dstNode->getLoadOpsForMemref(memref, &loadOps);
+
+    for (Operation *storeOp : storeOps) {
+      Value storeValue = cast<AffineWriteOpInterface>(storeOp).getValueStore();
+
+      for (Operation *loadOp : loadOps) {
+        Value loadValue = cast<AffineReadOpInterface>(loadOp).getValue();
+
+        if (!containedWithin(loadValue.getType(), storeValue.getType())) {
+          return false;
+        }
+      }
+    }
+  }
+
+  return true;
+}
+
 /// A memref escapes in the context of the fusion pass if either:
 ///   1. it (or its alias) is a block argument, or
 ///   2. created by an op not known to guarantee alias freedom,
-///   3. it (or its alias) are used by ops other than affine dereferencing ops
-///   (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
-///   terminator ops, etc.); such ops do not deference the memref in an affine
-///   way.
+///   3. it (or its alias) are used by ops other than affine dereferencing
+///   ops (e.g., by call op, memref load/store ops, alias creating ops,
+///   unknown ops, terminator ops, etc.); such ops do not deference the
+///   memref in an affine way.
 static bool isEscapingMemref(Value memref, Block *block) {
   Operation *defOp = memref.getDefiningOp();
   // Check if 'memref' is a block argument.
@@ -858,10 +906,19 @@ struct GreedyFusion {
             }))
           continue;
 
-        // Gather memrefs in 'srcNode' that are written and escape out of the
-        // block (e.g., memref block arguments, returned memrefs,
-        // memrefs passed to function calls, etc.).
-        DenseSet<Value> srcEscapingMemRefs;
+        if (!verifyLoadStoreDomainContainment(srcId, dstId,
+                                              &producerConsumerMemrefs, mdg)) {
+          LLVM_DEBUG(llvm::dbgs() << "Can't fuse: load domain not contained "
+                                     "within store domain\n");
+          continue;
+        }
+
+        if (any_of(producerConsumerMemrefs, UnaryPredicate P))
+
+          // Gather memrefs in 'srcNode' that are written and escape out of
+          // the block (e.g., memref block arguments, returned memrefs,
+          // memrefs passed to function calls, etc.).
+          DenseSet<Value> srcEscapingMemRefs;
         gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
 
         // Skip if there are non-affine operations in between the 'srcNode'

>From 64532022186463324a6deb2be95da5e1f63b75e2 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Mon, 25 Nov 2024 12:36:33 -0700
Subject: [PATCH 2/6] clear load/store vectors

---
 mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 7e6085c5399e4b..b1414023f6afd4 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -220,6 +220,9 @@ verifyLoadStoreDomainContainment(unsigned srcId, unsigned dstId,
         }
       }
     }
+
+    storeOps.clear();
+    loadOps.clear();
   }
 
   return true;

>From 71eabce85b7a772516b3c6eaafce2582baa80304 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Tue, 26 Nov 2024 11:40:04 -0700
Subject: [PATCH 3/6] rename function and move to single function

---
 .../Dialect/Affine/Transforms/LoopFusion.cpp  | 32 +++++++------------
 1 file changed, 11 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index b1414023f6afd4..ec1d2941a5d2e2 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -180,23 +180,7 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
                                 producerConsumerMemrefs);
 }
 
-static bool containedWithin(Type loadType, Type storeType) {
-  ShapedType loadShapedType = cast<ShapedType>(loadType);
-  ShapedType storeShapedType = cast<ShapedType>(storeType);
-
-  for (int i = 0; i < loadShapedType.getRank(); ++i) {
-    auto loadDim = loadShapedType.getDimSize(i);
-    auto storeDim = storeShapedType.getDimSize(i);
-
-    if (loadDim > storeDim) {
-      return false;
-    }
-  }
-  return true;
-}
-
-static bool
-verifyLoadStoreDomainContainment(unsigned srcId, unsigned dstId,
+static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
                                  DenseSet<Value> &producerConsumerMemrefs,
                                  MemRefDependenceGraph *mdg) {
   SmallVector<Operation *> storeOps;
@@ -211,12 +195,18 @@ verifyLoadStoreDomainContainment(unsigned srcId, unsigned dstId,
 
     for (Operation *storeOp : storeOps) {
       Value storeValue = cast<AffineWriteOpInterface>(storeOp).getValueStore();
+      ShapedType storeShapedType = cast<ShapedType>(storeValue.getType());
 
       for (Operation *loadOp : loadOps) {
         Value loadValue = cast<AffineReadOpInterface>(loadOp).getValue();
+        ShapedType loadShapedType = cast<ShapedType>(loadValue.getType());
+
+        for (int i = 0; i < loadShapedType.getRank(); ++i) {
+          auto loadDim = loadShapedType.getDimSize(i);
+          auto storeDim = storeShapedType.getDimSize(i);
 
-        if (!containedWithin(loadValue.getType(), storeValue.getType())) {
-          return false;
+          if (loadDim > storeDim)
+            return false;
         }
       }
     }
@@ -909,8 +899,8 @@ struct GreedyFusion {
             }))
           continue;
 
-        if (!verifyLoadStoreDomainContainment(srcId, dstId,
-                                              &producerConsumerMemrefs, mdg)) {
+        if (!checkLoadStoreShapes(srcId, dstId, &producerConsumerMemrefs,
+                                  mdg)) {
           LLVM_DEBUG(llvm::dbgs() << "Can't fuse: load domain not contained "
                                      "within store domain\n");
           continue;

>From 396dbfc87f9d7a04995dabb73020dcc42e926932 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Tue, 26 Nov 2024 12:16:06 -0700
Subject: [PATCH 4/6] Add comment to function

---
 mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index ec1d2941a5d2e2..bc0d9da9c15e07 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -180,6 +180,10 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
                                 producerConsumerMemrefs);
 }
 
+/// Checks the shapes of the loads and stores of each memref in
+/// the producer/consumer chains. If the load shapes are larger
+/// than the stores then we cannot fuse the loops. The loads
+/// would have a dependency on the values stored.
 static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
                                  DenseSet<Value> &producerConsumerMemrefs,
                                  MemRefDependenceGraph *mdg) {
@@ -194,7 +198,8 @@ static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
     dstNode->getLoadOpsForMemref(memref, &loadOps);
 
     for (Operation *storeOp : storeOps) {
-      Value storeValue = cast<AffineWriteOpInterface>(storeOp).getValueStore();
+      Value storeValue =
+          cast<AffineWriteOpInterface>(storeOp).getValueToStore();
       ShapedType storeShapedType = cast<ShapedType>(storeValue.getType());
 
       for (Operation *loadOp : loadOps) {

>From f2c251cda15820c6bf9ca3d1ca8e471abf284a41 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Thu, 28 Nov 2024 01:39:53 -0700
Subject: [PATCH 5/6] check if values are shaped types

---
 .../Dialect/Affine/Transforms/LoopFusion.cpp  | 28 +++++++++++--------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index bc0d9da9c15e07..42d209917b0385 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -200,11 +200,17 @@ static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
     for (Operation *storeOp : storeOps) {
       Value storeValue =
           cast<AffineWriteOpInterface>(storeOp).getValueToStore();
-      ShapedType storeShapedType = cast<ShapedType>(storeValue.getType());
+      auto storeShapedType = dyn_cast<ShapedType>(storeValue.getType());
+
+      if (!storeShapedType)
+        continue;
 
       for (Operation *loadOp : loadOps) {
         Value loadValue = cast<AffineReadOpInterface>(loadOp).getValue();
-        ShapedType loadShapedType = cast<ShapedType>(loadValue.getType());
+        auto loadShapedType = dyn_cast<ShapedType>(loadValue.getType());
+
+        if (!loadShapedType)
+          continue;
 
         for (int i = 0; i < loadShapedType.getRank(); ++i) {
           auto loadDim = loadShapedType.getDimSize(i);
@@ -904,19 +910,17 @@ struct GreedyFusion {
             }))
           continue;
 
-        if (!checkLoadStoreShapes(srcId, dstId, &producerConsumerMemrefs,
-                                  mdg)) {
-          LLVM_DEBUG(llvm::dbgs() << "Can't fuse: load domain not contained "
-                                     "within store domain\n");
+        if (!checkLoadStoreShapes(srcId, dstId, producerConsumerMemrefs, mdg)) {
+          LLVM_DEBUG(
+              llvm::dbgs()
+              << "Can't fuse: load dependent on a larger store region\n");
           continue;
         }
 
-        if (any_of(producerConsumerMemrefs, UnaryPredicate P))
-
-          // Gather memrefs in 'srcNode' that are written and escape out of
-          // the block (e.g., memref block arguments, returned memrefs,
-          // memrefs passed to function calls, etc.).
-          DenseSet<Value> srcEscapingMemRefs;
+        // Gather memrefs in 'srcNode' that are written and escape out of
+        // the block (e.g., memref block arguments, returned memrefs,
+        // memrefs passed to function calls, etc.).
+        DenseSet<Value> srcEscapingMemRefs;
         gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
 
         // Skip if there are non-affine operations in between the 'srcNode'

>From 4e2177c6e51a2e0be912eaf3e03d51f4f1fce555 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Tue, 3 Dec 2024 15:22:39 -0700
Subject: [PATCH 6/6] add lit test for simple reproducer

---
 mlir/test/Dialect/Affine/loop-fusion-4.mlir | 47 +++++++++++++++++++++
 1 file changed, 47 insertions(+)

diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index f46ad0f5e4c232..e3e490007a361c 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -285,3 +285,50 @@ module {
     spirv.ReturnValue %3 : !spirv.array<8192 x f32>
   }
 }
+
+// -----
+
+// Basic test for not fusing loops where a vector load depends on 
+// the entire result of a previous loop. store shape < load shape
+
+// CHECK-LABEL: func @should_not_fuse_across_memref_store_load_bounds() {
+func.func @should_not_fuse_across_memref_store_load_bounds() {
+  %a = memref.alloc() : memref<64x512xf32> 
+  %b = memref.alloc() : memref<64x512xf32>
+  %c = memref.alloc() : memref<64x512xf32> 
+  %d = memref.alloc() : memref<64x4096xf32>
+  %e = memref.alloc() : memref<64x4096xf32>
+
+  affine.for %j = 0 to 8 {
+      %lhs = affine.vector_load %a[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %rhs = affine.vector_load %b[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+      %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+      affine.vector_store %res, %c[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+  }
+
+  affine.for %j = 0 to 8 {
+      %lhs = affine.vector_load %c[0, 0] : memref<64x512xf32>, vector<64x512xf32>
+      %rhs = affine.vector_load %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+      %res = arith.subf %lhs, %rhs : vector<64x512xf32>
+      affine.vector_store %res, %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+  }
+
+  return
+}
+// CHECK-LABEL: func.func @should_not_fuse_across_memref_store_load_bounds
+// CHECK: [[a:%[0-9]+]] = memref.alloc() : memref<64x512xf32>
+// CHECK: [[b:%[0-9]+]] = memref.alloc() : memref<64x512xf32>
+// CHECK: [[c:%[0-9]+]] = memref.alloc() : memref<64x512xf32>
+// CHECK: [[d:%[0-9]+]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: [[e:%[0-9]+]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: affine.for %[[j:[a-z0-9]+]] = 0 to 8
+// CHECK: %[[lhs:[a-z0-9]+]] = affine.vector_load [[a]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[rhs:[a-z0-9]+]] = affine.vector_load [[b]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[res:[a-z0-9]+]] = arith.addf %[[lhs]], %[[rhs]] : vector<64x64xf32>
+// CHECK: affine.vector_store %[[res]], [[c]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: affine.for %[[j_2:[a-z0-9]+]] = 0 to 8
+// CHECK: %[[lhs_2:[a-z0-9]+]] = affine.vector_load [[c]][0, 0] : memref<64x512xf32>, vector<64x512xf32>
+// CHECK: %[[rhs_2:[a-z0-9]+]] = affine.vector_load [[d]][0, %[[j_2]] * 512] : memref<64x4096xf32>, vector<64x512xf32>
+// CHECK: %[[res_2:[a-z0-9]+]] = arith.subf %[[lhs_2]], %[[rhs_2]] : vector<64x512xf32>
+// CHECK: affine.vector_store %[[res_2]], [[d]][0, %[[j_2]] * 512] : memref<64x4096xf32>, vector<64x512xf32>
+// CHECK: return
\ No newline at end of file



More information about the Mlir-commits mailing list