[Mlir-commits] [mlir] 32389d0 - [mlir][spirv] Add OpenCL fma op and lowering

Ivan Butygin llvmlistbot at llvm.org
Tue Feb 15 00:29:52 PST 2022


Author: Ivan Butygin
Date: 2022-02-15T11:28:20+03:00
New Revision: 32389d0c2e2db54f9d3f78fae0e113060a9b4074

URL: https://github.com/llvm/llvm-project/commit/32389d0c2e2db54f9d3f78fae0e113060a9b4074
DIFF: https://github.com/llvm/llvm-project/commit/32389d0c2e2db54f9d3f78fae0e113060a9b4074.diff

LOG: [mlir][spirv] Add OpenCL fma op and lowering

Also, it seems Khronos has changed html spec format so small adjustment to script was needed.
Base op parsing is also probably broken.

Differential Revision: https://reviews.llvm.org/D119678

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
    mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
    mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
    mlir/test/Target/SPIRV/ocl-ops.mlir
    mlir/utils/spirv/gen_spirv_dialect.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
index fcc69a49d5593..c1d1fd480d4f5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td
@@ -82,15 +82,46 @@ class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
   let assemblyFormat = "operands attr-dict `:` type($result)";
 }
 
+// Base class for OpenCL binary ops.
+class SPV_OCLTernaryOp<string mnemonic, int opcode, Type resultType,
+                      Type operandType, list<Trait> traits = []> :
+  SPV_OCLOp<mnemonic, opcode, !listconcat([NoSideEffect], traits)> {
+
+  let arguments = (ins
+    SPV_ScalarOrVectorOf<operandType>:$x,
+    SPV_ScalarOrVectorOf<operandType>:$y,
+    SPV_ScalarOrVectorOf<operandType>:$z
+  );
+
+  let results = (outs
+    SPV_ScalarOrVectorOf<resultType>:$result
+  );
+
+  let hasVerifier = 0;
+}
+
+// Base class for OpenCL Ternary arithmetic ops where operand types and
+// return type matches.
+class SPV_OCLTernaryArithmeticOp<string mnemonic, int opcode, Type type,
+                                list<Trait> traits = []> :
+  SPV_OCLTernaryOp<mnemonic, opcode, type, type,
+                  traits # [SameOperandsAndResultType]> {
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
+
 // -----
 
-def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
+def SPV_OCLFmaOp : SPV_OCLTernaryArithmeticOp<"fma", 26, SPV_Float> {
   let summary = [{
-    Error function of x encountered in integrating the normal distribution.
+    Compute the correctly rounded floating-point representation of the sum
+    of c with the infinitely precise product of a and b. Rounding of
+    intermediate products shall not occur. Edge case results are per the
+    IEEE 754-2008 standard.
   }];
 
   let description = [{
-    Result Type and x must be floating-point or vector(2,3,4,8,16) of
+    Result Type, a, b and c must be floating-point or vector(2,3,4,8,16) of
     floating-point values.
 
     All of the operands, including the Result Type operand, must be of the
@@ -99,17 +130,13 @@ def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
     <!-- End of AutoGen section -->
 
     ```
-    float-scalar-vector-type ::= float-type |
-                                 `vector<` integer-literal `x` float-type `>`
-    erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:`
+    fma-op ::= ssa-id `=` `spv.OCL.fma` ssa-use, ssa-use, ssa-use `:`
                float-scalar-vector-type
     ```mlir
 
-    #### Example:
-
     ```
-    %2 = spv.OCL.erf %0 : f32
-    %3 = spv.OCL.erf %1 : vector<3xf16>
+    %0 = spv.OCL.fma %a, %b, %c : f32
+    %1 = spv.OCL.fma %a, %b, %c : vector<3xf16>
     ```
   }];
 }
@@ -179,6 +206,38 @@ def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> {
 
 // -----
 
+def SPV_OCLErfOp : SPV_OCLUnaryArithmeticOp<"erf", 18, SPV_Float> {
+  let summary = [{
+    Error function of x encountered in integrating the normal distribution.
+  }];
+
+  let description = [{
+    Result Type and x must be floating-point or vector(2,3,4,8,16) of
+    floating-point values.
+
+    All of the operands, including the Result Type operand, must be of the
+    same type.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    float-scalar-vector-type ::= float-type |
+                                 `vector<` integer-literal `x` float-type `>`
+    erf-op ::= ssa-id `=` `spv.OCL.erf` ssa-use `:`
+               float-scalar-vector-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %2 = spv.OCL.erf %0 : f32
+    %3 = spv.OCL.erf %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
 def SPV_OCLExpOp : SPV_OCLUnaryArithmeticOp<"exp", 19, SPV_Float> {
   let summary = "Exponentiation of Operand 1";
 

diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 90588ed9bd5f0..6d309d6760830 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -92,13 +92,13 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
            spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
            spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
+           spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
            spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
            spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
            spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
            spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
            spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
-           spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
-           spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>>(
+           spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
           typeConverter, patterns.getContext());
 
   // OpenCL patterns
@@ -109,6 +109,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
                spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
                spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
+               spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
                spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
                spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
                spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
index c996e7056783a..e172114f22eac 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
@@ -76,7 +76,7 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
   return
 }
 
-  // CHECK-LABEL: @float32_ternary_scalar
+// CHECK-LABEL: @float32_ternary_scalar
 func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
   // CHECK: spv.GLSL.Fma %{{.*}}: f32
   %0 = math.fma %a, %b, %c : f32

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index d0959efc98ab2..8a248edc9902f 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -78,4 +78,19 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
   return
 }
 
+// CHECK-LABEL: @float32_ternary_scalar
+func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
+  // CHECK: spv.OCL.fma %{{.*}}: f32
+  %0 = math.fma %a, %b, %c : f32
+  return
+}
+
+// CHECK-LABEL: @float32_ternary_vector
+func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
+                            %c: vector<4xf32>) {
+  // CHECK: spv.OCL.fma %{{.*}}: vector<4xf32>
+  %0 = math.fma %a, %b, %c : vector<4xf32>
+  return
+}
+
 } // end module

diff  --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index e60f74f0a10ed..c44add6db7279 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -166,3 +166,22 @@ func @sabs(%arg0 : i32) -> () {
   return
 }
 
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.OCL.fma
+//===----------------------------------------------------------------------===//
+
+func @fma(%a : f32, %b : f32, %c : f32) -> () {
+  // CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
+  %2 = spv.OCL.fma %a, %b, %c : f32
+  return
+}
+
+// -----
+
+func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () {
+  // CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32>
+  %2 = spv.OCL.fma %a, %b, %c : vector<3xf32>
+  return
+}

diff  --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir
index 62360e0240820..ff26a0480276f 100644
--- a/mlir/test/Target/SPIRV/ocl-ops.mlir
+++ b/mlir/test/Target/SPIRV/ocl-ops.mlir
@@ -38,4 +38,10 @@ spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
     %0 = spv.OCL.fabs %arg0 : vector<16xf32>
     spv.Return
   }
+
+  spv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" {
+    // CHECK: spv.OCL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
+    %13 = spv.OCL.fma %arg0, %arg1, %arg2 : f32
+    spv.Return
+  }
 }

diff  --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index 2d0c7649f34f1..72db3493c126c 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -51,7 +51,7 @@ def get_spirv_doc_from_html_spec(url, settings):
   doc = {}
 
   if settings.gen_ocl_ops:
-    section_anchor = spirv.find('h2', {'id': '_a_id_binary_a_binary_form'})
+    section_anchor = spirv.find('h2', {'id': '_binary_form'})
     for section in section_anchor.parent.find_all('div', {'class': 'sect2'}):
       for table in section.find_all('table'):
         inst_html = table.tbody.tr.td


        


More information about the Mlir-commits mailing list