[Mlir-commits] [mlir] 78e172f - [mlir][spirv] Support i32 memref.atomic_rmw conversion

Lei Zhang llvmlistbot at llvm.org
Wed Feb 15 09:54:11 PST 2023


Author: Lei Zhang
Date: 2023-02-15T17:53:56Z
New Revision: 78e172fc92e74be3347409e4a67432c97f071818

URL: https://github.com/llvm/llvm-project/commit/78e172fc92e74be3347409e4a67432c97f071818
DIFF: https://github.com/llvm/llvm-project/commit/78e172fc92e74be3347409e4a67432c97f071818.diff

LOG: [mlir][spirv] Support i32 memref.atomic_rmw conversion

These cases can be directly mapped to spirv.AtomicI* ops.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D143952

Added: 
    mlir/test/Conversion/MemRefToSPIRV/atomic.mlir

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 5a37806034018..35aa87e11a751 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -10,11 +10,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
 
@@ -182,6 +183,17 @@ class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
+class AtomicRMWOpPattern final
+    : public OpConversionPattern<memref::AtomicRMWOp> {
+public:
+  using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Removed a deallocation if it is a supported allocation. Currently only
 /// removes deallocation if the memory space is workgroup memory.
 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
@@ -303,6 +315,62 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// AllocOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
+                                    OpAdaptor adaptor,
+                                    ConversionPatternRewriter &rewriter) const {
+  if (atomicOp.getType().isa<FloatType>())
+    return rewriter.notifyMatchFailure(atomicOp,
+                                       "unimplemented floating-point case");
+
+  auto memrefType = atomicOp.getMemref().getType().cast<MemRefType>();
+  std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
+  if (!scope)
+    return rewriter.notifyMatchFailure(atomicOp,
+                                       "unsupported memref memory space");
+
+  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+  Type resultType = typeConverter.convertType(atomicOp.getType());
+  if (!resultType)
+    return rewriter.notifyMatchFailure(atomicOp,
+                                       "failed to convert result type");
+
+  auto loc = atomicOp.getLoc();
+  Value ptr =
+      spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
+                           adaptor.getIndices(), loc, rewriter);
+
+  if (!ptr)
+    return failure();
+
+#define ATOMIC_CASE(kind, spirvOp)                                             \
+  case arith::AtomicRMWKind::kind:                                             \
+    rewriter.replaceOpWithNewOp<spirv::spirvOp>(                               \
+        atomicOp, resultType, ptr, *scope,                                     \
+        spirv::MemorySemantics::AcquireRelease, adaptor.getValue());           \
+    break
+
+  switch (atomicOp.getKind()) {
+    ATOMIC_CASE(addi, AtomicIAddOp);
+    ATOMIC_CASE(maxs, AtomicSMaxOp);
+    ATOMIC_CASE(maxu, AtomicUMaxOp);
+    ATOMIC_CASE(mins, AtomicSMinOp);
+    ATOMIC_CASE(minu, AtomicUMinOp);
+    ATOMIC_CASE(ori, AtomicOrOp);
+    ATOMIC_CASE(andi, AtomicAndOp);
+  default:
+    return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
+  }
+
+#undef ATOMIC_CASE
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // DeallocOp
 //===----------------------------------------------------------------------===//
@@ -656,9 +724,9 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
 namespace mlir {
 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns) {
-  patterns.add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern,
-               IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
-               MemorySpaceCastOpPattern, StoreOpPattern>(typeConverter,
-                                                         patterns.getContext());
+  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+               DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
+               LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern>(
+      typeConverter, patterns.getContext());
 }
 } // namespace mlir

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
new file mode 100644
index 0000000000000..f72e12611a97e
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt -split-input-file -convert-memref-to-spirv %s -o - | FileCheck %s
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
+
+//      CHECK: func.func @atomic_addi_storage_buffer
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_addi_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>, %i0: index, %i1: index, %i2: index) -> i32 {
+  // CHECK: %[[AC:.+]] = spirv.AccessChain
+  // CHECK: %[[ATOMIC:.+]] = spirv.AtomicIAdd "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, StorageBuffer>
+  // CHECK: return %[[ATOMIC]]
+  %0 = memref.atomic_rmw "addi" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>) -> i32
+  return %0: i32
+}
+
+//      CHECK: func.func @atomic_maxs_workgroup
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_maxs_workgroup(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<Workgroup>>, %i0: index, %i1: index, %i2: index) -> i32 {
+  // CHECK: %[[AC:.+]] = spirv.AccessChain
+  // CHECK: %[[ATOMIC:.+]] = spirv.AtomicSMax "Workgroup" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, Workgroup>
+  // CHECK: return %[[ATOMIC]]
+  %0 = memref.atomic_rmw "maxs" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<Workgroup>>) -> i32
+  return %0: i32
+}
+
+//      CHECK: func.func @atomic_maxu_storage_buffer
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_maxu_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>, %i0: index, %i1: index, %i2: index) -> i32 {
+  // CHECK: %[[AC:.+]] = spirv.AccessChain
+  // CHECK: %[[ATOMIC:.+]] = spirv.AtomicUMax "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, StorageBuffer>
+  // CHECK: return %[[ATOMIC]]
+  %0 = memref.atomic_rmw "maxu" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>) -> i32
+  return %0: i32
+}
+
+//      CHECK: func.func @atomic_mins_workgroup
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_mins_workgroup(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<Workgroup>>, %i0: index, %i1: index, %i2: index) -> i32 {
+  // CHECK: %[[AC:.+]] = spirv.AccessChain
+  // CHECK: %[[ATOMIC:.+]] = spirv.AtomicSMin "Workgroup" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, Workgroup>
+  // CHECK: return %[[ATOMIC]]
+  %0 = memref.atomic_rmw "mins" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<Workgroup>>) -> i32
+  return %0: i32
+}
+
+//      CHECK: func.func @atomic_minu_storage_buffer
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_minu_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>, %i0: index, %i1: index, %i2: index) -> i32 {
+  // CHECK: %[[AC:.+]] = spirv.AccessChain
+  // CHECK: %[[ATOMIC:.+]] = spirv.AtomicUMin "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, StorageBuffer>
+  // CHECK: return %[[ATOMIC]]
+  %0 = memref.atomic_rmw "minu" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>) -> i32
+  return %0: i32
+}
+
+//      CHECK: func.func @atomic_ori_workgroup
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_ori_workgroup(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<Workgroup>>, %i0: index, %i1: index, %i2: index) -> i32 {
+  // CHECK: %[[AC:.+]] = spirv.AccessChain
+  // CHECK: %[[ATOMIC:.+]] = spirv.AtomicOr "Workgroup" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, Workgroup>
+  // CHECK: return %[[ATOMIC]]
+  %0 = memref.atomic_rmw "ori" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<Workgroup>>) -> i32
+  return %0: i32
+}
+
+//      CHECK: func.func @atomic_andi_storage_buffer
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_andi_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>, %i0: index, %i1: index, %i2: index) -> i32 {
+  // CHECK: %[[AC:.+]] = spirv.AccessChain
+  // CHECK: %[[ATOMIC:.+]] = spirv.AtomicAnd "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, StorageBuffer>
+  // CHECK: return %[[ATOMIC]]
+  %0 = memref.atomic_rmw "andi" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>) -> i32
+  return %0: i32
+}
+
+}
+


        


More information about the Mlir-commits mailing list