[Mlir-commits] [mlir] [mlir][affine]fix create affine.for bug. (PR #117721)
lonely eagle
llvmlistbot at llvm.org
Tue Nov 26 07:13:41 PST 2024
https://github.com/linuxlonelyeagle created https://github.com/llvm/llvm-project/pull/117721
I encountered this in the pass I wrote.
```
matmul.mlir:5:3: error: 'affine.for' op operand cannot be used as a symbol
linalg.matmul
^
matmul.mlir:5:3: note: see current operation:
"affine.for"(%9, %8) <{lowerBoundMap = affine_map<()[s0] -> (s0)>, operandSegmentSizes = array<i32: 1, 1, 0>, step = 32 : index, upperBoundMap = affine_map<()[s0] -> (s0)>}> ({
^bb0(%arg3: index):
"affine.yield"() : () -> ()
}) : (index, index) -> ()
make: *** [makefile:56: gemm-opt-matmul-lower] Error 1
```
This is because affinemap is used in the lower-bound or upper-bound of create affine.for, and the symbol for affinemap comes from a memref.dim whose memref is a function argument, affine.for check will be failed.
Something like the following, but the code below doesn't make sense. What I'm trying to say is that I created such affine.for in pass encountered the above bug. but it's worth mentioning that if you write the following IR by hand, there is no problem. So I didn't add a test.
```
#map = affine_map<()[s0] -> (s0)>
func.func @func(%A : memref<32x128xf32>) {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%dim_0 = memref.dim %A, %0 : memref<32x128xf32>
%dim_1 = memref.dim %A, %1 : memref<32x128xf32>
affine.for %it = #map()[%dim_0] to #map()[%dim_1] {
}
return
}
```
>From c069ddd70519fb8016d6b9d60ed80860552c721c Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 26 Nov 2024 23:02:13 +0800
Subject: [PATCH] fix create affine.for bug.
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 16 +++++++++++++---
1 file changed, 13 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..0d24e434328419 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -352,9 +353,13 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
// Conservatively handle remaining BlockArguments as non-valid symbols.
// E.g. scf.for iterArgs.
- if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
- return false;
-
+ if (auto blockArgument =
+ llvm::dyn_cast<BlockArgument>(dimOp.getShapedValue())) {
+ if (!llvm::isa<FunctionOpInterface>(
+ blockArgument.getParentRegion()->getParentOp())) {
+ return false;
+ }
+ }
// The dim op is also okay if its operand memref is a view/subview whose
// corresponding size is a valid symbol.
std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
@@ -365,6 +370,11 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
// Skip over all memref.cast ops (if any).
Operation *op = dimOp.getShapedValue().getDefiningOp();
+
+ // the ShapedValue of the dim is the function block argument.
+ if (!op)
+ return true;
+
while (auto castOp = dyn_cast<memref::CastOp>(op)) {
// Bail on unranked memrefs.
if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
More information about the Mlir-commits
mailing list