[Mlir-commits] [mlir] magic bf16 (PR #83180)

Benoit Jacob llvmlistbot at llvm.org
Tue Feb 27 12:47:14 PST 2024


https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/83180

None

>From e9ef520975fd99ae006f1f5e23ad2df91b68766c Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 27 Feb 2024 15:45:36 -0500
Subject: [PATCH] magic bf16

---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 74 ++++++-------------
 1 file changed, 22 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 8deb8f028ba458..1538eb61f5f170 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -261,68 +261,38 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
       return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
     }
 
-    Type i1Ty = b.getI1Type();
     Type i16Ty = b.getI16Type();
     Type i32Ty = b.getI32Type();
     Type f32Ty = b.getF32Type();
     if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
-      i1Ty = shapedTy.clone(i1Ty);
       i16Ty = shapedTy.clone(i16Ty);
       i32Ty = shapedTy.clone(i32Ty);
       f32Ty = shapedTy.clone(f32Ty);
     }
 
-    Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
-
-    Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter);
-    Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter);
-    Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter);
-    Value expMask =
-        createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter);
-    Value expMax =
-        createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter);
-
-    // Grab the sign bit.
-    Value sign = b.create<arith::ShRUIOp>(bitcast, c31);
-
-    // Our mantissa rounding value depends on the sign bit and the last
-    // truncated bit.
-    Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter);
-    cManRound = b.create<arith::SubIOp>(cManRound, sign);
-
-    // Grab out the mantissa and directly apply rounding.
-    Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
-    Value manRound = b.create<arith::AddIOp>(man, cManRound);
-
-    // Grab the overflow bit and shift right if we overflow.
-    Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
-    Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);
-
-    // Grab the exponent and round using the mantissa's carry bit.
-    Value exp = b.create<arith::AndIOp>(bitcast, expMask);
-    Value expCarry = b.create<arith::AddIOp>(exp, manRound);
-    expCarry = b.create<arith::AndIOp>(expCarry, expMask);
-
-    // If the exponent is saturated, we keep the max value.
-    Value expCmp =
-        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
-    exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);
-
-    // If the exponent is max and we rolled over, keep the old mantissa.
-    Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
-    Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
-    man = b.create<arith::SelectOp>(keepOldMan, man, manNew);
-
-    // Assemble the now rounded f32 value (as an i32).
-    Value rounded = b.create<arith::ShLIOp>(sign, c31);
-    rounded = b.create<arith::OrIOp>(rounded, exp);
-    rounded = b.create<arith::OrIOp>(rounded, man);
-
+    // Algorithm borrowed from this excellent code:
+    // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
+    // There is a magic idea there, to let the addition of the rounding_bias to
+    // the mantissa simply overflow into the exponent bits. It's a bit of an
+    // aggressive, obfuscating optimization, but it is well-tested code, and it
+    // results in more concise and efficient IR.
+    Value isNan =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
+    Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
+    Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
     Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
-    Value shr = b.create<arith::ShRUIOp>(rounded, c16);
-    Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
-    Value result = b.create<arith::BitcastOp>(resultTy, trunc);
-
+    Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
+    Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
+    Value bit16 =
+        b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
+    Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
+    Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
+    Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
+    Value normalCaseResult_i16 =
+        b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
+    Value select =
+        b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+    Value result = b.create<arith::BitcastOp>(resultTy, select);
     rewriter.replaceOp(op, result);
     return success();
   }



More information about the Mlir-commits mailing list