[llvm-branch-commits] [mlir] [mlir][MemRef] Use specialized index ops to fold expand/collapse_shape (PR #138930)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed May 7 11:19:21 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Krzysztof Drewniak (krzysz00)
<details>
<summary>Changes</summary>
This PR updates the FoldMemRefAliasOps to use `affine.linearize_index` and `affine.delinearize_index` to perform the index computations needed to fold a `memref.expand_shape` or `memref.collapse_shape` into its consumers, respectively.
This also loosens some limitations of the pass:
1. The existing `output_shape` argument to `memref.expand_shape` is now used, eliminating the need to re-infer this shape or call `memref.dim`.
2. Because we're using `affine.delinearize_index`, the restriction that each group in a `memref.collapse_shape` can only have one dynamic dimension is removed.
---
Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138930.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+10-4)
- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+49-120)
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+64-65)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..f34b5b46cab50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
according to specified offsets, sizes, and strides.
```mlir
- %result1 = memref.reinterpret_cast %arg0 to
+ %result1 = memref.reinterpret_cast %arg0 to
offset: [9],
sizes: [4, 4],
strides: [16, 2]
: memref<8x8xf32, strided<[8, 1], offset: 0>> to
memref<4x4xf32, strided<[16, 2], offset: 9>>
- %result2 = memref.reinterpret_cast %result1 to
+ %result2 = memref.reinterpret_cast %result1 to
offset: [0],
sizes: [2, 2],
strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
OpBuilder &b, Location loc, MemRefType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape);
+
+ // Return a vector with all the static and dynamic values in the output shape.
+ SmallVector<OpFoldResult> getMixedOutputShape() {
+ OpBuilder builder(getContext());
+ return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
+ }
}];
let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
let summary = "store operation";
let description = [{
The `store` op stores an element into a memref at the specified indices.
-
+
The number of indices must match the rank of the memref. The indices must
be in-bounds: `0 <= idx < dim_size`
@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
Unlike the `reinterpret_cast`, the values are relative to the strided
memref of the input (`%result1` in this case) and not its
underlying memory.
-
+
Example 2:
```mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e4fb3f9bb87ed..2acb90613e5d1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -59,92 +59,28 @@ using namespace mlir;
///
/// %2 = load %0[6 * i1 + i2, %i3] :
/// memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
- memref::ExpandShapeOp expandShapeOp,
- ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices) {
- // Record the rewriter context for constructing ops later.
- MLIRContext *ctx = rewriter.getContext();
-
- // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
- // This is done for the purpose of inferring the output shape via
- // `inferExpandOutputShape` which will in turn be used for suffix product
- // calculation later.
- SmallVector<OpFoldResult> srcShape;
- MemRefType srcType = expandShapeOp.getSrcType();
-
- for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
- if (srcType.isDynamicDim(i)) {
- srcShape.push_back(
- rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
- .getResult());
- } else {
- srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
- }
- }
-
- auto outputShape = inferExpandShapeOutputShape(
- rewriter, loc, expandShapeOp.getResultType(),
- expandShapeOp.getReassociationIndices(), srcShape);
- if (!outputShape.has_value())
- return failure();
+static LogicalResult resolveSourceIndicesExpandShape(
+ Location loc, PatternRewriter &rewriter,
+ memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+ SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
// Traverse all reassociation groups to determine the appropriate indices
// corresponding to each one of them post op folding.
- for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
- assert(!groups.empty() && "association indices groups cannot be empty");
- // Flag to indicate the presence of dynamic dimensions in current
- // reassociation group.
- int64_t groupSize = groups.size();
-
- // Group output dimensions utilized in this reassociation group for suffix
- // product calculation.
- SmallVector<OpFoldResult> sizesVal(groupSize);
- for (int64_t i = 0; i < groupSize; ++i) {
- sizesVal[i] = (*outputShape)[groups[i]];
+ for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+ if (groupSize == 1) {
+ sourceIndices.push_back(indices[group[0]]);
+ continue;
}
-
- // Calculate suffix product of relevant output dimension sizes.
- SmallVector<OpFoldResult> suffixProduct =
- memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
-
- // Create affine expression variables for dimensions and symbols in the
- // newly constructed affine map.
- SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
- bindDimsList<AffineExpr>(ctx, dims);
- bindSymbolsList<AffineExpr>(ctx, symbols);
-
- // Linearize binded dimensions and symbols to construct the resultant
- // affine expression for this indice.
- AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
-
- // Record the load index corresponding to each dimension in the
- // reassociation group. These are later supplied as operands to the affine
- // map used for calulating relevant index post op folding.
- SmallVector<OpFoldResult> dynamicIndices(groupSize);
- for (int64_t i = 0; i < groupSize; i++)
- dynamicIndices[i] = indices[groups[i]];
-
- // Supply suffix product results followed by load op indices as operands
- // to the map.
- SmallVector<OpFoldResult> mapOperands;
- llvm::append_range(mapOperands, suffixProduct);
- llvm::append_range(mapOperands, dynamicIndices);
-
- // Creating maximally folded and composed affine.apply composes better
- // with other transformations without interleaving canonicalization
- // passes.
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/groupSize,
- /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
- mapOperands);
-
- // Push index value in the op post folding corresponding to this
- // reassociation group.
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ SmallVector<OpFoldResult> groupBasis =
+ llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+ SmallVector<Value> groupIndices =
+ llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+ Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+ loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+ sourceIndices.push_back(collapsedIndex);
}
return success();
}
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
- int64_t cnt = 0;
- SmallVector<OpFoldResult> dynamicIndices;
- for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
- assert(!groups.empty() && "association indices groups cannot be empty");
- dynamicIndices.push_back(indices[cnt++]);
- int64_t groupSize = groups.size();
-
- // Calculate suffix product for all collapse op source dimension sizes
- // except the most major one of each group.
- // We allow the most major source dimension to be dynamic but enforce all
- // others to be known statically.
- SmallVector<int64_t> sizes(groupSize, 1);
- for (int64_t i = 1; i < groupSize; ++i) {
- sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
- if (sizes[i] == ShapedType::kDynamic)
- return failure();
- }
- SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
-
- // Derive the index values along all dimensions of the source corresponding
- // to the index wrt to collapsed shape op output.
- auto d0 = rewriter.getAffineDimExpr(0);
- SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
-
- // Construct the AffineApplyOp for each delinearizingExpr.
- for (int64_t i = 0; i < groupSize; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc,
- AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
- delinearizingExprs[i]),
- dynamicIndices);
- sourceIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ MemRefType sourceType = collapseShapeOp.getSrcType();
+ // Note: collapse_shape requires a strided memref, we can do this.
+ auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, collapseShapeOp.getSrc());
+ SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+ for (auto [index, group] :
+ llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+ assert(!group.empty() && "association indices groups cannot be empty");
+ int64_t groupSize = group.size();
+
+ if (groupSize == 1) {
+ sourceIndices.push_back(index);
+ continue;
}
- dynamicIndices.clear();
+
+ SmallVector<OpFoldResult> basis =
+ llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+ auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ loc, index, basis, /*hasOuterBound=*/true);
+ llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, zeroAffineMap, dynamicIndices);
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
+ // memref.load and affine.load guarantee that indexes start inbounds
+ // while the vector operations don't. This impacts if our linearization
+ // is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
- loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
+ // memref.store and affine.store guarantee that indexes start inbounds
+ // while the vector operations don't. This impacts if our linearization
+ // is `disjoint`
if (failed(resolveSourceIndicesExpandShape(
- storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
return failure();
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](affine::AffineStoreOp op) {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index a27fbf26e13d8..106652623933f 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -408,7 +408,6 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
// -----
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
@@ -416,14 +415,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0
%1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
return %1 : f32
}
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -431,15 +428,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
%1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
return %1 : f32
}
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -447,14 +441,28 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
%1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
return %1 : f32
}
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
// -----
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+ %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
@@ -462,7 +470,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
%1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
return %1 : f32
}
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32
@@ -476,7 +484,10 @@ func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return %0 : f32
}
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return %[[VAL1]] : f32
// -----
@@ -490,14 +501,16 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return
}
-// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
// CHECK-NEXT: return
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -513,21 +526,20 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
}
return
}
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/138930
More information about the llvm-branch-commits
mailing list