[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)
Petr Kurapov
llvmlistbot at llvm.org
Fri Jul 5 03:54:37 PDT 2024
https://github.com/kurapov-peter updated https://github.com/llvm/llvm-project/pull/97582
>From 005330bd1da279cf9567db6dc79359127e820ab6 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 21 Jun 2024 11:53:36 +0000
Subject: [PATCH 1/9] [MLIR][Linalg] Add aggregate ops decomposition pass and
softmax decomposition implementation
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 10 +
.../Dialect/Linalg/IR/LinalgInterfaces.td | 2 +-
mlir/include/mlir/Dialect/Linalg/Passes.td | 5 +
.../Linalg/TransformOps/LinalgTransformOps.td | 18 +-
.../Dialect/Linalg/Transforms/Transforms.h | 15 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 211 +++++++-----------
.../TransformOps/LinalgTransformOps.cpp | 34 ++-
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../DecomposeAggregateNamedLinalgOps.cpp | 62 +++++
.../Dialect/Linalg/decompose-named-ops.mlir | 107 +++++++++
.../Linalg/transform-op-decompose.mlir | 54 ++++-
11 files changed, 345 insertions(+), 174 deletions(-)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
create mode 100644 mlir/test/Dialect/Linalg/decompose-named-ops.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 08afdf373f014..3858075fae137 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -30,6 +30,16 @@ class IteratorTypeAttr;
class LinalgOp;
class GenericOp;
+/// Container for result values of decomposition.
+/// - `decomposedOps` contains operations created by the decomposition that are
+/// returned to the caller for further transformations.
+/// - `decomposedValues` contains the values corresponding to the result of the
+/// aggregate operation.
+struct DecompositionResult {
+ SmallVector<Operation *> decomposedOps;
+ SmallVector<Value> decomposedValues;
+};
+
namespace detail {
/// Implementation of the method that check if given operands
/// can be dropped, i.e. the remaining operands can compute the loop
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9..9b1ab20552628 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -862,7 +862,7 @@ def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
In other words, the returned vector can be used directly with
`RewriterBase::replaceOp(this, returnedValues)`.
}],
- /*retType=*/"FailureOr<SmallVector<Value>>",
+ /*retType=*/"FailureOr<DecompositionResult>",
/*methodName=*/"decomposeOperation",
/*args=*/(ins
"OpBuilder &":$b),
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d96ad919b65f0..98c23b97534f8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgDecomposeAggregateNamedOpsPass : Pass<"linalg-decompose-named-ops"> {
+ let summary = "Decompose complex named ops (e.g., Softmax) into a sequence of linalg named ops";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 866275cedf68b..c059196b807a7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1317,25 +1317,21 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- TODO
+ Decomposes high-level named ops into a sequence of non-aggregate named ops
+ via `AggregatedOpInterface`.
+
+ The operation ignores non-decomposable ops. The return handles point to
+ a sequence of named ops produced by the decomposition.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs TransformHandleTypeInterface:$transformed);
+ let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// RewriteInDestinationPassingStyleOp.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2a58d02d7b704..2c9e201ccbd85 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1589,6 +1589,11 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);
+/// Populates `patterns` with patterns to decompose high-level aggregate named
+/// ops (e.g., softmax) into a sequence of simpler linalg named ops, defining
+/// the operation semantics.
+void populateDecomposeAggregateNamedOpsPatterns(RewritePatternSet &patterns);
+
/// Linalg decompose convolutions patterns
/// Populates patterns to decompose high-D convolution ops into low-D ones.
@@ -1736,10 +1741,12 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);
/// Adds patterns that reduce the rank of named contraction ops that have
-/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
-/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
-/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
-/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
+/// unit dimensions in the operand(s) by converting to a sequence of
+/// `collapse_shape`,
+/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For
+/// example a `linalg.batch_matmul` with unit batch size will convert to
+/// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will
+/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
} // namespace linalg
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..383f285969ad7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2564,116 +2564,41 @@ void SoftmaxOp::getEffects(
// Helper functions for softmax decomposition.
// @{
-
-// Helper function to produce the iterator types (reduction or parallel) and
-// affine maps for the iterators used in the decomposition of softmax.
-// This method creates:
-// If allParallel == true:
-// - iterator type: {parallel, ..., parallel}
-// - affine maps:
-// -- identity with inputRank dimensions.
-// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
-// where N == inputRank.
-//
-// If allParallel == false:
-// - iterator type at dim(i) == parallel for i != \p dim and
-// dim(dim) == reduction.
-// - affine map:
-// -- identity with inputRank dimensions.
-// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
-// where N == inputRank.
-static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
-computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
- int64_t dim, bool allParallel = false) {
- SmallVector<utils::IteratorType> iteratorTypes(inputRank,
- utils::IteratorType::parallel);
- if (!allParallel)
- iteratorTypes[dim] = utils::IteratorType::reduction;
- MLIRContext *ctxt = builder.getContext();
- auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
- SmallVector<AffineExpr, 2> affineExprs;
- for (int i = 0; i < inputRank; i++) {
- if (i != dim)
- affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
- }
- auto reductionMap =
- AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
- SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
- return std::make_tuple(iteratorTypes, indexingMaps);
-}
-
-// Helper function to produce a linalg.generic that computes a reduction on
-// dimension \p dim with the operation type \p T.
-template <typename T>
-static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
- int64_t dim) {
- auto inputType = cast<ShapedType>(input.getType());
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputRank = inputShape.size();
- auto [iteratorTypes, indexingMaps] =
- computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
- assert(indexingMaps.size() == 2 &&
- "We should have two maps: 1 for the input, 1 for the output");
- assert(indexingMaps[0].isIdentity() && "input map should be identity");
-
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, output.getType(), input, output, indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<T>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- return genericOp.getResult(0);
-}
-
-/// Produce a linalg generic that computes the second step of the softmax
-/// decomposition: res = exp(input - max), where \p max is the max of \p input
-/// on dimension \p dim.
-static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
- Value max, Value output, int64_t dim) {
- auto inputType = cast<ShapedType>(input.getType());
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputRank = inputShape.size();
- auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
- builder, inputRank, dim, /*allParallel=*/true);
- assert(indexingMaps.size() == 2 && "We should have one map for each input");
- assert(indexingMaps[0].isIdentity() && "input map should be identity");
- // Add the affine map for the output argument.
- indexingMaps.push_back(indexingMaps[0]);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
- iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
- Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
- Value result = b.create<math::ExpOp>(loc, diff);
- b.create<linalg::YieldOp>(loc, result);
- });
- return genericOp.getResult(0);
-}
-
-/// Produce a linalg generic that computes the final step of the softmax
-/// decomposition.
-/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
-/// yield n / d
-/// }
-static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
- Value denominator, Value output, int64_t dim) {
- auto inputType = cast<ShapedType>(numerator.getType());
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputRank = inputShape.size();
- auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
- builder, inputRank, dim, /*allParallel=*/true);
- assert(indexingMaps.size() == 2 &&
- "We should have one map for each input (2)");
- assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
- // Add the affine map for the output tensor.
- indexingMaps.push_back(indexingMaps[0]);
- auto genericOp = builder.create<linalg::GenericOp>(
- loc, numerator.getType(), ValueRange{numerator, denominator}, output,
- indexingMaps, iteratorTypes,
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
- b.create<linalg::YieldOp>(loc, result);
- });
- return genericOp.getResult(0);
+TypedAttr createInitValueForReduceMaxOp(Type type, OpBuilder &b) {
+ if (isa<FloatType>(type))
+ return b.getFloatAttr(
+ type, APFloat::getSmallest(cast<FloatType>(type).getFloatSemantics()));
+ if (isa<IntegerType>(type))
+ return b.getIntegerAttr(
+ type, APInt::getSignedMinValue(type.getIntOrFloatBitWidth()));
+ return {};
+}
+
+TypedAttr createInitValueForReduceSumOp(Type type, OpBuilder &b) {
+ if (isa<FloatType>(type))
+ return b.getFloatAttr(
+ type, APFloat::getZero(cast<FloatType>(type).getFloatSemantics()));
+ if (isa<IntegerType>(type))
+ return b.getIntegerAttr(type, APInt::getZero(type.getIntOrFloatBitWidth()));
+ return {};
+}
+
+Value createLinalgReduceMaxBody(OpBuilder b, Location loc, ValueRange args,
+ Type elementTy) {
+ if (isa<FloatType>(elementTy))
+ return b.create<arith::MaxNumFOp>(loc, args[0], args[1]);
+ if (isa<IntegerType>(elementTy))
+ return b.create<arith::MaxSIOp>(loc, args[0], args[1]);
+ return {};
+}
+
+Value createLinalgReduceSumBody(OpBuilder &b, Location loc, ValueRange args,
+ Type elementTy) {
+ if (isa<FloatType>(elementTy))
+ return b.create<arith::AddFOp>(loc, args[0], args[1]);
+ if (isa<IntegerType>(elementTy))
+ return b.create<arith::AddIOp>(loc, args[0], args[1]);
+ return {};
}
// @} End helper functions for softmax decomposition.
@@ -2695,7 +2620,7 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
/// 4. Divide z and l. This gives the N-dimensional softmax.
/// softmax = z / l
///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(*this);
Location loc = getLoc();
@@ -2706,32 +2631,60 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
Value output = getOutput();
dims.erase(dims.begin() + reductionDim);
+
// Step 1: Compute max along dim.
Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
- Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
- elementType, b, loc,
- /*useOnlyFiniteValue=*/true);
- Value neutralForMaxFInit =
- b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
- .result();
- Value max =
- reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+ auto maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+ auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+ auto neutralMaxInitOp = b.create<linalg::FillOp>(
+ loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+ Value neutralForMaxFInit = neutralMaxInitOp.result();
+
+ auto reduceMaxOp = b.create<linalg::ReduceOp>(
+ loc, input, neutralForMaxFInit, reductionDim,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ auto result =
+ createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
// Step 2: Subtract max from input and exponentiate.
- Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
+ auto maxBroadcastOp = b.create<linalg::BroadcastOp>(
+ loc, reduceMaxOp.getResult(0), output, reduceMaxOp.getDimensionsAttr());
+
+ auto subOp = b.create<linalg::SubOp>(
+ loc, ValueRange{input, maxBroadcastOp.getResults().front()},
+ ValueRange{output});
+ auto expOp = b.create<linalg::ExpOp>(loc, ValueRange{subOp.getResult(0)},
+ ValueRange{output});
// Step 3: Compute sum along dim.
- Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
- b, loc, /*useOnlyFiniteValue=*/true);
- Value zeroInit =
- b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
- Value denominator =
- reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+ auto sumFillValAttr = createInitValueForReduceSumOp(elementType, b);
+ auto sumFillValue = b.create<arith::ConstantOp>(loc, sumFillValAttr);
+ auto neutralSumInitOp = b.create<linalg::FillOp>(
+ loc, ValueRange{sumFillValue}, ValueRange{outputReduce});
+ auto sumFilledTensor = neutralSumInitOp.result();
+ auto reduceSumOp = b.create<linalg::ReduceOp>(
+ loc, expOp.getResults(), sumFilledTensor, reductionDim,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ auto result =
+ createLinalgReduceSumBody(b, nestedLoc, args, elementType);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
// Step 4: Compute softmax.
- Value result =
- buildDivOp(b, loc, numerator, denominator, output, reductionDim);
- return SmallVector<Value>{result};
+ auto sumBcastOutput = b.create<tensor::EmptyOp>(
+ loc, getOutputOperandType().getShape(), elementType);
+ auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
+ loc, reduceSumOp.getResult(0), sumBcastOutput,
+ reduceSumOp.getDimensionsAttr());
+ auto divOp = b.create<linalg::DivOp>(
+ loc, ValueRange{expOp.getResult(0), sumBroadcastOp.getResults().front()},
+ ValueRange{output});
+ return DecompositionResult{{neutralMaxInitOp, reduceMaxOp, maxBroadcastOp,
+ subOp, expOp, neutralSumInitOp, reduceSumOp,
+ sumBroadcastOp, divOp},
+ {divOp.getResults().front()}};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4eb334f8bbbfa..555d28591104b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -431,27 +431,23 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
// Decompose the target operation if it implements the AggregatedOpInterface.
// Push the decomposed operations (the ones that replaces the values produced by
// \p target) in the `results`.
-DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
- transform::TransformRewriter &rewriter, Operation *target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
- auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
- if (!decomposableOp) {
- failed(rewriter.notifyMatchFailure(target,
- "payload is not a decomposable op"));
- return emitDefaultSilenceableFailure(target);
- }
+DiagnosedSilenceableFailure
+transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
+ auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
+ if (!decomposableOp)
+ continue;
- FailureOr<SmallVector<Value>> maybeNewResults =
- decomposableOp.decomposeOperation(rewriter);
- if (failed(maybeNewResults))
- return emitDefaultSilenceableFailure(target);
+ FailureOr<DecompositionResult> maybeNewResults =
+ decomposableOp.decomposeOperation(rewriter);
+ if (failed(maybeNewResults))
+ return emitDefaultSilenceableFailure(target);
- rewriter.replaceOp(decomposableOp, *maybeNewResults);
- for (Value val : *maybeNewResults) {
- Operation *definition = val.getDefiningOp();
- if (definition)
- results.push_back(definition);
+ rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
+ transformResults.set(cast<OpResult>(getResult(i)),
+ maybeNewResults->decomposedOps);
}
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..68582fe6cbad2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
ConvertConv2DToImg2Col.cpp
DataLayoutPropagation.cpp
DecomposeLinalgOps.cpp
+ DecomposeAggregateNamedLinalgOps.cpp
Detensorize.cpp
DropUnitDims.cpp
ElementwiseOpFusion.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
new file mode 100644
index 0000000000000..e8a5b96d54d34
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeAggregateNamedLinalgOps.cpp
@@ -0,0 +1,62 @@
+//===- DecomposeNamedLinalgOps.cpp - Patterns to break up complex ops -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGDECOMPOSEAGGREGATENAMEDOPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-decompose-named-ops"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct DecomposeSoftmaxPattern : public OpRewritePattern<SoftmaxOp> {
+ using OpRewritePattern<SoftmaxOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SoftmaxOp op,
+ PatternRewriter &rewriter) const override {
+ // Decompose softmax(x) into tmp = exp(x - max(x)); tmp / sum(tmp)
+ FailureOr<DecompositionResult> results = op.decomposeOperation(rewriter);
+ if (failed(results))
+ return rewriter.notifyMatchFailure(op, "Failed to decompose SoftmaxOp");
+ rewriter.replaceOp(op, results->decomposedValues);
+ return success();
+ }
+};
+
+} // namespace
+
+struct LinalgDecomposeAggregateNamedOpsPass
+ : public impl::LinalgDecomposeAggregateNamedOpsPassBase<
+ LinalgDecomposeAggregateNamedOpsPass> {
+ using impl::LinalgDecomposeAggregateNamedOpsPassBase<
+ LinalgDecomposeAggregateNamedOpsPass>::
+ LinalgDecomposeAggregateNamedOpsPassBase;
+
+ void runOnOperation() override;
+};
+
+void LinalgDecomposeAggregateNamedOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateDecomposeAggregateNamedOpsPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+void mlir::linalg::populateDecomposeAggregateNamedOpsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.insert<DecomposeSoftmaxPattern>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
new file mode 100644
index 0000000000000..742a25b992b9e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s -split-input-file -linalg-decompose-named-ops | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -linalg-decompose-named-ops -linalg-generalize-named-ops | FileCheck %s --check-prefix=GENERALIZECHECK
+
+func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+ %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+ return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[INF:.+]] = arith.constant 1.401300e-45 : f32
+// CHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
+// CHECK-DAG: %[[FIL:.+]] = linalg.fill
+// CHECK-NEXT: %[[RED:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<2x16x32xf32>)
+// CHECK-SAME: outs(%[[FIL]] : tensor<2x16xf32>) dimensions = [2]
+// CHECK-NEXT: (%[[IN:.+]]: f32, %[[INIT:.+]]: f32) {
+// CHECK-NEXT: %[[MAX:.+]] = arith.maxnumf %[[IN]], %[[INIT]] : f32
+// CHECK-NEXT: linalg.yield %[[MAX]] : f32
+// CHECK: %[[CST:.+]] = linalg.broadcast ins(%[[RED]] : tensor<2x16xf32>)
+// CHECK-NEXT: %[[SUB:.+]] = linalg.sub ins(%[[ARG0]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
+// CHECK-NEXT: %[[EXP:.+]] = linalg.exp ins(%[[SUB]] : tensor<2x16x32xf32>)
+// CHECK-DAG: %[[FIL:.+]] = linalg.fill
+// CHECK-NEXT: %[[SUM:.+]] = linalg.reduce ins(%[[EXP]] : tensor<2x16x32xf32>)
+// CHECK-SAME: outs(%[[FIL]] : tensor<2x16xf32>) dimensions = [2]
+// CHECK-NEXT: (%[[IN:.+]]: f32, %[[INIT:.+]]: f32) {
+// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[IN]], %[[INIT]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK-DAG: %[[CST2:.+]] = linalg.broadcast ins(%[[SUM]] : tensor<2x16xf32>)
+// CHECK-NEXT: %[[DIV:.+]] = linalg.div ins(%[[EXP]], %[[CST2]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[DST]] : tensor<2x16x32xf32>)
+// CHECK: return %[[DIV]]
+
+
+// GENERALIZECHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()>
+// GENERALIZECHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// GENERALIZECHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// GENERALIZECHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// GENERALIZECHECK-LABEL: func @softmax
+// GENERALIZECHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// GENERALIZECHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// GENERALIZECHECK-DAG: %[[INF:.+]] = arith.constant 1.401300e-45 : f32
+// GENERALIZECHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
+// GENERALIZECHECK-DAG: %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[INF]] : f32) outs(%[[EMP]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16xf32>
+// GENERALIZECHECK: %[[RED:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// GENERALIZECHECK-SAME: ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[FIL]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: %[[MAX:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
+// GENERALIZECHECK-NEXT: linalg.yield %[[MAX]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16xf32>
+// GENERALIZECHECK: %[[CST:.+]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP2]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[RED]] : tensor<2x16xf32>)
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
+// GENERALIZECHECK: %[[SUB:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP2]]]
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[ARG0]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
+// GENERALIZECHECK-SAME: outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[LHS:.+]]: f32, %[[RHS:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: %[[SUBF:.+]] = arith.subf %[[LHS]], %[[RHS]] : f32
+// GENERALIZECHECK-NEXT: linalg.yield %[[SUBF]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
+// GENERALIZECHECK: %[[EXP:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]]
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[SUB]] : tensor<2x16x32xf32>)
+// GENERALIZECHECK-SAME: outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: %[[EXPF:.+]] = math.exp %[[IN]] : f32
+// GENERALIZECHECK-NEXT: linalg.yield %[[EXPF]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
+// GENERALIZECHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
+// GENERALIZECHECK: %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[ZERO]] : f32) outs(%[[EMP]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16xf32>
+// GENERALIZECHECK: %[[RED:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// GENERALIZECHECK-SAME: ins(%[[EXP]] : tensor<2x16x32xf32>) outs(%[[FIL]] : tensor<2x16xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: %[[ADDF:.+]] = arith.addf %[[IN]], %[[OUT]] : f32
+// GENERALIZECHECK-NEXT: linalg.yield %[[ADDF]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16xf32>
+// GENERALIZECHECK: %[[CST:.+]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP2]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[RED]] : tensor<2x16xf32>)
+// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: linalg.yield %[[IN]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
+// GENERALIZECHECK: %[[DIV:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP2]]],
+// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// GENERALIZECHECK-SAME: ins(%[[EXP]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
+// GENERALIZECHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-NEXT: ^bb0(%[[LHS:.+]]: f32, %[[RHS:.+]]: f32, %[[OUT:.+]]: f32):
+// GENERALIZECHECK-NEXT: %[[DIVF:.+]] = arith.divf %[[LHS]], %[[RHS]] : f32
+// GENERALIZECHECK-NEXT: linalg.yield %[[DIVF]] : f32
+// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
+// GENERALIZECHECK: return %[[DIV]] : tensor<2x16x32xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 2e211d2fa7dbe..7243616c77d4e 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -2,6 +2,8 @@
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @conv_2d_nhwc_hwcf
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
@@ -210,32 +212,62 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// CHECK-LABEL: func.func @softmax(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
-// CHECK-DAG: %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
-// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 1.401300e-45 : f32
+// CHECK: %[[D2:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x16xf32>
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D8:.+]] = arith.maxnumf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D8]] : f32
// CHECK: } -> tensor<2x16xf32>
-// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
-// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK: %[[BCST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[D3]] : tensor<2x16xf32>) outs(%[[DST]] : tensor<2x16x32xf32>)
+// CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x16x32xf32>
+// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types =
+// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[BCST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
-// CHECK: %[[D9:.+]] = math.exp %[[D8]] : f32
-// CHECK: linalg.yield %[[D9]] : f32
+// CHECK: linalg.yield %[[D8]] : f32
// CHECK: } -> tensor<2x16x32xf32>
+// CHECK: %[[EXP:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[D4]] : tensor<2x16x32xf32>)
+// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT: %[[EXPF:.+]] = math.exp %[[IN]] : f32
+// CHECK-NEXT: linalg.yield %[[EXPF]] : f32
+// CHECK-NEXT: } -> tensor<2x16x32xf32>
// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
+// CHECK: %[[D5:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x16xf32>
// CHECK: %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
-// CHECK-SAME: "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
+// CHECK-SAME: "parallel", "reduction"]} ins(%[[EXP]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
// CHECK: linalg.yield %[[D8]] : f32
// CHECK: } -> tensor<2x16xf32>
-// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
-// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
+// CHECK: %[[EMP:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK: %[[CST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[D6]] : tensor<2x16xf32>) outs(%[[EMP]] : tensor<2x16x32xf32>) {
+// CHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK-NEXT: linalg.yield %[[IN]] : f32
+// CHECK-NEXT: } -> tensor<2x16x32xf32>
+// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types =
+// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[EXP]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
@@ -250,6 +282,8 @@ module attributes {transform.with_named_sequence} {
%2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op
+ %4 = transform.structured.generalize %3: (!transform.any_op) -> !transform.any_op
+
transform.yield
}
}
>From 717f622b346aff974c262dad421e5273f1c205a3 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Wed, 3 Jul 2024 15:02:16 +0000
Subject: [PATCH 2/9] fixup! [MLIR][Linalg] Add aggregate ops decomposition
pass and softmax decomposition implementation
---
mlir/test/Dialect/Linalg/decompose-named-ops.mlir | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 742a25b992b9e..619956d3417a9 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -63,7 +63,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// GENERALIZECHECK: %[[SUB:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP2]]]
// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
// GENERALIZECHECK-SAME: ins(%[[ARG0]], %[[CST]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>)
-// GENERALIZECHECK-SAME: outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
// GENERALIZECHECK-NEXT: ^bb0(%[[LHS:.+]]: f32, %[[RHS:.+]]: f32, %[[OUT:.+]]: f32):
// GENERALIZECHECK-NEXT: %[[SUBF:.+]] = arith.subf %[[LHS]], %[[RHS]] : f32
// GENERALIZECHECK-NEXT: linalg.yield %[[SUBF]] : f32
@@ -71,12 +71,11 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// GENERALIZECHECK: %[[EXP:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]]]
// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
// GENERALIZECHECK-SAME: ins(%[[SUB]] : tensor<2x16x32xf32>)
-// GENERALIZECHECK-SAME: outs(%[[CST]] : tensor<2x16x32xf32>) {
+// GENERALIZECHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
// GENERALIZECHECK-NEXT: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// GENERALIZECHECK-NEXT: %[[EXPF:.+]] = math.exp %[[IN]] : f32
// GENERALIZECHECK-NEXT: linalg.yield %[[EXPF]] : f32
// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
-// GENERALIZECHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
// GENERALIZECHECK: %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel"]}
// GENERALIZECHECK-SAME: ins(%[[ZERO]] : f32) outs(%[[EMP]] : tensor<2x16xf32>) {
>From 87fb531438a3aba60f4bf98e5c0bf775a0ffbf13 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Wed, 3 Jul 2024 19:30:25 +0000
Subject: [PATCH 3/9] skip temporary
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 383f285969ad7..363bf678cfa13 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2638,10 +2638,9 @@ FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
auto neutralMaxInitOp = b.create<linalg::FillOp>(
loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
- Value neutralForMaxFInit = neutralMaxInitOp.result();
auto reduceMaxOp = b.create<linalg::ReduceOp>(
- loc, input, neutralForMaxFInit, reductionDim,
+ loc, input, neutralMaxInitOp.result(), reductionDim,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
auto result =
createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
>From e24372af65145e0c604ffcd71c761e519f22c28c Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Wed, 3 Jul 2024 19:54:14 +0000
Subject: [PATCH 4/9] Fix float min constant
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 ++-
mlir/test/Dialect/Linalg/decompose-named-ops.mlir | 4 ++--
mlir/test/Dialect/Linalg/transform-op-decompose.mlir | 2 +-
3 files changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 363bf678cfa13..c35356dadcc15 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2567,7 +2567,8 @@ void SoftmaxOp::getEffects(
TypedAttr createInitValueForReduceMaxOp(Type type, OpBuilder &b) {
if (isa<FloatType>(type))
return b.getFloatAttr(
- type, APFloat::getSmallest(cast<FloatType>(type).getFloatSemantics()));
+ type, APFloat::getLargest(cast<FloatType>(type).getFloatSemantics(),
+ /*Negative=*/true));
if (isa<IntegerType>(type))
return b.getIntegerAttr(
type, APInt::getSignedMinValue(type.getIntOrFloatBitWidth()));
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 619956d3417a9..9d36e194413ad 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -8,7 +8,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[INF:.+]] = arith.constant 1.401300e-45 : f32
+// CHECK-DAG: %[[INF:.+]] = arith.constant -3.40282347E+38 : f32
// CHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
// CHECK-DAG: %[[FIL:.+]] = linalg.fill
// CHECK-NEXT: %[[RED:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<2x16x32xf32>)
@@ -39,7 +39,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// GENERALIZECHECK-LABEL: func @softmax
// GENERALIZECHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// GENERALIZECHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
-// GENERALIZECHECK-DAG: %[[INF:.+]] = arith.constant 1.401300e-45 : f32
+// GENERALIZECHECK-DAG: %[[INF:.+]] = arith.constant -3.40282347E+38 : f32
// GENERALIZECHECK-DAG: %[[EMP:.+]] = tensor.empty() : tensor<2x16xf32>
// GENERALIZECHECK-DAG: %[[FIL:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
// GENERALIZECHECK-SAME: iterator_types = ["parallel", "parallel"]}
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 7243616c77d4e..09f738a876ae2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -212,7 +212,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// CHECK-LABEL: func.func @softmax(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
-// CHECK-DAG: %[[CST:.+]] = arith.constant 1.401300e-45 : f32
+// CHECK-DAG: %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
// CHECK: %[[D2:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) {
>From a75b13595219e2056197ecea5fb9cf74c1f8144b Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 4 Jul 2024 09:58:32 +0000
Subject: [PATCH 5/9] Fail gracefully on memrefs
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 5 +++++
mlir/test/Dialect/Linalg/decompose-named-ops.mlir | 7 +++++++
2 files changed, 12 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c35356dadcc15..c7a7dec60cb47 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2622,6 +2622,11 @@ Value createLinalgReduceSumBody(OpBuilder &b, Location loc, ValueRange args,
/// softmax = z / l
///
FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+ if (!isa<RankedTensorType>(getInput().getType())) {
+ // The decomposition assumes ranked tensors as input
+ return failure();
+ }
+
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(*this);
Location loc = getLoc();
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 9d36e194413ad..148a855edb657 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -104,3 +104,10 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// GENERALIZECHECK-NEXT: linalg.yield %[[DIVF]] : f32
// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
// GENERALIZECHECK: return %[[DIV]] : tensor<2x16x32xf32>
+
+// COM: decomposition assumes tensors as inputs, this is just to make sure nothing breaks
+func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256xf32>) {
+ linalg.softmax
+ dimension(1) ins(%arg0 : memref<16x64x256xf32>) outs(%arg1 : memref<16x64x256xf32>)
+ return
+}
>From 99260403f2a9829ac306733bf014af85ef2c53a6 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 4 Jul 2024 10:17:56 +0000
Subject: [PATCH 6/9] Bind all decomposed ops to the same single result of
decompose interface op
---
.../Dialect/Linalg/TransformOps/LinalgTransformOps.td | 7 ++++---
.../Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 8 +++++---
2 files changed, 9 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c059196b807a7..cbebec366973c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1323,12 +1323,13 @@ def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface
Decomposes high-level named ops into a sequence of non-aggregate named ops
via `AggregatedOpInterface`.
- The operation ignores non-decomposable ops. The return handles point to
- a sequence of named ops produced by the decomposition.
+ The operation ignores non-decomposable ops. The return handle points to
+ a sequence of named ops produced by the all decompositions (i.e. the
+ information about decomposed op origin is lost).
}];
let arguments = (ins TransformHandleTypeInterface:$target);
- let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+ let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 555d28591104b..acba9b02e3729 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -430,11 +430,13 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
// Decompose the target operation if it implements the AggregatedOpInterface.
// Push the decomposed operations (the ones that replaces the values produced by
-// \p target) in the `results`.
+// \p target) in the `results`. Decompositions for all targets bind to the same
+// single ouptut value, thus the information about the original targets is lost.
DiagnosedSilenceableFailure
transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
+ SmallVector<Operation *> allDecomposedOps;
for (auto [i, target] : llvm::enumerate(state.getPayloadOps(getTarget()))) {
auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
if (!decomposableOp)
@@ -446,9 +448,9 @@ transform::DecomposeInterfaceOp::apply(transform::TransformRewriter &rewriter,
return emitDefaultSilenceableFailure(target);
rewriter.replaceOp(decomposableOp, maybeNewResults->decomposedValues);
- transformResults.set(cast<OpResult>(getResult(i)),
- maybeNewResults->decomposedOps);
+ allDecomposedOps.append(maybeNewResults->decomposedOps);
}
+ transformResults.set(cast<OpResult>(getResult()), allDecomposedOps);
return DiagnosedSilenceableFailure::success();
}
>From d3f7899042b8d42dcfe0a37128bedfb875c9182c Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 4 Jul 2024 10:35:10 +0000
Subject: [PATCH 7/9] Check for pure tensor semantics instead of input type
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c7a7dec60cb47..77b9ebcc5bf3a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2622,7 +2622,7 @@ Value createLinalgReduceSumBody(OpBuilder &b, Location loc, ValueRange args,
/// softmax = z / l
///
FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
- if (!isa<RankedTensorType>(getInput().getType())) {
+ if (!hasPureTensorSemantics()) {
// The decomposition assumes ranked tensors as input
return failure();
}
>From 593bc911dbd159fc85d73f3329acf9777d9b33c3 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 4 Jul 2024 14:21:28 +0000
Subject: [PATCH 8/9] Add basic check as a smoke test for softmax_memref
---
mlir/test/Dialect/Linalg/decompose-named-ops.mlir | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 148a855edb657..52a51eab2073a 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -105,9 +105,13 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// GENERALIZECHECK-NEXT: } -> tensor<2x16x32xf32>
// GENERALIZECHECK: return %[[DIV]] : tensor<2x16x32xf32>
+// -----
+
// COM: decomposition assumes tensors as inputs, this is just to make sure nothing breaks
func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256xf32>) {
+// CHECK-LABEL: @softmax_memref
linalg.softmax
dimension(1) ins(%arg0 : memref<16x64x256xf32>) outs(%arg1 : memref<16x64x256xf32>)
+// CHECK-NEXT: linalg.softmax {{.*}}
return
}
>From 6201b553db0adbd8bcfdd791e3fb3fda316c3ef3 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Fri, 5 Jul 2024 10:54:21 +0000
Subject: [PATCH 9/9] Add dynamic shapes test case
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 22 ++++++++--
.../Dialect/Linalg/decompose-named-ops.mlir | 40 +++++++++++++++++++
2 files changed, 58 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 77b9ebcc5bf3a..801f766b22d62 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2634,12 +2634,21 @@ FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
ShapedType inputType = getInputOperandType();
Type elementType = inputType.getElementType();
int64_t reductionDim = getDimension();
- SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
Value output = getOutput();
- dims.erase(dims.begin() + reductionDim);
+
+ SmallVector<int64_t> reduceShape;
+ SmallVector<Value> dynReduceDims;
+ for (unsigned i = 0; i < inputType.getRank(); i++) {
+ if (reductionDim != i) {
+ reduceShape.push_back(inputType.getDimSize(i));
+ if (inputType.isDynamicDim(i))
+ dynReduceDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+ }
+ }
// Step 1: Compute max along dim.
- Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
+ Value outputReduce =
+ b.create<tensor::EmptyOp>(loc, reduceShape, elementType, dynReduceDims);
auto maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
auto neutralMaxInitOp = b.create<linalg::FillOp>(
@@ -2678,8 +2687,13 @@ FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
});
// Step 4: Compute softmax.
+ SmallVector<Value> dynDims;
+ for (unsigned i = 0; i < inputType.getRank(); i++) {
+ if (inputType.isDynamicDim(i))
+ dynDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+ }
auto sumBcastOutput = b.create<tensor::EmptyOp>(
- loc, getOutputOperandType().getShape(), elementType);
+ loc, getOutputOperandType().getShape(), elementType, dynDims);
auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
loc, reduceSumOp.getResult(0), sumBcastOutput,
reduceSumOp.getDimensionsAttr());
diff --git a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
index 52a51eab2073a..8c00feea56192 100644
--- a/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-named-ops.mlir
@@ -115,3 +115,43 @@ func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256x
// CHECK-NEXT: linalg.softmax {{.*}}
return
}
+
+
+// -----
+
+func.func @softmax_dynamic_shapes(%arg0: tensor<?x?x?xf32>, %dst: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<?x?x?xf32>) outs(%dst: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+
+// CHECK: func.func @softmax_dynamic_shapes(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[INF:.+]] = arith.constant -3.40282347E+38 : f32
+// CHECK-DAG: %[[CST_0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST_1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[CST_2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST_0]] : tensor<?x?x?xf32>
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST_1]] : tensor<?x?x?xf32>
+// CHECK-DAG: %[[EMP:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
+// CHECK-DAG: %[[FIL:.+]] = linalg.fill
+// CHECK-NEXT: %[[RED:.+]] = linalg.reduce ins(%[[ARG0]] : tensor<?x?x?xf32>)
+// CHECK-SAME: outs(%[[FIL]] : tensor<?x?xf32>) dimensions = [2]
+// CHECK-NEXT: (%[[IN:.+]]: f32, %[[INIT:.+]]: f32) {
+// CHECK-NEXT: %[[MAX:.+]] = arith.maxnumf %[[IN]], %[[INIT]] : f32
+// CHECK-NEXT: linalg.yield %[[MAX]] : f32
+// CHECK: %[[CST:.+]] = linalg.broadcast ins(%[[RED]] : tensor<?x?xf32>)
+// CHECK-NEXT: %[[SUB:.+]] = linalg.sub ins(%[[ARG0]], %[[CST]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+// CHECK-NEXT: %[[EXP:.+]] = linalg.exp ins(%[[SUB]] : tensor<?x?x?xf32>)
+// CHECK-DAG: %[[FIL:.+]] = linalg.fill
+// CHECK-NEXT: %[[SUM:.+]] = linalg.reduce ins(%[[EXP]] : tensor<?x?x?xf32>)
+// CHECK-SAME: outs(%[[FIL]] : tensor<?x?xf32>) dimensions = [2]
+// CHECK-NEXT: (%[[IN:.+]]: f32, %[[INIT:.+]]: f32) {
+// CHECK-NEXT: %[[ADD:.+]] = arith.addf %[[IN]], %[[INIT]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST_0]] : tensor<?x?x?xf32>
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST_1]] : tensor<?x?x?xf32>
+// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[CST_2]] : tensor<?x?x?xf32>
+// CHECK-DAG: %[[EMP:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]], %[[DIM2]]) : tensor<?x?x?xf32>
+// CHECK-DAG: %[[CST2:.+]] = linalg.broadcast ins(%[[SUM]] : tensor<?x?xf32>)
+// CHECK-NEXT: %[[DIV:.+]] = linalg.div ins(%[[EXP]], %[[CST2]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[DST]] : tensor<?x?x?xf32>)
+// CHECK: return %[[DIV]]
More information about the Mlir-commits
mailing list