[Mlir-commits] [mlir] [mlir][affine]fix create affine.for bug. (PR #117721)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 26 07:14:24 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir
Author: lonely eagle (linuxlonelyeagle)
<details>
<summary>Changes</summary>
I encountered this in the pass I wrote.
```
matmul.mlir:5:3: error: 'affine.for' op operand cannot be used as a symbol
linalg.matmul
^
matmul.mlir:5:3: note: see current operation:
"affine.for"(%9, %8) <{lowerBoundMap = affine_map<()[s0] -> (s0)>, operandSegmentSizes = array<i32: 1, 1, 0>, step = 32 : index, upperBoundMap = affine_map<()[s0] -> (s0)>}> ({
^bb0(%arg3: index):
"affine.yield"() : () -> ()
}) : (index, index) -> ()
make: *** [makefile:56: gemm-opt-matmul-lower] Error 1
```
This is because affinemap is used in the lower-bound or upper-bound of create affine.for, and the symbol for affinemap comes from a memref.dim whose memref is a function argument, affine.for check will be failed.
Something like the following, but the code below doesn't make sense. What I'm trying to say is that I created such affine.for in pass encountered the above bug. but it's worth mentioning that if you write the following IR by hand, there is no problem. So I didn't add a test.
```
#map = affine_map<()[s0] -> (s0)>
func.func @<!-- -->func(%A : memref<32x128xf32>) {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
%dim_0 = memref.dim %A, %0 : memref<32x128xf32>
%dim_1 = memref.dim %A, %1 : memref<32x128xf32>
affine.for %it = #map()[%dim_0] to #map()[%dim_1] {
}
return
}
```
---
Full diff: https://github.com/llvm/llvm-project/pull/117721.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+13-3)
``````````diff
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..0d24e434328419 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -352,9 +353,13 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
// Conservatively handle remaining BlockArguments as non-valid symbols.
// E.g. scf.for iterArgs.
- if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
- return false;
-
+ if (auto blockArgument =
+ llvm::dyn_cast<BlockArgument>(dimOp.getShapedValue())) {
+ if (!llvm::isa<FunctionOpInterface>(
+ blockArgument.getParentRegion()->getParentOp())) {
+ return false;
+ }
+ }
// The dim op is also okay if its operand memref is a view/subview whose
// corresponding size is a valid symbol.
std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
@@ -365,6 +370,11 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
// Skip over all memref.cast ops (if any).
Operation *op = dimOp.getShapedValue().getDefiningOp();
+
+ // the ShapedValue of the dim is the function block argument.
+ if (!op)
+ return true;
+
while (auto castOp = dyn_cast<memref::CastOp>(op)) {
// Bail on unranked memrefs.
if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
``````````
</details>
https://github.com/llvm/llvm-project/pull/117721
More information about the Mlir-commits
mailing list