[Mlir-commits] [mlir] 6699807 - [mlir][linalg] Add bufferization for `linalg.softmax` (#97019)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jun 30 04:30:27 PDT 2024
Author: Matthias Springer
Date: 2024-06-30T13:30:23+02:00
New Revision: 6699807fa7397b1ba5c2c148d04d88afd3309226
URL: https://github.com/llvm/llvm-project/commit/6699807fa7397b1ba5c2c148d04d88afd3309226
DIFF: https://github.com/llvm/llvm-project/commit/6699807fa7397b1ba5c2c148d04d88afd3309226.diff
LOG: [mlir][linalg] Add bufferization for `linalg.softmax` (#97019)
Implement the `BufferizableOpInterface` for `linalg.softmax`. The op is
not a `LinalgOp`, so it is not covered by the "catch all" `LinalgOp`
interface implementation.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Linalg/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 81a5398dabcb7..be158af09d398 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,6 +162,35 @@ struct LinalgOpInterfaceHelper {
(Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
}
};
+
+struct SoftmaxOpInterface
+ : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
+ linalg::SoftmaxOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ // Output operand is not read.
+ auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+ return &opOperand == &softmaxOp.getInputMutable();
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+ FailureOr<Value> inputBuffer =
+ getBuffer(rewriter, softmaxOp.getInput(), options);
+ if (failed(inputBuffer))
+ return failure();
+ FailureOr<Value> outputBuffer =
+ getBuffer(rewriter, softmaxOp.getOutput(), options);
+ if (failed(outputBuffer))
+ return failure();
+ rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
+ /*result=*/TypeRange(), *inputBuffer,
+ *outputBuffer, softmaxOp.getDimension());
+ replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
+ return success();
+ }
+};
} // namespace
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -174,5 +203,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::registerOpInterface(ctx);
+
+ SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e8ab1184b1fd2..f416cd9fcf0b2 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,20 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
// CHECK: return %[[OUT_TENSOR]]
}
+
+// -----
+
+// CHECK-LABEL: func @bufferize_softmax(
+// CHECK-SAME: %[[arg0:.*]]: tensor<2x16x32xf32>, %[[arg1:.*]]: tensor<2x16x32xf32>
+// CHECK: %[[m0:.*]] = bufferization.to_memref %[[arg0]]
+// CHECK: %[[alloc:.*]] = memref.alloc()
+// CHECK-NOT: memref.copy
+// CHECK: linalg.softmax dimension(2) ins(%[[m0]] : {{.*}}) outs(%[[alloc:.*]] : {{.*}})
+// CHECK: %[[result:.*]] = bufferization.to_tensor %[[alloc]]
+// CHECK: return %[[result]]
+func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+ %1 = linalg.softmax dimension(2)
+ ins(%arg0 : tensor<2x16x32xf32>)
+ outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+ return %1 : tensor<2x16x32xf32>
+}
More information about the Mlir-commits
mailing list