[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jun 18 14:38:14 PDT 2025
================
@@ -322,6 +333,141 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!isa<Float4E2M1FNType>(operandETy)) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+ }
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+
+ Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+ Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+ Value cZero =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
+ Value cHalf =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
+
+ Value mantissaBitmask = c0x1;
+ Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
+ Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+
+ Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
+ Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
+ f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
+
+ Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
+ Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
+ f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
+ Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
+ f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
+ Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
+ f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
+
+ Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
+ Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
+ f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
+ f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
+
+ // Special consideration for subnormal exponent (exp == 00).
+ Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+ f32ExpBits, biasAdjustment);
+ Value isManSet =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
+ Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
+
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
+ if (!isa<Float32Type>(resultETy)) {
+ result = b.create<arith::TruncFOp>(resultETy, operand);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ ScalarF4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 2)
+ : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (isa<ShapedType>(operandTy))
+ return failure();
+
+ if (!isa<Float4E2M1FNType>(operandETy))
+ return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+
+ SmallVector<int> values = {
+ 0x00000000, // 0.0
+ 0x3f000000, // 0.5
+ 0x3f800000, // 1.0
+ 0x3fc00000, // 1.5
+ 0x40000000, // 2.0
+ 0x40400000, // 3.0
+ 0x40800000, // 4.0
+ 0x40c00000 // 6.0
+ };
+ // auto type = RankedTensorType::get({8}, b.getI32Type());
+ VectorType type = VectorType::get({8}, b.getI32Type());
+ SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector(
+ values, [&](int v) -> Attribute { return b.getI32IntegerAttr(v); });
+ Value lookupTable = b.create<arith::ConstantOp>(
+ DenseIntElementsAttr::get(type, lookupTableAttr));
+
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type i64Ty = cloneToShapedType(operandTy, b.getI64Type());
+
+ Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
+
+ Value expManBitmask = createConst(op.getLoc(), i4Ty, 0x7, rewriter);
+ Value indexI4 = b.create<arith::AndIOp>(i4Bits, expManBitmask);
+ Value indexI64 = b.create<arith::ExtUIOp>(i64Ty, indexI4);
+ Value index = b.create<arith::IndexCastOp>(b.getIndexType(), indexI64);
+
+ Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+ Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+ Value signBitI4 = b.create<arith::AndIOp>(i4Bits, signBitmask);
----------------
kuhar wrote:
You can set the sign bit of f32 by first zero-extending to i32, shift right by 3, and then shift left by 31, and or with the looked up value at the end.
https://github.com/llvm/llvm-project/pull/144157
More information about the Mlir-commits
mailing list