[Mlir-commits] [mlir] Add arith expansion of f8E8M0 type for extf/trunc ops (PR #140332)
Krzysztof Drewniak
llvmlistbot at llvm.org
Fri May 16 21:40:37 PDT 2025
================
@@ -313,18 +313,120 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
// Now that the rounding-bias has been added, truncating the low bits
// yields the correctly rounded result.
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
- Value normalCaseResult_i16 =
+ Value normalCaseResultI16 =
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
// Select either the above-computed result, or a quiet NaN constant
// if the input was NaN.
Value select =
- b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+ b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
Value result = b.create<arith::BitcastOp>(resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
};
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!operandETy.isF8E8M0FNU()) {
+ return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+ }
+
+ if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
+ return rewriter.notifyMatchFailure(
+ op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
+ }
+
+ Type i8Ty = b.getI8Type();
+ Type i32Ty = b.getI32Type();
+ Type f32Ty = b.getF32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i8Ty = shapedTy.clone(i8Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ f32Ty = shapedTy.clone(f32Ty);
+ }
+
+ Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+ // create constants for NaNs
+ Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+ Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+ Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+ Value isNan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+ // select for NaNs
+ f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+ Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+ if (resultETy.isBF16()) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ } else if (resultETy.isF16()) {
+ result = b.create<arith::TruncFOp>(resultTy, result);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultTy = op.getType();
+ Type resultETy = getElementTypeOrSelf(resultTy);
+ if (!resultETy.isF8E8M0FNU()) {
+ return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+ }
+ if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
----------------
krzysz00 wrote:
Same note: extend or truncate to f32 as needed
https://github.com/llvm/llvm-project/pull/140332
More information about the Mlir-commits
mailing list