[Mlir-commits] [mlir] 78e172f - [mlir][spirv] Support i32 memref.atomic_rmw conversion
Lei Zhang
llvmlistbot at llvm.org
Wed Feb 15 09:54:11 PST 2023
Author: Lei Zhang
Date: 2023-02-15T17:53:56Z
New Revision: 78e172fc92e74be3347409e4a67432c97f071818
URL: https://github.com/llvm/llvm-project/commit/78e172fc92e74be3347409e4a67432c97f071818
DIFF: https://github.com/llvm/llvm-project/commit/78e172fc92e74be3347409e4a67432c97f071818.diff
LOG: [mlir][spirv] Support i32 memref.atomic_rmw conversion
These cases can be directly mapped to spirv.AtomicI* ops.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D143952
Added:
mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 5a37806034018..35aa87e11a751 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -10,11 +10,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -182,6 +183,17 @@ class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts memref.automic_rmw operations to SPIR-V atomic operations.
+class AtomicRMWOpPattern final
+ : public OpConversionPattern<memref::AtomicRMWOp> {
+public:
+ using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Removed a deallocation if it is a supported allocation. Currently only
/// removes deallocation if the memory space is workgroup memory.
class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
@@ -303,6 +315,62 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
return success();
}
+//===----------------------------------------------------------------------===//
+// AllocOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ if (atomicOp.getType().isa<FloatType>())
+ return rewriter.notifyMatchFailure(atomicOp,
+ "unimplemented floating-point case");
+
+ auto memrefType = atomicOp.getMemref().getType().cast<MemRefType>();
+ std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
+ if (!scope)
+ return rewriter.notifyMatchFailure(atomicOp,
+ "unsupported memref memory space");
+
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Type resultType = typeConverter.convertType(atomicOp.getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(atomicOp,
+ "failed to convert result type");
+
+ auto loc = atomicOp.getLoc();
+ Value ptr =
+ spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
+ adaptor.getIndices(), loc, rewriter);
+
+ if (!ptr)
+ return failure();
+
+#define ATOMIC_CASE(kind, spirvOp) \
+ case arith::AtomicRMWKind::kind: \
+ rewriter.replaceOpWithNewOp<spirv::spirvOp>( \
+ atomicOp, resultType, ptr, *scope, \
+ spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \
+ break
+
+ switch (atomicOp.getKind()) {
+ ATOMIC_CASE(addi, AtomicIAddOp);
+ ATOMIC_CASE(maxs, AtomicSMaxOp);
+ ATOMIC_CASE(maxu, AtomicUMaxOp);
+ ATOMIC_CASE(mins, AtomicSMinOp);
+ ATOMIC_CASE(minu, AtomicUMinOp);
+ ATOMIC_CASE(ori, AtomicOrOp);
+ ATOMIC_CASE(andi, AtomicAndOp);
+ default:
+ return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
+ }
+
+#undef ATOMIC_CASE
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//
@@ -656,9 +724,9 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
namespace mlir {
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern,
- IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
- MemorySpaceCastOpPattern, StoreOpPattern>(typeConverter,
- patterns.getContext());
+ patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+ DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
+ LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern>(
+ typeConverter, patterns.getContext());
}
} // namespace mlir
diff --git a/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
new file mode 100644
index 0000000000000..f72e12611a97e
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt -split-input-file -convert-memref-to-spirv %s -o - | FileCheck %s
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
+
+// CHECK: func.func @atomic_addi_storage_buffer
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_addi_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>, %i0: index, %i1: index, %i2: index) -> i32 {
+ // CHECK: %[[AC:.+]] = spirv.AccessChain
+ // CHECK: %[[ATOMIC:.+]] = spirv.AtomicIAdd "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, StorageBuffer>
+ // CHECK: return %[[ATOMIC]]
+ %0 = memref.atomic_rmw "addi" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>) -> i32
+ return %0: i32
+}
+
+// CHECK: func.func @atomic_maxs_workgroup
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_maxs_workgroup(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<Workgroup>>, %i0: index, %i1: index, %i2: index) -> i32 {
+ // CHECK: %[[AC:.+]] = spirv.AccessChain
+ // CHECK: %[[ATOMIC:.+]] = spirv.AtomicSMax "Workgroup" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, Workgroup>
+ // CHECK: return %[[ATOMIC]]
+ %0 = memref.atomic_rmw "maxs" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<Workgroup>>) -> i32
+ return %0: i32
+}
+
+// CHECK: func.func @atomic_maxu_storage_buffer
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_maxu_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>, %i0: index, %i1: index, %i2: index) -> i32 {
+ // CHECK: %[[AC:.+]] = spirv.AccessChain
+ // CHECK: %[[ATOMIC:.+]] = spirv.AtomicUMax "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, StorageBuffer>
+ // CHECK: return %[[ATOMIC]]
+ %0 = memref.atomic_rmw "maxu" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>) -> i32
+ return %0: i32
+}
+
+// CHECK: func.func @atomic_mins_workgroup
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_mins_workgroup(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<Workgroup>>, %i0: index, %i1: index, %i2: index) -> i32 {
+ // CHECK: %[[AC:.+]] = spirv.AccessChain
+ // CHECK: %[[ATOMIC:.+]] = spirv.AtomicSMin "Workgroup" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, Workgroup>
+ // CHECK: return %[[ATOMIC]]
+ %0 = memref.atomic_rmw "mins" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<Workgroup>>) -> i32
+ return %0: i32
+}
+
+// CHECK: func.func @atomic_minu_storage_buffer
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_minu_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>, %i0: index, %i1: index, %i2: index) -> i32 {
+ // CHECK: %[[AC:.+]] = spirv.AccessChain
+ // CHECK: %[[ATOMIC:.+]] = spirv.AtomicUMin "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, StorageBuffer>
+ // CHECK: return %[[ATOMIC]]
+ %0 = memref.atomic_rmw "minu" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>) -> i32
+ return %0: i32
+}
+
+// CHECK: func.func @atomic_ori_workgroup
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_ori_workgroup(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<Workgroup>>, %i0: index, %i1: index, %i2: index) -> i32 {
+ // CHECK: %[[AC:.+]] = spirv.AccessChain
+ // CHECK: %[[ATOMIC:.+]] = spirv.AtomicOr "Workgroup" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, Workgroup>
+ // CHECK: return %[[ATOMIC]]
+ %0 = memref.atomic_rmw "ori" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<Workgroup>>) -> i32
+ return %0: i32
+}
+
+// CHECK: func.func @atomic_andi_storage_buffer
+// CHECK-SAME: (%[[VAL:.+]]: i32,
+func.func @atomic_andi_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>, %i0: index, %i1: index, %i2: index) -> i32 {
+ // CHECK: %[[AC:.+]] = spirv.AccessChain
+ // CHECK: %[[ATOMIC:.+]] = spirv.AtomicAnd "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr<i32, StorageBuffer>
+ // CHECK: return %[[ATOMIC]]
+ %0 = memref.atomic_rmw "andi" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class<StorageBuffer>>) -> i32
+ return %0: i32
+}
+
+}
+
More information about the Mlir-commits
mailing list