[Mlir-commits] [mlir] c97035c - [mlir][math][spirv] Add `math.roundeven` lowering to SPIR-V

Jakub Kuderski llvmlistbot at llvm.org
Tue Nov 1 09:51:48 PDT 2022


Author: Jakub Kuderski
Date: 2022-11-01T12:51:25-04:00
New Revision: c97035c49d941e5b196a938f0393f811d1adbd57

URL: https://github.com/llvm/llvm-project/commit/c97035c49d941e5b196a938f0393f811d1adbd57
DIFF: https://github.com/llvm/llvm-project/commit/c97035c49d941e5b196a938f0393f811d1adbd57.diff

LOG: [mlir][math][spirv] Add `math.roundeven` lowering to SPIR-V

This has two lowering path, one for each extended instructions set:
-  to OpenGL's `RoundEven`,
-  to OpenCL's `rint`.

Implement those two ops and add minimal tests.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D137171

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
    mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
    mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
    mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index 187622e55d35c..1e7ca2b97a0bd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -486,6 +486,39 @@ def SPIRV_CLRoundOp : SPIRV_CLUnaryArithmeticOp<"round", 55, SPIRV_Float> {
 
 // -----
 
+def SPIRV_CLRintOp : SPIRV_CLUnaryArithmeticOp<"rint", 53, SPIRV_Float> {
+  let summary = [{
+    Round x to integral value (using round to nearest even rounding mode) in
+    floating-point format.
+  }];
+
+  let description = [{
+    Result Type and x must be floating-point or vector(2,3,4,8,16) of
+    floating-point values.
+
+    All of the operands, including the Result Type operand, must be of the
+    same type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    rint-op ::= ssa-id `=` `spirv.CL.rint` ssa-use `:`
+               float-scalar-vector-type
+    ```
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.CL.rint %0 : f32
+    %1 = spirv.CL.rint %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
 def SPIRV_CLRsqrtOp : SPIRV_CLUnaryArithmeticOp<"rsqrt", 56, SPIRV_Float> {
   let summary = "Compute inverse square root of x.";
 
@@ -688,6 +721,8 @@ def SPIRV_CLUMaxOp : SPIRV_CLBinaryArithmeticOp<"u_max", 157, SPIRV_Integer> {
   }];
 }
 
+// -----
+
 def SPIRV_CLSMinOp : SPIRV_CLBinaryArithmeticOp<"s_min", 158, SPIRV_Integer> {
   let summary = "Return minimum of two signed integer operands";
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
index 3dd5219592b87..377eae3858e37 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
@@ -456,10 +456,13 @@ def SPIRV_GLFloorOp : SPIRV_GLUnaryArithmeticOp<"Floor", 8, SPIRV_Float> {
 // -----
 
 def SPIRV_GLRoundOp: SPIRV_GLUnaryArithmeticOp<"Round", 1, SPIRV_Float> {
-  let summary = "Rounds to the whole number";
+  let summary = "Rounds to the nearest whole number";
 
   let description = [{
-    Result is the value equal to the nearest whole number.
+    Result is the value equal to the nearest whole number to x. The fraction
+    0.5 will round in a direction chosen by the implementation, presumably
+    the direction that is fastest. This includes the possibility that
+    Round x is the same value as RoundEven x for all values of x.
 
     The operand x must be a scalar or vector whose component type is
     floating-point.
@@ -471,7 +474,7 @@ def SPIRV_GLRoundOp: SPIRV_GLUnaryArithmeticOp<"Round", 1, SPIRV_Float> {
     ```
     float-scalar-vector-type ::= float-type |
                                  `vector<` integer-literal `x` float-type `>`
-    floor-op ::= ssa-id `=` `spirv.GL.Round` ssa-use `:`
+    round-op ::= ssa-id `=` `spirv.GL.Round` ssa-use `:`
                 float-scalar-vector-type
     ```
     #### Example:
@@ -485,6 +488,38 @@ def SPIRV_GLRoundOp: SPIRV_GLUnaryArithmeticOp<"Round", 1, SPIRV_Float> {
 
 // -----
 
+def SPIRV_GLRoundEvenOp: SPIRV_GLUnaryArithmeticOp<"RoundEven", 2, SPIRV_Float> {
+  let summary = "Rounds to the nearest even whole number";
+
+  let description = [{
+    Result is the value equal to the nearest whole number to x. A fractional
+    part of 0.5 will round toward the nearest even whole number. (Both 3.5 and
+    4.5 for x will be 4.0.)
+
+    The operand x must be a scalar or vector whose component type is
+    floating-point.
+
+    Result Type and the type of x must be the same type. Results are computed
+    per component.
+
+    <!-- End of AutoGen section -->
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    round-even-op ::= ssa-id `=` `spirv.GL.RoundEven` ssa-use `:`
+                float-scalar-vector-type
+    ```
+    #### Example:
+
+    ```mlir
+    %2 = spirv.GL.RoundEven %0 : f32
+    %3 = spirv.GL.RoundEven %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
 def SPIRV_GLInverseSqrtOp : SPIRV_GLUnaryArithmeticOp<"InverseSqrt", 32, SPIRV_Float> {
   let summary = "Reciprocal of sqrt(operand)";
 

diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 55a242991bbcf..242000485d554 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -295,6 +295,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
            spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
            spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
            spirv::ElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
+           spirv::ElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
            spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
            spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
            spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
@@ -312,6 +313,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                spirv::ElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
                spirv::ElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
                spirv::ElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
+               spirv::ElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
                spirv::ElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
                spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
                spirv::ElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 5241c8a857106..df39036277ee6 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -20,20 +20,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.GL.Log %[[ADDONE]]
   %4 = math.log1p %arg0 : f32
+  // CHECK: spirv.GL.RoundEven %{{.*}}: f32
+  %5 = math.roundeven %arg0 : f32
   // CHECK: spirv.GL.InverseSqrt %{{.*}}: f32
-  %5 = math.rsqrt %arg0 : f32
+  %6 = math.rsqrt %arg0 : f32
   // CHECK: spirv.GL.Sqrt %{{.*}}: f32
-  %6 = math.sqrt %arg0 : f32
+  %7 = math.sqrt %arg0 : f32
   // CHECK: spirv.GL.Tanh %{{.*}}: f32
-  %7 = math.tanh %arg0 : f32
+  %8 = math.tanh %arg0 : f32
   // CHECK: spirv.GL.Sin %{{.*}}: f32
-  %8 = math.sin %arg0 : f32
+  %9 = math.sin %arg0 : f32
   // CHECK: spirv.GL.FAbs %{{.*}}: f32
-  %9 = math.absf %arg0 : f32
+  %10 = math.absf %arg0 : f32
   // CHECK: spirv.GL.Ceil %{{.*}}: f32
-  %10 = math.ceil %arg0 : f32
+  %11 = math.ceil %arg0 : f32
   // CHECK: spirv.GL.Floor %{{.*}}: f32
-  %11 = math.floor %arg0 : f32
+  %12 = math.floor %arg0 : f32
   return
 }
 
@@ -53,14 +55,16 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.GL.Log %[[ADDONE]]
   %4 = math.log1p %arg0 : vector<3xf32>
+  // CHECK: spirv.GL.RoundEven %{{.*}}: vector<3xf32>
+  %5 = math.roundeven %arg0 : vector<3xf32>
   // CHECK: spirv.GL.InverseSqrt %{{.*}}: vector<3xf32>
-  %5 = math.rsqrt %arg0 : vector<3xf32>
+  %6 = math.rsqrt %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Sqrt %{{.*}}: vector<3xf32>
-  %6 = math.sqrt %arg0 : vector<3xf32>
+  %7 = math.sqrt %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Tanh %{{.*}}: vector<3xf32>
-  %7 = math.tanh %arg0 : vector<3xf32>
+  %8 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
-  %8 = math.sin %arg0 : vector<3xf32>
+  %9 = math.sin %arg0 : vector<3xf32>
   return
 }
 

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index a85c7e6e06c9a..6897cfd9f2f5a 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -18,24 +18,26 @@ func.func @float32_unary_scalar(%arg0: f32) {
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.CL.log %[[ADDONE]]
   %4 = math.log1p %arg0 : f32
+  // CHECK: spirv.CL.rint %{{.*}}: f32
+  %5 = math.roundeven %arg0 : f32
   // CHECK: spirv.CL.rsqrt %{{.*}}: f32
-  %5 = math.rsqrt %arg0 : f32
+  %6 = math.rsqrt %arg0 : f32
   // CHECK: spirv.CL.sqrt %{{.*}}: f32
-  %6 = math.sqrt %arg0 : f32
+  %7 = math.sqrt %arg0 : f32
   // CHECK: spirv.CL.tanh %{{.*}}: f32
-  %7 = math.tanh %arg0 : f32
+  %8 = math.tanh %arg0 : f32
   // CHECK: spirv.CL.sin %{{.*}}: f32
-  %8 = math.sin %arg0 : f32
+  %9 = math.sin %arg0 : f32
   // CHECK: spirv.CL.fabs %{{.*}}: f32
-  %9 = math.absf %arg0 : f32
+  %10 = math.absf %arg0 : f32
   // CHECK: spirv.CL.ceil %{{.*}}: f32
-  %10 = math.ceil %arg0 : f32
+  %11 = math.ceil %arg0 : f32
   // CHECK: spirv.CL.floor %{{.*}}: f32
-  %11 = math.floor %arg0 : f32
+  %12 = math.floor %arg0 : f32
   // CHECK: spirv.CL.erf %{{.*}}: f32
-  %12 = math.erf %arg0 : f32
+  %13 = math.erf %arg0 : f32
   // CHECK: spirv.CL.round %{{.*}}: f32
-  %13 = math.round %arg0 : f32
+  %14 = math.round %arg0 : f32
   return
 }
 
@@ -55,14 +57,16 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
   // CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
   // CHECK: spirv.CL.log %[[ADDONE]]
   %4 = math.log1p %arg0 : vector<3xf32>
+  // CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
+  %5 = math.roundeven %arg0 : vector<3xf32>
   // CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
-  %5 = math.rsqrt %arg0 : vector<3xf32>
+  %6 = math.rsqrt %arg0 : vector<3xf32>
   // CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
-  %6 = math.sqrt %arg0 : vector<3xf32>
+  %7 = math.sqrt %arg0 : vector<3xf32>
   // CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
-  %7 = math.tanh %arg0 : vector<3xf32>
+  %8 = math.tanh %arg0 : vector<3xf32>
   // CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
-  %8 = math.sin %arg0 : vector<3xf32>
+  %9 = math.sin %arg0 : vector<3xf32>
   return
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 760ffadaa9eaf..3683e5b469b17 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -288,6 +288,22 @@ func.func @roundvec(%arg0 : vector<3xf16>) -> () {
   return
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.GL.RoundEven
+//===----------------------------------------------------------------------===//
+
+func.func @round_even(%arg0 : f32) -> () {
+  // CHECK: spirv.GL.RoundEven {{%.*}} : f32
+  %2 = spirv.GL.RoundEven %arg0 : f32
+  return
+}
+
+func.func @round_even_vec(%arg0 : vector<3xf16>) -> () {
+  // CHECK: spirv.GL.RoundEven {{%.*}} : vector<3xf16>
+  %2 = spirv.GL.RoundEven %arg0 : vector<3xf16>
+  return
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 8e25c272220f8..57a65d446d3fb 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -227,3 +227,23 @@ func.func @iminmax(%arg0: i32, %arg1: i32) {
   %4 = spirv.CL.u_min %arg0, %arg1 : i32
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.rint
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @rint(
+func.func @rint(%arg0 : f32) -> () {
+  // CHECK: spirv.CL.rint {{%.*}} : f32
+  %0 = spirv.CL.rint %arg0 : f32
+  return
+}
+
+// CHECK-LABEL: func.func @rintvec(
+func.func @rintvec(%arg0 : vector<3xf16>) -> () {
+  // CHECK: spirv.CL.rint {{%.*}} : vector<3xf16>
+  %0 = spirv.CL.rint %arg0 : vector<3xf16>
+  return
+}


        


More information about the Mlir-commits mailing list