[Mlir-commits] [mlir] [mlir][tensor] Fix tensor.concat reifyResultShapes for static result dims (PR #75558)
Matthias Springer
llvmlistbot at llvm.org
Thu Dec 14 23:25:00 PST 2023
================
@@ -637,17 +637,24 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
}
}
- // Take the sum of the input sizes along the concatenated dim.
- AffineExpr sum = builder.getAffineDimExpr(0);
- SmallVector<OpFoldResult> sizes = {
- builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
- for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
- sum = sum + builder.getAffineDimExpr(idx + 1);
- sizes.push_back(
- builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+ if (getType().isDynamicDim(dim)) {
+ // Take the sum of the input sizes along the concatenated dim.
+ AffineExpr sum = builder.getAffineDimExpr(0);
+ SmallVector<OpFoldResult> sizes = {
+ builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
+ for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
+ sum = sum + builder.getAffineDimExpr(idx + 1);
+ sizes.push_back(
+ builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+ }
+ reifiedReturnShapes[0][dim] =
+ affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
----------------
matthias-springer wrote:
This must be wrapped in `getValueOrCreateConstantIndexOp` (see https://github.com/llvm/llvm-project/blob/main/mlir/lib/Interfaces/InferTypeOpInterface.cpp#L54).
https://github.com/llvm/llvm-project/pull/75558
More information about the Mlir-commits
mailing list