[clang] [mlir][memref]: Add expand/collapse rewrite pattern to MemRef::CopyOp (PR #67808)

Aviad Cohen via cfe-commits cfe-commits at lists.llvm.org
Sun Oct 8 00:13:14 PDT 2023


https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/67808

>From a760c42f0c8b75361f822e1efcbdae30151b2180 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviadcohen7 at gmail.com>
Date: Fri, 29 Sep 2023 15:32:18 +0300
Subject: [PATCH] [mlir][memref]: Add expand/collapse rewrite pattern to
 MemRef::CopyOp

This pattern is useful to adjust the memref copy ranks.
---
 .../MemRef/Transforms/ExpandCollapseCopyOps.h |  45 ++++
 .../Dialect/MemRef/Transforms/CMakeLists.txt  |   1 +
 .../Transforms/ExpandCollapseCopyOps.cpp      | 238 ++++++++++++++++++
 .../Transforms/expand-collapse-copy-ops.mlir  | 141 +++++++++++
 mlir/test/lib/Dialect/MemRef/CMakeLists.txt   |   1 +
 .../MemRef/TestExpandCollapseCopyOps.cpp      |  66 +++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 7 files changed, 494 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
 create mode 100644 mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
 create mode 100644 mlir/test/Transforms/expand-collapse-copy-ops.mlir
 create mode 100644 mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
new file mode 100644
index 000000000000000..27a69ab93e42c74
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
@@ -0,0 +1,45 @@
+//===-- ExpandCollapseCopyOps.h - Expand/Collapse MemRef copy ranks      --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Patterns for expand collapse MemRef copies.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+#include <functional>
+
+namespace mlir {
+class MLIRContext;
+class RewritePatternSet;
+
+namespace memref {
+
+typedef std::function<bool(memref::CopyOp)> ExpandCollapseFuncCB;
+inline bool expandCollapseAny([[maybe_unused]] memref::CopyOp copyOp) {
+  return true;
+}
+
+/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks
+/// if a `memref::CopyOp` should be expanded/collapsed into `minRank`
+/// `maxRank` ranks. A selective callback may be provided to distinguish
+/// which operations should be expanded/collapsed.
+/// In some cases (i.e. the source/target are strided in whole dims),
+/// it will not be possible to expanded/collapsed the `memref::CopyOp`.
+
+void populateExpandCollapseCopyOpsPatterns(
+    RewritePatternSet &patterns, unsigned minRank = 1, unsigned maxRank = 1,
+    ExpandCollapseFuncCB funcCB = expandCollapseAny);
+
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index b16c281c93640ea..924feca4cad3012 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   AllocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   ComposeSubView.cpp
+  ExpandCollapseCopyOps.cpp
   ExpandOps.cpp
   ExpandRealloc.cpp
   ExpandStridedMetadata.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
new file mode 100644
index 000000000000000..7905254e71e19fc
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
@@ -0,0 +1,238 @@
+//===- ExpandCollapseCopyOps.cpp - Expand/Collapse rank of source/target copies
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------------===//
+//
+// This file contains rewrite patterns (transformations) to expand/collapse
+// MemRef copies. This is useful in architecture which have limitations on
+// dimensions of the copy operation.
+//
+//===--------------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "expand-collapse-copy-ops"
+
+using namespace mlir;
+
+#ifndef NDEBUG
+static inline std::string shape_to_string(ArrayRef<int64_t> shape);
+#endif // NDEBUG
+
+namespace {
+/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks
+/// if a `memref::CopyOp` should be expanded/collapsed into `minRank`
+/// `maxRank` ranks. A selective callback may be provided to distinguish
+/// which operations should be expanded/collapsed.
+/// In some cases (i.e. the source/target are strided in each dim),
+/// it will not be possible to expand/collapse the `memref::CopyOp`.
+
+struct ExpandCollapseCopyOpConverter : public OpRewritePattern<memref::CopyOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  ExpandCollapseCopyOpConverter(MLIRContext *context, unsigned minRank,
+                                unsigned maxRank,
+                                memref::ExpandCollapseFuncCB funcCB)
+      : OpRewritePattern<memref::CopyOp>(context, /*benefit=*/1),
+        minRank(minRank), maxRank(maxRank), funcCB(funcCB) {
+    assert(minRank <= maxRank && "invalid ranks range");
+  }
+
+  LogicalResult matchAndRewrite(memref::CopyOp copyOp,
+                                PatternRewriter &rewriter) const final {
+    MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+    unsigned rank = memRefType.getRank();
+
+    if (!funcCB(copyOp)) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Skip rewriting " << copyOp << ", filtered by funcCB\n");
+      return failure();
+    } else if (rank >= minRank && rank <= maxRank) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Skip rewriting " << copyOp
+                 << ", operation does not need to expand/collapse\n");
+      return failure();
+    }
+
+    if (rank > maxRank) {
+      return collapseCopyOpRank(copyOp, maxRank, rewriter);
+    } else {
+      assert(rank < minRank);
+      expandCopyOpRank(copyOp, minRank, rewriter);
+      // Expand is always successful.
+      return success();
+    }
+  }
+
+private:
+  unsigned minRank;
+  unsigned maxRank;
+  // Accept callback to select which `memref::CopyOp` to collapse/expand.
+  memref::ExpandCollapseFuncCB funcCB;
+
+  // Expand the `copyOp` source/target dims to newRank by
+  // adding new dims in size of `1`.
+  void expandCopyOpRank(memref::CopyOp copyOp, unsigned newRank,
+                        PatternRewriter &rewriter) const;
+  // Collapse the `copyOp` source/target dims to newRank.
+  // The function tries to collapse starting from the most inner dims
+  // to the most outer dims.
+  // This function return failure if there are no dims to collapse.
+  LogicalResult collapseCopyOpRank(memref::CopyOp copyOp, unsigned newRank,
+                                   PatternRewriter &rewriter) const;
+  // Fill `collapsedShape` with a shape in size of `newRank`.
+  // The function tries to collapse starting from the most inner dims
+  // to the most outer dims of `memrefToCollapse`.
+  // This function return failure if there are no dims to collapse.
+  LogicalResult getCollapsedShape(MemRefType memrefToCollapse, unsigned newRank,
+                                  SmallVector<int64_t> &collapsedShape) const;
+};
+
+} // namespace
+
+void ExpandCollapseCopyOpConverter::expandCopyOpRank(
+    memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const {
+  MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+
+  // New outer most dims will be 1s, rest dims are same as original shape.
+  auto shape = memRefType.getShape();
+  SmallVector<int64_t> newShape(newRank - memRefType.getRank(), 1);
+  newShape.insert(newShape.end(), shape.begin(), shape.end());
+
+#ifdef NDEBUG
+  LLVM_DEBUG(llvm::dbgs() << "Expanding shape " << shape_to_string(shape)
+                          << " to " << shape_to_string(newShape) << "\n");
+#endif // NDEBUG
+
+  // Expand reassociation is the same as collapse with opposing source/target
+  // shapes.
+  std::optional<SmallVector<ReassociationIndices>> reassociation =
+      getReassociationIndicesForCollapse(newShape, shape);
+  assert(reassociation && "expected reassociation to be valid for expand");
+
+  rewriter.setInsertionPoint(copyOp);
+  Value expandShapeSrc = rewriter.create<memref::ExpandShapeOp>(
+      copyOp.getLoc(), newShape, copyOp.getSource(), *reassociation);
+  Value expandShapeTarget = rewriter.create<memref::ExpandShapeOp>(
+      copyOp.getLoc(), newShape, copyOp.getTarget(), *reassociation);
+
+  rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, expandShapeSrc,
+                                              expandShapeTarget);
+}
+
+LogicalResult ExpandCollapseCopyOpConverter::collapseCopyOpRank(
+    memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const {
+  MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+
+  auto shape = memRefType.getShape();
+  SmallVector<int64_t> collapsedShape;
+  if (failed(getCollapsedShape(memRefType, newRank, collapsedShape)))
+    return failure();
+
+  std::optional<SmallVector<ReassociationIndices>> reassociation =
+      getReassociationIndicesForCollapse(shape, collapsedShape);
+  assert(reassociation && "expected reassociation to be valid for collapse");
+
+  rewriter.setInsertionPoint(copyOp);
+  Value collapseShapeSrc = rewriter.create<memref::CollapseShapeOp>(
+      copyOp.getLoc(), copyOp.getSource(), *reassociation);
+  Value collapseShapeTarget = rewriter.create<memref::CollapseShapeOp>(
+      copyOp.getLoc(), copyOp.getTarget(), *reassociation);
+
+  rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, collapseShapeSrc,
+                                              collapseShapeTarget);
+
+  return success();
+}
+
+LogicalResult ExpandCollapseCopyOpConverter::getCollapsedShape(
+    MemRefType memrefToCollapse, unsigned newRank,
+    SmallVector<int64_t> &collapsedShape) const {
+  auto shape = memrefToCollapse.getShape();
+  auto rank = memrefToCollapse.getRank();
+  int dimsToCollapse = rank - newRank;
+  assert(dimsToCollapse > 0);
+
+  // Try to find `dimsToCollapse` dims we can collapse, starting with most inner
+  // dim to collapse.
+  for (int firstDimToCollapse = rank - dimsToCollapse - 1;
+       firstDimToCollapse >= 0; --firstDimToCollapse) {
+    SmallVector<int64_t> newShape;
+
+    unsigned collapsedDims =
+        std::accumulate(shape.begin() + firstDimToCollapse,
+                        shape.begin() + firstDimToCollapse + dimsToCollapse + 1,
+                        1, std::multiplies<unsigned>());
+
+    // Generate new shape in `newRank` size. All collapse dims we be to set
+    // `collapsedDims`.
+    for (int i = 0; i < rank; ++i) {
+      if (i == firstDimToCollapse)
+        newShape.push_back(collapsedDims);
+      else if (i < firstDimToCollapse ||
+               i > firstDimToCollapse + dimsToCollapse)
+        newShape.push_back(shape[i]);
+    }
+    assert(newShape.size() == newRank);
+    assert(std::accumulate(shape.begin(), shape.end(), 1,
+                           std::multiplies<unsigned>()) ==
+           std::accumulate(newShape.begin(), newShape.end(), 1,
+                           std::multiplies<unsigned>()));
+
+#ifdef NDEBUG
+    LLVM_DEBUG(llvm::dbgs()
+               << "trying to collapse shape " << shape_to_string(shape)
+               << " to " << shape_to_string(newShape) << "\n");
+#endif // NDEBUG
+
+    std::optional<SmallVector<ReassociationIndices>> reassociation =
+        getReassociationIndicesForCollapse(shape, newShape);
+    assert(reassociation && "reassociation must be valid for collapse");
+    if (memref::CollapseShapeOp::isGuaranteedCollapsible(memrefToCollapse,
+                                                         *reassociation)) {
+      collapsedShape = std::move(newShape);
+      return success();
+    }
+  }
+
+  return failure();
+}
+
+#ifndef NDEBUG
+static inline std::string shape_to_string(ArrayRef<int64_t> shape) {
+  std::ostringstream shapeStream;
+
+  for (auto dim : shape) {
+    shapeStream << dim << 'x';
+  }
+
+  std::string shapeStr = shapeStream.str();
+
+  // Remove the trailing 'x' character.
+  if (!shapeStr.empty()) {
+    shapeStr.pop_back();
+  }
+
+  return shapeStr;
+}
+#endif // NDEBUG
+
+void memref::populateExpandCollapseCopyOpsPatterns(
+    RewritePatternSet &patterns, unsigned minRank, unsigned maxRank,
+    memref::ExpandCollapseFuncCB funcCB) {
+  patterns.add<ExpandCollapseCopyOpConverter>(patterns.getContext(), minRank,
+                                              maxRank, funcCB);
+}
diff --git a/mlir/test/Transforms/expand-collapse-copy-ops.mlir b/mlir/test/Transforms/expand-collapse-copy-ops.mlir
new file mode 100644
index 000000000000000..b3cd187424e084b
--- /dev/null
+++ b/mlir/test/Transforms/expand-collapse-copy-ops.mlir
@@ -0,0 +1,141 @@
+// RUN: mlir-opt -test-expand-collapse-copy-ops="minRank=2 maxRank=3" %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL:   func.func @empty() {
+// CHECK:           return
+// CHECK:         }
+func.func @empty() -> () {
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_to_expand(
+// CHECK-SAME:                                     %[[VAL_0:.*]]: memref<6xi32>) {
+// CHECK:           %[[VAL_1:.*]] = memref.alloc() : memref<6xi32>
+// CHECK:           %[[VAL_2:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32>
+// CHECK:           %[[VAL_3:.*]] = memref.expand_shape %[[VAL_1]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32>
+// CHECK:           memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x6xi32> to memref<1x6xi32>
+// CHECK:           return
+// CHECK:         }
+func.func @memref_copy_to_expand(%arg0: memref<6xi32>) {
+  %0 = memref.alloc() : memref<6xi32>
+  memref.copy %arg0, %0 : memref<6xi32> to memref<6xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_to_collapse(
+// CHECK-SAME:                                       %[[VAL_0:.*]]: memref<1x5x24x48xi32>,
+// CHECK-SAME:                                       %[[VAL_1:.*]]: memref<1x5x24x48xi32>) {
+// CHECK:           %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32>
+// CHECK:           memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x5x1152xi32> to memref<1x5x1152xi32>
+// CHECK:           return
+// CHECK:         }
+func.func @memref_copy_to_collapse(%arg0: memref<1x5x24x48xi32>, %arg1: memref<1x5x24x48xi32>) {
+  memref.copy %arg0, %arg1 : memref<1x5x24x48xi32> to memref<1x5x24x48xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_collapse_expand_in_loop(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: memref<1x5x24x48xf32>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 5760 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32>
+// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK:           %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_10:.*]] = memref.subview %[[VAL_6]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK:             %[[VAL_11:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK:             %[[VAL_12:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK:             %[[VAL_13:.*]] = memref.alloc() : memref<16xf32>
+// CHECK:             %[[VAL_14:.*]] = memref.expand_shape %[[VAL_10]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:             %[[VAL_15:.*]] = memref.expand_shape %[[VAL_13]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK:             memref.copy %[[VAL_14]], %[[VAL_15]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32>
+// CHECK:             %[[VAL_16:.*]] = memref.alloc() : memref<16xf32>
+// CHECK:             %[[VAL_17:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:             %[[VAL_18:.*]] = memref.expand_shape %[[VAL_16]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK:             memref.copy %[[VAL_17]], %[[VAL_18]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32>
+// CHECK:             %[[VAL_19:.*]] = memref.alloc() : memref<16xf32>
+// CHECK:             linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%[[VAL_13]], %[[VAL_16]] : memref<16xf32>, memref<16xf32>) outs(%[[VAL_19]] : memref<16xf32>) {
+// CHECK:             ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+// CHECK:               %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK:               linalg.yield %[[VAL_23]] : f32
+// CHECK:             }
+// CHECK:             %[[VAL_24:.*]] = memref.expand_shape %[[VAL_19]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK:             %[[VAL_25:.*]] = memref.expand_shape %[[VAL_12]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:             memref.copy %[[VAL_24]], %[[VAL_25]] : memref<1x16xf32> to memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:           }
+// CHECK:           return %[[VAL_5]] : memref<1x5x24x48xf32>
+// CHECK:         }
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @memref_copy_collapse_expand_in_loop(%arg0: memref<1x5x24x48xf32>, %arg1: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> {
+    %c5760 = arith.constant 5760 : index
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32>
+    %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+    %collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+    %collapse_shape_1 = memref.collapse_shape %alloc [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+    scf.for %arg2 = %c0 to %c5760 step %c16 {
+      %subview = memref.subview %collapse_shape[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+      %subview_2 = memref.subview %collapse_shape_0[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+      %subview_3 = memref.subview %collapse_shape_1[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+      %alloc_4 = memref.alloc() : memref<16xf32>
+      memref.copy %subview, %alloc_4 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32>
+      %alloc_5 = memref.alloc() : memref<16xf32>
+      memref.copy %subview_2, %alloc_5 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32>
+      %alloc_6 = memref.alloc() : memref<16xf32>
+      linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%alloc_4, %alloc_5 : memref<16xf32>, memref<16xf32>) outs(%alloc_6 : memref<16xf32>) {
+      ^bb0(%in: f32, %in_7: f32, %out: f32):
+        %0 = arith.addf %in, %in_7 : f32
+        linalg.yield %0 : f32
+      }
+      memref.copy %alloc_6, %subview_3 : memref<16xf32> to memref<16xf32, strided<[1], offset: ?>>
+    }
+    return %alloc : memref<1x5x24x48xf32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_strided_to_collapse(
+// CHECK-SAME:                                               %[[VAL_0:.*]]: memref<1x5x24x48xi32>,
+// CHECK-SAME:                                               %[[VAL_1:.*]]: memref<1x5x24x48xi32>) {
+// CHECK:           %[[VAL_2:.*]] = memref.subview %[[VAL_0]][0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>>
+// CHECK:           %[[VAL_3:.*]] = memref.subview %[[VAL_1]][0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>>
+// CHECK:           %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2], [3]] : memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>> into memref<1x120x24xi32, strided<[5760, 48, 1]>>
+// CHECK:           %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0], [1, 2], [3]] : memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>> into memref<1x120x24xi32, strided<[5760, 48, 1]>>
+// CHECK:           memref.copy %[[VAL_4]], %[[VAL_5]] : memref<1x120x24xi32, strided<[5760, 48, 1]>> to memref<1x120x24xi32, strided<[5760, 48, 1]>>
+// CHECK:           return
+// CHECK:         }
+func.func @memref_copy_strided_to_collapse(%arg0: memref<1x5x24x48xi32>, %arg1: memref<1x5x24x48xi32>) {
+  %subview = memref.subview %arg0[0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>>
+  %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>>
+  memref.copy %subview, %subview0 : memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>>
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_strided_cant_collapse(
+// CHECK-SAME:                                                 %[[VAL_0:.*]]: memref<2x6x24x48xi32>,
+// CHECK-SAME:                                                 %[[VAL_1:.*]]: memref<2x6x24x48xi32>) {
+// CHECK:           %[[VAL_2:.*]] = memref.subview %[[VAL_0]][0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>>
+// CHECK:           %[[VAL_3:.*]] = memref.subview %[[VAL_1]][0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>>
+// CHECK:           memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>>
+// CHECK:           return
+// CHECK:         }
+func.func @memref_copy_strided_cant_collapse(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) {
+  %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  memref.copy %subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  return
+}
diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
index 0498de3eb93178b..d665620b42a57b8 100644
--- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRMemRefTestPasses
   TestComposeSubView.cpp
   TestEmulateNarrowType.cpp
   TestMultiBuffer.cpp
+  TestExpandCollapseCopyOps.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
diff --git a/mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp b/mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp
new file mode 100644
index 000000000000000..446a70b538cdc9d
--- /dev/null
+++ b/mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp
@@ -0,0 +1,66 @@
+//===- TestExpandCollapseCopyOps.cpp.cpp - Test expand collapse copies ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the expand collapse copies patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestExpandCollapseCopyOpsPass
+    : public PassWrapper<TestExpandCollapseCopyOpsPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandCollapseCopyOpsPass)
+
+  TestExpandCollapseCopyOpsPass() = default;
+  TestExpandCollapseCopyOpsPass(const TestExpandCollapseCopyOpsPass &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const final {
+    return "test-expand-collapse-copy-ops";
+  }
+  StringRef getDescription() const final {
+    return "Test expand collapse copies";
+  }
+  void runOnOperation() override;
+  void getDependentDialects(DialectRegistry &registry) const override;
+
+  Option<unsigned> minRank{
+      *this, "minRank",
+      llvm::cl::desc("Minimum rank allowed for a MemRef Copy."),
+      llvm::cl::init(2)};
+  Option<unsigned> maxRank{
+      *this, "maxRank",
+      llvm::cl::desc("Maximum rank allowed for a MemRef Copy."),
+      llvm::cl::init(3)};
+};
+
+void TestExpandCollapseCopyOpsPass::getDependentDialects(
+    DialectRegistry &registry) const {
+  registry.insert<memref::MemRefDialect>();
+}
+
+void TestExpandCollapseCopyOpsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  memref::populateExpandCollapseCopyOpsPatterns(patterns, minRank, maxRank);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestExpandCollapseCopyOps() {
+  PassRegistration<TestExpandCollapseCopyOpsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index b7647d7de78a10e..ec2ba8838fd68d2 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -92,6 +92,7 @@ void registerTestEmulateNarrowTypePass();
 void registerTestExpandMathPass();
 void registerTestFooAnalysisPass();
 void registerTestComposeSubView();
+void registerTestExpandCollapseCopyOps();
 void registerTestMultiBuffering();
 void registerTestIntRangeInference();
 void registerTestIRVisitorsPass();
@@ -214,6 +215,7 @@ void registerTestPasses() {
   mlir::test::registerTestExpandMathPass();
   mlir::test::registerTestFooAnalysisPass();
   mlir::test::registerTestComposeSubView();
+  mlir::test::registerTestExpandCollapseCopyOps();
   mlir::test::registerTestMultiBuffering();
   mlir::test::registerTestIntRangeInference();
   mlir::test::registerTestIRVisitorsPass();



More information about the cfe-commits mailing list