[Mlir-commits] [mlir] 25d882e - [MLIR][SPIRV] Add `UsableInSpecConstantOp` trait.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 8 06:07:52 PST 2021


Author: KareemErgawy-TomTom
Date: 2021-01-08T15:07:40+01:00
New Revision: 25d882e758cc9af5de12d8a118cb6eecad14d316

URL: https://github.com/llvm/llvm-project/commit/25d882e758cc9af5de12d8a118cb6eecad14d316
DIFF: https://github.com/llvm/llvm-project/commit/25d882e758cc9af5de12d8a118cb6eecad14d316.diff

LOG: [MLIR][SPIRV] Add `UsableInSpecConstantOp` trait.

Instead of checking explicitly checking for whether an op is usalbe
inside a `SpecConstantOperationOP`, this commit adds a new trait to
filter such ops.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D94288

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 609f5105e11b..cb5c7d6a54b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -268,7 +268,9 @@ def SPV_FSubOp : SPV_ArithmeticBinaryOp<"FSub", SPV_Float, []> {
 
 // -----
 
-def SPV_IAddOp : SPV_ArithmeticBinaryOp<"IAdd", SPV_Integer, [Commutative]> {
+def SPV_IAddOp : SPV_ArithmeticBinaryOp<"IAdd",
+                                        SPV_Integer,
+                                        [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer addition of Operand 1 and Operand 2.";
 
   let description = [{
@@ -306,7 +308,9 @@ def SPV_IAddOp : SPV_ArithmeticBinaryOp<"IAdd", SPV_Integer, [Commutative]> {
 
 // -----
 
-def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul", SPV_Integer, [Commutative]> {
+def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul",
+                                        SPV_Integer,
+                                        [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer multiplication of Operand 1 and Operand 2.";
 
   let description = [{
@@ -344,7 +348,9 @@ def SPV_IMulOp : SPV_ArithmeticBinaryOp<"IMul", SPV_Integer, [Commutative]> {
 
 // -----
 
-def SPV_ISubOp : SPV_ArithmeticBinaryOp<"ISub", SPV_Integer, []> {
+def SPV_ISubOp : SPV_ArithmeticBinaryOp<"ISub",
+                                        SPV_Integer,
+                                        [UsableInSpecConstantOp]> {
   let summary = "Integer subtraction of Operand 2 from Operand 1.";
 
   let description = [{
@@ -382,7 +388,9 @@ def SPV_ISubOp : SPV_ArithmeticBinaryOp<"ISub", SPV_Integer, []> {
 
 // -----
 
-def SPV_SDivOp : SPV_ArithmeticBinaryOp<"SDiv", SPV_Integer, []> {
+def SPV_SDivOp : SPV_ArithmeticBinaryOp<"SDiv",
+                                        SPV_Integer,
+                                        [UsableInSpecConstantOp]> {
   let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -415,7 +423,9 @@ def SPV_SDivOp : SPV_ArithmeticBinaryOp<"SDiv", SPV_Integer, []> {
 
 // -----
 
-def SPV_SModOp : SPV_ArithmeticBinaryOp<"SMod", SPV_Integer, []> {
+def SPV_SModOp : SPV_ArithmeticBinaryOp<"SMod",
+                                        SPV_Integer,
+                                        [UsableInSpecConstantOp]> {
   let summary = [{
     Signed remainder operation for the remainder whose sign matches the sign
     of Operand 2.
@@ -452,7 +462,9 @@ def SPV_SModOp : SPV_ArithmeticBinaryOp<"SMod", SPV_Integer, []> {
 
 // -----
 
-def SPV_SNegateOp : SPV_ArithmeticUnaryOp<"SNegate", SPV_Integer, []> {
+def SPV_SNegateOp : SPV_ArithmeticUnaryOp<"SNegate",
+                                          SPV_Integer,
+                                          [UsableInSpecConstantOp]> {
   let summary = "Signed-integer subtract of Operand from zero.";
 
   let description = [{
@@ -477,7 +489,9 @@ def SPV_SNegateOp : SPV_ArithmeticUnaryOp<"SNegate", SPV_Integer, []> {
 
 // -----
 
-def SPV_SRemOp : SPV_ArithmeticBinaryOp<"SRem", SPV_Integer, []> {
+def SPV_SRemOp : SPV_ArithmeticBinaryOp<"SRem",
+                                        SPV_Integer,
+                                        [UsableInSpecConstantOp]> {
   let summary = [{
     Signed remainder operation for the remainder whose sign matches the sign
     of Operand 1.
@@ -514,7 +528,9 @@ def SPV_SRemOp : SPV_ArithmeticBinaryOp<"SRem", SPV_Integer, []> {
 
 // -----
 
-def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, [UnsignedOp]> {
+def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv",
+                                        SPV_Integer,
+                                        [UnsignedOp, UsableInSpecConstantOp]> {
   let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -546,7 +562,9 @@ def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, [UnsignedOp]> {
 
 // -----
 
-def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer, [UnsignedOp]> {
+def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod",
+                                        SPV_Integer,
+                                        [UnsignedOp, UsableInSpecConstantOp]> {
   let summary = "Unsigned modulo operation of Operand 1 modulo Operand 2.";
 
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index a9603adc3df0..76374ca481fb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3117,6 +3117,8 @@ def InModuleScope : PredOpTrait<
 
 def UnsignedOp : NativeOpTrait<"spirv::UnsignedOp">;
 
+def UsableInSpecConstantOp : NativeOpTrait<"spirv::UsableInSpecConstantOp">;
+
 //===----------------------------------------------------------------------===//
 // SPIR-V opcode specification
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 173a031ebaee..446066ae067e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -309,7 +309,8 @@ def SPV_BitReverseOp : SPV_BitUnaryOp<"BitReverse", []> {
 
 // -----
 
-def SPV_BitwiseAndOp : SPV_BitBinaryOp<"BitwiseAnd", [Commutative]> {
+def SPV_BitwiseAndOp : SPV_BitBinaryOp<"BitwiseAnd",
+                                       [Commutative, UsableInSpecConstantOp]> {
   let summary = [{
     Result is 1 if both Operand 1 and Operand 2 are 1. Result is 0 if either
     Operand 1 or Operand 2 are 0.
@@ -350,7 +351,8 @@ def SPV_BitwiseAndOp : SPV_BitBinaryOp<"BitwiseAnd", [Commutative]> {
 
 // -----
 
-def SPV_BitwiseOrOp : SPV_BitBinaryOp<"BitwiseOr", [Commutative]> {
+def SPV_BitwiseOrOp : SPV_BitBinaryOp<"BitwiseOr",
+                                      [Commutative, UsableInSpecConstantOp]> {
   let summary = [{
     Result is 1 if either Operand 1 or Operand 2 is 1. Result is 0 if both
     Operand 1 and Operand 2 are 0.
@@ -391,7 +393,8 @@ def SPV_BitwiseOrOp : SPV_BitBinaryOp<"BitwiseOr", [Commutative]> {
 
 // -----
 
-def SPV_BitwiseXorOp : SPV_BitBinaryOp<"BitwiseXor", [Commutative]> {
+def SPV_BitwiseXorOp : SPV_BitBinaryOp<"BitwiseXor",
+                                       [Commutative, UsableInSpecConstantOp]> {
   let summary = [{
     Result is 1 if exactly one of Operand 1 or Operand 2 is 1. Result is 0
     if Operand 1 and Operand 2 have the same value.
@@ -432,7 +435,8 @@ def SPV_BitwiseXorOp : SPV_BitBinaryOp<"BitwiseXor", [Commutative]> {
 
 // -----
 
-def SPV_ShiftLeftLogicalOp : SPV_ShiftOp<"ShiftLeftLogical", []> {
+def SPV_ShiftLeftLogicalOp : SPV_ShiftOp<"ShiftLeftLogical",
+                                         [UsableInSpecConstantOp]> {
   let summary = [{
     Shift the bits in Base left by the number of bits specified in Shift.
     The least-significant bits will be zero filled.
@@ -483,7 +487,8 @@ def SPV_ShiftLeftLogicalOp : SPV_ShiftOp<"ShiftLeftLogical", []> {
 
 // -----
 
-def SPV_ShiftRightArithmeticOp : SPV_ShiftOp<"ShiftRightArithmetic", []> {
+def SPV_ShiftRightArithmeticOp : SPV_ShiftOp<"ShiftRightArithmetic",
+                                             [UsableInSpecConstantOp]> {
   let summary = [{
     Shift the bits in Base right by the number of bits specified in Shift.
     The most-significant bits will be filled with the sign bit from Base.
@@ -531,7 +536,8 @@ def SPV_ShiftRightArithmeticOp : SPV_ShiftOp<"ShiftRightArithmetic", []> {
 
 // -----
 
-def SPV_ShiftRightLogicalOp : SPV_ShiftOp<"ShiftRightLogical", []> {
+def SPV_ShiftRightLogicalOp : SPV_ShiftOp<"ShiftRightLogical",
+                                          [UsableInSpecConstantOp]> {
   let summary = [{
     Shift the bits in Base right by the number of bits specified in Shift.
     The most-significant bits will be zero filled.
@@ -580,7 +586,7 @@ def SPV_ShiftRightLogicalOp : SPV_ShiftOp<"ShiftRightLogical", []> {
 
 // -----
 
-def SPV_NotOp : SPV_BitUnaryOp<"Not", []> {
+def SPV_NotOp : SPV_BitUnaryOp<"Not", [UsableInSpecConstantOp]> {
   let summary = "Complement the bits of Operand.";
 
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index 20d4afd87035..931f16425fe2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -232,7 +232,10 @@ def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF",
 
 // -----
 
-def SPV_FConvertOp : SPV_CastOp<"FConvert", SPV_Float, SPV_Float, []> {
+def SPV_FConvertOp : SPV_CastOp<"FConvert",
+                                SPV_Float,
+                                SPV_Float,
+                                [UsableInSpecConstantOp]> {
   let summary = [{
     Convert value numerically from one floating-point width to another
     width.
@@ -267,7 +270,7 @@ def SPV_FConvertOp : SPV_CastOp<"FConvert", SPV_Float, SPV_Float, []> {
 
 // -----
 
-def SPV_SConvertOp : SPV_CastOp<"SConvert", SPV_Integer, SPV_Integer, []> {
+def SPV_SConvertOp : SPV_CastOp<"SConvert", SPV_Integer, SPV_Integer, [UsableInSpecConstantOp]> {
   let summary = [{
     Convert signed width.  This is either a truncate or a sign extend.
   }];
@@ -304,7 +307,7 @@ def SPV_SConvertOp : SPV_CastOp<"SConvert", SPV_Integer, SPV_Integer, []> {
 def SPV_UConvertOp : SPV_CastOp<"UConvert",
                                 SPV_Integer,
                                 SPV_Integer,
-                                [UnsignedOp]> {
+                                [UnsignedOp, UsableInSpecConstantOp]> {
   let summary = [{
     Convert unsigned width. This is either a truncate or a zero extend.
   }];

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index 035d73f1c842..e384dac65acb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -69,7 +69,8 @@ def SPV_CompositeConstructOp : SPV_Op<"CompositeConstruct", [NoSideEffect]> {
 
 // -----
 
-def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
+def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract",
+                                    [NoSideEffect, UsableInSpecConstantOp]> {
   let summary = "Extract a part of a composite object.";
 
   let description = [{
@@ -119,7 +120,8 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
 
 // -----
 
-def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert", [NoSideEffect]> {
+def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert",
+                                   [NoSideEffect, UsableInSpecConstantOp]> {
   let summary = [{
     Make a copy of a composite object, while modifying one part of it.
   }];

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 46ff3d9e2b61..b4c5662217c8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -441,7 +441,9 @@ def SPV_FUnordNotEqualOp : SPV_LogicalBinaryOp<"FUnordNotEqual", SPV_Float, [Com
 
 // -----
 
-def SPV_IEqualOp : SPV_LogicalBinaryOp<"IEqual", SPV_Integer, [Commutative]> {
+def SPV_IEqualOp : SPV_LogicalBinaryOp<"IEqual",
+                                       SPV_Integer,
+                                       [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer comparison for equality.";
 
   let description = [{
@@ -472,7 +474,9 @@ def SPV_IEqualOp : SPV_LogicalBinaryOp<"IEqual", SPV_Integer, [Commutative]> {
 
 // -----
 
-def SPV_INotEqualOp : SPV_LogicalBinaryOp<"INotEqual", SPV_Integer, [Commutative]> {
+def SPV_INotEqualOp : SPV_LogicalBinaryOp<"INotEqual",
+                                          SPV_Integer,
+                                          [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer comparison for inequality.";
 
   let description = [{
@@ -503,7 +507,10 @@ def SPV_INotEqualOp : SPV_LogicalBinaryOp<"INotEqual", SPV_Integer, [Commutative
 
 // -----
 
-def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative]> {
+def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd",
+                                           SPV_Bool,
+                                           [Commutative,
+                                            UsableInSpecConstantOp]> {
   let summary = [{
     Result is true if both Operand 1 and Operand 2 are true. Result is false
     if either Operand 1 or Operand 2 are false.
@@ -538,7 +545,10 @@ def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative]
 
 // -----
 
-def SPV_LogicalEqualOp : SPV_LogicalBinaryOp<"LogicalEqual", SPV_Bool, [Commutative]> {
+def SPV_LogicalEqualOp : SPV_LogicalBinaryOp<"LogicalEqual",
+                                             SPV_Bool,
+                                             [Commutative,
+                                              UsableInSpecConstantOp]> {
   let summary = [{
     Result is true if Operand 1 and Operand 2 have the same value. Result is
     false if Operand 1 and Operand 2 have 
diff erent values.
@@ -571,7 +581,9 @@ def SPV_LogicalEqualOp : SPV_LogicalBinaryOp<"LogicalEqual", SPV_Bool, [Commutat
 
 // -----
 
-def SPV_LogicalNotOp : SPV_LogicalUnaryOp<"LogicalNot", SPV_Bool, []> {
+def SPV_LogicalNotOp : SPV_LogicalUnaryOp<"LogicalNot",
+                                          SPV_Bool,
+                                          [UsableInSpecConstantOp]> {
   let summary = [{
     Result is true if Operand is false.  Result is false if Operand is true.
   }];
@@ -602,7 +614,10 @@ def SPV_LogicalNotOp : SPV_LogicalUnaryOp<"LogicalNot", SPV_Bool, []> {
 
 // -----
 
-def SPV_LogicalNotEqualOp : SPV_LogicalBinaryOp<"LogicalNotEqual", SPV_Bool, [Commutative]> {
+def SPV_LogicalNotEqualOp : SPV_LogicalBinaryOp<"LogicalNotEqual",
+                                                SPV_Bool,
+                                                [Commutative,
+                                                 UsableInSpecConstantOp]> {
   let summary = [{
     Result is true if Operand 1 and Operand 2 have 
diff erent values. Result
     is false if Operand 1 and Operand 2 have the same value.
@@ -635,7 +650,10 @@ def SPV_LogicalNotEqualOp : SPV_LogicalBinaryOp<"LogicalNotEqual", SPV_Bool, [Co
 
 // -----
 
-def SPV_LogicalOrOp : SPV_LogicalBinaryOp<"LogicalOr", SPV_Bool, [Commutative]> {
+def SPV_LogicalOrOp : SPV_LogicalBinaryOp<"LogicalOr",
+                                          SPV_Bool,
+                                          [Commutative,
+                                           UsableInSpecConstantOp]> {
   let summary = [{
     Result is true if either Operand 1 or Operand 2 is true. Result is false
     if both Operand 1 and Operand 2 are false.
@@ -670,7 +688,9 @@ def SPV_LogicalOrOp : SPV_LogicalBinaryOp<"LogicalOr", SPV_Bool, [Commutative]>
 
 // -----
 
-def SPV_SGreaterThanOp : SPV_LogicalBinaryOp<"SGreaterThan", SPV_Integer, []> {
+def SPV_SGreaterThanOp : SPV_LogicalBinaryOp<"SGreaterThan",
+                                             SPV_Integer,
+                                             [UsableInSpecConstantOp]> {
   let summary = [{
     Signed-integer comparison if Operand 1 is greater than  Operand 2.
   }];
@@ -703,7 +723,9 @@ def SPV_SGreaterThanOp : SPV_LogicalBinaryOp<"SGreaterThan", SPV_Integer, []> {
 
 // -----
 
-def SPV_SGreaterThanEqualOp : SPV_LogicalBinaryOp<"SGreaterThanEqual", SPV_Integer, []> {
+def SPV_SGreaterThanEqualOp : SPV_LogicalBinaryOp<"SGreaterThanEqual",
+                                                  SPV_Integer,
+                                                  [UsableInSpecConstantOp]> {
   let summary = [{
     Signed-integer comparison if Operand 1 is greater than or equal to
     Operand 2.
@@ -737,7 +759,9 @@ def SPV_SGreaterThanEqualOp : SPV_LogicalBinaryOp<"SGreaterThanEqual", SPV_Integ
 
 // -----
 
-def SPV_SLessThanOp : SPV_LogicalBinaryOp<"SLessThan", SPV_Integer, []> {
+def SPV_SLessThanOp : SPV_LogicalBinaryOp<"SLessThan",
+                                          SPV_Integer,
+                                          [UsableInSpecConstantOp]> {
   let summary = [{
     Signed-integer comparison if Operand 1 is less than Operand 2.
   }];
@@ -770,7 +794,9 @@ def SPV_SLessThanOp : SPV_LogicalBinaryOp<"SLessThan", SPV_Integer, []> {
 
 // -----
 
-def SPV_SLessThanEqualOp : SPV_LogicalBinaryOp<"SLessThanEqual", SPV_Integer, []> {
+def SPV_SLessThanEqualOp : SPV_LogicalBinaryOp<"SLessThanEqual",
+                                               SPV_Integer,
+                                               [UsableInSpecConstantOp]> {
   let summary = [{
     Signed-integer comparison if Operand 1 is less than or equal to Operand
     2.
@@ -805,7 +831,9 @@ def SPV_SLessThanEqualOp : SPV_LogicalBinaryOp<"SLessThanEqual", SPV_Integer, []
 // -----
 
 def SPV_SelectOp : SPV_Op<"Select",
-    [NoSideEffect, AllTypesMatch<["true_value", "false_value", "result"]>]> {
+    [NoSideEffect,
+     AllTypesMatch<["true_value", "false_value", "result"]>,
+     UsableInSpecConstantOp]> {
   let summary = [{
     Select between two objects. Before version 1.4, results are only
     computed per component.
@@ -871,7 +899,8 @@ def SPV_SelectOp : SPV_Op<"Select",
 
 def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan",
                                              SPV_Integer,
-                                             [UnsignedOp]> {
+                                             [UnsignedOp,
+                                              UsableInSpecConstantOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is greater than  Operand 2.
   }];
@@ -906,7 +935,8 @@ def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan",
 
 def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual",
                                                   SPV_Integer,
-                                                  [UnsignedOp]> {
+                                                  [UnsignedOp,
+                                                   UsableInSpecConstantOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is greater than or equal to
     Operand 2.
@@ -942,7 +972,7 @@ def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual",
 
 def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan",
                                           SPV_Integer,
-                                          [UnsignedOp]> {
+                                          [UnsignedOp, UsableInSpecConstantOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is less than Operand 2.
   }];
@@ -975,8 +1005,10 @@ def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan",
 
 // -----
 
-def SPV_ULessThanEqualOp :
-  SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, [UnsignedOp]> {
+def SPV_ULessThanEqualOp : SPV_LogicalBinaryOp<"ULessThanEqual",
+                                               SPV_Integer,
+                                               [UnsignedOp,
+                                                UsableInSpecConstantOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is less than or equal to
     Operand 2.

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h
index 3e6792373f58..e3e78e51b32a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h
@@ -23,6 +23,12 @@ namespace spirv {
 template <typename ConcreteType>
 class UnsignedOp : public TraitBase<ConcreteType, UnsignedOp> {};
 
+/// A trait to mark ops that can be enclosed/wrapped in a
+/// `SpecConstantOperation` op.
+template <typename ConcreteType>
+class UsableInSpecConstantOp
+    : public TraitBase<ConcreteType, UsableInSpecConstantOp> {};
+
 } // namespace spirv
 } // namespace OpTrait
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 6cada4a7b01d..ad3e78f618b7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -15,12 +15,14 @@
 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/FunctionImplementation.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "llvm/ADT/StringExtras.h"
@@ -3439,21 +3441,7 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
 
   Operation &enclosedOp = block.getOperations().front();
 
-  // TODO Add a `UsableInSpecConstantOp` trait and mark ops from the list below
-  // with it instead.
-  if (!isa<spirv::SConvertOp, spirv::UConvertOp, spirv::FConvertOp,
-           spirv::SNegateOp, spirv::NotOp, spirv::IAddOp, spirv::ISubOp,
-           spirv::IMulOp, spirv::UDivOp, spirv::SDivOp, spirv::UModOp,
-           spirv::SRemOp, spirv::SModOp, spirv::ShiftRightLogicalOp,
-           spirv::ShiftRightArithmeticOp, spirv::ShiftLeftLogicalOp,
-           spirv::BitwiseOrOp, spirv::BitwiseXorOp, spirv::BitwiseAndOp,
-           spirv::CompositeExtractOp, spirv::CompositeInsertOp,
-           spirv::LogicalOrOp, spirv::LogicalAndOp, spirv::LogicalNotOp,
-           spirv::LogicalEqualOp, spirv::LogicalNotEqualOp, spirv::SelectOp,
-           spirv::IEqualOp, spirv::INotEqualOp, spirv::ULessThanOp,
-           spirv::SLessThanOp, spirv::UGreaterThanOp, spirv::SGreaterThanOp,
-           spirv::ULessThanEqualOp, spirv::SLessThanEqualOp,
-           spirv::UGreaterThanEqualOp, spirv::SGreaterThanEqualOp>(enclosedOp))
+  if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
     return constOp.emitOpError("invalid enclosed op");
 
   for (auto operand : enclosedOp.getOperands())


        


More information about the Mlir-commits mailing list