[Mlir-commits] [mlir] [mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims (PR #89093)

Prathamesh Tagore llvmlistbot at llvm.org
Tue Apr 23 05:05:57 PDT 2024


================
@@ -63,39 +63,99 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
                                 memref::ExpandShapeOp expandShapeOp,
                                 ValueRange indices,
                                 SmallVectorImpl<Value> &sourceIndices) {
-  // The below implementation uses computeSuffixProduct method, which only
-  // allows int64_t values (i.e., static shape). Bail out if it has dynamic
-  // shapes.
-  if (!expandShapeOp.getResultType().hasStaticShape())
-    return failure();
-
+  // Record the rewriter context for constructing ops later.
   MLIRContext *ctx = rewriter.getContext();
+
+  // Record result type to get result dimensions for calulating suffix product
+  // later.
+  ShapedType resultType = expandShapeOp.getResultType();
+
+  // Traverse all reassociation groups to determine the appropriate indice
+  // corresponding to each one of them post op folding.
   for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
     assert(!groups.empty() && "association indices groups cannot be empty");
+    // Flag to indicate the presence of dynamic dimensions in current
+    // reassociation group.
+    bool hasDynamicDims = false;
     int64_t groupSize = groups.size();
 
-    // Construct the expression for the index value w.r.t to expand shape op
-    // source corresponding the indices wrt to expand shape op result.
+    // Capture expand_shape's resultant memref dimensions which are to be used
+    // in suffix product calculation later.
     SmallVector<int64_t> sizes(groupSize);
-    for (int64_t i = 0; i < groupSize; ++i)
+    for (int64_t i = 0; i < groupSize; ++i) {
       sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
-    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+      if (resultType.isDynamicDim(groups[i]))
+        hasDynamicDims = true;
+    }
+
+    // Declare resultant affine apply result and affine expression variables to
+    // represent dimensions in the newly constructed affine map.
+    OpFoldResult ofr;
     SmallVector<AffineExpr> dims(groupSize);
     bindDimsList(ctx, MutableArrayRef{dims});
-    AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
 
-    /// Apply permutation and create AffineApplyOp.
+    // Record the load index corresponding to each dimension in the
+    // reassociation group. These are later supplied as operands to the affine
+    // map used for calulating relevant index post op folding.
     SmallVector<OpFoldResult> dynamicIndices(groupSize);
     for (int64_t i = 0; i < groupSize; i++)
       dynamicIndices[i] = indices[groups[i]];
 
-    // Creating maximally folded and composd affine.apply composes better with
-    // other transformations without interleaving canonicalization passes.
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc,
-        AffineMap::get(/*numDims=*/groupSize,
-                       /*numSymbols=*/0, srcIndexExpr),
-        dynamicIndices);
+    if (hasDynamicDims) {
+      // Record relevant dimension sizes for each result dimension in the
+      // reassociation group.
+      SmallVector<Value> sizesVal(groupSize);
+      for (int64_t i = 0; i < groupSize; ++i) {
+        if (sizes[i] <= 0)
+          sizesVal[i] = rewriter.create<memref::DimOp>(
+              loc, expandShapeOp.getResult(), groups[i]);
+        else
+          sizesVal[i] = rewriter.create<arith::ConstantIndexOp>(loc, sizes[i]);
+      }
+
+      // Calculate suffix product of previously obtained dimension sizes.
+      auto suffixProduct = computeSuffixProduct(loc, rewriter, sizesVal);
+
+      // Create affine expression variables for symbols in the newly constructed
+      // affine map.
+      SmallVector<AffineExpr> symbols(groupSize);
+      bindSymbolsList(ctx, MutableArrayRef{symbols});
----------------
meshtag wrote:

I got this without it 
```
error: no matching function for call to ‘bindSymbolsList(mlir::MLIRContext*&, llvm::SmallVector<mlir::AffineExpr>&)’
  122 |       bindSymbolsList(ctx, symbols);
```

https://github.com/llvm/llvm-project/pull/89093


More information about the Mlir-commits mailing list