[Mlir-commits] [mlir] [mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims (PR #89093)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 17 18:31:24 PDT 2024
================
@@ -63,39 +63,99 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
memref::ExpandShapeOp expandShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
- // The below implementation uses computeSuffixProduct method, which only
- // allows int64_t values (i.e., static shape). Bail out if it has dynamic
- // shapes.
- if (!expandShapeOp.getResultType().hasStaticShape())
- return failure();
-
+ // Record the rewriter context for constructing ops later.
MLIRContext *ctx = rewriter.getContext();
+
+ // Record result type to get result dimensions for calulating suffix product
+ // later.
+ ShapedType resultType = expandShapeOp.getResultType();
+
+ // Traverse all reassociation groups to determine the appropriate indice
+ // corresponding to each one of them post op folding.
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
assert(!groups.empty() && "association indices groups cannot be empty");
+ // Flag to indicate the presence of dynamic dimensions in current
+ // reassociation group.
+ bool hasDynamicDims = false;
int64_t groupSize = groups.size();
- // Construct the expression for the index value w.r.t to expand shape op
- // source corresponding the indices wrt to expand shape op result.
+ // Capture expand_shape's resultant memref dimensions which are to be used
+ // in suffix product calculation later.
SmallVector<int64_t> sizes(groupSize);
- for (int64_t i = 0; i < groupSize; ++i)
+ for (int64_t i = 0; i < groupSize; ++i) {
sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
- SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+ if (resultType.isDynamicDim(groups[i]))
----------------
prathameshpml wrote:
We are using the static shape value directly as well ([here](https://github.com/meshtag/llvm-project/blob/224262a2d921da44c484146d8f5209a80a3306a5/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp#L113)) whenever we can find it. In the event that we don't find it, we use `memref.dim` op.
https://github.com/llvm/llvm-project/pull/89093
More information about the Mlir-commits
mailing list