[Mlir-commits] [mlir] 4fb25ca - [MLIR][Linalg] Preserve DPS when decomposing Softmax

Lorenzo Chelini llvmlistbot at llvm.org
Fri Jul 21 09:03:31 PDT 2023

Author: Lorenzo Chelini
Date: 2023-07-21T18:03:26+02:00
New Revision: 4fb25ca51c851015d7e0081b08da342285de2cb1

URL: https://github.com/llvm/llvm-project/commit/4fb25ca51c851015d7e0081b08da342285de2cb1
DIFF: https://github.com/llvm/llvm-project/commit/4fb25ca51c851015d7e0081b08da342285de2cb1.diff

LOG: [MLIR][Linalg] Preserve DPS when decomposing Softmax

Preserve destination passing style (DPS) when decomposing
`linalg.Softmax`; instead of creating a new empty, which may materialize
as a new buffer after bufferization, use the result directly.

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D155942




diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 464a30a85f7105..d6778ed72c7d0e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2465,31 +2465,32 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   Type elementType = inputType.getElementType();
   int64_t reductionDim = getDimension();
   SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
-  Value outputNd = b.create<tensor::EmptyOp>(loc, dims, elementType);
+  Value output = getOutput();
   dims.erase(dims.begin() + reductionDim);
   // Step 1: Compute max along dim.
-  Value output = b.create<tensor::EmptyOp>(loc, dims, elementType);
+  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
   Value neutralForMaxF =
       arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc);
   Value neutralForMaxFInit =
-      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, output).result();
+      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
+          .result();
   Value max =
       reduce<arith::MaxFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
   // Step 2: Subtract max from input and exponentiate.
-  Value numerator =
-      buildSubAndExpOp(b, loc, input, max, outputNd, reductionDim);
+  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
   // Step 3: Compute sum along dim.
   Value zero =
       arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc);
-  Value zeroInit = b.create<linalg::FillOp>(loc, Value{zero}, output).result();
+  Value zeroInit =
+      b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
   Value denominator =
       reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
   // Step 4: Compute softmax.
   Value result =
-      buildDivOp(b, loc, numerator, denominator, outputNd, reductionDim);
+      buildDivOp(b, loc, numerator, denominator, output, reductionDim);
   return SmallVector<Value>{result};

diff  --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 30c2ab80546ce7..30a155e28a966a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -208,8 +208,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // CHECK-LABEL:      func.func @softmax(
-//CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
-// CHECK-DAG:        %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
 // CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
 // CHECK-DAG:        %[[CST:.+]] = arith.constant 0xFF800000 : f32
 // CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
@@ -221,7 +220,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // CHECK:        } -> tensor<2x16xf32>
 // CHECK:        %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
 // CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
-// CHECK-SAME:     outs(%[[D0]] : tensor<2x16x32xf32>) {
+// CHECK-SAME:     outs(%[[DST]] : tensor<2x16x32xf32>) {
 // CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:          %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
 // CHECK:          %[[D9:.+]] = math.exp %[[D8]] : f32
@@ -237,13 +236,12 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // CHECK:        } -> tensor<2x16xf32>
 // CHECK:        %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
 // CHECK-SAME:     ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
-// CHECK-SAME:     outs(%[[D0]] : tensor<2x16x32xf32>) {
+// CHECK-SAME:     outs(%[[DST]] : tensor<2x16x32xf32>) {
 // CHECK:        ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
 // CHECK:          %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
 // CHECK:          linalg.yield %[[D8]] : f32
 // CHECK:        } -> tensor<2x16x32xf32>
 // CHECK:        return %[[D7]] : tensor<2x16x32xf32>
-// CHECK:      }
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):


More information about the Mlir-commits mailing list