[Mlir-commits] [mlir] [mlir][linalg] Add bufferization for `linalg.softmax` (PR #97019)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 28 00:56:03 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Implement the `BufferizableOpInterface` for `linalg.softmax`. The op is not a `LinalgOp`, so it is not covered by the "catch all" `LinalgOp` interface implementation.

---
Full diff: https://github.com/llvm/llvm-project/pull/97019.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp (+31) 
- (modified) mlir/test/Dialect/Linalg/bufferize.mlir (+17) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 81a5398dabcb7..be158af09d398 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,6 +162,35 @@ struct LinalgOpInterfaceHelper {
     (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
   }
 };
+
+struct SoftmaxOpInterface
+    : public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
+                                                     linalg::SoftmaxOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    // Output operand is not read.
+    auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+    return &opOperand == &softmaxOp.getInputMutable();
+  }
+
+  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+                          const BufferizationOptions &options) const {
+    auto softmaxOp = cast<linalg::SoftmaxOp>(op);
+    FailureOr<Value> inputBuffer =
+        getBuffer(rewriter, softmaxOp.getInput(), options);
+    if (failed(inputBuffer))
+      return failure();
+    FailureOr<Value> outputBuffer =
+        getBuffer(rewriter, softmaxOp.getOutput(), options);
+    if (failed(outputBuffer))
+      return failure();
+    rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
+                                       /*result=*/TypeRange(), *inputBuffer,
+                                       *outputBuffer, softmaxOp.getDimension());
+    replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
@@ -174,5 +203,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
         >::registerOpInterface(ctx);
+
+    SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index e8ab1184b1fd2..f416cd9fcf0b2 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -189,3 +189,20 @@ func.func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> {
   // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32>
   // CHECK: return %[[OUT_TENSOR]]
 }
+
+// -----
+
+// CHECK-LABEL: func @bufferize_softmax(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<2x16x32xf32>, %[[arg1:.*]]: tensor<2x16x32xf32>
+//       CHECK:   %[[m0:.*]] = bufferization.to_memref %[[arg0]]
+//       CHECK:   %[[alloc:.*]] = memref.alloc()
+//   CHECK-NOT:   memref.copy
+//       CHECK:   linalg.softmax dimension(2) ins(%[[m0]] : {{.*}}) outs(%[[alloc:.*]] : {{.*}})
+//       CHECK:   %[[result:.*]] = bufferization.to_tensor %[[alloc]]
+//       CHECK:   return %[[result]]
+func.func @bufferize_softmax(%arg0: tensor<2x16x32xf32>, %arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+  %1 = linalg.softmax dimension(2)
+      ins(%arg0 : tensor<2x16x32xf32>)
+      outs(%arg1: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  return %1 : tensor<2x16x32xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/97019


More information about the Mlir-commits mailing list