[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)
Petr Kurapov
llvmlistbot at llvm.org
Wed Jul 3 08:02:28 PDT 2024
https://github.com/kurapov-peter updated https://github.com/llvm/llvm-project/pull/97582
>From 005330bd1da279cf9567db6dc79359127e820ab6 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 21 Jun 2024 11:53:36 +0000
Subject: [PATCH 1/2] [MLIR][Linalg] Add aggregate ops decomposition pass and
softmax decomposition implementation
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 10 +
.../Dialect/Linalg/IR/LinalgInterfaces.td | 2 +-
mlir/include/mlir/Dialect/Linalg/Passes.td | 5 +
.../Linalg/TransformOps/LinalgTransformOps.td | 18 +-
.../Dialect/Linalg/Transforms/Transforms.h | 15 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 211 +++++++-----------
.../TransformOps/LinalgTransformOps.cpp | 34 ++-
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../DecomposeAggregateNamedLinalgOps.cpp | 62 +++++
.../Dialect/Linalg/decompose-named-ops.mlir | 107 +++++++++
.../Linalg/transform-op-decompose.mlir | 54 ++++-
11 files changed, 345 insertions(+), 174 deletions(-)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
create mode 100644 mlir/test/Dialect/Linalg/decompose-named-ops.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 08afdf373f014..3858075fae137 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -30,6 +30,16 @@ class IteratorTypeAttr;
class LinalgOp;
class GenericOp;
+/// Container for result values of decomposition.
+/// - `decomposedOps` contains operations created by the decomposition that are
+/// returned to the caller for further transformations.
+/// - `decomposedValues` contains the values corresponding to the result of the
+/// aggregate operation.
+struct DecompositionResult {
+ SmallVector<Operation *> decomposedOps;
+ SmallVector<Value> decomposedValues;
+};
+
namespace detail {
/// Implementation of the method that check if given operands
/// can be dropped, i.e. the remaining operands can compute the loop
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9..9b1ab20552628 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -862,7 +862,7 @@ def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
In other words, the returned vector can be used directly with
`RewriterBase::replaceOp(this, returnedValues)`.
}],
- /*retType=*/"FailureOr<SmallVector<Value>>",
+ /*retType=*/"FailureOr<DecompositionResult>",
/*methodName=*/"decomposeOperation",
/*args=*/(ins
"OpBuilder &":$b),
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d96ad919b65f0..98c23b97534f8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgDecomposeAggregateNamedOpsPass : Pass<"linalg-decompose-named-ops"> {
+ let summary = "Decompose complex named ops (e.g., Softmax) into a sequence of linalg named ops";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 866275cedf68b..c059196b807a7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1317,25 +1317,21 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- TODO
+ Decomposes high-level named ops into a sequence of non-aggregate named ops
+ via `AggregatedOpInterface`.
+
+ The operation ignores non-decomposable ops. The return handles point to
+ a sequence of named ops produced by the decomposition.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs TransformHandleTypeInterface:$transformed);
+ let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// RewriteInDestinationPassingStyleOp.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2a58d02d7b704..2c9e201ccbd85 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1589,6 +1589,11 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);
+/// Populates `patterns` with patterns to decompose high-level aggregate named
+/// ops (e.g., softmax) into a sequence of simpler linalg named ops, defining
+/// the operation semantics.
+void populateDecomposeAggregateNamedOpsPatterns(RewritePatternSet &patterns);
+
/// Linalg decompose convolutions patterns
/// Populates patterns to decompose high-D convolution ops into low-D ones.
@@ -1736,10 +1741,12 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
/// Adds patterns that reduce the rank of named contraction ops that have
-/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
-/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
-/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
-/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
+/// unit dimensions in the operand(s) by converting to a sequence of
+/// `collapse_shape`,
+/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For
+/// example a `linalg.batch_matmul` with unit batch size will convert to
+/// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will
+/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
} // namespace linalg
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..383f285969ad7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2564,116 +2564,41 @@ void SoftmaxOp::getEffects(
// Helper functions for softmax decomposition.
// @{
-
-// Helper function to produce the iterator types (reduction or parallel) and
-// affine maps for the iterators used in the decomposition of softmax.
-// This method creates:
-// If allParallel == true:
-// - iterator type: {parallel, ..., parallel}
-// - affine maps:
-// -- identity with inputRank dimensions.
-// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
-// where N == inputRank.
-//
-// If allParallel == false:
-// - iterator type at dim(i) == parallel for i != \p dim and
-// dim(dim) == reduction.
-// - affine map:
-// -- identity with inputRank dimensions.
-// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
-// where N == inputRank.
-static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
-computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
- int64_t dim, bool allParallel = false) {
- SmallVector<utils::IteratorType> iteratorTypes(inputRank,
- utils::IteratorType::parallel);
- if (!allParallel)
- iteratorTypes[dim] = utils::IteratorType::reduction;
- MLIRContext *ctxt = builder.getContext();
- auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
- SmallVector<AffineExpr, 2> affineExprs;
- for (int i = 0; i < inputRank; i++) {
- if (i != dim)
- affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
- }
- auto reductionMap =
- AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
- SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
- return std::make_tuple(iteratorTypes, indexingMaps);
-}
-
-// Helper function to produce a linalg.generic that computes a reduction on
-// dimension \p dim with the operation type \p T.
-template <typename T>
-static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
- int64_t dim) {
- auto inputType = cast<ShapedType>(input.getType());
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputRank = inputShape.size();
- auto [iteratorTypes, indexingMaps] =
- computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
- assert(indexingMaps.size() == 2 &&
- "We should have two maps: 1 for the input, 1 for the output");
- assert(indexingMaps[0].isIdentity() && "input map should be identity");
-
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, output.getType(), input, output, indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<T>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- return genericOp.getResult(0);
-}
-
-/// Produce a linalg generic that computes the second step of the softmax
-/// decomposition: res = exp(input - max), where \p max is the max of \p input
-/// on dimension \p dim.
-static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
- Value max, Value output, int64_t dim) {
- auto inputType = cast<ShapedType>(input.getType());
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputRank = inputShape.size();
- auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
- builder, inputRank, dim, /*allParallel=*/true);
- assert(indexingMaps.size() == 2 && "We should have one map for each input");
- assert(indexingMaps[0].isIdentity() && "input map should be identity");
- // Add the affine map for the output argument.
- indexingMaps.push_back(indexingMaps[0]);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
- iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
- Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
- Value result = b.create<math::ExpOp>(loc, diff);
- b.create<linalg::YieldOp>(loc, result);
- });
- return genericOp.getResult(0);
-}
-
-/// Produce a linalg generic that computes the final step of the softmax
-/// decomposition.
-/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
-/// yield n / d
-/// }
-static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
- Value denominator, Value output, int64_t dim) {
- auto inputType = cast<ShapedType>(numerator.getType());
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputRank = inputShape.size();
- auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
- builder, inputRank, dim, /*allParallel=*/true);
- assert(indexingMaps.size() == 2 &&
- "We should have one map for each input (2)");
- assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
- // Add the affine map for the output tensor.
- indexingMaps.push_back(indexingMaps[0]);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, numerator.getType(), ValueRange{numerator, denominator}, output,
- indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- return genericOp.getResult(0);
+TypedAttr createInitValueForReduceMaxOp(Type type, OpBuilder &b) {
+ if (isa<FloatType>(type))
+ return b.getFloatAttr(
+ type, APFloat::getSmallest(cast<FloatType>(type).getFloatSemantics()));
+ if (isa<IntegerType>(type))
+ return b.getIntegerAttr(
+ type, APInt::getSignedMinValue(type.getIntOrFloatBitWidth()));
+ return {};
+}
+
+TypedAttr createInitValueForReduceSumOp(Type type, OpBuilder &b) {
+ if (isa<FloatType>(type))
+ return b.getFloatAttr(
+ type, APFloat::getZero(cast<FloatType>(type).getFloatSemantics()));
+ if (isa<IntegerType>(type))
+ return b.getIntegerAttr(type, APInt::getZero(type.getIntOrFloatBitWidth()));
+ return {};
+}
+
+Value createLinalgReduceMaxBody(OpBuilder b, Location loc, ValueRange args,
+ Type elementTy) {
+ if (isa<FloatType>(elementTy))
+ return b.create<arith::MaxNumFOp>(loc, args[0], args[1]);
+ if (isa<IntegerType>(elementTy))
+ return b.create<arith::MaxSIOp>(loc, args[0], args[1]);
+ return {};
+}
+
+Value createLinalgReduceSumBody(OpBuilder &b, Location loc, ValueRange args,
+ Type elementTy) {
+ if (isa<FloatType>(elementTy))
+ return b.create<arith::AddFOp>(loc, args[0], args[1]);
+ if (isa<IntegerType>(elementTy))
+ return b.create<arith::AddIOp>(loc, args[0], args[1]);
+ return {};
}
// @} End helper functions for softmax decomposition.
@@ -2695,7 +2620,7 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
/// 4. Divide z and l. This gives the N-dimensional softmax.
/// softmax = z / l
///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(*this);
Location loc = getLoc();
@@ -2706,32 +2631,60 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
Value output = getOutput();
dims.erase(dims.begin() + reductionDim);
+
// Step 1: Compute max along dim.
Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
- Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
- elementType, b, loc,
- /*useOnlyFiniteValue=*/true);
- Value neutralForMaxFInit =
- b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
- .result();
- Value max =
- reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+ auto maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+ auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+ auto neutralMaxInitOp = b.create<linalg::FillOp>(
+ loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+ Value neutralForMaxFInit = neutralMaxInitOp.result();
+
+ auto reduceMaxOp = b.create<linalg::ReduceOp>(
+ loc, input, neutralForMaxFInit, reductionDim,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ auto result =
+ createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
// Step 2: Subtract max from input and exponentiate.
- Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
+ auto maxBroadcastOp = b.create<linalg::BroadcastOp>(
+ loc, reduceMaxOp.getResult(0), output, reduceMaxOp.getDimensionsAttr());
+
+ auto subOp = b.create<linalg::SubOp>(
+ loc, ValueRange{input, maxBroadcastOp.getResults().front()},
+ ValueRange{output});
+ auto expOp = b.create<linalg::ExpOp>(loc, ValueRange{subOp.getResult(0)},
+ ValueRange{output});
// Step 3: Compute sum along dim.
- Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
- b, loc, /*useOnlyFiniteValue=*/true);
- Value zeroInit =
- b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
- Value denominator =
- reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+ auto sumFillValAttr = createInitValueForReduceSumOp(elementType, b);
+ auto sumFillValue = b.create<arith::ConstantOp>(loc, sumFillValAttr);
+ auto neutralSumInitOp = b.create<linalg::FillOp>(
+ loc, ValueRange{sumFillValue}, ValueRange{outputReduce});
+ auto sumFilledTensor = neutralSumInitOp.result();
+ auto reduceSumOp = b.create<linalg::ReduceOp>(
+ loc, expOp.getResults(), sumFilledTensor, reductionDim,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ auto result =
+ createLinalgReduceSumBody(b, nestedLoc, args, elementType);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
// Step 4: Compute softmax.
- Value result =
- buildDivOp(b, loc, numerator, denominator, output, reductionDim);
- return SmallVector<Value>{result};
+ auto sumBcastOutput = b.create<tensor::EmptyOp>(
+ loc, getOutputOperandType().getShape(), elementType);
+ auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
+ loc, reduceSumOp.getResult(0), sumBcastOutput,
+ reduceSumOp.getDimensionsAttr());
+ auto divOp = b.create<linalg::DivOp>(
+ loc, ValueRange{expOp.getResult(0), sumBroadcastOp.getResults().front()},
+ ValueRange{output});
+ return DecompositionResult{{neutralMaxInitOp, reduceMaxOp, maxBroadcastOp,
+ subOp, expOp, neutralSumInitOp, reduceSumOp,
+ sumBroadcastOp, divOp},
+ {divOp.getResults().front()}};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4eb334f8bbbfa..555d28591104b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -431,27 +431,23 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
// Decompose the target operation if it implements the AggregatedOpInterface.
// Push the decomposed operations (the ones that replaces the values produced by
// \p target) in the `results`.
-DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
- transform::TransformRewriter &rewriter, Operation *target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
- if (!decomposableOp) {
- failed(rewriter.notifyMatchFailure(target,
- "payload is not a decomposable op"));
- return emitDefaultSilenceableFailure(target);
- }
+DiagnosedSilenceableFailure
+transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
+ auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+ if (!decomposableOp)
+ continue;
- FailureOr<SmallVector<Value>> maybeNewResults =
- decomposableOp.decomposeOperation(rewriter);
- if (failed(maybeNewResults))
- return emitDefaultSilenceableFailure(target);
+ FailureOr<DecompositionResult> maybeNewResults =
+ decomposableOp.decomposeOperation(rewriter);
+ if (failed(maybeNewResults))
+ return emitDefaultSilenceableFailure(target);
- rewriter.replaceOp(decomposableOp, *maybeNewResults);
- for (Value val : *maybeNewResults) {
- Operation *definition = val.getDefiningOp();
- if (definition)
- results.push_back(definition);
+ rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
+ transformResults.set(cast<OpResult>(getResult(i)),
+ maybeNewResults->decomposedOps);
}
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..68582fe6cbad2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
ConvertConv2DToImg2Col.cpp
DataLayoutPropagation.cpp
DecomposeLinalgOps.cpp
+ DecomposeAggregateNamedLinalgOps.cpp
Detensorize.cpp
DropUnitDims.cpp
ElementwiseOpFusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
new file mode 100644
index 0000000000000..e8a5b96d54d34
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
@@ -0,0 +1,62 @@
+//===- DecomposeNamedLinalgOps.cpp - Patterns to break up complex ops -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGDECOMPOSEAGGREGATENAMEDOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-decompose-named-ops"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct DecomposeSoftmaxPattern : public OpRewritePattern<SoftmaxOp> {
+ using OpRewritePattern<SoftmaxOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SoftmaxOp op,
+ PatternRewriter &rewriter) const override {
+ // Decompose softmax(x) into tmp = exp(x - max(x)); tmp / sum(tmp)
+ FailureOr<DecompositionResult> results = op.decomposeOperation(rewriter);
+ if (failed(results))
+ return rewriter.notifyMatchFailure(op, "Failed to decompose SoftmaxOp");
+ rewriter.replaceOp(op, results->decomposedValues);
+ return success();
+ }
+};
+
+} // namespace
+
+struct LinalgDecomposeAggregateNamedOpsPass
+ : public impl::LinalgDecomposeAggregateNamedOpsPassBase<
+ LinalgDecomposeAggregateNamedOpsPass> {
+ using impl::LinalgDecomposeAggregateNamedOpsPassBase<
+ LinalgDecomposeAggregateNamedOpsPass>::
+ LinalgDecomposeAggregateNamedOpsPassBase;
+
+ void runOnOperation() override;
+};
+
+void LinalgDecomposeAggregateNamedOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateDecomposeAggregateNamedOpsPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+void mlir::linalg::populateDecomposeAggregateNamedOpsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.insert<DecomposeSoftmaxPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
new file mode 100644
index 0000000000000..742a25b992b9e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s -split-input-file -linalg-decompose-named-ops | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -linalg-decompose-named-ops -linalg-generalize-named-ops | FileCheck %s --check-prefix=GENERALIZECHECK
+
+func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+ %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+ return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[INF:.+]] = arith.constant 1.401300e-45 : f32
+// CHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
+// CHECK-DAG: %[[FIL:.+]] = linalg.fill
+// CHECK-NEXT: %[[RED:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<2x16x32xf32>)
+// CHECK-SAME: outs(%[[FIL]] : tensor<2x16xf32>) dimensions = [2]
+// CHECK-NEXT: (%[[IN:.+]]: f32, %[[INIT:.+]]: f32) {
+// CHECK-NEXT: %[[MAX:.+]] = arith.maxnumf %[[IN]], %[[INIT]] : f32
+// CHECK-NEXT: linalg.yield %[[MAX]] : f32
+// CHECK: %[[CST:.+]] = linalg.broadcast ins(%[[RED]] : tensor<2x16xf32>)
+// CHECK-NEXT: %[[SUB:.+]] = linalg.sub ins(%[[ARG0]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
+// CHECK-NEXT: %[[EXP:.+]] = linalg.exp ins(%[[SUB]] : tensor<2x16x32xf32>)
+// CHECK-DAG: %[[FIL:.+]] = linalg.fill
+// CHECK-NEXT: %[[SUM:.+]] = linalg.reduce ins(%[[EXP]] : tensor<2x16x32xf32>)
+// CHECK-SAME: outs(%[[FIL]] : tensor<2x16xf32>) dimensions = [2]
+// CHECK-NEXT: (%[[IN:.+]]: f32, %[[INIT:.+]]: f32) {
+// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[IN]], %[[INIT]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK-DAG: %[[CST2:.+]] = linalg.broadcast ins(%[[SUM]] : tensor<2x16xf32>)
+// CHECK-NEXT: %[[DIV:.+]] = linalg.div ins(%[[EXP]], %[[CST2]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[DST]] : tensor<2x16x32xf32>)
+// CHECK: return %[[DIV]]
+
+
+// GENERALIZECHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()>
+// GENERALIZECHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// GENERALIZECHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// GENERALIZECHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// GENERALIZECHECK-LABEL: func @softmax
+// GENERALIZECHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// GENERALIZECHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// GENERALIZECHECK-DAG: %[[INF:.+]] = arith.constant 1.401300e-45 : f32
+// GENERALIZECHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
+// GENERALIZECHECK-DAG: %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[INF]] : f32) outs(%[[EMP]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16xf32>
+// GENERALIZECHECK: %[[RED:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// GENERALIZECHECK-SAME: ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[FIL]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: %[[MAX:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
+// GENERALIZECHECK-NEXT: linalg.yield %[[MAX]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16xf32>
+// GENERALIZECHECK: %[[CST:.+]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP2]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[RED]] : tensor<2x16xf32>)
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
+// GENERALIZECHECK: %[[SUB:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP2]]]
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[ARG0]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
+// GENERALIZECHECK-SAME: outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[LHS:.+]]: f32, %[[RHS:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: %[[SUBF:.+]] = arith.subf %[[LHS]], %[[RHS]] : f32
+// GENERALIZECHECK-NEXT: linalg.yield %[[SUBF]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
+// GENERALIZECHECK: %[[EXP:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]]
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[SUB]] : tensor<2x16x32xf32>)
+// GENERALIZECHECK-SAME: outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: %[[EXPF:.+]] = math.exp %[[IN]] : f32
+// GENERALIZECHECK-NEXT: linalg.yield %[[EXPF]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
+// GENERALIZECHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
+// GENERALIZECHECK: %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[ZERO]] : f32) outs(%[[EMP]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16xf32>
+// GENERALIZECHECK: %[[RED:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// GENERALIZECHECK-SAME: ins(%[[EXP]] : tensor<2x16x32xf32>) outs(%[[FIL]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: %[[ADDF:.+]] = arith.addf %[[IN]], %[[OUT]] : f32
+// GENERALIZECHECK-NEXT: linalg.yield %[[ADDF]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16xf32>
+// GENERALIZECHECK: %[[CST:.+]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP2]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[RED]] : tensor<2x16xf32>)
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
+// GENERALIZECHECK: %[[DIV:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP2]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[EXP]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
+// GENERALIZECHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[LHS:.+]]: f32, %[[RHS:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: %[[DIVF:.+]] = arith.divf %[[LHS]], %[[RHS]] : f32
+// GENERALIZECHECK-NEXT: linalg.yield %[[DIVF]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
+// GENERALIZECHECK: return %[[DIV]] : tensor<2x16x32xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 2e211d2fa7dbe..7243616c77d4e 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -2,6 +2,8 @@
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @conv_2d_nhwc_hwcf
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
@@ -210,32 +212,62 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// CHECK-LABEL: func.func @softmax(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
-// CHECK-DAG: %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
-// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 1.401300e-45 : f32
+// CHECK: %[[D2:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x16xf32>
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D8]] : f32
// CHECK: } -> tensor<2x16xf32>
-// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
-// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK: %[[BCST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[D3]] : tensor<2x16xf32>) outs(%[[DST]] : tensor<2x16x32xf32>)
+// CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x16x32xf32>
+// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types =
+// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[BCST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
-// CHECK: %[[D9:.+]] = math.exp %[[D8]] : f32
-// CHECK: linalg.yield %[[D9]] : f32
+// CHECK: linalg.yield %[[D8]] : f32
// CHECK: } -> tensor<2x16x32xf32>
+// CHECK: %[[EXP:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[D4]] : tensor<2x16x32xf32>)
+// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT: %[[EXPF:.+]] = math.exp %[[IN]] : f32
+// CHECK-NEXT: linalg.yield %[[EXPF]] : f32
+// CHECK-NEXT: } -> tensor<2x16x32xf32>
// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK: %[[D5:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x16xf32>
// CHECK: %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
+// CHECK-SAME: "parallel", "reduction"]} ins(%[[EXP]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D8]] : f32
// CHECK: } -> tensor<2x16xf32>
-// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
-// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK: %[[EMP:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK: %[[CST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[D6]] : tensor<2x16xf32>) outs(%[[EMP]] : tensor<2x16x32xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x16x32xf32>
+// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types =
+// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[EXP]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
@@ -250,6 +282,8 @@ module attributes {transform.with_named_sequence} {
%2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op
+ %4 = transform.structured.generalize %3: (!transform.any_op) -> !transform.any_op
+
transform.yield
}
}
>From 717f622b346aff974c262dad421e5273f1c205a3 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Wed, 3 Jul 2024 15:02:16 +0000
Subject: [PATCH 2/2] fixup! [MLIR][Linalg] Add aggregate ops decomposition
pass and softmax decomposition implementation
---
mlir/test/Dialect/Linalg/decompose-named-ops.mlir | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 742a25b992b9e..619956d3417a9 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -63,7 +63,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// GENERALIZECHECK: %[[SUB:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP2]]]
// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
// GENERALIZECHECK-SAME: ins(%[[ARG0]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
-// GENERALIZECHECK-SAME: outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
// GENERALIZECHECK-NEXT: ^bb0(%[[LHS:.+]]: f32, %[[RHS:.+]]: f32, %[[OUT:.+]]: f32):
// GENERALIZECHECK-NEXT: %[[SUBF:.+]] = arith.subf %[[LHS]], %[[RHS]] : f32
// GENERALIZECHECK-NEXT: linalg.yield %[[SUBF]] : f32
@@ -71,12 +71,11 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// GENERALIZECHECK: %[[EXP:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]]
// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
// GENERALIZECHECK-SAME: ins(%[[SUB]] : tensor<2x16x32xf32>)
-// GENERALIZECHECK-SAME: outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// GENERALIZECHECK-NEXT: %[[EXPF:.+]] = math.exp %[[IN]] : f32
// GENERALIZECHECK-NEXT: linalg.yield %[[EXPF]] : f32
// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
-// GENERALIZECHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
// GENERALIZECHECK: %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel"]}
// GENERALIZECHECK-SAME: ins(%[[ZERO]] : f32) outs(%[[EMP]] : tensor<2x16xf32>) {
More information about the Mlir-commits
mailing list