[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jun 18 14:38:14 PDT 2025
================
@@ -322,6 +333,141 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!isa<Float4E2M1FNType>(operandETy)) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+ }
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+
+ Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+ Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+ Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+ Value cZero =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
+ Value cHalf =
+ createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
+
+ Value mantissaBitmask = c0x1;
+ Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
+ Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+
+ Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
+ Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
+ f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
+
+ Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
+ Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
+ f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
+ Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
+ f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
+ Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
+ f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
+
+ Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
+ Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
+ f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
+ f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
+
+ // Special consideration for subnormal exponent (exp == 00).
+ Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+ f32ExpBits, biasAdjustment);
+ Value isManSet =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
+ Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
+
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
+ if (!isa<Float32Type>(resultETy)) {
+ result = b.create<arith::TruncFOp>(resultETy, operand);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
----------------
kuhar wrote:
What do you mean by `Scalar` in the op name
https://github.com/llvm/llvm-project/pull/144157
More information about the Mlir-commits
mailing list