[llvm-branch-commits] [mlir] [mlir][MemRef] Use specialized index ops to fold expand/collapse_shape (PR #138930)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed May 7 11:19:21 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Krzysztof Drewniak (krzysz00)

<details>
<summary>Changes</summary>

This PR updates the FoldMemRefAliasOps to use `affine.linearize_index` and `affine.delinearize_index` to perform the index computations needed to fold a `memref.expand_shape` or `memref.collapse_shape` into its consumers, respectively.

This also loosens some limitations of the pass:
1. The existing `output_shape` argument to `memref.expand_shape` is now used, eliminating the need to re-infer this shape or call `memref.dim`.
2. Because we're using `affine.delinearize_index`, the restriction that each group in a `memref.collapse_shape` can only have one dynamic dimension is removed.

---

Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138930.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+10-4) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+49-120) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+64-65) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..f34b5b46cab50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
     according to specified offsets, sizes, and strides.
 
     ```mlir
-    %result1 = memref.reinterpret_cast %arg0 to 
+    %result1 = memref.reinterpret_cast %arg0 to
       offset: [9],
       sizes: [4, 4],
       strides: [16, 2]
     : memref<8x8xf32, strided<[8, 1], offset: 0>> to
       memref<4x4xf32, strided<[16, 2], offset: 9>>
 
-    %result2 = memref.reinterpret_cast %result1 to 
+    %result2 = memref.reinterpret_cast %result1 to
       offset: [0],
       sizes: [2, 2],
       strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
         OpBuilder &b, Location loc, MemRefType expandedType,
         ArrayRef<ReassociationIndices> reassociation,
         ArrayRef<OpFoldResult> inputShape);
+
+    // Return a vector with all the static and dynamic values in the output shape.
+    SmallVector<OpFoldResult> getMixedOutputShape() {
+      OpBuilder builder(getContext());
+      return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
+    }
   }];
 
   let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
   let summary = "store operation";
   let description = [{
     The `store` op stores an element into a memref at the specified indices.
-    
+
     The number of indices must match the rank of the memref. The indices must
     be in-bounds: `0 <= idx < dim_size`
 
@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     Unlike the `reinterpret_cast`, the values are relative to the strided
     memref of the input (`%result1` in this case) and not its
     underlying memory.
-    
+
     Example 2:
 
     ```mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e4fb3f9bb87ed..2acb90613e5d1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -59,92 +59,28 @@ using namespace mlir;
 ///
 /// %2 = load %0[6 * i1 + i2, %i3] :
 ///          memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
-                                memref::ExpandShapeOp expandShapeOp,
-                                ValueRange indices,
-                                SmallVectorImpl<Value> &sourceIndices) {
-  // Record the rewriter context for constructing ops later.
-  MLIRContext *ctx = rewriter.getContext();
-
-  // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
-  // This is done for the purpose of inferring the output shape via
-  // `inferExpandOutputShape` which will in turn be used for suffix product
-  // calculation later.
-  SmallVector<OpFoldResult> srcShape;
-  MemRefType srcType = expandShapeOp.getSrcType();
-
-  for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
-    if (srcType.isDynamicDim(i)) {
-      srcShape.push_back(
-          rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
-              .getResult());
-    } else {
-      srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
-    }
-  }
-
-  auto outputShape = inferExpandShapeOutputShape(
-      rewriter, loc, expandShapeOp.getResultType(),
-      expandShapeOp.getReassociationIndices(), srcShape);
-  if (!outputShape.has_value())
-    return failure();
+static LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
 
   // Traverse all reassociation groups to determine the appropriate indices
   // corresponding to each one of them post op folding.
-  for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
-    assert(!groups.empty() && "association indices groups cannot be empty");
-    // Flag to indicate the presence of dynamic dimensions in current
-    // reassociation group.
-    int64_t groupSize = groups.size();
-
-    // Group output dimensions utilized in this reassociation group for suffix
-    // product calculation.
-    SmallVector<OpFoldResult> sizesVal(groupSize);
-    for (int64_t i = 0; i < groupSize; ++i) {
-      sizesVal[i] = (*outputShape)[groups[i]];
+  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+    if (groupSize == 1) {
+      sourceIndices.push_back(indices[group[0]]);
+      continue;
     }
-
-    // Calculate suffix product of relevant output dimension sizes.
-    SmallVector<OpFoldResult> suffixProduct =
-        memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
-
-    // Create affine expression variables for dimensions and symbols in the
-    // newly constructed affine map.
-    SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
-    bindDimsList<AffineExpr>(ctx, dims);
-    bindSymbolsList<AffineExpr>(ctx, symbols);
-
-    // Linearize binded dimensions and symbols to construct the resultant
-    // affine expression for this indice.
-    AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
-
-    // Record the load index corresponding to each dimension in the
-    // reassociation group. These are later supplied as operands to the affine
-    // map used for calulating relevant index post op folding.
-    SmallVector<OpFoldResult> dynamicIndices(groupSize);
-    for (int64_t i = 0; i < groupSize; i++)
-      dynamicIndices[i] = indices[groups[i]];
-
-    // Supply suffix product results followed by load op indices as operands
-    // to the map.
-    SmallVector<OpFoldResult> mapOperands;
-    llvm::append_range(mapOperands, suffixProduct);
-    llvm::append_range(mapOperands, dynamicIndices);
-
-    // Creating maximally folded and composed affine.apply composes better
-    // with other transformations without interleaving canonicalization
-    // passes.
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc,
-        AffineMap::get(/*numDims=*/groupSize,
-                       /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
-        mapOperands);
-
-    // Push index value in the op post folding corresponding to this
-    // reassociation group.
-    sourceIndices.push_back(
-        getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+    SmallVector<OpFoldResult> groupBasis =
+        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+    SmallVector<Value> groupIndices =
+        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+    sourceIndices.push_back(collapsedIndex);
   }
   return success();
 }
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
                                   memref::CollapseShapeOp collapseShapeOp,
                                   ValueRange indices,
                                   SmallVectorImpl<Value> &sourceIndices) {
-  int64_t cnt = 0;
-  SmallVector<OpFoldResult> dynamicIndices;
-  for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
-    assert(!groups.empty() && "association indices groups cannot be empty");
-    dynamicIndices.push_back(indices[cnt++]);
-    int64_t groupSize = groups.size();
-
-    // Calculate suffix product for all collapse op source dimension sizes
-    // except the most major one of each group.
-    // We allow the most major source dimension to be dynamic but enforce all
-    // others to be known statically.
-    SmallVector<int64_t> sizes(groupSize, 1);
-    for (int64_t i = 1; i < groupSize; ++i) {
-      sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
-      if (sizes[i] == ShapedType::kDynamic)
-        return failure();
-    }
-    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
-
-    // Derive the index values along all dimensions of the source corresponding
-    // to the index wrt to collapsed shape op output.
-    auto d0 = rewriter.getAffineDimExpr(0);
-    SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
-
-    // Construct the AffineApplyOp for each delinearizingExpr.
-    for (int64_t i = 0; i < groupSize; i++) {
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc,
-          AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
-                         delinearizingExprs[i]),
-          dynamicIndices);
-      sourceIndices.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  MemRefType sourceType = collapseShapeOp.getSrcType();
+  // Note: collapse_shape requires a strided memref, we can do this.
+  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+      loc, collapseShapeOp.getSrc());
+  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+  for (auto [index, group] :
+       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+
+    if (groupSize == 1) {
+      sourceIndices.push_back(index);
+      continue;
     }
-    dynamicIndices.clear();
+
+    SmallVector<OpFoldResult> basis =
+        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, index, basis, /*hasOuterBound=*/true);
+    llvm::append_range(sourceIndices, delinearize.getResults());
   }
   if (collapseShapeOp.getReassociationIndices().empty()) {
     auto zeroAffineMap = rewriter.getConstantAffineMap(0);
     int64_t srcRank =
         cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
     for (int64_t i = 0; i < srcRank; i++) {
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, zeroAffineMap, dynamicIndices);
       sourceIndices.push_back(
           getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
     }
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
+  // memref.load and affine.load guarantee that indexes start inbounds
+  // while the vector operations don't. This impacts if our linearization
+  // is `disjoint`
   if (failed(resolveSourceIndicesExpandShape(
-          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+          isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
       .Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
+  // memref.store and affine.store guarantee that indexes start inbounds
+  // while the vector operations don't. This impacts if our linearization
+  // is `disjoint`
   if (failed(resolveSourceIndicesExpandShape(
-          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+          isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
       .Case([&](affine::AffineStoreOp op) {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index a27fbf26e13d8..106652623933f 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -408,7 +408,6 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
 
 // -----
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
@@ -416,14 +415,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0
   %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
   return %1 : f32
 }
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
 // CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
 // CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -431,15 +428,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
   %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
   return %1 : f32
 }
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
 // CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
 // CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -447,14 +441,28 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
   %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
   return %1 : f32
 }
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+  %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
@@ -462,7 +470,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
   %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
   return %1 : f32
 }
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
 // CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
@@ -476,7 +484,10 @@ func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x
   %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return %0 : f32
 }
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return %[[VAL1]] : f32
 
 // -----
@@ -490,14 +501,16 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
   memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return
 }
-// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return
 
 // -----
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
 // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
 func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -513,21 +526,20 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
   }
   return
 }
+// CHECK-NEXT:   %[[C0:.*]] = arith.constant 0...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/138930


More information about the llvm-branch-commits mailing list