[Mlir-commits] [mlir] b142912 - [mlir][arith] Fix `arith.cmpf` lowering with unsupported FP types (#166684)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 27 19:01:19 PST 2025
Author: Matthias Springer
Date: 2025-11-28T03:01:14Z
New Revision: b14291238a987675b1fb39938efc938afcab8446
URL: https://github.com/llvm/llvm-project/commit/b14291238a987675b1fb39938efc938afcab8446
DIFF: https://github.com/llvm/llvm-project/commit/b14291238a987675b1fb39938efc938afcab8446.diff
LOG: [mlir][arith] Fix `arith.cmpf` lowering with unsupported FP types (#166684)
The `arith.cmpf` lowering pattern used to generate invalid IR when an
unsupported floating-point type was used.
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 47b8381eefda8..32dd8ba2bc391 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -60,6 +60,12 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
Attribute propertiesAttr,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
+
+/// Return "true" if the given type is an unsupported floating point type. In
+/// case of a vector type, return "true" if the element type is an unsupported
+/// floating point type.
+bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
+ Type type);
} // namespace detail
} // namespace LLVM
@@ -97,16 +103,6 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
- /// Return the given type if it's a floating point type. If the given type is
- /// a vector type, return its element type if it's a floating point type.
- static FloatType getFloatingPointType(Type type) {
- if (auto floatType = dyn_cast<FloatType>(type))
- return floatType;
- if (auto vecType = dyn_cast<VectorType>(type))
- return dyn_cast<FloatType>(vecType.getElementType());
- return nullptr;
- }
-
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -114,26 +110,18 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
- // The pattern should not apply if a floating-point operand is converted to
- // a non-floating-point type. This indicates that the floating point type
- // is not supported by the LLVM lowering. (Such types are converted to
- // integers.)
- auto checkType = [&](Value v) -> LogicalResult {
- FloatType floatType = getFloatingPointType(v.getType());
- if (!floatType)
- return success();
- Type convertedType = this->getTypeConverter()->convertType(floatType);
- if (!isa_and_nonnull<FloatType>(convertedType))
- return rewriter.notifyMatchFailure(op,
- "unsupported floating point type");
- return success();
- };
+ // Bail on unsupported floating point types. (These are type-converted to
+ // integer types.)
if (FailOnUnsupportedFP) {
for (Value operand : op->getOperands())
- if (failed(checkType(operand)))
- return failure();
- if (failed(checkType(op->getResult(0))))
- return failure();
+ if (LLVM::detail::isUnsupportedFloatingPointType(
+ *this->getTypeConverter(), operand.getType()))
+ return rewriter.notifyMatchFailure(op,
+ "unsupported floating point type");
+ if (LLVM::detail::isUnsupportedFloatingPointType(
+ *this->getTypeConverter(), op->getResult(0).getType()))
+ return rewriter.notifyMatchFailure(op,
+ "unsupported floating point type");
}
// Determine attributes for the target op
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index cc3e8468f298b..220826dc5f3ac 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -483,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
LogicalResult
CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
+ if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
+ op.getLhs().getType()))
+ return rewriter.notifyMatchFailure(op, "unsupported floating point type");
+
Type operandType = adaptor.getLhs().getType();
Type resultType = op.getResult().getType();
LLVM::FastmathFlags fmf =
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 24b01259f0499..e5969c2539566 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -130,3 +130,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}
+
+/// Return the given type if it's a floating point type. If the given type is
+/// a vector type, return its element type if it's a floating point type.
+static FloatType getFloatingPointType(Type type) {
+ if (auto floatType = dyn_cast<FloatType>(type))
+ return floatType;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return dyn_cast<FloatType>(vecType.getElementType());
+ return nullptr;
+}
+
+bool LLVM::detail::isUnsupportedFloatingPointType(
+ const TypeConverter &typeConverter, Type type) {
+ FloatType floatType = getFloatingPointType(type);
+ if (!floatType)
+ return false;
+ Type convertedType = typeConverter.convertType(floatType);
+ if (!convertedType)
+ return true;
+ return !isa<FloatType>(convertedType);
+}
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 6fdc1104d2609..b53c52d75c0aa 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -770,12 +770,14 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
// CHECK: arith.addf {{.*}} : f4E2M1FN
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
+// CHECK: arith.cmpf {{.*}} : f4E2M1FN
// CHECK: llvm.select {{.*}} : i1, i4
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
- %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+ %3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN
+ %4 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
return
}
@@ -785,9 +787,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2
// CHECK: llvm.fadd {{.*}} : f32
// CHECK: llvm.fadd {{.*}} : vector<4xf32>
// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
-func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
+// CHECK: llvm.fcmp {{.*}} : f32
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) {
%0 = arith.addf %arg0, %arg0 : f32
%1 = arith.addf %arg1, %arg1 : vector<4xf32>
%2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
- return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
+ %3 = arith.cmpf oeq, %arg0, %arg3 : f32
+ return
}
More information about the Mlir-commits
mailing list