[Mlir-commits] [mlir] [mlir][LLVM] Fix conversion of MLIR float types (PR #122634)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 12 02:29:08 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Certain non-standard float types were directly passed through in the LLVM type converter, resulting in invalid IR or failed assertions:
```
mlir-opt: mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp:638: FailureOr<Type> mlir::LLVMTypeConverter::convertVectorType(VectorType) const: Assertion `LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"' failed.
```
The type converter should not define a type conversion rule for such types. Conversion patterns will no apply to ops with such operand types.
---
Full diff: https://github.com/llvm/llvm-project/pull/122634.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+9-1)
- (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+45)
- (modified) mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir (+8)
``````````diff
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 72799e42cf3fd1..64bdb248dff430 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -294,13 +294,21 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
}
Type LLVMTypeConverter::convertFloatType(FloatType type) const {
+ // Valid LLVM float types are used directly.
+ if (LLVM::isCompatibleType(type))
+ return type;
+
+ // F4, F6, F8 types are converted to integer types with the same bit width.
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
type.isFloat8E8M0FNU())
return IntegerType::get(&getContext(), type.getWidth());
- return type;
+
+ // Other floating-point types: A custom type conversion rule must be
+ // specified by the user.
+ return Type();
}
// Convert a `ComplexType` to an LLVM type. The result is a complex number
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index a9dcc0a16b3dbd..1dabacfd8a47cc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -37,6 +37,8 @@ func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4
return %1 : vector<4xf32>
}
+// -----
+
// CHECK-LABEL: @ops
func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
@@ -84,9 +86,14 @@ func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
%20 = arith.shrsi %arg2, %arg3 : i32
// CHECK: = llvm.lshr %arg2, %arg3 : i32
%21 = arith.shrui %arg2, %arg3 : i32
+// CHECK: arith.constant 2.000000e+00 : tf32
+ // There is no type conversion rule for tf32.
+ %22 = arith.constant 2.0 : tf32
return %0, %10 : f32, i32
}
+// -----
+
// Checking conversion of index types to integers using i1, assuming no target
// system would have a 1-bit address space. Otherwise, we would have had to
// make this test dependent on the pointer size on the target system.
@@ -99,6 +106,8 @@ func.func @index_cast(%arg0: index, %arg1: i1) {
return
}
+// -----
+
// CHECK-LABEL: @vector_index_cast
func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -108,6 +117,8 @@ func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
return
}
+// -----
+
func.func @index_castui(%arg0: index, %arg1: i1) {
// CHECK: = llvm.trunc %0 : i{{.*}} to i1
%0 = arith.index_castui %arg0: index to i1
@@ -116,6 +127,8 @@ func.func @index_castui(%arg0: index, %arg1: i1) {
return
}
+// -----
+
// CHECK-LABEL: @vector_index_castui
func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -125,6 +138,8 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
return
}
+// -----
+
// Checking conversion of signed integer types to floating point.
// CHECK-LABEL: @sitofp
func.func @sitofp(%arg0 : i32, %arg1 : i64) {
@@ -139,6 +154,8 @@ func.func @sitofp(%arg0 : i32, %arg1 : i64) {
return
}
+// -----
+
// Checking conversion of integer vectors to floating point vector types.
// CHECK-LABEL: @sitofp_vector
func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -157,6 +174,8 @@ func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of unsigned integer types to floating point.
// CHECK-LABEL: @uitofp
func.func @uitofp(%arg0 : i32, %arg1 : i64) {
@@ -171,6 +190,8 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext(%arg0 : f16, %arg1 : f32) {
@@ -183,6 +204,8 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
@@ -195,6 +218,8 @@ func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
return
}
+// -----
+
// Checking conversion of floating point to integer types.
// CHECK-LABEL: @fptosi
func.func @fptosi(%arg0 : f32, %arg1 : f64) {
@@ -209,6 +234,8 @@ func.func @fptosi(%arg0 : f32, %arg1 : f64) {
return
}
+// -----
+
// Checking conversion of floating point vectors to integer vector types.
// CHECK-LABEL: @fptosi_vector
func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -227,6 +254,8 @@ func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of floating point to integer types.
// CHECK-LABEL: @fptoui
func.func @fptoui(%arg0 : f32, %arg1 : f64) {
@@ -241,6 +270,8 @@ func.func @fptoui(%arg0 : f32, %arg1 : f64) {
return
}
+// -----
+
// Checking conversion of floating point vectors to integer vector types.
// CHECK-LABEL: @fptoui_vector
func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -259,6 +290,8 @@ func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of integer vectors to floating point vector types.
// CHECK-LABEL: @uitofp_vector
func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -277,6 +310,8 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
@@ -289,6 +324,8 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
@@ -301,6 +338,8 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
return
}
+// -----
+
// CHECK-LABEL: experimental_constrained_fptrunc
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
@@ -316,6 +355,8 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
return
}
+// -----
+
// Check sign and zero extension and truncation of integers.
// CHECK-LABEL: @integer_extension_and_truncation
func.func @integer_extension_and_truncation(%arg0 : i3) {
@@ -328,6 +369,8 @@ func.func @integer_extension_and_truncation(%arg0 : i3) {
return
}
+// -----
+
// CHECK-LABEL: @integer_cast_0d_vector
func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
@@ -340,6 +383,8 @@ func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
return
}
+// -----
+
// CHECK-LABEL: func @fcmp(%arg0: f32, %arg1: f32) {
func.func @fcmp(f32, f32) -> () {
^bb0(%arg0: f32, %arg1: f32):
diff --git a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
index 8396e5ad8ade15..22ac6eae73f534 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
@@ -555,6 +555,14 @@ func.func @index_arg(%arg0: index) -> index {
return %arg1 : index
}
+// There is no type conversion rule for tf32, so vector<1xtf32> and, therefore,
+// the func op cannot be converted.
+// CHECK: func.func @non_convertible_arg_type({{.*}}: vector<1xtf32>)
+// CHECK: llvm.return
+func.func @non_convertible_arg_type(%arg: vector<1xtf32>) {
+ return
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
``````````
</details>
https://github.com/llvm/llvm-project/pull/122634
More information about the Mlir-commits
mailing list