[Mlir-commits] [mlir] [mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims (PR #89093)

Prathamesh Tagore llvmlistbot at llvm.org
Wed Apr 17 08:58:59 PDT 2024


https://github.com/meshtag created https://github.com/llvm/llvm-project/pull/89093

`fold-memref-alias-ops` bails out in presence of dynamic shapes in `memref.expand_shape` op. Handle this case.

>From 224262a2d921da44c484146d8f5209a80a3306a5 Mon Sep 17 00:00:00 2001
From: Prathamesh Tagore <prathamesh+1 at polymagelabs.com>
Date: Wed, 17 Apr 2024 21:11:29 +0530
Subject: [PATCH] [mlir][fold-memref-alias-ops] Add support for folding
 memref.expand_shape involving dynamic dims

---
 .../mlir/Dialect/Utils/IndexingUtils.h        | 25 +++++
 .../MemRef/Transforms/FoldMemRefAliasOps.cpp  | 98 +++++++++++++++----
 mlir/lib/Dialect/Utils/IndexingUtils.cpp      | 27 +++++
 .../Dialect/MemRef/fold-memref-alias-ops.mlir | 61 +++++++++++-
 4 files changed, 187 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 9892253df2bff1..5f0ea7ee99a85c 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -48,6 +48,31 @@ inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
   return computeSuffixProduct(sizes);
 }
 
+/// Given a set of sizes, return the suffix product.
+///
+/// When applied to slicing, this is the calculation needed to derive the
+/// strides (i.e. the number of linear indices to skip along the (k-1) most
+/// minor dimensions to get the next k-slice).
+///
+/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
+///
+/// Assuming `sizes` is `[s0, .. sn]`, return the vector<Value>
+///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
+///
+/// It is the caller's responsibility to provide valid values which are expected
+/// to be constants with index type or results of dimension extraction ops
+/// (for ex. memref.dim op).
+///
+/// `sizes` elements are asserted to be non-negative.
+///
+/// Return an empty vector if `sizes` is empty.
+SmallVector<Value> computeSuffixProduct(Location loc, OpBuilder &builder,
+                                        ArrayRef<Value> sizes);
+inline SmallVector<Value> computeStrides(Location loc, OpBuilder &builder,
+                                         ArrayRef<Value> sizes) {
+  return computeSuffixProduct(loc, builder, sizes);
+}
+
 /// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
 ///
 /// Return an empty vector if `v1` and `v2` are empty.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index aa44455ada7f9a..f5b4844c7fc1a6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -63,39 +63,99 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
                                 memref::ExpandShapeOp expandShapeOp,
                                 ValueRange indices,
                                 SmallVectorImpl<Value> &sourceIndices) {
-  // The below implementation uses computeSuffixProduct method, which only
-  // allows int64_t values (i.e., static shape). Bail out if it has dynamic
-  // shapes.
-  if (!expandShapeOp.getResultType().hasStaticShape())
-    return failure();
-
+  // Record the rewriter context for constructing ops later.
   MLIRContext *ctx = rewriter.getContext();
+
+  // Record result type to get result dimensions for calulating suffix product
+  // later.
+  ShapedType resultType = expandShapeOp.getResultType();
+
+  // Traverse all reassociation groups to determine the appropriate indice
+  // corresponding to each one of them post op folding.
   for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
     assert(!groups.empty() && "association indices groups cannot be empty");
+    // Flag to indicate the presence of dynamic dimensions in current
+    // reassociation group.
+    bool hasDynamicDims = false;
     int64_t groupSize = groups.size();
 
-    // Construct the expression for the index value w.r.t to expand shape op
-    // source corresponding the indices wrt to expand shape op result.
+    // Capture expand_shape's resultant memref dimensions which are to be used
+    // in suffix product calculation later.
     SmallVector<int64_t> sizes(groupSize);
-    for (int64_t i = 0; i < groupSize; ++i)
+    for (int64_t i = 0; i < groupSize; ++i) {
       sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
-    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+      if (resultType.isDynamicDim(groups[i]))
+        hasDynamicDims = true;
+    }
+
+    // Declare resultant affine apply result and affine expression variables to
+    // represent dimensions in the newly constructed affine map.
+    OpFoldResult ofr;
     SmallVector<AffineExpr> dims(groupSize);
     bindDimsList(ctx, MutableArrayRef{dims});
-    AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
 
-    /// Apply permutation and create AffineApplyOp.
+    // Record the load index corresponding to each dimension in the
+    // reassociation group. These are later supplied as operands to the affine
+    // map used for calulating relevant index post op folding.
     SmallVector<OpFoldResult> dynamicIndices(groupSize);
     for (int64_t i = 0; i < groupSize; i++)
       dynamicIndices[i] = indices[groups[i]];
 
-    // Creating maximally folded and composd affine.apply composes better with
-    // other transformations without interleaving canonicalization passes.
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc,
-        AffineMap::get(/*numDims=*/groupSize,
-                       /*numSymbols=*/0, srcIndexExpr),
-        dynamicIndices);
+    if (hasDynamicDims) {
+      // Record relevant dimension sizes for each result dimension in the
+      // reassociation group.
+      SmallVector<Value> sizesVal(groupSize);
+      for (int64_t i = 0; i < groupSize; ++i) {
+        if (sizes[i] <= 0)
+          sizesVal[i] = rewriter.create<memref::DimOp>(
+              loc, expandShapeOp.getResult(), groups[i]);
+        else
+          sizesVal[i] = rewriter.create<arith::ConstantIndexOp>(loc, sizes[i]);
+      }
+
+      // Calculate suffix product of previously obtained dimension sizes.
+      auto suffixProduct = computeSuffixProduct(loc, rewriter, sizesVal);
+
+      // Create affine expression variables for symbols in the newly constructed
+      // affine map.
+      SmallVector<AffineExpr> symbols(groupSize);
+      bindSymbolsList(ctx, MutableArrayRef{symbols});
+
+      // Linearize binded dimensions and symbols to construct the resultant
+      // affine expression for this indice.
+      AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
+
+      // Supply suffix product results followed by load op indices as operands
+      // to the map.
+      SmallVector<OpFoldResult> mapOperands;
+      llvm::append_range(mapOperands, suffixProduct);
+      llvm::append_range(mapOperands, dynamicIndices);
+
+      // Creating maximally folded and composed affine.apply composes better
+      // with other transformations without interleaving canonicalization
+      // passes.
+      ofr = affine::makeComposedFoldedAffineApply(
+          rewriter, loc,
+          AffineMap::get(/*numDims=*/groupSize,
+                         /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
+          mapOperands);
+    } else {
+      // Calculate suffix product of static dimension sizes and linearize those
+      // values with dimension affine variables defined previously.
+      SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+      AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
+
+      // Creating maximally folded and composed affine.apply composes better
+      // with other transformations without interleaving canonicalization
+      // passes.
+      ofr = affine::makeComposedFoldedAffineApply(
+          rewriter, loc,
+          AffineMap::get(/*numDims=*/groupSize,
+                         /*numSymbols=*/0, /*expression=*/srcIndexExpr),
+          dynamicIndices);
+    }
+    // Push index value in the op post folding corresponding to this
+    // reassociation group.
     sourceIndices.push_back(
         getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
   }
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 4c960659d80cb7..7b9a77bfc8a0a8 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
@@ -29,6 +31,19 @@ SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
   return strides;
 }
 
+static SmallVector<Value> computeSuffixProductImpl(Location loc,
+                                                   OpBuilder &builder,
+                                                   ArrayRef<Value> sizes,
+                                                   Value unit) {
+  if (sizes.empty())
+    return {};
+  SmallVector<Value> strides(sizes.size(), unit);
+  for (int64_t r = strides.size() - 2; r >= 0; --r)
+    strides[r] =
+        builder.create<arith::MulIOp>(loc, strides[r + 1], sizes[r + 1]);
+  return strides;
+}
+
 template <typename ExprType>
 SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
                                                 ArrayRef<ExprType> v2) {
@@ -197,6 +212,18 @@ SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
   return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
 }
 
+//===----------------------------------------------------------------------===//
+// Utils that operate on compile time unknown values.
+//===----------------------------------------------------------------------===//
+
+SmallVector<Value> mlir::computeSuffixProduct(Location loc, OpBuilder &builder,
+                                              ArrayRef<Value> sizes) {
+  if (sizes.empty())
+    return {};
+  Value unit = builder.create<arith::ConstantIndexOp>(loc, 1);
+  return ::computeSuffixProductImpl(loc, builder, sizes, unit);
+}
+
 //===----------------------------------------------------------------------===//
 // Permutation utils.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 5b853a6cc5a37a..99ac6115558aeb 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -468,16 +468,67 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
 
 // -----
 
-// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
-func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
+// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
+// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> f32
+func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
   %c0 = arith.constant 0 : index
   %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return %0 : f32
 }
-// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
-// CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]]
-// CHECK: return %[[LOAD]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: return %[[VAL1]] : f32
+
+// -----
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
+// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) {
+  %c0 = arith.constant 0 : index
+  %c1f32 = arith.constant 1.0 : f32
+  %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+  memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+  return
+}
+// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: return
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
+// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index) {
+  %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
+  %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+  %dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+
+  affine.for %arg6 = 0 to %dim step 64 {
+    affine.for %arg7 = 0 to 16 step 16 {
+      %dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+      affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
+    }
+  }
+  return
+}
+// CHECK-NEXT:   memref.subview
+// CHECK-NEXT:   %[[EXPAND_SHAPE:.*]] = memref.expand_shape
+// CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+// CHECK-NEXT:   affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
+// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
+// CHECK-NEXT:   %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK-NEXT:   %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
+// CHECK-NEXT:   %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
+// CHECK-NEXT:   %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK-NEXT:   affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
 
 // -----
 



More information about the Mlir-commits mailing list