[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jun 18 14:38:14 PDT 2025


================
@@ -322,6 +333,141 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!isa<Float4E2M1FNType>(operandETy)) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+    }
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+
+    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+    Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+    Value cZero =
+        createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
+    Value cHalf =
+        createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
+
+    Value mantissaBitmask = c0x1;
+    Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
+    Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+
+    Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
+    Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
+    f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
+
+    Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
+    Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
+    f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
+    Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
+    f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
+    Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
+    f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
+
+    Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
+    Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
+    f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
+    f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
+
+    // Special consideration for subnormal exponent (exp == 00).
+    Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+                                                f32ExpBits, biasAdjustment);
+    Value isManSet =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
+    Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
+
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
+    if (!isa<Float32Type>(resultETy)) {
+      result = b.create<arith::TruncFOp>(resultETy, operand);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  ScalarF4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 2)
+      : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (isa<ShapedType>(operandTy))
+      return failure();
+
+    if (!isa<Float4E2M1FNType>(operandETy))
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+
+    SmallVector<int> values = {
+        0x00000000, // 0.0
+        0x3f000000, // 0.5
+        0x3f800000, // 1.0
+        0x3fc00000, // 1.5
+        0x40000000, // 2.0
+        0x40400000, // 3.0
+        0x40800000, // 4.0
+        0x40c00000  // 6.0
+    };
+    // auto type = RankedTensorType::get({8}, b.getI32Type());
+    VectorType type = VectorType::get({8}, b.getI32Type());
+    SmallVector<Attribute> lookupTableAttr = llvm::map_to_vector(
+        values, [&](int v) -> Attribute { return b.getI32IntegerAttr(v); });
+    Value lookupTable = b.create<arith::ConstantOp>(
+        DenseIntElementsAttr::get(type, lookupTableAttr));
+
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type i64Ty = cloneToShapedType(operandTy, b.getI64Type());
----------------
kuhar wrote:

Why do we need to extend this all the way to i64? I'd think we should be able to do everything within i32. The issue with i64 is that some targets don't support it natively and require software emulation.

https://github.com/llvm/llvm-project/pull/144157


More information about the Mlir-commits mailing list