[Mlir-commits] [mlir] [mlir][tosa] Fix bf16 reduction accumulator widening (PR #192045)
Georgios Pinitas
llvmlistbot at llvm.org
Tue Apr 14 05:42:14 PDT 2026
https://github.com/GeorgeARM created https://github.com/llvm/llvm-project/pull/192045
Use f32 accumulator when lowering bf16 arithmetic reductions in `TosaToLinalg`; then truncate the result back to bf16.
>From c6c26cdf43a313e7a04ab7ea9fcf72b570103504 Mon Sep 17 00:00:00 2001
From: Georgios Pinitas <georgios.pinitas at arm.com>
Date: Tue, 14 Apr 2026 13:23:49 +0100
Subject: [PATCH] [mlir][tosa] Fix bf16 reduction accumulator widening
Use f32 accumulator when lowering bf16 arithmetic reductions in
`TosaToLinalg`; then truncate the result back to bf16.
Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 8 +++---
.../TosaToLinalg/tosa-to-linalg.mlir | 25 +++++++++++++++++++
2 files changed, 30 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 76346a766f1f7..50663ddd27346 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1172,9 +1172,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
Value input = op->getOperand(0);
// Figure out the accType if needed
- bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
- isa<FloatType>(elementTy) &&
- cast<FloatType>(elementTy).isBF16();
+ const bool needsFp32AccTy =
+ isa<FloatType>(elementTy) && cast<FloatType>(elementTy).isBF16();
+ const bool widenAccTy = (std::is_same_v<OpTy, tosa::ReduceSumOp> ||
+ std::is_same_v<OpTy, tosa::ReduceProductOp>) &&
+ needsFp32AccTy;
Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
SmallVector<int64_t> reduceShape;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index e6bd800a0cf0a..20c93c671f48c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -983,6 +983,31 @@ func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
// -----
+// CHECK-LABEL: @reduce_product_bf16
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
+func.func @reduce_product_bf16(%arg0: tensor<5x4xbf16>) -> () {
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
+ // CHECK: [[CST1:%.+]] = arith.constant 1.0
+ // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST1]]{{.*}}outs([[INIT]]
+ // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<5xf32>) dimensions = [1]
+ // CHECK: (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
+ // CHECK: [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
+ // CHECK: [[ACC:%.+]] = arith.mulf [[EXTF]], %[[ARG2]] : f32
+ // CHECK: linalg.yield [[ACC]] : f32
+ // CHECK: }
+ // CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<5xbf16>
+ // CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<5xf32>) outs([[INIT_RES]] : tensor<5xbf16>)
+ // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16):
+ // CHECK: [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16
+ // CHECK: linalg.yield [[TRUNCF]] : bf16
+ // CHECK: }
+ // CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xbf16> into tensor<5x1xbf16>
+ %0 = tosa.reduce_product %arg0 {axis = 1 : i32} : (tensor<5x4xbf16>) -> tensor<5x1xbf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @reduce_float
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
More information about the Mlir-commits
mailing list