[Mlir-commits] [mlir] [mlir] Handle arith.const expr in dispatchIndexOpFoldResult func (PR #122432)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 14 07:02:45 PST 2025
rutkoor wrote:
> What I am wondering: Where is this example op coming from?
>
> ```
> %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
> output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
> ```
>
> I'd like to understand why the `output_shape` is not static.
Without the changes from this PR, this test case is invalid, it will throw below error,
```
within split at mlir/test/Dialect/Tensor/bubble-reshapes.mlir:21 offset :7:13: error: 'tensor.expand_shape' op expected dimension 3 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
^
within split at mlir/test/Dialect/Tensor/bubble-reshapes.mlir:21 offset :7:13: note: see current operation: %2 = "tensor.expand_shape"(%arg0, %arg1, %0, %1) <{reassociation = [[0], [1], [2], [3, 4]], static_output_shape = array<i64: -9223372036854775808, 2, 2, -9223372036854775808, -9223372036854775808>}> : (tensor<?x2x2x6xf32>, index, index, index) -> tensor<?x2x2x?x?xf32>
```
`BubbleUpExpandThroughParallelCollapse` patternRewriter is creating a `<tensor::ExpandShapeOp>` which is where we pass `resultType` and other arguments. Below is the code from `BubbleUpExpandThroughParallelCollapse`.
```
// Swap reshape order.
SmallVector<Value> dynamicSizes;
SmallVector<int64_t> staticSizes;
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
auto expandResultType = expandOp.getResultType().clone(staticSizes);
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
newExpandSizes);
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
expandOp, newExpand.getResult(), newCollapseReInds);
```
The `output_shape` is part of `newExpandSizes` which is being passed to `dispatchIndexOpFoldResults` function.
https://github.com/llvm/llvm-project/pull/122432
More information about the Mlir-commits
mailing list