[Mlir-commits] [mlir] [mlir][memref]: Add expand/collapse rewrite pattern to MemRef::CopyOp (PR #67808)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 29 07:18:38 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

<details>
<summary>Changes</summary>

This pattern is useful to adjust the memref copy ranks.

---

Patch is 26.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67808.diff


7 Files Affected:

- (added) mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h (+45) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp (+238) 
- (added) mlir/test/Transforms/expand-collapse-copy-ops.mlir (+141) 
- (modified) mlir/test/lib/Dialect/MemRef/CMakeLists.txt (+1) 
- (added) mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp (+66) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
new file mode 100644
index 000000000000000..27a69ab93e42c74
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
@@ -0,0 +1,45 @@
+//===-- ExpandCollapseCopyOps.h - Expand/Collapse MemRef copy ranks      --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Patterns for expand collapse MemRef copies.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+#include <functional>
+
+namespace mlir {
+class MLIRContext;
+class RewritePatternSet;
+
+namespace memref {
+
+typedef std::function<bool(memref::CopyOp)> ExpandCollapseFuncCB;
+inline bool expandCollapseAny([[maybe_unused]] memref::CopyOp copyOp) {
+  return true;
+}
+
+/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks
+/// if a `memref::CopyOp` should be expanded/collapsed into `minRank`
+/// `maxRank` ranks. A selective callback may be provided to distinguish
+/// which operations should be expanded/collapsed.
+/// In some cases (i.e. the source/target are strided in whole dims),
+/// it will not be possible to expanded/collapsed the `memref::CopyOp`.
+
+void populateExpandCollapseCopyOpsPatterns(
+    RewritePatternSet &patterns, unsigned minRank = 1, unsigned maxRank = 1,
+    ExpandCollapseFuncCB funcCB = expandCollapseAny);
+
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index b16c281c93640ea..924feca4cad3012 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   AllocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   ComposeSubView.cpp
+  ExpandCollapseCopyOps.cpp
   ExpandOps.cpp
   ExpandRealloc.cpp
   ExpandStridedMetadata.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
new file mode 100644
index 000000000000000..7905254e71e19fc
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
@@ -0,0 +1,238 @@
+//===- ExpandCollapseCopyOps.cpp - Expand/Collapse rank of source/target copies
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------------===//
+//
+// This file contains rewrite patterns (transformations) to expand/collapse
+// MemRef copies. This is useful in architecture which have limitations on
+// dimensions of the copy operation.
+//
+//===--------------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "expand-collapse-copy-ops"
+
+using namespace mlir;
+
+#ifndef NDEBUG
+static inline std::string shape_to_string(ArrayRef<int64_t> shape);
+#endif // NDEBUG
+
+namespace {
+/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks
+/// if a `memref::CopyOp` should be expanded/collapsed into `minRank`
+/// `maxRank` ranks. A selective callback may be provided to distinguish
+/// which operations should be expanded/collapsed.
+/// In some cases (i.e. the source/target are strided in each dim),
+/// it will not be possible to expand/collapse the `memref::CopyOp`.
+
+struct ExpandCollapseCopyOpConverter : public OpRewritePattern<memref::CopyOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  ExpandCollapseCopyOpConverter(MLIRContext *context, unsigned minRank,
+                                unsigned maxRank,
+                                memref::ExpandCollapseFuncCB funcCB)
+      : OpRewritePattern<memref::CopyOp>(context, /*benefit=*/1),
+        minRank(minRank), maxRank(maxRank), funcCB(funcCB) {
+    assert(minRank <= maxRank && "invalid ranks range");
+  }
+
+  LogicalResult matchAndRewrite(memref::CopyOp copyOp,
+                                PatternRewriter &rewriter) const final {
+    MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+    unsigned rank = memRefType.getRank();
+
+    if (!funcCB(copyOp)) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Skip rewriting " << copyOp << ", filtered by funcCB\n");
+      return failure();
+    } else if (rank >= minRank && rank <= maxRank) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Skip rewriting " << copyOp
+                 << ", operation does not need to expand/collapse\n");
+      return failure();
+    }
+
+    if (rank > maxRank) {
+      return collapseCopyOpRank(copyOp, maxRank, rewriter);
+    } else {
+      assert(rank < minRank);
+      expandCopyOpRank(copyOp, minRank, rewriter);
+      // Expand is always successful.
+      return success();
+    }
+  }
+
+private:
+  unsigned minRank;
+  unsigned maxRank;
+  // Accept callback to select which `memref::CopyOp` to collapse/expand.
+  memref::ExpandCollapseFuncCB funcCB;
+
+  // Expand the `copyOp` source/target dims to newRank by
+  // adding new dims in size of `1`.
+  void expandCopyOpRank(memref::CopyOp copyOp, unsigned newRank,
+                        PatternRewriter &rewriter) const;
+  // Collapse the `copyOp` source/target dims to newRank.
+  // The function tries to collapse starting from the most inner dims
+  // to the most outer dims.
+  // This function return failure if there are no dims to collapse.
+  LogicalResult collapseCopyOpRank(memref::CopyOp copyOp, unsigned newRank,
+                                   PatternRewriter &rewriter) const;
+  // Fill `collapsedShape` with a shape in size of `newRank`.
+  // The function tries to collapse starting from the most inner dims
+  // to the most outer dims of `memrefToCollapse`.
+  // This function return failure if there are no dims to collapse.
+  LogicalResult getCollapsedShape(MemRefType memrefToCollapse, unsigned newRank,
+                                  SmallVector<int64_t> &collapsedShape) const;
+};
+
+} // namespace
+
+void ExpandCollapseCopyOpConverter::expandCopyOpRank(
+    memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const {
+  MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+
+  // New outer most dims will be 1s, rest dims are same as original shape.
+  auto shape = memRefType.getShape();
+  SmallVector<int64_t> newShape(newRank - memRefType.getRank(), 1);
+  newShape.insert(newShape.end(), shape.begin(), shape.end());
+
+#ifdef NDEBUG
+  LLVM_DEBUG(llvm::dbgs() << "Expanding shape " << shape_to_string(shape)
+                          << " to " << shape_to_string(newShape) << "\n");
+#endif // NDEBUG
+
+  // Expand reassociation is the same as collapse with opposing source/target
+  // shapes.
+  std::optional<SmallVector<ReassociationIndices>> reassociation =
+      getReassociationIndicesForCollapse(newShape, shape);
+  assert(reassociation && "expected reassociation to be valid for expand");
+
+  rewriter.setInsertionPoint(copyOp);
+  Value expandShapeSrc = rewriter.create<memref::ExpandShapeOp>(
+      copyOp.getLoc(), newShape, copyOp.getSource(), *reassociation);
+  Value expandShapeTarget = rewriter.create<memref::ExpandShapeOp>(
+      copyOp.getLoc(), newShape, copyOp.getTarget(), *reassociation);
+
+  rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, expandShapeSrc,
+                                              expandShapeTarget);
+}
+
+LogicalResult ExpandCollapseCopyOpConverter::collapseCopyOpRank(
+    memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const {
+  MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+
+  auto shape = memRefType.getShape();
+  SmallVector<int64_t> collapsedShape;
+  if (failed(getCollapsedShape(memRefType, newRank, collapsedShape)))
+    return failure();
+
+  std::optional<SmallVector<ReassociationIndices>> reassociation =
+      getReassociationIndicesForCollapse(shape, collapsedShape);
+  assert(reassociation && "expected reassociation to be valid for collapse");
+
+  rewriter.setInsertionPoint(copyOp);
+  Value collapseShapeSrc = rewriter.create<memref::CollapseShapeOp>(
+      copyOp.getLoc(), copyOp.getSource(), *reassociation);
+  Value collapseShapeTarget = rewriter.create<memref::CollapseShapeOp>(
+      copyOp.getLoc(), copyOp.getTarget(), *reassociation);
+
+  rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, collapseShapeSrc,
+                                              collapseShapeTarget);
+
+  return success();
+}
+
+LogicalResult ExpandCollapseCopyOpConverter::getCollapsedShape(
+    MemRefType memrefToCollapse, unsigned newRank,
+    SmallVector<int64_t> &collapsedShape) const {
+  auto shape = memrefToCollapse.getShape();
+  auto rank = memrefToCollapse.getRank();
+  int dimsToCollapse = rank - newRank;
+  assert(dimsToCollapse > 0);
+
+  // Try to find `dimsToCollapse` dims we can collapse, starting with most inner
+  // dim to collapse.
+  for (int firstDimToCollapse = rank - dimsToCollapse - 1;
+       firstDimToCollapse >= 0; --firstDimToCollapse) {
+    SmallVector<int64_t> newShape;
+
+    unsigned collapsedDims =
+        std::accumulate(shape.begin() + firstDimToCollapse,
+                        shape.begin() + firstDimToCollapse + dimsToCollapse + 1,
+                        1, std::multiplies<unsigned>());
+
+    // Generate new shape in `newRank` size. All collapse dims we be to set
+    // `collapsedDims`.
+    for (int i = 0; i < rank; ++i) {
+      if (i == firstDimToCollapse)
+        newShape.push_back(collapsedDims);
+      else if (i < firstDimToCollapse ||
+               i > firstDimToCollapse + dimsToCollapse)
+        newShape.push_back(shape[i]);
+    }
+    assert(newShape.size() == newRank);
+    assert(std::accumulate(shape.begin(), shape.end(), 1,
+                           std::multiplies<unsigned>()) ==
+           std::accumulate(newShape.begin(), newShape.end(), 1,
+                           std::multiplies<unsigned>()));
+
+#ifdef NDEBUG
+    LLVM_DEBUG(llvm::dbgs()
+               << "trying to collapse shape " << shape_to_string(shape)
+               << " to " << shape_to_string(newShape) << "\n");
+#endif // NDEBUG
+
+    std::optional<SmallVector<ReassociationIndices>> reassociation =
+        getReassociationIndicesForCollapse(shape, newShape);
+    assert(reassociation && "reassociation must be valid for collapse");
+    if (memref::CollapseShapeOp::isGuaranteedCollapsible(memrefToCollapse,
+                                                         *reassociation)) {
+      collapsedShape = std::move(newShape);
+      return success();
+    }
+  }
+
+  return failure();
+}
+
+#ifndef NDEBUG
+static inline std::string shape_to_string(ArrayRef<int64_t> shape) {
+  std::ostringstream shapeStream;
+
+  for (auto dim : shape) {
+    shapeStream << dim << 'x';
+  }
+
+  std::string shapeStr = shapeStream.str();
+
+  // Remove the trailing 'x' character.
+  if (!shapeStr.empty()) {
+    shapeStr.pop_back();
+  }
+
+  return shapeStr;
+}
+#endif // NDEBUG
+
+void memref::populateExpandCollapseCopyOpsPatterns(
+    RewritePatternSet &patterns, unsigned minRank, unsigned maxRank,
+    memref::ExpandCollapseFuncCB funcCB) {
+  patterns.add<ExpandCollapseCopyOpConverter>(patterns.getContext(), minRank,
+                                              maxRank, funcCB);
+}
diff --git a/mlir/test/Transforms/expand-collapse-copy-ops.mlir b/mlir/test/Transforms/expand-collapse-copy-ops.mlir
new file mode 100644
index 000000000000000..b3cd187424e084b
--- /dev/null
+++ b/mlir/test/Transforms/expand-collapse-copy-ops.mlir
@@ -0,0 +1,141 @@
+// RUN: mlir-opt -test-expand-collapse-copy-ops="minRank=2 maxRank=3" %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL:   func.func @empty() {
+// CHECK:           return
+// CHECK:         }
+func.func @empty() -> () {
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_to_expand(
+// CHECK-SAME:                                     %[[VAL_0:.*]]: memref<6xi32>) {
+// CHECK:           %[[VAL_1:.*]] = memref.alloc() : memref<6xi32>
+// CHECK:           %[[VAL_2:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32>
+// CHECK:           %[[VAL_3:.*]] = memref.expand_shape %[[VAL_1]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32>
+// CHECK:           memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x6xi32> to memref<1x6xi32>
+// CHECK:           return
+// CHECK:         }
+func.func @memref_copy_to_expand(%arg0: memref<6xi32>) {
+  %0 = memref.alloc() : memref<6xi32>
+  memref.copy %arg0, %0 : memref<6xi32> to memref<6xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_to_collapse(
+// CHECK-SAME:                                       %[[VAL_0:.*]]: memref<1x5x24x48xi32>,
+// CHECK-SAME:                                       %[[VAL_1:.*]]: memref<1x5x24x48xi32>) {
+// CHECK:           %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32>
+// CHECK:           memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x5x1152xi32> to memref<1x5x1152xi32>
+// CHECK:           return
+// CHECK:         }
+func.func @memref_copy_to_collapse(%arg0: memref<1x5x24x48xi32>, %arg1: memref<1x5x24x48xi32>) {
+  memref.copy %arg0, %arg1 : memref<1x5x24x48xi32> to memref<1x5x24x48xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_collapse_expand_in_loop(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: memref<1x5x24x48xf32>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 5760 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32>
+// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK:           %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_10:.*]] = memref.subview %[[VAL_6]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK:             %[[VAL_11:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK:             %[[VAL_12:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK:             %[[VAL_13:.*]] = memref.alloc() : memref<16xf32>
+// CHECK:             %[[VAL_14:.*]] = memref.expand_shape %[[VAL_10]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:             %[[VAL_15:.*]] = memref.expand_shape %[[VAL_13]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK:             memref.copy %[[VAL_14]], %[[VAL_15]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32>
+// CHECK:             %[[VAL_16:.*]] = memref.alloc() : memref<16xf32>
+// CHECK:             %[[VAL_17:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:             %[[VAL_18:.*]] = memref.expand_shape %[[VAL_16]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK:             memref.copy %[[VAL_17]], %[[VAL_18]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32>
+// CHECK:             %[[VAL_19:.*]] = memref.alloc() : memref<16xf32>
+// CHECK:             linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%[[VAL_13]], %[[VAL_16]] : memref<16xf32>, memref<16xf32>) outs(%[[VAL_19]] : memref<16xf32>) {
+// CHECK:             ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+// CHECK:               %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK:               linalg.yield %[[VAL_23]] : f32
+// CHECK:             }
+// CHECK:             %[[VAL_24:.*]] = memref.expand_shape %[[VAL_19]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK:             %[[VAL_25:.*]] = memref.expand_shape %[[VAL_12]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:             memref.copy %[[VAL_24]], %[[VAL_25]] : memref<1x16xf32> to memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:           }
+// CHECK:           return %[[VAL_5]] : memref<1x5x24x48xf32>
+// CHECK:         }
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @memref_copy_collapse_expand_in_loop(%arg0: memref<1x5x24x48xf32>, %arg1: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> {
+    %c5760 = arith.constant 5760 : index
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32>
+    %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+    %collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+    %collapse_shape_1 = memref.collapse_shape %alloc [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+    scf.for %arg2 = %c0 to %c5760 step %c16 {
+      %subview = memref.subview %collapse_shape[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+      %subview_2 = memref.subview %collapse_shape_0[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+      %subview_3 = memref.subview %collapse_shape_1[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+      %alloc_4 = memref.alloc() : memref<16xf32>
+      memref.copy %subview, %alloc_4 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32>
+      %alloc_5 = memref.alloc() : memref<16xf32>
+      memref.copy %subview_2, %alloc_5 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32>
+      %alloc_6 = memref.alloc() : memref<16xf32>
+      linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%alloc_4, %alloc_5 : memref<16xf32>, memref<16xf32>) outs(%alloc_6 : memref<16xf32>) {
+      ^bb0(%in: f32, %in_7: f32, %out: f32):
+        %0 = arith.addf %in, %in_7 : f32
+        linalg.yield %0 : f32
+      }
+      memref.copy %alloc_6, %subview_3 : memref<16xf32> to memref<16xf32, strided<[1], offset: ?>>
+    }
+    return %alloc : memref<1x5x24x48xf32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_strided_to_collapse(
+// CHECK-SAME:                                               %[[VAL_0:.*]]: memref<1x5x24x48xi32>,
+// CHECK-SAME:                 ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/67808


More information about the Mlir-commits mailing list