[Mlir-commits] [mlir] [mlir][LLVM] refactor FailOnUnsupportedFP (PR #172054)

Maksim Levental llvmlistbot at llvm.org
Fri Dec 12 09:44:17 PST 2025


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/172054

Enable `FailOnUnsupportedFP` for `ConvertToLLVMPattern`

>From 67cfcc98da6cd0412a831542183a8fe3bc9fa9cf Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 10 Dec 2025 16:25:44 -0800
Subject: [PATCH] [mlir][LLVM] refactor FailOnUnsupportedFP

---
 .../mlir/Conversion/LLVMCommon/Pattern.h      | 23 +++++++++++++-
 .../Conversion/LLVMCommon/VectorPattern.h     | 25 +++++----------
 mlir/lib/Conversion/LLVMCommon/Pattern.cpp    | 31 +++++++++++++++++++
 .../Conversion/LLVMCommon/VectorPattern.cpp   | 21 -------------
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 27 +++++++++++-----
 5 files changed, 79 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f8e0ccc093f8b..cacd500d41291 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -54,6 +54,15 @@ LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
                                const LLVMTypeConverter &typeConverter,
                                RewriterBase &rewriter);
 
+/// Return "true" if the given type is an unsupported floating point type.
+/// In case of a vector type, return "true" if the element type is an
+/// unsupported floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+                                    Type type);
+/// Return "true" if the given op has any unsupported floating point
+/// types (either operands or results).
+bool opHasUnsupportedFloatingPointTypes(Operation *op,
+                                        const TypeConverter &typeConverter);
 } // namespace detail
 
 /// Decomposes a `src` value into a set of values of type `dstType` through
@@ -203,7 +212,7 @@ class ConvertToLLVMPattern : public ConversionPattern {
 
 /// Utility class for operation conversions targeting the LLVM dialect that
 /// match exactly one source operation.
-template <typename SourceOp>
+template <typename SourceOp, bool FailOnUnsupportedFP = false>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
@@ -220,12 +229,24 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
+    // Bail on unsupported floating point types. (These are type-converted to
+    // integer types.)
+    if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+                                   op, *this->typeConverter)) {
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+    }
     auto sourceOp = cast<SourceOp>(op);
     return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
+    // Bail on unsupported floating point types. (These are type-converted to
+    // integer types.)
+    if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+                                   op, *this->typeConverter)) {
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+    }
     auto sourceOp = cast<SourceOp>(op);
     return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
                            rewriter);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 32dd8ba2bc391..65988a2466318 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -60,12 +60,6 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
                                     Attribute propertiesAttr,
                                     const LLVMTypeConverter &typeConverter,
                                     ConversionPatternRewriter &rewriter);
-
-/// Return "true" if the given type is an unsupported floating point type. In
-/// case of a vector type, return "true" if the element type is an unsupported
-/// floating point type.
-bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
-                                    Type type);
 } // namespace detail
 } // namespace LLVM
 
@@ -98,9 +92,11 @@ template <typename SourceOp, typename TargetOp,
           template <typename, typename> typename AttrConvert =
               AttrConvertPassThrough,
           bool FailOnUnsupportedFP = false>
-class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
+class VectorConvertToLLVMPattern
+    : public ConvertOpToLLVMPattern<SourceOp, FailOnUnsupportedFP> {
 public:
-  using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+  using ConvertOpToLLVMPattern<SourceOp,
+                               FailOnUnsupportedFP>::ConvertOpToLLVMPattern;
   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
 
   LogicalResult
@@ -112,16 +108,9 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 
     // Bail on unsupported floating point types. (These are type-converted to
     // integer types.)
-    if (FailOnUnsupportedFP) {
-      for (Value operand : op->getOperands())
-        if (LLVM::detail::isUnsupportedFloatingPointType(
-                *this->getTypeConverter(), operand.getType()))
-          return rewriter.notifyMatchFailure(op,
-                                             "unsupported floating point type");
-      if (LLVM::detail::isUnsupportedFloatingPointType(
-              *this->getTypeConverter(), op->getResult(0).getType()))
-        return rewriter.notifyMatchFailure(op,
-                                           "unsupported floating point type");
+    if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+                                   op, *this->typeConverter)) {
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
     }
 
     // Determine attributes for the target op
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index f28a6ccb42455..b5120abfd226f 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -516,3 +516,34 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
                                    base, index, noWrapFlags)
              : base;
 }
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+  if (auto floatType = dyn_cast<FloatType>(type))
+    return floatType;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return dyn_cast<FloatType>(vecType.getElementType());
+  return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+    const TypeConverter &typeConverter, Type type) {
+  FloatType floatType = getFloatingPointType(type);
+  if (!floatType)
+    return false;
+  Type convertedType = typeConverter.convertType(floatType);
+  if (!convertedType)
+    return true;
+  return !isa<FloatType>(convertedType);
+}
+
+bool LLVM::detail::opHasUnsupportedFloatingPointTypes(
+    Operation *op, const TypeConverter &typeConverter) {
+  for (Value operand : op->getOperands())
+    if (isUnsupportedFloatingPointType(typeConverter, operand.getType()))
+      return true;
+  if (isUnsupportedFloatingPointType(typeConverter, op->getResult(0).getType()))
+    return true;
+  return false;
+}
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e5969c2539566..24b01259f0499 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -130,24 +130,3 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
                                        rewriter);
 }
-
-/// Return the given type if it's a floating point type. If the given type is
-/// a vector type, return its element type if it's a floating point type.
-static FloatType getFloatingPointType(Type type) {
-  if (auto floatType = dyn_cast<FloatType>(type))
-    return floatType;
-  if (auto vecType = dyn_cast<VectorType>(type))
-    return dyn_cast<FloatType>(vecType.getElementType());
-  return nullptr;
-}
-
-bool LLVM::detail::isUnsupportedFloatingPointType(
-    const TypeConverter &typeConverter, Type type) {
-  FloatType floatType = getFloatingPointType(type);
-  if (!floatType)
-    return false;
-  Type convertedType = typeConverter.convertType(floatType);
-  if (!convertedType)
-    return true;
-  return !isa<FloatType>(convertedType);
-}
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 7cce324f94295..431290a45eaf4 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -32,11 +32,15 @@ namespace {
 template <typename SourceOp, typename TargetOp>
 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
 
-template <typename SourceOp, typename TargetOp>
+template <typename SourceOp, typename TargetOp,
+          bool FailOnUnsupportedFP = false>
 using ConvertFMFMathToLLVMPattern =
-    VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
+    VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
+                               FailOnUnsupportedFP>;
 
-using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
+using AbsFOpLowering =
+    ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp,
+                                /*FailOnUnsupportedFP=*/true>;
 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
 using CopySignOpLowering =
     ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
@@ -49,7 +53,8 @@ using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
 using FloorOpLowering =
     ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
-using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
+using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp,
+                                                  /*FailOnUnsupportedFP=*/true>;
 using Log10OpLowering =
     ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
@@ -339,8 +344,11 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
   }
 };
 
-struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
-  using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
+struct IsNaNOpLowering
+    : public ConvertOpToLLVMPattern<math::IsNaNOp,
+                                    /*FailOnUnsupportedFP=*/true> {
+  using ConvertOpToLLVMPattern<
+      math::IsNaNOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
@@ -358,8 +366,11 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
   }
 };
 
-struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
-  using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
+struct IsFiniteOpLowering
+    : public ConvertOpToLLVMPattern<math::IsFiniteOp,
+                                    /*FailOnUnsupportedFP=*/true> {
+  using ConvertOpToLLVMPattern<
+      math::IsFiniteOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,



More information about the Mlir-commits mailing list