[Mlir-commits] [mlir] [MLIR][affine] Fix for #115849 Illegal affine loop fusion with vector types (PR #117617)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 26 11:16:22 PST 2024
https://github.com/brod4910 updated https://github.com/llvm/llvm-project/pull/117617
>From c86f5e1039df798e3cd8788066427dd62d96edca Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Mon, 25 Nov 2024 12:32:24 -0700
Subject: [PATCH 1/4] Fix for provided simple reproducer of #115849
---
.../Dialect/Affine/Transforms/LoopFusion.cpp | 73 +++++++++++++++++--
1 file changed, 65 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 6fefe4487ef59a..7e6085c5399e4b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
@@ -23,6 +24,8 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
@@ -177,13 +180,58 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
producerConsumerMemrefs);
}
+static bool containedWithin(Type loadType, Type storeType) {
+ ShapedType loadShapedType = cast<ShapedType>(loadType);
+ ShapedType storeShapedType = cast<ShapedType>(storeType);
+
+ for (int i = 0; i < loadShapedType.getRank(); ++i) {
+ auto loadDim = loadShapedType.getDimSize(i);
+ auto storeDim = storeShapedType.getDimSize(i);
+
+ if (loadDim > storeDim) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static bool
+verifyLoadStoreDomainContainment(unsigned srcId, unsigned dstId,
+ DenseSet<Value> &producerConsumerMemrefs,
+ MemRefDependenceGraph *mdg) {
+ SmallVector<Operation *> storeOps;
+ SmallVector<Operation *> loadOps;
+
+ auto *srcNode = mdg->getNode(srcId);
+ auto *dstNode = mdg->getNode(dstId);
+
+ for (Value memref : producerConsumerMemrefs) {
+ srcNode->getStoreOpsForMemref(memref, &storeOps);
+ dstNode->getLoadOpsForMemref(memref, &loadOps);
+
+ for (Operation *storeOp : storeOps) {
+ Value storeValue = cast<AffineWriteOpInterface>(storeOp).getValueStore();
+
+ for (Operation *loadOp : loadOps) {
+ Value loadValue = cast<AffineReadOpInterface>(loadOp).getValue();
+
+ if (!containedWithin(loadValue.getType(), storeValue.getType())) {
+ return false;
+ }
+ }
+ }
+ }
+
+ return true;
+}
+
/// A memref escapes in the context of the fusion pass if either:
/// 1. it (or its alias) is a block argument, or
/// 2. created by an op not known to guarantee alias freedom,
-/// 3. it (or its alias) are used by ops other than affine dereferencing ops
-/// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
-/// terminator ops, etc.); such ops do not deference the memref in an affine
-/// way.
+/// 3. it (or its alias) are used by ops other than affine dereferencing
+/// ops (e.g., by call op, memref load/store ops, alias creating ops,
+/// unknown ops, terminator ops, etc.); such ops do not deference the
+/// memref in an affine way.
static bool isEscapingMemref(Value memref, Block *block) {
Operation *defOp = memref.getDefiningOp();
// Check if 'memref' is a block argument.
@@ -858,10 +906,19 @@ struct GreedyFusion {
}))
continue;
- // Gather memrefs in 'srcNode' that are written and escape out of the
- // block (e.g., memref block arguments, returned memrefs,
- // memrefs passed to function calls, etc.).
- DenseSet<Value> srcEscapingMemRefs;
+ if (!verifyLoadStoreDomainContainment(srcId, dstId,
+ &producerConsumerMemrefs, mdg)) {
+ LLVM_DEBUG(llvm::dbgs() << "Can't fuse: load domain not contained "
+ "within store domain\n");
+ continue;
+ }
+
+ if (any_of(producerConsumerMemrefs, UnaryPredicate P))
+
+ // Gather memrefs in 'srcNode' that are written and escape out of
+ // the block (e.g., memref block arguments, returned memrefs,
+ // memrefs passed to function calls, etc.).
+ DenseSet<Value> srcEscapingMemRefs;
gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
// Skip if there are non-affine operations in between the 'srcNode'
>From 64532022186463324a6deb2be95da5e1f63b75e2 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Mon, 25 Nov 2024 12:36:33 -0700
Subject: [PATCH 2/4] clear load/store vectors
---
mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 7e6085c5399e4b..b1414023f6afd4 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -220,6 +220,9 @@ verifyLoadStoreDomainContainment(unsigned srcId, unsigned dstId,
}
}
}
+
+ storeOps.clear();
+ loadOps.clear();
}
return true;
>From 71eabce85b7a772516b3c6eaafce2582baa80304 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Tue, 26 Nov 2024 11:40:04 -0700
Subject: [PATCH 3/4] rename function and move to single function
---
.../Dialect/Affine/Transforms/LoopFusion.cpp | 32 +++++++------------
1 file changed, 11 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index b1414023f6afd4..ec1d2941a5d2e2 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -180,23 +180,7 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
producerConsumerMemrefs);
}
-static bool containedWithin(Type loadType, Type storeType) {
- ShapedType loadShapedType = cast<ShapedType>(loadType);
- ShapedType storeShapedType = cast<ShapedType>(storeType);
-
- for (int i = 0; i < loadShapedType.getRank(); ++i) {
- auto loadDim = loadShapedType.getDimSize(i);
- auto storeDim = storeShapedType.getDimSize(i);
-
- if (loadDim > storeDim) {
- return false;
- }
- }
- return true;
-}
-
-static bool
-verifyLoadStoreDomainContainment(unsigned srcId, unsigned dstId,
+static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
DenseSet<Value> &producerConsumerMemrefs,
MemRefDependenceGraph *mdg) {
SmallVector<Operation *> storeOps;
@@ -211,12 +195,18 @@ verifyLoadStoreDomainContainment(unsigned srcId, unsigned dstId,
for (Operation *storeOp : storeOps) {
Value storeValue = cast<AffineWriteOpInterface>(storeOp).getValueStore();
+ ShapedType storeShapedType = cast<ShapedType>(storeValue.getType());
for (Operation *loadOp : loadOps) {
Value loadValue = cast<AffineReadOpInterface>(loadOp).getValue();
+ ShapedType loadShapedType = cast<ShapedType>(loadValue.getType());
+
+ for (int i = 0; i < loadShapedType.getRank(); ++i) {
+ auto loadDim = loadShapedType.getDimSize(i);
+ auto storeDim = storeShapedType.getDimSize(i);
- if (!containedWithin(loadValue.getType(), storeValue.getType())) {
- return false;
+ if (loadDim > storeDim)
+ return false;
}
}
}
@@ -909,8 +899,8 @@ struct GreedyFusion {
}))
continue;
- if (!verifyLoadStoreDomainContainment(srcId, dstId,
- &producerConsumerMemrefs, mdg)) {
+ if (!checkLoadStoreShapes(srcId, dstId, &producerConsumerMemrefs,
+ mdg)) {
LLVM_DEBUG(llvm::dbgs() << "Can't fuse: load domain not contained "
"within store domain\n");
continue;
>From 396dbfc87f9d7a04995dabb73020dcc42e926932 Mon Sep 17 00:00:00 2001
From: brod4910 <brod4910 at gmail.com>
Date: Tue, 26 Nov 2024 12:16:06 -0700
Subject: [PATCH 4/4] Add comment to function
---
mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index ec1d2941a5d2e2..bc0d9da9c15e07 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -180,6 +180,10 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
producerConsumerMemrefs);
}
+/// Checks the shapes of the loads and stores of each memref in
+/// the producer/consumer chains. If the load shapes are larger
+/// than the stores then we cannot fuse the loops. The loads
+/// would have a dependency on the values stored.
static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
DenseSet<Value> &producerConsumerMemrefs,
MemRefDependenceGraph *mdg) {
@@ -194,7 +198,8 @@ static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
dstNode->getLoadOpsForMemref(memref, &loadOps);
for (Operation *storeOp : storeOps) {
- Value storeValue = cast<AffineWriteOpInterface>(storeOp).getValueStore();
+ Value storeValue =
+ cast<AffineWriteOpInterface>(storeOp).getValueToStore();
ShapedType storeShapedType = cast<ShapedType>(storeValue.getType());
for (Operation *loadOp : loadOps) {
More information about the Mlir-commits
mailing list