[Mlir-commits] [mlir] [mlir][Tensor] Use output_shape for DimOp->ExpandShapeOp folding (PR #118203)
Quinn Dawkins
llvmlistbot at llvm.org
Sun Dec 1 06:48:45 PST 2024
================
@@ -1971,32 +1971,12 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
if (!dim.has_value())
return failure();
- // Skip static dims. These are folded to constant ops.
- RankedTensorType resultType = expandShapeOp.getResultType();
- if (!resultType.isDynamicDim(*dim))
- return failure();
-
- // Find reassociation group that contains this result dimension.
- int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
-
- // `dim` is the only dynamic dimension in `group`. (Otherwise, the
- // ExpandShapeOp would be ambiguous.)
- int64_t product = 1;
- ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
- for (int64_t d : grp) {
- if (d != dim) {
- assert(!resultType.isDynamicDim(d) && "expected static dim");
- product *= resultType.getDimSize(d);
- }
- }
-
- // result dim size = src dim size / (product(other dims in reassoc group))
- Value srcDimSz =
- rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
- AffineExpr expr;
- bindSymbols(dimOp.getContext(), expr);
- rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
- dimOp, expr.floorDiv(product), srcDimSz);
+ SmallVector<OpFoldResult> outputShape =
+ getMixedValues(expandShapeOp.getStaticOutputShape(),
+ expandShapeOp.getOutputShape(), rewriter);
+ OpFoldResult outputDim = outputShape[dim.value()];
+ rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp(
+ rewriter, dimOp.getLoc(), outputDim));
----------------
qedawkins wrote:
Can we just roll this into the folder for DimOp? The folder already handles the static case and has similar handling for `extract_slice`. This would have the added benefit of improving `createOrFold<DimOp>` which is used quite frequently.
https://github.com/llvm/llvm-project/pull/118203
More information about the Mlir-commits
mailing list