[Mlir-commits] [mlir] [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp (PR #68526)

Aviad Cohen llvmlistbot at llvm.org
Sun Oct 22 08:07:38 PDT 2023


https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/68526

>From 41c5c80b69b730d4f20120d62780a2849de99d99 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Sun, 8 Oct 2023 16:02:32 +0300
Subject: [PATCH] [mlir][linalg] Enable CollapseLinalgDimensions to collapse
 linalg::CopyOp

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  18 +-
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 186 ++++++++++--------
 mlir/test/Dialect/Linalg/collapse-dim.mlir    |  37 ++++
 .../Linalg/TestLinalgElementwiseFusion.cpp    |   2 +-
 4 files changed, 150 insertions(+), 93 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3597209d7f90c25..fbe2923c710aabb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1047,16 +1047,18 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
 bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
                               ArrayRef<ReassociationIndices> dimSequences);
 
-/// Collapses dimensions of linalg.generic operation. A precondition to
-/// calling this method is that for each list in `foldedIterationDim`, the
+/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
+/// to calling this method is that for each list in `foldedIterationDim`, the
 /// sequence of dimensions is contiguous in domains of all `indexing_maps` of
-/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
+/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
 /// When valid, the method also collapses the operands of the op. Returns
-/// replacement values of the results of the original `genericOp` by inserting
+/// replacement values of the results of the original `linalgOp` by inserting
 /// reshapes to get back values of compatible types.
-FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
-    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
-    RewriterBase &rewriter);
+template <typename LinalgType>
+FailureOr<SmallVector<Value>>
+collapseOpIterationDims(LinalgType op,
+                        ArrayRef<ReassociationIndices> foldedIterationDims,
+                        RewriterBase &rewriter);
 
 struct LowerPackResult {
   tensor::PadOp padOp;
@@ -1515,7 +1517,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
 /// to return an array of `ReassociationIndices` representing dimensions that
 /// should be merged.
 using GetCollapsableDimensionsFn =
-    std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
+    std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
 
 /// Pattern to collapse dimensions in a linalg.generic op. This will collapse
 /// tensor operands when needed and expand back the result tensors.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6f4b0ff60ca97c6..35d7d86fd8f1d7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1373,16 +1373,17 @@ getOperandReassociation(AffineMap indexingMap,
 }
 
 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
-static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
+static Value getCollapsedOpOperand(Location loc, LinalgOp op,
                                    OpOperand *opOperand,
                                    const CollapsingInfo &collapsingInfo,
                                    OpBuilder &builder) {
-  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
+  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
   SmallVector<ReassociationIndices> operandReassociation =
       getOperandReassociation(indexingMap, collapsingInfo);
 
-  // If the number of entries in the reassocation for the operand is same as the
-  // number of results of the indexing map, then nothing to do for this operand.
+  // If the number of entries in the reassociation for the operand is same as
+  // the number of results of the indexing map, then nothing to do for this
+  // operand.
   Value operand = opOperand->get();
   if (operandReassociation.size() == indexingMap.getNumResults())
     return operand;
@@ -1439,20 +1440,80 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
   }
 }
 
+template <typename LinalgType>
+Operation *createCollapsedOp(LinalgType op,
+                             const CollapsingInfo &collapsingInfo,
+                             RewriterBase &rewriter) {
+  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
+                "unsupported linalg op type to create");
+  Location loc = op->getLoc();
+
+  // Get the input operands.
+  SmallVector<Value> inputOperands =
+      llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
+        return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
+                                     rewriter);
+      });
+
+  // Get the output operands and result types.
+  SmallVector<Type> resultTypes;
+  SmallVector<Value> outputOperands;
+  resultTypes.reserve(op.getNumDpsInits());
+  outputOperands.reserve(op.getNumDpsInits());
+  for (OpOperand &output : op.getDpsInitsMutable()) {
+    Value newOutput =
+        getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
+    outputOperands.push_back(newOutput);
+    // If the op has "buffer semantics", then the init operands are ranked
+    // memrefs and the op has no results.
+    if (!op.hasBufferSemantics())
+      resultTypes.push_back(newOutput.getType());
+  }
+
+  if (isa<linalg::CopyOp>(op)) {
+    return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
+                                           outputOperands[0]);
+  }
+
+  // Get the iterator types for the operand.
+  SmallVector<utils::IteratorType> iteratorTypes =
+      getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
+
+  // Get the indexing maps.
+  auto indexingMaps =
+      llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
+        return getCollapsedOpIndexingMap(map, collapsingInfo);
+      });
+
+  Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
+      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
+      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
+  Block *origOpBlock = &op->getRegion(0).front();
+  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
+  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
+                       collapsedOpBlock->getArguments());
+
+  return collapsedOp;
+}
+
 /// Implementation of fusion with reshape operation by collapsing dimensions.
-FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
-    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
+template <typename LinalgType>
+FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
+    LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter) {
+  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
+                "unsupported linalg op type to collapse");
+
   // Bail on trivial no-op cases.
-  if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
+  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
         return foldedDims.size() <= 1;
       }))
     return failure();
 
-  bool hasBufferSemantics = genericOp.hasBufferSemantics();
+  bool hasBufferSemantics = op.hasBufferSemantics();
   if (hasBufferSemantics &&
-      !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
+      !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
         MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
         if (!memRefToCollapse)
           return true;
@@ -1460,20 +1521,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
         return memref::CollapseShapeOp::isGuaranteedCollapsible(
             memRefToCollapse, foldedIterationDims);
       }))
-    return rewriter.notifyMatchFailure(genericOp,
+    return rewriter.notifyMatchFailure(op,
                                        "memref is not guaranteed collapsible");
 
   CollapsingInfo collapsingInfo;
-  if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
-                                       foldedIterationDims))) {
+  if (failed(
+          collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
     return rewriter.notifyMatchFailure(
-        genericOp, "illegal to collapse specified dimensions");
+        op, "illegal to collapse specified dimensions");
   }
 
   // Bail on non-canonical ranges.
   SmallVector<Range> loopRanges =
-      cast<LinalgOp>(genericOp.getOperation())
-          .createLoopRanges(rewriter, genericOp.getLoc());
+      cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
   auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
       return cast<IntegerAttr>(attr).getInt() == value;
@@ -1486,78 +1546,36 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
                opFoldIsConstantValue(range.stride, 1);
       })) {
     return rewriter.notifyMatchFailure(
-        genericOp,
-        "expected all loop ranges to have zero start and unit stride");
+        op, "expected all loop ranges to have zero start and unit stride");
   }
 
-  // Get the iterator types for the operand.
-  SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
-      genericOp.getIteratorTypesArray(), collapsingInfo);
-
-  // Get the indexing maps.
-  auto indexingMaps = llvm::to_vector(
-      llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
-        return getCollapsedOpIndexingMap(map, collapsingInfo);
-      }));
-
-  Location loc = genericOp->getLoc();
-
-  // Get the input operands.
-  auto inputOperands = llvm::to_vector(llvm::map_range(
-      genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
-        return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
-                                     rewriter);
-      }));
-
-  // Get the output operands and result types.
-  SmallVector<Type> resultTypes;
-  SmallVector<Value> outputOperands;
-  resultTypes.reserve(genericOp.getNumDpsInits());
-  outputOperands.reserve(genericOp.getNumDpsInits());
-  for (OpOperand &output : genericOp.getDpsInitsMutable()) {
-    Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
-                                            collapsingInfo, rewriter);
-    outputOperands.push_back(newOutput);
-    // If the op has "buffer semantics", then the init operands are ranked
-    // memrefs and the op has no results.
-    if (!hasBufferSemantics)
-      resultTypes.push_back(newOutput.getType());
-  }
-
-  // Create the generic op.
-  auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
-      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
-      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
-  Block *origOpBlock = &genericOp->getRegion(0).front();
-  Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
-  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
-                       collapsedOpBlock->getArguments());
+  LinalgType collapsedOp = cast<LinalgType>(
+      createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
 
-  if (collapsedGenericOp.hasIndexSemantics()) {
+  Location loc = op->getLoc();
+  if (collapsedOp.hasIndexSemantics()) {
     // Collect the loop range of the generic op.
     OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(collapsedGenericOp);
+    rewriter.setInsertionPoint(collapsedOp);
     SmallVector<Value> loopBound =
-        llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
+        llvm::map_to_vector(loopRanges, [&](Range range) {
           return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
-        }));
-    generateCollapsedIndexingRegion(loc,
-                                    &collapsedGenericOp->getRegion(0).front(),
+        });
+    generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
                                     collapsingInfo, loopBound, rewriter);
   }
 
   // Insert expanding reshape for the result to get back the original result
   // type.
   SmallVector<Value> results;
-  for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
-    Value collapsedOpResult =
-        collapsedGenericOp->getResult(originalResult.index());
+  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
+    Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
     auto originalResultType =
         cast<ShapedType>(originalResult.value().getType());
     auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
       AffineMap indexingMap =
-          genericOp.getIndexingMapMatchingResult(originalResult.value());
+          op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
       if (isa<MemRefType>(collapsedOpResult.getType())) {
@@ -1606,8 +1624,8 @@ class FoldWithProducerReshapeOpByCollapsing
       }
 
       std::optional<SmallVector<Value>> replacements =
-          collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
-                                         rewriter);
+          collapseOpIterationDims<linalg::GenericOp>(
+              genericOp, collapsableIterationDims, rewriter);
       if (!replacements) {
         return rewriter.notifyMatchFailure(
             genericOp, "failed to do the fusion by collapsing transformation");
@@ -1624,36 +1642,36 @@ class FoldWithProducerReshapeOpByCollapsing
 };
 
 /// Pattern to collapse dimensions.
-class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
+template <typename LinalgType>
+class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
 public:
   CollapseLinalgDimensions(MLIRContext *context,
                            GetCollapsableDimensionsFn collapseDimensions,
                            PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit),
+      : OpRewritePattern<LinalgType>(context, benefit),
         controlCollapseDimension(std::move(collapseDimensions)) {}
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(LinalgType op,
                                 PatternRewriter &rewriter) const override {
     SmallVector<ReassociationIndices> collapsableIterationDims =
-        controlCollapseDimension(genericOp);
+        controlCollapseDimension(op);
     if (collapsableIterationDims.empty())
       return failure();
 
     // Check if the specified list of dimensions to collapse is a valid list.
-    if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
+    if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
                                   collapsableIterationDims)) {
       return rewriter.notifyMatchFailure(
-          genericOp, "specified dimensions cannot be collapsed");
+          op, "specified dimensions cannot be collapsed");
     }
 
     std::optional<SmallVector<Value>> replacements =
-        collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
-                                       rewriter);
+        collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
+                                            rewriter);
     if (!replacements) {
-      return rewriter.notifyMatchFailure(genericOp,
-                                         "failed to collapse dimensions");
+      return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
     }
-    rewriter.replaceOp(genericOp, *replacements);
+    rewriter.replaceOp(op, *replacements);
     return success();
   }
 
@@ -1884,8 +1902,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
 void mlir::linalg::populateCollapseDimensions(
     RewritePatternSet &patterns,
     const GetCollapsableDimensionsFn &controlCollapseDimensions) {
-  patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
-                                         controlCollapseDimensions);
+  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>, CollapseLinalgDimensions<linalg::CopyOp>>(
+      patterns.getContext(), controlCollapseDimensions);
 }
 
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 106154ba3a553bd..547320f53387477 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -116,3 +116,40 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
   }
   return %alloc : memref<2x6x24x48xi32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @linalg_copy(
+// CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
+// CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
+// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
+// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
+// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
+// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:         }
+
+func.func @linalg_copy(
+    %arg0: tensor<1x2x3x4x5xf32, 1>, %arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> {
+  %0 = linalg.copy ins(%arg0: tensor<1x2x3x4x5xf32, 1>) outs(%arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3>
+  return %0 : tensor<1x2x3x4x5xf32, 3>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func private @memref_linalg_copy(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<1x24x32x8xf32, 1>,
+// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<1x24x32x8xf32, 1>) {
+// CHECK:           %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
+// CHECK:           linalg.copy ins(%[[VAL_2]] : memref<1x24x256xf32, 1>) outs(%[[VAL_3]] : memref<1x24x256xf32, 1>)
+// CHECK:           return
+// CHECK:         }
+
+func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: memref<1x24x32x8xf32, 1>) {
+  linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
+  return
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index e41481a9e51364e..7f68f4aec3a10c3 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -258,7 +258,7 @@ struct TestLinalgElementwiseFusion
       SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
                                    collapseDimensions.end());
       linalg::GetCollapsableDimensionsFn collapseFn =
-          [&dims](linalg::GenericOp op) {
+          [&dims](linalg::LinalgOp op) {
             SmallVector<ReassociationIndices> reassociations;
             reassociations.emplace_back(dims);
             return reassociations;



More information about the Mlir-commits mailing list