[Mlir-commits] [mlir] [mlir][tosa] Fix bf16 reduction accumulator widening (PR #192045)

Georgios Pinitas llvmlistbot at llvm.org
Tue Apr 14 05:42:14 PDT 2026


https://github.com/GeorgeARM created https://github.com/llvm/llvm-project/pull/192045

Use f32 accumulator when lowering bf16 arithmetic reductions in `TosaToLinalg`; then truncate the result back to bf16.

>From c6c26cdf43a313e7a04ab7ea9fcf72b570103504 Mon Sep 17 00:00:00 2001
From: Georgios Pinitas <georgios.pinitas at arm.com>
Date: Tue, 14 Apr 2026 13:23:49 +0100
Subject: [PATCH] [mlir][tosa] Fix bf16 reduction accumulator widening

Use f32 accumulator when lowering bf16 arithmetic reductions in
`TosaToLinalg`; then truncate the result back to bf16.

Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |  8 +++---
 .../TosaToLinalg/tosa-to-linalg.mlir          | 25 +++++++++++++++++++
 2 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 76346a766f1f7..50663ddd27346 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1172,9 +1172,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   Value input = op->getOperand(0);
 
   // Figure out the accType if needed
-  bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
-                    isa<FloatType>(elementTy) &&
-                    cast<FloatType>(elementTy).isBF16();
+  const bool needsFp32AccTy =
+      isa<FloatType>(elementTy) && cast<FloatType>(elementTy).isBF16();
+  const bool widenAccTy = (std::is_same_v<OpTy, tosa::ReduceSumOp> ||
+                           std::is_same_v<OpTy, tosa::ReduceProductOp>) &&
+                          needsFp32AccTy;
   Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
 
   SmallVector<int64_t> reduceShape;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index e6bd800a0cf0a..20c93c671f48c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -983,6 +983,31 @@ func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
 
 // -----
 
+// CHECK-LABEL: @reduce_product_bf16
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
+func.func @reduce_product_bf16(%arg0: tensor<5x4xbf16>) -> () {
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
+  // CHECK: [[CST1:%.+]] = arith.constant 1.0
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST1]]{{.*}}outs([[INIT]]
+  // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<5xf32>) dimensions = [1]
+  // CHECK:  (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
+  // CHECK:   [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
+  // CHECK:   [[ACC:%.+]] = arith.mulf [[EXTF]], %[[ARG2]] : f32
+  // CHECK:   linalg.yield [[ACC]] : f32
+  // CHECK:  }
+  // CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<5xbf16>
+  // CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<5xf32>) outs([[INIT_RES]] : tensor<5xbf16>)
+  // CHECK:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16):
+  // CHECK:   [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16
+  // CHECK:   linalg.yield [[TRUNCF]] : bf16
+  // CHECK:  }
+  // CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xbf16> into tensor<5x1xbf16>
+  %0 = tosa.reduce_product %arg0 {axis = 1 : i32} : (tensor<5x4xbf16>) -> tensor<5x1xbf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @reduce_float
 // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
 func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {



More information about the Mlir-commits mailing list