[Mlir-commits] [mlir] [mlir][linalg] Add bufferization for `linalg.softmax` (PR #97019)
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 28 00:55:35 PDT 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/97019
Implement the `BufferizableOpInterface` for `linalg.softmax`. The op is not a `LinalgOp`, so it is not covered by the "catch all" `LinalgOp` interface implementation.
>From 09d512645c522900a48952c31ba5d80848443d53 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 28 Jun 2024 09:54:12 +0200
Subject: [PATCH] [mlir][linalg] Add bufferization for `linalg.softmax`
Implement the `BufferizableOpInterface` for `linalg.softmax`. The op is not a `LinalgOp`, so it is not covered by the "catch all" `LinalgOp` interface implementation.
---
.../BufferizableOpInterfaceImpl.cpp | 31 +++++++++++++++++++
mlir/test/Dialect/Linalg/bufferize.mlir | 17 ++++++++++
2 files changed, 48 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 81a5398dabcb7..be158af09d398 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,6 +162,35 @@ struct LinalgOpInterfaceHelper {
(Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
}
};
+
+struct SoftmaxOpInterface
+ : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
+ linalg::SoftmaxOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ // Output operand is not read.
+ auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+ return &opOperand == &softmaxOp.getInputMutable();
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+ FailureOr<Value> inputBuffer =
+ getBuffer(rewriter, softmaxOp.getInput(), options);
+ if (failed(inputBuffer))
+ return failure();
+ FailureOr<Value> outputBuffer =
+ getBuffer(rewriter, softmaxOp.getOutput(), options);
+ if (failed(outputBuffer))
+ return failure();
+ rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
+ /*result=*/TypeRange(), *inputBuffer,
+ *outputBuffer, softmaxOp.getDimension());
+ replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
+ return success();
+ }
+};
} // namespace
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -174,5 +203,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::registerOpInterface(ctx);
+
+ SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e8ab1184b1fd2..f416cd9fcf0b2 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,20 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
// CHECK: return %[[OUT_TENSOR]]
}
+
+// -----
+
+// CHECK-LABEL: func @bufferize_softmax(
+// CHECK-SAME: %[[arg0:.*]]: tensor<2x16x32xf32>, %[[arg1:.*]]: tensor<2x16x32xf32>
+// CHECK: %[[m0:.*]] = bufferization.to_memref %[[arg0]]
+// CHECK: %[[alloc:.*]] = memref.alloc()
+// CHECK-NOT: memref.copy
+// CHECK: linalg.softmax dimension(2) ins(%[[m0]] : {{.*}}) outs(%[[alloc:.*]] : {{.*}})
+// CHECK: %[[result:.*]] = bufferization.to_tensor %[[alloc]]
+// CHECK: return %[[result]]
+func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+ %1 = linalg.softmax dimension(2)
+ ins(%arg0 : tensor<2x16x32xf32>)
+ outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+ return %1 : tensor<2x16x32xf32>
+}
More information about the Mlir-commits
mailing list