[Mlir-commits] [mlir] Add arith expansion of f8E8M0 type for extf/trunc ops (PR #140332)
Prashant Kumar
llvmlistbot at llvm.org
Wed May 21 15:03:34 PDT 2025
================
@@ -313,18 +313,113 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Now that the rounding-bias has been added, truncating the low bits
// yields the correctly rounded result.
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
- Value normalCaseResult_i16 =
+ Value normalCaseResultI16 =
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
// Select either the above-computed result, or a quiet NaN constant
// if the input was NaN.
Value select =
- b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+ b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
Value result = b.create<arith::BitcastOp>(resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
};
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+ }
+
+ Type i8Ty = b.getI8Type();
+ Type i32Ty = b.getI32Type();
+ Type f32Ty = b.getF32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i8Ty = shapedTy.clone(i8Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ f32Ty = shapedTy.clone(f32Ty);
+ }
+
+ Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+ // create constants for NaNs
+ Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+ Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+ Value isNan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ // select for NaNs
+ f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (resultETy.getIntOrFloatBitWidth() < 32) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+ result = b.create<arith::ExtFOp>(resultTy, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultTy = op.getType();
+ Type resultETy = getElementTypeOrSelf(resultTy);
+ if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
+ return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+ }
+
+ if (op.getRoundingmodeAttr()) {
+ return rewriter.notifyMatchFailure(
+ op, "only applicable to default rounding mode.");
+ }
+
+ Type i8Ty = b.getI8Type();
+ Type i32Ty = b.getI32Type();
+ Type f32Ty = b.getF32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i8Ty = shapedTy.clone(i8Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ f32Ty = shapedTy.clone(f32Ty);
+ }
----------------
pashu123 wrote:
I would expect something like
```
auto cloneIfShaped = [&](Type baseTy) -> Type {
if (auto shapedTy = dyn_cast<ShapedType>(operandTy))
return shapedTy.clone(baseTy);
return baseTy;
};
Type i8Ty = cloneIfShaped(b.getI8Type());
Type i32Ty = cloneIfShaped(b.getI32Type());
Type f32Ty = cloneIfShaped(b.getF32Type());
```
Also, since you are using it in the above rewrite, this lambda can be made a static function.
https://github.com/llvm/llvm-project/pull/140332
More information about the Mlir-commits
mailing list