[Mlir-commits] [mlir] [mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims (PR #89093)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 17 18:31:24 PDT 2024


================
@@ -63,39 +63,99 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
                                 memref::ExpandShapeOp expandShapeOp,
                                 ValueRange indices,
                                 SmallVectorImpl<Value> &sourceIndices) {
-  // The below implementation uses computeSuffixProduct method, which only
-  // allows int64_t values (i.e., static shape). Bail out if it has dynamic
-  // shapes.
-  if (!expandShapeOp.getResultType().hasStaticShape())
-    return failure();
-
+  // Record the rewriter context for constructing ops later.
   MLIRContext *ctx = rewriter.getContext();
+
+  // Record result type to get result dimensions for calulating suffix product
+  // later.
+  ShapedType resultType = expandShapeOp.getResultType();
+
+  // Traverse all reassociation groups to determine the appropriate indice
+  // corresponding to each one of them post op folding.
   for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
     assert(!groups.empty() && "association indices groups cannot be empty");
+    // Flag to indicate the presence of dynamic dimensions in current
+    // reassociation group.
+    bool hasDynamicDims = false;
     int64_t groupSize = groups.size();
 
-    // Construct the expression for the index value w.r.t to expand shape op
-    // source corresponding the indices wrt to expand shape op result.
+    // Capture expand_shape's resultant memref dimensions which are to be used
+    // in suffix product calculation later.
     SmallVector<int64_t> sizes(groupSize);
-    for (int64_t i = 0; i < groupSize; ++i)
+    for (int64_t i = 0; i < groupSize; ++i) {
       sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
-    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+      if (resultType.isDynamicDim(groups[i]))
----------------
prathameshpml wrote:

We are using the static shape value directly as well ([here](https://github.com/meshtag/llvm-project/blob/224262a2d921da44c484146d8f5209a80a3306a5/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp#L113)) whenever we can find it. In the event that we don't find it, we use `memref.dim` op.

https://github.com/llvm/llvm-project/pull/89093


More information about the Mlir-commits mailing list