[Mlir-commits] [mlir] [AMDGPU] add utils for common usage (PR #75097)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 11 13:11:15 PST 2023
https://github.com/erman-gurses created https://github.com/llvm/llvm-project/pull/75097
This PR generalizes helper functions for the common usage in swizzling implementation for AMDGPU and NVGPU Dialects.
>From d2326f06ec1b85d3a8c076ccec8587a38bde2443 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Mon, 11 Dec 2023 09:50:33 -0800
Subject: [PATCH] [AMDGPU] add utils for common usage
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 21 ++++++++++
.../mlir/Dialect/NVGPU/Transforms/Utils.h | 6 ---
.../mlir/Dialect/Utils/IndexingUtils.h | 13 ++++++
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 17 +++++++-
.../NVGPU/Transforms/CreateAsyncGroups.cpp | 5 ++-
.../NVGPU/Transforms/OptimizeSharedMemory.cpp | 2 +-
mlir/lib/Dialect/NVGPU/Transforms/Utils.cpp | 40 ------------------
mlir/lib/Dialect/Utils/IndexingUtils.cpp | 42 +++++++++++++++++++
8 files changed, 96 insertions(+), 50 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index ffb302fcedd73..110bdcb571fdf 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -29,6 +29,27 @@ def AMDGPU_Dialect : Dialect {
"gpu::GPUDialect"
];
let useDefaultAttributePrinterParser = 1;
+ let extraClassDeclaration = [{
+ /// Return true if the given MemRefType has an integer address
+ /// space that matches the ROCDL shared memory address space or
+ /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool hasSharedMemoryAddressSpace(MemRefType type);
+
+ /// Return true if the given Attribute has an integer address
+ /// space that matches the ROCDL shared memory address space or
+ /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool isSharedMemoryAddressSpace(Attribute type);
+
+ /// Defines the MemRef memory space attribute numeric value that indicates
+ /// a memref is located in global memory. This should correspond to the
+ /// value used in ROCDL.
+ static constexpr unsigned kGlobaldMemoryAddressSpace = 1;
+
+ /// Defines the MemRef memory space attribute numeric value that indicates
+ /// a memref is located in shared memory. This should correspond to the
+ /// value used in ROCDL.
+ static constexpr unsigned kSharedMemoryAddressSpace = 3;
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/NVGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/NVGPU/Transforms/Utils.h
index 64bce441722af..b845aef888ca6 100644
--- a/mlir/include/mlir/Dialect/NVGPU/Transforms/Utils.h
+++ b/mlir/include/mlir/Dialect/NVGPU/Transforms/Utils.h
@@ -11,12 +11,6 @@
namespace mlir {
namespace nvgpu {
-/// Get the indices that the given load/store operation is operating on.
-Operation::operand_range getIndices(Operation *op);
-
-/// Set the indices that the given load/store operation is operating on.
-void setIndices(Operation *op, ArrayRef<Value> indices);
-
/// Get the value that is stored by the given store operation.
Value getValueStored(Operation *op);
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index f51a8b28b7548..295add0312792 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -14,6 +14,9 @@
#ifndef MLIR_DIALECT_UTILS_INDEXINGUTILS_H
#define MLIR_DIALECT_UTILS_INDEXINGUTILS_H
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
@@ -387,6 +390,16 @@ class StaticTileOffsetRange {
IteratorTy beginValue;
IteratorTy pastEndValue;
};
+
+
+//===----------------------------------------------------------------------===//
+// load/store utils.
+//===----------------------------------------------------------------------===//
+/// Get the indices that the given load/store operation is operating on.
+Operation::operand_range getIndices(Operation *op);
+
+/// Set the indices that the given load/store operation is operating on.
+void setIndices(Operation *op, ArrayRef<Value> indices);
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2575ad4984814..f54ec46c7476d 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -9,7 +9,7 @@
// This file implements the AMDGPU dialect and its operations.
//
//===----------------------------------------------------------------------===//
-
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() {
>();
}
+bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+ if (!memorySpace)
+ return false;
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+ return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace;
+ if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
+ return false;
+}
+
+bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+ Attribute memorySpace = type.getMemorySpace();
+ return isSharedMemoryAddressSpace(memorySpace);
+}
+
//===----------------------------------------------------------------------===//
// 8-bit float ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index f63825cdc8f61..235cc28abd8f0 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -254,9 +255,9 @@ void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
// bypass_l1 only possible with 16 byte transfer.
Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
- /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp),
+ /*dst=*/storeBase, /*dstIndices=*/getIndices(writeOp),
/*src=*/loadBase,
- /*srcIndices=*/nvgpu::getIndices(readOp),
+ /*srcIndices=*/getIndices(readOp),
/*dstElements=*/rewriter.getIndexAttr(numElements),
/*srcElements=*/numReadElements,
/*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
index 693bb53cacff6..bfffeaa32fbe2 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
@@ -9,7 +9,7 @@
// This file implements transforms to optimize accesses to shared memory.
//
//===----------------------------------------------------------------------===//
-
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/NVGPU/Transforms/Utils.cpp
index a782ed5ddd85e..213b453d4bf9f 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/Utils.cpp
@@ -15,46 +15,6 @@
using namespace mlir;
using namespace mlir::nvgpu;
-Operation::operand_range nvgpu::getIndices(Operation *op) {
- if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
- return ldmatrixOp.getIndices();
- if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
- return copyOp.getDstIndices();
- if (auto loadOp = dyn_cast<memref::LoadOp>(op))
- return loadOp.getIndices();
- if (auto storeOp = dyn_cast<memref::StoreOp>(op))
- return storeOp.getIndices();
- if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
- return vectorReadOp.getIndices();
- if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
- return vectorStoreOp.getIndices();
- if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
- return transferReadOp.getIndices();
- if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
- return transferWriteOp.getIndices();
- llvm_unreachable("unsupported op type");
-}
-
-void nvgpu::setIndices(Operation *op, ArrayRef<Value> indices) {
- if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
- return ldmatrixOp.getIndicesMutable().assign(indices);
- if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
- return copyOp.getDstIndicesMutable().assign(indices);
- if (auto loadOp = dyn_cast<memref::LoadOp>(op))
- return loadOp.getIndicesMutable().assign(indices);
- if (auto storeOp = dyn_cast<memref::StoreOp>(op))
- return storeOp.getIndicesMutable().assign(indices);
- if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
- return vectorReadOp.getIndicesMutable().assign(indices);
- if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
- return vectorStoreOp.getIndicesMutable().assign(indices);
- if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
- return transferReadOp.getIndicesMutable().assign(indices);
- if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
- return transferWriteOp.getIndicesMutable().assign(indices);
- llvm_unreachable("unsupported op type");
-}
-
Value nvgpu::getValueStored(Operation *op) {
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
return storeOp.getValueToStore();
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index bb8a0d5912d7c..8bbd84918eed9 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -8,6 +8,8 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/NVGPU/Transforms/Utils.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -352,3 +354,43 @@ mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets(
return mlir::computeElementwiseMul(tileCoords,
getAffineConstantExprs(tileShape, ctx));
}
+
+Operation::operand_range mlir::getIndices(Operation *op) {
+ if (auto ldmatrixOp = dyn_cast<nvgpu::LdMatrixOp>(op))
+ return ldmatrixOp.getIndices();
+ if (auto copyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op))
+ return copyOp.getDstIndices();
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getIndices();
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getIndices();
+ if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
+ return vectorReadOp.getIndices();
+ if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+ return vectorStoreOp.getIndices();
+ if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+ return transferReadOp.getIndices();
+ if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+ return transferWriteOp.getIndices();
+ llvm_unreachable("unsupported op type");
+}
+
+void mlir::setIndices(Operation *op, ArrayRef<Value> indices) {
+ if (auto ldmatrixOp = dyn_cast<nvgpu::LdMatrixOp>(op))
+ return ldmatrixOp.getIndicesMutable().assign(indices);
+ if (auto copyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op))
+ return copyOp.getDstIndicesMutable().assign(indices);
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return loadOp.getIndicesMutable().assign(indices);
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return storeOp.getIndicesMutable().assign(indices);
+ if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
+ return vectorReadOp.getIndicesMutable().assign(indices);
+ if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
+ return vectorStoreOp.getIndicesMutable().assign(indices);
+ if (auto transferReadOp = dyn_cast<vector::TransferReadOp>(op))
+ return transferReadOp.getIndicesMutable().assign(indices);
+ if (auto transferWriteOp = dyn_cast<vector::TransferWriteOp>(op))
+ return transferWriteOp.getIndicesMutable().assign(indices);
+ llvm_unreachable("unsupported op type");
+}
More information about the Mlir-commits
mailing list