[Mlir-commits] [mlir] [mlir][tosa] Fix bf16 reduction accumulator widening (PR #192045)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 14 05:42:47 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Georgios Pinitas (GeorgeARM)
<details>
<summary>Changes</summary>
Use f32 accumulator when lowering bf16 arithmetic reductions in `TosaToLinalg`; then truncate the result back to bf16.
---
Full diff: https://github.com/llvm/llvm-project/pull/192045.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-3)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+25)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 76346a766f1f7..50663ddd27346 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1172,9 +1172,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
Value input = op->getOperand(0);
// Figure out the accType if needed
- bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
- isa<FloatType>(elementTy) &&
- cast<FloatType>(elementTy).isBF16();
+ const bool needsFp32AccTy =
+ isa<FloatType>(elementTy) && cast<FloatType>(elementTy).isBF16();
+ const bool widenAccTy = (std::is_same_v<OpTy, tosa::ReduceSumOp> ||
+ std::is_same_v<OpTy, tosa::ReduceProductOp>) &&
+ needsFp32AccTy;
Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
SmallVector<int64_t> reduceShape;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index e6bd800a0cf0a..20c93c671f48c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -983,6 +983,31 @@ func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
// -----
+// CHECK-LABEL: @reduce_product_bf16
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
+func.func @reduce_product_bf16(%arg0: tensor<5x4xbf16>) -> () {
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
+ // CHECK: [[CST1:%.+]] = arith.constant 1.0
+ // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST1]]{{.*}}outs([[INIT]]
+ // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<5xf32>) dimensions = [1]
+ // CHECK: (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
+ // CHECK: [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
+ // CHECK: [[ACC:%.+]] = arith.mulf [[EXTF]], %[[ARG2]] : f32
+ // CHECK: linalg.yield [[ACC]] : f32
+ // CHECK: }
+ // CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<5xbf16>
+ // CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<5xf32>) outs([[INIT_RES]] : tensor<5xbf16>)
+ // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16):
+ // CHECK: [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16
+ // CHECK: linalg.yield [[TRUNCF]] : bf16
+ // CHECK: }
+ // CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xbf16> into tensor<5x1xbf16>
+ %0 = tosa.reduce_product %arg0 {axis = 1 : i32} : (tensor<5x4xbf16>) -> tensor<5x1xbf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @reduce_float
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
``````````
</details>
https://github.com/llvm/llvm-project/pull/192045
More information about the Mlir-commits
mailing list