[Mlir-commits] [mlir] 04381c1 - [MLIR][AMDGPU]Add refactoring for shared-mem optimization (#81791)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 15 10:53:19 PST 2024


Author: erman-gurses
Date: 2024-02-15T13:53:15-05:00
New Revision: 04381c106f1cb0d3219ddabe231fe900113a9474

URL: https://github.com/llvm/llvm-project/commit/04381c106f1cb0d3219ddabe231fe900113a9474
DIFF: https://github.com/llvm/llvm-project/commit/04381c106f1cb0d3219ddabe231fe900113a9474.diff

LOG: [MLIR][AMDGPU]Add refactoring for shared-mem optimization (#81791)

Addressing the issues in this PR:
https://github.com/llvm/llvm-project/pull/81550

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
    mlir/lib/Dialect/AMDGPU/CMakeLists.txt
    mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
    mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index b4e9ad27003db1..79f9ab71a2b430 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -1,5 +1,4 @@
-//===- Transforms.h - AMDGPU Dialect transformations --------------*-
-// C++-*-===//
+//===- Transforms.h - AMDGPU Dialect transformations -------------*- C++-*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -46,10 +45,11 @@ namespace amdgpu {
 /// function that depends on the row Index. The permutation function is chosen
 /// to ensure that sequential distributed+vectorized reads/writes down a single
 /// dimension of the memref have minimal conflicts.
-mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
-                                                       Value memrefValue);
+LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                 Value memrefValue);
 
-void optimizeSharedMemoryReadsAndWritesOp(mlir::func::FuncOp funcOp);
+std::optional<LogicalResult>
+optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp);
 
 } // namespace amdgpu
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/AMDGPU/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
index 63b4d8b99f53fd..c47e4c5495c17b 100644
--- a/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
@@ -1,4 +1,4 @@
 add_subdirectory(IR)
-add_subdirectory(Utils)
 add_subdirectory(TransformOps)
 add_subdirectory(Transforms)
+add_subdirectory(Utils)

diff  --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 7c50a876e78f45..6bd03ed833898d 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -50,12 +50,12 @@ constexpr int64_t kDefaultVectorSizeBits = 64;
 static Value permuteVectorOffset(OpBuilder &b, Location loc,
                                  ArrayRef<Value> indices, MemRefType memrefTy,
                                  int64_t srcDim, int64_t tgtDim) {
-  /// Adjust the src index to change how often the permutation changes
-  /// if necessary.
+  // Adjust the src index to change how often the permutation changes
+  // if necessary.
   Value src = indices[srcDim];
 
-  /// We only want to permute every N iterations of the target dim where N is
-  /// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+  // We only want to permute every N iterations of the target dim where N is
+  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
   const int64_t permuteEveryN = std::max<int64_t>(
       1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
                                         memrefTy.getElementTypeBitWidth()) /
@@ -110,8 +110,8 @@ static void transformIndices(OpBuilder &builder, Location loc,
       permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
 }
 
-/// Return all operations within `parentOp` that read from or write to
-/// `shmMemRef`.
+// Return all operations within `parentOp` that read from or write to
+// `shmMemRef`.
 static LogicalResult
 getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
                       SmallVector<Operation *, 16> &readOps,
@@ -131,8 +131,8 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
       writeOps.push_back(op);
   });
 
-  /// Restrict to a supported set of ops. We also require at least 2D access,
-  /// although this could be relaxed.
+  // Restrict to a supported set of ops. We also require at least 2D access,
+  // although this could be relaxed.
   if (llvm::any_of(readOps, [](Operation *op) {
         return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
                    op) ||
@@ -149,23 +149,22 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
   return success();
 }
 
-mlir::LogicalResult
-mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
-                                                 Value memrefValue) {
+LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
+                                                         Value memrefValue) {
   auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
   if (!memRefType ||
       !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
     return failure();
 
-  /// Abort if the given value has any sub-views; we do not do any alias
-  /// analysis.
+  // Abort if the given value has any sub-views; we do not do any alias
+  // analysis.
   bool hasSubView = false;
   parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
   if (hasSubView)
     return failure();
 
-  /// Check if this is necessary given the assumption of 128b accesses:
-  /// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+  // Check if this is necessary given the assumption of 128b accesses:
+  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
   const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
   const int64_t rowsPerLine =
       (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
@@ -175,8 +174,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   if (rowsPerLine >= threadGroupSize)
     return failure();
 
-  /// Get sets of operations within the function that read/write to shared
-  /// memory.
+  // Get sets of operations within the function that read/write to shared
+  // memory.
   SmallVector<Operation *, 16> shmReadOps;
   SmallVector<Operation *, 16> shmWriteOps;
   if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
@@ -191,7 +190,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   int64_t tgtDim = memRefType.getRank() - 1;
   int64_t srcDim = memRefType.getRank() - 2;
 
-  /// Transform indices for the ops writing to shared memory.
+  // Transform indices for the ops writing to shared memory.
   while (!shmWriteOps.empty()) {
     Operation *shmWriteOp = shmWriteOps.pop_back_val();
     builder.setInsertionPoint(shmWriteOp);
@@ -203,7 +202,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
     amdgpu::setIndices(shmWriteOp, transformedIndices);
   }
 
-  /// Transform indices for the ops reading from shared memory.
+  // Transform indices for the ops reading from shared memory.
   while (!shmReadOps.empty()) {
     Operation *shmReadOp = shmReadOps.pop_back_val();
     builder.setInsertionPoint(shmReadOp);
@@ -218,7 +217,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   return success();
 }
 
-void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
+std::optional<LogicalResult>
+amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -228,8 +228,9 @@ void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
   for (auto allocOp : shmAllocOps) {
     if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
                                                           allocOp.getMemref())))
-      return;
+      return failure();
   }
+  return success();
 }
 
 struct OptimizeSharedMemoryPass

diff  --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
index dfdd1b17e244e3..143e7c2d270952 100644
--- a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -7,22 +7,17 @@
                     %fragRow: index, %fragCol: index, 
                     %fragColPerm: index,
                     %stRow: index, %stCol: index) {
-    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16                  
     %cst = arith.constant 0.000000e+00 : f16
 
-    // CHECK: [[shmA:%.+]] = memref.alloc
-    // CHECK: [[shmB:%.+]] = memref.alloc
     %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
     %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
 
-    // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
     // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
     // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     gpu.barrier
     gpu.barrier
@@ -31,17 +26,13 @@
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]       
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] 
-    // CHECK:  vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
     %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
-
-    // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
     // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
     // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     gpu.barrier
     gpu.barrier
@@ -50,7 +41,6 @@
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                     
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
-    // CHECK:  vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
     %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
     return
   }


        


More information about the Mlir-commits mailing list