[Mlir-commits] [mlir] [mlir][LLVM] refactor FailOnUnsupportedFP (PR #172054)
Maksim Levental
llvmlistbot at llvm.org
Fri Dec 12 09:52:10 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/172054
>From b2571a47491a504d987d8f63f5fd9568171e45cf Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Wed, 10 Dec 2025 16:25:44 -0800
Subject: [PATCH] [mlir][LLVM] refactor FailOnUnsupportedFP
---
.../mlir/Conversion/LLVMCommon/Pattern.h | 23 +++++++-
.../Conversion/LLVMCommon/VectorPattern.h | 25 +++-----
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 31 ++++++++++
.../Conversion/LLVMCommon/VectorPattern.cpp | 21 -------
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 57 +++++++++++++------
5 files changed, 100 insertions(+), 57 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f8e0ccc093f8b..cacd500d41291 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -54,6 +54,15 @@ LogicalResult intrinsicRewrite(Operation *op, StringRef intrinsic,
const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter);
+/// Return "true" if the given type is an unsupported floating point type.
+/// In case of a vector type, return "true" if the element type is an
+/// unsupported floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+ Type type);
+/// Return "true" if the given op has any unsupported floating point
+/// types (either operands or results).
+bool opHasUnsupportedFloatingPointTypes(Operation *op,
+ const TypeConverter &typeConverter);
} // namespace detail
/// Decomposes a `src` value into a set of values of type `dstType` through
@@ -203,7 +212,7 @@ class ConvertToLLVMPattern : public ConversionPattern {
/// Utility class for operation conversions targeting the LLVM dialect that
/// match exactly one source operation.
-template <typename SourceOp>
+template <typename SourceOp, bool FailOnUnsupportedFP = false>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
@@ -220,12 +229,24 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
+ // Bail on unsupported floating point types. (These are type-converted to
+ // integer types.)
+ if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ op, *this->typeConverter)) {
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+ }
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
+ // Bail on unsupported floating point types. (These are type-converted to
+ // integer types.)
+ if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ op, *this->typeConverter)) {
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+ }
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 32dd8ba2bc391..65988a2466318 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -60,12 +60,6 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
Attribute propertiesAttr,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
-
-/// Return "true" if the given type is an unsupported floating point type. In
-/// case of a vector type, return "true" if the element type is an unsupported
-/// floating point type.
-bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
- Type type);
} // namespace detail
} // namespace LLVM
@@ -98,9 +92,11 @@ template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough,
bool FailOnUnsupportedFP = false>
-class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
+class VectorConvertToLLVMPattern
+ : public ConvertOpToLLVMPattern<SourceOp, FailOnUnsupportedFP> {
public:
- using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+ using ConvertOpToLLVMPattern<SourceOp,
+ FailOnUnsupportedFP>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
LogicalResult
@@ -112,16 +108,9 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
// Bail on unsupported floating point types. (These are type-converted to
// integer types.)
- if (FailOnUnsupportedFP) {
- for (Value operand : op->getOperands())
- if (LLVM::detail::isUnsupportedFloatingPointType(
- *this->getTypeConverter(), operand.getType()))
- return rewriter.notifyMatchFailure(op,
- "unsupported floating point type");
- if (LLVM::detail::isUnsupportedFloatingPointType(
- *this->getTypeConverter(), op->getResult(0).getType()))
- return rewriter.notifyMatchFailure(op,
- "unsupported floating point type");
+ if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ op, *this->typeConverter)) {
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
}
// Determine attributes for the target op
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index f28a6ccb42455..640ff3d7c3c7d 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -516,3 +516,34 @@ Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
base, index, noWrapFlags)
: base;
}
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+ if (auto floatType = dyn_cast<FloatType>(type))
+ return floatType;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return dyn_cast<FloatType>(vecType.getElementType());
+ return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+ const TypeConverter &typeConverter, Type type) {
+ FloatType floatType = getFloatingPointType(type);
+ if (!floatType)
+ return false;
+ Type convertedType = typeConverter.convertType(floatType);
+ if (!convertedType)
+ return true;
+ return !isa<FloatType>(convertedType);
+}
+
+bool LLVM::detail::opHasUnsupportedFloatingPointTypes(
+ Operation *op, const TypeConverter &typeConverter) {
+ for (Value operand : op->getOperands())
+ if (isUnsupportedFloatingPointType(typeConverter, operand.getType()))
+ return true;
+ return llvm::any_of(op->getResults(), [&typeConverter](OpResult r) {
+ return isUnsupportedFloatingPointType(typeConverter, r.getType());
+ });
+}
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e5969c2539566..24b01259f0499 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -130,24 +130,3 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}
-
-/// Return the given type if it's a floating point type. If the given type is
-/// a vector type, return its element type if it's a floating point type.
-static FloatType getFloatingPointType(Type type) {
- if (auto floatType = dyn_cast<FloatType>(type))
- return floatType;
- if (auto vecType = dyn_cast<VectorType>(type))
- return dyn_cast<FloatType>(vecType.getElementType());
- return nullptr;
-}
-
-bool LLVM::detail::isUnsupportedFloatingPointType(
- const TypeConverter &typeConverter, Type type) {
- FloatType floatType = getFloatingPointType(type);
- if (!floatType)
- return false;
- Type convertedType = typeConverter.convertType(floatType);
- if (!convertedType)
- return true;
- return !isa<FloatType>(convertedType);
-}
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 7cce324f94295..faa4182943f67 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -32,9 +32,10 @@ namespace {
template <typename SourceOp, typename TargetOp>
using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
-template <typename SourceOp, typename TargetOp>
+template <typename SourceOp, typename TargetOp, bool FailOnUnsupportedFP = true>
using ConvertFMFMathToLLVMPattern =
- VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
+ VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
+ FailOnUnsupportedFP>;
using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
@@ -44,7 +45,9 @@ using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
using CtPopFOpLowering =
- VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
+ VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP*/ true>;
using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using FloorOpLowering =
@@ -76,8 +79,10 @@ using ATan2OpLowering =
// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
// may be better to separate the patterns.
template <typename MathOp, typename LLVMOp>
-struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
- using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
+struct IntOpWithFlagLowering
+ : public ConvertOpToLLVMPattern<MathOp, /*FailOnUnsupportedFP*/ true> {
+ using ConvertOpToLLVMPattern<
+ MathOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
LogicalResult
@@ -122,8 +127,11 @@ using CountTrailingZerosOpLowering =
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
-struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
- using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+struct SincosOpLowering
+ : public ConvertOpToLLVMPattern<math::SincosOp,
+ /*FailOnUnsupportedFP*/ true> {
+ using ConvertOpToLLVMPattern<
+ math::SincosOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
@@ -154,8 +162,11 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
};
// A `expm1` is converted into `exp - 1`.
-struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
- using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
+struct ExpM1OpLowering
+ : public ConvertOpToLLVMPattern<math::ExpM1Op,
+ /*FailOnUnsupportedFP*/ true> {
+ using ConvertOpToLLVMPattern<
+ math::ExpM1Op, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
@@ -216,8 +227,11 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
};
// A `log1p` is converted into `log(1 + ...)`.
-struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
- using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
+struct Log1pOpLowering
+ : public ConvertOpToLLVMPattern<math::Log1pOp,
+ /*FailOnUnsupportedFP*/ true> {
+ using ConvertOpToLLVMPattern<
+ math::Log1pOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
@@ -278,8 +292,11 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
};
// A `rsqrt` is converted into `1 / sqrt`.
-struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
- using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
+struct RsqrtOpLowering
+ : public ConvertOpToLLVMPattern<math::RsqrtOp,
+ /*FailOnUnsupportedFP*/ true> {
+ using ConvertOpToLLVMPattern<
+ math::RsqrtOp, /*FailOnUnsupportedFP*/ true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
@@ -339,8 +356,11 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
}
};
-struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
- using ConvertOpToLLVMPattern<math::IsNaNOp>::ConvertOpToLLVMPattern;
+struct IsNaNOpLowering
+ : public ConvertOpToLLVMPattern<math::IsNaNOp,
+ /*FailOnUnsupportedFP=*/true> {
+ using ConvertOpToLLVMPattern<
+ math::IsNaNOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
@@ -358,8 +378,11 @@ struct IsNaNOpLowering : public ConvertOpToLLVMPattern<math::IsNaNOp> {
}
};
-struct IsFiniteOpLowering : public ConvertOpToLLVMPattern<math::IsFiniteOp> {
- using ConvertOpToLLVMPattern<math::IsFiniteOp>::ConvertOpToLLVMPattern;
+struct IsFiniteOpLowering
+ : public ConvertOpToLLVMPattern<math::IsFiniteOp,
+ /*FailOnUnsupportedFP=*/true> {
+ using ConvertOpToLLVMPattern<
+ math::IsFiniteOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
More information about the Mlir-commits
mailing list