[Mlir-commits] [mlir] [mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims (PR #89093)

Prathamesh Tagore llvmlistbot at llvm.org
Fri Apr 19 00:03:48 PDT 2024


https://github.com/meshtag updated https://github.com/llvm/llvm-project/pull/89093

>From 063406cb50beeb3f42a82b1cbf9e6a2b8f8a5411 Mon Sep 17 00:00:00 2001
From: Prathamesh Tagore <prathamesh+1 at polymagelabs.com>
Date: Fri, 19 Apr 2024 12:28:47 +0530
Subject: [PATCH] [mlir][fold-memref-alias-ops] Add support for folding
 memref.expand_shape involving dynamic dims

fold-memref-alias-ops pass bails out in presence of dynamic shapes which leads to unwanted propagation of alias types during other transformations. This can percolate down further and can lead to errors which should not have been created in the first place.
---
 .../mlir/Dialect/Utils/IndexingUtils.h        | 25 +++++
 .../MemRef/Transforms/FoldMemRefAliasOps.cpp  | 98 +++++++++++++++----
 mlir/lib/Dialect/Utils/IndexingUtils.cpp      | 27 +++++
 .../Dialect/MemRef/fold-memref-alias-ops.mlir | 61 +++++++++++-
 4 files changed, 187 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 9892253df2bff1..5f0ea7ee99a85c 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -48,6 +48,31 @@ inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
   return computeSuffixProduct(sizes);
 }
 
+/// Given a set of sizes, return the suffix product.
+///
+/// When applied to slicing, this is the calculation needed to derive the
+/// strides (i.e. the number of linear indices to skip along the (k-1) most
+/// minor dimensions to get the next k-slice).
+///
+/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
+///
+/// Assuming `sizes` is `[s0, .. sn]`, return the vector<Value>
+///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
+///
+/// It is the caller's responsibility to provide valid values which are expected
+/// to be constants with index type or results of dimension extraction ops
+/// (for ex. memref.dim op).
+///
+/// `sizes` elements are asserted to be non-negative.
+///
+/// Return an empty vector if `sizes` is empty.
+SmallVector<Value> computeSuffixProduct(Location loc, OpBuilder &builder,
+                                        ArrayRef<Value> sizes);
+inline SmallVector<Value> computeStrides(Location loc, OpBuilder &builder,
+                                         ArrayRef<Value> sizes) {
+  return computeSuffixProduct(loc, builder, sizes);
+}
+
 /// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
 ///
 /// Return an empty vector if `v1` and `v2` are empty.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index aa44455ada7f9a..f5b4844c7fc1a6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -63,39 +63,99 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
                                 memref::ExpandShapeOp expandShapeOp,
                                 ValueRange indices,
                                 SmallVectorImpl<Value> &sourceIndices) {
-  // The below implementation uses computeSuffixProduct method, which only
-  // allows int64_t values (i.e., static shape). Bail out if it has dynamic
-  // shapes.
-  if (!expandShapeOp.getResultType().hasStaticShape())
-    return failure();
-
+  // Record the rewriter context for constructing ops later.
   MLIRContext *ctx = rewriter.getContext();
+
+  // Record result type to get result dimensions for calulating suffix product
+  // later.
+  ShapedType resultType = expandShapeOp.getResultType();
+
+  // Traverse all reassociation groups to determine the appropriate indice
+  // corresponding to each one of them post op folding.
   for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
     assert(!groups.empty() && "association indices groups cannot be empty");
+    // Flag to indicate the presence of dynamic dimensions in current
+    // reassociation group.
+    bool hasDynamicDims = false;
     int64_t groupSize = groups.size();
 
-    // Construct the expression for the index value w.r.t to expand shape op
-    // source corresponding the indices wrt to expand shape op result.
+    // Capture expand_shape's resultant memref dimensions which are to be used
+    // in suffix product calculation later.
     SmallVector<int64_t> sizes(groupSize);
-    for (int64_t i = 0; i < groupSize; ++i)
+    for (int64_t i = 0; i < groupSize; ++i) {
       sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
-    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+      if (resultType.isDynamicDim(groups[i]))
+        hasDynamicDims = true;
+    }
+
+    // Declare resultant affine apply result and affine expression variables to
+    // represent dimensions in the newly constructed affine map.
+    OpFoldResult ofr;
     SmallVector<AffineExpr> dims(groupSize);
     bindDimsList(ctx, MutableArrayRef{dims});
-    AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
 
-    /// Apply permutation and create AffineApplyOp.
+    // Record the load index corresponding to each dimension in the
+    // reassociation group. These are later supplied as operands to the affine
+    // map used for calulating relevant index post op folding.
     SmallVector<OpFoldResult> dynamicIndices(groupSize);
     for (int64_t i = 0; i < groupSize; i++)
       dynamicIndices[i] = indices[groups[i]];
 
-    // Creating maximally folded and composd affine.apply composes better with
-    // other transformations without interleaving canonicalization passes.
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc,
-        AffineMap::get(/*numDims=*/groupSize,
-                       /*numSymbols=*/0, srcIndexExpr),
-        dynamicIndices);
+    if (hasDynamicDims) {
+      // Record relevant dimension sizes for each result dimension in the
+      // reassociation group.
+      SmallVector<Value> sizesVal(groupSize);
+      for (int64_t i = 0; i < groupSize; ++i) {
+        if (sizes[i] <= 0)
+          sizesVal[i] = rewriter.create<memref::DimOp>(
+              loc, expandShapeOp.getResult(), groups[i]);
+        else
+          sizesVal[i] = rewriter.create<arith::ConstantIndexOp>(loc, sizes[i]);
+      }
+
+      // Calculate suffix product of previously obtained dimension sizes.
+      auto suffixProduct = computeSuffixProduct(loc, rewriter, sizesVal);
+
+      // Create affine expression variables for symbols in the newly constructed
+      // affine map.
+      SmallVector<AffineExpr> symbols(groupSize);
+      bindSymbolsList(ctx, MutableArrayRef{symbols});
+
+      // Linearize binded dimensions and symbols to construct the resultant
+      // affine expression for this indice.
+      AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
+
+      // Supply suffix product results followed by load op indices as operands
+      // to the map.
+      SmallVector<OpFoldResult> mapOperands;
+      llvm::append_range(mapOperands, suffixProduct);
+      llvm::append_range(mapOperands, dynamicIndices);
+
+      // Creating maximally folded and composed affine.apply composes better
+      // with other transformations without interleaving canonicalization
+      // passes.
+      ofr = affine::makeComposedFoldedAffineApply(
+          rewriter, loc,
+          AffineMap::get(/*numDims=*/groupSize,
+                         /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
+          mapOperands);
+    } else {
+      // Calculate suffix product of static dimension sizes and linearize those
+      // values with dimension affine variables defined previously.
+      SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+      AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
+
+      // Creating maximally folded and composed affine.apply composes better
+      // with other transformations without interleaving canonicalization
+      // passes.
+      ofr = affine::makeComposedFoldedAffineApply(
+          rewriter, loc,
+          AffineMap::get(/*numDims=*/groupSize,
+                         /*numSymbols=*/0, /*expression=*/srcIndexExpr),
+          dynamicIndices);
+    }
+    // Push index value in the op post folding corresponding to this
+    // reassociation group.
     sourceIndices.push_back(
         getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
   }
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 4c960659d80cb7..7b9a77bfc8a0a8 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
@@ -29,6 +31,19 @@ SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
   return strides;
 }
 
+static SmallVector<Value> computeSuffixProductImpl(Location loc,
+                                                   OpBuilder &builder,
+                                                   ArrayRef<Value> sizes,
+                                                   Value unit) {
+  if (sizes.empty())
+    return {};
+  SmallVector<Value> strides(sizes.size(), unit);
+  for (int64_t r = strides.size() - 2; r >= 0; --r)
+    strides[r] =
+        builder.create<arith::MulIOp>(loc, strides[r + 1], sizes[r + 1]);
+  return strides;
+}
+
 template <typename ExprType>
 SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
                                                 ArrayRef<ExprType> v2) {
@@ -197,6 +212,18 @@ SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
   return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
 }
 
+//===----------------------------------------------------------------------===//
+// Utils that operate on compile time unknown values.
+//===----------------------------------------------------------------------===//
+
+SmallVector<Value> mlir::computeSuffixProduct(Location loc, OpBuilder &builder,
+                                              ArrayRef<Value> sizes) {
+  if (sizes.empty())
+    return {};
+  Value unit = builder.create<arith::ConstantIndexOp>(loc, 1);
+  return ::computeSuffixProductImpl(loc, builder, sizes, unit);
+}
+
 //===----------------------------------------------------------------------===//
 // Permutation utils.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 5b853a6cc5a37a..99ac6115558aeb 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -468,16 +468,67 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
 
 // -----
 
-// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
-func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
+// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
+// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> f32
+func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
   %c0 = arith.constant 0 : index
   %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return %0 : f32
 }
-// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
-// CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]]
-// CHECK: return %[[LOAD]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: return %[[VAL1]] : f32
+
+// -----
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
+// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) {
+  %c0 = arith.constant 0 : index
+  %c1f32 = arith.constant 1.0 : f32
+  %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+  memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+  return
+}
+// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: return
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
+// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index) {
+  %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
+  %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+  %dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+
+  affine.for %arg6 = 0 to %dim step 64 {
+    affine.for %arg7 = 0 to 16 step 16 {
+      %dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+      affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
+    }
+  }
+  return
+}
+// CHECK-NEXT:   memref.subview
+// CHECK-NEXT:   %[[EXPAND_SHAPE:.*]] = memref.expand_shape
+// CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+// CHECK-NEXT:   affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
+// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
+// CHECK-NEXT:   %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK-NEXT:   %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
+// CHECK-NEXT:   %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
+// CHECK-NEXT:   %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK-NEXT:   affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
 
 // -----
 



More information about the Mlir-commits mailing list