[Mlir-commits] [mlir] [mlir][tosa] Introduce accumulator type for `reduce_sum` on bf16 (PR #158389)

Georgios Pinitas llvmlistbot at llvm.org
Sun Sep 14 05:16:53 PDT 2025


https://github.com/GeorgeARM updated https://github.com/llvm/llvm-project/pull/158389

>From ee039e2fe9b9f0bdcbdc917b3eb35c8007b5aa6e Mon Sep 17 00:00:00 2001
From: Georgios Pinitas <georgios.pinitas at arm.com>
Date: Sat, 13 Sep 2025 01:30:37 +0100
Subject: [PATCH 1/2] [mlir][tosa] Introduce accumulator type for `reduce_sum`
 on bf16

TOSA requires that `reduce_sum` operations on bf16 accumulate into fp32.
This change updates the `linalg` legalization by introducing an explicit
accumulator type to ensure compliance with the specification.

Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 55 ++++++++++++++-----
 .../TosaToLinalg/tosa-to-linalg.mlir          | 21 +++++++
 2 files changed, 61 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e2b31f640da2f..96eab7197a585 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1160,6 +1160,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   auto elementTy = resultTy.getElementType();
   Value input = op->getOperand(0);
 
+  // Figure out the accType if needed
+  bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
+                    isa<FloatType>(elementTy) &&
+                    cast<FloatType>(elementTy).isBF16();
+  Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
+
   SmallVector<int64_t> reduceShape;
   SmallVector<Value> dynDims;
   for (unsigned i = 0; i < inputTy.getRank(); i++) {
@@ -1174,11 +1180,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   inputs.push_back(input);
 
   // First fill the output buffer with the init value.
-  auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
-                                             resultTy.getElementType(), dynDims)
-                         .getResult();
+  auto emptyTensor =
+      tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
+          .getResult();
 
-  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
+  auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
   if (!fillValueAttr)
     return rewriter.notifyMatchFailure(
         op, "No initial value found for reduction operation");
@@ -1231,8 +1237,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
         std::array<Value, 2> binaryArgs{
             blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
-        auto result = createLinalgBodyCalculationForReduceOp(
-            op, binaryArgs, elementTy, rewriter);
+
+        // If reduction type differs then extend (applicable to reduce_sum)
+        if (binaryArgs[0].getType() != accTy)
+          binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
+                                                binaryArgs[0]);
+
+        auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
+                                                             accTy, rewriter);
         if (result)
           didEncounterError = true;
 
@@ -1273,12 +1285,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
 
     // Create a tensor full of NaNs.
     auto nanValueAttr = rewriter.getFloatAttr(
-        elementTy,
+        accTy,
         APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
     auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
     auto emptyNanTensor =
-        tensor::EmptyOp::create(rewriter, loc, reduceShape,
-                                resultTy.getElementType(), dynDims)
+        tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
             .getResult();
     auto nanFilledTensor =
         linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
@@ -1288,8 +1299,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
     // Create an empty tensor, non need to fill this since it will be
     // overwritten by the select.
     auto finalEmptyTensor =
-        tensor::EmptyOp::create(rewriter, loc, reduceShape,
-                                resultTy.getElementType(), dynDims)
+        tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
             .getResult();
 
     // Do a selection between the tensors akin to:
@@ -1304,9 +1314,24 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
     linalgOp = linalgSelect;
   }
 
+  // Truncate back to resultTy if needed
+  Value reducedRes = linalgOp->getResult(0);
+  if (widenAccTy) {
+    auto resEmptyOp =
+        tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
+            .getResult();
+    reducedRes = linalg::MapOp::create(
+                     rewriter, loc, ValueRange{reducedRes}, resEmptyOp,
+                     [&](OpBuilder &builder, Location loc, ValueRange args) {
+                       Value val = arith::TruncFOp::create(builder, loc,
+                                                           elementTy, args[0]);
+                       linalg::YieldOp::create(builder, loc, ValueRange{val});
+                     })
+                     .getResult()[0];
+  }
+
   SmallVector<ReassociationExprs, 4> reassociationMap;
-  uint64_t expandInputRank =
-      cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
+  uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
   reassociationMap.resize(expandInputRank);
 
   for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -1324,8 +1349,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   // since here we know which dimension to expand, and `tosa::ReshapeOp` would
   // not have access to such information. This matters when handling dynamically
   // sized tensors.
-  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
-      op, resultTy, linalgOp->getResults()[0], reassociationMap);
+  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
+                                                     reassociationMap);
   return success();
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3fc513f823a1a..3b63bdf4f7219 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -912,6 +912,27 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<
 
 // -----
 
+// CHECK-LABEL: @reduce_bf16
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
+func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<4xf32>
+  // CHECK: [[CST0:%.+]] = arith.constant 0.0
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
+  // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<4xf32>) dimensions = [0]
+  // CHECK:  (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
+  // CHECK:   [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
+  // CHECK:   [[ACC:%.+]] = arith.addf [[EXTF]], %[[ARG2]] : f32
+  // CHECK:   linalg.yield [[ACC]] : f32
+  // CHECK:  }
+  // CHECK:  [[TRUNCF:%.+]] = tensor.empty() : tensor<4xbf16>
+  // CHECK:  [[RES:%.+]] = linalg.map { arith.truncf } ins([[REDUCE]]{{.*}}outs([[TRUNCF]]
+  // CHECK:  tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xbf16> into tensor<1x4xbf16>
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xbf16>) -> tensor<1x4xbf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @reduce_float
 // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
 func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {

>From 9234525cc3f00572fdf3588ea77df22f479a49ea Mon Sep 17 00:00:00 2001
From: Georgios Pinitas <georgios.pinitas at arm.com>
Date: Sun, 14 Sep 2025 13:04:13 +0100
Subject: [PATCH 2/2] Address review comments and move linalg::MapOp to
 linalg::GenericOp

Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 24 ++++++++++++-------
 .../TosaToLinalg/tosa-to-linalg.mlir          | 11 ++++++---
 2 files changed, 24 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 96eab7197a585..0a6f2477560a1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1320,14 +1320,22 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
     auto resEmptyOp =
         tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
             .getResult();
-    reducedRes = linalg::MapOp::create(
-                     rewriter, loc, ValueRange{reducedRes}, resEmptyOp,
-                     [&](OpBuilder &builder, Location loc, ValueRange args) {
-                       Value val = arith::TruncFOp::create(builder, loc,
-                                                           elementTy, args[0]);
-                       linalg::YieldOp::create(builder, loc, ValueRange{val});
-                     })
-                     .getResult()[0];
+
+    const unsigned reducedRank =
+        cast<ShapedType>(reducedRes.getType()).getRank();
+    auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
+    reducedRes =
+        linalg::GenericOp::create(
+            rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes},
+            ValueRange{resEmptyOp},
+            ArrayRef<AffineMap>{identityMap, identityMap},
+            getNParallelLoopsAttrs(reducedRank),
+            [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+              Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
+                                                     elementTy, args[0]);
+              linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
+            })
+            .getResults()[0];
   }
 
   SmallVector<ReassociationExprs, 4> reassociationMap;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3b63bdf4f7219..37af8b8859852 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -912,6 +912,7 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<
 
 // -----
 
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 // CHECK-LABEL: @reduce_bf16
 // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
 func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
@@ -924,9 +925,13 @@ func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
   // CHECK:   [[ACC:%.+]] = arith.addf [[EXTF]], %[[ARG2]] : f32
   // CHECK:   linalg.yield [[ACC]] : f32
   // CHECK:  }
-  // CHECK:  [[TRUNCF:%.+]] = tensor.empty() : tensor<4xbf16>
-  // CHECK:  [[RES:%.+]] = linalg.map { arith.truncf } ins([[REDUCE]]{{.*}}outs([[TRUNCF]]
-  // CHECK:  tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xbf16> into tensor<1x4xbf16>
+  // CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<4xbf16>
+  // CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<4xf32>) outs([[INIT_RES]] : tensor<4xbf16>)
+  // CHECK:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16):
+  // CHECK:   [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16
+  // CHECK:   linalg.yield [[TRUNCF]] : bf16
+  // CHECK:  }
+  // CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xbf16> into tensor<1x4xbf16>
   %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xbf16>) -> tensor<1x4xbf16>
   return
 }



More information about the Mlir-commits mailing list