[Mlir-commits] [mlir] 4fb25ca - [MLIR][Linalg] Preserve DPS when decomposing Softmax
Lorenzo Chelini
llvmlistbot at llvm.org
Fri Jul 21 09:03:31 PDT 2023
Author: Lorenzo Chelini
Date: 2023-07-21T18:03:26+02:00
New Revision: 4fb25ca51c851015d7e0081b08da342285de2cb1
URL: https://github.com/llvm/llvm-project/commit/4fb25ca51c851015d7e0081b08da342285de2cb1
DIFF: https://github.com/llvm/llvm-project/commit/4fb25ca51c851015d7e0081b08da342285de2cb1.diff
LOG: [MLIR][Linalg] Preserve DPS when decomposing Softmax
Preserve destination passing style (DPS) when decomposing
`linalg.Softmax`; instead of creating a new empty, which may materialize
as a new buffer after bufferization, use the result directly.
Reviewed By: qcolombet
Differential Revision: https://reviews.llvm.org/D155942
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/transform-op-decompose.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 464a30a85f7105..d6778ed72c7d0e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2465,31 +2465,32 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
Type elementType = inputType.getElementType();
int64_t reductionDim = getDimension();
SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
- Value outputNd = b.create<tensor::EmptyOp>(loc, dims, elementType);
+ Value output = getOutput();
dims.erase(dims.begin() + reductionDim);
// Step 1: Compute max along dim.
- Value output = b.create<tensor::EmptyOp>(loc, dims, elementType);
+ Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
Value neutralForMaxF =
arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc);
Value neutralForMaxFInit =
- b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, output).result();
+ b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
+ .result();
Value max =
reduce<arith::MaxFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
// Step 2: Subtract max from input and exponentiate.
- Value numerator =
- buildSubAndExpOp(b, loc, input, max, outputNd, reductionDim);
+ Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
// Step 3: Compute sum along dim.
Value zero =
arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc);
- Value zeroInit = b.create<linalg::FillOp>(loc, Value{zero}, output).result();
+ Value zeroInit =
+ b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
Value denominator =
reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
// Step 4: Compute softmax.
Value result =
- buildDivOp(b, loc, numerator, denominator, outputNd, reductionDim);
+ buildDivOp(b, loc, numerator, denominator, output, reductionDim);
return SmallVector<Value>{result};
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 30c2ab80546ce7..30a155e28a966a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -208,8 +208,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
}
// CHECK-LABEL: func.func @softmax(
-//CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
-// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 : f32
// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
@@ -221,7 +220,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// CHECK: } -> tensor<2x16xf32>
// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
-// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) {
+// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
// CHECK: %[[D9:.+]] = math.exp %[[D8]] : f32
@@ -237,13 +236,12 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// CHECK: } -> tensor<2x16xf32>
// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
-// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) {
+// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
// CHECK: linalg.yield %[[D8]] : f32
// CHECK: } -> tensor<2x16x32xf32>
// CHECK: return %[[D7]] : tensor<2x16x32xf32>
-// CHECK: }
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
More information about the Mlir-commits
mailing list