[Mlir-commits] [mlir] 1cb13fd - [mlir] spirv: Add some atomic ops
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 3 04:52:03 PDT 2021
Author: Butygin
Date: 2021-11-03T14:47:12+03:00
New Revision: 1cb13fddb9d87177dc9f543027cb573de45a94bc
URL: https://github.com/llvm/llvm-project/commit/1cb13fddb9d87177dc9f543027cb573de45a94bc
DIFF: https://github.com/llvm/llvm-project/commit/1cb13fddb9d87177dc9f543027cb573de45a94bc.diff
LOG: [mlir] spirv: Add some atomic ops
Differential Revision: https://reviews.llvm.org/D112812
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
mlir/test/Target/SPIRV/atomic-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
index 122987a212d0e..406f6647a13d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
@@ -102,6 +102,77 @@ def SPV_AtomicAndOp : SPV_AtomicUpdateWithValueOp<"AtomicAnd", []> {
// -----
+def SPV_AtomicCompareExchangeOp : SPV_Op<"AtomicCompareExchange", []> {
+ let summary = [{
+ Perform the following steps atomically with respect to any other atomic
+ accesses within Scope to the same location:
+ }];
+
+ let description = [{
+ 1) load through Pointer to get an Original Value,
+
+ 2) get a New Value from Value only if Original Value equals Comparator,
+ and
+
+ 3) store the New Value back through Pointer'only if 'Original Value
+ equaled Comparator.
+
+ The instruction's result is the Original Value.
+
+ Result Type must be an integer type scalar.
+
+ Use Equal for the memory semantics of this instruction when Value and
+ Original Value compare equal.
+
+ Use Unequal for the memory semantics of this instruction when Value and
+ Original Value compare unequal. Unequal must not be set to Release or
+ Acquire and Release. In addition, Unequal cannot be set to a stronger
+ memory-order then Equal.
+
+ The type of Value must be the same as Result Type. The type of the
+ value pointed to by Pointer must be the same as Result Type. This type
+ must also match the type of Comparator.
+
+ Memory is a memory Scope.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ atomic-compare-exchange-op ::=
+ `spv.AtomicCompareExchange` scope memory-semantics memory-semantics
+ ssa-use `,` ssa-use `,` ssa-use
+ `:` spv-pointer-type
+ ```mlir
+
+ #### Example:
+
+ ```
+ %0 = spv.AtomicCompareExchange "Workgroup" "Acquire" "None"
+ %pointer, %value, %comparator
+ : !spv.ptr<i32, WorkGroup>
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_AnyPtr:$pointer,
+ SPV_ScopeAttr:$memory_scope,
+ SPV_MemorySemanticsAttr:$equal_semantics,
+ SPV_MemorySemanticsAttr:$unequal_semantics,
+ SPV_Integer:$value,
+ SPV_Integer:$comparator
+ );
+
+ let results = (outs
+ SPV_Integer:$result
+ );
+
+ let parser = [{ return ::parseAtomicCompareExchangeImpl(parser, result); }];
+ let printer = [{ return ::printAtomicCompareExchangeImpl(*this, p); }];
+ let verifier = [{ return ::verifyAtomicCompareExchangeImpl(*this); }];
+}
+
+// -----
+
def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> {
let summary = "Deprecated (use OpAtomicCompareExchange).";
@@ -147,6 +218,62 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> {
let results = (outs
SPV_Integer:$result
);
+
+ let parser = [{ return ::parseAtomicCompareExchangeImpl(parser, result); }];
+ let printer = [{ return ::printAtomicCompareExchangeImpl(*this, p); }];
+ let verifier = [{ return ::verifyAtomicCompareExchangeImpl(*this); }];
+}
+
+// -----
+
+def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> {
+ let summary = [{
+ Perform the following steps atomically with respect to any other atomic
+ accesses within Scope to the same location:
+ }];
+
+ let description = [{
+ 1) load through Pointer to get an Original Value,
+
+ 2) get a New Value from copying Value, and
+
+ 3) store the New Value back through Pointer.
+
+ The instruction's result is the Original Value.
+
+ Result Type must be a scalar of integer type or floating-point type.
+
+ The type of Value must be the same as Result Type. The type of the
+ value pointed to by Pointer must be the same as Result Type.
+
+ Memory is a memory Scope.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ atomic-exchange-op ::=
+ `spv.AtomicCompareExchange` scope memory-semantics
+ ssa-use `,` ssa-use `:` spv-pointer-type
+ ```mlir
+
+ #### Example:
+
+ ```
+ %0 = spv.AtomicExchange "Workgroup" "Acquire" %pointer, %value,
+ : !spv.ptr<i32, WorkGroup>
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_AnyPtr:$pointer,
+ SPV_ScopeAttr:$memory_scope,
+ SPV_MemorySemanticsAttr:$semantics,
+ SPV_Numerical:$value
+ );
+
+ let results = (outs
+ SPV_Numerical:$result
+ );
}
// -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1bbe01a5e8a1b..3b54c80ba0da6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3352,6 +3352,8 @@ def SPV_OC_OpBitReverse : I32EnumAttrCase<"OpBitReverse", 204>;
def SPV_OC_OpBitCount : I32EnumAttrCase<"OpBitCount", 205>;
def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>;
def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>;
+def SPV_OC_OpAtomicExchange : I32EnumAttrCase<"OpAtomicExchange", 229>;
+def SPV_OC_OpAtomicCompareExchange : I32EnumAttrCase<"OpAtomicCompareExchange", 230>;
def SPV_OC_OpAtomicCompareExchangeWeak : I32EnumAttrCase<"OpAtomicCompareExchangeWeak", 231>;
def SPV_OC_OpAtomicIIncrement : I32EnumAttrCase<"OpAtomicIIncrement", 232>;
def SPV_OC_OpAtomicIDecrement : I32EnumAttrCase<"OpAtomicIDecrement", 233>;
@@ -3442,6 +3444,7 @@ def SPV_OpcodeAttr :
SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
+ SPV_OC_OpAtomicExchange, SPV_OC_OpAtomicCompareExchange,
SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 14d58ef107684..554248f9c5c19 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1138,12 +1138,16 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
return success();
}
-//===----------------------------------------------------------------------===//
-// spv.AtomicCompareExchangeWeak
-//===----------------------------------------------------------------------===//
+template <typename T>
+static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
+ printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
+ << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
+ << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
+ << atomOp.getOperands() << " : " << atomOp.pointer().getType();
+}
-static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser,
- OperationState &state) {
+static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
+ OperationState &state) {
spirv::Scope memoryScope;
spirv::MemorySemantics equalSemantics, unequalSemantics;
SmallVector<OpAsmParser::OperandType, 3> operandInfo;
@@ -1173,15 +1177,8 @@ static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser,
return parser.addTypeToList(ptrType.getPointeeType(), state.types);
}
-static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
- OpAsmPrinter &printer) {
- printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
- << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
- << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
- << atomOp.getOperands() << " : " << atomOp.pointer().getType();
-}
-
-static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
+template <typename T>
+static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
// According to the spec:
// "The type of Value must be the same as Result Type. The type of the value
// pointed to by Pointer must be the same as Result Type. This type must also
@@ -1197,8 +1194,10 @@ static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
"result, but found ")
<< atomOp.comparator().getType() << " vs " << atomOp.getType();
- Type pointeeType =
- atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
+ Type pointeeType = atomOp.pointer()
+ .getType()
+ .template cast<spirv::PointerType>()
+ .getPointeeType();
if (atomOp.getType() != pointeeType)
return atomOp.emitOpError(
"pointer operand's pointee type must have the same "
@@ -1211,6 +1210,59 @@ static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
return success();
}
+//===----------------------------------------------------------------------===//
+// spv.AtomicExchange
+//===----------------------------------------------------------------------===//
+
+static void print(spirv::AtomicExchangeOp atomOp, OpAsmPrinter &printer) {
+ printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
+ << stringifyMemorySemantics(atomOp.semantics()) << "\" "
+ << atomOp.getOperands() << " : " << atomOp.pointer().getType();
+}
+
+static ParseResult parseAtomicExchangeOp(OpAsmParser &parser,
+ OperationState &state) {
+ spirv::Scope memoryScope;
+ spirv::MemorySemantics semantics;
+ SmallVector<OpAsmParser::OperandType, 2> operandInfo;
+ Type type;
+ if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
+ parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) ||
+ parser.parseOperandList(operandInfo, 2))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseColonType(type))
+ return failure();
+
+ auto ptrType = type.dyn_cast<spirv::PointerType>();
+ if (!ptrType)
+ return parser.emitError(loc, "expected pointer type");
+
+ if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
+ parser.getNameLoc(), state.operands))
+ return failure();
+
+ return parser.addTypeToList(ptrType.getPointeeType(), state.types);
+}
+
+static LogicalResult verify(spirv::AtomicExchangeOp atomOp) {
+ if (atomOp.getType() != atomOp.value().getType())
+ return atomOp.emitOpError("value operand must have the same type as the op "
+ "result, but found ")
+ << atomOp.value().getType() << " vs " << atomOp.getType();
+
+ Type pointeeType =
+ atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
+ if (atomOp.getType() != pointeeType)
+ return atomOp.emitOpError(
+ "pointer operand's pointee type must have the same "
+ "as the op result type, but found ")
+ << pointeeType << " vs " << atomOp.getType();
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spv.BitcastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
index 7a10878fad5ef..2bc800025f989 100644
--- a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
@@ -37,6 +37,42 @@ func @atomic_and(%ptr : !spv.ptr<i32, StorageBuffer>, %value : i32) -> i32 {
// -----
+//===----------------------------------------------------------------------===//
+// spv.AtomicCompareExchange
+//===----------------------------------------------------------------------===//
+
+func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 {
+ // CHECK: spv.AtomicCompareExchange "Workgroup" "Release" "Acquire" %{{.*}}, %{{.*}}, %{{.*}} : !spv.ptr<i32, Workgroup>
+ %0 = spv.AtomicCompareExchange "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr<i32, Workgroup>
+ return %0: i32
+}
+
+// -----
+
+func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64, %comparator: i32) -> i32 {
+ // expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
+ %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
+ return %0: i32
+}
+
+// -----
+
+func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i16) -> i32 {
+ // expected-error @+1 {{comparator operand must have the same type as the op result, but found 'i16' vs 'i32'}}
+ %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
+ return %0: i32
+}
+
+// -----
+
+func @atomic_compare_exchange(%ptr: !spv.ptr<i64, Workgroup>, %value: i32, %comparator: i32) -> i32 {
+ // expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
+ %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
+ return %0: i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spv.AtomicCompareExchangeWeak
//===----------------------------------------------------------------------===//
@@ -73,6 +109,34 @@ func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i64, Workgroup>, %value: i32,
// -----
+//===----------------------------------------------------------------------===//
+// spv.AtomicExchange
+//===----------------------------------------------------------------------===//
+
+func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32) -> i32 {
+ // CHECK: spv.AtomicExchange "Workgroup" "Release" %{{.*}}, %{{.*}} : !spv.ptr<i32, Workgroup>
+ %0 = spv.AtomicExchange "Workgroup" "Release" %ptr, %value: !spv.ptr<i32, Workgroup>
+ return %0: i32
+}
+
+// -----
+
+func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64) -> i32 {
+ // expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
+ %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr<i32, Workgroup>, i64) -> (i32)
+ return %0: i32
+}
+
+// -----
+
+func @atomic_exchange(%ptr: !spv.ptr<i64, Workgroup>, %value: i32) -> i32 {
+ // expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
+ %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr<i64, Workgroup>, i32) -> (i32)
+ return %0: i32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spv.AtomicIAdd
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/atomic-ops.mlir b/mlir/test/Target/SPIRV/atomic-ops.mlir
index 6bf32af37155b..252d3d5deee23 100644
--- a/mlir/test/Target/SPIRV/atomic-ops.mlir
+++ b/mlir/test/Target/SPIRV/atomic-ops.mlir
@@ -27,6 +27,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%10 = spv.AtomicUMin "Device" "Release" %ptr, %value : !spv.ptr<i32, Workgroup>
// CHECK: spv.AtomicXor "Workgroup" "AcquireRelease" %{{.*}}, %{{.*}} : !spv.ptr<i32, Workgroup>
%11 = spv.AtomicXor "Workgroup" "AcquireRelease" %ptr, %value : !spv.ptr<i32, Workgroup>
+ // CHECK: spv.AtomicCompareExchange "Workgroup" "Release" "Acquire" %{{.*}}, %{{.*}}, %{{.*}} : !spv.ptr<i32, Workgroup>
+ %12 = spv.AtomicCompareExchange "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr<i32, Workgroup>
+ // CHECK: spv.AtomicExchange "Workgroup" "Release" %{{.*}}, %{{.*}} : !spv.ptr<i32, Workgroup>
+ %13 = spv.AtomicExchange "Workgroup" "Release" %ptr, %value: !spv.ptr<i32, Workgroup>
spv.ReturnValue %0: i32
}
}
More information about the Mlir-commits
mailing list