[Mlir-commits] [mlir] [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp (PR #68526)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 8 06:31:30 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

- [mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations
- [mlir][linalg] Enable CollapseLinalgDimensions to collapse linalg::CopyOp


---

Patch is 23.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68526.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+10-8) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+104-62) 
- (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+83) 
- (modified) mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 07a192f7b8606d3..0b0be116ce1c1d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1047,16 +1047,18 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
 bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
                               ArrayRef<ReassociationIndices> dimSequences);
 
-/// Collapses dimensions of linalg.generic operation. A precondition to
-/// calling this method is that for each list in `foldedIterationDim`, the
+/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
+/// to calling this method is that for each list in `foldedIterationDim`, the
 /// sequence of dimensions is contiguous in domains of all `indexing_maps` of
-/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
+/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
 /// When valid, the method also collapses the operands of the op. Returns
-/// replacement values of the results of the original `genericOp` by inserting
+/// replacement values of the results of the original `linalgOp` by inserting
 /// reshapes to get back values of compatible types.
-FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
-    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
-    RewriterBase &rewriter);
+template <typename LinalgType>
+FailureOr<SmallVector<Value>>
+collapseOpIterationDims(LinalgType op,
+                        ArrayRef<ReassociationIndices> foldedIterationDims,
+                        RewriterBase &rewriter);
 
 struct LowerPackResult {
   tensor::PadOp padOp;
@@ -1507,7 +1509,7 @@ void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns);
 /// to return an array of `ReassociationIndices` representing dimensions that
 /// should be merged.
 using GetCollapsableDimensionsFn =
-    std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
+    std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>;
 
 /// Pattern to collapse dimensions in a linalg.generic op. This will collapse
 /// tensor operands when needed and expand back the result tensors.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 069c613cc246d6a..3e5f0ec24ffde99 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1373,24 +1373,31 @@ getOperandReassociation(AffineMap indexingMap,
 }
 
 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
-static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
+static Value getCollapsedOpOperand(Location loc, LinalgOp op,
                                    OpOperand *opOperand,
                                    const CollapsingInfo &collapsingInfo,
                                    OpBuilder &builder) {
-  AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
+  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
   SmallVector<ReassociationIndices> operandReassociation =
       getOperandReassociation(indexingMap, collapsingInfo);
 
-  // If the number of entries in the reassocation for the operand is same as the
-  // number of results of the indexing map, then nothing to do for this operand.
+  // If the number of entries in the reassociation for the operand is same as
+  // the number of results of the indexing map, then nothing to do for this
+  // operand.
   Value operand = opOperand->get();
   if (operandReassociation.size() == indexingMap.getNumResults())
     return operand;
 
   // Insert a reshape to collapse the dimensions.
-  auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
-      loc, operand, operandReassociation);
-  return reshapeOp.getResult();
+  if (isa<MemRefType>(operand.getType())) {
+    return builder
+        .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
+        .getResult();
+  } else {
+    return builder
+        .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
+        .getResult();
+  }
 }
 
 /// Modify the `linalg.index` operations in the original generic op, to its
@@ -1434,27 +1441,43 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
 }
 
 /// Implementation of fusion with reshape operation by collapsing dimensions.
-FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
-    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
+template <typename LinalgType>
+FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
+    LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter) {
+  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
+                "unsupported linalg op type to collapse");
+
   // Bail on trivial no-op cases.
-  if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
+  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
         return foldedDims.size() <= 1;
       }))
     return failure();
 
+  bool hasBufferSemantics = op.hasBufferSemantics();
+  if (hasBufferSemantics &&
+      !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
+        MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
+        if (!memRefToCollapse)
+          return true;
+
+        return memref::CollapseShapeOp::isGuaranteedCollapsible(
+            memRefToCollapse, foldedIterationDims);
+      }))
+    return rewriter.notifyMatchFailure(op,
+                                       "memref is not guaranteed collapsible");
+
   CollapsingInfo collapsingInfo;
-  if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
-                                       foldedIterationDims))) {
+  if (failed(
+          collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
     return rewriter.notifyMatchFailure(
-        genericOp, "illegal to collapse specified dimensions");
+        op, "illegal to collapse specified dimensions");
   }
 
   // Bail on non-canonical ranges.
   SmallVector<Range> loopRanges =
-      cast<LinalgOp>(genericOp.getOperation())
-          .createLoopRanges(rewriter, genericOp.getLoc());
+      cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
   auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
       return cast<IntegerAttr>(attr).getInt() == value;
@@ -1467,80 +1490,97 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
                opFoldIsConstantValue(range.stride, 1);
       })) {
     return rewriter.notifyMatchFailure(
-        genericOp,
-        "expected all loop ranges to have zero start and unit stride");
+        op, "expected all loop ranges to have zero start and unit stride");
   }
 
   // Get the iterator types for the operand.
-  SmallVector<utils::IteratorType> iteratorTypes = getCollapsedOpIteratorTypes(
-      genericOp.getIteratorTypesArray(), collapsingInfo);
+  SmallVector<utils::IteratorType> iteratorTypes =
+      getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
 
   // Get the indexing maps.
   auto indexingMaps = llvm::to_vector(
-      llvm::map_range(genericOp.getIndexingMapsArray(), [&](AffineMap map) {
+      llvm::map_range(op.getIndexingMapsArray(), [&](AffineMap map) {
         return getCollapsedOpIndexingMap(map, collapsingInfo);
       }));
 
-  Location loc = genericOp->getLoc();
+  Location loc = op->getLoc();
 
   // Get the input operands.
-  auto inputOperands = llvm::to_vector(llvm::map_range(
-      genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) {
-        return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
+  auto inputOperands = llvm::to_vector(
+      llvm::map_range(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
+        return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
                                      rewriter);
       }));
 
   // Get the output operands and result types.
   SmallVector<Type> resultTypes;
   SmallVector<Value> outputOperands;
-  resultTypes.reserve(genericOp.getNumDpsInits());
-  outputOperands.reserve(genericOp.getNumDpsInits());
-  for (OpOperand &output : genericOp.getDpsInitsMutable()) {
-    Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
-                                            collapsingInfo, rewriter);
+  resultTypes.reserve(op.getNumDpsInits());
+  outputOperands.reserve(op.getNumDpsInits());
+  for (OpOperand &output : op.getDpsInitsMutable()) {
+    Value newOutput =
+        getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
     outputOperands.push_back(newOutput);
-    resultTypes.push_back(newOutput.getType());
+    // If the op has "buffer semantics", then the init operands are ranked
+    // memrefs and the op has no results.
+    if (!hasBufferSemantics)
+      resultTypes.push_back(newOutput.getType());
   }
 
   // Create the generic op.
-  auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
-      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
-      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
-  Block *origOpBlock = &genericOp->getRegion(0).front();
-  Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
-  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
-                       collapsedOpBlock->getArguments());
-
-  if (collapsedGenericOp.hasIndexSemantics()) {
+  Operation *collapsedOp;
+  if (isa<linalg::GenericOp>(op)) {
+    collapsedOp = rewriter.create<linalg::GenericOp>(
+        loc, resultTypes, inputOperands, outputOperands, indexingMaps,
+        iteratorTypes,
+        [](OpBuilder &builder, Location loc, ValueRange args) {});
+    Block *origOpBlock = &op->getRegion(0).front();
+    Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
+    rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
+                         collapsedOpBlock->getArguments());
+  } else {
+    assert(isa<linalg::CopyOp>(op));
+    collapsedOp = rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
+                                                  outputOperands[0]);
+  }
+  LinalgType collapsedLinalgOp = cast<LinalgType>(collapsedOp);
+
+  if (collapsedLinalgOp.hasIndexSemantics()) {
     // Collect the loop range of the generic op.
     OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(collapsedGenericOp);
+    rewriter.setInsertionPoint(collapsedLinalgOp);
     SmallVector<Value> loopBound =
         llvm::to_vector(llvm::map_range(loopRanges, [&](Range range) {
           return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
         }));
     generateCollapsedIndexingRegion(loc,
-                                    &collapsedGenericOp->getRegion(0).front(),
+                                    &collapsedLinalgOp->getRegion(0).front(),
                                     collapsingInfo, loopBound, rewriter);
   }
 
   // Insert expanding reshape for the result to get back the original result
   // type.
   SmallVector<Value> results;
-  for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
+  for (const auto &originalResult : llvm::enumerate(op->getResults())) {
     Value collapsedOpResult =
-        collapsedGenericOp->getResult(originalResult.index());
+        collapsedLinalgOp->getResult(originalResult.index());
     auto originalResultType =
         cast<ShapedType>(originalResult.value().getType());
     auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
       AffineMap indexingMap =
-          genericOp.getIndexingMapMatchingResult(originalResult.value());
+          op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
-      Value result = rewriter.create<tensor::ExpandShapeOp>(
-          loc, originalResultType, collapsedOpResult, reassociation);
-      results.push_back(result);
+      if (isa<MemRefType>(collapsedOpResult.getType())) {
+        Value result = rewriter.create<memref::ExpandShapeOp>(
+            loc, originalResultType, collapsedOpResult, reassociation);
+        results.push_back(result);
+      } else {
+        Value result = rewriter.create<tensor::ExpandShapeOp>(
+            loc, originalResultType, collapsedOpResult, reassociation);
+        results.push_back(result);
+      }
     } else {
       results.push_back(collapsedOpResult);
     }
@@ -1578,8 +1618,8 @@ class FoldWithProducerReshapeOpByCollapsing
       }
 
       std::optional<SmallVector<Value>> replacements =
-          collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
-                                         rewriter);
+          collapseOpIterationDims<linalg::GenericOp>(
+              genericOp, collapsableIterationDims, rewriter);
       if (!replacements) {
         return rewriter.notifyMatchFailure(
             genericOp, "failed to do the fusion by collapsing transformation");
@@ -1596,36 +1636,36 @@ class FoldWithProducerReshapeOpByCollapsing
 };
 
 /// Pattern to collapse dimensions.
-class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
+template <typename LinalgType>
+class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
 public:
   CollapseLinalgDimensions(MLIRContext *context,
                            GetCollapsableDimensionsFn collapseDimensions,
                            PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit),
+      : OpRewritePattern<LinalgType>(context, benefit),
         controlCollapseDimension(std::move(collapseDimensions)) {}
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(LinalgType op,
                                 PatternRewriter &rewriter) const override {
     SmallVector<ReassociationIndices> collapsableIterationDims =
-        controlCollapseDimension(genericOp);
+        controlCollapseDimension(op);
     if (collapsableIterationDims.empty())
       return failure();
 
     // Check if the specified list of dimensions to collapse is a valid list.
-    if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
+    if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
                                   collapsableIterationDims)) {
       return rewriter.notifyMatchFailure(
-          genericOp, "specified dimensions cannot be collapsed");
+          op, "specified dimensions cannot be collapsed");
     }
 
     std::optional<SmallVector<Value>> replacements =
-        collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
-                                       rewriter);
+        collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
+                                            rewriter);
     if (!replacements) {
-      return rewriter.notifyMatchFailure(genericOp,
-                                         "failed to collapse dimensions");
+      return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
     }
-    rewriter.replaceOp(genericOp, *replacements);
+    rewriter.replaceOp(op, *replacements);
     return success();
   }
 
@@ -1856,8 +1896,10 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
 void mlir::linalg::populateCollapseDimensions(
     RewritePatternSet &patterns,
     const GetCollapsableDimensionsFn &controlCollapseDimensions) {
-  patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
-                                         controlCollapseDimensions);
+  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>>(
+      patterns.getContext(), controlCollapseDimensions);
+  patterns.add<CollapseLinalgDimensions<linalg::CopyOp>>(
+      patterns.getContext(), controlCollapseDimensions);
 }
 
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 6737a6e15da5afe..547320f53387477 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -70,3 +70,86 @@ func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41
 // CHECK-LABEL: func @uncollapsable(
 //       CHECK:   linalg.generic
 //  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+// -----
+
+// CHECK-LABEL:   func.func private @collapsable_memref(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
+// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
+// CHECK:           %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
+// CHECK:           ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+// CHECK:             %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK:             linalg.yield %[[VAL_9]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_2]] : memref<1x24x32x8xf32>
+// CHECK:         }
+
+func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>) -> (memref<1x24x32x8xf32>) {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%alloc : memref<1x24x32x8xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %0 = arith.addf %in, %in_0 : f32
+    linalg.yield %0 : f32
+  }
+  return %alloc : memref<1x24x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @uncollapsable_strided_memref(
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) -> (memref<2x6x24x48xi32>) {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x6x24x48xi32>
+  %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  %subview1 = memref.subview %alloc[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>, memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) outs(%subview1 : memref<1x...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/68526


More information about the Mlir-commits mailing list