[Mlir-commits] [mlir] magic bf16 (PR #83180)

Benoit Jacob llvmlistbot at llvm.org
Wed Feb 28 09:32:45 PST 2024


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/83180

>From 190380fe6090023b8ffdec4e858d877ab07155dd Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 27 Feb 2024 15:45:36 -0500
Subject: [PATCH] magic bf16

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 98 +++++++++----------
 mlir/test/Dialect/Arith/expand-ops.mlir       | 45 +++------
 .../mlir-cpu-runner/expand-arith-ops.mlir     | 47 ++++++---
 3 files changed, 96 insertions(+), 94 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 8deb8f028ba458..7f246daf99ff3c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -261,68 +261,62 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
       return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
     }
 
-    Type i1Ty = b.getI1Type();
     Type i16Ty = b.getI16Type();
     Type i32Ty = b.getI32Type();
     Type f32Ty = b.getF32Type();
     if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
-      i1Ty = shapedTy.clone(i1Ty);
       i16Ty = shapedTy.clone(i16Ty);
       i32Ty = shapedTy.clone(i32Ty);
       f32Ty = shapedTy.clone(f32Ty);
     }
 
-    Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
-
-    Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter);
-    Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter);
-    Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter);
-    Value expMask =
-        createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter);
-    Value expMax =
-        createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter);
-
-    // Grab the sign bit.
-    Value sign = b.create<arith::ShRUIOp>(bitcast, c31);
-
-    // Our mantissa rounding value depends on the sign bit and the last
-    // truncated bit.
-    Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter);
-    cManRound = b.create<arith::SubIOp>(cManRound, sign);
-
-    // Grab out the mantissa and directly apply rounding.
-    Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
-    Value manRound = b.create<arith::AddIOp>(man, cManRound);
-
-    // Grab the overflow bit and shift right if we overflow.
-    Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
-    Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);
-
-    // Grab the exponent and round using the mantissa's carry bit.
-    Value exp = b.create<arith::AndIOp>(bitcast, expMask);
-    Value expCarry = b.create<arith::AddIOp>(exp, manRound);
-    expCarry = b.create<arith::AndIOp>(expCarry, expMask);
-
-    // If the exponent is saturated, we keep the max value.
-    Value expCmp =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
-    exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);
-
-    // If the exponent is max and we rolled over, keep the old mantissa.
-    Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
-    Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
-    man = b.create<arith::SelectOp>(keepOldMan, man, manNew);
-
-    // Assemble the now rounded f32 value (as an i32).
-    Value rounded = b.create<arith::ShLIOp>(sign, c31);
-    rounded = b.create<arith::OrIOp>(rounded, exp);
-    rounded = b.create<arith::OrIOp>(rounded, man);
-
+    // Algorithm borrowed from this excellent code:
+    // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
+    // There is a magic idea there, to let the addition of the rounding_bias to
+    // the mantissa simply overflow into the exponent bits. It's a bit of an
+    // aggressive, obfuscating optimization, but it is well-tested code, and it
+    // results in more concise and efficient IR.
+    // The case of NaN is handled separately (see isNaN and the final select).
+    // The case of infinities is NOT handled separately, which deserves an
+    // explanation. As the encoding of infinities has zero mantissa, the
+    // rounding-bias addition never carries into the exponent so that just gets
+    // truncated away, and as bfloat16 and float32 have the same number of
+    // exponent bits, that simple truncation is the desired outcome for
+    // infinities.
+    Value isNan =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
+    // Constant used to make the rounding bias.
+    Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
+    // Constant used to generate a quiet NaN.
+    Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+    // Small constants used to address bits.
     Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
-    Value shr = b.create<arith::ShRUIOp>(rounded, c16);
-    Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
-    Value result = b.create<arith::BitcastOp>(resultTy, trunc);
-
+    Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
+    // Reinterpret the input f32 value as bits.
+    Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
+    // Read bit 16 as a value in {0,1}.
+    Value bit16 =
+        b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
+    // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
+    // on bit 16, implementing the tie-breaking "to nearest even".
+    Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
+    // Add the rounding bias. Generally we want this to be added to the
+    // mantissa, but nothing prevents this to from carrying into the exponent
+    // bits, which would feel like a bug, but this is the magic trick here:
+    // when that happens, the mantissa gets reset to zero and the exponent
+    // gets incremented by the carry... which is actually exactly what we
+    // want.
+    Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
+    // Now that the rounding-bias has been added, truncating the low bits
+    // yields the correctly rounded result.
+    Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
+    Value normalCaseResult_i16 =
+        b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
+    // Select either the above-computed result, or a quiet NaN constant
+    // if the input was NaN.
+    Value select =
+        b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+    Value result = b.create<arith::BitcastOp>(resultTy, select);
     rewriter.replaceOp(op, result);
     return success();
   }
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 046e8ff64fba6d..91f652e5a270e3 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -255,36 +255,21 @@ func.func @truncf_f32(%arg0 : f32) -> bf16 {
 }
 
 // CHECK-LABEL: @truncf_f32
-
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16
-// CHECK-DAG: %[[C32768:.+]] = arith.constant 32768
-// CHECK-DAG: %[[C2130706432:.+]] = arith.constant 2130706432
-// CHECK-DAG: %[[C2139095040:.+]] = arith.constant 2139095040
-// CHECK-DAG: %[[C8388607:.+]] = arith.constant 8388607
-// CHECK-DAG: %[[C31:.+]] = arith.constant 31
-// CHECK-DAG: %[[C23:.+]] = arith.constant 23
-// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %arg0
-// CHECK-DAG: %[[SIGN:.+]] = arith.shrui %[[BITCAST:.+]], %[[C31]]
-// CHECK-DAG: %[[ROUND:.+]] = arith.subi %[[C32768]], %[[SIGN]]
-// CHECK-DAG: %[[MANTISSA:.+]] = arith.andi %[[BITCAST]], %[[C8388607]]
-// CHECK-DAG: %[[ROUNDED:.+]] = arith.addi %[[MANTISSA]], %[[ROUND]]
-// CHECK-DAG: %[[ROLL:.+]] = arith.shrui %[[ROUNDED]], %[[C23]]
-// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[ROUNDED]], %[[ROLL]]
-// CHECK-DAG: %[[EXP:.+]] = arith.andi %0, %[[C2139095040]]
-// CHECK-DAG: %[[EXPROUND:.+]] = arith.addi %[[EXP]], %[[ROUNDED]]
-// CHECK-DAG: %[[EXPROLL:.+]] = arith.andi %[[EXPROUND]], %[[C2139095040]]
-// CHECK-DAG: %[[EXPMAX:.+]] = arith.cmpi uge, %[[EXP]], %[[C2130706432]]
-// CHECK-DAG: %[[EXPNEW:.+]] = arith.select %[[EXPMAX]], %[[EXP]], %[[EXPROLL]]
-// CHECK-DAG: %[[OVERFLOW_B:.+]] = arith.trunci %[[ROLL]]
-// CHECK-DAG: %[[KEEP_MAN:.+]] = arith.andi %[[EXPMAX]], %[[OVERFLOW_B]]
-// CHECK-DAG: %[[MANNEW:.+]] = arith.select %[[KEEP_MAN]], %[[MANTISSA]], %[[SHR]]
-// CHECK-DAG: %[[NEWSIGN:.+]] = arith.shli %[[SIGN]], %[[C31]]
-// CHECK-DAG: %[[WITHEXP:.+]] = arith.ori %[[NEWSIGN]], %[[EXPNEW]]
-// CHECK-DAG: %[[WITHMAN:.+]] = arith.ori %[[WITHEXP]], %[[MANNEW]]
-// CHECK-DAG: %[[SHIFT:.+]] = arith.shrui %[[WITHMAN]], %[[C16]]
-// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHIFT]]
-// CHECK-DAG: %[[RES:.+]] = arith.bitcast %[[TRUNC]]
-// CHECK: return %[[RES]]
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C7FC0_i16:.+]] = arith.constant 32704 : i16
+// CHECK-DAG: %[[C7FFF:.+]] = arith.constant 32767 : i32
+// CHECK-DAG: %[[ISNAN:.+]] = arith.cmpf une, %arg0, %arg0 : f32
+// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
+// CHECK-DAG: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C16]] : i32
+// CHECK-DAG: %[[BIT16:.+]] = arith.andi %[[SHRUI]], %[[C1]] : i32
+// CHECK-DAG: %[[ROUNDING_BIAS:.+]] = arith.addi %[[BIT16]], %[[C7FFF]] : i32
+// CHECK-DAG: %[[BIASED:.+]] = arith.addi %[[BITCAST]], %[[ROUNDING_BIAS]] : i32
+// CHECK-DAG: %[[BIASED_SHIFTED:.+]] = arith.shrui %[[BIASED]], %[[C16]] : i32
+// CHECK-DAG: %[[NORMAL_CASE_RESULT_i16:.+]] = arith.trunci %[[BIASED_SHIFTED]] : i32 to i16
+// CHECK-DAG: %[[SELECT:.+]] = arith.select %[[ISNAN]], %[[C7FC0_i16]], %[[NORMAL_CASE_RESULT_i16]] : i16
+// CHECK-DAG: %[[RESULT:.+]] = arith.bitcast %[[SELECT]] : i16 to bf16
+// CHECK: return %[[RESULT]]
 
 // -----
 
diff --git a/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir b/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir
index 44141cc4eeaf42..0bf6523c5e5d5c 100644
--- a/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir
+++ b/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir
@@ -13,10 +13,21 @@ func.func @trunc_bf16(%a : f32) {
 }
 
 func.func @main() {
-  // CHECK: 1.00781
-  %roundOneI = arith.constant 0x3f808000 : i32
-  %roundOneF = arith.bitcast %roundOneI : i32 to f32
-  call @trunc_bf16(%roundOneF): (f32) -> ()
+  // Note: this is a tie (low 16 bits are 0x8000). We expect the rounding behavior
+  // to break ties "to nearest-even", which in this case means downwards,
+  // since bit 16 is not set.
+  // CHECK: 1
+  %value_1_00391_I = arith.constant 0x3f808000 : i32
+  %value_1_00391_F = arith.bitcast %value_1_00391_I : i32 to f32
+  call @trunc_bf16(%value_1_00391_F): (f32) -> ()
+
+  // Note: this is a tie (low 16 bits are 0x8000). We expect the rounding behavior
+  // to break ties "to nearest-even", which in this case means upwards,
+  // since bit 16 is set.
+  // CHECK-NEXT: 1.01562
+  %value_1_01172_I = arith.constant 0x3f818000 : i32
+  %value_1_01172_F = arith.bitcast %value_1_01172_I : i32 to f32
+  call @trunc_bf16(%value_1_01172_F): (f32) -> ()
 
   // CHECK-NEXT: -1
   %noRoundNegOneI = arith.constant 0xbf808000 : i32
@@ -38,15 +49,27 @@ func.func @main() {
   %neginff = arith.bitcast %neginfi : i32 to f32
   call @trunc_bf16(%neginff): (f32) -> ()
 
+  // Note: this rounds upwards. As the mantissa was already saturated, this rounding
+  // causes the exponent to be incremented. As the exponent was already the
+  // maximum exponent value for finite values, this increment of the exponent
+  // causes this to overflow to +inf.
+  // CHECK-NEXT: inf
+  %big_overflowing_i = arith.constant 0x7f7fffff : i32
+  %big_overflowing_f = arith.bitcast %big_overflowing_i : i32 to f32
+  call @trunc_bf16(%big_overflowing_f): (f32) -> ()
+
+  // Same as the previous testcase but negative.
+  // CHECK-NEXT: -inf
+  %negbig_overflowing_i = arith.constant 0xff7fffff : i32
+  %negbig_overflowing_f = arith.bitcast %negbig_overflowing_i : i32 to f32
+  call @trunc_bf16(%negbig_overflowing_f): (f32) -> ()
+
+  // In contrast to the previous two testcases, the upwards-rounding here
+  // does not cause overflow.
   // CHECK-NEXT: 3.38953e+38
-  %bigi = arith.constant 0x7f7fffff : i32
-  %bigf = arith.bitcast %bigi : i32 to f32
-  call @trunc_bf16(%bigf): (f32) -> ()
-
-  // CHECK-NEXT: -3.38953e+38
-  %negbigi = arith.constant 0xff7fffff : i32
-  %negbigf = arith.bitcast %negbigi : i32 to f32
-  call @trunc_bf16(%negbigf): (f32) -> ()
+  %big_nonoverflowing_i = arith.constant 0x7f7effff : i32
+  %big_nonoverflowing_f = arith.bitcast %big_nonoverflowing_i : i32 to f32
+  call @trunc_bf16(%big_nonoverflowing_f): (f32) -> ()
 
   // CHECK-NEXT: 1.625
   %exprolli = arith.constant 0x3fcfffff : i32



More information about the Mlir-commits mailing list