[Mlir-commits] [mlir] [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 (PR #144157)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jun 18 14:38:14 PDT 2025


================
@@ -322,6 +333,141 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
   }
 };
 
+struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  F4E2M1ExtFOpConverter(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<arith::ExtFOp>(context, benefit) {}
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!isa<Float4E2M1FNType>(operandETy)) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
+    }
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);
+
+    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+    Value c0x00000015 = createConst(op->getLoc(), i32Ty, 23, rewriter);
+    Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
+    Value cZero =
+        createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
+    Value cHalf =
+        createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);
+
+    Value mantissaBitmask = c0x1;
+    Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
+    Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);
+
+    Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
+    Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
+    f32Bits = b.create<arith::ShLIOp>(f32Bits, c0x0000001c);
+
+    Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
+    Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
+    f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
+    Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
+    f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
+    Value f32Exp = b.create<arith::ShLIOp>(f32ExpBits, c0x00000015);
+    f32Bits = b.create<arith::AddIOp>(f32Bits, f32Exp);
+
+    Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
+    Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
+    f32ManBit = b.create<arith::ShLIOp>(f32ManBit, c0x00000014);
+    f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);
+
+    // Special consideration for subnormal exponent (exp == 00).
+    Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
+                                                f32ExpBits, biasAdjustment);
+    Value isManSet =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
+    Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
+
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    result = b.create<arith::SelectOp>(isSubnormal, subnormalVal, result);
+    if (!isa<Float32Type>(resultETy)) {
+      result = b.create<arith::TruncFOp>(resultETy, operand);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct ScalarF4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
----------------
kuhar wrote:

What do you mean by `Scalar` in the op name

https://github.com/llvm/llvm-project/pull/144157


More information about the Mlir-commits mailing list