[Mlir-commits] [mlir] 6699807 - [mlir][linalg] Add bufferization for `linalg.softmax` (#97019)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jun 30 04:30:27 PDT 2024


Author: Matthias Springer
Date: 2024-06-30T13:30:23+02:00
New Revision: 6699807fa7397b1ba5c2c148d04d88afd3309226

URL: https://github.com/llvm/llvm-project/commit/6699807fa7397b1ba5c2c148d04d88afd3309226
DIFF: https://github.com/llvm/llvm-project/commit/6699807fa7397b1ba5c2c148d04d88afd3309226.diff

LOG: [mlir][linalg] Add bufferization for `linalg.softmax` (#97019)

Implement the `BufferizableOpInterface` for `linalg.softmax`. The op is
not a `LinalgOp`, so it is not covered by the "catch all" `LinalgOp`
interface implementation.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 81a5398dabcb7..be158af09d398 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,6 +162,35 @@ struct LinalgOpInterfaceHelper {
     (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
   }
 };
+
+struct SoftmaxOpInterface
+    : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
+                                                     linalg::SoftmaxOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    // Output operand is not read.
+    auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+    return &opOperand == &softmaxOp.getInputMutable();
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+    FailureOr<Value> inputBuffer =
+        getBuffer(rewriter, softmaxOp.getInput(), options);
+    if (failed(inputBuffer))
+      return failure();
+    FailureOr<Value> outputBuffer =
+        getBuffer(rewriter, softmaxOp.getOutput(), options);
+    if (failed(outputBuffer))
+      return failure();
+    rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
+                                       /*result=*/TypeRange(), *inputBuffer,
+                                       *outputBuffer, softmaxOp.getDimension());
+    replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -174,5 +203,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
         >::registerOpInterface(ctx);
+
+    SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
   });
 }

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e8ab1184b1fd2..f416cd9fcf0b2 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,20 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
   // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
   // CHECK: return %[[OUT_TENSOR]]
 }
+
+// -----
+
+// CHECK-LABEL: func @bufferize_softmax(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<2x16x32xf32>, %[[arg1:.*]]: tensor<2x16x32xf32>
+//       CHECK:   %[[m0:.*]] = bufferization.to_memref %[[arg0]]
+//       CHECK:   %[[alloc:.*]] = memref.alloc()
+//   CHECK-NOT:   memref.copy
+//       CHECK:   linalg.softmax dimension(2) ins(%[[m0]] : {{.*}}) outs(%[[alloc:.*]] : {{.*}})
+//       CHECK:   %[[result:.*]] = bufferization.to_tensor %[[alloc]]
+//       CHECK:   return %[[result]]
+func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+  %1 = linalg.softmax dimension(2)
+      ins(%arg0 : tensor<2x16x32xf32>)
+      outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  return %1 : tensor<2x16x32xf32>
+}


        


More information about the Mlir-commits mailing list