[Mlir-commits] [mlir] [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp (PR #68526)

Aviad Cohen llvmlistbot at llvm.org
Sun Oct 8 06:30:27 PDT 2023


https://github.com/AviadCo created https://github.com/llvm/llvm-project/pull/68526

- [mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations
- [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp


>From 5e59fd1725a54a3ab5ab6f8d1cb46f3bb06fb8ce Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Fri, 6 Oct 2023 15:07:37 +0300
Subject: [PATCH 1/2] [mlir][linalg] Enable CollapseLinalgDimensions to
 collapse memref based operations

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 42 ++++++++++++++---
 mlir/test/Dialect/Linalg/collapse-dim.mlir    | 46 +++++++++++++++++++
 2 files changed, 81 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 069c613cc246d6a..6f4b0ff60ca97c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1388,9 +1388,15 @@ static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
     return operand;
 
   // Insert a reshape to collapse the dimensions.
-  auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
-      loc, operand, operandReassociation);
-  return reshapeOp.getResult();
+  if (isa<MemRefType>(operand.getType())) {
+    return builder
+        .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
+        .getResult();
+  } else {
+    return builder
+        .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
+        .getResult();
+  }
 }
 
 /// Modify the `linalg.index` operations in the original generic op, to its
@@ -1444,6 +1450,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
       }))
     return failure();
 
+  bool hasBufferSemantics = genericOp.hasBufferSemantics();
+  if (hasBufferSemantics &&
+      !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
+        MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
+        if (!memRefToCollapse)
+          return true;
+
+        return memref::CollapseShapeOp::isGuaranteedCollapsible(
+            memRefToCollapse, foldedIterationDims);
+      }))
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "memref is not guaranteed collapsible");
+
   CollapsingInfo collapsingInfo;
   if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
                                        foldedIterationDims))) {
@@ -1499,7 +1518,10 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
     Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
                                             collapsingInfo, rewriter);
     outputOperands.push_back(newOutput);
-    resultTypes.push_back(newOutput.getType());
+    // If the op has "buffer semantics", then the init operands are ranked
+    // memrefs and the op has no results.
+    if (!hasBufferSemantics)
+      resultTypes.push_back(newOutput.getType());
   }
 
   // Create the generic op.
@@ -1538,9 +1560,15 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
           genericOp.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
-      Value result = rewriter.create<tensor::ExpandShapeOp>(
-          loc, originalResultType, collapsedOpResult, reassociation);
-      results.push_back(result);
+      if (isa<MemRefType>(collapsedOpResult.getType())) {
+        Value result = rewriter.create<memref::ExpandShapeOp>(
+            loc, originalResultType, collapsedOpResult, reassociation);
+        results.push_back(result);
+      } else {
+        Value result = rewriter.create<tensor::ExpandShapeOp>(
+            loc, originalResultType, collapsedOpResult, reassociation);
+        results.push_back(result);
+      }
     } else {
       results.push_back(collapsedOpResult);
     }
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 6737a6e15da5afe..106154ba3a553bd 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -70,3 +70,49 @@ func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41
 // CHECK-LABEL: func @uncollapsable(
 //       CHECK:   linalg.generic
 //  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+// -----
+
+// CHECK-LABEL:   func.func private @collapsable_memref(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
+// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
+// CHECK:           %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
+// CHECK:           ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+// CHECK:             %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK:             linalg.yield %[[VAL_9]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_2]] : memref<1x24x32x8xf32>
+// CHECK:         }
+
+func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>) -> (memref<1x24x32x8xf32>) {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%alloc : memref<1x24x32x8xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %0 = arith.addf %in, %in_0 : f32
+    linalg.yield %0 : f32
+  }
+  return %alloc : memref<1x24x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @uncollapsable_strided_memref(
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) -> (memref<2x6x24x48xi32>) {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x6x24x48xi32>
+  %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  %subview1 = memref.subview %alloc[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>, memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) outs(%subview1 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %0 = arith.addi %in, %in_0 : i32
+    linalg.yield %0 : i32
+  }
+  return %alloc : memref<2x6x24x48xi32>
+}

>From 17609b475479f2ac955ac92c6f8292e13d312eb8 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 8 Oct 2023 16:02:32 +0300
Subject: [PATCH 2/2] [mlir][linalg] Enable CollapseLinalgDimensions to
 collapse linalg::CopyOp

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  18 +--
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 130 ++++++++++--------
 mlir/test/Dialect/Linalg/collapse-dim.mlir    |  37 +++++
 .../Linalg/TestLinalgElementwiseFusion.cpp    |   2 +-
 4 files changed, 120 insertions(+), 67 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 07a192f7b8606d3..0b0be116ce1c1d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1047,16 +1047,18 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
 bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
                               ArrayRef<ReassociationIndices> dimSequences);
 
-/// Collapses dimensions of linalg.generic operation. A precondition to
-/// calling this method is that for each list in `foldedIterationDim`, the
+/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
+/// to calling this method is that for each list in `foldedIterationDim`, the
 /// sequence of dimensions is contiguous in domains of all `indexing_maps` of
-/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
+/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
 /// When valid, the method also collapses the operands of the op. Returns
-/// replacement values of the results of the original `genericOp` by inserting
+/// replacement values of the results of the original `linalgOp` by inserting
 /// reshapes to get back values of compatible types.
-FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
-    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
-    RewriterBase &rewriter);
+template <typename LinalgType>
+FailureOr<SmallVector<Value>>
+collapseOpIterationDims(LinalgType op,
+                        ArrayRef<ReassociationIndices> foldedIterationDims,
+                        RewriterBase &rewriter);
 
 struct LowerPackResult {
   tensor::PadOp padOp;
@@ -1507,7 +1509,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
 /// to return an array of `ReassociationIndices` representing dimensions that
 /// should be merged.
 using GetCollapsableDimensionsFn =
-    std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
+    std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
 
 /// Pattern to collapse dimensions in a linalg.generic op. This will collapse
 /// tensor operands when needed and expand back the result tensors.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6f4b0ff60ca97c6..3e5f0ec24ffde99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1373,16 +1373,17 @@ getOperandReassociation(AffineMap indexingMap,
 }
 
 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
-static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
+static Value getCollapsedOpOperand(Location loc, LinalgOp op,
                                    OpOperand *opOperand,
                                    const CollapsingInfo &collapsingInfo,
                                    OpBuilder &builder) {
-  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
+  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
   SmallVector<ReassociationIndices> operandReassociation =
       getOperandReassociation(indexingMap, collapsingInfo);
 
-  // If the number of entries in the reassocation for the operand is same as the
-  // number of results of the indexing map, then nothing to do for this operand.
+  // If the number of entries in the reassociation for the operand is same as
+  // the number of results of the indexing map, then nothing to do for this
+  // operand.
   Value operand = opOperand->get();
   if (operandReassociation.size() == indexingMap.getNumResults())
     return operand;
@@ -1440,19 +1441,23 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
 }
 
 /// Implementation of fusion with reshape operation by collapsing dimensions.
-FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
-    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
+template <typename LinalgType>
+FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
+    LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter) {
+  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
+                "unsupported linalg op type to collapse");
+
   // Bail on trivial no-op cases.
-  if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
+  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
         return foldedDims.size() <= 1;
       }))
     return failure();
 
-  bool hasBufferSemantics = genericOp.hasBufferSemantics();
+  bool hasBufferSemantics = op.hasBufferSemantics();
   if (hasBufferSemantics &&
-      !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
+      !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
         MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
         if (!memRefToCollapse)
           return true;
@@ -1460,20 +1465,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
         return memref::CollapseShapeOp::isGuaranteedCollapsible(
             memRefToCollapse, foldedIterationDims);
       }))
-    return rewriter.notifyMatchFailure(genericOp,
+    return rewriter.notifyMatchFailure(op,
                                        "memref is not guaranteed collapsible");
 
   CollapsingInfo collapsingInfo;
-  if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
-                                       foldedIterationDims))) {
+  if (failed(
+          collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
     return rewriter.notifyMatchFailure(
-        genericOp, "illegal to collapse specified dimensions");
+        op, "illegal to collapse specified dimensions");
   }
 
   // Bail on non-canonical ranges.
   SmallVector<Range> loopRanges =
-      cast<LinalgOp>(genericOp.getOperation())
-          .createLoopRanges(rewriter, genericOp.getLoc());
+      cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
   auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
       return cast<IntegerAttr>(attr).getInt() == value;
@@ -1486,37 +1490,36 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
                opFoldIsConstantValue(range.stride, 1);
       })) {
     return rewriter.notifyMatchFailure(
-        genericOp,
-        "expected all loop ranges to have zero start and unit stride");
+        op, "expected all loop ranges to have zero start and unit stride");
   }
 
   // Get the iterator types for the operand.
-  SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
-      genericOp.getIteratorTypesArray(), collapsingInfo);
+  SmallVector<utils::IteratorType> iteratorTypes =
+      getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
 
   // Get the indexing maps.
   auto indexingMaps = llvm::to_vector(
-      llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
+      llvm::map_range(op.getIndexingMapsArray(), [&](AffineMap map) {
         return getCollapsedOpIndexingMap(map, collapsingInfo);
       }));
 
-  Location loc = genericOp->getLoc();
+  Location loc = op->getLoc();
 
   // Get the input operands.
-  auto inputOperands = llvm::to_vector(llvm::map_range(
-      genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
-        return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
+  auto inputOperands = llvm::to_vector(
+      llvm::map_range(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
+        return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
                                      rewriter);
       }));
 
   // Get the output operands and result types.
   SmallVector<Type> resultTypes;
   SmallVector<Value> outputOperands;
-  resultTypes.reserve(genericOp.getNumDpsInits());
-  outputOperands.reserve(genericOp.getNumDpsInits());
-  for (OpOperand &output : genericOp.getDpsInitsMutable()) {
-    Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
-                                            collapsingInfo, rewriter);
+  resultTypes.reserve(op.getNumDpsInits());
+  outputOperands.reserve(op.getNumDpsInits());
+  for (OpOperand &output : op.getDpsInitsMutable()) {
+    Value newOutput =
+        getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
     outputOperands.push_back(newOutput);
     // If the op has "buffer semantics", then the init operands are ranked
     // memrefs and the op has no results.
@@ -1525,39 +1528,48 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
   }
 
   // Create the generic op.
-  auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
-      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
-      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
-  Block *origOpBlock = &genericOp->getRegion(0).front();
-  Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
-  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
-                       collapsedOpBlock->getArguments());
-
-  if (collapsedGenericOp.hasIndexSemantics()) {
+  Operation *collapsedOp;
+  if (isa<linalg::GenericOp>(op)) {
+    collapsedOp = rewriter.create<linalg::GenericOp>(
+        loc, resultTypes, inputOperands, outputOperands, indexingMaps,
+        iteratorTypes,
+        [](OpBuilder &builder, Location loc, ValueRange args) {});
+    Block *origOpBlock = &op->getRegion(0).front();
+    Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
+    rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
+                         collapsedOpBlock->getArguments());
+  } else {
+    assert(isa<linalg::CopyOp>(op));
+    collapsedOp = rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
+                                                  outputOperands[0]);
+  }
+  LinalgType collapsedLinalgOp = cast<LinalgType>(collapsedOp);
+
+  if (collapsedLinalgOp.hasIndexSemantics()) {
     // Collect the loop range of the generic op.
     OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(collapsedGenericOp);
+    rewriter.setInsertionPoint(collapsedLinalgOp);
     SmallVector<Value> loopBound =
         llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
           return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
         }));
     generateCollapsedIndexingRegion(loc,
-                                    &collapsedGenericOp->getRegion(0).front(),
+                                    &collapsedLinalgOp->getRegion(0).front(),
                                     collapsingInfo, loopBound, rewriter);
   }
 
   // Insert expanding reshape for the result to get back the original result
   // type.
   SmallVector<Value> results;
-  for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
+  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
     Value collapsedOpResult =
-        collapsedGenericOp->getResult(originalResult.index());
+        collapsedLinalgOp->getResult(originalResult.index());
     auto originalResultType =
         cast<ShapedType>(originalResult.value().getType());
     auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
       AffineMap indexingMap =
-          genericOp.getIndexingMapMatchingResult(originalResult.value());
+          op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
       if (isa<MemRefType>(collapsedOpResult.getType())) {
@@ -1606,8 +1618,8 @@ class FoldWithProducerReshapeOpByCollapsing
       }
 
       std::optional<SmallVector<Value>> replacements =
-          collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
-                                         rewriter);
+          collapseOpIterationDims<linalg::GenericOp>(
+              genericOp, collapsableIterationDims, rewriter);
       if (!replacements) {
         return rewriter.notifyMatchFailure(
             genericOp, "failed to do the fusion by collapsing transformation");
@@ -1624,36 +1636,36 @@ class FoldWithProducerReshapeOpByCollapsing
 };
 
 /// Pattern to collapse dimensions.
-class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
+template <typename LinalgType>
+class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
 public:
   CollapseLinalgDimensions(MLIRContext *context,
                            GetCollapsableDimensionsFn collapseDimensions,
                            PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit),
+      : OpRewritePattern<LinalgType>(context, benefit),
         controlCollapseDimension(std::move(collapseDimensions)) {}
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(LinalgType op,
                                 PatternRewriter &rewriter) const override {
     SmallVector<ReassociationIndices> collapsableIterationDims =
-        controlCollapseDimension(genericOp);
+        controlCollapseDimension(op);
     if (collapsableIterationDims.empty())
       return failure();
 
     // Check if the specified list of dimensions to collapse is a valid list.
-    if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
+    if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
                                   collapsableIterationDims)) {
       return rewriter.notifyMatchFailure(
-          genericOp, "specified dimensions cannot be collapsed");
+          op, "specified dimensions cannot be collapsed");
     }
 
     std::optional<SmallVector<Value>> replacements =
-        collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
-                                       rewriter);
+        collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
+                                            rewriter);
     if (!replacements) {
-      return rewriter.notifyMatchFailure(genericOp,
-                                         "failed to collapse dimensions");
+      return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
     }
-    rewriter.replaceOp(genericOp, *replacements);
+    rewriter.replaceOp(op, *replacements);
     return success();
   }
 
@@ -1884,8 +1896,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
 void mlir::linalg::populateCollapseDimensions(
     RewritePatternSet &patterns,
     const GetCollapsableDimensionsFn &controlCollapseDimensions) {
-  patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
-                                         controlCollapseDimensions);
+  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>>(
+      patterns.getContext(), controlCollapseDimensions);
+  patterns.add<CollapseLinalgDimensions<linalg::CopyOp>>(
+      patterns.getContext(), controlCollapseDimensions);
 }
 
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 106154ba3a553bd..547320f53387477 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -116,3 +116,40 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
   }
   return %alloc : memref<2x6x24x48xi32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @linalg_copy(
+// CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
+// CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
+// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
+// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
+// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
+// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:         }
+
+func.func @linalg_copy(
+    %arg0: tensor<1x2x3x4x5xf32, 1>, %arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> {
+  %0 = linalg.copy ins(%arg0: tensor<1x2x3x4x5xf32, 1>) outs(%arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3>
+  return %0 : tensor<1x2x3x4x5xf32, 3>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @memref_linalg_copy(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<1x24x32x8xf32, 1>,
+// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<1x24x32x8xf32, 1>) {
+// CHECK:           %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
+// CHECK:           linalg.copy ins(%[[VAL_2]] : memref<1x24x256xf32, 1>) outs(%[[VAL_3]] : memref<1x24x256xf32, 1>)
+// CHECK:           return
+// CHECK:         }
+
+func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: memref<1x24x32x8xf32, 1>) {
+  linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
+  return
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index e41481a9e51364e..7f68f4aec3a10c3 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -258,7 +258,7 @@ struct TestLinalgElementwiseFusion
       SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
                                    collapseDimensions.end());
       linalg::GetCollapsableDimensionsFn collapseFn =
-          [&dims](linalg::GenericOp op) {
+          [&dims](linalg::LinalgOp op) {
             SmallVector<ReassociationIndices> reassociations;
             reassociations.emplace_back(dims);
             return reassociations;



More information about the Mlir-commits mailing list