[Mlir-commits] [mlir] [mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims (PR #89093)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 17 08:59:29 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Prathamesh Tagore (meshtag)
<details>
<summary>Changes</summary>
`fold-memref-alias-ops` bails out in presence of dynamic shapes in `memref.expand_shape` op. Handle this case.
---
Full diff: https://github.com/llvm/llvm-project/pull/89093.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Utils/IndexingUtils.h (+25)
- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+79-19)
- (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+27)
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+56-5)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 9892253df2bff1..5f0ea7ee99a85c 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -48,6 +48,31 @@ inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
return computeSuffixProduct(sizes);
}
+/// Given a set of sizes, return the suffix product.
+///
+/// When applied to slicing, this is the calculation needed to derive the
+/// strides (i.e. the number of linear indices to skip along the (k-1) most
+/// minor dimensions to get the next k-slice).
+///
+/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
+///
+/// Assuming `sizes` is `[s0, .. sn]`, return the vector<Value>
+/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
+///
+/// It is the caller's responsibility to provide valid values which are expected
+/// to be constants with index type or results of dimension extraction ops
+/// (for ex. memref.dim op).
+///
+/// `sizes` elements are asserted to be non-negative.
+///
+/// Return an empty vector if `sizes` is empty.
+SmallVector<Value> computeSuffixProduct(Location loc, OpBuilder &builder,
+ ArrayRef<Value> sizes);
+inline SmallVector<Value> computeStrides(Location loc, OpBuilder &builder,
+ ArrayRef<Value> sizes) {
+ return computeSuffixProduct(loc, builder, sizes);
+}
+
/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
///
/// Return an empty vector if `v1` and `v2` are empty.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index aa44455ada7f9a..f5b4844c7fc1a6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -63,39 +63,99 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
memref::ExpandShapeOp expandShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
- // The below implementation uses computeSuffixProduct method, which only
- // allows int64_t values (i.e., static shape). Bail out if it has dynamic
- // shapes.
- if (!expandShapeOp.getResultType().hasStaticShape())
- return failure();
-
+ // Record the rewriter context for constructing ops later.
MLIRContext *ctx = rewriter.getContext();
+
+ // Record result type to get result dimensions for calulating suffix product
+ // later.
+ ShapedType resultType = expandShapeOp.getResultType();
+
+ // Traverse all reassociation groups to determine the appropriate indice
+ // corresponding to each one of them post op folding.
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
assert(!groups.empty() && "association indices groups cannot be empty");
+ // Flag to indicate the presence of dynamic dimensions in current
+ // reassociation group.
+ bool hasDynamicDims = false;
int64_t groupSize = groups.size();
- // Construct the expression for the index value w.r.t to expand shape op
- // source corresponding the indices wrt to expand shape op result.
+ // Capture expand_shape's resultant memref dimensions which are to be used
+ // in suffix product calculation later.
SmallVector<int64_t> sizes(groupSize);
- for (int64_t i = 0; i < groupSize; ++i)
+ for (int64_t i = 0; i < groupSize; ++i) {
sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
- SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+ if (resultType.isDynamicDim(groups[i]))
+ hasDynamicDims = true;
+ }
+
+ // Declare resultant affine apply result and affine expression variables to
+ // represent dimensions in the newly constructed affine map.
+ OpFoldResult ofr;
SmallVector<AffineExpr> dims(groupSize);
bindDimsList(ctx, MutableArrayRef{dims});
- AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
- /// Apply permutation and create AffineApplyOp.
+ // Record the load index corresponding to each dimension in the
+ // reassociation group. These are later supplied as operands to the affine
+ // map used for calulating relevant index post op folding.
SmallVector<OpFoldResult> dynamicIndices(groupSize);
for (int64_t i = 0; i < groupSize; i++)
dynamicIndices[i] = indices[groups[i]];
- // Creating maximally folded and composd affine.apply composes better with
- // other transformations without interleaving canonicalization passes.
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/groupSize,
- /*numSymbols=*/0, srcIndexExpr),
- dynamicIndices);
+ if (hasDynamicDims) {
+ // Record relevant dimension sizes for each result dimension in the
+ // reassociation group.
+ SmallVector<Value> sizesVal(groupSize);
+ for (int64_t i = 0; i < groupSize; ++i) {
+ if (sizes[i] <= 0)
+ sizesVal[i] = rewriter.create<memref::DimOp>(
+ loc, expandShapeOp.getResult(), groups[i]);
+ else
+ sizesVal[i] = rewriter.create<arith::ConstantIndexOp>(loc, sizes[i]);
+ }
+
+ // Calculate suffix product of previously obtained dimension sizes.
+ auto suffixProduct = computeSuffixProduct(loc, rewriter, sizesVal);
+
+ // Create affine expression variables for symbols in the newly constructed
+ // affine map.
+ SmallVector<AffineExpr> symbols(groupSize);
+ bindSymbolsList(ctx, MutableArrayRef{symbols});
+
+ // Linearize binded dimensions and symbols to construct the resultant
+ // affine expression for this indice.
+ AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
+
+ // Supply suffix product results followed by load op indices as operands
+ // to the map.
+ SmallVector<OpFoldResult> mapOperands;
+ llvm::append_range(mapOperands, suffixProduct);
+ llvm::append_range(mapOperands, dynamicIndices);
+
+ // Creating maximally folded and composed affine.apply composes better
+ // with other transformations without interleaving canonicalization
+ // passes.
+ ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc,
+ AffineMap::get(/*numDims=*/groupSize,
+ /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
+ mapOperands);
+ } else {
+ // Calculate suffix product of static dimension sizes and linearize those
+ // values with dimension affine variables defined previously.
+ SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
+ AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
+
+ // Creating maximally folded and composed affine.apply composes better
+ // with other transformations without interleaving canonicalization
+ // passes.
+ ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc,
+ AffineMap::get(/*numDims=*/groupSize,
+ /*numSymbols=*/0, /*expression=*/srcIndexExpr),
+ dynamicIndices);
+ }
+ // Push index value in the op post folding corresponding to this
+ // reassociation group.
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 4c960659d80cb7..7b9a77bfc8a0a8 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/IndexingUtils.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
@@ -29,6 +31,19 @@ SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
return strides;
}
+static SmallVector<Value> computeSuffixProductImpl(Location loc,
+ OpBuilder &builder,
+ ArrayRef<Value> sizes,
+ Value unit) {
+ if (sizes.empty())
+ return {};
+ SmallVector<Value> strides(sizes.size(), unit);
+ for (int64_t r = strides.size() - 2; r >= 0; --r)
+ strides[r] =
+ builder.create<arith::MulIOp>(loc, strides[r + 1], sizes[r + 1]);
+ return strides;
+}
+
template <typename ExprType>
SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
ArrayRef<ExprType> v2) {
@@ -197,6 +212,18 @@ SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
}
+//===----------------------------------------------------------------------===//
+// Utils that operate on compile time unknown values.
+//===----------------------------------------------------------------------===//
+
+SmallVector<Value> mlir::computeSuffixProduct(Location loc, OpBuilder &builder,
+ ArrayRef<Value> sizes) {
+ if (sizes.empty())
+ return {};
+ Value unit = builder.create<arith::ConstantIndexOp>(loc, 1);
+ return ::computeSuffixProductImpl(loc, builder, sizes, unit);
+}
+
//===----------------------------------------------------------------------===//
// Permutation utils.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 5b853a6cc5a37a..99ac6115558aeb 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -468,16 +468,67 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
// -----
-// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
-func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
+// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
+// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> f32
+func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
%c0 = arith.constant 0 : index
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return %0 : f32
}
-// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
-// CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]]
-// CHECK: return %[[LOAD]]
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: return %[[VAL1]] : f32
+
+// -----
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
+// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) {
+ %c0 = arith.constant 0 : index
+ %c1f32 = arith.constant 1.0 : f32
+ %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+ memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+ return
+}
+// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: return
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
+// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
+func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index) {
+ %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
+ %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+ %dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+
+ affine.for %arg6 = 0 to %dim step 64 {
+ affine.for %arg7 = 0 to 16 step 16 {
+ %dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+ affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
+ }
+ }
+ return
+}
+// CHECK-NEXT: memref.subview
+// CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
+// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
+// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
+// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
+// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/89093
More information about the Mlir-commits
mailing list