[Mlir-commits] [mlir] b142912 - [mlir][arith] Fix `arith.cmpf` lowering with unsupported FP types (#166684)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 27 19:01:19 PST 2025


Author: Matthias Springer
Date: 2025-11-28T03:01:14Z
New Revision: b14291238a987675b1fb39938efc938afcab8446

URL: https://github.com/llvm/llvm-project/commit/b14291238a987675b1fb39938efc938afcab8446
DIFF: https://github.com/llvm/llvm-project/commit/b14291238a987675b1fb39938efc938afcab8446.diff

LOG: [mlir][arith] Fix `arith.cmpf` lowering with unsupported FP types (#166684)

The `arith.cmpf` lowering pattern used to generate invalid IR when an
unsupported floating-point type was used.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
    mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 47b8381eefda8..32dd8ba2bc391 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -60,6 +60,12 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
                                     Attribute propertiesAttr,
                                     const LLVMTypeConverter &typeConverter,
                                     ConversionPatternRewriter &rewriter);
+
+/// Return "true" if the given type is an unsupported floating point type. In
+/// case of a vector type, return "true" if the element type is an unsupported
+/// floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+                                    Type type);
 } // namespace detail
 } // namespace LLVM
 
@@ -97,16 +103,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
 
-  /// Return the given type if it's a floating point type. If the given type is
-  /// a vector type, return its element type if it's a floating point type.
-  static FloatType getFloatingPointType(Type type) {
-    if (auto floatType = dyn_cast<FloatType>(type))
-      return floatType;
-    if (auto vecType = dyn_cast<VectorType>(type))
-      return dyn_cast<FloatType>(vecType.getElementType());
-    return nullptr;
-  }
-
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -114,26 +110,18 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
 
-    // The pattern should not apply if a floating-point operand is converted to
-    // a non-floating-point type. This indicates that the floating point type
-    // is not supported by the LLVM lowering. (Such types are converted to
-    // integers.)
-    auto checkType = [&](Value v) -> LogicalResult {
-      FloatType floatType = getFloatingPointType(v.getType());
-      if (!floatType)
-        return success();
-      Type convertedType = this->getTypeConverter()->convertType(floatType);
-      if (!isa_and_nonnull<FloatType>(convertedType))
-        return rewriter.notifyMatchFailure(op,
-                                           "unsupported floating point type");
-      return success();
-    };
+    // Bail on unsupported floating point types. (These are type-converted to
+    // integer types.)
     if (FailOnUnsupportedFP) {
       for (Value operand : op->getOperands())
-        if (failed(checkType(operand)))
-          return failure();
-      if (failed(checkType(op->getResult(0))))
-        return failure();
+        if (LLVM::detail::isUnsupportedFloatingPointType(
+                *this->getTypeConverter(), operand.getType()))
+          return rewriter.notifyMatchFailure(op,
+                                             "unsupported floating point type");
+      if (LLVM::detail::isUnsupportedFloatingPointType(
+              *this->getTypeConverter(), op->getResult(0).getType()))
+        return rewriter.notifyMatchFailure(op,
+                                           "unsupported floating point type");
     }
 
     // Determine attributes for the target op

diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index cc3e8468f298b..220826dc5f3ac 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -483,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
 LogicalResult
 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
+  if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
+                                                   op.getLhs().getType()))
+    return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+
   Type operandType = adaptor.getLhs().getType();
   Type resultType = op.getResult().getType();
   LLVM::FastmathFlags fmf =

diff  --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 24b01259f0499..e5969c2539566 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -130,3 +130,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
                                        rewriter);
 }
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+  if (auto floatType = dyn_cast<FloatType>(type))
+    return floatType;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return dyn_cast<FloatType>(vecType.getElementType());
+  return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+    const TypeConverter &typeConverter, Type type) {
+  FloatType floatType = getFloatingPointType(type);
+  if (!floatType)
+    return false;
+  Type convertedType = typeConverter.convertType(floatType);
+  if (!convertedType)
+    return true;
+  return !isa<FloatType>(convertedType);
+}

diff  --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 6fdc1104d2609..b53c52d75c0aa 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -770,12 +770,14 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
 //       CHECK:   arith.addf {{.*}} : f4E2M1FN
 //       CHECK:   arith.addf {{.*}} : vector<4xf4E2M1FN>
 //       CHECK:   arith.addf {{.*}} : vector<8x4xf4E2M1FN>
+//       CHECK:   arith.cmpf {{.*}} : f4E2M1FN
 //       CHECK:   llvm.select {{.*}} : i1, i4
 func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
   %0 = arith.addf %arg0, %arg0 : f4E2M1FN
   %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
   %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
-  %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+  %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN
+  %4 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
   return
 }
 
@@ -785,9 +787,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2
 //         CHECK:   llvm.fadd {{.*}} : f32
 //         CHECK:   llvm.fadd {{.*}} : vector<4xf32>
 // CHECK-COUNT-4:   llvm.fadd {{.*}} : vector<8xf32>
-func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
+//         CHECK:   llvm.fcmp {{.*}} : f32
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) {
   %0 = arith.addf %arg0, %arg0 : f32
   %1 = arith.addf %arg1, %arg1 : vector<4xf32>
   %2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
-  return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
+  %3 = arith.cmpf oeq, %arg0, %arg3 : f32
+  return
 }


        


More information about the Mlir-commits mailing list