[Mlir-commits] [mlir] [MLIR][AMDGPU]Add refactor for shared-mem optimization (PR #81791)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 14 13:16:28 PST 2024


https://github.com/erman-gurses created https://github.com/llvm/llvm-project/pull/81791

Addressing the issues in this PR: https://github.com/llvm/llvm-project/pull/81550

>From 5f40a6440dcd316d18954f6971b15d0372d61d46 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Wed, 14 Feb 2024 15:11:38 -0600
Subject: [PATCH] Add refactor for shmem-optimization

---
 .../Dialect/AMDGPU/Transforms/Transforms.h    |  2 +-
 mlir/lib/Dialect/AMDGPU/CMakeLists.txt        |  3 +-
 .../Transforms/OptimizeSharedMemory.cpp       | 37 ++++++++++---------
 ...transform_optimize_shmem_reads_writes.mlir | 10 -----
 4 files changed, 22 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index b4e9ad27003db1..3b8c4880f16c02 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -49,7 +49,7 @@ namespace amdgpu {
 mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                        Value memrefValue);
 
-void optimizeSharedMemoryReadsAndWritesOp(mlir::func::FuncOp funcOp);
+std::optional<mlir::LogicalResult> optimizeSharedMemoryReadsAndWritesOp(mlir::func::FuncOp funcOp);
 
 } // namespace amdgpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
index 63b4d8b99f53fd..ab9812c43b328a 100644
--- a/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(IR)
-add_subdirectory(Utils)
 add_subdirectory(TransformOps)
 add_subdirectory(Transforms)
+add_subdirectory(Utils)
+
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 7c50a876e78f45..771d48fadb36fd 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -50,12 +50,12 @@ constexpr int64_t kDefaultVectorSizeBits = 64;
 static Value permuteVectorOffset(OpBuilder &b, Location loc,
                                  ArrayRef<Value> indices, MemRefType memrefTy,
                                  int64_t srcDim, int64_t tgtDim) {
-  /// Adjust the src index to change how often the permutation changes
-  /// if necessary.
+  // Adjust the src index to change how often the permutation changes
+  // if necessary.
   Value src = indices[srcDim];
 
-  /// We only want to permute every N iterations of the target dim where N is
-  /// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+  // We only want to permute every N iterations of the target dim where N is
+  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
   const int64_t permuteEveryN = std::max<int64_t>(
       1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
                                         memrefTy.getElementTypeBitWidth()) /
@@ -110,8 +110,8 @@ static void transformIndices(OpBuilder &builder, Location loc,
       permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
 }
 
-/// Return all operations within `parentOp` that read from or write to
-/// `shmMemRef`.
+// Return all operations within `parentOp` that read from or write to
+// `shmMemRef`.
 static LogicalResult
 getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
                       SmallVector<Operation *, 16> &readOps,
@@ -131,8 +131,8 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
       writeOps.push_back(op);
   });
 
-  /// Restrict to a supported set of ops. We also require at least 2D access,
-  /// although this could be relaxed.
+  // Restrict to a supported set of ops. We also require at least 2D access,
+  // although this could be relaxed.
   if (llvm::any_of(readOps, [](Operation *op) {
         return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
                    op) ||
@@ -157,15 +157,15 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
       !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
     return failure();
 
-  /// Abort if the given value has any sub-views; we do not do any alias
-  /// analysis.
+  // Abort if the given value has any sub-views; we do not do any alias
+  // analysis.
   bool hasSubView = false;
   parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
   if (hasSubView)
     return failure();
 
-  /// Check if this is necessary given the assumption of 128b accesses:
-  /// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+  // Check if this is necessary given the assumption of 128b accesses:
+  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
   const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
   const int64_t rowsPerLine =
       (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
@@ -175,8 +175,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   if (rowsPerLine >= threadGroupSize)
     return failure();
 
-  /// Get sets of operations within the function that read/write to shared
-  /// memory.
+  // Get sets of operations within the function that read/write to shared
+  // memory.
   SmallVector<Operation *, 16> shmReadOps;
   SmallVector<Operation *, 16> shmWriteOps;
   if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
@@ -191,7 +191,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   int64_t tgtDim = memRefType.getRank() - 1;
   int64_t srcDim = memRefType.getRank() - 2;
 
-  /// Transform indices for the ops writing to shared memory.
+  // Transform indices for the ops writing to shared memory.
   while (!shmWriteOps.empty()) {
     Operation *shmWriteOp = shmWriteOps.pop_back_val();
     builder.setInsertionPoint(shmWriteOp);
@@ -203,7 +203,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
     amdgpu::setIndices(shmWriteOp, transformedIndices);
   }
 
-  /// Transform indices for the ops reading from shared memory.
+  // Transform indices for the ops reading from shared memory.
   while (!shmReadOps.empty()) {
     Operation *shmReadOp = shmReadOps.pop_back_val();
     builder.setInsertionPoint(shmReadOp);
@@ -218,7 +218,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   return success();
 }
 
-void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
+std::optional<mlir::LogicalResult> amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -228,8 +228,9 @@ void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
   for (auto allocOp : shmAllocOps) {
     if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
                                                           allocOp.getMemref())))
-      return;
+      return failure();
   }
+   return success();
 }
 
 struct OptimizeSharedMemoryPass
diff --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
index dfdd1b17e244e3..143e7c2d270952 100644
--- a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -7,22 +7,17 @@
                     %fragRow: index, %fragCol: index, 
                     %fragColPerm: index,
                     %stRow: index, %stCol: index) {
-    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16                  
     %cst = arith.constant 0.000000e+00 : f16
 
-    // CHECK: [[shmA:%.+]] = memref.alloc
-    // CHECK: [[shmB:%.+]] = memref.alloc
     %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
     %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
 
-    // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
     // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
     // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     gpu.barrier
     gpu.barrier
@@ -31,17 +26,13 @@
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]       
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] 
-    // CHECK:  vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
     %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
-
-    // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
     // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
     // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     gpu.barrier
     gpu.barrier
@@ -50,7 +41,6 @@
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                     
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
-    // CHECK:  vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
     %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
     return
   }



More information about the Mlir-commits mailing list