[Mlir-commits] [mlir] 67c1615 - [MLIR] Add vector support for fpexp and fptrunc.
Tim Shen
llvmlistbot at llvm.org
Fri Feb 28 12:48:40 PST 2020
Author: Tim Shen
Date: 2020-02-28T12:24:45-08:00
New Revision: 67c1615440ca399109132b95290fc4b948968c9e
URL: https://github.com/llvm/llvm-project/commit/67c1615440ca399109132b95290fc4b948968c9e
DIFF: https://github.com/llvm/llvm-project/commit/67c1615440ca399109132b95290fc4b948968c9e.diff
LOG: [MLIR] Add vector support for fpexp and fptrunc.
Differential Revision: https://reviews.llvm.org/D75150
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 65bd714d7881..4d88aa8e99da 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2601,6 +2601,10 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
if (auto fa = a.dyn_cast<FloatType>())
if (auto fb = b.dyn_cast<FloatType>())
return fa.getWidth() < fb.getWidth();
+ if (auto va = a.dyn_cast<VectorType>())
+ if (auto vb = b.dyn_cast<VectorType>())
+ return va.getShape().equals(vb.getShape()) &&
+ areCastCompatible(va.getElementType(), vb.getElementType());
return false;
}
@@ -2612,6 +2616,10 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
if (auto fa = a.dyn_cast<FloatType>())
if (auto fb = b.dyn_cast<FloatType>())
return fa.getWidth() > fb.getWidth();
+ if (auto va = a.dyn_cast<VectorType>())
+ if (auto vb = b.dyn_cast<VectorType>())
+ return va.getShape().equals(vb.getShape()) &&
+ areCastCompatible(va.getElementType(), vb.getElementType());
return false;
}
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 68aeef8a2e1f..660fc79fb310 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -485,6 +485,18 @@ func @fpext(%arg0 : f16, %arg1 : f32) {
return
}
+// Checking conversion of integer types to floating point.
+// CHECK-LABEL: @fpext
+func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
+// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x half>"> to !llvm<"<2 x float>">
+ %0 = fpext %arg0: vector<2xf16> to vector<2xf32>
+// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x half>"> to !llvm<"<2 x double>">
+ %1 = fpext %arg0: vector<2xf16> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x double>">
+ %2 = fpext %arg1: vector<2xf32> to vector<2xf64>
+ return
+}
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func @fptrunc(%arg0 : f32, %arg1 : f64) {
@@ -497,6 +509,18 @@ func @fptrunc(%arg0 : f32, %arg1 : f64) {
return
}
+// Checking conversion of integer types to floating point.
+// CHECK-LABEL: @fptrunc
+func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x half>">
+ %0 = fptrunc %arg0: vector<2xf32> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x double>"> to !llvm<"<2 x half>">
+ %1 = fptrunc %arg1: vector<2xf64> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x double>"> to !llvm<"<2 x float>">
+ %2 = fptrunc %arg1: vector<2xf64> to vector<2xf32>
+ return
+}
+
// Check sign and zero extension and truncation of integers.
// CHECK-LABEL: @integer_extension_and_truncation
func @integer_extension_and_truncation() {
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index c07931f01f8c..28ef33edba06 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -506,6 +506,12 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32>
%142 = sqrt %t : tensor<4x4x?xf32>
+ // CHECK: = fpext {{.*}} : vector<4xf32> to vector<4xf64>
+ %143 = fpext %vcf32 : vector<4xf32> to vector<4xf64>
+
+ // CHECK: = fptrunc {{.*}} : vector<4xf32> to vector<4xf16>
+ %144 = fptrunc %vcf32 : vector<4xf32> to vector<4xf16>
+
return
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 7cc0331bd484..5b43103e9018 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -563,6 +563,46 @@ func @fpext_f32_to_i32(%arg0 : f32) {
// -----
+func @fpext_vec(%arg0 : vector<2xf16>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = fpext %arg0 : vector<2xf16> to vector<3xf32>
+ return
+}
+
+// -----
+
+func @fpext_vec_f32_to_f16(%arg0 : vector<2xf32>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = fpext %arg0 : vector<2xf32> to vector<2xf16>
+ return
+}
+
+// -----
+
+func @fpext_vec_f16_to_f16(%arg0 : vector<2xf16>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = fpext %arg0 : vector<2xf16> to vector<2xf16>
+ return
+}
+
+// -----
+
+func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = fpext %arg0 : vector<2xi32> to vector<2xf32>
+ return
+}
+
+// -----
+
+func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = fpext %arg0 : vector<2xf32> to vector<2xi32>
+ return
+}
+
+// -----
+
func @fptrunc_f16_to_f32(%arg0 : f16) {
// expected-error at +1 {{are cast incompatible}}
%0 = fptrunc %arg0 : f16 to f32
@@ -595,6 +635,46 @@ func @fptrunc_f32_to_i32(%arg0 : f32) {
// -----
+func @fptrunc_vec(%arg0 : vector<2xf16>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32>
+ return
+}
+
+// -----
+
+func @fptrunc_vec_f16_to_f32(%arg0 : vector<2xf16>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = fptrunc %arg0 : vector<2xf16> to vector<2xf32>
+ return
+}
+
+// -----
+
+func @fptrunc_vec_f32_to_f32(%arg0 : vector<2xf32>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = fptrunc %arg0 : vector<2xf32> to vector<2xf32>
+ return
+}
+
+// -----
+
+func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = fptrunc %arg0 : vector<2xi32> to vector<2xf32>
+ return
+}
+
+// -----
+
+func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
+ // expected-error at +1 {{are cast incompatible}}
+ %0 = fptrunc %arg0 : vector<2xf32> to vector<2xi32>
+ return
+}
+
+// -----
+
func @sexti_index_as_operand(%arg0 : index) {
// expected-error at +1 {{'index' is not a valid operand type}}
%0 = sexti %arg0 : index to i128
More information about the Mlir-commits
mailing list