[Mlir-commits] [mlir] [mlir][memref]: Add expand/collapse rewrite pattern to MemRef::CopyOp (PR #67808)
Aviad Cohen
llvmlistbot at llvm.org
Sun Oct 8 00:13:12 PDT 2023
https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/67808
>From a760c42f0c8b75361f822e1efcbdae30151b2180 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviadcohen7 at gmail.com>
Date: Fri, 29 Sep 2023 15:32:18 +0300
Subject: [PATCH] [mlir][memref]: Add expand/collapse rewrite pattern to
MemRef::CopyOp
This pattern is useful to adjust the memref copy ranks.
---
.../MemRef/Transforms/ExpandCollapseCopyOps.h | 45 ++++
.../Dialect/MemRef/Transforms/CMakeLists.txt | 1 +
.../Transforms/ExpandCollapseCopyOps.cpp | 238 ++++++++++++++++++
.../Transforms/expand-collapse-copy-ops.mlir | 141 +++++++++++
mlir/test/lib/Dialect/MemRef/CMakeLists.txt | 1 +
.../MemRef/TestExpandCollapseCopyOps.cpp | 66 +++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
7 files changed, 494 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
create mode 100644 mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
create mode 100644 mlir/test/Transforms/expand-collapse-copy-ops.mlir
create mode 100644 mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
new file mode 100644
index 000000000000000..27a69ab93e42c74
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
@@ -0,0 +1,45 @@
+//===-- ExpandCollapseCopyOps.h - Expand/Collapse MemRef copy ranks --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Patterns for expand collapse MemRef copies.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+#include <functional>
+
+namespace mlir {
+class MLIRContext;
+class RewritePatternSet;
+
+namespace memref {
+
+typedef std::function<bool(memref::CopyOp)> ExpandCollapseFuncCB;
+inline bool expandCollapseAny([[maybe_unused]] memref::CopyOp copyOp) {
+ return true;
+}
+
+/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks
+/// if a `memref::CopyOp` should be expanded/collapsed into `minRank`
+/// `maxRank` ranks. A selective callback may be provided to distinguish
+/// which operations should be expanded/collapsed.
+/// In some cases (i.e. the source/target are strided in whole dims),
+/// it will not be possible to expanded/collapsed the `memref::CopyOp`.
+
+void populateExpandCollapseCopyOpsPatterns(
+ RewritePatternSet &patterns, unsigned minRank = 1, unsigned maxRank = 1,
+ ExpandCollapseFuncCB funcCB = expandCollapseAny);
+
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index b16c281c93640ea..924feca4cad3012 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
AllocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
ComposeSubView.cpp
+ ExpandCollapseCopyOps.cpp
ExpandOps.cpp
ExpandRealloc.cpp
ExpandStridedMetadata.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
new file mode 100644
index 000000000000000..7905254e71e19fc
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
@@ -0,0 +1,238 @@
+//===- ExpandCollapseCopyOps.cpp - Expand/Collapse rank of source/target copies
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------------===//
+//
+// This file contains rewrite patterns (transformations) to expand/collapse
+// MemRef copies. This is useful in architecture which have limitations on
+// dimensions of the copy operation.
+//
+//===--------------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "expand-collapse-copy-ops"
+
+using namespace mlir;
+
+#ifndef NDEBUG
+static inline std::string shape_to_string(ArrayRef<int64_t> shape);
+#endif // NDEBUG
+
+namespace {
+/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks
+/// if a `memref::CopyOp` should be expanded/collapsed into `minRank`
+/// `maxRank` ranks. A selective callback may be provided to distinguish
+/// which operations should be expanded/collapsed.
+/// In some cases (i.e. the source/target are strided in each dim),
+/// it will not be possible to expand/collapse the `memref::CopyOp`.
+
+struct ExpandCollapseCopyOpConverter : public OpRewritePattern<memref::CopyOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ ExpandCollapseCopyOpConverter(MLIRContext *context, unsigned minRank,
+ unsigned maxRank,
+ memref::ExpandCollapseFuncCB funcCB)
+ : OpRewritePattern<memref::CopyOp>(context, /*benefit=*/1),
+ minRank(minRank), maxRank(maxRank), funcCB(funcCB) {
+ assert(minRank <= maxRank && "invalid ranks range");
+ }
+
+ LogicalResult matchAndRewrite(memref::CopyOp copyOp,
+ PatternRewriter &rewriter) const final {
+ MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+ unsigned rank = memRefType.getRank();
+
+ if (!funcCB(copyOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Skip rewriting " << copyOp << ", filtered by funcCB\n");
+ return failure();
+ } else if (rank >= minRank && rank <= maxRank) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Skip rewriting " << copyOp
+ << ", operation does not need to expand/collapse\n");
+ return failure();
+ }
+
+ if (rank > maxRank) {
+ return collapseCopyOpRank(copyOp, maxRank, rewriter);
+ } else {
+ assert(rank < minRank);
+ expandCopyOpRank(copyOp, minRank, rewriter);
+ // Expand is always successful.
+ return success();
+ }
+ }
+
+private:
+ unsigned minRank;
+ unsigned maxRank;
+ // Accept callback to select which `memref::CopyOp` to collapse/expand.
+ memref::ExpandCollapseFuncCB funcCB;
+
+ // Expand the `copyOp` source/target dims to newRank by
+ // adding new dims in size of `1`.
+ void expandCopyOpRank(memref::CopyOp copyOp, unsigned newRank,
+ PatternRewriter &rewriter) const;
+ // Collapse the `copyOp` source/target dims to newRank.
+ // The function tries to collapse starting from the most inner dims
+ // to the most outer dims.
+ // This function return failure if there are no dims to collapse.
+ LogicalResult collapseCopyOpRank(memref::CopyOp copyOp, unsigned newRank,
+ PatternRewriter &rewriter) const;
+ // Fill `collapsedShape` with a shape in size of `newRank`.
+ // The function tries to collapse starting from the most inner dims
+ // to the most outer dims of `memrefToCollapse`.
+ // This function return failure if there are no dims to collapse.
+ LogicalResult getCollapsedShape(MemRefType memrefToCollapse, unsigned newRank,
+ SmallVector<int64_t> &collapsedShape) const;
+};
+
+} // namespace
+
+void ExpandCollapseCopyOpConverter::expandCopyOpRank(
+ memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const {
+ MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+
+ // New outer most dims will be 1s, rest dims are same as original shape.
+ auto shape = memRefType.getShape();
+ SmallVector<int64_t> newShape(newRank - memRefType.getRank(), 1);
+ newShape.insert(newShape.end(), shape.begin(), shape.end());
+
+#ifdef NDEBUG
+ LLVM_DEBUG(llvm::dbgs() << "Expanding shape " << shape_to_string(shape)
+ << " to " << shape_to_string(newShape) << "\n");
+#endif // NDEBUG
+
+ // Expand reassociation is the same as collapse with opposing source/target
+ // shapes.
+ std::optional<SmallVector<ReassociationIndices>> reassociation =
+ getReassociationIndicesForCollapse(newShape, shape);
+ assert(reassociation && "expected reassociation to be valid for expand");
+
+ rewriter.setInsertionPoint(copyOp);
+ Value expandShapeSrc = rewriter.create<memref::ExpandShapeOp>(
+ copyOp.getLoc(), newShape, copyOp.getSource(), *reassociation);
+ Value expandShapeTarget = rewriter.create<memref::ExpandShapeOp>(
+ copyOp.getLoc(), newShape, copyOp.getTarget(), *reassociation);
+
+ rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, expandShapeSrc,
+ expandShapeTarget);
+}
+
+LogicalResult ExpandCollapseCopyOpConverter::collapseCopyOpRank(
+ memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const {
+ MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+
+ auto shape = memRefType.getShape();
+ SmallVector<int64_t> collapsedShape;
+ if (failed(getCollapsedShape(memRefType, newRank, collapsedShape)))
+ return failure();
+
+ std::optional<SmallVector<ReassociationIndices>> reassociation =
+ getReassociationIndicesForCollapse(shape, collapsedShape);
+ assert(reassociation && "expected reassociation to be valid for collapse");
+
+ rewriter.setInsertionPoint(copyOp);
+ Value collapseShapeSrc = rewriter.create<memref::CollapseShapeOp>(
+ copyOp.getLoc(), copyOp.getSource(), *reassociation);
+ Value collapseShapeTarget = rewriter.create<memref::CollapseShapeOp>(
+ copyOp.getLoc(), copyOp.getTarget(), *reassociation);
+
+ rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, collapseShapeSrc,
+ collapseShapeTarget);
+
+ return success();
+}
+
+LogicalResult ExpandCollapseCopyOpConverter::getCollapsedShape(
+ MemRefType memrefToCollapse, unsigned newRank,
+ SmallVector<int64_t> &collapsedShape) const {
+ auto shape = memrefToCollapse.getShape();
+ auto rank = memrefToCollapse.getRank();
+ int dimsToCollapse = rank - newRank;
+ assert(dimsToCollapse > 0);
+
+ // Try to find `dimsToCollapse` dims we can collapse, starting with most inner
+ // dim to collapse.
+ for (int firstDimToCollapse = rank - dimsToCollapse - 1;
+ firstDimToCollapse >= 0; --firstDimToCollapse) {
+ SmallVector<int64_t> newShape;
+
+ unsigned collapsedDims =
+ std::accumulate(shape.begin() + firstDimToCollapse,
+ shape.begin() + firstDimToCollapse + dimsToCollapse + 1,
+ 1, std::multiplies<unsigned>());
+
+ // Generate new shape in `newRank` size. All collapse dims we be to set
+ // `collapsedDims`.
+ for (int i = 0; i < rank; ++i) {
+ if (i == firstDimToCollapse)
+ newShape.push_back(collapsedDims);
+ else if (i < firstDimToCollapse ||
+ i > firstDimToCollapse + dimsToCollapse)
+ newShape.push_back(shape[i]);
+ }
+ assert(newShape.size() == newRank);
+ assert(std::accumulate(shape.begin(), shape.end(), 1,
+ std::multiplies<unsigned>()) ==
+ std::accumulate(newShape.begin(), newShape.end(), 1,
+ std::multiplies<unsigned>()));
+
+#ifdef NDEBUG
+ LLVM_DEBUG(llvm::dbgs()
+ << "trying to collapse shape " << shape_to_string(shape)
+ << " to " << shape_to_string(newShape) << "\n");
+#endif // NDEBUG
+
+ std::optional<SmallVector<ReassociationIndices>> reassociation =
+ getReassociationIndicesForCollapse(shape, newShape);
+ assert(reassociation && "reassociation must be valid for collapse");
+ if (memref::CollapseShapeOp::isGuaranteedCollapsible(memrefToCollapse,
+ *reassociation)) {
+ collapsedShape = std::move(newShape);
+ return success();
+ }
+ }
+
+ return failure();
+}
+
+#ifndef NDEBUG
+static inline std::string shape_to_string(ArrayRef<int64_t> shape) {
+ std::ostringstream shapeStream;
+
+ for (auto dim : shape) {
+ shapeStream << dim << 'x';
+ }
+
+ std::string shapeStr = shapeStream.str();
+
+ // Remove the trailing 'x' character.
+ if (!shapeStr.empty()) {
+ shapeStr.pop_back();
+ }
+
+ return shapeStr;
+}
+#endif // NDEBUG
+
+void memref::populateExpandCollapseCopyOpsPatterns(
+ RewritePatternSet &patterns, unsigned minRank, unsigned maxRank,
+ memref::ExpandCollapseFuncCB funcCB) {
+ patterns.add<ExpandCollapseCopyOpConverter>(patterns.getContext(), minRank,
+ maxRank, funcCB);
+}
diff --git a/mlir/test/Transforms/expand-collapse-copy-ops.mlir b/mlir/test/Transforms/expand-collapse-copy-ops.mlir
new file mode 100644
index 000000000000000..b3cd187424e084b
--- /dev/null
+++ b/mlir/test/Transforms/expand-collapse-copy-ops.mlir
@@ -0,0 +1,141 @@
+// RUN: mlir-opt -test-expand-collapse-copy-ops="minRank=2 maxRank=3" %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @empty() {
+// CHECK: return
+// CHECK: }
+func.func @empty() -> () {
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @memref_copy_to_expand(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<6xi32>) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<6xi32>
+// CHECK: %[[VAL_2:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32>
+// CHECK: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_1]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32>
+// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x6xi32> to memref<1x6xi32>
+// CHECK: return
+// CHECK: }
+func.func @memref_copy_to_expand(%arg0: memref<6xi32>) {
+ %0 = memref.alloc() : memref<6xi32>
+ memref.copy %arg0, %0 : memref<6xi32> to memref<6xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @memref_copy_to_collapse(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x5x24x48xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x5x24x48xi32>) {
+// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32>
+// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x5x1152xi32> to memref<1x5x1152xi32>
+// CHECK: return
+// CHECK: }
+func.func @memref_copy_to_collapse(%arg0: memref<1x5x24x48xi32>, %arg1: memref<1x5x24x48xi32>) {
+ memref.copy %arg0, %arg1 : memref<1x5x24x48xi32> to memref<1x5x24x48xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @memref_copy_collapse_expand_in_loop(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x5x24x48xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 5760 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32>
+// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] {
+// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_6]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK: %[[VAL_11:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<16xf32>
+// CHECK: %[[VAL_14:.*]] = memref.expand_shape %[[VAL_10]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK: %[[VAL_15:.*]] = memref.expand_shape %[[VAL_13]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32>
+// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<16xf32>
+// CHECK: %[[VAL_17:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK: %[[VAL_18:.*]] = memref.expand_shape %[[VAL_16]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK: memref.copy %[[VAL_17]], %[[VAL_18]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32>
+// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<16xf32>
+// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%[[VAL_13]], %[[VAL_16]] : memref<16xf32>, memref<16xf32>) outs(%[[VAL_19]] : memref<16xf32>) {
+// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK: linalg.yield %[[VAL_23]] : f32
+// CHECK: }
+// CHECK: %[[VAL_24:.*]] = memref.expand_shape %[[VAL_19]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK: %[[VAL_25:.*]] = memref.expand_shape %[[VAL_12]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK: memref.copy %[[VAL_24]], %[[VAL_25]] : memref<1x16xf32> to memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK: }
+// CHECK: return %[[VAL_5]] : memref<1x5x24x48xf32>
+// CHECK: }
+#map = affine_map<(d0) -> (d0)>
+module {
+ func.func @memref_copy_collapse_expand_in_loop(%arg0: memref<1x5x24x48xf32>, %arg1: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> {
+ %c5760 = arith.constant 5760 : index
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32>
+ %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+ %collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+ %collapse_shape_1 = memref.collapse_shape %alloc [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+ scf.for %arg2 = %c0 to %c5760 step %c16 {
+ %subview = memref.subview %collapse_shape[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+ %subview_2 = memref.subview %collapse_shape_0[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+ %subview_3 = memref.subview %collapse_shape_1[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+ %alloc_4 = memref.alloc() : memref<16xf32>
+ memref.copy %subview, %alloc_4 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32>
+ %alloc_5 = memref.alloc() : memref<16xf32>
+ memref.copy %subview_2, %alloc_5 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32>
+ %alloc_6 = memref.alloc() : memref<16xf32>
+ linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%alloc_4, %alloc_5 : memref<16xf32>, memref<16xf32>) outs(%alloc_6 : memref<16xf32>) {
+ ^bb0(%in: f32, %in_7: f32, %out: f32):
+ %0 = arith.addf %in, %in_7 : f32
+ linalg.yield %0 : f32
+ }
+ memref.copy %alloc_6, %subview_3 : memref<16xf32> to memref<16xf32, strided<[1], offset: ?>>
+ }
+ return %alloc : memref<1x5x24x48xf32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @memref_copy_strided_to_collapse(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x5x24x48xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x5x24x48xi32>) {
+// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]][0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>>
+// CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_1]][0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>>
+// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2], [3]] : memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>> into memref<1x120x24xi32, strided<[5760, 48, 1]>>
+// CHECK: %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0], [1, 2], [3]] : memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1]>> into memref<1x120x24xi32, strided<[5760, 48, 1]>>
+// CHECK: memref.copy %[[VAL_4]], %[[VAL_5]] : memref<1x120x24xi32, strided<[5760, 48, 1]>> to memref<1x120x24xi32, strided<[5760, 48, 1]>>
+// CHECK: return
+// CHECK: }
+func.func @memref_copy_strided_to_collapse(%arg0: memref<1x5x24x48xi32>, %arg1: memref<1x5x24x48xi32>) {
+ %subview = memref.subview %arg0[0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>>
+ %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 5, 24, 24] [1, 1, 1, 1] : memref<1x5x24x48xi32> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>>
+ memref.copy %subview, %subview0 : memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>> to memref<1x5x24x24xi32, strided<[5760, 1152, 48, 1], offset: 0>>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @memref_copy_strided_cant_collapse(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<2x6x24x48xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<2x6x24x48xi32>) {
+// CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]][0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>>
+// CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_1]][0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>>
+// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1]>>
+// CHECK: return
+// CHECK: }
+func.func @memref_copy_strided_cant_collapse(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) {
+ %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ memref.copy %subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ return
+}
diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
index 0498de3eb93178b..d665620b42a57b8 100644
--- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRMemRefTestPasses
TestComposeSubView.cpp
TestEmulateNarrowType.cpp
TestMultiBuffer.cpp
+ TestExpandCollapseCopyOps.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp b/mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp
new file mode 100644
index 000000000000000..446a70b538cdc9d
--- /dev/null
+++ b/mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp
@@ -0,0 +1,66 @@
+//===- TestExpandCollapseCopyOps.cpp.cpp - Test expand collapse copies ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to test the expand collapse copies patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestExpandCollapseCopyOpsPass
+ : public PassWrapper<TestExpandCollapseCopyOpsPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandCollapseCopyOpsPass)
+
+ TestExpandCollapseCopyOpsPass() = default;
+ TestExpandCollapseCopyOpsPass(const TestExpandCollapseCopyOpsPass &pass)
+ : PassWrapper(pass) {}
+
+ StringRef getArgument() const final {
+ return "test-expand-collapse-copy-ops";
+ }
+ StringRef getDescription() const final {
+ return "Test expand collapse copies";
+ }
+ void runOnOperation() override;
+ void getDependentDialects(DialectRegistry ®istry) const override;
+
+ Option<unsigned> minRank{
+ *this, "minRank",
+ llvm::cl::desc("Minimum rank allowed for a MemRef Copy."),
+ llvm::cl::init(2)};
+ Option<unsigned> maxRank{
+ *this, "maxRank",
+ llvm::cl::desc("Maximum rank allowed for a MemRef Copy."),
+ llvm::cl::init(3)};
+};
+
+void TestExpandCollapseCopyOpsPass::getDependentDialects(
+ DialectRegistry ®istry) const {
+ registry.insert<memref::MemRefDialect>();
+}
+
+void TestExpandCollapseCopyOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ memref::populateExpandCollapseCopyOpsPatterns(patterns, minRank, maxRank);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestExpandCollapseCopyOps() {
+ PassRegistration<TestExpandCollapseCopyOpsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index b7647d7de78a10e..ec2ba8838fd68d2 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -92,6 +92,7 @@ void registerTestEmulateNarrowTypePass();
void registerTestExpandMathPass();
void registerTestFooAnalysisPass();
void registerTestComposeSubView();
+void registerTestExpandCollapseCopyOps();
void registerTestMultiBuffering();
void registerTestIntRangeInference();
void registerTestIRVisitorsPass();
@@ -214,6 +215,7 @@ void registerTestPasses() {
mlir::test::registerTestExpandMathPass();
mlir::test::registerTestFooAnalysisPass();
mlir::test::registerTestComposeSubView();
+ mlir::test::registerTestExpandCollapseCopyOps();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestIntRangeInference();
mlir::test::registerTestIRVisitorsPass();
More information about the Mlir-commits
mailing list