[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 3 07:32:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Petr Kurapov (kurapov-peter)

<details>
<summary>Changes</summary>

…position implementation

* Add a decomposition pass that handles complex aggregate ops (e.g., softmax), replacing them with a sequence of non-aggregate linalg named ops. Implementation for softmax follows the lowering semantics of popular frameworks like [PyTorch](https://github.com/intel/graph-compiler/issues/10#issuecomment-2161145033), [TensorFlow](https://github.com/intel/graph-compiler/issues/10#issuecomment-2162722181), and [others](https://github.com/intel/graph-compiler/issues/10#issuecomment-2161153179).
* Make the `AggregatedOpInterface` return a `DecompositionResult`, similar to the tiling interface. This is to communicate the decomposition sequence nicely (e.g., useful for transform dialect, see below).
* Rework `DecomposeInterfaceOp` to return variadic results and use the new decomposition. This removes code duplication between the generalization pass and decomposition implementation - now aggregate ops are decomposed first and then generalized.

---

Patch is 35.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97582.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (+10) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+5) 
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+7-11) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+82-129) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+15-19) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp (+62) 
- (added) mlir/test/Dialect/Linalg/decompose-named-ops.mlir (+107) 
- (modified) mlir/test/Dialect/Linalg/transform-op-decompose.mlir (+44-10) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 08afdf373f014..3858075fae137 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -30,6 +30,16 @@ class IteratorTypeAttr;
 class LinalgOp;
 class GenericOp;
 
+/// Container for result values of decomposition.
+/// - `decomposedOps` contains operations created by the decomposition that are
+/// returned to the caller for further transformations.
+/// - `decomposedValues` contains the values corresponding to the result of the
+/// aggregate operation.
+struct DecompositionResult {
+  SmallVector<Operation *> decomposedOps;
+  SmallVector<Value> decomposedValues;
+};
+
 namespace detail {
 /// Implementation of the method that check if given operands
 /// can be dropped, i.e. the remaining operands can compute the loop
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9..9b1ab20552628 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -862,7 +862,7 @@ def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
           In other words, the returned vector can be used directly with
           `RewriterBase::replaceOp(this, returnedValues)`.
         }],
-        /*retType=*/"FailureOr<SmallVector<Value>>",
+        /*retType=*/"FailureOr<DecompositionResult>",
         /*methodName=*/"decomposeOperation",
         /*args=*/(ins
             "OpBuilder &":$b),
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0621a9f33ba1e..3031126e582f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -94,6 +94,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgDecomposeAggregateNamedOpsPass : Pass<"linalg-decompose-named-ops"> {
+  let summary = "Decompose complex named ops (e.g., Softmax) into a sequence of linalg named ops";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
   let summary = "Detensorize linalg ops";
   let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..2e8e294aa2e4c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1317,25 +1317,21 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
 def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    TODO
+    Decomposes high-level named ops into a sequence of non-aggregate named ops
+    via `AggregatedOpInterface`.
+
+    The operation ignores non-decomposable ops. The return handles point to
+    a sequence of named ops produced by the decomposition.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs TransformHandleTypeInterface:$transformed);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
   let assemblyFormat =
       "$target attr-dict `:` functional-type(operands, results)";
 
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 //===----------------------------------------------------------------------===//
 // RewriteInDestinationPassingStyleOp.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..b0eeb274f71bb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1546,6 +1546,11 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
 /// linalg.generic ops.
 void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns to decompose high-level aggregate named
+/// ops (e.g., softmax) into a sequence of simpler linalg named ops, defining
+/// the operation semantics.
+void populateDecomposeAggregateNamedOpsPatterns(RewritePatternSet &patterns);
+
 /// Linalg decompose convolutions patterns
 
 /// Populates patterns to decompose high-D convolution ops into low-D ones.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..383f285969ad7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2564,116 +2564,41 @@ void SoftmaxOp::getEffects(
 
 // Helper functions for softmax decomposition.
 // @{
-
-// Helper function to produce the iterator types (reduction or parallel) and
-// affine maps for the iterators used in the decomposition of softmax.
-// This method creates:
-// If allParallel == true:
-// - iterator type: {parallel, ..., parallel}
-// - affine maps:
-// -- identity with inputRank dimensions.
-// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
-//    where N == inputRank.
-//
-// If allParallel == false:
-// - iterator type at dim(i) == parallel for i != \p dim and
-//   dim(dim) == reduction.
-// - affine map:
-// -- identity with inputRank dimensions.
-// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
-//    where N == inputRank.
-static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
-computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
-                                    int64_t dim, bool allParallel = false) {
-  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
-                                                 utils::IteratorType::parallel);
-  if (!allParallel)
-    iteratorTypes[dim] = utils::IteratorType::reduction;
-  MLIRContext *ctxt = builder.getContext();
-  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
-  SmallVector<AffineExpr, 2> affineExprs;
-  for (int i = 0; i < inputRank; i++) {
-    if (i != dim)
-      affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
-  }
-  auto reductionMap =
-      AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
-  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
-  return std::make_tuple(iteratorTypes, indexingMaps);
-}
-
-// Helper function to produce a linalg.generic that computes a reduction on
-// dimension \p dim with the operation type \p T.
-template <typename T>
-static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
-                    int64_t dim) {
-  auto inputType = cast<ShapedType>(input.getType());
-  ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputRank = inputShape.size();
-  auto [iteratorTypes, indexingMaps] =
-      computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
-  assert(indexingMaps.size() == 2 &&
-         "We should have two maps: 1 for the input, 1 for the output");
-  assert(indexingMaps[0].isIdentity() && "input map should be identity");
-
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, output.getType(), input, output, indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value result = b.create<T>(loc, args[0], args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  return genericOp.getResult(0);
-}
-
-/// Produce a linalg generic that computes the second step of the softmax
-/// decomposition: res = exp(input - max), where \p max is the max of \p input
-/// on dimension \p dim.
-static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
-                              Value max, Value output, int64_t dim) {
-  auto inputType = cast<ShapedType>(input.getType());
-  ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputRank = inputShape.size();
-  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
-      builder, inputRank, dim, /*allParallel=*/true);
-  assert(indexingMaps.size() == 2 && "We should have one map for each input");
-  assert(indexingMaps[0].isIdentity() && "input map should be identity");
-  // Add the affine map for the output argument.
-  indexingMaps.push_back(indexingMaps[0]);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
-      iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
-        Value result = b.create<math::ExpOp>(loc, diff);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  return genericOp.getResult(0);
-}
-
-/// Produce a linalg generic that computes the final step of the softmax
-/// decomposition.
-/// \returns  linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
-///   yield  n / d
-/// }
-static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
-                        Value denominator, Value output, int64_t dim) {
-  auto inputType = cast<ShapedType>(numerator.getType());
-  ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputRank = inputShape.size();
-  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
-      builder, inputRank, dim, /*allParallel=*/true);
-  assert(indexingMaps.size() == 2 &&
-         "We should have one map for each input (2)");
-  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
-  // Add the affine map for the output tensor.
-  indexingMaps.push_back(indexingMaps[0]);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, numerator.getType(), ValueRange{numerator, denominator}, output,
-      indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  return genericOp.getResult(0);
+TypedAttr createInitValueForReduceMaxOp(Type type, OpBuilder &b) {
+  if (isa<FloatType>(type))
+    return b.getFloatAttr(
+        type, APFloat::getSmallest(cast<FloatType>(type).getFloatSemantics()));
+  if (isa<IntegerType>(type))
+    return b.getIntegerAttr(
+        type, APInt::getSignedMinValue(type.getIntOrFloatBitWidth()));
+  return {};
+}
+
+TypedAttr createInitValueForReduceSumOp(Type type, OpBuilder &b) {
+  if (isa<FloatType>(type))
+    return b.getFloatAttr(
+        type, APFloat::getZero(cast<FloatType>(type).getFloatSemantics()));
+  if (isa<IntegerType>(type))
+    return b.getIntegerAttr(type, APInt::getZero(type.getIntOrFloatBitWidth()));
+  return {};
+}
+
+Value createLinalgReduceMaxBody(OpBuilder b, Location loc, ValueRange args,
+                                Type elementTy) {
+  if (isa<FloatType>(elementTy))
+    return b.create<arith::MaxNumFOp>(loc, args[0], args[1]);
+  if (isa<IntegerType>(elementTy))
+    return b.create<arith::MaxSIOp>(loc, args[0], args[1]);
+  return {};
+}
+
+Value createLinalgReduceSumBody(OpBuilder &b, Location loc, ValueRange args,
+                                Type elementTy) {
+  if (isa<FloatType>(elementTy))
+    return b.create<arith::AddFOp>(loc, args[0], args[1]);
+  if (isa<IntegerType>(elementTy))
+    return b.create<arith::AddIOp>(loc, args[0], args[1]);
+  return {};
 }
 // @} End helper functions for softmax decomposition.
 
@@ -2695,7 +2620,7 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
 /// 4. Divide z and l. This gives the N-dimensional softmax.
 ///    softmax = z / l
 ///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(*this);
   Location loc = getLoc();
@@ -2706,32 +2631,60 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
   Value output = getOutput();
   dims.erase(dims.begin() + reductionDim);
+
   // Step 1: Compute max along dim.
   Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
-  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
-                                                 elementType, b, loc,
-                                                 /*useOnlyFiniteValue=*/true);
-  Value neutralForMaxFInit =
-      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
-          .result();
-  Value max =
-      reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+  auto maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+  auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+  auto neutralMaxInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+  Value neutralForMaxFInit = neutralMaxInitOp.result();
+
+  auto reduceMaxOp = b.create<linalg::ReduceOp>(
+      loc, input, neutralForMaxFInit, reductionDim,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        auto result =
+            createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
 
   // Step 2: Subtract max from input and exponentiate.
-  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
+  auto maxBroadcastOp = b.create<linalg::BroadcastOp>(
+      loc, reduceMaxOp.getResult(0), output, reduceMaxOp.getDimensionsAttr());
+
+  auto subOp = b.create<linalg::SubOp>(
+      loc, ValueRange{input, maxBroadcastOp.getResults().front()},
+      ValueRange{output});
+  auto expOp = b.create<linalg::ExpOp>(loc, ValueRange{subOp.getResult(0)},
+                                       ValueRange{output});
 
   // Step 3: Compute sum along dim.
-  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
-                                       b, loc, /*useOnlyFiniteValue=*/true);
-  Value zeroInit =
-      b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
-  Value denominator =
-      reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+  auto sumFillValAttr = createInitValueForReduceSumOp(elementType, b);
+  auto sumFillValue = b.create<arith::ConstantOp>(loc, sumFillValAttr);
+  auto neutralSumInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{sumFillValue}, ValueRange{outputReduce});
+  auto sumFilledTensor = neutralSumInitOp.result();
+  auto reduceSumOp = b.create<linalg::ReduceOp>(
+      loc, expOp.getResults(), sumFilledTensor, reductionDim,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        auto result =
+            createLinalgReduceSumBody(b, nestedLoc, args, elementType);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
 
   // Step 4: Compute softmax.
-  Value result =
-      buildDivOp(b, loc, numerator, denominator, output, reductionDim);
-  return SmallVector<Value>{result};
+  auto sumBcastOutput = b.create<tensor::EmptyOp>(
+      loc, getOutputOperandType().getShape(), elementType);
+  auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
+      loc, reduceSumOp.getResult(0), sumBcastOutput,
+      reduceSumOp.getDimensionsAttr());
+  auto divOp = b.create<linalg::DivOp>(
+      loc, ValueRange{expOp.getResult(0), sumBroadcastOp.getResults().front()},
+      ValueRange{output});
+  return DecompositionResult{{neutralMaxInitOp, reduceMaxOp, maxBroadcastOp,
+                              subOp, expOp, neutralSumInitOp, reduceSumOp,
+                              sumBroadcastOp, divOp},
+                             {divOp.getResults().front()}};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..e3f0a18a5ec2c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -431,27 +431,23 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
 // Decompose the target operation if it implements the AggregatedOpInterface.
 // Push the decomposed operations (the ones that replaces the values produced by
 // \p target) in the `results`.
-DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
-    transform::TransformRewriter &rewriter, Operation *target,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
-  if (!decomposableOp) {
-    failed(rewriter.notifyMatchFailure(target,
-                                       "payload is not a decomposable op"));
-    return emitDefaultSilenceableFailure(target);
-  }
+DiagnosedSilenceableFailure
+transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
+                                       TransformResults &transformResults,
+                                       TransformState &state) {
+  for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
+    auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+    if (!decomposableOp)
+      continue;
 
-  FailureOr<SmallVector<Value>> maybeNewResults =
-      decomposableOp.decomposeOperation(rewriter);
-  if (failed(maybeNewResults))
-    return emitDefaultSilenceableFailure(target);
+    FailureOr<DecompositionResult> maybeNewResults =
+        decomposableOp.decomposeOperation(rewriter);
+    if (failed(maybeNewResults))
+      return emitDefaultSilenceableFailure(target);
 
-  rewriter.replaceOp(decomposableOp, *maybeNewResults);
-  for (Value val : *maybeNewResults) {
-    Operation *definition = val.getDefiningOp();
-    if (definition)
-      results.push_back(definition);
+    rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
+    transformResults.set(cast<OpResult>(getResult(i)),
+                         maybeNewResults->decomposedOps);
   }
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..68582fe6cbad2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   ConvertConv2DToImg2Col.cpp
   DataLayoutPropagation.cpp
   DecomposeLinalgOps.cpp
+  DecomposeAggregateNamedLinalgOps.cpp
   Detensorize.cpp
   DropUnitDims.cpp
   ElementwiseOpFusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
new file mode 100644
index 0000000000000..e8a5b96d54d34
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
@@ -0,0 +1,62 @@
+//===- DecomposeNamedLinalgOps.cpp - Patterns to break up complex ops -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/Gre...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/97582


More information about the Mlir-commits mailing list