[Mlir-commits] [mlir] [mlir][LLVM] refactor FailOnUnsupportedFP (PR #172054)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 12 10:37:26 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

Enable `FailOnUnsupportedFP` for `ConvertToLLVMPattern`.

---
Full diff: https://github.com/llvm/llvm-project/pull/172054.diff


6 Files Affected:

- (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+22-1) 
- (modified) mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (+7-18) 
- (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+31) 
- (modified) mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp (-21) 
- (modified) mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp (+40-17) 
- (modified) mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir (+13) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f8e0ccc093f8b..cacd500d41291 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -54,6 +54,15 @@ LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
                                const LLVMTypeConverter &typeConverter,
                                RewriterBase &rewriter);
 
+/// Return "true" if the given type is an unsupported floating point type.
+/// In case of a vector type, return "true" if the element type is an
+/// unsupported floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+                                    Type type);
+/// Return "true" if the given op has any unsupported floating point
+/// types (either operands or results).
+bool opHasUnsupportedFloatingPointTypes(Operation *op,
+                                        const TypeConverter &typeConverter);
 } // namespace detail
 
 /// Decomposes a `src` value into a set of values of type `dstType` through
@@ -203,7 +212,7 @@ class ConvertToLLVMPattern : public ConversionPattern {
 
 /// Utility class for operation conversions targeting the LLVM dialect that
 /// match exactly one source operation.
-template <typename SourceOp>
+template <typename SourceOp, bool FailOnUnsupportedFP = false>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
@@ -220,12 +229,24 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
+    // Bail on unsupported floating point types. (These are type-converted to
+    // integer types.)
+    if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+                                   op, *this->typeConverter)) {
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+    }
     auto sourceOp = cast<SourceOp>(op);
     return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
+    // Bail on unsupported floating point types. (These are type-converted to
+    // integer types.)
+    if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+                                   op, *this->typeConverter)) {
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+    }
     auto sourceOp = cast<SourceOp>(op);
     return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
                            rewriter);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 32dd8ba2bc391..65988a2466318 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -60,12 +60,6 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
                                     Attribute propertiesAttr,
                                     const LLVMTypeConverter &typeConverter,
                                     ConversionPatternRewriter &rewriter);
-
-/// Return "true" if the given type is an unsupported floating point type. In
-/// case of a vector type, return "true" if the element type is an unsupported
-/// floating point type.
-bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
-                                    Type type);
 } // namespace detail
 } // namespace LLVM
 
@@ -98,9 +92,11 @@ template <typename SourceOp, typename TargetOp,
           template <typename, typename> typename AttrConvert =
               AttrConvertPassThrough,
           bool FailOnUnsupportedFP = false>
-class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
+class VectorConvertToLLVMPattern
+    : public ConvertOpToLLVMPattern<SourceOp, FailOnUnsupportedFP> {
 public:
-  using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+  using ConvertOpToLLVMPattern<SourceOp,
+                               FailOnUnsupportedFP>::ConvertOpToLLVMPattern;
   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
 
   LogicalResult
@@ -112,16 +108,9 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 
     // Bail on unsupported floating point types. (These are type-converted to
     // integer types.)
-    if (FailOnUnsupportedFP) {
-      for (Value operand : op->getOperands())
-        if (LLVM::detail::isUnsupportedFloatingPointType(
-                *this->getTypeConverter(), operand.getType()))
-          return rewriter.notifyMatchFailure(op,
-                                             "unsupported floating point type");
-      if (LLVM::detail::isUnsupportedFloatingPointType(
-              *this->getTypeConverter(), op->getResult(0).getType()))
-        return rewriter.notifyMatchFailure(op,
-                                           "unsupported floating point type");
+    if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+                                   op, *this->typeConverter)) {
+      return rewriter.notifyMatchFailure(op, "unsupported floating point type");
     }
 
     // Determine attributes for the target op
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index f28a6ccb42455..640ff3d7c3c7d 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -516,3 +516,34 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
                                    base, index, noWrapFlags)
              : base;
 }
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+  if (auto floatType = dyn_cast<FloatType>(type))
+    return floatType;
+  if (auto vecType = dyn_cast<VectorType>(type))
+    return dyn_cast<FloatType>(vecType.getElementType());
+  return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+    const TypeConverter &typeConverter, Type type) {
+  FloatType floatType = getFloatingPointType(type);
+  if (!floatType)
+    return false;
+  Type convertedType = typeConverter.convertType(floatType);
+  if (!convertedType)
+    return true;
+  return !isa<FloatType>(convertedType);
+}
+
+bool LLVM::detail::opHasUnsupportedFloatingPointTypes(
+    Operation *op, const TypeConverter &typeConverter) {
+  for (Value operand : op->getOperands())
+    if (isUnsupportedFloatingPointType(typeConverter, operand.getType()))
+      return true;
+  return llvm::any_of(op->getResults(), [&typeConverter](OpResult r) {
+    return isUnsupportedFloatingPointType(typeConverter, r.getType());
+  });
+}
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e5969c2539566..24b01259f0499 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -130,24 +130,3 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
                                        rewriter);
 }
-
-/// Return the given type if it's a floating point type. If the given type is
-/// a vector type, return its element type if it's a floating point type.
-static FloatType getFloatingPointType(Type type) {
-  if (auto floatType = dyn_cast<FloatType>(type))
-    return floatType;
-  if (auto vecType = dyn_cast<VectorType>(type))
-    return dyn_cast<FloatType>(vecType.getElementType());
-  return nullptr;
-}
-
-bool LLVM::detail::isUnsupportedFloatingPointType(
-    const TypeConverter &typeConverter, Type type) {
-  FloatType floatType = getFloatingPointType(type);
-  if (!floatType)
-    return false;
-  Type convertedType = typeConverter.convertType(floatType);
-  if (!convertedType)
-    return true;
-  return !isa<FloatType>(convertedType);
-}
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 7cce324f94295..faa4182943f67 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -32,9 +32,10 @@ namespace {
 template <typename SourceOp, typename TargetOp>
 using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
 
-template <typename SourceOp, typename TargetOp>
+template <typename SourceOp, typename TargetOp, bool FailOnUnsupportedFP = true>
 using ConvertFMFMathToLLVMPattern =
-    VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
+    VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
+                               FailOnUnsupportedFP>;
 
 using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
 using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
@@ -44,7 +45,9 @@ using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
 using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
 using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
 using CtPopFOpLowering =
-    VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
+    VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp,
+                               AttrConvertPassThrough,
+                               /*FailOnUnsupportedFP*/ true>;
 using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
 using FloorOpLowering =
@@ -76,8 +79,10 @@ using ATan2OpLowering =
 // TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
 // may be better to separate the patterns.
 template <typename MathOp, typename LLVMOp>
-struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
-  using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
+struct IntOpWithFlagLowering
+    : public ConvertOpToLLVMPattern<MathOp, /*FailOnUnsupportedFP*/ true> {
+  using ConvertOpToLLVMPattern<
+      MathOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
   using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
 
   LogicalResult
@@ -122,8 +127,11 @@ using CountTrailingZerosOpLowering =
 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
 
 // A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
-struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
-  using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+struct SincosOpLowering
+    : public ConvertOpToLLVMPattern<math::SincosOp,
+                                    /*FailOnUnsupportedFP*/ true> {
+  using ConvertOpToLLVMPattern<
+      math::SincosOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
@@ -154,8 +162,11 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
 };
 
 // A `expm1` is converted into `exp - 1`.
-struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
-  using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
+struct ExpM1OpLowering
+    : public ConvertOpToLLVMPattern<math::ExpM1Op,
+                                    /*FailOnUnsupportedFP*/ true> {
+  using ConvertOpToLLVMPattern<
+      math::ExpM1Op, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
@@ -216,8 +227,11 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
 };
 
 // A `log1p` is converted into `log(1 + ...)`.
-struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
-  using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
+struct Log1pOpLowering
+    : public ConvertOpToLLVMPattern<math::Log1pOp,
+                                    /*FailOnUnsupportedFP*/ true> {
+  using ConvertOpToLLVMPattern<
+      math::Log1pOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
@@ -278,8 +292,11 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
 };
 
 // A `rsqrt` is converted into `1 / sqrt`.
-struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
-  using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
+struct RsqrtOpLowering
+    : public ConvertOpToLLVMPattern<math::RsqrtOp,
+                                    /*FailOnUnsupportedFP*/ true> {
+  using ConvertOpToLLVMPattern<
+      math::RsqrtOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
@@ -339,8 +356,11 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
   }
 };
 
-struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
-  using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
+struct IsNaNOpLowering
+    : public ConvertOpToLLVMPattern<math::IsNaNOp,
+                                    /*FailOnUnsupportedFP=*/true> {
+  using ConvertOpToLLVMPattern<
+      math::IsNaNOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
@@ -358,8 +378,11 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
   }
 };
 
-struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
-  using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
+struct IsFiniteOpLowering
+    : public ConvertOpToLLVMPattern<math::IsFiniteOp,
+                                    /*FailOnUnsupportedFP=*/true> {
+  using ConvertOpToLLVMPattern<
+      math::IsFiniteOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
 
   LogicalResult
   matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index f7d27120d4207..394aca876ff08 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -628,3 +628,16 @@ func.func @fastmath(%arg0 : f32, %arg1 : vector<4xf32>) {
   %3 = math.fma %arg0, %arg0, %arg0 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f32
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @unsupported_fp_type
+//       CHECK:   math.absf {{.*}} : f4E2M1FN
+//       CHECK:   math.cos {{.*}} : f4E2M1FN
+//       CHECK:   math.fma {{.*}} : f4E2M1FN
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: f4E2M1FN, %arg2: f4E2M1FN) {
+  %0 = math.absf %arg0 : f4E2M1FN
+  %1 = math.cos %arg0 : f4E2M1FN
+  %2 = math.fma %arg1, %arg1, %arg2 : f4E2M1FN
+  return
+}
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/172054


More information about the Mlir-commits mailing list