[Mlir-commits] [mlir] [mlir][LLVM] refactor FailOnUnsupportedFP (PR #172054)
Maksim Levental
llvmlistbot at llvm.org
Fri Dec 12 09:44:17 PST 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/172054
Enable `FailOnUnsupportedFP` for `ConvertToLLVMPattern`
>From 67cfcc98da6cd0412a831542183a8fe3bc9fa9cf Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 10 Dec 2025 16:25:44 -0800
Subject: [PATCH] [mlir][LLVM] refactor FailOnUnsupportedFP
---
.../mlir/Conversion/LLVMCommon/Pattern.h | 23 +++++++++++++-
.../Conversion/LLVMCommon/VectorPattern.h | 25 +++++----------
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 31 +++++++++++++++++++
.../Conversion/LLVMCommon/VectorPattern.cpp | 21 -------------
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 27 +++++++++++-----
5 files changed, 79 insertions(+), 48 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f8e0ccc093f8b..cacd500d41291 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -54,6 +54,15 @@ LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter);
+/// Return "true" if the given type is an unsupported floating point type.
+/// In case of a vector type, return "true" if the element type is an
+/// unsupported floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+ Type type);
+/// Return "true" if the given op has any unsupported floating point
+/// types (either operands or results).
+bool opHasUnsupportedFloatingPointTypes(Operation *op,
+ const TypeConverter &typeConverter);
} // namespace detail
/// Decomposes a `src` value into a set of values of type `dstType` through
@@ -203,7 +212,7 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// Utility class for operation conversions targeting the LLVM dialect that
/// match exactly one source operation.
-template <typename SourceOp>
+template <typename SourceOp, bool FailOnUnsupportedFP = false>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
@@ -220,12 +229,24 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
+ // Bail on unsupported floating point types. (These are type-converted to
+ // integer types.)
+ if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ op, *this->typeConverter)) {
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+ }
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
+ // Bail on unsupported floating point types. (These are type-converted to
+ // integer types.)
+ if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ op, *this->typeConverter)) {
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+ }
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 32dd8ba2bc391..65988a2466318 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -60,12 +60,6 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
Attribute propertiesAttr,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
-
-/// Return "true" if the given type is an unsupported floating point type. In
-/// case of a vector type, return "true" if the element type is an unsupported
-/// floating point type.
-bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
- Type type);
} // namespace detail
} // namespace LLVM
@@ -98,9 +92,11 @@ template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough,
bool FailOnUnsupportedFP = false>
-class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
+class VectorConvertToLLVMPattern
+ : public ConvertOpToLLVMPattern<SourceOp, FailOnUnsupportedFP> {
public:
- using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+ using ConvertOpToLLVMPattern<SourceOp,
+ FailOnUnsupportedFP>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
LogicalResult
@@ -112,16 +108,9 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
// Bail on unsupported floating point types. (These are type-converted to
// integer types.)
- if (FailOnUnsupportedFP) {
- for (Value operand : op->getOperands())
- if (LLVM::detail::isUnsupportedFloatingPointType(
- *this->getTypeConverter(), operand.getType()))
- return rewriter.notifyMatchFailure(op,
- "unsupported floating point type");
- if (LLVM::detail::isUnsupportedFloatingPointType(
- *this->getTypeConverter(), op->getResult(0).getType()))
- return rewriter.notifyMatchFailure(op,
- "unsupported floating point type");
+ if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ op, *this->typeConverter)) {
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
}
// Determine attributes for the target op
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index f28a6ccb42455..b5120abfd226f 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -516,3 +516,34 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
base, index, noWrapFlags)
: base;
}
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+ if (auto floatType = dyn_cast<FloatType>(type))
+ return floatType;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return dyn_cast<FloatType>(vecType.getElementType());
+ return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+ const TypeConverter &typeConverter, Type type) {
+ FloatType floatType = getFloatingPointType(type);
+ if (!floatType)
+ return false;
+ Type convertedType = typeConverter.convertType(floatType);
+ if (!convertedType)
+ return true;
+ return !isa<FloatType>(convertedType);
+}
+
+bool LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ Operation *op, const TypeConverter &typeConverter) {
+ for (Value operand : op->getOperands())
+ if (isUnsupportedFloatingPointType(typeConverter, operand.getType()))
+ return true;
+ if (isUnsupportedFloatingPointType(typeConverter, op->getResult(0).getType()))
+ return true;
+ return false;
+}
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e5969c2539566..24b01259f0499 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -130,24 +130,3 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}
-
-/// Return the given type if it's a floating point type. If the given type is
-/// a vector type, return its element type if it's a floating point type.
-static FloatType getFloatingPointType(Type type) {
- if (auto floatType = dyn_cast<FloatType>(type))
- return floatType;
- if (auto vecType = dyn_cast<VectorType>(type))
- return dyn_cast<FloatType>(vecType.getElementType());
- return nullptr;
-}
-
-bool LLVM::detail::isUnsupportedFloatingPointType(
- const TypeConverter &typeConverter, Type type) {
- FloatType floatType = getFloatingPointType(type);
- if (!floatType)
- return false;
- Type convertedType = typeConverter.convertType(floatType);
- if (!convertedType)
- return true;
- return !isa<FloatType>(convertedType);
-}
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 7cce324f94295..431290a45eaf4 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -32,11 +32,15 @@ namespace {
template <typename SourceOp, typename TargetOp>
using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
-template <typename SourceOp, typename TargetOp>
+template <typename SourceOp, typename TargetOp,
+ bool FailOnUnsupportedFP = false>
using ConvertFMFMathToLLVMPattern =
- VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
+ VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
+ FailOnUnsupportedFP>;
-using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
+using AbsFOpLowering =
+ ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp,
+ /*FailOnUnsupportedFP=*/true>;
using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
using CopySignOpLowering =
ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
@@ -49,7 +53,8 @@ using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using FloorOpLowering =
ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
-using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
+using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp,
+ /*FailOnUnsupportedFP=*/true>;
using Log10OpLowering =
ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
@@ -339,8 +344,11 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
}
};
-struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
- using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
+struct IsNaNOpLowering
+ : public ConvertOpToLLVMPattern<math::IsNaNOp,
+ /*FailOnUnsupportedFP=*/true> {
+ using ConvertOpToLLVMPattern<
+ math::IsNaNOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
@@ -358,8 +366,11 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
}
};
-struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
- using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
+struct IsFiniteOpLowering
+ : public ConvertOpToLLVMPattern<math::IsFiniteOp,
+ /*FailOnUnsupportedFP=*/true> {
+ using ConvertOpToLLVMPattern<
+ math::IsFiniteOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
More information about the Mlir-commits
mailing list