[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 18 15:02:01 PDT 2025
================
@@ -322,6 +333,141 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!isa<Float4E2M1FNType>(operandETy)) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+ }
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+
+ Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+ Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+ Value cZero =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
+ Value cHalf =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
+
+ Value mantissaBitmask = c0x1;
+ Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
+ Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+
+ Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
+ Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
+ f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
+
+ Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
+ Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
+ f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
+ Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
+ f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
+ Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
+ f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
+
+ Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
+ Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
+ f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
+ f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
+
+ // Special consideration for subnormal exponent (exp == 00).
+ Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+ f32ExpBits, biasAdjustment);
+ Value isManSet =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
+ Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
+
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
+ if (!isa<Float32Type>(resultETy)) {
+ result = b.create<arith::TruncFOp>(resultETy, operand);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
----------------
Muzammiluddin-Syed-ECE wrote:
To use a lookup table for conversion, we need to make an index from the operand of the extf. The current implementation of casting the last 3 bits of the number to an index to use to access the table doesn't work when the operand is of rank greater than 0. More logic would be required, and possibly a 2d lookup table.
To avoid this, I keep this implementation aside for only scalar operands. I agree though this is not ideal so I will implement extf with selects instead.
https://github.com/llvm/llvm-project/pull/144157
More information about the Mlir-commits
mailing list