[Mlir-commits] [mlir] 2d9b910 - [mlir][Linalg] Add a softmax op
Quentin Colombet
llvmlistbot at llvm.org
Thu Jun 29 03:57:35 PDT 2023
Author: Quentin Colombet
Date: 2023-06-29T12:57:06+02:00
New Revision: 2d9b9103b45e42feb24058fb1f8e615fdba6ae5c
URL: https://github.com/llvm/llvm-project/commit/2d9b9103b45e42feb24058fb1f8e615fdba6ae5c
DIFF: https://github.com/llvm/llvm-project/commit/2d9b9103b45e42feb24058fb1f8e615fdba6ae5c.diff
LOG: [mlir][Linalg] Add a softmax op
This patch adds a softmax op.
For now, nothing interesting happens, we can only do a round trip.
Later patches will add the tiling interface and the lowering of this op to
a sequence of simpler ops.
This is graduating the linag_ext.softmax op from iree to LLVM.
Original implementation from Harsh Menon <harsh at nod-labs.com>
Nicolas Vasilache <nicolas.vasilache at gmail.com> co-authored this patch.
Differential Revision: https://reviews.llvm.org/D153422
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index d9c1eec9ea9593..43b86cda281e75 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -15,6 +15,7 @@
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -87,4 +88,62 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
let hasVerifier = 1;
}
+def Linalg_SoftmaxOp : Linalg_Op<"softmax",
+ [DestinationStyleOpInterface,
+ PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ let summary = "Softmax operator";
+ let description = [{
+ linalg.softmax computes a numerically stable version of softmax.
+
+ For a given input tensor and a specified dimension `d`, compute:
+ 1. the max `m` along that dimension `d`
+ 2. f(x) = exp(x - m)
+ 3. sum f(x) along dimension d to get l(x).
+ 4. compute the final result f(x) / l(x).
+
+ This is an aggregate linalg operation that further reduces to a small DAG of
+ structured operations.
+ }];
+
+ let arguments = (ins AnyShaped:$input,
+ AnyShaped:$output,
+ I64Attr:$dimension
+ );
+
+ let results = (outs Variadic<AnyRankedTensor>:$result);
+ let hasFolder = 1;
+ let assemblyFormat = [{
+ attr-dict
+ `dimension` `(` $dimension `)`
+ `ins` `(` $input `:` type($input) `)`
+ `outs` `(` $output `:` type($output) `)`
+ (`->` type($result)^)?
+ }];
+
+ let extraClassDeclaration = [{
+ ShapedType getInputOperandType() {
+ return getInput().getType().cast<ShapedType>();
+ }
+ ShapedType getOutputOperandType() {
+ return getOutput().getType().cast<ShapedType>();
+ }
+ int64_t getInputOperandRank() {
+ return getInputOperandType().getRank();
+ }
+ int64_t getOutputOperandRank() {
+ return getOutputOperandType().getRank();
+ }
+ // Method to implement DestinationStyleOpInterface.
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+ std::pair<unsigned, unsigned> outputsIndexAndLength =
+ getODSOperandIndexAndLength(1);
+ return std::make_pair<int64_t, int64_t>(
+ outputsIndexAndLength.first,
+ outputsIndexAndLength.first + outputsIndexAndLength.second);
+ }
+ }];
+ let hasVerifier = 1;
+}
+
#endif // LINALG_OPS
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5af2035a4832d..03407e39946035 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2140,6 +2140,39 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
// All named ops canonicalizers and folders are auto-generated in the
// .cpp.inc.
+//===----------------------------------------------------------------------===//
+// SoftmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SoftmaxOp::verify() {
+ ShapedType inputType = getInputOperandType();
+ ShapedType outputType = getOutputOperandType();
+
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ if (failed(verifyCompatibleShape(inputShape, outputShape)))
+ return emitOpError("incompatible output shape");
+
+ int64_t inputRank = getInputOperandRank();
+ int64_t dimension = getDimension();
+ if ((dimension < 0) || (dimension >= inputRank))
+ return emitOpError("incorrect dimension specified");
+
+ return success();
+}
+
+// cast(dynamic) -> static.
+LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+
+LogicalResult
+SoftmaxOp::reifyResultShapes(OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ return cast<LinalgOp>(getOperation())
+ .reifyResultShapes(b, reifiedReturnShapes);
+}
+
//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index dbc93d56e2a9e0..88f070a3252ed3 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -733,3 +733,14 @@ func.func @missing_iterator_types() {
linalg.generic {} ins() outs()
return
}
+
+// -----
+
+func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x16xf32> {
+ %0 = tensor.empty() : tensor<2x16xf32>
+ // expected-error @+1 {{incompatible output shape}}
+ %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>)
+ outs(%0: tensor<2x16xf32>)
+ -> tensor<2x16xf32>
+ return %1 : tensor<2x16xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 8bf5e5ba418da9..9895cd255cf4fc 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -599,3 +599,17 @@ func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
// CHECK-SAME: outs
// CHECK-SAME: dimensions = [1]
// CHECK-NEXT: return %[[REDUCED]] : tensor<16x64xf32>
+
+// -----
+
+func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+ %0 = tensor.empty() : tensor<2x16x32xf32>
+ %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+ return %1 : tensor<2x16x32xf32>
+}
+// CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK: %[[D1:.+]] = linalg.softmax dimension(2) ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+// CHECK: return %[[D1]] : tensor<2x16x32xf32>
+// CHECK: }
More information about the Mlir-commits
mailing list