[Mlir-commits] [mlir] b6f4dd9 - [mlir][transform] Implement `FlattenElementwiseLinalgOp` transform op (#81431)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 28 09:19:10 PST 2024


Author: srcarroll
Date: 2024-02-28T11:19:06-06:00
New Revision: b6f4dd9ee8cd4aa2c41622d75790df9341d9e043

URL: https://github.com/llvm/llvm-project/commit/b6f4dd9ee8cd4aa2c41622d75790df9341d9e043
DIFF: https://github.com/llvm/llvm-project/commit/b6f4dd9ee8cd4aa2c41622d75790df9341d9e043.diff

LOG: [mlir][transform] Implement `FlattenElementwiseLinalgOp` transform op (#81431)

A `transform.structured.flatten_elementwise` op is implemented for
flattening the iteration space and (applicable) operands/results to a
single dimension.

Added: 
    mlir/test/Dialect/Linalg/flatten-elementwise.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 309573a562872f..53ed31877c6f24 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2295,6 +2295,49 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// FlattenElementwiseLinalgOp
+//===----------------------------------------------------------------------===//
+
+def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
+    "structured.flatten_elementwise",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Flattens the iteration space and (applicable) operands of elementwise
+    linalg ops to a single dimension.
+
+    Returns one handle:
+    - Flattened linalg operation.
+
+    #### Return modes:
+
+    Returns a definite failure if target is not isolated from above.
+    Returns a silenceable failure if the pattern application failed.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Transpose Conv2D
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a848d12fbbb50e..65cf19e7a4fcd6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1074,6 +1074,11 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
 bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
                               ArrayRef<ReassociationIndices> dimSequences);
 
+struct CollapseResult {
+  SmallVector<Value> results;
+  LinalgOp collapsedOp;
+};
+
 /// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
 /// to calling this method is that for each list in `foldedIterationDim`, the
 /// sequence of dimensions is contiguous in domains of all `indexing_maps` of
@@ -1081,9 +1086,8 @@ bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
 /// When valid, the method also collapses the operands of the op. Returns
 /// replacement values of the results of the original `linalgOp` by inserting
 /// reshapes to get back values of compatible types.
-template <typename LinalgType>
-FailureOr<SmallVector<Value>>
-collapseOpIterationDims(LinalgType op,
+FailureOr<CollapseResult>
+collapseOpIterationDims(LinalgOp op,
                         ArrayRef<ReassociationIndices> foldedIterationDims,
                         RewriterBase &rewriter);
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 299965bcfc3ab3..ef9cd5561665fc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3244,6 +3244,31 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// FlattenElementwiseLinalgOp.
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  if (target.getNumLoops() <= 1)
+    return DiagnosedSilenceableFailure::success();
+  ReassociationIndices reassociation(target.getNumLoops());
+  std::iota(reassociation.begin(), reassociation.end(), 0);
+  auto maybeFlattened =
+      (isElementwise(target))
+          ? collapseOpIterationDims(target, reassociation, rewriter)
+          : FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
+                target, "only elementwise flattening is supported"));
+  if (failed(maybeFlattened))
+    return emitDefaultSilenceableFailure(target);
+  results.push_back(maybeFlattened->collapsedOp);
+  rewriter.replaceOp(target, maybeFlattened->results);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeConv2DOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4977940cfbd797..4797bfb2267d7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1446,24 +1446,20 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
   }
 }
 
-template <typename LinalgType>
-Operation *createCollapsedOp(LinalgType op,
-                             const CollapsingInfo &collapsingInfo,
-                             RewriterBase &rewriter) {
-  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
-                "unsupported linalg op type to create");
+void collapseOperandsAndResults(LinalgOp op,
+                                const CollapsingInfo &collapsingInfo,
+                                RewriterBase &rewriter,
+                                SmallVectorImpl<Value> &inputOperands,
+                                SmallVectorImpl<Value> &outputOperands,
+                                SmallVectorImpl<Type> &resultTypes) {
   Location loc = op->getLoc();
-
-  // Get the input operands.
-  SmallVector<Value> inputOperands =
+  inputOperands =
       llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
         return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
                                      rewriter);
       });
 
   // Get the output operands and result types.
-  SmallVector<Type> resultTypes;
-  SmallVector<Value> outputOperands;
   resultTypes.reserve(op.getNumDpsInits());
   outputOperands.reserve(op.getNumDpsInits());
   for (OpOperand &output : op.getDpsInitsMutable()) {
@@ -1475,41 +1471,69 @@ Operation *createCollapsedOp(LinalgType op,
     if (!op.hasPureBufferSemantics())
       resultTypes.push_back(newOutput.getType());
   }
+}
 
-  if (isa<linalg::CopyOp>(op)) {
-    return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
-                                           outputOperands[0]);
-  }
+/// Clone a `LinalgOp` to a collapsed version of same name
+template <typename OpTy>
+OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
+                        const CollapsingInfo &collapsingInfo) {
+  return nullptr;
+}
 
-  // Get the iterator types for the operand.
-  SmallVector<utils::IteratorType> iteratorTypes =
-      getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
+/// Collapse any `LinalgOp` that does not require any specialization such as
+/// indexing_maps, iterator_types, etc.
+template <>
+LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
+                                      const CollapsingInfo &collapsingInfo) {
+  SmallVector<Value> inputOperands, outputOperands;
+  SmallVector<Type> resultTypes;
+  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
+                             outputOperands, resultTypes);
+  return cast<LinalgOp>(clone(
+      rewriter, origOp, resultTypes,
+      llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands))));
+}
 
-  // Get the indexing maps.
-  auto indexingMaps =
-      llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
+/// Collapse a `GenericOp`
+template <>
+GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
+                                        GenericOp origOp,
+                                        const CollapsingInfo &collapsingInfo) {
+  SmallVector<Value> inputOperands, outputOperands;
+  SmallVector<Type> resultTypes;
+  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
+                             outputOperands, resultTypes);
+  SmallVector<AffineMap> indexingMaps(
+      llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
         return getCollapsedOpIndexingMap(map, collapsingInfo);
-      });
+      }));
+
+  SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
+      origOp.getIteratorTypesArray(), collapsingInfo));
 
-  Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
-      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
+  GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
+      origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
       iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
-  Block *origOpBlock = &op->getRegion(0).front();
+  Block *origOpBlock = &origOp->getRegion(0).front();
   Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
   rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
                        collapsedOpBlock->getArguments());
-
   return collapsedOp;
 }
 
+LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
+                           RewriterBase &rewriter) {
+  if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
+    return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
+  } else {
+    return cloneToCollapsedOp(rewriter, op, collapsingInfo);
+  }
+}
+
 /// Implementation of fusion with reshape operation by collapsing dimensions.
-template <typename LinalgType>
-FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
-    LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
+FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
+    LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter) {
-  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
-                "unsupported linalg op type to collapse");
-
   // Bail on trivial no-op cases.
   if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
@@ -1538,8 +1562,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
   }
 
   // Bail on non-canonical ranges.
-  SmallVector<Range> loopRanges =
-      cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
+  SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
   auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
       return cast<IntegerAttr>(attr).getInt() == value;
@@ -1555,8 +1578,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
         op, "expected all loop ranges to have zero start and unit stride");
   }
 
-  LinalgType collapsedOp = cast<LinalgType>(
-      createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
+  LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
 
   Location loc = op->getLoc();
   if (collapsedOp.hasIndexSemantics()) {
@@ -1597,7 +1619,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
       results.push_back(collapsedOpResult);
     }
   }
-  return results;
+  return CollapseResult{results, collapsedOp};
 }
 
 namespace {
@@ -1629,15 +1651,14 @@ class FoldWithProducerReshapeOpByCollapsing
         continue;
       }
 
-      std::optional<SmallVector<Value>> replacements =
-          collapseOpIterationDims<linalg::GenericOp>(
-              genericOp, collapsableIterationDims, rewriter);
-      if (!replacements) {
+      std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
+          genericOp, collapsableIterationDims, rewriter);
+      if (!collapseResult) {
         return rewriter.notifyMatchFailure(
             genericOp, "failed to do the fusion by collapsing transformation");
       }
 
-      rewriter.replaceOp(genericOp, *replacements);
+      rewriter.replaceOp(genericOp, collapseResult->results);
       return success();
     }
     return failure();
@@ -1671,13 +1692,12 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
           op, "specified dimensions cannot be collapsed");
     }
 
-    std::optional<SmallVector<Value>> replacements =
-        collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
-                                            rewriter);
-    if (!replacements) {
+    std::optional<CollapseResult> collapseResult =
+        collapseOpIterationDims(op, collapsableIterationDims, rewriter);
+    if (!collapseResult) {
       return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
     }
-    rewriter.replaceOp(op, *replacements);
+    rewriter.replaceOp(op, collapseResult->results);
     return success();
   }
 

diff  --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
new file mode 100644
index 00000000000000..858c133dd536ca
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -0,0 +1,99 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @fill(
+// CHECK-SAME:                  %[[ARG0:.*]]: f32,
+// CHECK-SAME:                  %[[ARG1:.*]]: memref<32x7xf32>
+// CHECK-NEXT:    %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+// CHECK-NEXT:    linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>)
+func.func @fill(%cst: f32, %arg: memref<32x7xf32>) {
+    linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>)
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fill_tensor(
+// CHECK-SAME:                         %[[ARG0:.*]]: f32,
+// CHECK-SAME:                         %[[ARG1:.*]]: tensor<32x7xf32>
+// CHECK-NEXT:    %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+// CHECK-NEXT:    %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>)
+// CHECK-NEXT:    %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]]
+func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> {
+    %0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) ->  tensor<32x7xf32>
+    return %0 :  tensor<32x7xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @map(
+// CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME:                 %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME:                 %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-NEXT:    %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+// CHECK-NEXT:    %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+// CHECK-NEXT:    %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
+// CHECK-NEXT:    linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
+func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
+    linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>)
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func.func @generic
+// CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME:                 %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME:                 %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-NEXT:    %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+// CHECK-NEXT:    %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+// CHECK-NEXT:    %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
+// CHECK-NEXT:    linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
+// CHECK-NEXT:       ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32)
+// CHECK-NEXT:         %[[SUM:.*]] = arith.addf %[[A]], %[[B]]
+// CHECK-NEXT:         linalg.yield %[[SUM]]
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
+    linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) {
+        ^bb0(%a: f32, %b: f32, %c: f32):
+            %0 = arith.addf %a, %b : f32
+            linalg.yield %0 : f32
+    }
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}


        


More information about the Mlir-commits mailing list