[Mlir-commits] [mlir] [mlir][affine]fix create affine.for bug. (PR #117721)

lonely eagle llvmlistbot at llvm.org
Wed Nov 27 07:10:48 PST 2024


================
@@ -352,9 +353,13 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
 
   // Conservatively handle remaining BlockArguments as non-valid symbols.
   // E.g. scf.for iterArgs.
-  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
-    return false;
-
+  if (auto blockArgument =
+          llvm::dyn_cast<BlockArgument>(dimOp.getShapedValue())) {
+    if (!llvm::isa<FunctionOpInterface>(
+            blockArgument.getParentRegion()->getParentOp())) {
----------------
linuxlonelyeagle wrote:

You can parse the following IR with mlir-opt, which will depart the bug, I found that I can depart the bug via generic IR.
In that case, I can write tests too.
```
#map = affine_map<()[s0] -> (s0)>
"builtin.module"() ({
  "gpu.module"() <{sym_name = "gpu"}> ({
    "gpu.func"() <{function_type = (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()}> ({
    ^bb0(%arg3: memref<?x?xf32>, %arg4: memref<?x?xf32>, %arg5: memref<?x?xf32>):
      %16 = "arith.constant"() <{value = 1 : index}> : () -> index
      %17 = "memref.dim"(%arg3, %16) : (memref<?x?xf32>, index) -> index
      %18 = "arith.constant"() <{value = 0 : index}> : () -> index
      "affine.for"(%18, %17) <{lowerBoundMap = #map, operandSegmentSizes = array<i32: 1, 1, 0>, step = 32 : index, upperBoundMap = #map}> ({
      ^bb0(%arg6: index):
        "affine.yield"() : () -> ()
      }) : (index, index) -> ()
      "gpu.return"() : () -> ()
    }) {gpu.kernel, sym_name = "gemm", workgroup_attributions = 0 : i64} : () -> ()
  }) : () -> ()
  "func.func"() <{function_type = (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> f32, sym_name = "main"}> ({
  ^bb0(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>):
    %0 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
    %1 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
    %2 = "arith.constant"() <{value = 2.000000e+00 : f32}> : () -> f32
    %3 = "arith.constant"() <{value = 0 : index}> : () -> index
    %4 = "memref.dim"(%arg0, %3) : (memref<?x?xf32>, index) -> index
    %5 = "arith.constant"() <{value = 1 : index}> : () -> index
    %6 = "memref.dim"(%arg0, %5) : (memref<?x?xf32>, index) -> index
    %7 = "arith.constant"() <{value = 1 : index}> : () -> index
    %8 = "memref.dim"(%arg1, %7) : (memref<?x?xf32>, index) -> index
    %9 = "arith.constant"() <{value = 128 : index}> : () -> index
    %10 = "arith.ceildivui"(%4, %9) : (index, index) -> index
    %11 = "arith.constant"() <{value = 64 : index}> : () -> index
    %12 = "arith.ceildivsi"(%6, %11) : (index, index) -> index
    %13 = "arith.constant"() <{value = 256 : index}> : () -> index
    %14 = "arith.constant"() <{value = 262144 : i32}> : () -> i32
    %15 = "arith.constant"() <{value = 1 : index}> : () -> index
    "gpu.launch_func"(%12, %10, %15, %13, %15, %15, %14, %arg0, %arg1, %arg2) <{kernel = @gpu::@gemm, operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 3, 0>}> : (index, index, index, index, index, index, i32, memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
    "func.return"(%0) : (f32) -> ()
  }) : () -> ()
}) {gpu.container_module} : () -> ()
```
But in that case, there is another question I'd like to ask, which I'm not thinking about very clearly.


https://github.com/llvm/llvm-project/pull/117721


More information about the Mlir-commits mailing list