[Mlir-commits] [mlir] 75a1bee - [mlir][spirv] Add math to OpenCL conversion

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 23 15:31:58 PST 2021


Author: Butygin
Date: 2021-11-24T02:31:21+03:00
New Revision: 75a1bee05db7ca4277cf93545834110409c75bc9

URL: https://github.com/llvm/llvm-project/commit/75a1bee05db7ca4277cf93545834110409c75bc9
DIFF: https://github.com/llvm/llvm-project/commit/75a1bee05db7ca4277cf93545834110409c75bc9.diff

LOG: [mlir][spirv] Add math to OpenCL conversion

Differential Revision: https://reviews.llvm.org/D113780

Added: 
    mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
    mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
    mlir/test/Target/SPIRV/ocl-ops.mlir

Removed: 
    mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
index 55b3a67bc336c..437805b508ecf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
@@ -22,7 +22,15 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 
 // Base class for all GLSL ops.
 class SPV_GLSLOp<string mnemonic, int opcode, list<OpTrait> traits = []> :
-  SPV_ExtInstOp<mnemonic, "GLSL", "GLSL.std.450", opcode, traits>;
+  SPV_ExtInstOp<mnemonic, "GLSL", "GLSL.std.450", opcode, traits> {
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_Shader]>
+  ];
+}
 
 // Base class for GLSL unary ops.
 class SPV_GLSLUnaryOp<string mnemonic, int opcode, Type resultType,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
index ecbe2b63356dd..ac773fd146c2c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
@@ -21,7 +21,15 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
 
 // Base class for all OpenCL ops.
 class SPV_OCLOp<string mnemonic, int opcode, list<OpTrait> traits = []> :
-  SPV_ExtInstOp<mnemonic, "OCL", "OpenCL.std", opcode, traits>;
+  SPV_ExtInstOp<mnemonic, "OCL", "OpenCL.std", opcode, traits> {
+
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_Kernel]>
+  ];
+}
 
 // Base class for OpenCL unary ops.
 class SPV_OCLUnaryOp<string mnemonic, int opcode, Type resultType,
@@ -78,6 +86,69 @@ class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
 
 // -----
 
+def SPV_OCLTanhOp : SPV_OCLUnaryArithmeticOp<"tanh", 63, SPV_Float> {
+  let summary = "Compute hyperbolic tangent of x radians.";
+
+  let description = [{
+    Result Type and x must be floating-point or vector(2,3,4,8,16) of
+    floating-point values.
+
+    All of the operands, including the Result Type operand, must be of the
+    same type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    tanh-op ::= ssa-id `=` `spv.OCL.tanh` ssa-use `:`
+               float-scalar-vector-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %2 = spv.OCL.tanh %0 : f32
+    %3 = spv.OCL.tanh %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_OCLCeilOp : SPV_OCLUnaryArithmeticOp<"ceil", 12, SPV_Float> {
+  let summary = [{
+    Round x to integral value using the round to positive infinity rounding
+    mode.
+  }];
+
+  let description = [{
+    Result Type and x must be floating-point or vector(2,3,4,8,16) of
+    floating-point values.
+
+    All of the operands, including the Result Type operand, must be of the
+    same type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    ceil-op ::= ssa-id `=` `spv.OCL.ceil` ssa-use `:`
+               float-scalar-vector-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %2 = spv.OCL.ceil %0 : f32
+    %3 = spv.OCL.ceil %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
 def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> {
   let summary = "Compute the cosine of x radians.";
 
@@ -93,7 +164,7 @@ def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> {
     ```
     float-scalar-vector-type ::= float-type |
                                  `vector<` integer-literal `x` float-type `>`
-    abs-op ::= ssa-id `=` `spv.OCL.cos` ssa-use `:`
+    cos-op ::= ssa-id `=` `spv.OCL.cos` ssa-use `:`
                float-scalar-vector-type
     ```mlir
 
@@ -168,6 +239,39 @@ def SPV_OCLFAbsOp : SPV_OCLUnaryArithmeticOp<"fabs", 23, SPV_Float> {
 
 // -----
 
+def SPV_OCLFloorOp : SPV_OCLUnaryArithmeticOp<"floor", 25, SPV_Float> {
+  let summary = [{
+    Round x to the integral value using the round to negative infinity
+    rounding mode.
+  }];
+
+  let description = [{
+    Result Type and x must be floating-point or vector(2,3,4,8,16) of
+    floating-point values.
+
+    All of the operands, including the Result Type operand, must be of the
+    same type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    floor-op ::= ssa-id `=` `spv.OCL.floor` ssa-use `:`
+               float-scalar-vector-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %2 = spv.OCL.floor %0 : f32
+    %3 = spv.OCL.ceifloorl %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
 def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
   let summary = "Compute the natural logarithm of x.";
 
@@ -183,7 +287,7 @@ def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
     ```
     float-scalar-vector-type ::= float-type |
                                  `vector<` integer-literal `x` float-type `>`
-    abs-op ::= ssa-id `=` `spv.OCL.log` ssa-use `:`
+    log-op ::= ssa-id `=` `spv.OCL.log` ssa-use `:`
                float-scalar-vector-type
     ```mlir
 
@@ -198,6 +302,67 @@ def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> {
 
 // -----
 
+def SPV_OCLPowOp : SPV_OCLBinaryArithmeticOp<"pow", 48, SPV_Float> {
+  let summary = "Compute x to the power y.";
+
+  let description = [{
+    Result Type, x and y must be floating-point or vector(2,3,4,8,16) of
+    floating-point values.
+
+    All of the operands, including the Result Type operand, must be of the
+    same type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    restricted-float-scalar-type ::=  `f16` | `f32`
+    restricted-float-scalar-vector-type ::=
+      restricted-float-scalar-type |
+      `vector<` integer-literal `x` restricted-float-scalar-type `>`
+    pow-op ::= ssa-id `=` `spv.OCL.pow` ssa-use `:`
+               restricted-float-scalar-vector-type
+    ```
+    #### Example:
+
+    ```mlir
+    %2 = spv.OCL.pow %0, %1 : f32
+    %3 = spv.OCL.pow %0, %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_OCLRsqrtOp : SPV_OCLUnaryArithmeticOp<"rsqrt", 56, SPV_Float> {
+  let summary = "Compute inverse square root of x.";
+
+  let description = [{
+    Result Type and x must be floating-point or vector(2,3,4,8,16) of
+    floating-point values.
+
+    All of the operands, including the Result Type operand, must be of the
+    same type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    rsqrt-op ::= ssa-id `=` `spv.OCL.rsqrt` ssa-use `:`
+               float-scalar-vector-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %2 = spv.OCL.rsqrt %0 : f32
+    %3 = spv.OCL.rsqrt %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
 def SPV_OCLSinOp : SPV_OCLUnaryArithmeticOp<"sin", 57, SPV_Float> {
   let summary = "Compute sine of x radians.";
 
@@ -213,7 +378,7 @@ def SPV_OCLSinOp : SPV_OCLUnaryArithmeticOp<"sin", 57, SPV_Float> {
     ```
     float-scalar-vector-type ::= float-type |
                                  `vector<` integer-literal `x` float-type `>`
-    abs-op ::= ssa-id `=` `spv.OCL.sin` ssa-use `:`
+    sin-op ::= ssa-id `=` `spv.OCL.sin` ssa-use `:`
                float-scalar-vector-type
     ```mlir
 
@@ -243,7 +408,7 @@ def SPV_OCLSqrtOp : SPV_OCLUnaryArithmeticOp<"sqrt", 61, SPV_Float> {
     ```
     float-scalar-vector-type ::= float-type |
                                  `vector<` integer-literal `x` float-type `>`
-    abs-op ::= ssa-id `=` `spv.OCL.sqrt` ssa-use `:`
+    sqrt-op ::= ssa-id `=` `spv.OCL.sqrt` ssa-use `:`
                float-scalar-vector-type
     ```mlir
 

diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index d1aa7865f88b1..9e96829e79c1d 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -34,6 +34,7 @@ namespace {
 ///
 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
 /// these operations.
+template <typename LogOp>
 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
 public:
   using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
@@ -48,7 +49,7 @@ class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
     auto onePlus =
         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
-    rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
+    rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
     return success();
   }
 };
@@ -61,8 +62,10 @@ class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
 namespace mlir {
 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                  RewritePatternSet &patterns) {
+
+  // GLSL patterns
   patterns.add<
-      Log1pOpPattern,
+      Log1pOpPattern<spirv::GLSLLogOp>,
       spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
       spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
       spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
@@ -75,6 +78,21 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
       spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
       typeConverter, patterns.getContext());
+
+  // OpenCL patterns
+  patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
+               spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
+               spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::OCLCeilOp>,
+               spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::OCLCosOp>,
+               spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::OCLExpOp>,
+               spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::OCLFloorOp>,
+               spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::OCLLogOp>,
+               spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::OCLPowOp>,
+               spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
+               spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::OCLSinOp>,
+               spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
+               spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
+      typeConverter, patterns.getContext());
 }
 
 } // namespace mlir

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index 8a41a90a2fc0b..5ab13e780d0d4 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -6,7 +6,7 @@
 
 module attributes {
   spv.target_env = #spv.target_env<
-    #spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, {}>
+    #spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, {}>
 } {
 
 // Check integer operation conversions.

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
similarity index 95%
rename from mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir
rename to mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index 65bdc18909e35..ad32a88a876ea 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
 
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, {}> } {
+
 // CHECK-LABEL: @float32_unary_scalar
 func @float32_unary_scalar(%arg0: f32) {
   // CHECK: spv.GLSL.Cos %{{.*}}: f32
@@ -59,3 +61,5 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
   %0 = math.powf %lhs, %rhs : vector<4xf32>
   return
 }
+
+} // end module

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
new file mode 100644
index 0000000000000..8a1a3acc5f0cd
--- /dev/null
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, {}> } {
+
+// CHECK-LABEL: @float32_unary_scalar
+func @float32_unary_scalar(%arg0: f32) {
+  // CHECK: spv.OCL.cos %{{.*}}: f32
+  %0 = math.cos %arg0 : f32
+  // CHECK: spv.OCL.exp %{{.*}}: f32
+  %1 = math.exp %arg0 : f32
+  // CHECK: spv.OCL.log %{{.*}}: f32
+  %2 = math.log %arg0 : f32
+  // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32
+  // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
+  // CHECK: spv.OCL.log %[[ADDONE]]
+  %3 = math.log1p %arg0 : f32
+  // CHECK: spv.OCL.rsqrt %{{.*}}: f32
+  %4 = math.rsqrt %arg0 : f32
+  // CHECK: spv.OCL.sqrt %{{.*}}: f32
+  %5 = math.sqrt %arg0 : f32
+  // CHECK: spv.OCL.tanh %{{.*}}: f32
+  %6 = math.tanh %arg0 : f32
+  // CHECK: spv.OCL.sin %{{.*}}: f32
+  %7 = math.sin %arg0 : f32
+  return
+}
+
+// CHECK-LABEL: @float32_unary_vector
+func @float32_unary_vector(%arg0: vector<3xf32>) {
+  // CHECK: spv.OCL.cos %{{.*}}: vector<3xf32>
+  %0 = math.cos %arg0 : vector<3xf32>
+  // CHECK: spv.OCL.exp %{{.*}}: vector<3xf32>
+  %1 = math.exp %arg0 : vector<3xf32>
+  // CHECK: spv.OCL.log %{{.*}}: vector<3xf32>
+  %2 = math.log %arg0 : vector<3xf32>
+  // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32>
+  // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}}
+  // CHECK: spv.OCL.log %[[ADDONE]]
+  %3 = math.log1p %arg0 : vector<3xf32>
+  // CHECK: spv.OCL.rsqrt %{{.*}}: vector<3xf32>
+  %4 = math.rsqrt %arg0 : vector<3xf32>
+  // CHECK: spv.OCL.sqrt %{{.*}}: vector<3xf32>
+  %5 = math.sqrt %arg0 : vector<3xf32>
+  // CHECK: spv.OCL.tanh %{{.*}}: vector<3xf32>
+  %6 = math.tanh %arg0 : vector<3xf32>
+  // CHECK: spv.OCL.sin %{{.*}}: vector<3xf32>
+  %7 = math.sin %arg0 : vector<3xf32>
+  return
+}
+
+// CHECK-LABEL: @float32_binary_scalar
+func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
+  // CHECK: spv.OCL.pow %{{.*}}: f32
+  %0 = math.powf %lhs, %rhs : f32
+  return
+}
+
+// CHECK-LABEL: @float32_binary_vector
+func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
+  // CHECK: spv.OCL.pow %{{.*}}: vector<4xf32>
+  %0 = math.powf %lhs, %rhs : vector<4xf32>
+  return
+}
+
+} // end module

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index db11f41c11e43..029c161b3df45 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -6,7 +6,7 @@
 
 module attributes {
   spv.target_env = #spv.target_env<
-    #spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, {}>
+    #spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, {}>
 } {
 
 // Check integer operation conversions.

diff  --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir
index 37ab7d90e2bc5..c6c9af28f97d6 100644
--- a/mlir/test/Target/SPIRV/ocl-ops.mlir
+++ b/mlir/test/Target/SPIRV/ocl-ops.mlir
@@ -14,6 +14,14 @@ spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
     %4 = spv.OCL.log %arg0 : f32
     // CHECK: {{%.*}} = spv.OCL.sqrt {{%.*}} : f32
     %5 = spv.OCL.sqrt %arg0 : f32
+    // CHECK: {{%.*}} = spv.OCL.ceil {{%.*}} : f32
+    %6 = spv.OCL.ceil %arg0 : f32
+    // CHECK: {{%.*}} = spv.OCL.floor {{%.*}} : f32
+    %7 = spv.OCL.floor %arg0 : f32
+    // CHECK: {{%.*}} = spv.OCL.pow {{%.*}}, {{%.*}} : f32
+    %8 = spv.OCL.pow %arg0, %arg0 : f32
+    // CHECK: {{%.*}} = spv.OCL.rsqrt {{%.*}} : f32
+    %9 = spv.OCL.rsqrt %arg0 : f32
     spv.Return
   }
 


        


More information about the Mlir-commits mailing list