[Mlir-commits] [mlir] [mlir][LLVM] Fix conversion of MLIR float types (PR #122634)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 12 02:29:08 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Certain non-standard float types were directly passed through in the LLVM type converter, resulting in invalid IR or failed assertions:

```
mlir-opt: mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp:638: FailureOr<Type> mlir::LLVMTypeConverter::convertVectorType(VectorType) const: Assertion `LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"' failed.
```

The type converter should not define a type conversion rule for such types. Conversion patterns will no apply to ops with such operand types.


---
Full diff: https://github.com/llvm/llvm-project/pull/122634.diff


3 Files Affected:

- (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+9-1) 
- (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+45) 
- (modified) mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir (+8) 


``````````diff
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 72799e42cf3fd1..64bdb248dff430 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -294,13 +294,21 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
 }
 
 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
+  // Valid LLVM float types are used directly.
+  if (LLVM::isCompatibleType(type))
+    return type;
+
+  // F4, F6, F8 types are converted to integer types with the same bit width.
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
       type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
       type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
       type.isFloat8E8M0FNU())
     return IntegerType::get(&getContext(), type.getWidth());
-  return type;
+
+  // Other floating-point types: A custom type conversion rule must be
+  // specified by the user.
+  return Type();
 }
 
 // Convert a `ComplexType` to an LLVM type. The result is a complex number
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index a9dcc0a16b3dbd..1dabacfd8a47cc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -37,6 +37,8 @@ func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4
   return %1 : vector<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @ops
 func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
 ^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
@@ -84,9 +86,14 @@ func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
   %20 = arith.shrsi %arg2, %arg3 : i32
 // CHECK: = llvm.lshr %arg2, %arg3 : i32
   %21 = arith.shrui %arg2, %arg3 : i32
+// CHECK: arith.constant 2.000000e+00 : tf32
+  // There is no type conversion rule for tf32.
+  %22 = arith.constant 2.0 : tf32
   return %0, %10 : f32, i32
 }
 
+// -----
+
 // Checking conversion of index types to integers using i1, assuming no target
 // system would have a 1-bit address space.  Otherwise, we would have had to
 // make this test dependent on the pointer size on the target system.
@@ -99,6 +106,8 @@ func.func @index_cast(%arg0: index, %arg1: i1) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @vector_index_cast
 func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
 // CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -108,6 +117,8 @@ func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
   return
 }
 
+// -----
+
 func.func @index_castui(%arg0: index, %arg1: i1) {
 // CHECK: = llvm.trunc %0 : i{{.*}} to i1
   %0 = arith.index_castui %arg0: index to i1
@@ -116,6 +127,8 @@ func.func @index_castui(%arg0: index, %arg1: i1) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @vector_index_castui
 func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
 // CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -125,6 +138,8 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
   return
 }
 
+// -----
+
 // Checking conversion of signed integer types to floating point.
 // CHECK-LABEL: @sitofp
 func.func @sitofp(%arg0 : i32, %arg1 : i64) {
@@ -139,6 +154,8 @@ func.func @sitofp(%arg0 : i32, %arg1 : i64) {
   return
 }
 
+// -----
+
 // Checking conversion of integer vectors to floating point vector types.
 // CHECK-LABEL: @sitofp_vector
 func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -157,6 +174,8 @@ func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
   return
 }
 
+// -----
+
 // Checking conversion of unsigned integer types to floating point.
 // CHECK-LABEL: @uitofp
 func.func @uitofp(%arg0 : i32, %arg1 : i64) {
@@ -171,6 +190,8 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
   return
 }
 
+// -----
+
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fpext
 func.func @fpext(%arg0 : f16, %arg1 : f32) {
@@ -183,6 +204,8 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
   return
 }
 
+// -----
+
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fpext
 func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
@@ -195,6 +218,8 @@ func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
   return
 }
 
+// -----
+
 // Checking conversion of floating point to integer types.
 // CHECK-LABEL: @fptosi
 func.func @fptosi(%arg0 : f32, %arg1 : f64) {
@@ -209,6 +234,8 @@ func.func @fptosi(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// -----
+
 // Checking conversion of floating point vectors to integer vector types.
 // CHECK-LABEL: @fptosi_vector
 func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -227,6 +254,8 @@ func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
   return
 }
 
+// -----
+
 // Checking conversion of floating point to integer types.
 // CHECK-LABEL: @fptoui
 func.func @fptoui(%arg0 : f32, %arg1 : f64) {
@@ -241,6 +270,8 @@ func.func @fptoui(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// -----
+
 // Checking conversion of floating point vectors to integer vector types.
 // CHECK-LABEL: @fptoui_vector
 func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -259,6 +290,8 @@ func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
   return
 }
 
+// -----
+
 // Checking conversion of integer vectors to floating point vector types.
 // CHECK-LABEL: @uitofp_vector
 func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -277,6 +310,8 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
   return
 }
 
+// -----
+
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc
 func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
@@ -289,6 +324,8 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// -----
+
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc
 func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
@@ -301,6 +338,8 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: experimental_constrained_fptrunc
 func.func @experimental_constrained_fptrunc(%arg0 : f64) {
 // CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
@@ -316,6 +355,8 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
   return
 }
 
+// -----
+
 // Check sign and zero extension and truncation of integers.
 // CHECK-LABEL: @integer_extension_and_truncation
 func.func @integer_extension_and_truncation(%arg0 : i3) {
@@ -328,6 +369,8 @@ func.func @integer_extension_and_truncation(%arg0 : i3) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @integer_cast_0d_vector
 func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
 // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
@@ -340,6 +383,8 @@ func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @fcmp(%arg0: f32, %arg1: f32) {
 func.func @fcmp(f32, f32) -> () {
 ^bb0(%arg0: f32, %arg1: f32):
diff --git a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
index 8396e5ad8ade15..22ac6eae73f534 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
@@ -555,6 +555,14 @@ func.func @index_arg(%arg0: index) -> index {
   return %arg1 : index
 }
 
+// There is no type conversion rule for tf32, so vector<1xtf32> and, therefore,
+// the func op cannot be converted.
+// CHECK: func.func @non_convertible_arg_type({{.*}}: vector<1xtf32>)
+// CHECK:   llvm.return
+func.func @non_convertible_arg_type(%arg: vector<1xtf32>) {
+  return
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %toplevel_module

``````````

</details>


https://github.com/llvm/llvm-project/pull/122634


More information about the Mlir-commits mailing list