[Mlir-commits] [mlir] Fix the lowering of `arith.truncf : f32 to bf16`. (PR #83180)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 28 10:38:39 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

<details>
<summary>Changes</summary>

This lowering was not correctly handling the case where saturation of the mantissa results in an increase of the exponent value. The new code borrows, with credit, the idea from https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79 and adds comments to explain the magic trick going on here and why it's correct. Hat tip to its original author, whom I believe to be @<!-- -->Maratyszcza. 

A testcase was also requiring a tie to be broken upwards in a case where "to nearest-even" required going downward. The fact that it used to pass suggests that there was another bug in the old code.

---
Full diff: https://github.com/llvm/llvm-project/pull/83180.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+46-52) 
- (modified) mlir/test/Dialect/Arith/expand-ops.mlir (+15-30) 
- (modified) mlir/test/mlir-cpu-runner/expand-arith-ops.mlir (+35-12) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 8deb8f028ba458..7f246daf99ff3c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -261,68 +261,62 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
       return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
     }
 
-    Type i1Ty = b.getI1Type();
     Type i16Ty = b.getI16Type();
     Type i32Ty = b.getI32Type();
     Type f32Ty = b.getF32Type();
     if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
-      i1Ty = shapedTy.clone(i1Ty);
       i16Ty = shapedTy.clone(i16Ty);
       i32Ty = shapedTy.clone(i32Ty);
       f32Ty = shapedTy.clone(f32Ty);
     }
 
-    Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
-
-    Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter);
-    Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter);
-    Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter);
-    Value expMask =
-        createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter);
-    Value expMax =
-        createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter);
-
-    // Grab the sign bit.
-    Value sign = b.create<arith::ShRUIOp>(bitcast, c31);
-
-    // Our mantissa rounding value depends on the sign bit and the last
-    // truncated bit.
-    Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter);
-    cManRound = b.create<arith::SubIOp>(cManRound, sign);
-
-    // Grab out the mantissa and directly apply rounding.
-    Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
-    Value manRound = b.create<arith::AddIOp>(man, cManRound);
-
-    // Grab the overflow bit and shift right if we overflow.
-    Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
-    Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);
-
-    // Grab the exponent and round using the mantissa's carry bit.
-    Value exp = b.create<arith::AndIOp>(bitcast, expMask);
-    Value expCarry = b.create<arith::AddIOp>(exp, manRound);
-    expCarry = b.create<arith::AndIOp>(expCarry, expMask);
-
-    // If the exponent is saturated, we keep the max value.
-    Value expCmp =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
-    exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);
-
-    // If the exponent is max and we rolled over, keep the old mantissa.
-    Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
-    Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
-    man = b.create<arith::SelectOp>(keepOldMan, man, manNew);
-
-    // Assemble the now rounded f32 value (as an i32).
-    Value rounded = b.create<arith::ShLIOp>(sign, c31);
-    rounded = b.create<arith::OrIOp>(rounded, exp);
-    rounded = b.create<arith::OrIOp>(rounded, man);
-
+    // Algorithm borrowed from this excellent code:
+    // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
+    // There is a magic idea there, to let the addition of the rounding_bias to
+    // the mantissa simply overflow into the exponent bits. It's a bit of an
+    // aggressive, obfuscating optimization, but it is well-tested code, and it
+    // results in more concise and efficient IR.
+    // The case of NaN is handled separately (see isNaN and the final select).
+    // The case of infinities is NOT handled separately, which deserves an
+    // explanation. As the encoding of infinities has zero mantissa, the
+    // rounding-bias addition never carries into the exponent so that just gets
+    // truncated away, and as bfloat16 and float32 have the same number of
+    // exponent bits, that simple truncation is the desired outcome for
+    // infinities.
+    Value isNan =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
+    // Constant used to make the rounding bias.
+    Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
+    // Constant used to generate a quiet NaN.
+    Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+    // Small constants used to address bits.
     Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
-    Value shr = b.create<arith::ShRUIOp>(rounded, c16);
-    Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
-    Value result = b.create<arith::BitcastOp>(resultTy, trunc);
-
+    Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
+    // Reinterpret the input f32 value as bits.
+    Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
+    // Read bit 16 as a value in {0,1}.
+    Value bit16 =
+        b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
+    // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
+    // on bit 16, implementing the tie-breaking "to nearest even".
+    Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
+    // Add the rounding bias. Generally we want this to be added to the
+    // mantissa, but nothing prevents this to from carrying into the exponent
+    // bits, which would feel like a bug, but this is the magic trick here:
+    // when that happens, the mantissa gets reset to zero and the exponent
+    // gets incremented by the carry... which is actually exactly what we
+    // want.
+    Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
+    // Now that the rounding-bias has been added, truncating the low bits
+    // yields the correctly rounded result.
+    Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
+    Value normalCaseResult_i16 =
+        b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
+    // Select either the above-computed result, or a quiet NaN constant
+    // if the input was NaN.
+    Value select =
+        b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+    Value result = b.create<arith::BitcastOp>(resultTy, select);
     rewriter.replaceOp(op, result);
     return success();
   }
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 046e8ff64fba6d..91f652e5a270e3 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -255,36 +255,21 @@ func.func @truncf_f32(%arg0 : f32) -> bf16 {
 }
 
 // CHECK-LABEL: @truncf_f32
-
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16
-// CHECK-DAG: %[[C32768:.+]] = arith.constant 32768
-// CHECK-DAG: %[[C2130706432:.+]] = arith.constant 2130706432
-// CHECK-DAG: %[[C2139095040:.+]] = arith.constant 2139095040
-// CHECK-DAG: %[[C8388607:.+]] = arith.constant 8388607
-// CHECK-DAG: %[[C31:.+]] = arith.constant 31
-// CHECK-DAG: %[[C23:.+]] = arith.constant 23
-// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %arg0
-// CHECK-DAG: %[[SIGN:.+]] = arith.shrui %[[BITCAST:.+]], %[[C31]]
-// CHECK-DAG: %[[ROUND:.+]] = arith.subi %[[C32768]], %[[SIGN]]
-// CHECK-DAG: %[[MANTISSA:.+]] = arith.andi %[[BITCAST]], %[[C8388607]]
-// CHECK-DAG: %[[ROUNDED:.+]] = arith.addi %[[MANTISSA]], %[[ROUND]]
-// CHECK-DAG: %[[ROLL:.+]] = arith.shrui %[[ROUNDED]], %[[C23]]
-// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[ROUNDED]], %[[ROLL]]
-// CHECK-DAG: %[[EXP:.+]] = arith.andi %0, %[[C2139095040]]
-// CHECK-DAG: %[[EXPROUND:.+]] = arith.addi %[[EXP]], %[[ROUNDED]]
-// CHECK-DAG: %[[EXPROLL:.+]] = arith.andi %[[EXPROUND]], %[[C2139095040]]
-// CHECK-DAG: %[[EXPMAX:.+]] = arith.cmpi uge, %[[EXP]], %[[C2130706432]]
-// CHECK-DAG: %[[EXPNEW:.+]] = arith.select %[[EXPMAX]], %[[EXP]], %[[EXPROLL]]
-// CHECK-DAG: %[[OVERFLOW_B:.+]] = arith.trunci %[[ROLL]]
-// CHECK-DAG: %[[KEEP_MAN:.+]] = arith.andi %[[EXPMAX]], %[[OVERFLOW_B]]
-// CHECK-DAG: %[[MANNEW:.+]] = arith.select %[[KEEP_MAN]], %[[MANTISSA]], %[[SHR]]
-// CHECK-DAG: %[[NEWSIGN:.+]] = arith.shli %[[SIGN]], %[[C31]]
-// CHECK-DAG: %[[WITHEXP:.+]] = arith.ori %[[NEWSIGN]], %[[EXPNEW]]
-// CHECK-DAG: %[[WITHMAN:.+]] = arith.ori %[[WITHEXP]], %[[MANNEW]]
-// CHECK-DAG: %[[SHIFT:.+]] = arith.shrui %[[WITHMAN]], %[[C16]]
-// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHIFT]]
-// CHECK-DAG: %[[RES:.+]] = arith.bitcast %[[TRUNC]]
-// CHECK: return %[[RES]]
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C7FC0_i16:.+]] = arith.constant 32704 : i16
+// CHECK-DAG: %[[C7FFF:.+]] = arith.constant 32767 : i32
+// CHECK-DAG: %[[ISNAN:.+]] = arith.cmpf une, %arg0, %arg0 : f32
+// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
+// CHECK-DAG: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C16]] : i32
+// CHECK-DAG: %[[BIT16:.+]] = arith.andi %[[SHRUI]], %[[C1]] : i32
+// CHECK-DAG: %[[ROUNDING_BIAS:.+]] = arith.addi %[[BIT16]], %[[C7FFF]] : i32
+// CHECK-DAG: %[[BIASED:.+]] = arith.addi %[[BITCAST]], %[[ROUNDING_BIAS]] : i32
+// CHECK-DAG: %[[BIASED_SHIFTED:.+]] = arith.shrui %[[BIASED]], %[[C16]] : i32
+// CHECK-DAG: %[[NORMAL_CASE_RESULT_i16:.+]] = arith.trunci %[[BIASED_SHIFTED]] : i32 to i16
+// CHECK-DAG: %[[SELECT:.+]] = arith.select %[[ISNAN]], %[[C7FC0_i16]], %[[NORMAL_CASE_RESULT_i16]] : i16
+// CHECK-DAG: %[[RESULT:.+]] = arith.bitcast %[[SELECT]] : i16 to bf16
+// CHECK: return %[[RESULT]]
 
 // -----
 
diff --git a/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir b/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir
index 44141cc4eeaf42..0bf6523c5e5d5c 100644
--- a/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir
+++ b/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir
@@ -13,10 +13,21 @@ func.func @trunc_bf16(%a : f32) {
 }
 
 func.func @main() {
-  // CHECK: 1.00781
-  %roundOneI = arith.constant 0x3f808000 : i32
-  %roundOneF = arith.bitcast %roundOneI : i32 to f32
-  call @trunc_bf16(%roundOneF): (f32) -> ()
+  // Note: this is a tie (low 16 bits are 0x8000). We expect the rounding behavior
+  // to break ties "to nearest-even", which in this case means downwards,
+  // since bit 16 is not set.
+  // CHECK: 1
+  %value_1_00391_I = arith.constant 0x3f808000 : i32
+  %value_1_00391_F = arith.bitcast %value_1_00391_I : i32 to f32
+  call @trunc_bf16(%value_1_00391_F): (f32) -> ()
+
+  // Note: this is a tie (low 16 bits are 0x8000). We expect the rounding behavior
+  // to break ties "to nearest-even", which in this case means upwards,
+  // since bit 16 is set.
+  // CHECK-NEXT: 1.01562
+  %value_1_01172_I = arith.constant 0x3f818000 : i32
+  %value_1_01172_F = arith.bitcast %value_1_01172_I : i32 to f32
+  call @trunc_bf16(%value_1_01172_F): (f32) -> ()
 
   // CHECK-NEXT: -1
   %noRoundNegOneI = arith.constant 0xbf808000 : i32
@@ -38,15 +49,27 @@ func.func @main() {
   %neginff = arith.bitcast %neginfi : i32 to f32
   call @trunc_bf16(%neginff): (f32) -> ()
 
+  // Note: this rounds upwards. As the mantissa was already saturated, this rounding
+  // causes the exponent to be incremented. As the exponent was already the
+  // maximum exponent value for finite values, this increment of the exponent
+  // causes this to overflow to +inf.
+  // CHECK-NEXT: inf
+  %big_overflowing_i = arith.constant 0x7f7fffff : i32
+  %big_overflowing_f = arith.bitcast %big_overflowing_i : i32 to f32
+  call @trunc_bf16(%big_overflowing_f): (f32) -> ()
+
+  // Same as the previous testcase but negative.
+  // CHECK-NEXT: -inf
+  %negbig_overflowing_i = arith.constant 0xff7fffff : i32
+  %negbig_overflowing_f = arith.bitcast %negbig_overflowing_i : i32 to f32
+  call @trunc_bf16(%negbig_overflowing_f): (f32) -> ()
+
+  // In contrast to the previous two testcases, the upwards-rounding here
+  // does not cause overflow.
   // CHECK-NEXT: 3.38953e+38
-  %bigi = arith.constant 0x7f7fffff : i32
-  %bigf = arith.bitcast %bigi : i32 to f32
-  call @trunc_bf16(%bigf): (f32) -> ()
-
-  // CHECK-NEXT: -3.38953e+38
-  %negbigi = arith.constant 0xff7fffff : i32
-  %negbigf = arith.bitcast %negbigi : i32 to f32
-  call @trunc_bf16(%negbigf): (f32) -> ()
+  %big_nonoverflowing_i = arith.constant 0x7f7effff : i32
+  %big_nonoverflowing_f = arith.bitcast %big_nonoverflowing_i : i32 to f32
+  call @trunc_bf16(%big_nonoverflowing_f): (f32) -> ()
 
   // CHECK-NEXT: 1.625
   %exprolli = arith.constant 0x3fcfffff : i32

``````````

</details>


https://github.com/llvm/llvm-project/pull/83180


More information about the Mlir-commits mailing list