[Mlir-commits] [mlir] [mlir][amdgpu] Shared memory access optimization pass (PR #75627)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 16 16:21:19 PST 2024
https://github.com/erman-gurses updated https://github.com/llvm/llvm-project/pull/75627
>From 87689f16e2c06dcebfef2e02aa97946476a1cc0b Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Fri, 15 Dec 2023 09:02:41 -0800
Subject: [PATCH 01/11] [mlir][amdgpu] Shared memory access optimization pass
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 27 ++
.../mlir/Dialect/AMDGPU/Transforms/Passes.h | 4 +
.../mlir/Dialect/AMDGPU/Transforms/Passes.td | 8 +
.../Dialect/AMDGPU/Transforms/Transforms.h | 54 ++++
.../mlir/Dialect/AMDGPU/Transforms/Utils.h | 21 ++
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 15 ++
.../Dialect/AMDGPU/Transforms/CMakeLists.txt | 2 +
.../Transforms/OptimizeSharedMemory.cpp | 252 ++++++++++++++++++
mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp | 48 ++++
.../AMDGPU/optimize_shmem_reads_writes.mlir | 57 ++++
10 files changed, 488 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
create mode 100644 mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
create mode 100644 mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index ffb302fcedd732..324c656f47599e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -29,6 +29,33 @@ def AMDGPU_Dialect : Dialect {
"gpu::GPUDialect"
];
let useDefaultAttributePrinterParser = 1;
+
+ let extraClassDeclaration = [{
+ /// Return true if the given MemRefType has an integer address
+ /// space that matches the ROCDL shared memory address space or
+ /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool hasSharedMemoryAddressSpace(MemRefType type);
+
+ /// Return true if the given Attribute has an integer address
+ /// space that matches the ROCDL shared memory address space or
+ /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool isSharedMemoryAddressSpace(Attribute type);
+
+ /// Defines the MemRef memory space attribute numeric value that indicates
+ /// a memref is located in global memory. This should correspond to the
+ /// value used in ROCDL.
+ static constexpr unsigned kGlobalMemoryAddressSpace = 1;
+
+ /// Defines the MemRef memory space attribute numeric value that indicates
+ /// a memref is located in private memory. This should correspond to the
+ /// value used in ROCDL.
+ static constexpr unsigned kPrivateMemoryAddressSpace = 2;
+
+ /// Defines the MemRef memory space attribute numeric value that indicates
+ /// a memref is located in shared memory. This should correspond to the
+ /// value used in ROCDL.
+ static constexpr unsigned kSharedMemoryAddressSpace = 3;
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index 8dd5ff1a4b198a..752078cd6930e3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -21,6 +21,10 @@ class ConversionTarget;
namespace amdgpu {
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
+
+/// Create a pass to optimize shared memory reads and writes.
+std::unique_ptr<Pass> createOptimizeSharedMemoryPass();
+
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index e6b27aa842dfcd..1c12ca98271127 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -30,4 +30,12 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
"Chipset that these operations will run on">];
}
+def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
+ let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts.";
+ let constructor = "mlir::amdgpu::createOptimizeSharedMemoryPass()";
+ let dependentDialects = [
+ "memref::MemRefDialect", "vector::VectorDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
new file mode 100644
index 00000000000000..140bc12deed690
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -0,0 +1,54 @@
+//===- Transforms.h - AMDGPU Dialect transformations --------------*-
+// C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares functions that assist transformations for the amdgpu
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+class RewriterBase;
+
+namespace amdgpu {
+
+///
+/// Passes
+///
+
+/// Optimizes vectorized accesses to a shared memory buffer specified by
+/// memrefValue. This transformation assumes the following:
+/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
+/// 2) The function will fail precondition checks if any subviews are
+/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur
+/// through `memrefValue` directly.
+///
+/// Shared memory bank conflicts occur when multiple threads attempt to read or
+/// write locations assigned to the same shared memory bank. For `2^N` byte
+/// vectorized accesses, we need to be concerned with conflicts among threads
+/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation
+/// changes any indexed memory access (vector.load, memref.load, etc)
+/// such that the final dimension's index value is permuted such that
+/// `newColIndex = oldColIndex % vectorSize +
+/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the
+/// index for the second-to last dimension and `perm[rowIndex]` is a permutation
+/// function that depends on the row Index. The permutation function is chosen
+/// to ensure that sequential distributed+vectorized reads/writes down a single
+/// dimension of the memref have minimal conflicts.
+mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+ Value memrefValue);
+
+} // namespace amdgpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
new file mode 100644
index 00000000000000..bee3af1914feef
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
@@ -0,0 +1,21 @@
+//===- Utils.h - Transform utilities -----------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace amdgpu {
+
+/// Get the indices that the given load/store operation is operating on.
+Operation::operand_range getIndices(Operation *op);
+
+/// Set the indices that the given load/store operation is operating on.
+void setIndices(Operation *op, ArrayRef<Value> indices);
+
+} // namespace amdgpu
+} // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2575ad4984814b..4e72fbf56b80a4 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
>();
}
+bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+ if (!memorySpace)
+ return false;
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+ return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace;
+ if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
+ return false;
+}
+
+bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+ Attribute memorySpace = type.getMemorySpace();
+ return isSharedMemoryAddressSpace(memorySpace);
+}
+
//===----------------------------------------------------------------------===//
// 8-bit float ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index e11b6cc88bf224..a1a91270bc55c4 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
+ OptimizeSharedMemory.cpp
+ Utils.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
new file mode 100644
index 00000000000000..0a2f04f4e6487f
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -0,0 +1,252 @@
+//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to optimize accesses to shared memory.
+// It is inspired by
+// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace amdgpu {
+#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace amdgpu
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// The size of a shared memory line according to AMD documentation.
+/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
+constexpr int64_t kSharedMemoryLineSizeBytes = 64;
+/// We optimize for 64bit accesses, but this can be made an argument in the
+/// future.
+constexpr int64_t kDefaultVectorSizeBits = 64;
+
+/// Uses `srcIndexValue` to permute `tgtIndexValue` via
+/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
+/// floordiv(tgtIdxVal,vectorSize)))
+/// + tgtIdxVal % vectorSize`
+/// This is done using an optimized sequence of `arith` operations.
+static Value permuteVectorOffset(OpBuilder &b, Location loc,
+ ArrayRef<Value> indices, MemRefType memrefTy,
+ int64_t srcDim, int64_t tgtDim) {
+ // Adjust the src index to change how often the permutation changes
+ // if necessary.
+ Value src = indices[srcDim];
+
+ // We only want to permute every N iterations of the target dim where N is
+ // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+ const int64_t permuteEveryN = std::max<int64_t>(
+ 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
+ memrefTy.getElementTypeBitWidth()) /
+ 8));
+
+ // clang-format off
+ // Index bit representation (b0 = least significant bit) for dim(1)
+ // of a `memref<?x?xDT>` is as follows:
+ // N := log2(128/elementSizeBits)
+ // M := log2(dimSize(1))
+ // then
+ // bits[0:N] = sub-vector element offset
+ // bits[N:M] = vector index
+ // clang-format on
+ int64_t n =
+ llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
+ int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
+
+ // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
+ int64_t mask = (1LL << (m - n)) - 1;
+ if (permuteEveryN > 1)
+ mask = mask << llvm::Log2_64(permuteEveryN);
+ Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
+ srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
+
+ // Use the src bits to permute the target bits b[N:M] containing the
+ // vector offset.
+ if (permuteEveryN > 1) {
+ int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
+ if (shlBits > 0) {
+ Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
+ srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+ } else if (shlBits < 0) {
+ Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
+ srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
+ }
+ } else {
+ Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
+ srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+ }
+
+ Value permutedVectorIdx =
+ b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
+ return permutedVectorIdx;
+}
+
+static void transformIndices(OpBuilder &builder, Location loc,
+ SmallVector<Value, 4> &indices,
+ MemRefType memrefTy, int64_t srcDim,
+ int64_t tgtDim) {
+ indices[tgtDim] =
+ permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
+}
+
+/// Return all operations within `parentOp` that read from or write to
+/// `shmMemRef`.
+static LogicalResult
+getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
+ SmallVector<Operation *, 16> &readOps,
+ SmallVector<Operation *, 16> &writeOps) {
+ parentOp->walk([&](Operation *op) {
+ MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (!iface)
+ return;
+ std::optional<MemoryEffects::EffectInstance> effect =
+ iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
+ if (effect) {
+ readOps.push_back(op);
+ return;
+ }
+ effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
+ if (effect)
+ writeOps.push_back(op);
+ });
+
+ // Restrict to a supported set of ops. We also require at least 2D access,
+ // although this could be relaxed.
+ if (llvm::any_of(readOps, [](Operation *op) {
+ return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
+ op) ||
+ amdgpu::getIndices(op).size() < 2;
+ }))
+ return failure();
+ if (llvm::any_of(writeOps, [](Operation *op) {
+ return !isa<memref::StoreOp, vector::StoreOp, vector::TransferWriteOp>(
+ op) ||
+ amdgpu::getIndices(op).size() < 2;
+ }))
+ return failure();
+
+ return success();
+}
+
+mlir::LogicalResult
+mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+ Value memrefValue) {
+ auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
+ if (!memRefType ||
+ !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
+ return failure();
+
+ // Abort if the given value has any sub-views; we do not do any alias
+ // analysis.
+ bool hasSubView = false;
+ parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
+ if (hasSubView)
+ return failure();
+
+ // Check if this is necessary given the assumption of 128b accesses:
+ // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+ const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
+ const int64_t rowsPerLine =
+ (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
+ rowSize;
+ const int64_t threadGroupSize =
+ 1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
+ if (rowsPerLine >= threadGroupSize)
+ return failure();
+
+ // Get sets of operations within the function that read/write to shared
+ // memory.
+ SmallVector<Operation *, 16> shmReadOps;
+ SmallVector<Operation *, 16> shmWriteOps;
+ if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
+ shmWriteOps)))
+ return failure();
+
+ if (shmReadOps.empty() || shmWriteOps.empty())
+ return failure();
+
+ OpBuilder builder(parentOp->getContext());
+
+ int64_t tgtDim = memRefType.getRank() - 1;
+ int64_t srcDim = memRefType.getRank() - 2;
+
+ // Transform indices for the ops writing to shared memory.
+ while (!shmWriteOps.empty()) {
+ Operation *shmWriteOp = shmWriteOps.back();
+ shmWriteOps.pop_back();
+ builder.setInsertionPoint(shmWriteOp);
+
+ auto indices = amdgpu::getIndices(shmWriteOp);
+ SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+ transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
+ memRefType, srcDim, tgtDim);
+ amdgpu::setIndices(shmWriteOp, transformedIndices);
+ }
+
+ // Transform indices for the ops reading from shared memory.
+ while (!shmReadOps.empty()) {
+ Operation *shmReadOp = shmReadOps.back();
+ shmReadOps.pop_back();
+ builder.setInsertionPoint(shmReadOp);
+
+ auto indices = amdgpu::getIndices(shmReadOp);
+ SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+ transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
+ memRefType, srcDim, tgtDim);
+ amdgpu::setIndices(shmReadOp, transformedIndices);
+ }
+
+ return success();
+}
+
+namespace {
+class OptimizeSharedMemoryPass
+ : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
+public:
+ OptimizeSharedMemoryPass() = default;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ SmallVector<memref::AllocOp> shmAllocOps;
+ op->walk([&](memref::AllocOp allocOp) {
+ if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(
+ allocOp.getType()))
+ return;
+ shmAllocOps.push_back(allocOp);
+ });
+ for (auto allocOp : shmAllocOps) {
+ if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
+ allocOp.getMemref())))
+ return;
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::amdgpu::createOptimizeSharedMemoryPass() {
+ return std::make_unique<OptimizeSharedMemoryPass>();
+}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
new file mode 100644
index 00000000000000..a1dc6cf70e7bf8
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
@@ -0,0 +1,48 @@
+//===- Utils.cpp - Transform utilities ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+Operation::operand_range amdgpu::getIndices(Operation *op) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getIndices();
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getIndices();
+ if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
+ return vectorReadOp.getIndices();
+ if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+ return vectorStoreOp.getIndices();
+ if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+ return transferReadOp.getIndices();
+ if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+ return transferWriteOp.getIndices();
+ llvm_unreachable("unsupported op type");
+}
+
+void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getIndicesMutable().assign(indices);
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getIndicesMutable().assign(indices);
+ if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
+ return vectorReadOp.getIndicesMutable().assign(indices);
+ if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+ return vectorStoreOp.getIndicesMutable().assign(indices);
+ if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+ return transferReadOp.getIndicesMutable().assign(indices);
+ if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+ return transferWriteOp.getIndicesMutable().assign(indices);
+ llvm_unreachable("unsupported op type");
+}
diff --git a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
new file mode 100644
index 00000000000000..41111dddda5205
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
+
+ // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
+ func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
+ %readRow: index, %readCol: index,
+ %writeRow: index, %writeCol: index,
+ %fragRow: index, %fragCol: index,
+ %fragColPerm: index,
+ %stRow: index, %stCol: index) {
+ // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
+ %cst = arith.constant 0.000000e+00 : f16
+
+ // CHECK: [[shmA:%.+]] = memref.alloc
+ // CHECK: [[shmB:%.+]] = memref.alloc
+ %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
+ %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
+
+ // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
+ %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
+ // CHECK: [[c7:%.+]] = arith.constant 7 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
+ // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
+ vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
+ gpu.barrier
+ gpu.barrier
+ // CHECK: [[c7:%.+]] = arith.constant 7 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
+ // CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
+ %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
+
+ // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
+ %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
+ // CHECK: [[c7:%.+]] = arith.constant 7 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
+ // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
+ vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
+ gpu.barrier
+ gpu.barrier
+ // CHECK: [[c7:%.+]] = arith.constant 7 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
+ // CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
+ %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
+ return
+ }
+
\ No newline at end of file
>From c072744c97c328360afbb96eaabd6033665dc965 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Sun, 7 Jan 2024 22:21:29 -0800
Subject: [PATCH 02/11] Add a fix for bad line break
---
mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 0a2f04f4e6487f..c80beed0ed1d91 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -1,5 +1,4 @@
-//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation
-//----------===//
+//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
>From 5384ebaed5b156ce5f7446da6f054a59a459b668 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Mon, 8 Jan 2024 15:35:15 -0500
Subject: [PATCH 03/11] Remove constructor to enable autogeneration
---
mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h | 5 +----
mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td | 1 -
mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 4 ----
3 files changed, 1 insertion(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index 752078cd6930e3..11d182ba5823e8 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -20,10 +20,7 @@ namespace mlir {
class ConversionTarget;
namespace amdgpu {
-#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
-
-/// Create a pass to optimize shared memory reads and writes.
-std::unique_ptr<Pass> createOptimizeSharedMemoryPass();
+#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 1c12ca98271127..1b1543c2d38971 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -32,7 +32,6 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts.";
- let constructor = "mlir::amdgpu::createOptimizeSharedMemoryPass()";
let dependentDialects = [
"memref::MemRefDialect", "vector::VectorDialect"
];
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index c80beed0ed1d91..81d98c9225de42 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -245,7 +245,3 @@ class OptimizeSharedMemoryPass
}
};
} // namespace
-
-std::unique_ptr<Pass> mlir::amdgpu::createOptimizeSharedMemoryPass() {
- return std::make_unique<OptimizeSharedMemoryPass>();
-}
>From d78fd01e7f856caf0ab6fee41692f5c959befe7c Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Mon, 8 Jan 2024 19:38:02 -0800
Subject: [PATCH 04/11] Remove unused constant expressions
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 10 ----------
1 file changed, 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 324c656f47599e..b4bf1b5191232d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -41,16 +41,6 @@ def AMDGPU_Dialect : Dialect {
/// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
static bool isSharedMemoryAddressSpace(Attribute type);
- /// Defines the MemRef memory space attribute numeric value that indicates
- /// a memref is located in global memory. This should correspond to the
- /// value used in ROCDL.
- static constexpr unsigned kGlobalMemoryAddressSpace = 1;
-
- /// Defines the MemRef memory space attribute numeric value that indicates
- /// a memref is located in private memory. This should correspond to the
- /// value used in ROCDL.
- static constexpr unsigned kPrivateMemoryAddressSpace = 2;
-
/// Defines the MemRef memory space attribute numeric value that indicates
/// a memref is located in shared memory. This should correspond to the
/// value used in ROCDL.
>From 811fea3dd804bcf5f87b16f2c2588479a9a55fcb Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Tue, 9 Jan 2024 08:17:04 -0800
Subject: [PATCH 05/11] Add simplification for read/write ops initialization
---
mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 81d98c9225de42..d004d258bebe61 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -195,8 +195,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
// Transform indices for the ops writing to shared memory.
while (!shmWriteOps.empty()) {
- Operation *shmWriteOp = shmWriteOps.back();
- shmWriteOps.pop_back();
+ Operation *shmWriteOp = shmWriteOps.pop_back_val();
builder.setInsertionPoint(shmWriteOp);
auto indices = amdgpu::getIndices(shmWriteOp);
@@ -208,8 +207,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
// Transform indices for the ops reading from shared memory.
while (!shmReadOps.empty()) {
- Operation *shmReadOp = shmReadOps.back();
- shmReadOps.pop_back();
+ Operation *shmReadOp = shmReadOps.pop_back_val();
builder.setInsertionPoint(shmReadOp);
auto indices = amdgpu::getIndices(shmReadOp);
>From f8b4c06dfcef6af4a84a36af0a15fc0eb0ba6849 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Wed, 10 Jan 2024 08:31:30 -0800
Subject: [PATCH 06/11] Add description for utils
---
mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
index bee3af1914feef..b39e25d1a8826f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
@@ -11,10 +11,13 @@
namespace mlir {
namespace amdgpu {
-/// Get the indices that the given load/store operation is operating on.
+/// Get and set the indices that the given load/store operation is operating on.
+/// Preconditions:
+/// - The Op must have memory affects
+/// - Considers memref::LoadOp, vector::LoadOp, vector::TransferReadOp
+/// - Considers memref::StoreOp, vector::StoreOp, vector::TransferWriteOp
+/// - Excludes subview op
Operation::operand_range getIndices(Operation *op);
-
-/// Set the indices that the given load/store operation is operating on.
void setIndices(Operation *op, ArrayRef<Value> indices);
} // namespace amdgpu
>From 653c1ae031d9a05b6606c63b525b2ad6f7559e44 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Thu, 11 Jan 2024 07:06:24 -0800
Subject: [PATCH 07/11] Add description
---
mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 1b1543c2d38971..c8059e6d316e8a 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -32,6 +32,12 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts.";
+ let description = [{
+ This pass adds a transformation and pass to the AMDGPU dialect that
+ attempts to optimize reads/writes from a memref representing GPU shared
+ memory in order to avoid bank conflicts.
+ }];
+
let dependentDialects = [
"memref::MemRefDialect", "vector::VectorDialect"
];
>From fe9f5a957c4e588acff3b2564a5e24fa3ce5e1fe Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Thu, 11 Jan 2024 07:13:28 -0800
Subject: [PATCH 08/11] Add change the pass data type as struct
---
mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index d004d258bebe61..4a7dc2f20afd98 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -221,7 +221,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
}
namespace {
-class OptimizeSharedMemoryPass
+struct OptimizeSharedMemoryPass
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
public:
OptimizeSharedMemoryPass() = default;
>From 983e956e175bbb9c689a53e11555ab3ca17e1612 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Thu, 11 Jan 2024 07:19:40 -0800
Subject: [PATCH 09/11] Remove anonymous namespace
---
mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 4a7dc2f20afd98..2ce1ed72856dc2 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -220,7 +220,6 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
return success();
}
-namespace {
struct OptimizeSharedMemoryPass
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
public:
@@ -242,4 +241,3 @@ struct OptimizeSharedMemoryPass
}
}
};
-} // namespace
>From b722dd7bd16bd96aac7f3516b1b85cca44c24ef4 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Thu, 11 Jan 2024 11:58:44 -0800
Subject: [PATCH 10/11] Add optional for the util function
---
mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h | 2 +-
.../Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 8 ++++----
mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp | 5 ++---
3 files changed, 7 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
index b39e25d1a8826f..6be57ca54b15f8 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
@@ -17,7 +17,7 @@ namespace amdgpu {
/// - Considers memref::LoadOp, vector::LoadOp, vector::TransferReadOp
/// - Considers memref::StoreOp, vector::StoreOp, vector::TransferWriteOp
/// - Excludes subview op
-Operation::operand_range getIndices(Operation *op);
+std::optional<Operation::operand_range> getIndices(Operation *op);
void setIndices(Operation *op, ArrayRef<Value> indices);
} // namespace amdgpu
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 2ce1ed72856dc2..c7001fc6d57d5f 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -138,13 +138,13 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
if (llvm::any_of(readOps, [](Operation *op) {
return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
op) ||
- amdgpu::getIndices(op).size() < 2;
+ amdgpu::getIndices(op)->size() < 2;
}))
return failure();
if (llvm::any_of(writeOps, [](Operation *op) {
return !isa<memref::StoreOp, vector::StoreOp, vector::TransferWriteOp>(
op) ||
- amdgpu::getIndices(op).size() < 2;
+ amdgpu::getIndices(op)->size() < 2;
}))
return failure();
@@ -199,7 +199,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
builder.setInsertionPoint(shmWriteOp);
auto indices = amdgpu::getIndices(shmWriteOp);
- SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+ SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
memRefType, srcDim, tgtDim);
amdgpu::setIndices(shmWriteOp, transformedIndices);
@@ -211,7 +211,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
builder.setInsertionPoint(shmReadOp);
auto indices = amdgpu::getIndices(shmReadOp);
- SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+ SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
memRefType, srcDim, tgtDim);
amdgpu::setIndices(shmReadOp, transformedIndices);
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
index a1dc6cf70e7bf8..05ac29bfcfaec3 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
@@ -15,7 +15,7 @@
using namespace mlir;
using namespace mlir::amdgpu;
-Operation::operand_range amdgpu::getIndices(Operation *op) {
+std::optional<Operation::operand_range> amdgpu::getIndices(Operation *op) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return loadOp.getIndices();
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
@@ -28,7 +28,7 @@ Operation::operand_range amdgpu::getIndices(Operation *op) {
return transferReadOp.getIndices();
if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteOp.getIndices();
- llvm_unreachable("unsupported op type");
+ return std::nullopt;
}
void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
@@ -44,5 +44,4 @@ void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
return transferReadOp.getIndicesMutable().assign(indices);
if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteOp.getIndicesMutable().assign(indices);
- llvm_unreachable("unsupported op type");
}
>From 59d5e32296f2d3043d459f4d6401514bc0dd971a Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Tue, 16 Jan 2024 16:19:26 -0800
Subject: [PATCH 11/11] Add interface for get and set indices functions
---
mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp | 78 +++++++++++++-------
1 file changed, 53 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
index 05ac29bfcfaec3..e4cbcf59c90239 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
@@ -15,33 +15,61 @@
using namespace mlir;
using namespace mlir::amdgpu;
+// Define an interface for operations with indices
+class IndicesInterface {
+public:
+ virtual std::optional<Operation::operand_range> getIndices() = 0;
+ virtual void setIndices(ArrayRef<Value> indices) = 0;
+ virtual ~IndicesInterface() = default;
+};
+
+// Implement a generic class that uses IndicesInterface
+class OperationWithIndices : public IndicesInterface {
+private:
+ Operation *op;
+
+public:
+ OperationWithIndices(Operation *op) : op(op) {}
+
+ std::optional<Operation::operand_range> getIndices() override {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getIndices();
+ else if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getIndices();
+ else if (auto vectorLoadOp = dyn_cast<vector::LoadOp>(op))
+ return vectorLoadOp.getIndices();
+ else if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+ return vectorStoreOp.getIndices();
+ else if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+ return transferReadOp.getIndices();
+ else if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+ return transferWriteOp.getIndices();
+ else
+ return std::nullopt;
+ }
+
+ void setIndices(ArrayRef<Value> indices) override {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ loadOp.getIndicesMutable().assign(indices);
+ else if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ storeOp.getIndicesMutable().assign(indices);
+ else if (auto vectorLoadOp = dyn_cast<vector::LoadOp>(op))
+ vectorLoadOp.getIndicesMutable().assign(indices);
+ else if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+ vectorStoreOp.getIndicesMutable().assign(indices);
+ else if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+ transferReadOp.getIndicesMutable().assign(indices);
+ else if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+ transferWriteOp.getIndicesMutable().assign(indices);
+ }
+};
+
std::optional<Operation::operand_range> amdgpu::getIndices(Operation *op) {
- if (auto loadOp = dyn_cast<memref::LoadOp>(op))
- return loadOp.getIndices();
- if (auto storeOp = dyn_cast<memref::StoreOp>(op))
- return storeOp.getIndices();
- if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
- return vectorReadOp.getIndices();
- if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
- return vectorStoreOp.getIndices();
- if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
- return transferReadOp.getIndices();
- if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
- return transferWriteOp.getIndices();
- return std::nullopt;
+ OperationWithIndices operationWithIndices(op);
+ return operationWithIndices.getIndices();
}
void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
- if (auto loadOp = dyn_cast<memref::LoadOp>(op))
- return loadOp.getIndicesMutable().assign(indices);
- if (auto storeOp = dyn_cast<memref::StoreOp>(op))
- return storeOp.getIndicesMutable().assign(indices);
- if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
- return vectorReadOp.getIndicesMutable().assign(indices);
- if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
- return vectorStoreOp.getIndicesMutable().assign(indices);
- if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
- return transferReadOp.getIndicesMutable().assign(indices);
- if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
- return transferWriteOp.getIndicesMutable().assign(indices);
+ OperationWithIndices operationWithIndices(op);
+ operationWithIndices.setIndices(indices);
}
More information about the Mlir-commits
mailing list