[Mlir-commits] [mlir] 67c1615 - [MLIR] Add vector support for fpexp and fptrunc.

Tim Shen llvmlistbot at llvm.org
Fri Feb 28 12:48:40 PST 2020


Author: Tim Shen
Date: 2020-02-28T12:24:45-08:00
New Revision: 67c1615440ca399109132b95290fc4b948968c9e

URL: https://github.com/llvm/llvm-project/commit/67c1615440ca399109132b95290fc4b948968c9e
DIFF: https://github.com/llvm/llvm-project/commit/67c1615440ca399109132b95290fc4b948968c9e.diff

LOG: [MLIR] Add vector support for fpexp and fptrunc.

Differential Revision: https://reviews.llvm.org/D75150

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 65bd714d7881..4d88aa8e99da 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2601,6 +2601,10 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
   if (auto fa = a.dyn_cast<FloatType>())
     if (auto fb = b.dyn_cast<FloatType>())
       return fa.getWidth() < fb.getWidth();
+  if (auto va = a.dyn_cast<VectorType>())
+    if (auto vb = b.dyn_cast<VectorType>())
+      return va.getShape().equals(vb.getShape()) &&
+             areCastCompatible(va.getElementType(), vb.getElementType());
   return false;
 }
 
@@ -2612,6 +2616,10 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
   if (auto fa = a.dyn_cast<FloatType>())
     if (auto fb = b.dyn_cast<FloatType>())
       return fa.getWidth() > fb.getWidth();
+  if (auto va = a.dyn_cast<VectorType>())
+    if (auto vb = b.dyn_cast<VectorType>())
+      return va.getShape().equals(vb.getShape()) &&
+             areCastCompatible(va.getElementType(), vb.getElementType());
   return false;
 }
 

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 68aeef8a2e1f..660fc79fb310 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -485,6 +485,18 @@ func @fpext(%arg0 : f16, %arg1 : f32) {
   return
 }
 
+// Checking conversion of integer types to floating point.
+// CHECK-LABEL: @fpext
+func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
+// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x half>"> to !llvm<"<2 x float>">
+  %0 = fpext %arg0: vector<2xf16> to vector<2xf32>
+// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x half>"> to !llvm<"<2 x double>">
+  %1 = fpext %arg0: vector<2xf16> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x double>">
+  %2 = fpext %arg1: vector<2xf32> to vector<2xf64>
+  return
+}
+
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc
 func @fptrunc(%arg0 : f32, %arg1 : f64) {
@@ -497,6 +509,18 @@ func @fptrunc(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// Checking conversion of integer types to floating point.
+// CHECK-LABEL: @fptrunc
+func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x half>">
+  %0 = fptrunc %arg0: vector<2xf32> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x double>"> to !llvm<"<2 x half>">
+  %1 = fptrunc %arg1: vector<2xf64> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x double>"> to !llvm<"<2 x float>">
+  %2 = fptrunc %arg1: vector<2xf64> to vector<2xf32>
+  return
+}
+
 // Check sign and zero extension and truncation of integers.
 // CHECK-LABEL: @integer_extension_and_truncation
 func @integer_extension_and_truncation() {

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index c07931f01f8c..28ef33edba06 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -506,6 +506,12 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
   // CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32>
   %142 = sqrt %t : tensor<4x4x?xf32>
 
+  // CHECK: = fpext {{.*}} : vector<4xf32> to vector<4xf64>
+  %143 = fpext %vcf32 : vector<4xf32> to vector<4xf64>
+
+  // CHECK: = fptrunc {{.*}} : vector<4xf32> to vector<4xf16>
+  %144 = fptrunc %vcf32 : vector<4xf32> to vector<4xf16>
+
   return
 }
 

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 7cc0331bd484..5b43103e9018 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -563,6 +563,46 @@ func @fpext_f32_to_i32(%arg0 : f32) {
 
 // -----
 
+func @fpext_vec(%arg0 : vector<2xf16>) {
+  // expected-error at +1 {{are cast incompatible}}
+  %0 = fpext %arg0 : vector<2xf16> to vector<3xf32>
+  return
+}
+
+// -----
+
+func @fpext_vec_f32_to_f16(%arg0 : vector<2xf32>) {
+  // expected-error at +1 {{are cast incompatible}}
+  %0 = fpext %arg0 : vector<2xf32> to vector<2xf16>
+  return
+}
+
+// -----
+
+func @fpext_vec_f16_to_f16(%arg0 : vector<2xf16>) {
+  // expected-error at +1 {{are cast incompatible}}
+  %0 = fpext %arg0 : vector<2xf16> to vector<2xf16>
+  return
+}
+
+// -----
+
+func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) {
+  // expected-error at +1 {{are cast incompatible}}
+  %0 = fpext %arg0 : vector<2xi32> to vector<2xf32>
+  return
+}
+
+// -----
+
+func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) {
+  // expected-error at +1 {{are cast incompatible}}
+  %0 = fpext %arg0 : vector<2xf32> to vector<2xi32>
+  return
+}
+
+// -----
+
 func @fptrunc_f16_to_f32(%arg0 : f16) {
   // expected-error at +1 {{are cast incompatible}}
   %0 = fptrunc %arg0 : f16 to f32
@@ -595,6 +635,46 @@ func @fptrunc_f32_to_i32(%arg0 : f32) {
 
 // -----
 
+func @fptrunc_vec(%arg0 : vector<2xf16>) {
+  // expected-error at +1 {{are cast incompatible}}
+  %0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32>
+  return
+}
+
+// -----
+
+func @fptrunc_vec_f16_to_f32(%arg0 : vector<2xf16>) {
+  // expected-error at +1 {{are cast incompatible}}
+  %0 = fptrunc %arg0 : vector<2xf16> to vector<2xf32>
+  return
+}
+
+// -----
+
+func @fptrunc_vec_f32_to_f32(%arg0 : vector<2xf32>) {
+  // expected-error at +1 {{are cast incompatible}}
+  %0 = fptrunc %arg0 : vector<2xf32> to vector<2xf32>
+  return
+}
+
+// -----
+
+func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) {
+  // expected-error at +1 {{are cast incompatible}}
+  %0 = fptrunc %arg0 : vector<2xi32> to vector<2xf32>
+  return
+}
+
+// -----
+
+func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
+  // expected-error at +1 {{are cast incompatible}}
+  %0 = fptrunc %arg0 : vector<2xf32> to vector<2xi32>
+  return
+}
+
+// -----
+
 func @sexti_index_as_operand(%arg0 : index) {
   // expected-error at +1 {{'index' is not a valid operand type}}
   %0 = sexti %arg0 : index to i128


        


More information about the Mlir-commits mailing list