[Mlir-commits] [mlir] [mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims (PR #89093)

Prathamesh Tagore llvmlistbot at llvm.org
Mon May 6 21:26:46 PDT 2024


================
@@ -63,39 +64,85 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
                                 memref::ExpandShapeOp expandShapeOp,
                                 ValueRange indices,
                                 SmallVectorImpl<Value> &sourceIndices) {
-  // The below implementation uses computeSuffixProduct method, which only
-  // allows int64_t values (i.e., static shape). Bail out if it has dynamic
-  // shapes.
-  if (!expandShapeOp.getResultType().hasStaticShape())
+  // Record the rewriter context for constructing ops later.
+  MLIRContext *ctx = rewriter.getContext();
+
+  // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
+  // This is done for the purpose of inferring the output shape via
+  // `inferExpandOutputShape` which will in turn be used for suffix product
+  // calculation later.
+  SmallVector<OpFoldResult> srcShape;
+  MemRefType srcType = expandShapeOp.getSrcType();
+
+  for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
+    if (srcType.isDynamicDim(i)) {
+      srcShape.push_back(
+          rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
+              .getResult());
+    } else {
+      srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
+    }
+  }
+
+  auto outputShape = inferExpandShapeOutputShape(
----------------
meshtag wrote:

The above `for` loop is used to populate `srcShape` which is required for the function call. Are you talking about a different `for` loop?

https://github.com/llvm/llvm-project/pull/89093


More information about the Mlir-commits mailing list