[Mlir-commits] [mlir] Add arith expansion of f8E8M0 type for extf/trunc ops (PR #140332)

Krzysztof Drewniak llvmlistbot at llvm.org
Fri May 16 21:40:37 PDT 2025


================
@@ -313,18 +313,120 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     // Now that the rounding-bias has been added, truncating the low bits
     // yields the correctly rounded result.
     Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
-    Value normalCaseResult_i16 =
+    Value normalCaseResultI16 =
         b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
     // Select either the above-computed result, or a quiet NaN constant
     // if the input was NaN.
     Value select =
-        b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+        b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
     Value result = b.create<arith::BitcastOp>(resultTy, select);
     rewriter.replaceOp(op, result);
     return success();
   }
 };
 
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!operandETy.isF8E8M0FNU()) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+    }
+
+    if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
----------------
krzysz00 wrote:

I'd not hardcode a list of targets here - I'd just plow forward with a cast to f32 and if the final target has a bitwidth less than 32 you truncate and for > 32 you extend

https://github.com/llvm/llvm-project/pull/140332


More information about the Mlir-commits mailing list