[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

Petr Kurapov llvmlistbot at llvm.org
Fri Jul 5 03:54:37 PDT 2024


https://github.com/kurapov-peter updated https://github.com/llvm/llvm-project/pull/97582

>From 005330bd1da279cf9567db6dc79359127e820ab6 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 21 Jun 2024 11:53:36 +0000
Subject: [PATCH 1/9] [MLIR][Linalg] Add aggregate ops decomposition pass and
 softmax decomposition implementation

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  10 +
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |   2 +-
 mlir/include/mlir/Dialect/Linalg/Passes.td    |   5 +
 .../Linalg/TransformOps/LinalgTransformOps.td |  18 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  15 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 211 +++++++-----------
 .../TransformOps/LinalgTransformOps.cpp       |  34 ++-
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../DecomposeAggregateNamedLinalgOps.cpp      |  62 +++++
 .../Dialect/Linalg/decompose-named-ops.mlir   | 107 +++++++++
 .../Linalg/transform-op-decompose.mlir        |  54 ++++-
 11 files changed, 345 insertions(+), 174 deletions(-)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
 create mode 100644 mlir/test/Dialect/Linalg/decompose-named-ops.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 08afdf373f014..3858075fae137 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -30,6 +30,16 @@ class IteratorTypeAttr;
 class LinalgOp;
 class GenericOp;
 
+/// Container for result values of decomposition.
+/// - `decomposedOps` contains operations created by the decomposition that are
+/// returned to the caller for further transformations.
+/// - `decomposedValues` contains the values corresponding to the result of the
+/// aggregate operation.
+struct DecompositionResult {
+  SmallVector<Operation *> decomposedOps;
+  SmallVector<Value> decomposedValues;
+};
+
 namespace detail {
 /// Implementation of the method that check if given operands
 /// can be dropped, i.e. the remaining operands can compute the loop
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9..9b1ab20552628 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -862,7 +862,7 @@ def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
           In other words, the returned vector can be used directly with
           `RewriterBase::replaceOp(this, returnedValues)`.
         }],
-        /*retType=*/"FailureOr<SmallVector<Value>>",
+        /*retType=*/"FailureOr<DecompositionResult>",
         /*methodName=*/"decomposeOperation",
         /*args=*/(ins
             "OpBuilder &":$b),
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d96ad919b65f0..98c23b97534f8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgDecomposeAggregateNamedOpsPass : Pass<"linalg-decompose-named-ops"> {
+  let summary = "Decompose complex named ops (e.g., Softmax) into a sequence of linalg named ops";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
   let summary = "Detensorize linalg ops";
   let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 866275cedf68b..c059196b807a7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1317,25 +1317,21 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
 def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    TODO
+    Decomposes high-level named ops into a sequence of non-aggregate named ops
+    via `AggregatedOpInterface`.
+
+    The operation ignores non-decomposable ops. The return handles point to
+    a sequence of named ops produced by the decomposition.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs TransformHandleTypeInterface:$transformed);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
   let assemblyFormat =
       "$target attr-dict `:` functional-type(operands, results)";
 
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 //===----------------------------------------------------------------------===//
 // RewriteInDestinationPassingStyleOp.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2a58d02d7b704..2c9e201ccbd85 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1589,6 +1589,11 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
 void populateLinalgGenericOpsSpecializationPatterns(
     RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns to decompose high-level aggregate named
+/// ops (e.g., softmax) into a sequence of simpler linalg named ops, defining
+/// the operation semantics.
+void populateDecomposeAggregateNamedOpsPatterns(RewritePatternSet &patterns);
+
 /// Linalg decompose convolutions patterns
 
 /// Populates patterns to decompose high-D convolution ops into low-D ones.
@@ -1736,10 +1741,12 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
 /// Adds patterns that reduce the rank of named contraction ops that have
-/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
-/// `<corresponding linalg named op>`, `expand_shape` (if on tensors).  For example a
-/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
-/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
+/// unit dimensions in the operand(s) by converting to a sequence of
+/// `collapse_shape`,
+/// `<corresponding linalg named op>`, `expand_shape` (if on tensors).  For
+/// example a `linalg.batch_matmul` with unit batch size will convert to
+/// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will
+/// convert to a `linalg.dot`.
 void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
 
 } // namespace linalg
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..383f285969ad7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2564,116 +2564,41 @@ void SoftmaxOp::getEffects(
 
 // Helper functions for softmax decomposition.
 // @{
-
-// Helper function to produce the iterator types (reduction or parallel) and
-// affine maps for the iterators used in the decomposition of softmax.
-// This method creates:
-// If allParallel == true:
-// - iterator type: {parallel, ..., parallel}
-// - affine maps:
-// -- identity with inputRank dimensions.
-// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
-//    where N == inputRank.
-//
-// If allParallel == false:
-// - iterator type at dim(i) == parallel for i != \p dim and
-//   dim(dim) == reduction.
-// - affine map:
-// -- identity with inputRank dimensions.
-// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
-//    where N == inputRank.
-static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
-computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
-                                    int64_t dim, bool allParallel = false) {
-  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
-                                                 utils::IteratorType::parallel);
-  if (!allParallel)
-    iteratorTypes[dim] = utils::IteratorType::reduction;
-  MLIRContext *ctxt = builder.getContext();
-  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
-  SmallVector<AffineExpr, 2> affineExprs;
-  for (int i = 0; i < inputRank; i++) {
-    if (i != dim)
-      affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
-  }
-  auto reductionMap =
-      AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
-  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
-  return std::make_tuple(iteratorTypes, indexingMaps);
-}
-
-// Helper function to produce a linalg.generic that computes a reduction on
-// dimension \p dim with the operation type \p T.
-template <typename T>
-static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
-                    int64_t dim) {
-  auto inputType = cast<ShapedType>(input.getType());
-  ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputRank = inputShape.size();
-  auto [iteratorTypes, indexingMaps] =
-      computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
-  assert(indexingMaps.size() == 2 &&
-         "We should have two maps: 1 for the input, 1 for the output");
-  assert(indexingMaps[0].isIdentity() && "input map should be identity");
-
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, output.getType(), input, output, indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value result = b.create<T>(loc, args[0], args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  return genericOp.getResult(0);
-}
-
-/// Produce a linalg generic that computes the second step of the softmax
-/// decomposition: res = exp(input - max), where \p max is the max of \p input
-/// on dimension \p dim.
-static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
-                              Value max, Value output, int64_t dim) {
-  auto inputType = cast<ShapedType>(input.getType());
-  ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputRank = inputShape.size();
-  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
-      builder, inputRank, dim, /*allParallel=*/true);
-  assert(indexingMaps.size() == 2 && "We should have one map for each input");
-  assert(indexingMaps[0].isIdentity() && "input map should be identity");
-  // Add the affine map for the output argument.
-  indexingMaps.push_back(indexingMaps[0]);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
-      iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
-        Value result = b.create<math::ExpOp>(loc, diff);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  return genericOp.getResult(0);
-}
-
-/// Produce a linalg generic that computes the final step of the softmax
-/// decomposition.
-/// \returns  linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
-///   yield  n / d
-/// }
-static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
-                        Value denominator, Value output, int64_t dim) {
-  auto inputType = cast<ShapedType>(numerator.getType());
-  ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputRank = inputShape.size();
-  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
-      builder, inputRank, dim, /*allParallel=*/true);
-  assert(indexingMaps.size() == 2 &&
-         "We should have one map for each input (2)");
-  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
-  // Add the affine map for the output tensor.
-  indexingMaps.push_back(indexingMaps[0]);
-  auto genericOp = builder.create<linalg::GenericOp>(
-      loc, numerator.getType(), ValueRange{numerator, denominator}, output,
-      indexingMaps, iteratorTypes,
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
-        b.create<linalg::YieldOp>(loc, result);
-      });
-  return genericOp.getResult(0);
+TypedAttr createInitValueForReduceMaxOp(Type type, OpBuilder &b) {
+  if (isa<FloatType>(type))
+    return b.getFloatAttr(
+        type, APFloat::getSmallest(cast<FloatType>(type).getFloatSemantics()));
+  if (isa<IntegerType>(type))
+    return b.getIntegerAttr(
+        type, APInt::getSignedMinValue(type.getIntOrFloatBitWidth()));
+  return {};
+}
+
+TypedAttr createInitValueForReduceSumOp(Type type, OpBuilder &b) {
+  if (isa<FloatType>(type))
+    return b.getFloatAttr(
+        type, APFloat::getZero(cast<FloatType>(type).getFloatSemantics()));
+  if (isa<IntegerType>(type))
+    return b.getIntegerAttr(type, APInt::getZero(type.getIntOrFloatBitWidth()));
+  return {};
+}
+
+Value createLinalgReduceMaxBody(OpBuilder b, Location loc, ValueRange args,
+                                Type elementTy) {
+  if (isa<FloatType>(elementTy))
+    return b.create<arith::MaxNumFOp>(loc, args[0], args[1]);
+  if (isa<IntegerType>(elementTy))
+    return b.create<arith::MaxSIOp>(loc, args[0], args[1]);
+  return {};
+}
+
+Value createLinalgReduceSumBody(OpBuilder &b, Location loc, ValueRange args,
+                                Type elementTy) {
+  if (isa<FloatType>(elementTy))
+    return b.create<arith::AddFOp>(loc, args[0], args[1]);
+  if (isa<IntegerType>(elementTy))
+    return b.create<arith::AddIOp>(loc, args[0], args[1]);
+  return {};
 }
 // @} End helper functions for softmax decomposition.
 
@@ -2695,7 +2620,7 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
 /// 4. Divide z and l. This gives the N-dimensional softmax.
 ///    softmax = z / l
 ///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(*this);
   Location loc = getLoc();
@@ -2706,32 +2631,60 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
   Value output = getOutput();
   dims.erase(dims.begin() + reductionDim);
+
   // Step 1: Compute max along dim.
   Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
-  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
-                                                 elementType, b, loc,
-                                                 /*useOnlyFiniteValue=*/true);
-  Value neutralForMaxFInit =
-      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
-          .result();
-  Value max =
-      reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+  auto maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+  auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+  auto neutralMaxInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+  Value neutralForMaxFInit = neutralMaxInitOp.result();
+
+  auto reduceMaxOp = b.create<linalg::ReduceOp>(
+      loc, input, neutralForMaxFInit, reductionDim,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        auto result =
+            createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
 
   // Step 2: Subtract max from input and exponentiate.
-  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
+  auto maxBroadcastOp = b.create<linalg::BroadcastOp>(
+      loc, reduceMaxOp.getResult(0), output, reduceMaxOp.getDimensionsAttr());
+
+  auto subOp = b.create<linalg::SubOp>(
+      loc, ValueRange{input, maxBroadcastOp.getResults().front()},
+      ValueRange{output});
+  auto expOp = b.create<linalg::ExpOp>(loc, ValueRange{subOp.getResult(0)},
+                                       ValueRange{output});
 
   // Step 3: Compute sum along dim.
-  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
-                                       b, loc, /*useOnlyFiniteValue=*/true);
-  Value zeroInit =
-      b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
-  Value denominator =
-      reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+  auto sumFillValAttr = createInitValueForReduceSumOp(elementType, b);
+  auto sumFillValue = b.create<arith::ConstantOp>(loc, sumFillValAttr);
+  auto neutralSumInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{sumFillValue}, ValueRange{outputReduce});
+  auto sumFilledTensor = neutralSumInitOp.result();
+  auto reduceSumOp = b.create<linalg::ReduceOp>(
+      loc, expOp.getResults(), sumFilledTensor, reductionDim,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        auto result =
+            createLinalgReduceSumBody(b, nestedLoc, args, elementType);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
 
   // Step 4: Compute softmax.
-  Value result =
-      buildDivOp(b, loc, numerator, denominator, output, reductionDim);
-  return SmallVector<Value>{result};
+  auto sumBcastOutput = b.create<tensor::EmptyOp>(
+      loc, getOutputOperandType().getShape(), elementType);
+  auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
+      loc, reduceSumOp.getResult(0), sumBcastOutput,
+      reduceSumOp.getDimensionsAttr());
+  auto divOp = b.create<linalg::DivOp>(
+      loc, ValueRange{expOp.getResult(0), sumBroadcastOp.getResults().front()},
+      ValueRange{output});
+  return DecompositionResult{{neutralMaxInitOp, reduceMaxOp, maxBroadcastOp,
+                              subOp, expOp, neutralSumInitOp, reduceSumOp,
+                              sumBroadcastOp, divOp},
+                             {divOp.getResults().front()}};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4eb334f8bbbfa..555d28591104b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -431,27 +431,23 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
 // Decompose the target operation if it implements the AggregatedOpInterface.
 // Push the decomposed operations (the ones that replaces the values produced by
 // \p target) in the `results`.
-DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
-    transform::TransformRewriter &rewriter, Operation *target,
-    transform::ApplyToEachResultList &results,
-    transform::TransformState &state) {
-  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
-  if (!decomposableOp) {
-    failed(rewriter.notifyMatchFailure(target,
-                                       "payload is not a decomposable op"));
-    return emitDefaultSilenceableFailure(target);
-  }
+DiagnosedSilenceableFailure
+transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
+                                       TransformResults &transformResults,
+                                       TransformState &state) {
+  for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
+    auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+    if (!decomposableOp)
+      continue;
 
-  FailureOr<SmallVector<Value>> maybeNewResults =
-      decomposableOp.decomposeOperation(rewriter);
-  if (failed(maybeNewResults))
-    return emitDefaultSilenceableFailure(target);
+    FailureOr<DecompositionResult> maybeNewResults =
+        decomposableOp.decomposeOperation(rewriter);
+    if (failed(maybeNewResults))
+      return emitDefaultSilenceableFailure(target);
 
-  rewriter.replaceOp(decomposableOp, *maybeNewResults);
-  for (Value val : *maybeNewResults) {
-    Operation *definition = val.getDefiningOp();
-    if (definition)
-      results.push_back(definition);
+    rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
+    transformResults.set(cast<OpResult>(getResult(i)),
+                         maybeNewResults->decomposedOps);
   }
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..68582fe6cbad2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   ConvertConv2DToImg2Col.cpp
   DataLayoutPropagation.cpp
   DecomposeLinalgOps.cpp
+  DecomposeAggregateNamedLinalgOps.cpp
   Detensorize.cpp
   DropUnitDims.cpp
   ElementwiseOpFusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
new file mode 100644
index 0000000000000..e8a5b96d54d34
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
@@ -0,0 +1,62 @@
+//===- DecomposeNamedLinalgOps.cpp - Patterns to break up complex ops -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGDECOMPOSEAGGREGATENAMEDOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-decompose-named-ops"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct DecomposeSoftmaxPattern : public OpRewritePattern<SoftmaxOp> {
+  using OpRewritePattern<SoftmaxOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SoftmaxOp op,
+                                PatternRewriter &rewriter) const override {
+    // Decompose softmax(x) into tmp = exp(x - max(x)); tmp / sum(tmp)
+    FailureOr<DecompositionResult> results = op.decomposeOperation(rewriter);
+    if (failed(results))
+      return rewriter.notifyMatchFailure(op, "Failed to decompose SoftmaxOp");
+    rewriter.replaceOp(op, results->decomposedValues);
+    return success();
+  }
+};
+
+} // namespace
+
+struct LinalgDecomposeAggregateNamedOpsPass
+    : public impl::LinalgDecomposeAggregateNamedOpsPassBase<
+          LinalgDecomposeAggregateNamedOpsPass> {
+  using impl::LinalgDecomposeAggregateNamedOpsPassBase<
+      LinalgDecomposeAggregateNamedOpsPass>::
+      LinalgDecomposeAggregateNamedOpsPassBase;
+
+  void runOnOperation() override;
+};
+
+void LinalgDecomposeAggregateNamedOpsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  populateDecomposeAggregateNamedOpsPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+void mlir::linalg::populateDecomposeAggregateNamedOpsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.insert<DecomposeSoftmaxPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
new file mode 100644
index 0000000000000..742a25b992b9e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s -split-input-file -linalg-decompose-named-ops | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -linalg-decompose-named-ops -linalg-generalize-named-ops | FileCheck %s --check-prefix=GENERALIZECHECK
+
+func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK:      func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-DAG:  %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:  %[[INF:.+]] = arith.constant 1.401300e-45 : f32
+// CHECK-DAG:  %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
+// CHECK-DAG:  %[[FIL:.+]] = linalg.fill
+// CHECK-NEXT: %[[RED:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<2x16x32xf32>)
+// CHECK-SAME:  outs(%[[FIL]] : tensor<2x16xf32>) dimensions = [2]
+// CHECK-NEXT: (%[[IN:.+]]: f32, %[[INIT:.+]]: f32) {
+// CHECK-NEXT: %[[MAX:.+]] = arith.maxnumf %[[IN]], %[[INIT]] : f32
+// CHECK-NEXT: linalg.yield %[[MAX]] : f32
+// CHECK:      %[[CST:.+]] = linalg.broadcast ins(%[[RED]] : tensor<2x16xf32>)
+// CHECK-NEXT: %[[SUB:.+]] = linalg.sub ins(%[[ARG0]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
+// CHECK-NEXT: %[[EXP:.+]] = linalg.exp ins(%[[SUB]] : tensor<2x16x32xf32>)
+// CHECK-DAG:  %[[FIL:.+]] = linalg.fill
+// CHECK-NEXT: %[[SUM:.+]] = linalg.reduce ins(%[[EXP]] : tensor<2x16x32xf32>)
+// CHECK-SAME:  outs(%[[FIL]] : tensor<2x16xf32>) dimensions = [2]
+// CHECK-NEXT: (%[[IN:.+]]: f32, %[[INIT:.+]]: f32) {
+// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[IN]], %[[INIT]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-DAG:  %[[EMP:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK-DAG:  %[[CST2:.+]] = linalg.broadcast ins(%[[SUM]] : tensor<2x16xf32>)
+// CHECK-NEXT: %[[DIV:.+]] = linalg.div ins(%[[EXP]], %[[CST2]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[DST]] : tensor<2x16x32xf32>)
+// CHECK: return %[[DIV]]
+
+
+// GENERALIZECHECK-DAG:    #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()>
+// GENERALIZECHECK-DAG:    #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// GENERALIZECHECK-DAG:    #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// GENERALIZECHECK-DAG:    #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// GENERALIZECHECK-LABEL: func @softmax
+// GENERALIZECHECK-SAME:     (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// GENERALIZECHECK-DAG:     %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// GENERALIZECHECK-DAG:     %[[INF:.+]] = arith.constant 1.401300e-45 : f32
+// GENERALIZECHECK-DAG:     %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
+// GENERALIZECHECK-DAG:     %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
+// GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel"]}
+// GENERALIZECHECK-SAME:      ins(%[[INF]] : f32) outs(%[[EMP]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT:      linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT:    } -> tensor<2x16xf32>
+// GENERALIZECHECK:         %[[RED:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel", "reduction"]}
+// GENERALIZECHECK-SAME:      ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[FIL]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT:    ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT:      %[[MAX:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
+// GENERALIZECHECK-NEXT:      linalg.yield %[[MAX]] : f32
+// GENERALIZECHECK-NEXT:    } -> tensor<2x16xf32>
+// GENERALIZECHECK:         %[[CST:.+]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP2]]],
+// GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME:      ins(%[[RED]] : tensor<2x16xf32>)
+// GENERALIZECHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT:        linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT:    } -> tensor<2x16x32xf32>
+// GENERALIZECHECK:         %[[SUB:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP2]]]
+// GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME:      ins(%[[ARG0]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
+// GENERALIZECHECK-SAME:      outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-NEXT:      ^bb0(%[[LHS:.+]]: f32, %[[RHS:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT:      %[[SUBF:.+]] = arith.subf %[[LHS]], %[[RHS]] : f32
+// GENERALIZECHECK-NEXT:      linalg.yield %[[SUBF]] : f32
+// GENERALIZECHECK-NEXT:    } -> tensor<2x16x32xf32>
+// GENERALIZECHECK:         %[[EXP:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]]
+// GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME:      ins(%[[SUB]] : tensor<2x16x32xf32>)
+// GENERALIZECHECK-SAME:      outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT:      %[[EXPF:.+]] = math.exp %[[IN]] : f32
+// GENERALIZECHECK-NEXT:      linalg.yield %[[EXPF]] : f32
+// GENERALIZECHECK-NEXT:    } -> tensor<2x16x32xf32>
+// GENERALIZECHECK-DAG:     %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
+// GENERALIZECHECK:         %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
+// GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel"]}
+// GENERALIZECHECK-SAME:      ins(%[[ZERO]] : f32) outs(%[[EMP]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT:      linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT:    } -> tensor<2x16xf32>
+// GENERALIZECHECK:         %[[RED:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel", "reduction"]}
+// GENERALIZECHECK-SAME:      ins(%[[EXP]] : tensor<2x16x32xf32>) outs(%[[FIL]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT:    ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT:      %[[ADDF:.+]] = arith.addf %[[IN]], %[[OUT]] : f32
+// GENERALIZECHECK-NEXT:      linalg.yield %[[ADDF]] : f32
+// GENERALIZECHECK-NEXT:    } -> tensor<2x16xf32>
+// GENERALIZECHECK:         %[[CST:.+]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP2]]],
+// GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME:      ins(%[[RED]] : tensor<2x16xf32>)
+// GENERALIZECHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT:        linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT:    } -> tensor<2x16x32xf32>
+// GENERALIZECHECK:         %[[DIV:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP2]]],
+// GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME:      ins(%[[EXP]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
+// GENERALIZECHECK-SAME:      outs(%[[DST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-NEXT:      ^bb0(%[[LHS:.+]]: f32, %[[RHS:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT:      %[[DIVF:.+]] = arith.divf %[[LHS]], %[[RHS]] : f32
+// GENERALIZECHECK-NEXT:      linalg.yield %[[DIVF]] : f32
+// GENERALIZECHECK-NEXT:    } -> tensor<2x16x32xf32>
+// GENERALIZECHECK:         return %[[DIV]] : tensor<2x16x32xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 2e211d2fa7dbe..7243616c77d4e 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -2,6 +2,8 @@
 
 // CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG:  #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 
 // CHECK-LABEL: @conv_2d_nhwc_hwcf
 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
@@ -210,32 +212,62 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // CHECK-LABEL:      func.func @softmax(
 // CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
 // CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
-// CHECK-DAG:        %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
-// CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 1.401300e-45 : f32
+// CHECK:        %[[D2:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// CHECK-SAME:      iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME:      ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) {
+// CHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:      linalg.yield %[[IN]] : f32
+// CHECK-NEXT:    } -> tensor<2x16xf32>
 // CHECK:        %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
 // CHECK-SAME:     "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
 // CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:          %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
 // CHECK:          linalg.yield %[[D8]] : f32
 // CHECK:        } -> tensor<2x16xf32>
-// CHECK:        %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
-// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK:        %[[BCST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]],
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:      ins(%[[D3]] : tensor<2x16xf32>) outs(%[[DST]] : tensor<2x16x32xf32>)
+// CHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:        linalg.yield %[[IN]] : f32
+// CHECK-NEXT:    } -> tensor<2x16x32xf32>
+// CHECK:        %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types =
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[BCST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
 // CHECK-SAME:     outs(%[[DST]] : tensor<2x16x32xf32>) {
 // CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:          %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
-// CHECK:          %[[D9:.+]] = math.exp %[[D8]] : f32
-// CHECK:          linalg.yield %[[D9]] : f32
+// CHECK:          linalg.yield %[[D8]] : f32
 // CHECK:        } -> tensor<2x16x32xf32>
+// CHECK:         %[[EXP:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:      ins(%[[D4]] : tensor<2x16x32xf32>)
+// CHECK-SAME:      outs(%[[DST]] : tensor<2x16x32xf32>) {
+// CHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:      %[[EXPF:.+]] = math.exp %[[IN]] : f32
+// CHECK-NEXT:      linalg.yield %[[EXPF]] : f32
+// CHECK-NEXT:    } -> tensor<2x16x32xf32>
 // CHECK:        %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK:        %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK:        %[[D5:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// CHECK-SAME:      iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME:      ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) {
+// CHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:      linalg.yield %[[IN]] : f32
+// CHECK-NEXT:    } -> tensor<2x16xf32>
 // CHECK:        %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME:     "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
+// CHECK-SAME:     "parallel", "reduction"]} ins(%[[EXP]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
 // CHECK:        ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:          %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
 // CHECK:          linalg.yield %[[D8]] : f32
 // CHECK:        } -> tensor<2x16xf32>
-// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
-// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK:        %[[EMP:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK:        %[[CST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]],
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:      ins(%[[D6]] : tensor<2x16xf32>) outs(%[[EMP]] : tensor<2x16x32xf32>) {
+// CHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT:        linalg.yield %[[IN]] : f32
+// CHECK-NEXT:    } -> tensor<2x16x32xf32>
+// CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types =
+// CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[EXP]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
 // CHECK-SAME:     outs(%[[DST]] : tensor<2x16x32xf32>) {
 // CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:          %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
@@ -250,6 +282,8 @@ module attributes {transform.with_named_sequence} {
 
     %2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op
+    %4 = transform.structured.generalize %3: (!transform.any_op) -> !transform.any_op
+    
     transform.yield
   }
 }

>From 717f622b346aff974c262dad421e5273f1c205a3 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Wed, 3 Jul 2024 15:02:16 +0000
Subject: [PATCH 2/9] fixup! [MLIR][Linalg] Add aggregate ops decomposition
 pass and softmax decomposition implementation

---
 mlir/test/Dialect/Linalg/decompose-named-ops.mlir | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 742a25b992b9e..619956d3417a9 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -63,7 +63,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // GENERALIZECHECK:         %[[SUB:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP2]]]
 // GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]}
 // GENERALIZECHECK-SAME:      ins(%[[ARG0]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
-// GENERALIZECHECK-SAME:      outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-SAME:      outs(%[[DST]] : tensor<2x16x32xf32>) {
 // GENERALIZECHECK-NEXT:      ^bb0(%[[LHS:.+]]: f32, %[[RHS:.+]]: f32, %[[OUT:.+]]: f32):
 // GENERALIZECHECK-NEXT:      %[[SUBF:.+]] = arith.subf %[[LHS]], %[[RHS]] : f32
 // GENERALIZECHECK-NEXT:      linalg.yield %[[SUBF]] : f32
@@ -71,12 +71,11 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // GENERALIZECHECK:         %[[EXP:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]]
 // GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"]}
 // GENERALIZECHECK-SAME:      ins(%[[SUB]] : tensor<2x16x32xf32>)
-// GENERALIZECHECK-SAME:      outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-SAME:      outs(%[[DST]] : tensor<2x16x32xf32>) {
 // GENERALIZECHECK-NEXT:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
 // GENERALIZECHECK-NEXT:      %[[EXPF:.+]] = math.exp %[[IN]] : f32
 // GENERALIZECHECK-NEXT:      linalg.yield %[[EXPF]] : f32
 // GENERALIZECHECK-NEXT:    } -> tensor<2x16x32xf32>
-// GENERALIZECHECK-DAG:     %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
 // GENERALIZECHECK:         %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
 // GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel"]}
 // GENERALIZECHECK-SAME:      ins(%[[ZERO]] : f32) outs(%[[EMP]] : tensor<2x16xf32>) {

>From 87fb531438a3aba60f4bf98e5c0bf775a0ffbf13 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Wed, 3 Jul 2024 19:30:25 +0000
Subject: [PATCH 3/9] skip temporary

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 383f285969ad7..363bf678cfa13 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2638,10 +2638,9 @@ FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
   auto neutralMaxInitOp = b.create<linalg::FillOp>(
       loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
-  Value neutralForMaxFInit = neutralMaxInitOp.result();
 
   auto reduceMaxOp = b.create<linalg::ReduceOp>(
-      loc, input, neutralForMaxFInit, reductionDim,
+      loc, input, neutralMaxInitOp.result(), reductionDim,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
         auto result =
             createLinalgReduceMaxBody(b, nestedLoc, args, elementType);

>From e24372af65145e0c604ffcd71c761e519f22c28c Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Wed, 3 Jul 2024 19:54:14 +0000
Subject: [PATCH 4/9] Fix float min constant

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp             | 3 ++-
 mlir/test/Dialect/Linalg/decompose-named-ops.mlir    | 4 ++--
 mlir/test/Dialect/Linalg/transform-op-decompose.mlir | 2 +-
 3 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 363bf678cfa13..c35356dadcc15 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2567,7 +2567,8 @@ void SoftmaxOp::getEffects(
 TypedAttr createInitValueForReduceMaxOp(Type type, OpBuilder &b) {
   if (isa<FloatType>(type))
     return b.getFloatAttr(
-        type, APFloat::getSmallest(cast<FloatType>(type).getFloatSemantics()));
+        type, APFloat::getLargest(cast<FloatType>(type).getFloatSemantics(),
+                                  /*Negative=*/true));
   if (isa<IntegerType>(type))
     return b.getIntegerAttr(
         type, APInt::getSignedMinValue(type.getIntOrFloatBitWidth()));
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 619956d3417a9..9d36e194413ad 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -8,7 +8,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 
 // CHECK:      func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
 // CHECK-DAG:  %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:  %[[INF:.+]] = arith.constant 1.401300e-45 : f32
+// CHECK-DAG:  %[[INF:.+]] = arith.constant -3.40282347E+38 : f32
 // CHECK-DAG:  %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
 // CHECK-DAG:  %[[FIL:.+]] = linalg.fill
 // CHECK-NEXT: %[[RED:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<2x16x32xf32>)
@@ -39,7 +39,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // GENERALIZECHECK-LABEL: func @softmax
 // GENERALIZECHECK-SAME:     (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
 // GENERALIZECHECK-DAG:     %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
-// GENERALIZECHECK-DAG:     %[[INF:.+]] = arith.constant 1.401300e-45 : f32
+// GENERALIZECHECK-DAG:     %[[INF:.+]] = arith.constant -3.40282347E+38 : f32
 // GENERALIZECHECK-DAG:     %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
 // GENERALIZECHECK-DAG:     %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
 // GENERALIZECHECK-SAME:      iterator_types = ["parallel", "parallel"]}
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 7243616c77d4e..09f738a876ae2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -212,7 +212,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // CHECK-LABEL:      func.func @softmax(
 // CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
 // CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
-// CHECK-DAG:        %[[CST:.+]] = arith.constant 1.401300e-45 : f32
+// CHECK-DAG:        %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
 // CHECK:        %[[D2:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
 // CHECK-SAME:      iterator_types = ["parallel", "parallel"]}
 // CHECK-SAME:      ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) {

>From a75b13595219e2056197ecea5fb9cf74c1f8144b Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 4 Jul 2024 09:58:32 +0000
Subject: [PATCH 5/9] Fail gracefully on memrefs

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp          | 5 +++++
 mlir/test/Dialect/Linalg/decompose-named-ops.mlir | 7 +++++++
 2 files changed, 12 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c35356dadcc15..c7a7dec60cb47 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2622,6 +2622,11 @@ Value createLinalgReduceSumBody(OpBuilder &b, Location loc, ValueRange args,
 ///    softmax = z / l
 ///
 FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+  if (!isa<RankedTensorType>(getInput().getType())) {
+    // The decomposition assumes ranked tensors as input
+    return failure();
+  }
+
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(*this);
   Location loc = getLoc();
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 9d36e194413ad..148a855edb657 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -104,3 +104,10 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // GENERALIZECHECK-NEXT:      linalg.yield %[[DIVF]] : f32
 // GENERALIZECHECK-NEXT:    } -> tensor<2x16x32xf32>
 // GENERALIZECHECK:         return %[[DIV]] : tensor<2x16x32xf32>
+
+// COM: decomposition assumes tensors as inputs, this is just to make sure nothing breaks
+func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256xf32>) {
+  linalg.softmax
+    dimension(1) ins(%arg0 : memref<16x64x256xf32>) outs(%arg1 : memref<16x64x256xf32>)
+  return
+}

>From 99260403f2a9829ac306733bf014af85ef2c53a6 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 4 Jul 2024 10:17:56 +0000
Subject: [PATCH 6/9] Bind all decomposed ops to the same single result of
 decompose interface op

---
 .../Dialect/Linalg/TransformOps/LinalgTransformOps.td     | 7 ++++---
 .../Dialect/Linalg/TransformOps/LinalgTransformOps.cpp    | 8 +++++---
 2 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c059196b807a7..cbebec366973c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1323,12 +1323,13 @@ def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface
     Decomposes high-level named ops into a sequence of non-aggregate named ops
     via `AggregatedOpInterface`.
 
-    The operation ignores non-decomposable ops. The return handles point to
-    a sequence of named ops produced by the decomposition.
+    The operation ignores non-decomposable ops. The return handle points to
+    a sequence of named ops produced by the all decompositions (i.e. the 
+    information about decomposed op origin is lost).
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
-  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+  let results = (outs TransformHandleTypeInterface:$transformed);
   let assemblyFormat =
       "$target attr-dict `:` functional-type(operands, results)";
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 555d28591104b..acba9b02e3729 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -430,11 +430,13 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
 
 // Decompose the target operation if it implements the AggregatedOpInterface.
 // Push the decomposed operations (the ones that replaces the values produced by
-// \p target) in the `results`.
+// \p target) in the `results`. Decompositions for all targets bind to the same
+// single ouptut value, thus the information about the original targets is lost.
 DiagnosedSilenceableFailure
 transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
                                        TransformResults &transformResults,
                                        TransformState &state) {
+  SmallVector<Operation *> allDecomposedOps;
   for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
     auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
     if (!decomposableOp)
@@ -446,9 +448,9 @@ transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
       return emitDefaultSilenceableFailure(target);
 
     rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
-    transformResults.set(cast<OpResult>(getResult(i)),
-                         maybeNewResults->decomposedOps);
+    allDecomposedOps.append(maybeNewResults->decomposedOps);
   }
+  transformResults.set(cast<OpResult>(getResult()), allDecomposedOps);
   return DiagnosedSilenceableFailure::success();
 }
 

>From d3f7899042b8d42dcfe0a37128bedfb875c9182c Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 4 Jul 2024 10:35:10 +0000
Subject: [PATCH 7/9] Check for pure tensor semantics instead of input type

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c7a7dec60cb47..77b9ebcc5bf3a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2622,7 +2622,7 @@ Value createLinalgReduceSumBody(OpBuilder &b, Location loc, ValueRange args,
 ///    softmax = z / l
 ///
 FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
-  if (!isa<RankedTensorType>(getInput().getType())) {
+  if (!hasPureTensorSemantics()) {
     // The decomposition assumes ranked tensors as input
     return failure();
   }

>From 593bc911dbd159fc85d73f3329acf9777d9b33c3 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 4 Jul 2024 14:21:28 +0000
Subject: [PATCH 8/9] Add basic check as a smoke test for softmax_memref

---
 mlir/test/Dialect/Linalg/decompose-named-ops.mlir | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 148a855edb657..52a51eab2073a 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -105,9 +105,13 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // GENERALIZECHECK-NEXT:    } -> tensor<2x16x32xf32>
 // GENERALIZECHECK:         return %[[DIV]] : tensor<2x16x32xf32>
 
+// -----
+
 // COM: decomposition assumes tensors as inputs, this is just to make sure nothing breaks
 func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256xf32>) {
+// CHECK-LABEL: @softmax_memref
   linalg.softmax
     dimension(1) ins(%arg0 : memref<16x64x256xf32>) outs(%arg1 : memref<16x64x256xf32>)
+// CHECK-NEXT: linalg.softmax {{.*}}
   return
 }

>From 6201b553db0adbd8bcfdd791e3fb3fda316c3ef3 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 5 Jul 2024 10:54:21 +0000
Subject: [PATCH 9/9] Add dynamic shapes test case

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 22 ++++++++--
 .../Dialect/Linalg/decompose-named-ops.mlir   | 40 +++++++++++++++++++
 2 files changed, 58 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 77b9ebcc5bf3a..801f766b22d62 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2634,12 +2634,21 @@ FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   ShapedType inputType = getInputOperandType();
   Type elementType = inputType.getElementType();
   int64_t reductionDim = getDimension();
-  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
   Value output = getOutput();
-  dims.erase(dims.begin() + reductionDim);
+
+  SmallVector<int64_t> reduceShape;
+  SmallVector<Value> dynReduceDims;
+  for (unsigned i = 0; i < inputType.getRank(); i++) {
+    if (reductionDim != i) {
+      reduceShape.push_back(inputType.getDimSize(i));
+      if (inputType.isDynamicDim(i))
+        dynReduceDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+    }
+  }
 
   // Step 1: Compute max along dim.
-  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
+  Value outputReduce =
+      b.create<tensor::EmptyOp>(loc, reduceShape, elementType, dynReduceDims);
   auto maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
   auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
   auto neutralMaxInitOp = b.create<linalg::FillOp>(
@@ -2678,8 +2687,13 @@ FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
       });
 
   // Step 4: Compute softmax.
+  SmallVector<Value> dynDims;
+  for (unsigned i = 0; i < inputType.getRank(); i++) {
+    if (inputType.isDynamicDim(i))
+      dynDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+  }
   auto sumBcastOutput = b.create<tensor::EmptyOp>(
-      loc, getOutputOperandType().getShape(), elementType);
+      loc, getOutputOperandType().getShape(), elementType, dynDims);
   auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
       loc, reduceSumOp.getResult(0), sumBcastOutput,
       reduceSumOp.getDimensionsAttr());
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 52a51eab2073a..8c00feea56192 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -115,3 +115,43 @@ func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256x
 // CHECK-NEXT: linalg.softmax {{.*}}
   return
 }
+
+
+// -----
+
+func.func @softmax_dynamic_shapes(%arg0: tensor<?x?x?xf32>, %dst: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<?x?x?xf32>) outs(%dst: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+
+// CHECK:      func.func @softmax_dynamic_shapes(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK-DAG:  %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:  %[[INF:.+]] = arith.constant -3.40282347E+38 : f32
+// CHECK-DAG:  %[[CST_0:.+]] = arith.constant 0 : index
+// CHECK-DAG:  %[[CST_1:.+]] = arith.constant 1 : index
+// CHECK-DAG:  %[[CST_2:.+]] = arith.constant 2 : index
+// CHECK-DAG:  %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST_0]] : tensor<?x?x?xf32>
+// CHECK-DAG:  %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST_1]] : tensor<?x?x?xf32>
+// CHECK-DAG:  %[[EMP:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
+// CHECK-DAG:  %[[FIL:.+]] = linalg.fill
+// CHECK-NEXT: %[[RED:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<?x?x?xf32>)
+// CHECK-SAME:  outs(%[[FIL]] : tensor<?x?xf32>) dimensions = [2]
+// CHECK-NEXT: (%[[IN:.+]]: f32, %[[INIT:.+]]: f32) {
+// CHECK-NEXT: %[[MAX:.+]] = arith.maxnumf %[[IN]], %[[INIT]] : f32
+// CHECK-NEXT: linalg.yield %[[MAX]] : f32
+// CHECK:      %[[CST:.+]] = linalg.broadcast ins(%[[RED]] : tensor<?x?xf32>)
+// CHECK-NEXT: %[[SUB:.+]] = linalg.sub ins(%[[ARG0]], %[[CST]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+// CHECK-NEXT: %[[EXP:.+]] = linalg.exp ins(%[[SUB]] : tensor<?x?x?xf32>)
+// CHECK-DAG:  %[[FIL:.+]] = linalg.fill
+// CHECK-NEXT: %[[SUM:.+]] = linalg.reduce ins(%[[EXP]] : tensor<?x?x?xf32>)
+// CHECK-SAME:  outs(%[[FIL]] : tensor<?x?xf32>) dimensions = [2]
+// CHECK-NEXT: (%[[IN:.+]]: f32, %[[INIT:.+]]: f32) {
+// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[IN]], %[[INIT]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-DAG:  %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST_0]] : tensor<?x?x?xf32>
+// CHECK-DAG:  %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST_1]] : tensor<?x?x?xf32>
+// CHECK-DAG:  %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[CST_2]] : tensor<?x?x?xf32>
+// CHECK-DAG:  %[[EMP:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]], %[[DIM2]]) : tensor<?x?x?xf32>
+// CHECK-DAG:  %[[CST2:.+]] = linalg.broadcast ins(%[[SUM]] : tensor<?x?xf32>)
+// CHECK-NEXT: %[[DIV:.+]] = linalg.div ins(%[[EXP]], %[[CST2]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[DST]] : tensor<?x?x?xf32>)
+// CHECK: return %[[DIV]]



More information about the Mlir-commits mailing list