[Mlir-commits] [mlir] [mlir][affine]fix create affine.for bug. (PR #117721)
lonely eagle
llvmlistbot at llvm.org
Wed Nov 27 07:10:48 PST 2024
================
@@ -352,9 +353,13 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
// Conservatively handle remaining BlockArguments as non-valid symbols.
// E.g. scf.for iterArgs.
- if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
- return false;
-
+ if (auto blockArgument =
+ llvm::dyn_cast<BlockArgument>(dimOp.getShapedValue())) {
+ if (!llvm::isa<FunctionOpInterface>(
+ blockArgument.getParentRegion()->getParentOp())) {
----------------
linuxlonelyeagle wrote:
You can parse the following IR with mlir-opt, which will depart the bug, I found that I can depart the bug via generic IR.
In that case, I can write tests too.
```
#map = affine_map<()[s0] -> (s0)>
"builtin.module"() ({
"gpu.module"() <{sym_name = "gpu"}> ({
"gpu.func"() <{function_type = (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()}> ({
^bb0(%arg3: memref<?x?xf32>, %arg4: memref<?x?xf32>, %arg5: memref<?x?xf32>):
%16 = "arith.constant"() <{value = 1 : index}> : () -> index
%17 = "memref.dim"(%arg3, %16) : (memref<?x?xf32>, index) -> index
%18 = "arith.constant"() <{value = 0 : index}> : () -> index
"affine.for"(%18, %17) <{lowerBoundMap = #map, operandSegmentSizes = array<i32: 1, 1, 0>, step = 32 : index, upperBoundMap = #map}> ({
^bb0(%arg6: index):
"affine.yield"() : () -> ()
}) : (index, index) -> ()
"gpu.return"() : () -> ()
}) {gpu.kernel, sym_name = "gemm", workgroup_attributions = 0 : i64} : () -> ()
}) : () -> ()
"func.func"() <{function_type = (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> f32, sym_name = "main"}> ({
^bb0(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>):
%0 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
%1 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
%2 = "arith.constant"() <{value = 2.000000e+00 : f32}> : () -> f32
%3 = "arith.constant"() <{value = 0 : index}> : () -> index
%4 = "memref.dim"(%arg0, %3) : (memref<?x?xf32>, index) -> index
%5 = "arith.constant"() <{value = 1 : index}> : () -> index
%6 = "memref.dim"(%arg0, %5) : (memref<?x?xf32>, index) -> index
%7 = "arith.constant"() <{value = 1 : index}> : () -> index
%8 = "memref.dim"(%arg1, %7) : (memref<?x?xf32>, index) -> index
%9 = "arith.constant"() <{value = 128 : index}> : () -> index
%10 = "arith.ceildivui"(%4, %9) : (index, index) -> index
%11 = "arith.constant"() <{value = 64 : index}> : () -> index
%12 = "arith.ceildivsi"(%6, %11) : (index, index) -> index
%13 = "arith.constant"() <{value = 256 : index}> : () -> index
%14 = "arith.constant"() <{value = 262144 : i32}> : () -> i32
%15 = "arith.constant"() <{value = 1 : index}> : () -> index
"gpu.launch_func"(%12, %10, %15, %13, %15, %15, %14, %arg0, %arg1, %arg2) <{kernel = @gpu::@gemm, operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 3, 0>}> : (index, index, index, index, index, index, i32, memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
"func.return"(%0) : (f32) -> ()
}) : () -> ()
}) {gpu.container_module} : () -> ()
```
But in that case, there is another question I'd like to ask, which I'm not thinking about very clearly.
https://github.com/llvm/llvm-project/pull/117721
More information about the Mlir-commits
mailing list