[Mlir-commits] [mlir] 123415e - [mlir][spirv] Add OpenCL extended ops: exp, fabs, s_abs
Konrad Dobros
llvmlistbot at llvm.org
Thu Oct 8 05:54:49 PDT 2020
Author: Konrad Dobros
Date: 2020-10-08T14:54:22+02:00
New Revision: 123415eddaf7d55db8606597e6e2375869b3f395
URL: https://github.com/llvm/llvm-project/commit/123415eddaf7d55db8606597e6e2375869b3f395
DIFF: https://github.com/llvm/llvm-project/commit/123415eddaf7d55db8606597e6e2375869b3f395.diff
LOG: [mlir][spirv] Add OpenCL extended ops: exp, fabs, s_abs
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D88966
Added:
mlir/include/mlir/Dialect/SPIRV/SPIRVOCLOps.td
mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir
mlir/test/Dialect/SPIRV/ocl-ops.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOCLOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOCLOps.td
new file mode 100644
index 000000000000..78151bafb654
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOCLOps.td
@@ -0,0 +1,169 @@
+//===- SPIRVOCLOps.td - OpenCL extended insts spec file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of OpenCL extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_OCL_OPS
+#define SPIRV_OCL_OPS
+
+include "mlir/Dialect/SPIRV/SPIRVBase.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V OpenCL opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all OpenCL ops.
+class SPV_OCLOp<string mnemonic, int opcode, list<OpTrait> traits = []> :
+ SPV_ExtInstOp<mnemonic, "OCL", "OpenCL.std", opcode, traits>;
+
+// Base class for OpenCL unary ops.
+class SPV_OCLUnaryOp<string mnemonic, int opcode, Type resultType,
+ Type operandType, list<OpTrait> traits = []> :
+ SPV_OCLOp<mnemonic, opcode, !listconcat([NoSideEffect], traits)> {
+
+ let arguments = (ins
+ SPV_ScalarOrVectorOf<operandType>:$operand
+ );
+
+ let results = (outs
+ SPV_ScalarOrVectorOf<resultType>:$result
+ );
+
+ let parser = [{ return parseUnaryOp(parser, result); }];
+
+ let printer = [{ return printUnaryOp(getOperation(), p); }];
+
+ let verifier = [{ return success(); }];
+}
+
+// Base class for OpenCL Unary arithmetic ops where return type matches
+// the operand type.
+class SPV_OCLUnaryArithmeticOp<string mnemonic, int opcode, Type type,
+ list<OpTrait> traits = []> :
+ SPV_OCLUnaryOp<mnemonic, opcode, type, type, traits>;
+
+// Base class for OpenCL binary ops.
+class SPV_OCLBinaryOp<string mnemonic, int opcode, Type resultType,
+ Type operandType, list<OpTrait> traits = []> :
+ SPV_OCLOp<mnemonic, opcode, !listconcat([NoSideEffect], traits)> {
+
+ let arguments = (ins
+ SPV_ScalarOrVectorOf<operandType>:$lhs,
+ SPV_ScalarOrVectorOf<operandType>:$rhs
+ );
+
+ let results = (outs
+ SPV_ScalarOrVectorOf<resultType>:$result
+ );
+
+ let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
+
+ let printer = [{ return impl::printOneResultOp(getOperation(), p); }];
+
+ let verifier = [{ return success(); }];
+}
+
+// Base class for OpenCL Binary arithmetic ops where operand types and
+// return type matches.
+class SPV_OCLBinaryArithmeticOp<string mnemonic, int opcode, Type type,
+ list<OpTrait> traits = []> :
+ SPV_OCLBinaryOp<mnemonic, opcode, type, type, traits>;
+
+// -----
+
+def SPV_OCLExpOp : SPV_OCLUnaryArithmeticOp<"exp", 19, SPV_Float> {
+ let summary = "Exponentiation of Operand 1";
+
+ let description = [{
+ Compute the base-e exponential of x. (i.e. ex)
+
+ Result Type and x must be floating-point or vector(2,3,4,8,16) of
+ floating-point values.
+
+ All of the operands, including the Result Type operand,
+ must be of the same type.
+
+ <!-- End of AutoGen section -->
+ ```
+ float-scalar-vector-type ::= float-type |
+ `vector<` integer-literal `x` float-type `>`
+ exp-op ::= ssa-id `=` `spv.OCL.exp` ssa-use `:`
+ float-scalar-vector-type
+ ```
+ #### Example:
+
+ ```mlir
+ %2 = spv.OCL.exp %0 : f32
+ %3 = spv.OCL.exp %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
+def SPV_OCLFAbsOp : SPV_OCLUnaryArithmeticOp<"fabs", 23, SPV_Float> {
+ let summary = "Absolute value of operand";
+
+ let description = [{
+ Compute the absolute value of x.
+
+ Result Type and x must be floating-point or vector(2,3,4,8,16) of
+ floating-point values.
+
+ All of the operands, including the Result Type operand,
+ must be of the same type.
+
+ <!-- End of AutoGen section -->
+ ```
+ float-scalar-vector-type ::= float-type |
+ `vector<` integer-literal `x` float-type `>`
+ abs-op ::= ssa-id `=` `spv.OCL.fabs` ssa-use `:`
+ float-scalar-vector-type
+ ```
+ #### Example:
+
+ ```mlir
+ %2 = spv.OCL.fabs %0 : f32
+ %3 = spv.OCL.fabs %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
+def SPV_OCLSAbsOp : SPV_OCLUnaryArithmeticOp<"s_abs", 141, SPV_Integer> {
+ let summary = "Absolute value of operand";
+
+ let description = [{
+ Returns |x|, where x is treated as signed integer.
+
+ Result Type and x must be integer or vector(2,3,4,8,16) of
+ integer values.
+
+ All of the operands, including the Result Type operand,
+ must be of the same type.
+
+ <!-- End of AutoGen section -->
+ ```
+ integer-scalar-vector-type ::= integer-type |
+ `vector<` integer-literal `x` integer-type `>`
+ abs-op ::= ssa-id `=` `spv.OCL.s_abs` ssa-use `:`
+ integer-scalar-vector-type
+ ```
+ #### Example:
+
+ ```mlir
+ %2 = spv.OCL.s_abs %0 : i32
+ %3 = spv.OCL.s_abs %1 : vector<3xi16>
+ ```
+ }];
+}
+
+#endif // SPIRV_OCL_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index abc8e5d85552..e4bb134496f6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -34,6 +34,7 @@ include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
include "mlir/Dialect/SPIRV/SPIRVMatrixOps.td"
include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td"
+include "mlir/Dialect/SPIRV/SPIRVOCLOps.td"
include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
diff --git a/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir
new file mode 100644
index 000000000000..5130b4915096
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Serialization/ocl-ops.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
+
+spv.module Physical64 OpenCL requires #spv.vce<v1.0, [Kernel, Addresses], []> {
+ spv.func @float_insts(%arg0 : f32) "None" {
+ // CHECK: {{%.*}} = spv.OCL.exp {{%.*}} : f32
+ %0 = spv.OCL.exp %arg0 : f32
+ // CHECK: {{%.*}} = spv.OCL.fabs {{%.*}} : f32
+ %1 = spv.OCL.fabs %arg0 : f32
+ spv.Return
+ }
+
+ spv.func @integer_insts(%arg0 : i32) "None" {
+ // CHECK: {{%.*}} = spv.OCL.s_abs {{%.*}} : i32
+ %0 = spv.OCL.s_abs %arg0 : i32
+ spv.Return
+ }
+}
diff --git a/mlir/test/Dialect/SPIRV/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/ocl-ops.mlir
new file mode 100644
index 000000000000..e60f74f0a10e
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/ocl-ops.mlir
@@ -0,0 +1,168 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.OCL.exp
+//===----------------------------------------------------------------------===//
+
+func @exp(%arg0 : f32) -> () {
+ // CHECK: spv.OCL.exp {{%.*}} : f32
+ %2 = spv.OCL.exp %arg0 : f32
+ return
+}
+
+func @expvec(%arg0 : vector<3xf16>) -> () {
+ // CHECK: spv.OCL.exp {{%.*}} : vector<3xf16>
+ %2 = spv.OCL.exp %arg0 : vector<3xf16>
+ return
+}
+
+// -----
+
+func @exp(%arg0 : i32) -> () {
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %2 = spv.OCL.exp %arg0 : i32
+ return
+}
+
+// -----
+
+func @exp(%arg0 : vector<5xf32>) -> () {
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}}
+ %2 = spv.OCL.exp %arg0 : vector<5xf32>
+ return
+}
+
+// -----
+
+func @exp(%arg0 : f32, %arg1 : f32) -> () {
+ // expected-error @+1 {{expected ':'}}
+ %2 = spv.OCL.exp %arg0, %arg1 : i32
+ return
+}
+
+// -----
+
+func @exp(%arg0 : i32) -> () {
+ // expected-error @+2 {{expected non-function type}}
+ %2 = spv.OCL.exp %arg0 :
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.OCL.fabs
+//===----------------------------------------------------------------------===//
+
+func @fabs(%arg0 : f32) -> () {
+ // CHECK: spv.OCL.fabs {{%.*}} : f32
+ %2 = spv.OCL.fabs %arg0 : f32
+ return
+}
+
+func @fabsvec(%arg0 : vector<3xf16>) -> () {
+ // CHECK: spv.OCL.fabs {{%.*}} : vector<3xf16>
+ %2 = spv.OCL.fabs %arg0 : vector<3xf16>
+ return
+}
+
+func @fabsf64(%arg0 : f64) -> () {
+ // CHECK: spv.OCL.fabs {{%.*}} : f64
+ %2 = spv.OCL.fabs %arg0 : f64
+ return
+}
+
+// -----
+
+func @fabs(%arg0 : i32) -> () {
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %2 = spv.OCL.fabs %arg0 : i32
+ return
+}
+
+// -----
+
+func @fabs(%arg0 : vector<5xf32>) -> () {
+ // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}}
+ %2 = spv.OCL.fabs %arg0 : vector<5xf32>
+ return
+}
+
+// -----
+
+func @fabs(%arg0 : f32, %arg1 : f32) -> () {
+ // expected-error @+1 {{expected ':'}}
+ %2 = spv.OCL.fabs %arg0, %arg1 : i32
+ return
+}
+
+// -----
+
+func @fabs(%arg0 : i32) -> () {
+ // expected-error @+2 {{expected non-function type}}
+ %2 = spv.OCL.fabs %arg0 :
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.OCL.s_abs
+//===----------------------------------------------------------------------===//
+
+func @sabs(%arg0 : i32) -> () {
+ // CHECK: spv.OCL.s_abs {{%.*}} : i32
+ %2 = spv.OCL.s_abs %arg0 : i32
+ return
+}
+
+func @sabsvec(%arg0 : vector<3xi16>) -> () {
+ // CHECK: spv.OCL.s_abs {{%.*}} : vector<3xi16>
+ %2 = spv.OCL.s_abs %arg0 : vector<3xi16>
+ return
+}
+
+func @sabsi64(%arg0 : i64) -> () {
+ // CHECK: spv.OCL.s_abs {{%.*}} : i64
+ %2 = spv.OCL.s_abs %arg0 : i64
+ return
+}
+
+func @sabsi8(%arg0 : i8) -> () {
+ // CHECK: spv.OCL.s_abs {{%.*}} : i8
+ %2 = spv.OCL.s_abs %arg0 : i8
+ return
+}
+
+// -----
+
+func @sabs(%arg0 : f32) -> () {
+ // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
+ %2 = spv.OCL.s_abs %arg0 : f32
+ return
+}
+
+// -----
+
+func @sabs(%arg0 : vector<5xi32>) -> () {
+ // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}}
+ %2 = spv.OCL.s_abs %arg0 : vector<5xi32>
+ return
+}
+
+// -----
+
+func @sabs(%arg0 : i32, %arg1 : i32) -> () {
+ // expected-error @+1 {{expected ':'}}
+ %2 = spv.OCL.s_abs %arg0, %arg1 : i32
+ return
+}
+
+// -----
+
+func @sabs(%arg0 : i32) -> () {
+ // expected-error @+2 {{expected non-function type}}
+ %2 = spv.OCL.s_abs %arg0 :
+ return
+}
+
More information about the Mlir-commits
mailing list