[llvm-branch-commits] [mlir] f60e0a9 - [MLIR][SPIRV] Add `UnsignedOp` trait.

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 6 06:33:34 PST 2021


Author: KareemErgawy-TomTom
Date: 2021-01-06T15:28:41+01:00
New Revision: f60e0a91fbdd8e3409f5ee883a05a6c77f70720c

URL: https://github.com/llvm/llvm-project/commit/f60e0a91fbdd8e3409f5ee883a05a6c77f70720c
DIFF: https://github.com/llvm/llvm-project/commit/f60e0a91fbdd8e3409f5ee883a05a6c77f70720c.diff

LOG: [MLIR][SPIRV] Add `UnsignedOp` trait.

This commit adds a new trait that can be attached to ops that have
unsigned semantics.

TODO:
- Check if other places in code can use the new attribute (possibly in this patch).
- Add a similar `SignedOp` attribute (in a new patch).

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D94068

Added: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 0d6dd015b7e3..609f5105e11b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -514,7 +514,7 @@ def SPV_SRemOp : SPV_ArithmeticBinaryOp<"SRem", SPV_Integer, []> {
 
 // -----
 
-def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, []> {
+def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, [UnsignedOp]> {
   let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -546,7 +546,7 @@ def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, []> {
 
 // -----
 
-def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer> {
+def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer, [UnsignedOp]> {
   let summary = "Unsigned modulo operation of Operand 1 modulo Operand 2.";
 
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
index 1c9dbd758857..289e9a23bb35 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
@@ -438,7 +438,7 @@ def SPV_AtomicSMinOp : SPV_AtomicUpdateWithValueOp<"AtomicSMin", []> {
 
 // -----
 
-def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", []> {
+def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", [UnsignedOp]> {
   let summary = [{
     Perform the following steps atomically with respect to any other atomic
     accesses within Scope to the same location:
@@ -480,7 +480,7 @@ def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", []> {
 
 // -----
 
-def SPV_AtomicUMinOp : SPV_AtomicUpdateWithValueOp<"AtomicUMin", []> {
+def SPV_AtomicUMinOp : SPV_AtomicUpdateWithValueOp<"AtomicUMin", [UnsignedOp]> {
   let summary = [{
     Perform the following steps atomically with respect to any other atomic
     accesses within Scope to the same location:

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2ed11015b960..a9603adc3df0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3115,6 +3115,8 @@ def InModuleScope : PredOpTrait<
   "op must appear in a module-like op's block",
   CPred<"isDirectInModuleLikeOp($_op.getParentOp())">>;
 
+def UnsignedOp : NativeOpTrait<"spirv::UnsignedOp">;
+
 //===----------------------------------------------------------------------===//
 // SPIR-V opcode specification
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 3df9798c4d81..173a031ebaee 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -232,7 +232,8 @@ def SPV_BitFieldSExtractOp : SPV_BitFieldExtractOp<"BitFieldSExtract", []> {
 
 // -----
 
-def SPV_BitFieldUExtractOp : SPV_BitFieldExtractOp<"BitFieldUExtract", []> {
+def SPV_BitFieldUExtractOp : SPV_BitFieldExtractOp<"BitFieldUExtract",
+                                                   [UnsignedOp]> {
   let summary = "Extract a bit field from an object, without sign extension.";
 
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index 726f79f94fa1..20d4afd87035 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -196,7 +196,10 @@ def SPV_ConvertSToFOp : SPV_CastOp<"ConvertSToF", SPV_Float, SPV_Integer, []> {
 
 // -----
 
-def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF", SPV_Float, SPV_Integer, []> {
+def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF",
+                                   SPV_Float,
+                                   SPV_Integer,
+                                   [UnsignedOp]> {
   let summary = [{
     Convert value numerically from unsigned integer to floating point.
   }];
@@ -298,7 +301,10 @@ def SPV_SConvertOp : SPV_CastOp<"SConvert", SPV_Integer, SPV_Integer, []> {
 
 // -----
 
-def SPV_UConvertOp : SPV_CastOp<"UConvert", SPV_Integer, SPV_Integer, []> {
+def SPV_UConvertOp : SPV_CastOp<"UConvert",
+                                SPV_Integer,
+                                SPV_Integer,
+                                [UnsignedOp]> {
   let summary = [{
     Convert unsigned width. This is either a truncate or a zero extend.
   }];

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 71254c17fbc3..46ff3d9e2b61 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -869,7 +869,9 @@ def SPV_SelectOp : SPV_Op<"Select",
 
 // -----
 
-def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", SPV_Integer, []> {
+def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan",
+                                             SPV_Integer,
+                                             [UnsignedOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is greater than  Operand 2.
   }];
@@ -902,7 +904,9 @@ def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", SPV_Integer, []> {
 
 // -----
 
-def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", SPV_Integer, []> {
+def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual",
+                                                  SPV_Integer,
+                                                  [UnsignedOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is greater than or equal to
     Operand 2.
@@ -936,7 +940,9 @@ def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", SPV_Integ
 
 // -----
 
-def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", SPV_Integer, []> {
+def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan",
+                                          SPV_Integer,
+                                          [UnsignedOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is less than Operand 2.
   }];
@@ -970,7 +976,7 @@ def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", SPV_Integer, []> {
 // -----
 
 def SPV_ULessThanEqualOp :
-  SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, []> {
+  SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, [UnsignedOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is less than or equal to
     Operand 2.

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 9a22a930b62d..89eaf1d95e28 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -631,7 +631,9 @@ def SPV_GroupNonUniformSMinOp :
 // -----
 
 def SPV_GroupNonUniformUMaxOp :
-    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMax", SPV_Integer, []> {
+    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMax",
+                                    SPV_Integer,
+                                    [UnsignedOp]> {
   let summary = [{
     An unsigned integer maximum group operation of all Value operands
     contributed by active invocations in the group.
@@ -681,7 +683,9 @@ def SPV_GroupNonUniformUMaxOp :
 // -----
 
 def SPV_GroupNonUniformUMinOp :
-    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMin", SPV_Integer, []> {
+    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMin",
+                                    SPV_Integer,
+                                    [UnsignedOp]> {
   let summary = [{
     An unsigned integer minimum group operation of all Value operands
     contributed by active invocations in the group.

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h
new file mode 100644
index 000000000000..3e6792373f58
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h
@@ -0,0 +1,30 @@
+//===- SPIRVOps.h - MLIR SPIR-V operation traits ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares C++ classes for some of operation traits in the SPIR-V
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_
+#define MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace spirv {
+
+template <typename ConcreteType>
+class UnsignedOp : public TraitBase<ConcreteType, UnsignedOp> {};
+
+} // namespace spirv
+} // namespace OpTrait
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
index 5e9a46d039d7..2de2bc0b4bcf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
 #define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
 
+#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 164ed36095ac..88d0a818b230 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -187,36 +187,6 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
                                                    offset);
 }
 
-/// Returns true if the operator is operating on unsigned integers.
-/// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information
-/// to the ops themselves.
-template <typename SPIRVOp>
-bool isUnsignedOp() {
-  return false;
-}
-
-#define CHECK_UNSIGNED_OP(SPIRVOp)                                             \
-  template <>                                                                  \
-  bool isUnsignedOp<SPIRVOp>() {                                               \
-    return true;                                                               \
-  }
-
-CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp)
-CHECK_UNSIGNED_OP(spirv::AtomicUMinOp)
-CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp)
-CHECK_UNSIGNED_OP(spirv::ConvertUToFOp)
-CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp)
-CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp)
-CHECK_UNSIGNED_OP(spirv::UConvertOp)
-CHECK_UNSIGNED_OP(spirv::UDivOp)
-CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp)
-CHECK_UNSIGNED_OP(spirv::UGreaterThanOp)
-CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp)
-CHECK_UNSIGNED_OP(spirv::ULessThanOp)
-CHECK_UNSIGNED_OP(spirv::UModOp)
-
-#undef CHECK_UNSIGNED_OP
-
 /// Returns true if the allocations of type `t` can be lowered to SPIR-V.
 static bool isAllocationSupported(MemRefType t) {
   // Currently only support workgroup local memory allocations with static
@@ -334,7 +304,8 @@ class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
     auto dstType = this->typeConverter.convertType(operation.getType());
     if (!dstType)
       return failure();
-    if (isUnsignedOp<SPIRVOp>() && dstType != operation.getType()) {
+    if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
+        dstType != operation.getType()) {
       return operation.emitError(
           "bitwidth emulation is not implemented yet on unsigned op");
     }
@@ -799,7 +770,7 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
   switch (cmpIOp.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
-    if (isUnsignedOp<spirvOp>() &&                                             \
+    if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
         operandType != this->typeConverter.convertType(operandType)) {         \
       return cmpIOp.emitError(                                                 \
           "bitwidth emulation is not implemented yet on unsigned op");         \


        


More information about the llvm-branch-commits mailing list