[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jun 18 14:38:14 PDT 2025


================
@@ -322,6 +333,141 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!isa<Float4E2M1FNType>(operandETy)) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+    }
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+
+    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+    Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+    Value cZero =
+        createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
+    Value cHalf =
+        createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
+
+    Value mantissaBitmask = c0x1;
+    Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
+    Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+
+    Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
+    Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
+    f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
+
+    Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
+    Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
+    f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
+    Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
+    f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
+    Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
+    f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
+
+    Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
+    Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
+    f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
+    f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
+
+    // Special consideration for subnormal exponent (exp == 00).
+    Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+                                                f32ExpBits, biasAdjustment);
+    Value isManSet =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
+    Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
+
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
+    if (!isa<Float32Type>(resultETy)) {
+      result = b.create<arith::TruncFOp>(resultETy, operand);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  ScalarF4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 2)
+      : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (isa<ShapedType>(operandTy))
+      return failure();
+
+    if (!isa<Float4E2M1FNType>(operandETy))
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+
+    SmallVector<int> values = {
+        0x00000000, // 0.0
+        0x3f000000, // 0.5
+        0x3f800000, // 1.0
+        0x3fc00000, // 1.5
+        0x40000000, // 2.0
+        0x40400000, // 3.0
+        0x40800000, // 4.0
+        0x40c00000  // 6.0
+    };
+    // auto type = RankedTensorType::get({8}, b.getI32Type());
+    VectorType type = VectorType::get({8}, b.getI32Type());
+    SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector(
+        values, [&](int v) -> Attribute { return b.getI32IntegerAttr(v); });
+    Value lookupTable = b.create<arith::ConstantOp>(
+        DenseIntElementsAttr::get(type, lookupTableAttr));
+
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type i64Ty = cloneToShapedType(operandTy, b.getI64Type());
+
+    Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
+
+    Value expManBitmask = createConst(op.getLoc(), i4Ty, 0x7, rewriter);
+    Value indexI4 = b.create<arith::AndIOp>(i4Bits, expManBitmask);
+    Value indexI64 = b.create<arith::ExtUIOp>(i64Ty, indexI4);
+    Value index = b.create<arith::IndexCastOp>(b.getIndexType(), indexI64);
+
+    Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+    Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+    Value signBitI4 = b.create<arith::AndIOp>(i4Bits, signBitmask);
----------------
kuhar wrote:

You can set the sign bit of f32 by first zero-extending to i32, shift right by 3, and then shift left by 31, and or with the looked up value at the end.

https://github.com/llvm/llvm-project/pull/144157


More information about the Mlir-commits mailing list