[Mlir-commits] [mlir] ef7a255 - Add support for casting elements in vectors for certain Std dialect type conversion operations.

Lubomir Litchev llvmlistbot at llvm.org
Mon Sep 14 07:45:57 PDT 2020


Author: Lubomir Litchev
Date: 2020-09-14T07:45:46-07:00
New Revision: ef7a255c037ca462f71ddd3d2b5a46310b08f6eb

URL: https://github.com/llvm/llvm-project/commit/ef7a255c037ca462f71ddd3d2b5a46310b08f6eb
DIFF: https://github.com/llvm/llvm-project/commit/ef7a255c037ca462f71ddd3d2b5a46310b08f6eb.diff

LOG: Add support for casting elements in vectors for certain Std dialect type conversion operations.

Added support to the Std dialect cast operations to do casts in vector types when feasible.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D87410

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index afdc3edae86c..4d0cf76ec9d8 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2443,10 +2443,10 @@ def SignExtendIOp : Std_Op<"sexti",
 def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
   let summary = "cast from integer type to floating-point";
   let description = [{
-    Cast from a value interpreted as signed integer to the corresponding
-    floating-point value. If the value cannot be exactly represented, it is
-    rounded using the default rounding mode. Only scalars are currently
-    supported.
+    Cast from a value interpreted as signed or vector of signed integers to the
+    corresponding floating-point scalar or vector value. If the value cannot be
+    exactly represented, it is rounded using the default rounding mode. Scalars
+    and vector types are currently supported.
   }];
 
   let extraClassDeclaration = [{
@@ -3124,10 +3124,10 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
 def UIToFPOp : CastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
   let summary = "cast from unsigned integer type to floating-point";
   let description = [{
-    Cast from a value interpreted as unsigned integer to the corresponding
-    floating-point value. If the value cannot be exactly represented, it is
-    rounded using the default rounding mode. Only scalars are currently
-    supported.
+    Cast from a value interpreted as unsigned integer or vector of unsigned
+    integers to the corresponding scalar or vector floating-point value. If the
+    value cannot be exactly represented, it is rounded using the default
+    rounding mode. Scalars and vector types are currently supported.
   }];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index cf085a604b46..c77bc12cca33 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -217,6 +217,26 @@ static LogicalResult foldMemRefCast(Operation *op) {
   return success(folded);
 }
 
+//===----------------------------------------------------------------------===//
+// Common cast compatibility check for vector types.
+//===----------------------------------------------------------------------===//
+
+/// This method checks for cast compatibility of vector types.
+/// If 'a' and 'b' are vector types, and they are cast compatible,
+/// it calls the 'areElementsCastCompatible' function to check for
+/// element cast compatibility.
+/// Returns 'true' if the vector types are cast compatible,  and 'false'
+/// otherwise.
+static bool areVectorCastSimpleCompatible(
+    Type a, Type b, function_ref<bool(Type, Type)> areElementsCastCompatible) {
+  if (auto va = a.dyn_cast<VectorType>())
+    if (auto vb = b.dyn_cast<VectorType>())
+      return va.getShape().equals(vb.getShape()) &&
+             areElementsCastCompatible(va.getElementType(),
+                                       vb.getElementType());
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // AddFOp
 //===----------------------------------------------------------------------===//
@@ -1816,11 +1836,7 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
   if (auto fa = a.dyn_cast<FloatType>())
     if (auto fb = b.dyn_cast<FloatType>())
       return fa.getWidth() < fb.getWidth();
-  if (auto va = a.dyn_cast<VectorType>())
-    if (auto vb = b.dyn_cast<VectorType>())
-      return va.getShape().equals(vb.getShape()) &&
-             areCastCompatible(va.getElementType(), vb.getElementType());
-  return false;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1828,7 +1844,9 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
 //===----------------------------------------------------------------------===//
 
 bool FPToSIOp::areCastCompatible(Type a, Type b) {
-  return a.isa<FloatType>() && b.isSignlessInteger();
+  if (a.isa<FloatType>() && b.isSignlessInteger())
+    return true;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1836,7 +1854,9 @@ bool FPToSIOp::areCastCompatible(Type a, Type b) {
 //===----------------------------------------------------------------------===//
 
 bool FPToUIOp::areCastCompatible(Type a, Type b) {
-  return a.isa<FloatType>() && b.isSignlessInteger();
+  if (a.isa<FloatType>() && b.isSignlessInteger())
+    return true;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1847,11 +1867,7 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
   if (auto fa = a.dyn_cast<FloatType>())
     if (auto fb = b.dyn_cast<FloatType>())
       return fa.getWidth() > fb.getWidth();
-  if (auto va = a.dyn_cast<VectorType>())
-    if (auto vb = b.dyn_cast<VectorType>())
-      return va.getShape().equals(vb.getShape()) &&
-             areCastCompatible(va.getElementType(), vb.getElementType());
-  return false;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2291,7 +2307,9 @@ OpFoldResult SignedRemIOp::fold(ArrayRef<Attribute> operands) {
 
 // sitofp is applicable from integer types to float types.
 bool SIToFPOp::areCastCompatible(Type a, Type b) {
-  return a.isSignlessInteger() && b.isa<FloatType>();
+  if (a.isSignlessInteger() && b.isa<FloatType>())
+    return true;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2371,7 +2389,9 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
 
 // uitofp is applicable from integer types to float types.
 bool UIToFPOp::areCastCompatible(Type a, Type b) {
-  return a.isSignlessInteger() && b.isa<FloatType>();
+  if (a.isSignlessInteger() && b.isa<FloatType>())
+    return true;
+  return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 62be4783e364..bb0363b1cba5 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -594,6 +594,24 @@ func @sitofp(%arg0 : i32, %arg1 : i64) {
   return
 }
 
+// Checking conversion of integer vectors to floating point vector types.
+// CHECK-LABEL: @sitofp_vector
+func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float>
+  %0 = sitofp %arg0: vector<2xi16> to vector<2xf32>
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double>
+  %1 = sitofp %arg0: vector<2xi16> to vector<2xf64>
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float>
+  %2 = sitofp %arg1: vector<2xi32> to vector<2xf32>
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double>
+  %3 = sitofp %arg1: vector<2xi32> to vector<2xf64>
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float>
+  %4 = sitofp %arg2: vector<2xi64> to vector<2xf32>
+// CHECK-NEXT: = llvm.sitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double>
+  %5 = sitofp %arg2: vector<2xi64> to vector<2xf64>
+  return
+}
+
 // Checking conversion of unsigned integer types to floating point.
 // CHECK-LABEL: @uitofp
 func @uitofp(%arg0 : i32, %arg1 : i64) {
@@ -646,6 +664,24 @@ func @fptosi(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// Checking conversion of floating point vectors to integer vector types.
+// CHECK-LABEL: @fptosi_vector
+func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32>
+  %0 = fptosi %arg0: vector<2xf16> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64>
+  %1 = fptosi %arg0: vector<2xf16> to vector<2xi64>
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32>
+  %2 = fptosi %arg1: vector<2xf32> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64>
+  %3 = fptosi %arg1: vector<2xf32> to vector<2xi64>
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32>
+  %4 = fptosi %arg2: vector<2xf64> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptosi {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64>
+  %5 = fptosi %arg2: vector<2xf64> to vector<2xi64>
+  return
+}
+
 // Checking conversion of floating point to integer types.
 // CHECK-LABEL: @fptoui
 func @fptoui(%arg0 : f32, %arg1 : f64) {
@@ -660,6 +696,41 @@ func @fptoui(%arg0 : f32, %arg1 : f64) {
   return
 }
 
+// Checking conversion of floating point vectors to integer vector types.
+// CHECK-LABEL: @fptoui_vector
+func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i32>
+  %0 = fptoui %arg0: vector<2xf16> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x half> to !llvm.vec<2 x i64>
+  %1 = fptoui %arg0: vector<2xf16> to vector<2xi64>
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i32>
+  %2 = fptoui %arg1: vector<2xf32> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x float> to !llvm.vec<2 x i64>
+  %3 = fptoui %arg1: vector<2xf32> to vector<2xi64>
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i32>
+  %4 = fptoui %arg2: vector<2xf64> to vector<2xi32>
+// CHECK-NEXT: = llvm.fptoui {{.*}} : !llvm.vec<2 x double> to !llvm.vec<2 x i64>
+  %5 = fptoui %arg2: vector<2xf64> to vector<2xi64>
+  return
+}
+
+// Checking conversion of integer vectors to floating point vector types.
+// CHECK-LABEL: @uitofp_vector
+func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x float>
+  %0 = uitofp %arg0: vector<2xi16> to vector<2xf32>
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i16> to !llvm.vec<2 x double>
+  %1 = uitofp %arg0: vector<2xi16> to vector<2xf64>
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x float>
+  %2 = uitofp %arg1: vector<2xi32> to vector<2xf32>
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i32> to !llvm.vec<2 x double>
+  %3 = uitofp %arg1: vector<2xi32> to vector<2xf64>
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x float>
+  %4 = uitofp %arg2: vector<2xi64> to vector<2xf32>
+// CHECK-NEXT: = llvm.uitofp {{.*}} : !llvm.vec<2 x i64> to !llvm.vec<2 x double>
+  %5 = uitofp %arg2: vector<2xi64> to vector<2xf64>
+  return
+}
 
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc


        


More information about the Mlir-commits mailing list