[Mlir-commits] [mlir] [mlir][amdgpu] Shared memory access optimization pass (PR #75627)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 18 14:08:12 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-amdgpu

Author: None (erman-gurses)

<details>
<summary>Changes</summary>

It implements transformation to optimize accesses to shared memory.

Reference: https://reviews.llvm.org/D127457

_This change adds a transformation and pass to the NvGPU dialect that
attempts to optimize reads/writes from a memref representing GPU shared
memory in order to avoid bank conflicts. Given a value representing a
shared memory memref, it traverses all reads/writes within the parent op
and, subject to suitable conditions, rewrites all last dimension index
values such that element locations in the final (col) dimension are
given by newColIdx = col % vecSize + perm[row](col / vecSize, row)
where perm is a permutation function indexed by row and vecSize
is the vector access size in elements (currently assumes 128bit
vectorized accesses, but this can be made a parameter). This specific
transformation can help optimize typical distributed & vectorized accesses
common to loading matrix multiplication operands to/from shared memory._

---

Patch is 25.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75627.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+27) 
- (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+4) 
- (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+8) 
- (added) mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h (+54) 
- (added) mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h (+21) 
- (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+15) 
- (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+2) 
- (added) mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp (+252) 
- (added) mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp (+48) 
- (added) mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir (+57) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index ffb302fcedd732..324c656f47599e 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -29,6 +29,33 @@ def AMDGPU_Dialect : Dialect {
     "gpu::GPUDialect"
   ];
   let useDefaultAttributePrinterParser = 1;
+
+  let extraClassDeclaration = [{
+    /// Return true if the given MemRefType has an integer address
+    /// space that matches the ROCDL shared memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool hasSharedMemoryAddressSpace(MemRefType type);
+
+    /// Return true if the given Attribute has an integer address
+    /// space that matches the ROCDL shared memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool isSharedMemoryAddressSpace(Attribute type);
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in global memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kGlobalMemoryAddressSpace = 1;
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in private memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kPrivateMemoryAddressSpace = 2;
+
+    /// Defines the MemRef memory space attribute numeric value that indicates
+    /// a memref is located in shared memory. This should correspond to the
+    /// value used in ROCDL.
+    static constexpr unsigned kSharedMemoryAddressSpace = 3;
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index 8dd5ff1a4b198a..752078cd6930e3 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -21,6 +21,10 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
+
+/// Create a pass to optimize shared memory reads and writes.
+std::unique_ptr<Pass> createOptimizeSharedMemoryPass();
+
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index e6b27aa842dfcd..1c12ca98271127 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -30,4 +30,12 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
                         "Chipset that these operations will run on">];
 }
 
+def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
+  let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts.";
+  let constructor = "mlir::amdgpu::createOptimizeSharedMemoryPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
new file mode 100644
index 00000000000000..140bc12deed690
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -0,0 +1,54 @@
+//===- Transforms.h - AMDGPU Dialect transformations --------------*-
+// C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares functions that assist transformations for the amdgpu
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
+
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+class RewriterBase;
+
+namespace amdgpu {
+
+///
+/// Passes
+///
+
+/// Optimizes vectorized accesses to a shared memory buffer specified by
+/// memrefValue. This transformation assumes the following:
+/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
+/// 2) The function will fail precondition checks if any subviews are
+/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur
+/// through `memrefValue` directly.
+///
+/// Shared memory bank conflicts occur when multiple threads attempt to read or
+/// write locations assigned to the same shared memory bank. For `2^N` byte
+/// vectorized accesses, we need to be concerned with conflicts among threads
+/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation
+/// changes any indexed memory access (vector.load, memref.load, etc)
+/// such that the final dimension's index value is permuted such that
+/// `newColIndex = oldColIndex % vectorSize +
+/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the
+/// index for the second-to last dimension and `perm[rowIndex]` is a permutation
+/// function that depends on the row Index. The permutation function is chosen
+/// to ensure that sequential distributed+vectorized reads/writes down a single
+/// dimension of the memref have minimal conflicts.
+mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                       Value memrefValue);
+
+} // namespace amdgpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
new file mode 100644
index 00000000000000..bee3af1914feef
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h
@@ -0,0 +1,21 @@
+//===- Utils.h - Transform utilities -----------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace amdgpu {
+
+/// Get the indices that the given load/store operation is operating on.
+Operation::operand_range getIndices(Operation *op);
+
+/// Set the indices that the given load/store operation is operating on.
+void setIndices(Operation *op, ArrayRef<Value> indices);
+
+} // namespace amdgpu
+} // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2575ad4984814b..4e72fbf56b80a4 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
       >();
 }
 
+bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace;
+  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
+  return false;
+}
+
+bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+  Attribute memorySpace = type.getMemorySpace();
+  return isSharedMemoryAddressSpace(memorySpace);
+}
+
 //===----------------------------------------------------------------------===//
 // 8-bit float ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index e11b6cc88bf224..a1a91270bc55c4 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
+  OptimizeSharedMemory.cpp
+  Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
new file mode 100644
index 00000000000000..0a2f04f4e6487f
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -0,0 +1,252 @@
+//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation
+//----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to optimize accesses to shared memory.
+// It is inspired by
+// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace amdgpu {
+#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+} // namespace amdgpu
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+/// The size of a shared memory line according to AMD documentation.
+/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
+constexpr int64_t kSharedMemoryLineSizeBytes = 64;
+/// We optimize for 64bit accesses, but this can be made an argument in the
+/// future.
+constexpr int64_t kDefaultVectorSizeBits = 64;
+
+/// Uses `srcIndexValue` to permute `tgtIndexValue` via
+/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
+///               floordiv(tgtIdxVal,vectorSize)))
+///            + tgtIdxVal % vectorSize`
+/// This is done using an optimized sequence of `arith` operations.
+static Value permuteVectorOffset(OpBuilder &b, Location loc,
+                                 ArrayRef<Value> indices, MemRefType memrefTy,
+                                 int64_t srcDim, int64_t tgtDim) {
+  // Adjust the src index to change how often the permutation changes
+  // if necessary.
+  Value src = indices[srcDim];
+
+  // We only want to permute every N iterations of the target dim where N is
+  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+  const int64_t permuteEveryN = std::max<int64_t>(
+      1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
+                                        memrefTy.getElementTypeBitWidth()) /
+                                       8));
+
+  // clang-format off
+  // Index bit representation (b0 = least significant bit) for dim(1)
+  // of a `memref<?x?xDT>` is as follows:
+  // N := log2(128/elementSizeBits)
+  // M := log2(dimSize(1))
+  // then
+  // bits[0:N] = sub-vector element offset
+  // bits[N:M] = vector index
+  // clang-format on
+  int64_t n =
+      llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
+  int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
+
+  // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
+  int64_t mask = (1LL << (m - n)) - 1;
+  if (permuteEveryN > 1)
+    mask = mask << llvm::Log2_64(permuteEveryN);
+  Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
+  srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
+
+  // Use the src bits to permute the target bits b[N:M] containing the
+  // vector offset.
+  if (permuteEveryN > 1) {
+    int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
+    if (shlBits > 0) {
+      Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
+      srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+    } else if (shlBits < 0) {
+      Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
+      srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
+    }
+  } else {
+    Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
+    srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
+  }
+
+  Value permutedVectorIdx =
+      b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
+  return permutedVectorIdx;
+}
+
+static void transformIndices(OpBuilder &builder, Location loc,
+                             SmallVector<Value, 4> &indices,
+                             MemRefType memrefTy, int64_t srcDim,
+                             int64_t tgtDim) {
+  indices[tgtDim] =
+      permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
+}
+
+/// Return all operations within `parentOp` that read from or write to
+/// `shmMemRef`.
+static LogicalResult
+getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
+                      SmallVector<Operation *, 16> &readOps,
+                      SmallVector<Operation *, 16> &writeOps) {
+  parentOp->walk([&](Operation *op) {
+    MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
+    if (!iface)
+      return;
+    std::optional<MemoryEffects::EffectInstance> effect =
+        iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
+    if (effect) {
+      readOps.push_back(op);
+      return;
+    }
+    effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
+    if (effect)
+      writeOps.push_back(op);
+  });
+
+  // Restrict to a supported set of ops. We also require at least 2D access,
+  // although this could be relaxed.
+  if (llvm::any_of(readOps, [](Operation *op) {
+        return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
+                   op) ||
+               amdgpu::getIndices(op).size() < 2;
+      }))
+    return failure();
+  if (llvm::any_of(writeOps, [](Operation *op) {
+        return !isa<memref::StoreOp, vector::StoreOp, vector::TransferWriteOp>(
+                   op) ||
+               amdgpu::getIndices(op).size() < 2;
+      }))
+    return failure();
+
+  return success();
+}
+
+mlir::LogicalResult
+mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                 Value memrefValue) {
+  auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
+  if (!memRefType ||
+      !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
+    return failure();
+
+  // Abort if the given value has any sub-views; we do not do any alias
+  // analysis.
+  bool hasSubView = false;
+  parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
+  if (hasSubView)
+    return failure();
+
+  // Check if this is necessary given the assumption of 128b accesses:
+  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+  const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
+  const int64_t rowsPerLine =
+      (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
+      rowSize;
+  const int64_t threadGroupSize =
+      1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
+  if (rowsPerLine >= threadGroupSize)
+    return failure();
+
+  // Get sets of operations within the function that read/write to shared
+  // memory.
+  SmallVector<Operation *, 16> shmReadOps;
+  SmallVector<Operation *, 16> shmWriteOps;
+  if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
+                                   shmWriteOps)))
+    return failure();
+
+  if (shmReadOps.empty() || shmWriteOps.empty())
+    return failure();
+
+  OpBuilder builder(parentOp->getContext());
+
+  int64_t tgtDim = memRefType.getRank() - 1;
+  int64_t srcDim = memRefType.getRank() - 2;
+
+  // Transform indices for the ops writing to shared memory.
+  while (!shmWriteOps.empty()) {
+    Operation *shmWriteOp = shmWriteOps.back();
+    shmWriteOps.pop_back();
+    builder.setInsertionPoint(shmWriteOp);
+
+    auto indices = amdgpu::getIndices(shmWriteOp);
+    SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+    transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
+                     memRefType, srcDim, tgtDim);
+    amdgpu::setIndices(shmWriteOp, transformedIndices);
+  }
+
+  // Transform indices for the ops reading from shared memory.
+  while (!shmReadOps.empty()) {
+    Operation *shmReadOp = shmReadOps.back();
+    shmReadOps.pop_back();
+    builder.setInsertionPoint(shmReadOp);
+
+    auto indices = amdgpu::getIndices(shmReadOp);
+    SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
+    transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
+                     memRefType, srcDim, tgtDim);
+    amdgpu::setIndices(shmReadOp, transformedIndices);
+  }
+
+  return success();
+}
+
+namespace {
+class OptimizeSharedMemoryPass
+    : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
+public:
+  OptimizeSharedMemoryPass() = default;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    SmallVector<memref::AllocOp> shmAllocOps;
+    op->walk([&](memref::AllocOp allocOp) {
+      if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(
+              allocOp.getType()))
+        return;
+      shmAllocOps.push_back(allocOp);
+    });
+    for (auto allocOp : shmAllocOps) {
+      if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
+                                                    allocOp.getMemref())))
+        return;
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::amdgpu::createOptimizeSharedMemoryPass() {
+  return std::make_unique<OptimizeSharedMemoryPass>();
+}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
new file mode 100644
index 00000000000000..a1dc6cf70e7bf8
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp
@@ -0,0 +1,48 @@
+//===- Utils.cpp - Transform utilities ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+Operation::operand_range amdgpu::getIndices(Operation *op) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndices();
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return storeOp.getIndices();
+  if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
+    return vectorReadOp.getIndices();
+  if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+    return vectorStoreOp.getIndices();
+  if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+    return transferReadOp.getIndices();
+  if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+    return transferWriteOp.getIndices();
+  llvm_unreachable("unsupported op type");
+}
+
+void amdgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return loadOp.getIndicesMutable().assign(indices);
+  if (auto storeOp = dyn_cast<memref:...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/75627


More information about the Mlir-commits mailing list