[Mlir-commits] [mlir] c97035c - [mlir][math][spirv] Add `math.roundeven` lowering to SPIR-V
Jakub Kuderski
llvmlistbot at llvm.org
Tue Nov 1 09:51:48 PDT 2022
Author: Jakub Kuderski
Date: 2022-11-01T12:51:25-04:00
New Revision: c97035c49d941e5b196a938f0393f811d1adbd57
URL: https://github.com/llvm/llvm-project/commit/c97035c49d941e5b196a938f0393f811d1adbd57
DIFF: https://github.com/llvm/llvm-project/commit/c97035c49d941e5b196a938f0393f811d1adbd57.diff
LOG: [mlir][math][spirv] Add `math.roundeven` lowering to SPIR-V
This has two lowering path, one for each extended instructions set:
- to OpenGL's `RoundEven`,
- to OpenCL's `rint`.
Implement those two ops and add minimal tests.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D137171
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index 187622e55d35c..1e7ca2b97a0bd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -486,6 +486,39 @@ def SPIRV_CLRoundOp : SPIRV_CLUnaryArithmeticOp<"round", 55, SPIRV_Float> {
// -----
+def SPIRV_CLRintOp : SPIRV_CLUnaryArithmeticOp<"rint", 53, SPIRV_Float> {
+ let summary = [{
+ Round x to integral value (using round to nearest even rounding mode) in
+ floating-point format.
+ }];
+
+ let description = [{
+ Result Type and x must be floating-point or vector(2,3,4,8,16) of
+ floating-point values.
+
+ All of the operands, including the Result Type operand, must be of the
+ same type.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ float-scalar-vector-type ::= float-type |
+ `vector<` integer-literal `x` float-type `>`
+ rint-op ::= ssa-id `=` `spirv.CL.rint` ssa-use `:`
+ float-scalar-vector-type
+ ```
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.CL.rint %0 : f32
+ %1 = spirv.CL.rint %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
def SPIRV_CLRsqrtOp : SPIRV_CLUnaryArithmeticOp<"rsqrt", 56, SPIRV_Float> {
let summary = "Compute inverse square root of x.";
@@ -688,6 +721,8 @@ def SPIRV_CLUMaxOp : SPIRV_CLBinaryArithmeticOp<"u_max", 157, SPIRV_Integer> {
}];
}
+// -----
+
def SPIRV_CLSMinOp : SPIRV_CLBinaryArithmeticOp<"s_min", 158, SPIRV_Integer> {
let summary = "Return minimum of two signed integer operands";
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
index 3dd5219592b87..377eae3858e37 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
@@ -456,10 +456,13 @@ def SPIRV_GLFloorOp : SPIRV_GLUnaryArithmeticOp<"Floor", 8, SPIRV_Float> {
// -----
def SPIRV_GLRoundOp: SPIRV_GLUnaryArithmeticOp<"Round", 1, SPIRV_Float> {
- let summary = "Rounds to the whole number";
+ let summary = "Rounds to the nearest whole number";
let description = [{
- Result is the value equal to the nearest whole number.
+ Result is the value equal to the nearest whole number to x. The fraction
+ 0.5 will round in a direction chosen by the implementation, presumably
+ the direction that is fastest. This includes the possibility that
+ Round x is the same value as RoundEven x for all values of x.
The operand x must be a scalar or vector whose component type is
floating-point.
@@ -471,7 +474,7 @@ def SPIRV_GLRoundOp: SPIRV_GLUnaryArithmeticOp<"Round", 1, SPIRV_Float> {
```
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
- floor-op ::= ssa-id `=` `spirv.GL.Round` ssa-use `:`
+ round-op ::= ssa-id `=` `spirv.GL.Round` ssa-use `:`
float-scalar-vector-type
```
#### Example:
@@ -485,6 +488,38 @@ def SPIRV_GLRoundOp: SPIRV_GLUnaryArithmeticOp<"Round", 1, SPIRV_Float> {
// -----
+def SPIRV_GLRoundEvenOp: SPIRV_GLUnaryArithmeticOp<"RoundEven", 2, SPIRV_Float> {
+ let summary = "Rounds to the nearest even whole number";
+
+ let description = [{
+ Result is the value equal to the nearest whole number to x. A fractional
+ part of 0.5 will round toward the nearest even whole number. (Both 3.5 and
+ 4.5 for x will be 4.0.)
+
+ The operand x must be a scalar or vector whose component type is
+ floating-point.
+
+ Result Type and the type of x must be the same type. Results are computed
+ per component.
+
+ <!-- End of AutoGen section -->
+ ```
+ float-scalar-vector-type ::= float-type |
+ `vector<` integer-literal `x` float-type `>`
+ round-even-op ::= ssa-id `=` `spirv.GL.RoundEven` ssa-use `:`
+ float-scalar-vector-type
+ ```
+ #### Example:
+
+ ```mlir
+ %2 = spirv.GL.RoundEven %0 : f32
+ %3 = spirv.GL.RoundEven %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
def SPIRV_GLInverseSqrtOp : SPIRV_GLUnaryArithmeticOp<"InverseSqrt", 32, SPIRV_Float> {
let summary = "Reciprocal of sqrt(operand)";
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 55a242991bbcf..242000485d554 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -295,6 +295,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
+ spirv::ElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
@@ -312,6 +313,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
spirv::ElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
+ spirv::ElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
spirv::ElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
spirv::ElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
index 5241c8a857106..df39036277ee6 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
@@ -20,20 +20,22 @@ func.func @float32_unary_scalar(%arg0: f32) {
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.GL.Log %[[ADDONE]]
%4 = math.log1p %arg0 : f32
+ // CHECK: spirv.GL.RoundEven %{{.*}}: f32
+ %5 = math.roundeven %arg0 : f32
// CHECK: spirv.GL.InverseSqrt %{{.*}}: f32
- %5 = math.rsqrt %arg0 : f32
+ %6 = math.rsqrt %arg0 : f32
// CHECK: spirv.GL.Sqrt %{{.*}}: f32
- %6 = math.sqrt %arg0 : f32
+ %7 = math.sqrt %arg0 : f32
// CHECK: spirv.GL.Tanh %{{.*}}: f32
- %7 = math.tanh %arg0 : f32
+ %8 = math.tanh %arg0 : f32
// CHECK: spirv.GL.Sin %{{.*}}: f32
- %8 = math.sin %arg0 : f32
+ %9 = math.sin %arg0 : f32
// CHECK: spirv.GL.FAbs %{{.*}}: f32
- %9 = math.absf %arg0 : f32
+ %10 = math.absf %arg0 : f32
// CHECK: spirv.GL.Ceil %{{.*}}: f32
- %10 = math.ceil %arg0 : f32
+ %11 = math.ceil %arg0 : f32
// CHECK: spirv.GL.Floor %{{.*}}: f32
- %11 = math.floor %arg0 : f32
+ %12 = math.floor %arg0 : f32
return
}
@@ -53,14 +55,16 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.GL.Log %[[ADDONE]]
%4 = math.log1p %arg0 : vector<3xf32>
+ // CHECK: spirv.GL.RoundEven %{{.*}}: vector<3xf32>
+ %5 = math.roundeven %arg0 : vector<3xf32>
// CHECK: spirv.GL.InverseSqrt %{{.*}}: vector<3xf32>
- %5 = math.rsqrt %arg0 : vector<3xf32>
+ %6 = math.rsqrt %arg0 : vector<3xf32>
// CHECK: spirv.GL.Sqrt %{{.*}}: vector<3xf32>
- %6 = math.sqrt %arg0 : vector<3xf32>
+ %7 = math.sqrt %arg0 : vector<3xf32>
// CHECK: spirv.GL.Tanh %{{.*}}: vector<3xf32>
- %7 = math.tanh %arg0 : vector<3xf32>
+ %8 = math.tanh %arg0 : vector<3xf32>
// CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
- %8 = math.sin %arg0 : vector<3xf32>
+ %9 = math.sin %arg0 : vector<3xf32>
return
}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index a85c7e6e06c9a..6897cfd9f2f5a 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -18,24 +18,26 @@ func.func @float32_unary_scalar(%arg0: f32) {
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.CL.log %[[ADDONE]]
%4 = math.log1p %arg0 : f32
+ // CHECK: spirv.CL.rint %{{.*}}: f32
+ %5 = math.roundeven %arg0 : f32
// CHECK: spirv.CL.rsqrt %{{.*}}: f32
- %5 = math.rsqrt %arg0 : f32
+ %6 = math.rsqrt %arg0 : f32
// CHECK: spirv.CL.sqrt %{{.*}}: f32
- %6 = math.sqrt %arg0 : f32
+ %7 = math.sqrt %arg0 : f32
// CHECK: spirv.CL.tanh %{{.*}}: f32
- %7 = math.tanh %arg0 : f32
+ %8 = math.tanh %arg0 : f32
// CHECK: spirv.CL.sin %{{.*}}: f32
- %8 = math.sin %arg0 : f32
+ %9 = math.sin %arg0 : f32
// CHECK: spirv.CL.fabs %{{.*}}: f32
- %9 = math.absf %arg0 : f32
+ %10 = math.absf %arg0 : f32
// CHECK: spirv.CL.ceil %{{.*}}: f32
- %10 = math.ceil %arg0 : f32
+ %11 = math.ceil %arg0 : f32
// CHECK: spirv.CL.floor %{{.*}}: f32
- %11 = math.floor %arg0 : f32
+ %12 = math.floor %arg0 : f32
// CHECK: spirv.CL.erf %{{.*}}: f32
- %12 = math.erf %arg0 : f32
+ %13 = math.erf %arg0 : f32
// CHECK: spirv.CL.round %{{.*}}: f32
- %13 = math.round %arg0 : f32
+ %14 = math.round %arg0 : f32
return
}
@@ -55,14 +57,16 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.CL.log %[[ADDONE]]
%4 = math.log1p %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
+ %5 = math.roundeven %arg0 : vector<3xf32>
// CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
- %5 = math.rsqrt %arg0 : vector<3xf32>
+ %6 = math.rsqrt %arg0 : vector<3xf32>
// CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
- %6 = math.sqrt %arg0 : vector<3xf32>
+ %7 = math.sqrt %arg0 : vector<3xf32>
// CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
- %7 = math.tanh %arg0 : vector<3xf32>
+ %8 = math.tanh %arg0 : vector<3xf32>
// CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
- %8 = math.sin %arg0 : vector<3xf32>
+ %9 = math.sin %arg0 : vector<3xf32>
return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 760ffadaa9eaf..3683e5b469b17 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -288,6 +288,22 @@ func.func @roundvec(%arg0 : vector<3xf16>) -> () {
return
}
+//===----------------------------------------------------------------------===//
+// spirv.GL.RoundEven
+//===----------------------------------------------------------------------===//
+
+func.func @round_even(%arg0 : f32) -> () {
+ // CHECK: spirv.GL.RoundEven {{%.*}} : f32
+ %2 = spirv.GL.RoundEven %arg0 : f32
+ return
+}
+
+func.func @round_even_vec(%arg0 : vector<3xf16>) -> () {
+ // CHECK: spirv.GL.RoundEven {{%.*}} : vector<3xf16>
+ %2 = spirv.GL.RoundEven %arg0 : vector<3xf16>
+ return
+}
+
// -----
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 8e25c272220f8..57a65d446d3fb 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -227,3 +227,23 @@ func.func @iminmax(%arg0: i32, %arg1: i32) {
%4 = spirv.CL.u_min %arg0, %arg1 : i32
return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.rint
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @rint(
+func.func @rint(%arg0 : f32) -> () {
+ // CHECK: spirv.CL.rint {{%.*}} : f32
+ %0 = spirv.CL.rint %arg0 : f32
+ return
+}
+
+// CHECK-LABEL: func.func @rintvec(
+func.func @rintvec(%arg0 : vector<3xf16>) -> () {
+ // CHECK: spirv.CL.rint {{%.*}} : vector<3xf16>
+ %0 = spirv.CL.rint %arg0 : vector<3xf16>
+ return
+}
More information about the Mlir-commits
mailing list