[Mlir-commits] [mlir] 6c640b8 - [mlir][LLVM] Fix unsupported FP lowering in `VectorConvertToLLVMPattern` (#166513)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 5 04:25:03 PST 2025


Author: Matthias Springer
Date: 2025-11-05T12:24:59Z
New Revision: 6c640b86e6e03298385231cb7e77d2f3524bc643

URL: https://github.com/llvm/llvm-project/commit/6c640b86e6e03298385231cb7e77d2f3524bc643
DIFF: https://github.com/llvm/llvm-project/commit/6c640b86e6e03298385231cb7e77d2f3524bc643.diff

LOG: [mlir][LLVM] Fix unsupported FP lowering in `VectorConvertToLLVMPattern` (#166513)

Fixes a bug in `VectorConvertToLLVMPattern`, which converted operations
with unsupported FP types. E.g., `arith.addf ... : f4E2M1FN` was lowered
to `llvm.fadd ... : i4`, which does not verify. There are a few more
patterns that have the same bug. Those will be fixed in follow-up PRs.

This commit is in preparation of adding an `APFloat`-based lowering for
`arith` operations with unsupported floating-point types.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
    mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 964281592cc65..cad6cec761ab8 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -92,12 +92,43 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
 
+  /// Return the given type if it's a floating point type. If the given type is
+  /// a vector type, return its element type if it's a floating point type.
+  static FloatType getFloatingPointType(Type type) {
+    if (auto floatType = dyn_cast<FloatType>(type))
+      return floatType;
+    if (auto vecType = dyn_cast<VectorType>(type))
+      return dyn_cast<FloatType>(vecType.getElementType());
+    return nullptr;
+  }
+
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     static_assert(
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
+
+    // The pattern should not apply if a floating-point operand is converted to
+    // a non-floating-point type. This indicates that the floating point type
+    // is not supported by the LLVM lowering. (Such types are converted to
+    // integers.)
+    auto checkType = [&](Value v) -> LogicalResult {
+      FloatType floatType = getFloatingPointType(v.getType());
+      if (!floatType)
+        return success();
+      Type convertedType = this->getTypeConverter()->convertType(floatType);
+      if (!isa_and_nonnull<FloatType>(convertedType))
+        return rewriter.notifyMatchFailure(op,
+                                           "unsupported floating point type");
+      return success();
+    };
+    for (Value operand : op->getOperands())
+      if (failed(checkType(operand)))
+        return failure();
+    if (failed(checkType(op->getResult(0))))
+      return failure();
+
     // Determine attributes for the target op
     AttrConvert<SourceOp, TargetOp> attrConvert(op);
 

diff  --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index ba12ff29ebef9..b5dcb01d3dc6b 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -747,3 +747,29 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
   %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
   func.return %2 : memref<?xbf16>
 }
+
+// -----
+
+// CHECK-LABEL: func @unsupported_fp_type
+//       CHECK:   arith.addf {{.*}} : f4E2M1FN
+//       CHECK:   arith.addf {{.*}} : vector<4xf4E2M1FN>
+//       CHECK:   arith.addf {{.*}} : vector<8x4xf4E2M1FN>
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+  %0 = arith.addf %arg0, %arg0 : f4E2M1FN
+  %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
+  %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
+  return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+}
+
+// -----
+
+//   CHECK-LABEL: func @supported_fp_type
+//         CHECK:   llvm.fadd {{.*}} : f32
+//         CHECK:   llvm.fadd {{.*}} : vector<4xf32>
+// CHECK-COUNT-4:   llvm.fadd {{.*}} : vector<8xf32>
+func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
+  %0 = arith.addf %arg0, %arg0 : f32
+  %1 = arith.addf %arg1, %arg1 : vector<4xf32>
+  %2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
+  return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
+}


        


More information about the Mlir-commits mailing list