[Mlir-commits] [mlir] de0094e - [mlir][tosa] Introduce accumulator type for `reduce_sum` on bf16 (#158389)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 15 02:38:36 PDT 2025
Author: Georgios Pinitas
Date: 2025-09-15T10:38:31+01:00
New Revision: de0094edf0c8596550ed58d1b43e10969631a5ab
URL: https://github.com/llvm/llvm-project/commit/de0094edf0c8596550ed58d1b43e10969631a5ab
DIFF: https://github.com/llvm/llvm-project/commit/de0094edf0c8596550ed58d1b43e10969631a5ab.diff
LOG: [mlir][tosa] Introduce accumulator type for `reduce_sum` on bf16 (#158389)
TOSA requires that `reduce_sum` operations on bf16 accumulate into fp32.
This change updates the `linalg` legalization by introducing an explicit
accumulator type to ensure compliance with the specification.
---------
Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e2b31f640da2f..0a6f2477560a1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1160,6 +1160,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);
+ // Figure out the accType if needed
+ bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
+ isa<FloatType>(elementTy) &&
+ cast<FloatType>(elementTy).isBF16();
+ Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
+
SmallVector<int64_t> reduceShape;
SmallVector<Value> dynDims;
for (unsigned i = 0; i < inputTy.getRank(); i++) {
@@ -1174,11 +1180,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
inputs.push_back(input);
// First fill the output buffer with the init value.
- auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
- resultTy.getElementType(), dynDims)
- .getResult();
+ auto emptyTensor =
+ tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
+ .getResult();
- auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
+ auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
if (!fillValueAttr)
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
@@ -1231,8 +1237,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
std::array<Value, 2> binaryArgs{
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
- auto result = createLinalgBodyCalculationForReduceOp(
- op, binaryArgs, elementTy, rewriter);
+
+ // If reduction type
diff ers then extend (applicable to reduce_sum)
+ if (binaryArgs[0].getType() != accTy)
+ binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
+ binaryArgs[0]);
+
+ auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
+ accTy, rewriter);
if (result)
didEncounterError = true;
@@ -1273,12 +1285,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// Create a tensor full of NaNs.
auto nanValueAttr = rewriter.getFloatAttr(
- elementTy,
+ accTy,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
auto emptyNanTensor =
- tensor::EmptyOp::create(rewriter, loc, reduceShape,
- resultTy.getElementType(), dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
.getResult();
auto nanFilledTensor =
linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
@@ -1288,8 +1299,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// Create an empty tensor, non need to fill this since it will be
// overwritten by the select.
auto finalEmptyTensor =
- tensor::EmptyOp::create(rewriter, loc, reduceShape,
- resultTy.getElementType(), dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
.getResult();
// Do a selection between the tensors akin to:
@@ -1304,9 +1314,32 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
linalgOp = linalgSelect;
}
+ // Truncate back to resultTy if needed
+ Value reducedRes = linalgOp->getResult(0);
+ if (widenAccTy) {
+ auto resEmptyOp =
+ tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
+ .getResult();
+
+ const unsigned reducedRank =
+ cast<ShapedType>(reducedRes.getType()).getRank();
+ auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
+ reducedRes =
+ linalg::GenericOp::create(
+ rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes},
+ ValueRange{resEmptyOp},
+ ArrayRef<AffineMap>{identityMap, identityMap},
+ getNParallelLoopsAttrs(reducedRank),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
+ elementTy, args[0]);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
+ })
+ .getResults()[0];
+ }
+
SmallVector<ReassociationExprs, 4> reassociationMap;
- uint64_t expandInputRank =
- cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
+ uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
reassociationMap.resize(expandInputRank);
for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -1324,8 +1357,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// since here we know which dimension to expand, and `tosa::ReshapeOp` would
// not have access to such information. This matters when handling dynamically
// sized tensors.
- rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
- op, resultTy, linalgOp->getResults()[0], reassociationMap);
+ rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
+ reassociationMap);
return success();
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3fc513f823a1a..37af8b8859852 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -912,6 +912,32 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<
// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: @reduce_bf16
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
+func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<4xf32>
+ // CHECK: [[CST0:%.+]] = arith.constant 0.0
+ // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
+ // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<4xf32>) dimensions = [0]
+ // CHECK: (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
+ // CHECK: [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
+ // CHECK: [[ACC:%.+]] = arith.addf [[EXTF]], %[[ARG2]] : f32
+ // CHECK: linalg.yield [[ACC]] : f32
+ // CHECK: }
+ // CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<4xbf16>
+ // CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<4xf32>) outs([[INIT_RES]] : tensor<4xbf16>)
+ // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16):
+ // CHECK: [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16
+ // CHECK: linalg.yield [[TRUNCF]] : bf16
+ // CHECK: }
+ // CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xbf16> into tensor<1x4xbf16>
+ %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xbf16>) -> tensor<1x4xbf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @reduce_float
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
More information about the Mlir-commits
mailing list