[llvm-branch-commits] [mlir] f60e0a9 - [MLIR][SPIRV] Add `UnsignedOp` trait.
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jan 6 06:33:34 PST 2021
Author: KareemErgawy-TomTom
Date: 2021-01-06T15:28:41+01:00
New Revision: f60e0a91fbdd8e3409f5ee883a05a6c77f70720c
URL: https://github.com/llvm/llvm-project/commit/f60e0a91fbdd8e3409f5ee883a05a6c77f70720c
DIFF: https://github.com/llvm/llvm-project/commit/f60e0a91fbdd8e3409f5ee883a05a6c77f70720c.diff
LOG: [MLIR][SPIRV] Add `UnsignedOp` trait.
This commit adds a new trait that can be attached to ops that have
unsigned semantics.
TODO:
- Check if other places in code can use the new attribute (possibly in this patch).
- Add a similar `SignedOp` attribute (in a new patch).
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D94068
Added:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 0d6dd015b7e3..609f5105e11b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -514,7 +514,7 @@ def SPV_SRemOp : SPV_ArithmeticBinaryOp<"SRem", SPV_Integer, []> {
// -----
-def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, []> {
+def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, [UnsignedOp]> {
let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
let description = [{
@@ -546,7 +546,7 @@ def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, []> {
// -----
-def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer> {
+def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer, [UnsignedOp]> {
let summary = "Unsigned modulo operation of Operand 1 modulo Operand 2.";
let description = [{
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
index 1c9dbd758857..289e9a23bb35 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
@@ -438,7 +438,7 @@ def SPV_AtomicSMinOp : SPV_AtomicUpdateWithValueOp<"AtomicSMin", []> {
// -----
-def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", []> {
+def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", [UnsignedOp]> {
let summary = [{
Perform the following steps atomically with respect to any other atomic
accesses within Scope to the same location:
@@ -480,7 +480,7 @@ def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", []> {
// -----
-def SPV_AtomicUMinOp : SPV_AtomicUpdateWithValueOp<"AtomicUMin", []> {
+def SPV_AtomicUMinOp : SPV_AtomicUpdateWithValueOp<"AtomicUMin", [UnsignedOp]> {
let summary = [{
Perform the following steps atomically with respect to any other atomic
accesses within Scope to the same location:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2ed11015b960..a9603adc3df0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3115,6 +3115,8 @@ def InModuleScope : PredOpTrait<
"op must appear in a module-like op's block",
CPred<"isDirectInModuleLikeOp($_op.getParentOp())">>;
+def UnsignedOp : NativeOpTrait<"spirv::UnsignedOp">;
+
//===----------------------------------------------------------------------===//
// SPIR-V opcode specification
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 3df9798c4d81..173a031ebaee 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -232,7 +232,8 @@ def SPV_BitFieldSExtractOp : SPV_BitFieldExtractOp<"BitFieldSExtract", []> {
// -----
-def SPV_BitFieldUExtractOp : SPV_BitFieldExtractOp<"BitFieldUExtract", []> {
+def SPV_BitFieldUExtractOp : SPV_BitFieldExtractOp<"BitFieldUExtract",
+ [UnsignedOp]> {
let summary = "Extract a bit field from an object, without sign extension.";
let description = [{
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index 726f79f94fa1..20d4afd87035 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -196,7 +196,10 @@ def SPV_ConvertSToFOp : SPV_CastOp<"ConvertSToF", SPV_Float, SPV_Integer, []> {
// -----
-def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF", SPV_Float, SPV_Integer, []> {
+def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF",
+ SPV_Float,
+ SPV_Integer,
+ [UnsignedOp]> {
let summary = [{
Convert value numerically from unsigned integer to floating point.
}];
@@ -298,7 +301,10 @@ def SPV_SConvertOp : SPV_CastOp<"SConvert", SPV_Integer, SPV_Integer, []> {
// -----
-def SPV_UConvertOp : SPV_CastOp<"UConvert", SPV_Integer, SPV_Integer, []> {
+def SPV_UConvertOp : SPV_CastOp<"UConvert",
+ SPV_Integer,
+ SPV_Integer,
+ [UnsignedOp]> {
let summary = [{
Convert unsigned width. This is either a truncate or a zero extend.
}];
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 71254c17fbc3..46ff3d9e2b61 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -869,7 +869,9 @@ def SPV_SelectOp : SPV_Op<"Select",
// -----
-def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", SPV_Integer, []> {
+def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan",
+ SPV_Integer,
+ [UnsignedOp]> {
let summary = [{
Unsigned-integer comparison if Operand 1 is greater than Operand 2.
}];
@@ -902,7 +904,9 @@ def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", SPV_Integer, []> {
// -----
-def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", SPV_Integer, []> {
+def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual",
+ SPV_Integer,
+ [UnsignedOp]> {
let summary = [{
Unsigned-integer comparison if Operand 1 is greater than or equal to
Operand 2.
@@ -936,7 +940,9 @@ def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", SPV_Integ
// -----
-def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", SPV_Integer, []> {
+def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan",
+ SPV_Integer,
+ [UnsignedOp]> {
let summary = [{
Unsigned-integer comparison if Operand 1 is less than Operand 2.
}];
@@ -970,7 +976,7 @@ def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", SPV_Integer, []> {
// -----
def SPV_ULessThanEqualOp :
- SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, []> {
+ SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, [UnsignedOp]> {
let summary = [{
Unsigned-integer comparison if Operand 1 is less than or equal to
Operand 2.
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
index 9a22a930b62d..89eaf1d95e28 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
@@ -631,7 +631,9 @@ def SPV_GroupNonUniformSMinOp :
// -----
def SPV_GroupNonUniformUMaxOp :
- SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMax", SPV_Integer, []> {
+ SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMax",
+ SPV_Integer,
+ [UnsignedOp]> {
let summary = [{
An unsigned integer maximum group operation of all Value operands
contributed by active invocations in the group.
@@ -681,7 +683,9 @@ def SPV_GroupNonUniformUMaxOp :
// -----
def SPV_GroupNonUniformUMinOp :
- SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMin", SPV_Integer, []> {
+ SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMin",
+ SPV_Integer,
+ [UnsignedOp]> {
let summary = [{
An unsigned integer minimum group operation of all Value operands
contributed by active invocations in the group.
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h
new file mode 100644
index 000000000000..3e6792373f58
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h
@@ -0,0 +1,30 @@
+//===- SPIRVOps.h - MLIR SPIR-V operation traits ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares C++ classes for some of operation traits in the SPIR-V
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_
+#define MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace spirv {
+
+template <typename ConcreteType>
+class UnsignedOp : public TraitBase<ConcreteType, UnsignedOp> {};
+
+} // namespace spirv
+} // namespace OpTrait
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
index 5e9a46d039d7..2de2bc0b4bcf 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
#define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
+#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 164ed36095ac..88d0a818b230 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -187,36 +187,6 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
offset);
}
-/// Returns true if the operator is operating on unsigned integers.
-/// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information
-/// to the ops themselves.
-template <typename SPIRVOp>
-bool isUnsignedOp() {
- return false;
-}
-
-#define CHECK_UNSIGNED_OP(SPIRVOp) \
- template <> \
- bool isUnsignedOp<SPIRVOp>() { \
- return true; \
- }
-
-CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp)
-CHECK_UNSIGNED_OP(spirv::AtomicUMinOp)
-CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp)
-CHECK_UNSIGNED_OP(spirv::ConvertUToFOp)
-CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp)
-CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp)
-CHECK_UNSIGNED_OP(spirv::UConvertOp)
-CHECK_UNSIGNED_OP(spirv::UDivOp)
-CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp)
-CHECK_UNSIGNED_OP(spirv::UGreaterThanOp)
-CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp)
-CHECK_UNSIGNED_OP(spirv::ULessThanOp)
-CHECK_UNSIGNED_OP(spirv::UModOp)
-
-#undef CHECK_UNSIGNED_OP
-
/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
static bool isAllocationSupported(MemRefType t) {
// Currently only support workgroup local memory allocations with static
@@ -334,7 +304,8 @@ class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
auto dstType = this->typeConverter.convertType(operation.getType());
if (!dstType)
return failure();
- if (isUnsignedOp<SPIRVOp>() && dstType != operation.getType()) {
+ if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
+ dstType != operation.getType()) {
return operation.emitError(
"bitwidth emulation is not implemented yet on unsigned op");
}
@@ -799,7 +770,7 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
switch (cmpIOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
- if (isUnsignedOp<spirvOp>() && \
+ if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
operandType != this->typeConverter.convertType(operandType)) { \
return cmpIOp.emitError( \
"bitwidth emulation is not implemented yet on unsigned op"); \
More information about the llvm-branch-commits
mailing list