[Mlir-commits] [mlir] Fix assertion when tiling linalg.generic (PR #114688)

Christopher Bate llvmlistbot at llvm.org
Wed Nov 13 21:52:22 PST 2024


christopherbate wrote:

Removing the assertion will not result in a correct result:


```
within split at test.mlir:1 offset :8:8: error: 'linalg.generic' op inferred input/output operand #0 has shape's dimension #0 to be greater than or equal to 8, but found 7
  %0 = linalg.generic {
       ^
within split at test.mlir:1 offset :8:8: note: see current operation: 
%7 = "linalg.generic"(%5, %6) <{indexing_maps = [affine_map<(d0) -> (-d0 + 7)>, affine_map<(d0) -> (d0)>], iterator_types = [#linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg4: i8, %arg5: i8):
  "linalg.yield"(%arg4) : (i8) -> ()
}) : (tensor<7xi8>, tensor<2xi8>) -> tensor<2xi8>
// -----// IR Dump After InterpreterPass Failed (transform-interpreter) //----- //
#map = affine_map<(d0) -> (-d0 + 7)>
#map1 = affine_map<(d0) -> (d0)>
"builtin.module"() ({
  "func.func"() <{function_type = (tensor<8xi8>, tensor<8xi8>) -> tensor<8xi8>, sym_name = "test"}> ({
  ^bb0(%arg1: tensor<8xi8>, %arg2: tensor<8xi8>):
    %2 = "arith.constant"() <{value = 0 : index}> : () -> index
    %3 = "arith.constant"() <{value = 8 : index}> : () -> index
    %4 = "arith.constant"() <{value = 2 : index}> : () -> index
    %5 = "scf.for"(%2, %3, %4, %arg2) ({
    ^bb0(%arg3: index, %arg4: tensor<8xi8>):
      %6 = "affine.apply"(%arg3) <{map = #map}> : (index) -> index
      %7 = "tensor.extract_slice"(%arg1, %6) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 7>, static_strides = array<i64: 1>}> : (tensor<8xi8>, index) -> tensor<7xi8>
      %8 = "tensor.extract_slice"(%arg4, %arg3) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 2>, static_strides = array<i64: 1>}> : (tensor<8xi8>, index) -> tensor<2xi8>
      %9 = "linalg.generic"(%7, %8) <{indexing_maps = [#map, #map1], iterator_types = [#linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
      ^bb0(%arg5: i8, %arg6: i8):
        "linalg.yield"(%arg5) : (i8) -> ()
      }) : (tensor<7xi8>, tensor<2xi8>) -> tensor<2xi8>
      %10 = "tensor.insert_slice"(%9, %arg4, %arg3) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 2>, static_strides = array<i64: 1>}> : (tensor<2xi8>, tensor<8xi8>, index) -> tensor<8xi8>
      "scf.yield"(%10) : (tensor<8xi8>) -> ()
    }) : (index, index, index, tensor<8xi8>) -> tensor<8xi8>
    "func.return"(%5) : (tensor<8xi8>) -> ()
  }) : () -> ()
  "builtin.module"() ({
    "transform.named_sequence"() <{arg_attrs = [{transform.readonly}], function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
    ^bb0(%arg0: !transform.any_op):
      %0 = "transform.structured.match"(%arg0) <{ops = ["linalg.generic"]}> : (!transform.any_op) -> !transform.any_op
      %1:2 = "transform.structured.tile_using_for"(%0) <{scalable_sizes = array<i1: false>, static_sizes = array<i64: 2>}> : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
      "transform.yield"() : () -> ()
    }) : () -> ()
  }) {transform.with_named_sequence} : () -> ()
}) : () -> ()
```

https://github.com/llvm/llvm-project/pull/114688


More information about the Mlir-commits mailing list