[Mlir-commits] [mlir] [MLIR][AMDGPU]Add refactoring for shared-mem optimization (PR #81791)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 14 14:04:47 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-amdgpu
@llvm/pr-subscribers-mlir
Author: None (erman-gurses)
<details>
<summary>Changes</summary>
Addressing the issues in this PR: https://github.com/llvm/llvm-project/pull/81550
---
Full diff: https://github.com/llvm/llvm-project/pull/81791.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h (+2-1)
- (modified) mlir/lib/Dialect/AMDGPU/CMakeLists.txt (+1-1)
- (modified) mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp (+20-18)
- (modified) mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir (-10)
``````````diff
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index b4e9ad27003db1..22bc9b9e0cf842 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -49,7 +49,8 @@ namespace amdgpu {
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue);
-void optimizeSharedMemoryReadsAndWritesOp(mlir::func::FuncOp funcOp);
+std::optional<mlir::LogicalResult>
+optimizeSharedMemoryReadsAndWritesOp(mlir::func::FuncOp funcOp);
} // namespace amdgpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
index 63b4d8b99f53fd..c47e4c5495c17b 100644
--- a/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
@@ -1,4 +1,4 @@
add_subdirectory(IR)
-add_subdirectory(Utils)
add_subdirectory(TransformOps)
add_subdirectory(Transforms)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 7c50a876e78f45..c33608a496470e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -50,12 +50,12 @@ constexpr int64_t kDefaultVectorSizeBits = 64;
static Value permuteVectorOffset(OpBuilder &b, Location loc,
ArrayRef<Value> indices, MemRefType memrefTy,
int64_t srcDim, int64_t tgtDim) {
- /// Adjust the src index to change how often the permutation changes
- /// if necessary.
+ // Adjust the src index to change how often the permutation changes
+ // if necessary.
Value src = indices[srcDim];
- /// We only want to permute every N iterations of the target dim where N is
- /// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+ // We only want to permute every N iterations of the target dim where N is
+ // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
const int64_t permuteEveryN = std::max<int64_t>(
1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
memrefTy.getElementTypeBitWidth()) /
@@ -110,8 +110,8 @@ static void transformIndices(OpBuilder &builder, Location loc,
permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
}
-/// Return all operations within `parentOp` that read from or write to
-/// `shmMemRef`.
+// Return all operations within `parentOp` that read from or write to
+// `shmMemRef`.
static LogicalResult
getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
SmallVector<Operation *, 16> &readOps,
@@ -131,8 +131,8 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
writeOps.push_back(op);
});
- /// Restrict to a supported set of ops. We also require at least 2D access,
- /// although this could be relaxed.
+ // Restrict to a supported set of ops. We also require at least 2D access,
+ // although this could be relaxed.
if (llvm::any_of(readOps, [](Operation *op) {
return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
op) ||
@@ -157,15 +157,15 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
return failure();
- /// Abort if the given value has any sub-views; we do not do any alias
- /// analysis.
+ // Abort if the given value has any sub-views; we do not do any alias
+ // analysis.
bool hasSubView = false;
parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
if (hasSubView)
return failure();
- /// Check if this is necessary given the assumption of 128b accesses:
- /// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+ // Check if this is necessary given the assumption of 128b accesses:
+ // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
const int64_t rowsPerLine =
(8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
@@ -175,8 +175,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
if (rowsPerLine >= threadGroupSize)
return failure();
- /// Get sets of operations within the function that read/write to shared
- /// memory.
+ // Get sets of operations within the function that read/write to shared
+ // memory.
SmallVector<Operation *, 16> shmReadOps;
SmallVector<Operation *, 16> shmWriteOps;
if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
@@ -191,7 +191,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
int64_t tgtDim = memRefType.getRank() - 1;
int64_t srcDim = memRefType.getRank() - 2;
- /// Transform indices for the ops writing to shared memory.
+ // Transform indices for the ops writing to shared memory.
while (!shmWriteOps.empty()) {
Operation *shmWriteOp = shmWriteOps.pop_back_val();
builder.setInsertionPoint(shmWriteOp);
@@ -203,7 +203,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
amdgpu::setIndices(shmWriteOp, transformedIndices);
}
- /// Transform indices for the ops reading from shared memory.
+ // Transform indices for the ops reading from shared memory.
while (!shmReadOps.empty()) {
Operation *shmReadOp = shmReadOps.pop_back_val();
builder.setInsertionPoint(shmReadOp);
@@ -218,7 +218,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
return success();
}
-void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
+std::optional<mlir::LogicalResult>
+amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
SmallVector<memref::AllocOp> shmAllocOps;
funcOp.walk([&](memref::AllocOp allocOp) {
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -228,8 +229,9 @@ void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
for (auto allocOp : shmAllocOps) {
if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
allocOp.getMemref())))
- return;
+ return failure();
}
+ return success();
}
struct OptimizeSharedMemoryPass
diff --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
index dfdd1b17e244e3..143e7c2d270952 100644
--- a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -7,22 +7,17 @@
%fragRow: index, %fragCol: index,
%fragColPerm: index,
%stRow: index, %stCol: index) {
- // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
%cst = arith.constant 0.000000e+00 : f16
- // CHECK: [[shmA:%.+]] = memref.alloc
- // CHECK: [[shmB:%.+]] = memref.alloc
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
- // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
// CHECK: [[c7:%.+]] = arith.constant 7 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
- // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
gpu.barrier
gpu.barrier
@@ -31,17 +26,13 @@
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
- // CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
-
- // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
// CHECK: [[c7:%.+]] = arith.constant 7 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
- // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
gpu.barrier
gpu.barrier
@@ -50,7 +41,6 @@
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
- // CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/81791
More information about the Mlir-commits
mailing list