[Mlir-commits] [mlir] [mlir][arith] Add `arith.fptofp` op (PR #188041)
Matthias Springer
llvmlistbot at llvm.org
Tue Mar 24 06:32:33 PDT 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/188041
>From 9f464e76a28a9c1f9ffc3af304876fac46e69888 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 23 Mar 2026 13:42:54 +0000
Subject: [PATCH 1/4] [mlir][arith] Add `arith.fptofp` op
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 34 +++++
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 70 +++++++++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 44 +++++++
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 116 ++++++++++++++++++
mlir/test/Dialect/Arith/canonicalize.mlir | 11 ++
mlir/test/Dialect/Arith/invalid.mlir | 40 ++++++
mlir/test/Dialect/Arith/ops.mlir | 80 ++++++++++++
7 files changed, 395 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4b830c05bf585..06ff3f9eeac44 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1425,6 +1425,40 @@ def Arith_TruncFOp :
attr-dict `:` type($in) `to` type($out) }];
}
+//===----------------------------------------------------------------------===//
+// FPToFPOp
+//===----------------------------------------------------------------------===//
+
+def Arith_FPToFPOp :
+ Arith_Op<"fptofp",
+ [Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
+ DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
+ DeclareOpInterfaceMethods<CastOpInterface>]>,
+ Arguments<(ins FloatLike:$in,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+ OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+ Results<(outs FloatLike:$out)> {
+ let summary = "cast between floating-point types";
+ let description = [{
+ Cast a floating-point value to a different floating-point type. Unlike
+ `arith.extf` and `arith.truncf`, this operation supports arbitrary
+ conversions between floating-point types, including conversions between
+ types of the same bitwidth but different semantics (e.g., f16 to bf16).
+
+ The source and destination element types must be different. If the value
+ cannot be exactly represented, it is rounded using the provided rounding
+ mode or the default one if no rounding mode is provided. When operating
+ on vectors, casts elementwise.
+ }];
+
+ let hasFolder = 1;
+ let hasVerifier = 1;
+ let assemblyFormat = [{ $in ($roundingmode^)?
+ (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($in) `to` type($out) }];
+}
+
//===----------------------------------------------------------------------===//
// Scaling TruncFOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index e7f561e8a4d67..a6fe481304c06 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -262,6 +262,75 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Lower arith.fptofp to the appropriate LLVM op(s).
+///
+/// - If src is wider than dst: llvm.fptrunc
+/// - If src is narrower than dst: llvm.fpext
+/// - bf16 <-> f16: llvm.fpext to f32, then llvm.fptrunc to dst.
+/// - Other FP types: not supported by the LLVM dialect.
+struct FPToFPOpLowering : public ConvertOpToLLVMPattern<arith::FPToFPOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::FPToFPOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (LLVM::detail::opHasUnsupportedFloatingPointTypes(op,
+ *getTypeConverter()))
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+
+ auto srcFloat = cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
+ auto dstFloat = cast<FloatType>(getElementTypeOrSelf(op.getType()));
+
+ Type convertedType = getTypeConverter()->convertType(op.getType());
+ if (!convertedType)
+ return rewriter.notifyMatchFailure(op, "failed to convert result type");
+
+ Value input = adaptor.getIn();
+ Location loc = op.getLoc();
+ Type operandType = input.getType();
+
+ if (!isa<LLVM::LLVMArrayType>(operandType)) {
+ rewriter.replaceOp(op, emitConversion(rewriter, loc, input, convertedType,
+ srcFloat, dstFloat));
+ return success();
+ }
+
+ if (!isa<VectorType>(op.getType()))
+ return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ return emitConversion(rewriter, loc, operands.front(), llvm1DVectorTy,
+ srcFloat, dstFloat);
+ },
+ rewriter);
+ }
+
+private:
+ static Value emitConversion(ConversionPatternRewriter &rewriter, Location loc,
+ Value input, Type targetType, FloatType srcFloat,
+ FloatType dstFloat) {
+ unsigned srcWidth = srcFloat.getWidth();
+ unsigned dstWidth = dstFloat.getWidth();
+ if (srcWidth > dstWidth)
+ return LLVM::FPTruncOp::create(rewriter, loc, targetType, input);
+ if (srcWidth < dstWidth)
+ return LLVM::FPExtOp::create(rewriter, loc, targetType, input);
+
+ // Same width, different semantics: bf16 <-> f16
+ assert((srcFloat.isBF16() && dstFloat.isF16() ||
+ srcFloat.isF16() && dstFloat.isBF16()) &&
+ "only bf16 <-> f16 conversions are supported");
+ Type f32Scalar = Float32Type::get(rewriter.getContext());
+ Type f32Ty = f32Scalar;
+ if (auto vecTy = dyn_cast<VectorType>(targetType))
+ f32Ty = VectorType::get(vecTy.getShape(), f32Scalar);
+ Value ext = LLVM::FPExtOp::create(rewriter, loc, f32Ty, input);
+ return LLVM::FPTruncOp::create(rewriter, loc, targetType, ext);
+ }
+};
+
struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
using Adaptor = ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
@@ -642,6 +711,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
ExtFOpLowering,
ExtSIOpLowering,
ExtUIOpLowering,
+ FPToFPOpLowering,
FPToSIOpLowering,
FPToUIOpLowering,
IndexCastOpSILowering,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 155edc5070a9d..774eeb03536be 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1699,6 +1699,50 @@ LogicalResult arith::TruncFOp::verify() {
return verifyTruncateOp<FloatType>(*this);
}
+//===----------------------------------------------------------------------===//
+// FPToFPOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::FPToFPOp::fold(FoldAdaptor adaptor) {
+ auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
+ const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
+ return constFoldCastOp<FloatAttr, FloatAttr>(
+ adaptor.getOperands(), getType(),
+ [this, &targetSemantics](const APFloat &a, bool &castStatus) {
+ RoundingMode roundingMode =
+ getRoundingmode().value_or(RoundingMode::to_nearest_even);
+ llvm::RoundingMode llvmRoundingMode =
+ convertArithRoundingModeToLLVMIR(roundingMode);
+ FailureOr<APFloat> result =
+ convertFloatValue(a, targetSemantics, llvmRoundingMode);
+ if (failed(result)) {
+ castStatus = false;
+ return a;
+ }
+ return *result;
+ });
+}
+
+bool arith::FPToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+ if (!areValidCastInputsAndOutputs(inputs, outputs))
+ return false;
+ auto srcType = getTypeIfLike<FloatType>(inputs.front());
+ auto dstType = getTypeIfLike<FloatType>(outputs.front());
+ if (!srcType || !dstType)
+ return false;
+ return srcType != dstType;
+}
+
+LogicalResult arith::FPToFPOp::verify() {
+ Type srcType = getElementTypeOrSelf(getIn().getType());
+ Type dstType = getElementTypeOrSelf(getType());
+ if (srcType == dstType)
+ return emitError("result element type ")
+ << dstType << " must be different from operand element type "
+ << srcType;
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ScalingTruncFOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 47069906fa110..5c43639442338 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -838,3 +838,119 @@ func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8
%3 = arith.cmpf oeq, %arg0, %arg3 : f32
return
}
+
+// -----
+
+// CHECK-LABEL: func @fptofp_ext(
+// CHECK-SAME: %[[ARG0:.*]]: f16, %[[ARG1:.*]]: f32)
+func.func @fptofp_ext(%arg0 : f16, %arg1 : f32) {
+// CHECK-NEXT: = llvm.fpext %[[ARG0]] : f16 to f32
+ %0 = arith.fptofp %arg0 : f16 to f32
+// CHECK-NEXT: = llvm.fpext %[[ARG0]] : f16 to f64
+ %1 = arith.fptofp %arg0 : f16 to f64
+// CHECK-NEXT: = llvm.fpext %[[ARG1]] : f32 to f64
+ %2 = arith.fptofp %arg1 : f32 to f64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @fptofp_trunc(
+// CHECK-SAME: %[[ARG0:.*]]: f32, %[[ARG1:.*]]: f64)
+func.func @fptofp_trunc(%arg0 : f32, %arg1 : f64) {
+// CHECK-NEXT: = llvm.fptrunc %[[ARG0]] : f32 to f16
+ %0 = arith.fptofp %arg0 : f32 to f16
+// CHECK-NEXT: = llvm.fptrunc %[[ARG1]] : f64 to f16
+ %1 = arith.fptofp %arg1 : f64 to f16
+// CHECK-NEXT: = llvm.fptrunc %[[ARG1]] : f64 to f32
+ %2 = arith.fptofp %arg1 : f64 to f32
+ return
+}
+
+// -----
+
+// bf16 <-> f16: same width, different semantics. Lowered via f32.
+// CHECK-LABEL: func @fptofp_same_width(
+// CHECK-SAME: %[[ARG0:.*]]: bf16, %[[ARG1:.*]]: f16)
+func.func @fptofp_same_width(%arg0 : bf16, %arg1 : f16) {
+// CHECK-NEXT: %[[EXT0:.*]] = llvm.fpext %[[ARG0]] : bf16 to f32
+// CHECK-NEXT: = llvm.fptrunc %[[EXT0]] : f32 to f16
+ %0 = arith.fptofp %arg0 : bf16 to f16
+// CHECK-NEXT: %[[EXT1:.*]] = llvm.fpext %[[ARG1]] : f16 to f32
+// CHECK-NEXT: = llvm.fptrunc %[[EXT1]] : f32 to bf16
+ %1 = arith.fptofp %arg1 : f16 to bf16
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @fptofp_ext_vector(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xf16>, %[[ARG1:.*]]: vector<2xf32>)
+func.func @fptofp_ext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
+// CHECK-NEXT: = llvm.fpext %[[ARG0]] : vector<2xf16> to vector<2xf32>
+ %0 = arith.fptofp %arg0 : vector<2xf16> to vector<2xf32>
+// CHECK-NEXT: = llvm.fpext %[[ARG0]] : vector<2xf16> to vector<2xf64>
+ %1 = arith.fptofp %arg0 : vector<2xf16> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext %[[ARG1]] : vector<2xf32> to vector<2xf64>
+ %2 = arith.fptofp %arg1 : vector<2xf32> to vector<2xf64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @fptofp_trunc_vector(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>, %[[ARG1:.*]]: vector<2xf64>)
+func.func @fptofp_trunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
+// CHECK-NEXT: = llvm.fptrunc %[[ARG0]] : vector<2xf32> to vector<2xf16>
+ %0 = arith.fptofp %arg0 : vector<2xf32> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc %[[ARG1]] : vector<2xf64> to vector<2xf16>
+ %1 = arith.fptofp %arg1 : vector<2xf64> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc %[[ARG1]] : vector<2xf64> to vector<2xf32>
+ %2 = arith.fptofp %arg1 : vector<2xf64> to vector<2xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @fptofp_same_width_vector(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xbf16>, %[[ARG1:.*]]: vector<2xf16>)
+func.func @fptofp_same_width_vector(%arg0 : vector<2xbf16>, %arg1 : vector<2xf16>) {
+// CHECK-NEXT: %[[EXT0:.*]] = llvm.fpext %[[ARG0]] : vector<2xbf16> to vector<2xf32>
+// CHECK-NEXT: = llvm.fptrunc %[[EXT0]] : vector<2xf32> to vector<2xf16>
+ %0 = arith.fptofp %arg0 : vector<2xbf16> to vector<2xf16>
+// CHECK-NEXT: %[[EXT1:.*]] = llvm.fpext %[[ARG1]] : vector<2xf16> to vector<2xf32>
+// CHECK-NEXT: = llvm.fptrunc %[[EXT1]] : vector<2xf32> to vector<2xbf16>
+ %1 = arith.fptofp %arg1 : vector<2xf16> to vector<2xbf16>
+ return
+}
+
+// -----
+
+// Multi-dimensional vectors are unrolled.
+// CHECK-LABEL: @fptofp_ext_multidim_vector
+func.func @fptofp_ext_multidim_vector(%arg0 : vector<2x3xf16>) {
+// CHECK-COUNT-2: llvm.fpext %{{.*}} : vector<3xf16> to vector<3xf32>
+ %0 = arith.fptofp %arg0 : vector<2x3xf16> to vector<2x3xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @fptofp_trunc_multidim_vector
+func.func @fptofp_trunc_multidim_vector(%arg0 : vector<2x3xf64>) {
+// CHECK-COUNT-2: llvm.fptrunc %{{.*}} : vector<3xf64> to vector<3xf32>
+ %0 = arith.fptofp %arg0 : vector<2x3xf64> to vector<2x3xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @fptofp_same_width_multidim_vector
+func.func @fptofp_same_width_multidim_vector(%arg0 : vector<2x3xbf16>) {
+// CHECK: llvm.fpext %{{.*}} : vector<3xbf16> to vector<3xf32>
+// CHECK: llvm.fptrunc %{{.*}} : vector<3xf32> to vector<3xf16>
+// CHECK: llvm.fpext %{{.*}} : vector<3xbf16> to vector<3xf32>
+// CHECK: llvm.fptrunc %{{.*}} : vector<3xf32> to vector<3xf16>
+ %0 = arith.fptofp %arg0 : vector<2x3xbf16> to vector<2x3xf16>
+ return
+}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 643e4e076e7c6..5d3bc34636669 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3553,3 +3553,14 @@ func.func @truncf_neg_inf_to_finite_only_no_fold() -> f4E2M1FN {
return %result : f4E2M1FN
}
+// -----
+
+// CHECK-LABEL: @fptofp_fold_f8
+// CHECK: %[[C:.*]] = arith.constant 2.000000e+00 : f8E5M2
+// CHECK: return %[[C]]
+func.func @fptofp_fold_f8() -> f8E5M2 {
+ %c = arith.constant 2.0 : f8E4M3FN
+ %result = arith.fptofp %c : f8E4M3FN to f8E5M2
+ return %result : f8E5M2
+}
+
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 0ea614e0d4b97..722d36d6448be 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -1016,3 +1016,43 @@ func.func @index_castui_i0(%a: i0) -> index {
%0 = arith.index_castui %a : i0 to index
return %0 : index
}
+
+// -----
+
+func.func @fptofp_same_type(%arg0 : f32) {
+ // expected-error @+1 {{are cast incompatible}}
+ %0 = arith.fptofp %arg0 : f32 to f32
+ return
+}
+
+// -----
+
+func.func @fptofp_same_type_vec(%arg0 : vector<2xf16>) {
+ // expected-error @+1 {{are cast incompatible}}
+ %0 = arith.fptofp %arg0 : vector<2xf16> to vector<2xf16>
+ return
+}
+
+// -----
+
+func.func @fptofp_shape_mismatch(%arg0 : vector<2xf16>) {
+ // expected-error @+1 {{op requires the same shape for all operands and results}}
+ %0 = arith.fptofp %arg0 : vector<2xf16> to vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @fptofp_int_input(%arg0 : i32) {
+ // expected-error @+1 {{op operand #0 must be floating-point-like, but got 'i32'}}
+ %0 = arith.fptofp %arg0 : i32 to f32
+ return
+}
+
+// -----
+
+func.func @fptofp_int_output(%arg0 : f32) {
+ // expected-error @+1 {{op result #0 must be floating-point-like, but got 'i32'}}
+ %0 = arith.fptofp %arg0 : f32 to i32
+ return
+}
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 9765db69d6dd5..97a359357819f 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1228,3 +1228,83 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
%4 = arith.trunci %arg0 overflow<nsw, nuw> : i64 to i32
return
}
+
+// CHECK-LABEL: func @test_fptofp(
+// CHECK-SAME: %[[ARG0:.*]]: f16
+func.func @test_fptofp(%arg0 : f16) -> f32 {
+ // CHECK: arith.fptofp %[[ARG0]] : f16 to f32
+ %0 = arith.fptofp %arg0 : f16 to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_fptofp_bf16_to_f16(
+// CHECK-SAME: %[[ARG0:.*]]: bf16
+func.func @test_fptofp_bf16_to_f16(%arg0 : bf16) -> f16 {
+ // CHECK: arith.fptofp %[[ARG0]] : bf16 to f16
+ %0 = arith.fptofp %arg0 : bf16 to f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_fptofp_vector(
+// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>
+func.func @test_fptofp_vector(%arg0 : vector<8xf16>) -> vector<8xf32> {
+ // CHECK: arith.fptofp %[[ARG0]] : vector<8xf16> to vector<8xf32>
+ %0 = arith.fptofp %arg0 : vector<8xf16> to vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @test_fptofp_scalable_vector(
+// CHECK-SAME: %[[ARG0:.*]]: vector<[8]xf16>
+func.func @test_fptofp_scalable_vector(%arg0 : vector<[8]xf16>) -> vector<[8]xf32> {
+ // CHECK: arith.fptofp %[[ARG0]] : vector<[8]xf16> to vector<[8]xf32>
+ %0 = arith.fptofp %arg0 : vector<[8]xf16> to vector<[8]xf32>
+ return %0 : vector<[8]xf32>
+}
+
+// CHECK-LABEL: func @test_fptofp_tensor(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x8xf32>
+func.func @test_fptofp_tensor(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf64> {
+ // CHECK: arith.fptofp %[[ARG0]] : tensor<8x8xf32> to tensor<8x8xf64>
+ %0 = arith.fptofp %arg0 : tensor<8x8xf32> to tensor<8x8xf64>
+ return %0 : tensor<8x8xf64>
+}
+
+// CHECK-LABEL: func @test_fptofp_tensor_encoding(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x8xf32, "foo">
+func.func @test_fptofp_tensor_encoding(%arg0 : tensor<8x8xf32, "foo">) -> tensor<8x8xf64, "foo"> {
+ // CHECK: arith.fptofp %[[ARG0]] : tensor<8x8xf32, "foo"> to tensor<8x8xf64, "foo">
+ %0 = arith.fptofp %arg0 : tensor<8x8xf32, "foo"> to tensor<8x8xf64, "foo">
+ return %0 : tensor<8x8xf64, "foo">
+}
+
+// CHECK-LABEL: func @test_fptofp_rounding_mode(
+// CHECK-SAME: %[[ARG0:.*]]: f64
+func.func @test_fptofp_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
+ // CHECK: arith.fptofp %[[ARG0]] to_nearest_even : f64 to f32
+ %0 = arith.fptofp %arg0 to_nearest_even : f64 to f32
+ // CHECK: arith.fptofp %[[ARG0]] downward : f64 to f32
+ %1 = arith.fptofp %arg0 downward : f64 to f32
+ // CHECK: arith.fptofp %[[ARG0]] upward : f64 to f32
+ %2 = arith.fptofp %arg0 upward : f64 to f32
+ // CHECK: arith.fptofp %[[ARG0]] toward_zero : f64 to f32
+ %3 = arith.fptofp %arg0 toward_zero : f64 to f32
+ // CHECK: arith.fptofp %[[ARG0]] to_nearest_away : f64 to f32
+ %4 = arith.fptofp %arg0 to_nearest_away : f64 to f32
+ return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func @test_fptofp_fastmath(
+// CHECK-SAME: %[[ARG0:.*]]: f32
+func.func @test_fptofp_fastmath(%arg0 : f32) -> f64 {
+ // CHECK: arith.fptofp %[[ARG0]] fastmath<nnan> : f32 to f64
+ %0 = arith.fptofp %arg0 fastmath<nnan> : f32 to f64
+ return %0 : f64
+}
+
+// CHECK-LABEL: func @test_fptofp_rounding_mode_and_fastmath(
+// CHECK-SAME: %[[ARG0:.*]]: f64
+func.func @test_fptofp_rounding_mode_and_fastmath(%arg0 : f64) -> f32 {
+ // CHECK: arith.fptofp %[[ARG0]] to_nearest_even fastmath<nnan> : f64 to f32
+ %0 = arith.fptofp %arg0 to_nearest_even fastmath<nnan> : f64 to f32
+ return %0 : f32
+}
>From 27a31bd3223c8d7c6c09990a2c5cded8c7be0538 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 23 Mar 2026 14:10:57 +0000
Subject: [PATCH 2/4] address reviews
---
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 56 ++++++++++---------
1 file changed, 31 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index a6fe481304c06..866ead531f812 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -262,12 +262,11 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
-/// Lower arith.fptofp to the appropriate LLVM op(s).
+/// Lower arith.fptofp to the appropriate arith/LLVM op(s).
///
-/// - If src is wider than dst: llvm.fptrunc
-/// - If src is narrower than dst: llvm.fpext
+/// - If src is wider than dst: arith.truncf (recursively lowered)
+/// - If src is narrower than dst: arith.extf (recursively lowered)
/// - bf16 <-> f16: llvm.fpext to f32, then llvm.fptrunc to dst.
-/// - Other FP types: not supported by the LLVM dialect.
struct FPToFPOpLowering : public ConvertOpToLLVMPattern<arith::FPToFPOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -277,21 +276,39 @@ struct FPToFPOpLowering : public ConvertOpToLLVMPattern<arith::FPToFPOp> {
if (LLVM::detail::opHasUnsupportedFloatingPointTypes(op,
*getTypeConverter()))
return rewriter.notifyMatchFailure(op, "unsupported floating point type");
-
auto srcFloat = cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
auto dstFloat = cast<FloatType>(getElementTypeOrSelf(op.getType()));
+ unsigned srcWidth = srcFloat.getWidth();
+ unsigned dstWidth = dstFloat.getWidth();
+
+ if (srcWidth > dstWidth) {
+ rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, op.getType(), op.getIn(),
+ op.getRoundingmodeAttr(),
+ op.getFastmathAttr());
+ return success();
+ }
+ if (srcWidth < dstWidth) {
+ rewriter.replaceOpWithNewOp<arith::ExtFOp>(op, op.getType(), op.getIn(),
+ op.getFastmathAttr());
+ return success();
+ }
+ // Same width, different semantic: lower directly to LLVM. Only bf16 <-> f16
+ // conversions are supported. There is currently no other pair of FP types
+ // that are valid LLVM types.
+ assert((srcFloat.isBF16() && dstFloat.isF16()) ||
+ (srcFloat.isF16() && dstFloat.isBF16()) &&
+ "only bf16 <-> f16 conversions are supported");
Type convertedType = getTypeConverter()->convertType(op.getType());
if (!convertedType)
return rewriter.notifyMatchFailure(op, "failed to convert result type");
Value input = adaptor.getIn();
Location loc = op.getLoc();
- Type operandType = input.getType();
- if (!isa<LLVM::LLVMArrayType>(operandType)) {
- rewriter.replaceOp(op, emitConversion(rewriter, loc, input, convertedType,
- srcFloat, dstFloat));
+ if (!isa<LLVM::LLVMArrayType>(input.getType())) {
+ rewriter.replaceOp(
+ op, emitSameWidthConversion(rewriter, loc, input, convertedType));
return success();
}
@@ -301,27 +318,16 @@ struct FPToFPOpLowering : public ConvertOpToLLVMPattern<arith::FPToFPOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
- return emitConversion(rewriter, loc, operands.front(), llvm1DVectorTy,
- srcFloat, dstFloat);
+ return emitSameWidthConversion(rewriter, loc, operands.front(),
+ llvm1DVectorTy);
},
rewriter);
}
private:
- static Value emitConversion(ConversionPatternRewriter &rewriter, Location loc,
- Value input, Type targetType, FloatType srcFloat,
- FloatType dstFloat) {
- unsigned srcWidth = srcFloat.getWidth();
- unsigned dstWidth = dstFloat.getWidth();
- if (srcWidth > dstWidth)
- return LLVM::FPTruncOp::create(rewriter, loc, targetType, input);
- if (srcWidth < dstWidth)
- return LLVM::FPExtOp::create(rewriter, loc, targetType, input);
-
- // Same width, different semantics: bf16 <-> f16
- assert((srcFloat.isBF16() && dstFloat.isF16() ||
- srcFloat.isF16() && dstFloat.isBF16()) &&
- "only bf16 <-> f16 conversions are supported");
+ static Value emitSameWidthConversion(ConversionPatternRewriter &rewriter,
+ Location loc, Value input,
+ Type targetType) {
Type f32Scalar = Float32Type::get(rewriter.getContext());
Type f32Ty = f32Scalar;
if (auto vecTy = dyn_cast<VectorType>(targetType))
>From be25c1d2bdd977272da8a097b37bf5f97b112694 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 24 Mar 2026 10:30:06 +0000
Subject: [PATCH 3/4] address comments
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 20 +--
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 50 +++---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 11 +-
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 148 ++++--------------
.../convert-nd-vector-to-llvmir.mlir | 15 ++
mlir/test/Dialect/Arith/invalid.mlir | 48 ++++++
mlir/test/Dialect/Arith/ops.mlir | 114 +++++---------
7 files changed, 167 insertions(+), 239 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 06ff3f9eeac44..7f2c7337ba706 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1439,17 +1439,17 @@ def Arith_FPToFPOp :
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
Results<(outs FloatLike:$out)> {
- let summary = "cast between floating-point types";
+ let summary = "cast between floating-point types of the same bitwidth";
let description = [{
- Cast a floating-point value to a different floating-point type. Unlike
- `arith.extf` and `arith.truncf`, this operation supports arbitrary
- conversions between floating-point types, including conversions between
- types of the same bitwidth but different semantics (e.g., f16 to bf16).
-
- The source and destination element types must be different. If the value
- cannot be exactly represented, it is rounded using the provided rounding
- mode or the default one if no rounding mode is provided. When operating
- on vectors, casts elementwise.
+ Cast a floating-point value to a different floating-point type of the same
+ bitwidth. This operation handles conversions between types that have the
+ same bitwidth but different semantics (e.g., f16 to bf16), which cannot
+ be represented by `arith.extf` or `arith.truncf`.
+
+ The source and destination element types must be different and must have
+ the same bitwidth. If the value cannot be exactly represented, it is
+ rounded using the provided rounding mode or the default one if no rounding
+ mode is provided. When operating on vectors, casts elementwise.
}];
let hasFolder = 1;
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 866ead531f812..0d3460eb66ee9 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -262,11 +262,11 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
-/// Lower arith.fptofp to the appropriate arith/LLVM op(s).
+/// Lower arith.fptofp (same-bitwidth FP cast) to LLVM.
///
-/// - If src is wider than dst: arith.truncf (recursively lowered)
-/// - If src is narrower than dst: arith.extf (recursively lowered)
-/// - bf16 <-> f16: llvm.fpext to f32, then llvm.fptrunc to dst.
+/// Extends to f32 via llvm.fpext, then truncates to the target type via
+/// llvm.fptrunc. This handles bf16 <-> f16, which is the only same-bitwidth
+/// pair of LLVM-supported FP types.
struct FPToFPOpLowering : public ConvertOpToLLVMPattern<arith::FPToFPOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -276,29 +276,15 @@ struct FPToFPOpLowering : public ConvertOpToLLVMPattern<arith::FPToFPOp> {
if (LLVM::detail::opHasUnsupportedFloatingPointTypes(op,
*getTypeConverter()))
return rewriter.notifyMatchFailure(op, "unsupported floating point type");
- auto srcFloat = cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
- auto dstFloat = cast<FloatType>(getElementTypeOrSelf(op.getType()));
- unsigned srcWidth = srcFloat.getWidth();
- unsigned dstWidth = dstFloat.getWidth();
-
- if (srcWidth > dstWidth) {
- rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, op.getType(), op.getIn(),
- op.getRoundingmodeAttr(),
- op.getFastmathAttr());
- return success();
- }
- if (srcWidth < dstWidth) {
- rewriter.replaceOpWithNewOp<arith::ExtFOp>(op, op.getType(), op.getIn(),
- op.getFastmathAttr());
- return success();
- }
- // Same width, different semantic: lower directly to LLVM. Only bf16 <-> f16
- // conversions are supported. There is currently no other pair of FP types
- // that are valid LLVM types.
- assert((srcFloat.isBF16() && dstFloat.isF16()) ||
- (srcFloat.isF16() && dstFloat.isBF16()) &&
+ // Only bf16 <-> f16 conversions are supported. There is currently no other
+ // pair of FP types that are valid LLVM types.
+ auto srcType = getElementTypeOrSelf(op.getIn().getType());
+ auto dstType = getElementTypeOrSelf(op.getType());
+ assert((srcType.isBF16() && dstType.isF16()) ||
+ (srcType.isF16() && dstType.isBF16()) &&
"only bf16 <-> f16 conversions are supported");
+
Type convertedType = getTypeConverter()->convertType(op.getType());
if (!convertedType)
return rewriter.notifyMatchFailure(op, "failed to convert result type");
@@ -307,8 +293,8 @@ struct FPToFPOpLowering : public ConvertOpToLLVMPattern<arith::FPToFPOp> {
Location loc = op.getLoc();
if (!isa<LLVM::LLVMArrayType>(input.getType())) {
- rewriter.replaceOp(
- op, emitSameWidthConversion(rewriter, loc, input, convertedType));
+ rewriter.replaceOp(op,
+ emitConversion(rewriter, loc, input, convertedType));
return success();
}
@@ -318,20 +304,20 @@ struct FPToFPOpLowering : public ConvertOpToLLVMPattern<arith::FPToFPOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
- return emitSameWidthConversion(rewriter, loc, operands.front(),
- llvm1DVectorTy);
+ return emitConversion(rewriter, loc, operands.front(),
+ llvm1DVectorTy);
},
rewriter);
}
private:
- static Value emitSameWidthConversion(ConversionPatternRewriter &rewriter,
- Location loc, Value input,
- Type targetType) {
+ static Value emitConversion(ConversionPatternRewriter &rewriter, Location loc,
+ Value input, Type targetType) {
Type f32Scalar = Float32Type::get(rewriter.getContext());
Type f32Ty = f32Scalar;
if (auto vecTy = dyn_cast<VectorType>(targetType))
f32Ty = VectorType::get(vecTy.getShape(), f32Scalar);
+
Value ext = LLVM::FPExtOp::create(rewriter, loc, f32Ty, input);
return LLVM::FPTruncOp::create(rewriter, loc, targetType, ext);
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 774eeb03536be..01c1fb2fec828 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1730,16 +1730,21 @@ bool arith::FPToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
auto dstType = getTypeIfLike<FloatType>(outputs.front());
if (!srcType || !dstType)
return false;
- return srcType != dstType;
+ return srcType != dstType &&
+ srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
}
LogicalResult arith::FPToFPOp::verify() {
- Type srcType = getElementTypeOrSelf(getIn().getType());
- Type dstType = getElementTypeOrSelf(getType());
+ auto srcType = cast<FloatType>(getElementTypeOrSelf(getIn().getType()));
+ auto dstType = cast<FloatType>(getElementTypeOrSelf(getType()));
if (srcType == dstType)
return emitError("result element type ")
<< dstType << " must be different from operand element type "
<< srcType;
+ if (srcType.getWidth() != dstType.getWidth())
+ return emitError("result element type ")
+ << dstType << " must have the same bitwidth as operand element type "
+ << srcType;
return success();
}
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 5c43639442338..f05394f22c6fe 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -377,6 +377,39 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
// -----
+// CHECK-LABEL: @fptofp_f16_to_bf16
+func.func @fptofp_f16_to_bf16(%arg0 : f16) -> bf16 {
+// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : f16 to f32
+// CHECK-NEXT: %[[TRUNC:.*]] = llvm.fptrunc %[[EXT]] : f32 to bf16
+ %0 = arith.fptofp %arg0 : f16 to bf16
+// CHECK-NEXT: return %[[TRUNC]]
+ return %0 : bf16
+}
+
+// -----
+
+// CHECK-LABEL: @fptofp_bf16_to_f16
+func.func @fptofp_bf16_to_f16(%arg0 : bf16) -> f16 {
+// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : bf16 to f32
+// CHECK-NEXT: %[[TRUNC:.*]] = llvm.fptrunc %[[EXT]] : f32 to f16
+ %0 = arith.fptofp %arg0 : bf16 to f16
+// CHECK-NEXT: return %[[TRUNC]]
+ return %0 : f16
+}
+
+// -----
+
+// CHECK-LABEL: @fptofp_vector
+func.func @fptofp_vector(%arg0 : vector<2xf16>) -> vector<2xbf16> {
+// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : vector<2xf16> to vector<2xf32>
+// CHECK-NEXT: %[[TRUNC:.*]] = llvm.fptrunc %[[EXT]] : vector<2xf32> to vector<2xbf16>
+ %0 = arith.fptofp %arg0 : vector<2xf16> to vector<2xbf16>
+// CHECK-NEXT: return %[[TRUNC]]
+ return %0 : vector<2xbf16>
+}
+
+// -----
+
// Check sign and zero extension and truncation of integers.
// CHECK-LABEL: @integer_extension_and_truncation
func.func @integer_extension_and_truncation(%arg0 : i3) {
@@ -839,118 +872,3 @@ func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8
return
}
-// -----
-
-// CHECK-LABEL: func @fptofp_ext(
-// CHECK-SAME: %[[ARG0:.*]]: f16, %[[ARG1:.*]]: f32)
-func.func @fptofp_ext(%arg0 : f16, %arg1 : f32) {
-// CHECK-NEXT: = llvm.fpext %[[ARG0]] : f16 to f32
- %0 = arith.fptofp %arg0 : f16 to f32
-// CHECK-NEXT: = llvm.fpext %[[ARG0]] : f16 to f64
- %1 = arith.fptofp %arg0 : f16 to f64
-// CHECK-NEXT: = llvm.fpext %[[ARG1]] : f32 to f64
- %2 = arith.fptofp %arg1 : f32 to f64
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @fptofp_trunc(
-// CHECK-SAME: %[[ARG0:.*]]: f32, %[[ARG1:.*]]: f64)
-func.func @fptofp_trunc(%arg0 : f32, %arg1 : f64) {
-// CHECK-NEXT: = llvm.fptrunc %[[ARG0]] : f32 to f16
- %0 = arith.fptofp %arg0 : f32 to f16
-// CHECK-NEXT: = llvm.fptrunc %[[ARG1]] : f64 to f16
- %1 = arith.fptofp %arg1 : f64 to f16
-// CHECK-NEXT: = llvm.fptrunc %[[ARG1]] : f64 to f32
- %2 = arith.fptofp %arg1 : f64 to f32
- return
-}
-
-// -----
-
-// bf16 <-> f16: same width, different semantics. Lowered via f32.
-// CHECK-LABEL: func @fptofp_same_width(
-// CHECK-SAME: %[[ARG0:.*]]: bf16, %[[ARG1:.*]]: f16)
-func.func @fptofp_same_width(%arg0 : bf16, %arg1 : f16) {
-// CHECK-NEXT: %[[EXT0:.*]] = llvm.fpext %[[ARG0]] : bf16 to f32
-// CHECK-NEXT: = llvm.fptrunc %[[EXT0]] : f32 to f16
- %0 = arith.fptofp %arg0 : bf16 to f16
-// CHECK-NEXT: %[[EXT1:.*]] = llvm.fpext %[[ARG1]] : f16 to f32
-// CHECK-NEXT: = llvm.fptrunc %[[EXT1]] : f32 to bf16
- %1 = arith.fptofp %arg1 : f16 to bf16
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @fptofp_ext_vector(
-// CHECK-SAME: %[[ARG0:.*]]: vector<2xf16>, %[[ARG1:.*]]: vector<2xf32>)
-func.func @fptofp_ext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
-// CHECK-NEXT: = llvm.fpext %[[ARG0]] : vector<2xf16> to vector<2xf32>
- %0 = arith.fptofp %arg0 : vector<2xf16> to vector<2xf32>
-// CHECK-NEXT: = llvm.fpext %[[ARG0]] : vector<2xf16> to vector<2xf64>
- %1 = arith.fptofp %arg0 : vector<2xf16> to vector<2xf64>
-// CHECK-NEXT: = llvm.fpext %[[ARG1]] : vector<2xf32> to vector<2xf64>
- %2 = arith.fptofp %arg1 : vector<2xf32> to vector<2xf64>
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @fptofp_trunc_vector(
-// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>, %[[ARG1:.*]]: vector<2xf64>)
-func.func @fptofp_trunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
-// CHECK-NEXT: = llvm.fptrunc %[[ARG0]] : vector<2xf32> to vector<2xf16>
- %0 = arith.fptofp %arg0 : vector<2xf32> to vector<2xf16>
-// CHECK-NEXT: = llvm.fptrunc %[[ARG1]] : vector<2xf64> to vector<2xf16>
- %1 = arith.fptofp %arg1 : vector<2xf64> to vector<2xf16>
-// CHECK-NEXT: = llvm.fptrunc %[[ARG1]] : vector<2xf64> to vector<2xf32>
- %2 = arith.fptofp %arg1 : vector<2xf64> to vector<2xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @fptofp_same_width_vector(
-// CHECK-SAME: %[[ARG0:.*]]: vector<2xbf16>, %[[ARG1:.*]]: vector<2xf16>)
-func.func @fptofp_same_width_vector(%arg0 : vector<2xbf16>, %arg1 : vector<2xf16>) {
-// CHECK-NEXT: %[[EXT0:.*]] = llvm.fpext %[[ARG0]] : vector<2xbf16> to vector<2xf32>
-// CHECK-NEXT: = llvm.fptrunc %[[EXT0]] : vector<2xf32> to vector<2xf16>
- %0 = arith.fptofp %arg0 : vector<2xbf16> to vector<2xf16>
-// CHECK-NEXT: %[[EXT1:.*]] = llvm.fpext %[[ARG1]] : vector<2xf16> to vector<2xf32>
-// CHECK-NEXT: = llvm.fptrunc %[[EXT1]] : vector<2xf32> to vector<2xbf16>
- %1 = arith.fptofp %arg1 : vector<2xf16> to vector<2xbf16>
- return
-}
-
-// -----
-
-// Multi-dimensional vectors are unrolled.
-// CHECK-LABEL: @fptofp_ext_multidim_vector
-func.func @fptofp_ext_multidim_vector(%arg0 : vector<2x3xf16>) {
-// CHECK-COUNT-2: llvm.fpext %{{.*}} : vector<3xf16> to vector<3xf32>
- %0 = arith.fptofp %arg0 : vector<2x3xf16> to vector<2x3xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @fptofp_trunc_multidim_vector
-func.func @fptofp_trunc_multidim_vector(%arg0 : vector<2x3xf64>) {
-// CHECK-COUNT-2: llvm.fptrunc %{{.*}} : vector<3xf64> to vector<3xf32>
- %0 = arith.fptofp %arg0 : vector<2x3xf64> to vector<2x3xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @fptofp_same_width_multidim_vector
-func.func @fptofp_same_width_multidim_vector(%arg0 : vector<2x3xbf16>) {
-// CHECK: llvm.fpext %{{.*}} : vector<3xbf16> to vector<3xf32>
-// CHECK: llvm.fptrunc %{{.*}} : vector<3xf32> to vector<3xf16>
-// CHECK: llvm.fpext %{{.*}} : vector<3xbf16> to vector<3xf32>
-// CHECK: llvm.fptrunc %{{.*}} : vector<3xf32> to vector<3xf16>
- %0 = arith.fptofp %arg0 : vector<2x3xbf16> to vector<2x3xf16>
- return
-}
diff --git a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
index 86f0be81ce99e..1e079f4c2cd9e 100644
--- a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
@@ -125,6 +125,21 @@ func.func @fptrunc_vector(%arg0 : vector<1x2x3xf64>) -> vector<1x2x3xf16> {
return %0 : vector<1x2x3xf16>
}
+// CHECK-LABEL: @fptofp
+func.func @fptofp_vector(%arg0 : vector<1x2x3xf16>) -> vector<1x2x3xbf16> {
+ // CHECK: llvm.mlir.poison : !llvm.array<1 x array<2 x vector<3xbf16>>>
+ // CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf16>>>
+ // CHECK: llvm.fpext %{{.*}} : vector<3xf16> to vector<3xf32>
+ // CHECK: llvm.fptrunc %{{.*}} : vector<3xf32> to vector<3xbf16>
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xbf16>>>
+ // CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xf16>>>
+ // CHECK: llvm.fpext %{{.*}} : vector<3xf16> to vector<3xf32>
+ // CHECK: llvm.fptrunc %{{.*}} : vector<3xf32> to vector<3xbf16>
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xbf16>>>
+ %0 = arith.fptofp %arg0: vector<1x2x3xf16> to vector<1x2x3xbf16>
+ return %0 : vector<1x2x3xbf16>
+}
+
// CHECK-LABEL: @trunci
func.func @trunci_vector(%arg0 : vector<1x2x3xi64>) -> vector<1x2x3xi16> {
// CHECK: llvm.mlir.poison : !llvm.array<1 x array<2 x vector<3xi16>>>
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 722d36d6448be..cc5c8997f7cbb 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -535,6 +535,54 @@ func.func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
// -----
+func.func @fptofp_same_type(%arg0 : f16) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = arith.fptofp %arg0 : f16 to f16
+ return
+}
+
+// -----
+
+func.func @fptofp_different_bitwidth(%arg0 : f16) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = arith.fptofp %arg0 : f16 to f32
+ return
+}
+
+// -----
+
+func.func @fptofp_different_bitwidth_trunc(%arg0 : f32) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = arith.fptofp %arg0 : f32 to f16
+ return
+}
+
+// -----
+
+func.func @fptofp_vec_same_type(%arg0 : vector<2xf16>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = arith.fptofp %arg0 : vector<2xf16> to vector<2xf16>
+ return
+}
+
+// -----
+
+func.func @fptofp_vec_different_bitwidth(%arg0 : vector<2xf16>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = arith.fptofp %arg0 : vector<2xf16> to vector<2xf32>
+ return
+}
+
+// -----
+
+func.func @fptofp_shape_mismatch(%arg0 : vector<2xf16>) {
+ // expected-error at +1 {{op requires the same shape for all operands and results}}
+ %0 = arith.fptofp %arg0 : vector<2xf16> to vector<3xbf16>
+ return
+}
+
+// -----
+
func.func @sexti_index_as_operand(%arg0 : index) {
// expected-error at +1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}}
%0 = arith.extsi %arg0 : index to i128
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 97a359357819f..541eb514c1867 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -751,6 +751,41 @@ func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
}
+// CHECK-LABEL: test_fptofp
+func.func @test_fptofp(%arg0 : f16) -> bf16 {
+ // CHECK: arith.fptofp %arg0 : f16 to bf16
+ %0 = arith.fptofp %arg0 : f16 to bf16
+ return %0 : bf16
+}
+
+// CHECK-LABEL: test_fptofp_vector
+func.func @test_fptofp_vector(%arg0 : vector<8xf16>) -> vector<8xbf16> {
+ // CHECK: arith.fptofp %arg0 : vector<8xf16> to vector<8xbf16>
+ %0 = arith.fptofp %arg0 : vector<8xf16> to vector<8xbf16>
+ return %0 : vector<8xbf16>
+}
+
+// CHECK-LABEL: test_fptofp_scalable_vector
+func.func @test_fptofp_scalable_vector(%arg0 : vector<[8]xbf16>) -> vector<[8]xf16> {
+ // CHECK: arith.fptofp %arg0 : vector<[8]xbf16> to vector<[8]xf16>
+ %0 = arith.fptofp %arg0 : vector<[8]xbf16> to vector<[8]xf16>
+ return %0 : vector<[8]xf16>
+}
+
+// CHECK-LABEL: test_fptofp_tensor
+func.func @test_fptofp_tensor(%arg0 : tensor<8x8xf16>) -> tensor<8x8xbf16> {
+ // CHECK: arith.fptofp %arg0 : tensor<8x8xf16> to tensor<8x8xbf16>
+ %0 = arith.fptofp %arg0 : tensor<8x8xf16> to tensor<8x8xbf16>
+ return %0 : tensor<8x8xbf16>
+}
+
+// CHECK-LABEL: test_fptofp_rounding_mode
+func.func @test_fptofp_rounding_mode(%arg0 : bf16) -> f16 {
+ // CHECK: arith.fptofp %arg0 to_nearest_even : bf16 to f16
+ %0 = arith.fptofp %arg0 to_nearest_even : bf16 to f16
+ return %0 : f16
+}
+
// CHECK-LABEL: test_uitofp
func.func @test_uitofp(%arg0 : i32) -> f32 {
%0 = arith.uitofp %arg0 : i32 to f32
@@ -1229,82 +1264,3 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
return
}
-// CHECK-LABEL: func @test_fptofp(
-// CHECK-SAME: %[[ARG0:.*]]: f16
-func.func @test_fptofp(%arg0 : f16) -> f32 {
- // CHECK: arith.fptofp %[[ARG0]] : f16 to f32
- %0 = arith.fptofp %arg0 : f16 to f32
- return %0 : f32
-}
-
-// CHECK-LABEL: func @test_fptofp_bf16_to_f16(
-// CHECK-SAME: %[[ARG0:.*]]: bf16
-func.func @test_fptofp_bf16_to_f16(%arg0 : bf16) -> f16 {
- // CHECK: arith.fptofp %[[ARG0]] : bf16 to f16
- %0 = arith.fptofp %arg0 : bf16 to f16
- return %0 : f16
-}
-
-// CHECK-LABEL: func @test_fptofp_vector(
-// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>
-func.func @test_fptofp_vector(%arg0 : vector<8xf16>) -> vector<8xf32> {
- // CHECK: arith.fptofp %[[ARG0]] : vector<8xf16> to vector<8xf32>
- %0 = arith.fptofp %arg0 : vector<8xf16> to vector<8xf32>
- return %0 : vector<8xf32>
-}
-
-// CHECK-LABEL: func @test_fptofp_scalable_vector(
-// CHECK-SAME: %[[ARG0:.*]]: vector<[8]xf16>
-func.func @test_fptofp_scalable_vector(%arg0 : vector<[8]xf16>) -> vector<[8]xf32> {
- // CHECK: arith.fptofp %[[ARG0]] : vector<[8]xf16> to vector<[8]xf32>
- %0 = arith.fptofp %arg0 : vector<[8]xf16> to vector<[8]xf32>
- return %0 : vector<[8]xf32>
-}
-
-// CHECK-LABEL: func @test_fptofp_tensor(
-// CHECK-SAME: %[[ARG0:.*]]: tensor<8x8xf32>
-func.func @test_fptofp_tensor(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf64> {
- // CHECK: arith.fptofp %[[ARG0]] : tensor<8x8xf32> to tensor<8x8xf64>
- %0 = arith.fptofp %arg0 : tensor<8x8xf32> to tensor<8x8xf64>
- return %0 : tensor<8x8xf64>
-}
-
-// CHECK-LABEL: func @test_fptofp_tensor_encoding(
-// CHECK-SAME: %[[ARG0:.*]]: tensor<8x8xf32, "foo">
-func.func @test_fptofp_tensor_encoding(%arg0 : tensor<8x8xf32, "foo">) -> tensor<8x8xf64, "foo"> {
- // CHECK: arith.fptofp %[[ARG0]] : tensor<8x8xf32, "foo"> to tensor<8x8xf64, "foo">
- %0 = arith.fptofp %arg0 : tensor<8x8xf32, "foo"> to tensor<8x8xf64, "foo">
- return %0 : tensor<8x8xf64, "foo">
-}
-
-// CHECK-LABEL: func @test_fptofp_rounding_mode(
-// CHECK-SAME: %[[ARG0:.*]]: f64
-func.func @test_fptofp_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
- // CHECK: arith.fptofp %[[ARG0]] to_nearest_even : f64 to f32
- %0 = arith.fptofp %arg0 to_nearest_even : f64 to f32
- // CHECK: arith.fptofp %[[ARG0]] downward : f64 to f32
- %1 = arith.fptofp %arg0 downward : f64 to f32
- // CHECK: arith.fptofp %[[ARG0]] upward : f64 to f32
- %2 = arith.fptofp %arg0 upward : f64 to f32
- // CHECK: arith.fptofp %[[ARG0]] toward_zero : f64 to f32
- %3 = arith.fptofp %arg0 toward_zero : f64 to f32
- // CHECK: arith.fptofp %[[ARG0]] to_nearest_away : f64 to f32
- %4 = arith.fptofp %arg0 to_nearest_away : f64 to f32
- return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
-}
-
-// CHECK-LABEL: func @test_fptofp_fastmath(
-// CHECK-SAME: %[[ARG0:.*]]: f32
-func.func @test_fptofp_fastmath(%arg0 : f32) -> f64 {
- // CHECK: arith.fptofp %[[ARG0]] fastmath<nnan> : f32 to f64
- %0 = arith.fptofp %arg0 fastmath<nnan> : f32 to f64
- return %0 : f64
-}
-
-// CHECK-LABEL: func @test_fptofp_rounding_mode_and_fastmath(
-// CHECK-SAME: %[[ARG0:.*]]: f64
-func.func @test_fptofp_rounding_mode_and_fastmath(%arg0 : f64) -> f32 {
- // CHECK: arith.fptofp %[[ARG0]] to_nearest_even fastmath<nnan> : f64 to f32
- %0 = arith.fptofp %arg0 to_nearest_even fastmath<nnan> : f64 to f32
- return %0 : f32
-}
>From a3318edc3381d64f90924bbacd4d318d5127ff14 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 24 Mar 2026 13:28:42 +0000
Subject: [PATCH 4/4] rename op
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 6 +--
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 8 ++--
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 8 ++--
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 18 ++++----
.../convert-nd-vector-to-llvmir.mlir | 6 +--
mlir/test/Dialect/Arith/canonicalize.mlir | 6 +--
mlir/test/Dialect/Arith/invalid.mlir | 44 +++++++++----------
mlir/test/Dialect/Arith/ops.mlir | 40 ++++++++---------
8 files changed, 68 insertions(+), 68 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 7f2c7337ba706..8f0755c8a8144 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1426,11 +1426,11 @@ def Arith_TruncFOp :
}
//===----------------------------------------------------------------------===//
-// FPToFPOp
+// ConvertFOp
//===----------------------------------------------------------------------===//
-def Arith_FPToFPOp :
- Arith_Op<"fptofp",
+def Arith_ConvertFOp :
+ Arith_Op<"convertf",
[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 0d3460eb66ee9..a0346ec6f4fb6 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -262,16 +262,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
-/// Lower arith.fptofp (same-bitwidth FP cast) to LLVM.
+/// Lower arith.convertf (same-bitwidth FP cast) to LLVM.
///
/// Extends to f32 via llvm.fpext, then truncates to the target type via
/// llvm.fptrunc. This handles bf16 <-> f16, which is the only same-bitwidth
/// pair of LLVM-supported FP types.
-struct FPToFPOpLowering : public ConvertOpToLLVMPattern<arith::FPToFPOp> {
+struct ConvertFOpLowering : public ConvertOpToLLVMPattern<arith::ConvertFOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(arith::FPToFPOp op, OpAdaptor adaptor,
+ matchAndRewrite(arith::ConvertFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (LLVM::detail::opHasUnsupportedFloatingPointTypes(op,
*getTypeConverter()))
@@ -703,7 +703,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
ExtFOpLowering,
ExtSIOpLowering,
ExtUIOpLowering,
- FPToFPOpLowering,
+ ConvertFOpLowering,
FPToSIOpLowering,
FPToUIOpLowering,
IndexCastOpSILowering,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 01c1fb2fec828..6999e9153bb9a 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1700,10 +1700,10 @@ LogicalResult arith::TruncFOp::verify() {
}
//===----------------------------------------------------------------------===//
-// FPToFPOp
+// ConvertFOp
//===----------------------------------------------------------------------===//
-OpFoldResult arith::FPToFPOp::fold(FoldAdaptor adaptor) {
+OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
@@ -1723,7 +1723,7 @@ OpFoldResult arith::FPToFPOp::fold(FoldAdaptor adaptor) {
});
}
-bool arith::FPToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
+bool arith::ConvertFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (!areValidCastInputsAndOutputs(inputs, outputs))
return false;
auto srcType = getTypeIfLike<FloatType>(inputs.front());
@@ -1734,7 +1734,7 @@ bool arith::FPToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
}
-LogicalResult arith::FPToFPOp::verify() {
+LogicalResult arith::ConvertFOp::verify() {
auto srcType = cast<FloatType>(getElementTypeOrSelf(getIn().getType()));
auto dstType = cast<FloatType>(getElementTypeOrSelf(getType()));
if (srcType == dstType)
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index f05394f22c6fe..6a6016c4f5b16 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -377,33 +377,33 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
// -----
-// CHECK-LABEL: @fptofp_f16_to_bf16
-func.func @fptofp_f16_to_bf16(%arg0 : f16) -> bf16 {
+// CHECK-LABEL: @convertf_f16_to_bf16
+func.func @convertf_f16_to_bf16(%arg0 : f16) -> bf16 {
// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : f16 to f32
// CHECK-NEXT: %[[TRUNC:.*]] = llvm.fptrunc %[[EXT]] : f32 to bf16
- %0 = arith.fptofp %arg0 : f16 to bf16
+ %0 = arith.convertf %arg0 : f16 to bf16
// CHECK-NEXT: return %[[TRUNC]]
return %0 : bf16
}
// -----
-// CHECK-LABEL: @fptofp_bf16_to_f16
-func.func @fptofp_bf16_to_f16(%arg0 : bf16) -> f16 {
+// CHECK-LABEL: @convertf_bf16_to_f16
+func.func @convertf_bf16_to_f16(%arg0 : bf16) -> f16 {
// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : bf16 to f32
// CHECK-NEXT: %[[TRUNC:.*]] = llvm.fptrunc %[[EXT]] : f32 to f16
- %0 = arith.fptofp %arg0 : bf16 to f16
+ %0 = arith.convertf %arg0 : bf16 to f16
// CHECK-NEXT: return %[[TRUNC]]
return %0 : f16
}
// -----
-// CHECK-LABEL: @fptofp_vector
-func.func @fptofp_vector(%arg0 : vector<2xf16>) -> vector<2xbf16> {
+// CHECK-LABEL: @convertf_vector
+func.func @convertf_vector(%arg0 : vector<2xf16>) -> vector<2xbf16> {
// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : vector<2xf16> to vector<2xf32>
// CHECK-NEXT: %[[TRUNC:.*]] = llvm.fptrunc %[[EXT]] : vector<2xf32> to vector<2xbf16>
- %0 = arith.fptofp %arg0 : vector<2xf16> to vector<2xbf16>
+ %0 = arith.convertf %arg0 : vector<2xf16> to vector<2xbf16>
// CHECK-NEXT: return %[[TRUNC]]
return %0 : vector<2xbf16>
}
diff --git a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
index 1e079f4c2cd9e..bf1e8580a5b76 100644
--- a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
@@ -125,8 +125,8 @@ func.func @fptrunc_vector(%arg0 : vector<1x2x3xf64>) -> vector<1x2x3xf16> {
return %0 : vector<1x2x3xf16>
}
-// CHECK-LABEL: @fptofp
-func.func @fptofp_vector(%arg0 : vector<1x2x3xf16>) -> vector<1x2x3xbf16> {
+// CHECK-LABEL: @convertf
+func.func @convertf_vector(%arg0 : vector<1x2x3xf16>) -> vector<1x2x3xbf16> {
// CHECK: llvm.mlir.poison : !llvm.array<1 x array<2 x vector<3xbf16>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf16>>>
// CHECK: llvm.fpext %{{.*}} : vector<3xf16> to vector<3xf32>
@@ -136,7 +136,7 @@ func.func @fptofp_vector(%arg0 : vector<1x2x3xf16>) -> vector<1x2x3xbf16> {
// CHECK: llvm.fpext %{{.*}} : vector<3xf16> to vector<3xf32>
// CHECK: llvm.fptrunc %{{.*}} : vector<3xf32> to vector<3xbf16>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xbf16>>>
- %0 = arith.fptofp %arg0: vector<1x2x3xf16> to vector<1x2x3xbf16>
+ %0 = arith.convertf %arg0: vector<1x2x3xf16> to vector<1x2x3xbf16>
return %0 : vector<1x2x3xbf16>
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 5d3bc34636669..18665e2eb6f4a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3555,12 +3555,12 @@ func.func @truncf_neg_inf_to_finite_only_no_fold() -> f4E2M1FN {
// -----
-// CHECK-LABEL: @fptofp_fold_f8
+// CHECK-LABEL: @convertf_fold_f8
// CHECK: %[[C:.*]] = arith.constant 2.000000e+00 : f8E5M2
// CHECK: return %[[C]]
-func.func @fptofp_fold_f8() -> f8E5M2 {
+func.func @convertf_fold_f8() -> f8E5M2 {
%c = arith.constant 2.0 : f8E4M3FN
- %result = arith.fptofp %c : f8E4M3FN to f8E5M2
+ %result = arith.convertf %c : f8E4M3FN to f8E5M2
return %result : f8E5M2
}
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index cc5c8997f7cbb..96013a4fadde5 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -535,49 +535,49 @@ func.func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
// -----
-func.func @fptofp_same_type(%arg0 : f16) {
+func.func @convertf_same_type(%arg0 : f16) {
// expected-error at +1 {{are cast incompatible}}
- %0 = arith.fptofp %arg0 : f16 to f16
+ %0 = arith.convertf %arg0 : f16 to f16
return
}
// -----
-func.func @fptofp_different_bitwidth(%arg0 : f16) {
+func.func @convertf_different_bitwidth(%arg0 : f16) {
// expected-error at +1 {{are cast incompatible}}
- %0 = arith.fptofp %arg0 : f16 to f32
+ %0 = arith.convertf %arg0 : f16 to f32
return
}
// -----
-func.func @fptofp_different_bitwidth_trunc(%arg0 : f32) {
+func.func @convertf_different_bitwidth_trunc(%arg0 : f32) {
// expected-error at +1 {{are cast incompatible}}
- %0 = arith.fptofp %arg0 : f32 to f16
+ %0 = arith.convertf %arg0 : f32 to f16
return
}
// -----
-func.func @fptofp_vec_same_type(%arg0 : vector<2xf16>) {
+func.func @convertf_vec_same_type(%arg0 : vector<2xf16>) {
// expected-error at +1 {{are cast incompatible}}
- %0 = arith.fptofp %arg0 : vector<2xf16> to vector<2xf16>
+ %0 = arith.convertf %arg0 : vector<2xf16> to vector<2xf16>
return
}
// -----
-func.func @fptofp_vec_different_bitwidth(%arg0 : vector<2xf16>) {
+func.func @convertf_vec_different_bitwidth(%arg0 : vector<2xf16>) {
// expected-error at +1 {{are cast incompatible}}
- %0 = arith.fptofp %arg0 : vector<2xf16> to vector<2xf32>
+ %0 = arith.convertf %arg0 : vector<2xf16> to vector<2xf32>
return
}
// -----
-func.func @fptofp_shape_mismatch(%arg0 : vector<2xf16>) {
+func.func @convertf_shape_mismatch(%arg0 : vector<2xf16>) {
// expected-error at +1 {{op requires the same shape for all operands and results}}
- %0 = arith.fptofp %arg0 : vector<2xf16> to vector<3xbf16>
+ %0 = arith.convertf %arg0 : vector<2xf16> to vector<3xbf16>
return
}
@@ -1067,40 +1067,40 @@ func.func @index_castui_i0(%a: i0) -> index {
// -----
-func.func @fptofp_same_type(%arg0 : f32) {
+func.func @convertf_same_type(%arg0 : f32) {
// expected-error @+1 {{are cast incompatible}}
- %0 = arith.fptofp %arg0 : f32 to f32
+ %0 = arith.convertf %arg0 : f32 to f32
return
}
// -----
-func.func @fptofp_same_type_vec(%arg0 : vector<2xf16>) {
+func.func @convertf_same_type_vec(%arg0 : vector<2xf16>) {
// expected-error @+1 {{are cast incompatible}}
- %0 = arith.fptofp %arg0 : vector<2xf16> to vector<2xf16>
+ %0 = arith.convertf %arg0 : vector<2xf16> to vector<2xf16>
return
}
// -----
-func.func @fptofp_shape_mismatch(%arg0 : vector<2xf16>) {
+func.func @convertf_shape_mismatch(%arg0 : vector<2xf16>) {
// expected-error @+1 {{op requires the same shape for all operands and results}}
- %0 = arith.fptofp %arg0 : vector<2xf16> to vector<3xf32>
+ %0 = arith.convertf %arg0 : vector<2xf16> to vector<3xf32>
return
}
// -----
-func.func @fptofp_int_input(%arg0 : i32) {
+func.func @convertf_int_input(%arg0 : i32) {
// expected-error @+1 {{op operand #0 must be floating-point-like, but got 'i32'}}
- %0 = arith.fptofp %arg0 : i32 to f32
+ %0 = arith.convertf %arg0 : i32 to f32
return
}
// -----
-func.func @fptofp_int_output(%arg0 : f32) {
+func.func @convertf_int_output(%arg0 : f32) {
// expected-error @+1 {{op result #0 must be floating-point-like, but got 'i32'}}
- %0 = arith.fptofp %arg0 : f32 to i32
+ %0 = arith.convertf %arg0 : f32 to i32
return
}
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 541eb514c1867..2c5371de9ff24 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -751,38 +751,38 @@ func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
}
-// CHECK-LABEL: test_fptofp
-func.func @test_fptofp(%arg0 : f16) -> bf16 {
- // CHECK: arith.fptofp %arg0 : f16 to bf16
- %0 = arith.fptofp %arg0 : f16 to bf16
+// CHECK-LABEL: test_convertf
+func.func @test_convertf(%arg0 : f16) -> bf16 {
+ // CHECK: arith.convertf %arg0 : f16 to bf16
+ %0 = arith.convertf %arg0 : f16 to bf16
return %0 : bf16
}
-// CHECK-LABEL: test_fptofp_vector
-func.func @test_fptofp_vector(%arg0 : vector<8xf16>) -> vector<8xbf16> {
- // CHECK: arith.fptofp %arg0 : vector<8xf16> to vector<8xbf16>
- %0 = arith.fptofp %arg0 : vector<8xf16> to vector<8xbf16>
+// CHECK-LABEL: test_convertf_vector
+func.func @test_convertf_vector(%arg0 : vector<8xf16>) -> vector<8xbf16> {
+ // CHECK: arith.convertf %arg0 : vector<8xf16> to vector<8xbf16>
+ %0 = arith.convertf %arg0 : vector<8xf16> to vector<8xbf16>
return %0 : vector<8xbf16>
}
-// CHECK-LABEL: test_fptofp_scalable_vector
-func.func @test_fptofp_scalable_vector(%arg0 : vector<[8]xbf16>) -> vector<[8]xf16> {
- // CHECK: arith.fptofp %arg0 : vector<[8]xbf16> to vector<[8]xf16>
- %0 = arith.fptofp %arg0 : vector<[8]xbf16> to vector<[8]xf16>
+// CHECK-LABEL: test_convertf_scalable_vector
+func.func @test_convertf_scalable_vector(%arg0 : vector<[8]xbf16>) -> vector<[8]xf16> {
+ // CHECK: arith.convertf %arg0 : vector<[8]xbf16> to vector<[8]xf16>
+ %0 = arith.convertf %arg0 : vector<[8]xbf16> to vector<[8]xf16>
return %0 : vector<[8]xf16>
}
-// CHECK-LABEL: test_fptofp_tensor
-func.func @test_fptofp_tensor(%arg0 : tensor<8x8xf16>) -> tensor<8x8xbf16> {
- // CHECK: arith.fptofp %arg0 : tensor<8x8xf16> to tensor<8x8xbf16>
- %0 = arith.fptofp %arg0 : tensor<8x8xf16> to tensor<8x8xbf16>
+// CHECK-LABEL: test_convertf_tensor
+func.func @test_convertf_tensor(%arg0 : tensor<8x8xf16>) -> tensor<8x8xbf16> {
+ // CHECK: arith.convertf %arg0 : tensor<8x8xf16> to tensor<8x8xbf16>
+ %0 = arith.convertf %arg0 : tensor<8x8xf16> to tensor<8x8xbf16>
return %0 : tensor<8x8xbf16>
}
-// CHECK-LABEL: test_fptofp_rounding_mode
-func.func @test_fptofp_rounding_mode(%arg0 : bf16) -> f16 {
- // CHECK: arith.fptofp %arg0 to_nearest_even : bf16 to f16
- %0 = arith.fptofp %arg0 to_nearest_even : bf16 to f16
+// CHECK-LABEL: test_convertf_rounding_mode
+func.func @test_convertf_rounding_mode(%arg0 : bf16) -> f16 {
+ // CHECK: arith.convertf %arg0 to_nearest_even : bf16 to f16
+ %0 = arith.convertf %arg0 to_nearest_even : bf16 to f16
return %0 : f16
}
More information about the Mlir-commits
mailing list