[Mlir-commits] [mlir] 1cb13fd - [mlir] spirv: Add some atomic ops

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 3 04:52:03 PDT 2021


Author: Butygin
Date: 2021-11-03T14:47:12+03:00
New Revision: 1cb13fddb9d87177dc9f543027cb573de45a94bc

URL: https://github.com/llvm/llvm-project/commit/1cb13fddb9d87177dc9f543027cb573de45a94bc
DIFF: https://github.com/llvm/llvm-project/commit/1cb13fddb9d87177dc9f543027cb573de45a94bc.diff

LOG: [mlir] spirv: Add some atomic ops

Differential Revision: https://reviews.llvm.org/D112812

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
    mlir/test/Target/SPIRV/atomic-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
index 122987a212d0e..406f6647a13d0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
@@ -102,6 +102,77 @@ def SPV_AtomicAndOp : SPV_AtomicUpdateWithValueOp<"AtomicAnd", []> {
 
 // -----
 
+def SPV_AtomicCompareExchangeOp : SPV_Op<"AtomicCompareExchange", []> {
+  let summary = [{
+    Perform the following steps atomically with respect to any other atomic
+    accesses within Scope to the same location:
+  }];
+
+  let description = [{
+    1) load through Pointer to get an Original Value,
+
+    2) get a New Value from Value only if Original Value equals Comparator,
+    and
+
+    3) store the New Value back through Pointer'only if 'Original Value
+    equaled Comparator.
+
+    The instruction's result is the Original Value.
+
+    Result Type must be an integer type scalar.
+
+    Use Equal for the memory semantics of this instruction when Value and
+    Original Value compare equal.
+
+    Use Unequal for the memory semantics of this instruction when Value and
+    Original Value compare unequal. Unequal must not be set to Release or
+    Acquire and Release. In addition, Unequal cannot be set to a stronger
+    memory-order then Equal.
+
+     The type of Value must be the same as Result Type.  The type of the
+    value pointed to by Pointer must be the same as Result Type.  This type
+    must also match the type of Comparator.
+
+    Memory is a memory Scope.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    atomic-compare-exchange-op ::=
+        `spv.AtomicCompareExchange` scope memory-semantics memory-semantics
+                                    ssa-use `,` ssa-use `,` ssa-use
+                                    `:` spv-pointer-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %0 = spv.AtomicCompareExchange "Workgroup" "Acquire" "None"
+                                    %pointer, %value, %comparator
+                                    : !spv.ptr<i32, WorkGroup>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_AnyPtr:$pointer,
+    SPV_ScopeAttr:$memory_scope,
+    SPV_MemorySemanticsAttr:$equal_semantics,
+    SPV_MemorySemanticsAttr:$unequal_semantics,
+    SPV_Integer:$value,
+    SPV_Integer:$comparator
+  );
+
+  let results = (outs
+    SPV_Integer:$result
+  );
+
+  let parser = [{ return ::parseAtomicCompareExchangeImpl(parser, result); }];
+  let printer = [{ return ::printAtomicCompareExchangeImpl(*this, p); }];
+  let verifier = [{ return ::verifyAtomicCompareExchangeImpl(*this); }];
+}
+
+// -----
+
 def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> {
   let summary = "Deprecated (use OpAtomicCompareExchange).";
 
@@ -147,6 +218,62 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> {
   let results = (outs
     SPV_Integer:$result
   );
+
+  let parser = [{ return ::parseAtomicCompareExchangeImpl(parser, result); }];
+  let printer = [{ return ::printAtomicCompareExchangeImpl(*this, p); }];
+  let verifier = [{ return ::verifyAtomicCompareExchangeImpl(*this); }];
+}
+
+// -----
+
+def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> {
+  let summary = [{
+    Perform the following steps atomically with respect to any other atomic
+    accesses within Scope to the same location:
+  }];
+
+  let description = [{
+    1) load through Pointer to get an Original Value,
+
+    2) get a New Value from copying Value, and
+
+    3) store the New Value back through Pointer.
+
+    The instruction's result is the Original Value.
+
+    Result Type must be a scalar of integer type or floating-point type.
+
+     The type of Value must be the same as Result Type.  The type of the
+    value pointed to by Pointer must be the same as Result Type.
+
+    Memory is a memory Scope.
+
+    <!-- End of AutoGen section -->
+
+     ```
+    atomic-exchange-op ::=
+        `spv.AtomicCompareExchange` scope memory-semantics
+                                    ssa-use `,` ssa-use `:` spv-pointer-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %0 = spv.AtomicExchange "Workgroup" "Acquire" %pointer, %value,
+                            : !spv.ptr<i32, WorkGroup>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_AnyPtr:$pointer,
+    SPV_ScopeAttr:$memory_scope,
+    SPV_MemorySemanticsAttr:$semantics,
+    SPV_Numerical:$value
+  );
+
+  let results = (outs
+    SPV_Numerical:$result
+  );
 }
 
 // -----

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1bbe01a5e8a1b..3b54c80ba0da6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3352,6 +3352,8 @@ def SPV_OC_OpBitReverse                : I32EnumAttrCase<"OpBitReverse", 204>;
 def SPV_OC_OpBitCount                  : I32EnumAttrCase<"OpBitCount", 205>;
 def SPV_OC_OpControlBarrier            : I32EnumAttrCase<"OpControlBarrier", 224>;
 def SPV_OC_OpMemoryBarrier             : I32EnumAttrCase<"OpMemoryBarrier", 225>;
+def SPV_OC_OpAtomicExchange            : I32EnumAttrCase<"OpAtomicExchange", 229>;
+def SPV_OC_OpAtomicCompareExchange     : I32EnumAttrCase<"OpAtomicCompareExchange", 230>;
 def SPV_OC_OpAtomicCompareExchangeWeak : I32EnumAttrCase<"OpAtomicCompareExchangeWeak", 231>;
 def SPV_OC_OpAtomicIIncrement          : I32EnumAttrCase<"OpAtomicIIncrement", 232>;
 def SPV_OC_OpAtomicIDecrement          : I32EnumAttrCase<"OpAtomicIDecrement", 233>;
@@ -3442,6 +3444,7 @@ def SPV_OpcodeAttr :
       SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert,
       SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse,
       SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
+      SPV_OC_OpAtomicExchange, SPV_OC_OpAtomicCompareExchange,
       SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement,
       SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub,
       SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 14d58ef107684..554248f9c5c19 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1138,12 +1138,16 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spv.AtomicCompareExchangeWeak
-//===----------------------------------------------------------------------===//
+template <typename T>
+static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
+  printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
+          << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
+          << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
+          << atomOp.getOperands() << " : " << atomOp.pointer().getType();
+}
 
-static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser,
-                                                    OperationState &state) {
+static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
+                                                  OperationState &state) {
   spirv::Scope memoryScope;
   spirv::MemorySemantics equalSemantics, unequalSemantics;
   SmallVector<OpAsmParser::OperandType, 3> operandInfo;
@@ -1173,15 +1177,8 @@ static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser,
   return parser.addTypeToList(ptrType.getPointeeType(), state.types);
 }
 
-static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
-                  OpAsmPrinter &printer) {
-  printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
-          << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
-          << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
-          << atomOp.getOperands() << " : " << atomOp.pointer().getType();
-}
-
-static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
+template <typename T>
+static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
   // According to the spec:
   // "The type of Value must be the same as Result Type. The type of the value
   // pointed to by Pointer must be the same as Result Type. This type must also
@@ -1197,8 +1194,10 @@ static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
                "result, but found ")
            << atomOp.comparator().getType() << " vs " << atomOp.getType();
 
-  Type pointeeType =
-      atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
+  Type pointeeType = atomOp.pointer()
+                         .getType()
+                         .template cast<spirv::PointerType>()
+                         .getPointeeType();
   if (atomOp.getType() != pointeeType)
     return atomOp.emitOpError(
                "pointer operand's pointee type must have the same "
@@ -1211,6 +1210,59 @@ static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.AtomicExchange
+//===----------------------------------------------------------------------===//
+
+static void print(spirv::AtomicExchangeOp atomOp, OpAsmPrinter &printer) {
+  printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
+          << stringifyMemorySemantics(atomOp.semantics()) << "\" "
+          << atomOp.getOperands() << " : " << atomOp.pointer().getType();
+}
+
+static ParseResult parseAtomicExchangeOp(OpAsmParser &parser,
+                                         OperationState &state) {
+  spirv::Scope memoryScope;
+  spirv::MemorySemantics semantics;
+  SmallVector<OpAsmParser::OperandType, 2> operandInfo;
+  Type type;
+  if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
+      parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) ||
+      parser.parseOperandList(operandInfo, 2))
+    return failure();
+
+  auto loc = parser.getCurrentLocation();
+  if (parser.parseColonType(type))
+    return failure();
+
+  auto ptrType = type.dyn_cast<spirv::PointerType>();
+  if (!ptrType)
+    return parser.emitError(loc, "expected pointer type");
+
+  if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
+                             parser.getNameLoc(), state.operands))
+    return failure();
+
+  return parser.addTypeToList(ptrType.getPointeeType(), state.types);
+}
+
+static LogicalResult verify(spirv::AtomicExchangeOp atomOp) {
+  if (atomOp.getType() != atomOp.value().getType())
+    return atomOp.emitOpError("value operand must have the same type as the op "
+                              "result, but found ")
+           << atomOp.value().getType() << " vs " << atomOp.getType();
+
+  Type pointeeType =
+      atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
+  if (atomOp.getType() != pointeeType)
+    return atomOp.emitOpError(
+               "pointer operand's pointee type must have the same "
+               "as the op result type, but found ")
+           << pointeeType << " vs " << atomOp.getType();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spv.BitcastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
index 7a10878fad5ef..2bc800025f989 100644
--- a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
@@ -37,6 +37,42 @@ func @atomic_and(%ptr : !spv.ptr<i32, StorageBuffer>, %value : i32) -> i32 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spv.AtomicCompareExchange
+//===----------------------------------------------------------------------===//
+
+func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 {
+  // CHECK: spv.AtomicCompareExchange "Workgroup" "Release" "Acquire" %{{.*}}, %{{.*}}, %{{.*}} : !spv.ptr<i32, Workgroup>
+  %0 = spv.AtomicCompareExchange "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr<i32, Workgroup>
+  return %0: i32
+}
+
+// -----
+
+func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64, %comparator: i32) -> i32 {
+  // expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
+  %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
+  return %0: i32
+}
+
+// -----
+
+func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i16) -> i32 {
+  // expected-error @+1 {{comparator operand must have the same type as the op result, but found 'i16' vs 'i32'}}
+  %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
+  return %0: i32
+}
+
+// -----
+
+func @atomic_compare_exchange(%ptr: !spv.ptr<i64, Workgroup>, %value: i32, %comparator: i32) -> i32 {
+  // expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
+  %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
+  return %0: i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.AtomicCompareExchangeWeak
 //===----------------------------------------------------------------------===//
@@ -73,6 +109,34 @@ func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i64, Workgroup>, %value: i32,
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spv.AtomicExchange
+//===----------------------------------------------------------------------===//
+
+func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32) -> i32 {
+  // CHECK: spv.AtomicExchange "Workgroup" "Release" %{{.*}}, %{{.*}} : !spv.ptr<i32, Workgroup>
+  %0 = spv.AtomicExchange "Workgroup" "Release" %ptr, %value: !spv.ptr<i32, Workgroup>
+  return %0: i32
+}
+
+// -----
+
+func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64) -> i32 {
+  // expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
+  %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr<i32, Workgroup>, i64) -> (i32)
+  return %0: i32
+}
+
+// -----
+
+func @atomic_exchange(%ptr: !spv.ptr<i64, Workgroup>, %value: i32) -> i32 {
+  // expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
+  %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr<i64, Workgroup>, i32) -> (i32)
+  return %0: i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.AtomicIAdd
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Target/SPIRV/atomic-ops.mlir b/mlir/test/Target/SPIRV/atomic-ops.mlir
index 6bf32af37155b..252d3d5deee23 100644
--- a/mlir/test/Target/SPIRV/atomic-ops.mlir
+++ b/mlir/test/Target/SPIRV/atomic-ops.mlir
@@ -27,6 +27,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %10 = spv.AtomicUMin "Device" "Release" %ptr, %value : !spv.ptr<i32, Workgroup>
     // CHECK: spv.AtomicXor "Workgroup" "AcquireRelease" %{{.*}}, %{{.*}} : !spv.ptr<i32, Workgroup>
     %11 = spv.AtomicXor "Workgroup" "AcquireRelease" %ptr, %value : !spv.ptr<i32, Workgroup>
+    // CHECK: spv.AtomicCompareExchange "Workgroup" "Release" "Acquire" %{{.*}}, %{{.*}}, %{{.*}} : !spv.ptr<i32, Workgroup>
+    %12 = spv.AtomicCompareExchange "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr<i32, Workgroup>
+    // CHECK: spv.AtomicExchange "Workgroup" "Release" %{{.*}}, %{{.*}} : !spv.ptr<i32, Workgroup>
+    %13 = spv.AtomicExchange "Workgroup" "Release" %ptr, %value: !spv.ptr<i32, Workgroup>
     spv.ReturnValue %0: i32
   }
 }


        


More information about the Mlir-commits mailing list