[Mlir-commits] [mlir] Add arith expansion of f8E8M0 type for extf/trunc ops (PR #140332)

Prashant Kumar llvmlistbot at llvm.org
Wed May 21 15:03:34 PDT 2025


================
@@ -313,18 +313,113 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     // Now that the rounding-bias has been added, truncating the low bits
     // yields the correctly rounded result.
     Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
-    Value normalCaseResult_i16 =
+    Value normalCaseResultI16 =
         b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
     // Select either the above-computed result, or a quiet NaN constant
     // if the input was NaN.
     Value select =
-        b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+        b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
     Value result = b.create<arith::BitcastOp>(resultTy, select);
     rewriter.replaceOp(op, result);
     return success();
   }
 };
 
+struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::ExtFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
+      return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
+    }
+
+    Type i8Ty = b.getI8Type();
+    Type i32Ty = b.getI32Type();
+    Type f32Ty = b.getF32Type();
+    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+      i8Ty = shapedTy.clone(i8Ty);
+      i32Ty = shapedTy.clone(i32Ty);
+      f32Ty = shapedTy.clone(f32Ty);
+    }
+
+    Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
+    // create constants for NaNs
+    Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
+
+    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+    Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
+
+    Value isNan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
+    // select for NaNs
+    f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
+    Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
+    if (resultETy.getIntOrFloatBitWidth() < 32) {
+      result = b.create<arith::TruncFOp>(resultTy, result);
+    } else if (resultETy.getIntOrFloatBitWidth() > 32) {
+      result = b.create<arith::ExtFOp>(resultTy, result);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/*
+TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
+Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
+they all map to NaN in F8E8M0 Type.
+*/
+struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultTy = op.getType();
+    Type resultETy = getElementTypeOrSelf(resultTy);
+    if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
+      return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
+    }
+
+    if (op.getRoundingmodeAttr()) {
+      return rewriter.notifyMatchFailure(
+          op, "only applicable to default rounding mode.");
+    }
+
+    Type i8Ty = b.getI8Type();
+    Type i32Ty = b.getI32Type();
+    Type f32Ty = b.getF32Type();
+    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+      i8Ty = shapedTy.clone(i8Ty);
+      i32Ty = shapedTy.clone(i32Ty);
+      f32Ty = shapedTy.clone(f32Ty);
+    }
----------------
pashu123 wrote:

I would expect something like 
```
auto cloneIfShaped = [&](Type baseTy) -> Type {
  if (auto shapedTy = dyn_cast<ShapedType>(operandTy))
    return shapedTy.clone(baseTy);
  return baseTy;
};

Type i8Ty = cloneIfShaped(b.getI8Type());
Type i32Ty = cloneIfShaped(b.getI32Type());
Type f32Ty = cloneIfShaped(b.getF32Type());
```
Also, since you are using it in the above rewrite, this lambda can be made a static function.


https://github.com/llvm/llvm-project/pull/140332


More information about the Mlir-commits mailing list