[Mlir-commits] [mlir] [MLIR][affine] Fix for #115849 Illegal affine loop fusion with vector types (PR #117617)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 3 16:10:05 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-affine
Author: None (brod4910)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/117617.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+65-6)
- (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+47)
``````````diff
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 6fefe4487ef59a..42d209917b0385 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
@@ -23,6 +24,8 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
@@ -177,13 +180,62 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
producerConsumerMemrefs);
}
+/// Checks the shapes of the loads and stores of each memref in
+/// the producer/consumer chains. If the load shapes are larger
+/// than the stores then we cannot fuse the loops. The loads
+/// would have a dependency on the values stored.
+static bool checkLoadStoreShapes(unsigned srcId, unsigned dstId,
+ DenseSet<Value> &producerConsumerMemrefs,
+ MemRefDependenceGraph *mdg) {
+ SmallVector<Operation *> storeOps;
+ SmallVector<Operation *> loadOps;
+
+ auto *srcNode = mdg->getNode(srcId);
+ auto *dstNode = mdg->getNode(dstId);
+
+ for (Value memref : producerConsumerMemrefs) {
+ srcNode->getStoreOpsForMemref(memref, &storeOps);
+ dstNode->getLoadOpsForMemref(memref, &loadOps);
+
+ for (Operation *storeOp : storeOps) {
+ Value storeValue =
+ cast<AffineWriteOpInterface>(storeOp).getValueToStore();
+ auto storeShapedType = dyn_cast<ShapedType>(storeValue.getType());
+
+ if (!storeShapedType)
+ continue;
+
+ for (Operation *loadOp : loadOps) {
+ Value loadValue = cast<AffineReadOpInterface>(loadOp).getValue();
+ auto loadShapedType = dyn_cast<ShapedType>(loadValue.getType());
+
+ if (!loadShapedType)
+ continue;
+
+ for (int i = 0; i < loadShapedType.getRank(); ++i) {
+ auto loadDim = loadShapedType.getDimSize(i);
+ auto storeDim = storeShapedType.getDimSize(i);
+
+ if (loadDim > storeDim)
+ return false;
+ }
+ }
+ }
+
+ storeOps.clear();
+ loadOps.clear();
+ }
+
+ return true;
+}
+
/// A memref escapes in the context of the fusion pass if either:
/// 1. it (or its alias) is a block argument, or
/// 2. created by an op not known to guarantee alias freedom,
-/// 3. it (or its alias) are used by ops other than affine dereferencing ops
-/// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
-/// terminator ops, etc.); such ops do not deference the memref in an affine
-/// way.
+/// 3. it (or its alias) are used by ops other than affine dereferencing
+/// ops (e.g., by call op, memref load/store ops, alias creating ops,
+/// unknown ops, terminator ops, etc.); such ops do not deference the
+/// memref in an affine way.
static bool isEscapingMemref(Value memref, Block *block) {
Operation *defOp = memref.getDefiningOp();
// Check if 'memref' is a block argument.
@@ -858,8 +910,15 @@ struct GreedyFusion {
}))
continue;
- // Gather memrefs in 'srcNode' that are written and escape out of the
- // block (e.g., memref block arguments, returned memrefs,
+ if (!checkLoadStoreShapes(srcId, dstId, producerConsumerMemrefs, mdg)) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "Can't fuse: load dependent on a larger store region\n");
+ continue;
+ }
+
+ // Gather memrefs in 'srcNode' that are written and escape out of
+ // the block (e.g., memref block arguments, returned memrefs,
// memrefs passed to function calls, etc.).
DenseSet<Value> srcEscapingMemRefs;
gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index f46ad0f5e4c232..e3e490007a361c 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -285,3 +285,50 @@ module {
spirv.ReturnValue %3 : !spirv.array<8192 x f32>
}
}
+
+// -----
+
+// Basic test for not fusing loops where a vector load depends on
+// the entire result of a previous loop. store shape < load shape
+
+// CHECK-LABEL: func @should_not_fuse_across_memref_store_load_bounds() {
+func.func @should_not_fuse_across_memref_store_load_bounds() {
+ %a = memref.alloc() : memref<64x512xf32>
+ %b = memref.alloc() : memref<64x512xf32>
+ %c = memref.alloc() : memref<64x512xf32>
+ %d = memref.alloc() : memref<64x4096xf32>
+ %e = memref.alloc() : memref<64x4096xf32>
+
+ affine.for %j = 0 to 8 {
+ %lhs = affine.vector_load %a[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %rhs = affine.vector_load %b[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ %res = arith.addf %lhs, %rhs : vector<64x64xf32>
+ affine.vector_store %res, %c[0, %j * 64] : memref<64x512xf32>, vector<64x64xf32>
+ }
+
+ affine.for %j = 0 to 8 {
+ %lhs = affine.vector_load %c[0, 0] : memref<64x512xf32>, vector<64x512xf32>
+ %rhs = affine.vector_load %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+ %res = arith.subf %lhs, %rhs : vector<64x512xf32>
+ affine.vector_store %res, %d[0, %j * 512] : memref<64x4096xf32>, vector<64x512xf32>
+ }
+
+ return
+}
+// CHECK-LABEL: func.func @should_not_fuse_across_memref_store_load_bounds
+// CHECK: [[a:%[0-9]+]] = memref.alloc() : memref<64x512xf32>
+// CHECK: [[b:%[0-9]+]] = memref.alloc() : memref<64x512xf32>
+// CHECK: [[c:%[0-9]+]] = memref.alloc() : memref<64x512xf32>
+// CHECK: [[d:%[0-9]+]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: [[e:%[0-9]+]] = memref.alloc() : memref<64x4096xf32>
+// CHECK: affine.for %[[j:[a-z0-9]+]] = 0 to 8
+// CHECK: %[[lhs:[a-z0-9]+]] = affine.vector_load [[a]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[rhs:[a-z0-9]+]] = affine.vector_load [[b]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: %[[res:[a-z0-9]+]] = arith.addf %[[lhs]], %[[rhs]] : vector<64x64xf32>
+// CHECK: affine.vector_store %[[res]], [[c]][0, %[[j]] * 64] : memref<64x512xf32>, vector<64x64xf32>
+// CHECK: affine.for %[[j_2:[a-z0-9]+]] = 0 to 8
+// CHECK: %[[lhs_2:[a-z0-9]+]] = affine.vector_load [[c]][0, 0] : memref<64x512xf32>, vector<64x512xf32>
+// CHECK: %[[rhs_2:[a-z0-9]+]] = affine.vector_load [[d]][0, %[[j_2]] * 512] : memref<64x4096xf32>, vector<64x512xf32>
+// CHECK: %[[res_2:[a-z0-9]+]] = arith.subf %[[lhs_2]], %[[rhs_2]] : vector<64x512xf32>
+// CHECK: affine.vector_store %[[res_2]], [[d]][0, %[[j_2]] * 512] : memref<64x4096xf32>, vector<64x512xf32>
+// CHECK: return
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/117617
More information about the Mlir-commits
mailing list