[Mlir-commits] [mlir] 2d9b910 - [mlir][Linalg] Add a softmax op

Quentin Colombet llvmlistbot at llvm.org
Thu Jun 29 03:57:35 PDT 2023


Author: Quentin Colombet
Date: 2023-06-29T12:57:06+02:00
New Revision: 2d9b9103b45e42feb24058fb1f8e615fdba6ae5c

URL: https://github.com/llvm/llvm-project/commit/2d9b9103b45e42feb24058fb1f8e615fdba6ae5c
DIFF: https://github.com/llvm/llvm-project/commit/2d9b9103b45e42feb24058fb1f8e615fdba6ae5c.diff

LOG: [mlir][Linalg] Add a softmax op

This patch adds a softmax op.
For now, nothing interesting happens, we can only do a round trip.
Later patches will add the tiling interface and the lowering of this op to
a sequence of simpler ops.

This is graduating the linag_ext.softmax op from iree to LLVM.

Original implementation from Harsh Menon <harsh at nod-labs.com>
Nicolas Vasilache <nicolas.vasilache at gmail.com> co-authored this patch.

Differential Revision: https://reviews.llvm.org/D153422

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index d9c1eec9ea9593..43b86cda281e75 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -87,4 +88,62 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
   let hasVerifier = 1;
 }
 
+def Linalg_SoftmaxOp : Linalg_Op<"softmax",
+    [DestinationStyleOpInterface,
+     PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+  let summary = "Softmax operator";
+  let description = [{
+    linalg.softmax computes a numerically stable version of softmax.
+
+    For a given input tensor and a specified dimension `d`, compute:
+      1. the max `m` along that dimension `d`
+      2. f(x) = exp(x - m)
+      3. sum f(x) along dimension d to get l(x).
+      4. compute the final result f(x) / l(x).
+
+    This is an aggregate linalg operation that further reduces to a small DAG of
+    structured operations.
+  }];
+
+  let arguments = (ins AnyShaped:$input,
+                       AnyShaped:$output,
+                       I64Attr:$dimension
+  );
+
+  let results = (outs Variadic<AnyRankedTensor>:$result);
+  let hasFolder = 1;
+  let assemblyFormat = [{
+    attr-dict
+    `dimension` `(` $dimension `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    (`->` type($result)^)?
+  }];
+
+  let extraClassDeclaration = [{
+    ShapedType getInputOperandType() {
+      return getInput().getType().cast<ShapedType>();
+    }
+    ShapedType getOutputOperandType() {
+      return getOutput().getType().cast<ShapedType>();
+    }
+    int64_t getInputOperandRank() {
+      return getInputOperandType().getRank();
+    }
+    int64_t getOutputOperandRank() {
+      return getOutputOperandType().getRank();
+    }
+    // Method to implement DestinationStyleOpInterface.
+    std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+      std::pair<unsigned, unsigned> outputsIndexAndLength =
+        getODSOperandIndexAndLength(1);
+      return std::make_pair<int64_t, int64_t>(
+          outputsIndexAndLength.first,
+          outputsIndexAndLength.first + outputsIndexAndLength.second);
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // LINALG_OPS

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5af2035a4832d..03407e39946035 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2140,6 +2140,39 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
 // All named ops canonicalizers and folders are auto-generated in the
 // .cpp.inc.
 
+//===----------------------------------------------------------------------===//
+// SoftmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SoftmaxOp::verify() {
+  ShapedType inputType = getInputOperandType();
+  ShapedType outputType = getOutputOperandType();
+
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  if (failed(verifyCompatibleShape(inputShape, outputShape)))
+    return emitOpError("incompatible output shape");
+
+  int64_t inputRank = getInputOperandRank();
+  int64_t dimension = getDimension();
+  if ((dimension < 0) || (dimension >= inputRank))
+    return emitOpError("incorrect dimension specified");
+
+  return success();
+}
+
+// cast(dynamic) -> static.
+LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+LogicalResult
+SoftmaxOp::reifyResultShapes(OpBuilder &b,
+                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  return cast<LinalgOp>(getOperation())
+      .reifyResultShapes(b, reifiedReturnShapes);
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index dbc93d56e2a9e0..88f070a3252ed3 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -733,3 +733,14 @@ func.func @missing_iterator_types() {
   linalg.generic {} ins() outs()
   return
 }
+
+// -----
+
+func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x16xf32> {
+  %0 = tensor.empty() : tensor<2x16xf32>
+  // expected-error @+1 {{incompatible output shape}}
+  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>)
+                                   outs(%0: tensor<2x16xf32>)
+    -> tensor<2x16xf32>
+  return %1 : tensor<2x16xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 8bf5e5ba418da9..9895cd255cf4fc 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -599,3 +599,17 @@ func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
 // CHECK-SAME:    outs
 // CHECK-SAME:    dimensions = [1]
 // CHECK-NEXT:    return %[[REDUCED]] : tensor<16x64xf32>
+
+// -----
+
+func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+  %0 = tensor.empty() : tensor<2x16x32xf32>
+  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  return %1 : tensor<2x16x32xf32>
+}
+// CHECK:      func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK:        %[[D1:.+]] = linalg.softmax dimension(2) ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D0]] :
+// CHECK-SAME:     tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+// CHECK:        return %[[D1]] : tensor<2x16x32xf32>
+// CHECK:      }


        


More information about the Mlir-commits mailing list