[Mlir-commits] [mlir] 6422546 - [mlir][LLVM] Fix conversion of non-standard MLIR float types (#122634)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 12 06:17:16 PST 2025


Author: Matthias Springer
Date: 2025-01-12T15:17:12+01:00
New Revision: 6422546e996c769dda39a681da090fe28870a376

URL: https://github.com/llvm/llvm-project/commit/6422546e996c769dda39a681da090fe28870a376
DIFF: https://github.com/llvm/llvm-project/commit/6422546e996c769dda39a681da090fe28870a376.diff

LOG: [mlir][LLVM] Fix conversion of non-standard MLIR float types (#122634)

Certain non-standard float types were directly passed through in the
LLVM type converter, resulting in invalid IR or failed assertions:

```
mlir-opt: mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp:638: FailureOr<Type> mlir::LLVMTypeConverter::convertVectorType(VectorType) const: Assertion `LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"' failed.
```

The LLVM type converter should not define invalid type conversion rules
for such types. If there is no type conversion rule, conversion patterns
will not apply to ops with such operand types.

Added: 
    

Modified: 
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
    mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 72799e42cf3fd1..64bdb248dff430 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -294,13 +294,21 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
 }
 
 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
+  // Valid LLVM float types are used directly.
+  if (LLVM::isCompatibleType(type))
+    return type;
+
+  // F4, F6, F8 types are converted to integer types with the same bit width.
   if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
       type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
       type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
       type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
       type.isFloat8E8M0FNU())
     return IntegerType::get(&getContext(), type.getWidth());
-  return type;
+
+  // Other floating-point types: A custom type conversion rule must be
+  // specified by the user.
+  return Type();
 }
 
 // Convert a `ComplexType` to an LLVM type. The result is a complex number

diff  --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index a9dcc0a16b3dbd..1dabacfd8a47cc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -37,6 +37,8 @@ func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4
   return %1 : vector<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: @ops
 func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
 ^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
@@ -84,9 +86,14 @@ func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
   %20 = arith.shrsi %arg2, %arg3 : i32
 // CHECK: = llvm.lshr %arg2, %arg3 : i32
   %21 = arith.shrui %arg2, %arg3 : i32
+// CHECK: arith.constant 2.000000e+00 : tf32
+  // There is no type conversion rule for tf32.
+  %22 = arith.constant 2.0 : tf32
   return %0, %10 : f32, i32
 }
 
+// -----
+
 // Checking conversion of index types to integers using i1, assuming no target
 // system would have a 1-bit address space.  Otherwise, we would have had to
 // make this test dependent on the pointer size on the target system.
@@ -99,6 +106,8 @@ func.func @index_cast(%arg0: index, %arg1: i1) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @vector_index_cast
 func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
 // CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -108,6 +117,8 @@ func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
   return
 }
 
+// -----
+
 func.func @index_castui(%arg0: index, %arg1: i1) {
 // CHECK: = llvm.trunc %0 : i{{.*}} to i1
   %0 = arith.index_castui %arg0: index to i1
@@ -116,6 +127,8 @@ func.func @index_castui(%arg0: index, %arg1: i1) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @vector_index_castui
 func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
 // CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -125,6 +138,8 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
   return
 }
 
+// -----
+
 // Checking conversion of signed integer types to floating point.
 // CHECK-LABEL: @sitofp
 func.func @sitofp(%arg0 : i32, %arg1 : i64) {
@@ -139,6 +154,8 @@ func.func @sitofp(%arg0 : i32, %arg1 : i64) {
   return
 }
 
+// -----
+
 // Checking conversion of integer vectors to floating point vector types.
 // CHECK-LABEL: @sitofp_vector
 func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -157,6 +174,8 @@ func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
   return
 }
 
+// -----
+
 // Checking conversion of unsigned integer types to floating point.
 // CHECK-LABEL: @uitofp
 func.func @uitofp(%arg0 : i32, %arg1 : i64) {
@@ -171,6 +190,8 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
   return
 }
 
+// -----
+
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fpext
 func.func @fpext(%arg0 : f16, %arg1 : f32) {
@@ -183,6 +204,8 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
   return
 }
 
+// -----
+
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fpext
 func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
@@ -195,6 +218,8 @@ func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
   return
 }
 
+// -----
+
 // Checking conversion of floating point to integer types.
 // CHECK-LABEL: @fptosi
 func.func @fptosi(%arg0 : f32, %arg1 : f64) {
@@ -209,6 +234,8 @@ func.func @fptosi(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// -----
+
 // Checking conversion of floating point vectors to integer vector types.
 // CHECK-LABEL: @fptosi_vector
 func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -227,6 +254,8 @@ func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
   return
 }
 
+// -----
+
 // Checking conversion of floating point to integer types.
 // CHECK-LABEL: @fptoui
 func.func @fptoui(%arg0 : f32, %arg1 : f64) {
@@ -241,6 +270,8 @@ func.func @fptoui(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// -----
+
 // Checking conversion of floating point vectors to integer vector types.
 // CHECK-LABEL: @fptoui_vector
 func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -259,6 +290,8 @@ func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
   return
 }
 
+// -----
+
 // Checking conversion of integer vectors to floating point vector types.
 // CHECK-LABEL: @uitofp_vector
 func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -277,6 +310,8 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
   return
 }
 
+// -----
+
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc
 func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
@@ -289,6 +324,8 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// -----
+
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc
 func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
@@ -301,6 +338,8 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: experimental_constrained_fptrunc
 func.func @experimental_constrained_fptrunc(%arg0 : f64) {
 // CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
@@ -316,6 +355,8 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
   return
 }
 
+// -----
+
 // Check sign and zero extension and truncation of integers.
 // CHECK-LABEL: @integer_extension_and_truncation
 func.func @integer_extension_and_truncation(%arg0 : i3) {
@@ -328,6 +369,8 @@ func.func @integer_extension_and_truncation(%arg0 : i3) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: @integer_cast_0d_vector
 func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
 // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
@@ -340,6 +383,8 @@ func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @fcmp(%arg0: f32, %arg1: f32) {
 func.func @fcmp(f32, f32) -> () {
 ^bb0(%arg0: f32, %arg1: f32):

diff  --git a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
index 8396e5ad8ade15..22ac6eae73f534 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
@@ -555,6 +555,14 @@ func.func @index_arg(%arg0: index) -> index {
   return %arg1 : index
 }
 
+// There is no type conversion rule for tf32, so vector<1xtf32> and, therefore,
+// the func op cannot be converted.
+// CHECK: func.func @non_convertible_arg_type({{.*}}: vector<1xtf32>)
+// CHECK:   llvm.return
+func.func @non_convertible_arg_type(%arg: vector<1xtf32>) {
+  return
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %toplevel_module


        


More information about the Mlir-commits mailing list