[Mlir-commits] [mlir] Introduce `arith.scaling_extf` and `arith.scaling_truncf` (PR #141965)
Umang Yadav
llvmlistbot at llvm.org
Fri May 30 11:43:24 PDT 2025
================
@@ -409,6 +421,125 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};
+struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value inputOperand = op.getIn();
+ Value scaleOperand = op.getScale();
+ Type scaleETy = getElementTypeOrSelf(scaleOperand);
+ // allow implicit exponent extraction from 16/32 bits floats
+ if (scaleETy.getIntOrFloatBitWidth() >= 16) {
+ scaleETy = b.getF8E8M0Type();
+ scaleOperand = b.create<arith::TruncFOp>(scaleETy, scaleOperand);
+ }
+ if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
+ return rewriter.notifyMatchFailure(
+ op, "scaling extf is not using scale operand of type f8E8M0FNU");
+ }
+ Type resultTy = op.getType();
+ // extf on scale will essentially create f32 number that is 2^scale and will
----------------
umangyadav wrote:
> why f32? can't resultTy be any float type
Changed comment to better reflect what it's doing.
> should we check if resultTy >= Float8E8M0FNU and >= inputType
As part of verification, it checks that output dtype is of larger widhth compared to input.
https://github.com/umangyadav/llvm-project/blob/d1543414578abf95a495b4eb6fe9b6201de8e9f6/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L1460
https://github.com/llvm/llvm-project/pull/141965
More information about the Mlir-commits
mailing list