[Mlir-commits] [mlir] [mlir][linalg] Add bufferization for `linalg.softmax` (PR #97019)

Matthias Springer llvmlistbot at llvm.org
Fri Jun 28 00:55:35 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/97019

Implement the `BufferizableOpInterface` for `linalg.softmax`. The op is not a `LinalgOp`, so it is not covered by the "catch all" `LinalgOp` interface implementation.

>From 09d512645c522900a48952c31ba5d80848443d53 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 28 Jun 2024 09:54:12 +0200
Subject: [PATCH] [mlir][linalg] Add bufferization for `linalg.softmax`

Implement the `BufferizableOpInterface` for `linalg.softmax`. The op is not a `LinalgOp`, so it is not covered by the "catch all" `LinalgOp` interface implementation.
---
 .../BufferizableOpInterfaceImpl.cpp           | 31 +++++++++++++++++++
 mlir/test/Dialect/Linalg/bufferize.mlir       | 17 ++++++++++
 2 files changed, 48 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 81a5398dabcb7..be158af09d398 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,6 +162,35 @@ struct LinalgOpInterfaceHelper {
     (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
   }
 };
+
+struct SoftmaxOpInterface
+    : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
+                                                     linalg::SoftmaxOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    // Output operand is not read.
+    auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+    return &opOperand == &softmaxOp.getInputMutable();
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+    FailureOr<Value> inputBuffer =
+        getBuffer(rewriter, softmaxOp.getInput(), options);
+    if (failed(inputBuffer))
+      return failure();
+    FailureOr<Value> outputBuffer =
+        getBuffer(rewriter, softmaxOp.getOutput(), options);
+    if (failed(outputBuffer))
+      return failure();
+    rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
+                                       /*result=*/TypeRange(), *inputBuffer,
+                                       *outputBuffer, softmaxOp.getDimension());
+    replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -174,5 +203,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
         >::registerOpInterface(ctx);
+
+    SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e8ab1184b1fd2..f416cd9fcf0b2 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,20 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
   // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
   // CHECK: return %[[OUT_TENSOR]]
 }
+
+// -----
+
+// CHECK-LABEL: func @bufferize_softmax(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<2x16x32xf32>, %[[arg1:.*]]: tensor<2x16x32xf32>
+//       CHECK:   %[[m0:.*]] = bufferization.to_memref %[[arg0]]
+//       CHECK:   %[[alloc:.*]] = memref.alloc()
+//   CHECK-NOT:   memref.copy
+//       CHECK:   linalg.softmax dimension(2) ins(%[[m0]] : {{.*}}) outs(%[[alloc:.*]] : {{.*}})
+//       CHECK:   %[[result:.*]] = bufferization.to_tensor %[[alloc]]
+//       CHECK:   return %[[result]]
+func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+  %1 = linalg.softmax dimension(2)
+      ins(%arg0 : tensor<2x16x32xf32>)
+      outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  return %1 : tensor<2x16x32xf32>
+}



More information about the Mlir-commits mailing list