[Mlir-commits] [mlir] [mlir][LLVM] Fix unsupported FP lowering in `VectorConvertToLLVMPattern` (PR #166513)
Matthias Springer
llvmlistbot at llvm.org
Tue Nov 4 23:10:31 PST 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/166513
Fixes a bug in `VectorConvertToLLVMPattern`, which converted operations with unsupported FP types. E.g., `arith.addf ... : f4E2M1FN` was lowered to `llvm.fadd ... : i4`, which does not verify.
>From 6eb0454c03888a79a4cd33f30fa8b3026f971439 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 5 Nov 2025 07:08:47 +0000
Subject: [PATCH] [mlir][LLVM] Fix unsupported FP lowering in
`VectorConvertToLLVMPattern`
---
.../Conversion/LLVMCommon/VectorPattern.h | 25 ++++++++++++++++++
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 26 +++++++++++++++++++
2 files changed, 51 insertions(+)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 964281592cc65..d8483114f1137 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -92,12 +92,37 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
+ /// Return "true" if the given type (or its element type) is a floating point
+ /// type.
+ static FloatType getFloatingPointType(Type type) {
+ if (auto floatType = dyn_cast<FloatType>(type))
+ return floatType;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return dyn_cast<FloatType>(vecType.getElementType());
+ return nullptr;
+ }
+
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
+
+ // The pattern should not apply if a floating-point operand is converted to
+ // a non-floating-point type. This indicates that the floating point type
+ // is not supported by the LLVM lowering. (Such types are converted to
+ // integers.)
+ for (Value operand : op->getOperands()) {
+ FloatType floatType = getFloatingPointType(operand.getType());
+ if (!floatType)
+ continue;
+ Type convertedType = this->getTypeConverter()->convertType(floatType);
+ if (!isa<FloatType>(convertedType))
+ return rewriter.notifyMatchFailure(op,
+ "unsupported floating point type");
+ }
+
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index ba12ff29ebef9..b5dcb01d3dc6b 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -747,3 +747,29 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
%2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
func.return %2 : memref<?xbf16>
}
+
+// -----
+
+// CHECK-LABEL: func @unsupported_fp_type
+// CHECK: arith.addf {{.*}} : f4E2M1FN
+// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
+// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+ %0 = arith.addf %arg0, %arg0 : f4E2M1FN
+ %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
+ %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
+ return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+}
+
+// -----
+
+// CHECK-LABEL: func @supported_fp_type
+// CHECK: llvm.fadd {{.*}} : f32
+// CHECK: llvm.fadd {{.*}} : vector<4xf32>
+// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
+ %0 = arith.addf %arg0, %arg0 : f32
+ %1 = arith.addf %arg1, %arg1 : vector<4xf32>
+ %2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
+ return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
+}
More information about the Mlir-commits
mailing list