[Mlir-commits] [mlir] Fix assertion when tiling linalg.generic (PR #114688)
Christopher Bate
llvmlistbot at llvm.org
Wed Nov 13 21:52:22 PST 2024
christopherbate wrote:
Removing the assertion will not result in a correct result:
```
within split at test.mlir:1 offset :8:8: error: 'linalg.generic' op inferred input/output operand #0 has shape's dimension #0 to be greater than or equal to 8, but found 7
%0 = linalg.generic {
^
within split at test.mlir:1 offset :8:8: note: see current operation:
%7 = "linalg.generic"(%5, %6) <{indexing_maps = [affine_map<(d0) -> (-d0 + 7)>, affine_map<(d0) -> (d0)>], iterator_types = [#linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg4: i8, %arg5: i8):
"linalg.yield"(%arg4) : (i8) -> ()
}) : (tensor<7xi8>, tensor<2xi8>) -> tensor<2xi8>
// -----// IR Dump After InterpreterPass Failed (transform-interpreter) //----- //
#map = affine_map<(d0) -> (-d0 + 7)>
#map1 = affine_map<(d0) -> (d0)>
"builtin.module"() ({
"func.func"() <{function_type = (tensor<8xi8>, tensor<8xi8>) -> tensor<8xi8>, sym_name = "test"}> ({
^bb0(%arg1: tensor<8xi8>, %arg2: tensor<8xi8>):
%2 = "arith.constant"() <{value = 0 : index}> : () -> index
%3 = "arith.constant"() <{value = 8 : index}> : () -> index
%4 = "arith.constant"() <{value = 2 : index}> : () -> index
%5 = "scf.for"(%2, %3, %4, %arg2) ({
^bb0(%arg3: index, %arg4: tensor<8xi8>):
%6 = "affine.apply"(%arg3) <{map = #map}> : (index) -> index
%7 = "tensor.extract_slice"(%arg1, %6) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 7>, static_strides = array<i64: 1>}> : (tensor<8xi8>, index) -> tensor<7xi8>
%8 = "tensor.extract_slice"(%arg4, %arg3) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 2>, static_strides = array<i64: 1>}> : (tensor<8xi8>, index) -> tensor<2xi8>
%9 = "linalg.generic"(%7, %8) <{indexing_maps = [#map, #map1], iterator_types = [#linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg5: i8, %arg6: i8):
"linalg.yield"(%arg5) : (i8) -> ()
}) : (tensor<7xi8>, tensor<2xi8>) -> tensor<2xi8>
%10 = "tensor.insert_slice"(%9, %arg4, %arg3) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 2>, static_strides = array<i64: 1>}> : (tensor<2xi8>, tensor<8xi8>, index) -> tensor<8xi8>
"scf.yield"(%10) : (tensor<8xi8>) -> ()
}) : (index, index, index, tensor<8xi8>) -> tensor<8xi8>
"func.return"(%5) : (tensor<8xi8>) -> ()
}) : () -> ()
"builtin.module"() ({
"transform.named_sequence"() <{arg_attrs = [{transform.readonly}], function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
^bb0(%arg0: !transform.any_op):
%0 = "transform.structured.match"(%arg0) <{ops = ["linalg.generic"]}> : (!transform.any_op) -> !transform.any_op
%1:2 = "transform.structured.tile_using_for"(%0) <{scalable_sizes = array<i1: false>, static_sizes = array<i64: 2>}> : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
}) : () -> ()
```
https://github.com/llvm/llvm-project/pull/114688
More information about the Mlir-commits
mailing list