[Mlir-commits] [mlir] [mlir][Tensor] Use output_shape for DimOp->ExpandShapeOp folding (PR #118203)

Quinn Dawkins llvmlistbot at llvm.org
Sun Dec 1 06:48:45 PST 2024


================
@@ -1971,32 +1971,12 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
     if (!dim.has_value())
       return failure();
 
-    // Skip static dims. These are folded to constant ops.
-    RankedTensorType resultType = expandShapeOp.getResultType();
-    if (!resultType.isDynamicDim(*dim))
-      return failure();
-
-    // Find reassociation group that contains this result dimension.
-    int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
-
-    // `dim` is the only dynamic dimension in `group`. (Otherwise, the
-    // ExpandShapeOp would be ambiguous.)
-    int64_t product = 1;
-    ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
-    for (int64_t d : grp) {
-      if (d != dim) {
-        assert(!resultType.isDynamicDim(d) && "expected static dim");
-        product *= resultType.getDimSize(d);
-      }
-    }
-
-    // result dim size = src dim size / (product(other dims in reassoc group))
-    Value srcDimSz =
-        rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
-    AffineExpr expr;
-    bindSymbols(dimOp.getContext(), expr);
-    rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
-        dimOp, expr.floorDiv(product), srcDimSz);
+    SmallVector<OpFoldResult> outputShape =
+        getMixedValues(expandShapeOp.getStaticOutputShape(),
+                       expandShapeOp.getOutputShape(), rewriter);
+    OpFoldResult outputDim = outputShape[dim.value()];
+    rewriter.replaceOp(dimOp, getValueOrCreateConstantIndexOp(
+                                  rewriter, dimOp.getLoc(), outputDim));
----------------
qedawkins wrote:

Can we just roll this into the folder for DimOp? The folder already handles the static case and has similar handling for `extract_slice`. This would have the added benefit of improving `createOrFold<DimOp>` which is used quite frequently.

https://github.com/llvm/llvm-project/pull/118203


More information about the Mlir-commits mailing list