[Mlir-commits] [mlir] [AMDGPU] Add parameterization for optimized shared memory variables (PR #82508)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 27 16:40:47 PST 2024


https://github.com/erman-gurses updated https://github.com/llvm/llvm-project/pull/82508

>From 3659cbf8eee983c6e3fe7829377b82b0ea10c599 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Wed, 21 Feb 2024 10:58:30 -0600
Subject: [PATCH 1/8] Add parameterization for optimized shared memory
 variables

---
 .../AMDGPU/TransformOps/AMDGPUTransformOps.td |  6 ++-
 .../Transforms/OptimizeSharedMemory.cpp       | 33 ++++++++++--
 .../AMDGPU/optimize_shmem_reads_writes.mlir   | 50 ++++++++-----------
 ...transform_optimize_shmem_reads_writes.mlir | 46 ++++++++---------
 4 files changed, 79 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
index 23873d86b495c6..9a9446155bf27f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
@@ -13,8 +13,8 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
-
 //===----------------------------------------------------------------------===//
 // ApplyOptimizeSharedMemoryReadsAndWritesOp
 //===----------------------------------------------------------------------===//
@@ -28,7 +28,9 @@ def ApplyOptimizeSharedMemoryReadsAndWritesOp :
     reads/writes with the goal of avoiding bank conflicts.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target);
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                    DefaultValuedOptionalAttr<I64Attr, "128">:$kSharedMemoryLineSizeBytes,
+                    DefaultValuedOptionalAttr<I64Attr, "128">:$kDefaultVectorSizeBits);
   let results = (outs);
 
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 6bd03ed833898d..00b70d673adc9e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -37,11 +37,18 @@ using namespace mlir::amdgpu;
 
 /// The size of a shared memory line according to AMD documentation.
 /// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
-constexpr int64_t kSharedMemoryLineSizeBytes = 64;
+int64_t kSharedMemoryLineSizeBytes;
 /// We optimize for 64bit accesses, but this can be made an argument in the
 /// future.
-constexpr int64_t kDefaultVectorSizeBits = 64;
+int64_t kDefaultVectorSizeBits;
 
+void setMemoryLineSize(int64_t _kSharedMemoryLineSizeBytes) {
+  kSharedMemoryLineSizeBytes = _kSharedMemoryLineSizeBytes;
+}
+
+void setDefaultVectorSize(int64_t _kDefaultVectorSizeBits) {
+  kDefaultVectorSizeBits = _kDefaultVectorSizeBits;
+}
 /// Uses `srcIndexValue` to permute `tgtIndexValue` via
 /// `result = xor(floordiv(srcIdxVal,permuteEveryN),
 ///               floordiv(tgtIdxVal,vectorSize)))
@@ -151,6 +158,7 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
 
 LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                          Value memrefValue) {
+
   auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
   if (!memRefType ||
       !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
@@ -219,6 +227,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
 
 std::optional<LogicalResult>
 amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
+  //setMemoryLineSize(_kSharedMemoryLineSizeBytes);
+  //setDefaultVectorSize(_kDefaultVectorSizeBits);
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -235,10 +245,23 @@ amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
 
 struct OptimizeSharedMemoryPass
     : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
+
 public:
-  OptimizeSharedMemoryPass() = default;
+  OptimizeSharedMemoryPass()
+      : OptimizeSharedMemoryBase(),
+        _kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes = 128),
+        _kDefaultVectorSizeBits(kDefaultVectorSizeBits = 128){};
+
+  OptimizeSharedMemoryPass(int64_t kSharedMemoryLineSizeBytes,
+                           int64_t kDefaultVectorSizeBits)
+      : OptimizeSharedMemoryBase(),
+        _kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes),
+        _kDefaultVectorSizeBits(kDefaultVectorSizeBits){};
 
   void runOnOperation() override {
+    setMemoryLineSize(_kSharedMemoryLineSizeBytes);
+    setDefaultVectorSize(_kDefaultVectorSizeBits);
+
     Operation *op = getOperation();
     SmallVector<memref::AllocOp> shmAllocOps;
     op->walk([&](memref::AllocOp allocOp) {
@@ -253,4 +276,8 @@ struct OptimizeSharedMemoryPass
         return;
     }
   }
+
+private:
+  int64_t _kSharedMemoryLineSizeBytes;
+  int64_t _kDefaultVectorSizeBits;
 };
diff --git a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
index a1de1ff87c229f..983eee732e2afe 100644
--- a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
@@ -1,13 +1,13 @@
-// RUN: mlir-opt  %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
   
   // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
-  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, 
+  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
                     %readRow: index, %readCol: index,
                     %writeRow: index, %writeCol: index,
-                    %fragRow: index, %fragCol: index, 
+                    %fragRow: index, %fragCol: index,
                     %fragColPerm: index,
                     %stRow: index, %stCol: index) {
-    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16                  
+    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16
     %cst = arith.constant 0.000000e+00 : f16
 
     // CHECK: [[shmA:%.+]] = memref.alloc
@@ -15,42 +15,36 @@
     %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
     %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
 
-    // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]     
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]       
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] 
-    // CHECK:  vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
     %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
 
-    // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]          
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                     
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
-    // CHECK:  vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
     %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
     return
   }
diff --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
index 143e7c2d270952..83fcc2520f3ce7 100644
--- a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt  %s -transform-interpreter  | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
 
   // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
-  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, 
+  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
                     %readRow: index, %readCol: index,
                     %writeRow: index, %writeCol: index,
-                    %fragRow: index, %fragCol: index, 
+                    %fragRow: index, %fragCol: index,
                     %fragColPerm: index,
                     %stRow: index, %stCol: index) {
     %cst = arith.constant 0.000000e+00 : f16
@@ -13,33 +13,33 @@
     %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
 
     %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]     
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]       
-    // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] 
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
     %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
     %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]          
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                     
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
     %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
     return
@@ -48,7 +48,7 @@
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
-    transform.amdgpu.optimize_shared_memory_reads_and_writes %0 : (!transform.any_op) -> ()
+    transform.amdgpu.optimize_shared_memory_reads_and_writes %0 {kSharedMemoryLineSizeBytes = 128, kDefaultVectorSizeBits = 128}: (!transform.any_op) -> ()
     transform.yield
   } // @__transform_main
 } // module

>From 72a38aa5f01a569af9098027bb8dba3cef4725e9 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Wed, 21 Feb 2024 11:13:28 -0600
Subject: [PATCH 2/8] Add formatting

---
 .../mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td      | 2 +-
 mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp     | 2 --
 2 files changed, 1 insertion(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
index 9a9446155bf27f..9419c8b14069e2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
@@ -13,7 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
-include "mlir/IR/EnumAttr.td"
+
 include "mlir/Interfaces/SideEffectInterfaces.td"
 //===----------------------------------------------------------------------===//
 // ApplyOptimizeSharedMemoryReadsAndWritesOp
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 00b70d673adc9e..a9be32567b6eaa 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -227,8 +227,6 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
 
 std::optional<LogicalResult>
 amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
-  //setMemoryLineSize(_kSharedMemoryLineSizeBytes);
-  //setDefaultVectorSize(_kDefaultVectorSizeBits);
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))

>From 030b5211734a0341c176ad6195469e2c294362ba Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Mon, 26 Feb 2024 16:10:00 -0600
Subject: [PATCH 3/8] Add fix for default values of transform dialect

---
 mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h    | 4 +++-
 mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp | 3 ++-
 mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 6 +++++-
 3 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index 79f9ab71a2b430..bb234d3a285e97 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -49,7 +49,9 @@ LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                  Value memrefValue);
 
 std::optional<LogicalResult>
-optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp);
+optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
+                                     int64_t kSharedMemoryLineSizeBytes,
+                                     int64_t kDefaultVectorSizeBits);
 
 } // namespace amdgpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
index ff29f9f6938535..08b57f7c8182f4 100644
--- a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
@@ -27,7 +27,8 @@ DiagnosedSilenceableFailure
 ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
     TransformRewriter &rewriter, FuncOp funcOp, ApplyToEachResultList &results,
     TransformState &state) {
-  optimizeSharedMemoryReadsAndWritesOp(funcOp);
+  optimizeSharedMemoryReadsAndWritesOp(funcOp, getKSharedMemoryLineSizeBytes(),
+                                       getKDefaultVectorSizeBits());
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index a9be32567b6eaa..6cf017cb52eb4b 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -226,7 +226,11 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
 }
 
 std::optional<LogicalResult>
-amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
+amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
+                                             int64_t kSharedMemoryLineSizeBytes,
+                                             int64_t kDefaultVectorSizeBits) {
+  setMemoryLineSize(kSharedMemoryLineSizeBytes);
+  setDefaultVectorSize(kDefaultVectorSizeBits);
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))

>From 1a193b2f60c44d3538f90de5ab374b993e8fec66 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Mon, 26 Feb 2024 16:16:19 -0600
Subject: [PATCH 4/8] Update comments

---
 mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 6cf017cb52eb4b..c48a9e1a9a6422 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -38,7 +38,7 @@ using namespace mlir::amdgpu;
 /// The size of a shared memory line according to AMD documentation.
 /// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
 int64_t kSharedMemoryLineSizeBytes;
-/// We optimize for 64bit accesses, but this can be made an argument in the
+/// We optimize for 128 bit accesses, but this can be made an argument in the
 /// future.
 int64_t kDefaultVectorSizeBits;
 

>From 0214796f7d84bd95965b1c0a4726449825667c3f Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Mon, 26 Feb 2024 22:15:01 -0600
Subject: [PATCH 5/8] Remove global vars

---
 .../Dialect/AMDGPU/Transforms/Transforms.h    |  6 +-
 .../Transforms/OptimizeSharedMemory.cpp       | 62 ++++++++-----------
 2 files changed, 30 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index bb234d3a285e97..50e35c8d263720 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -45,8 +45,10 @@ namespace amdgpu {
 /// function that depends on the row Index. The permutation function is chosen
 /// to ensure that sequential distributed+vectorized reads/writes down a single
 /// dimension of the memref have minimal conflicts.
-LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
-                                                 Value memrefValue);
+LogicalResult
+optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue,
+                                   int64_t kSharedMemoryLineSizeBytes,
+                                   int64_t kDefaultVectorSizeBits);
 
 std::optional<LogicalResult>
 optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index c48a9e1a9a6422..8119cc697fb54e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -35,20 +35,6 @@ namespace amdgpu {
 using namespace mlir;
 using namespace mlir::amdgpu;
 
-/// The size of a shared memory line according to AMD documentation.
-/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
-int64_t kSharedMemoryLineSizeBytes;
-/// We optimize for 128 bit accesses, but this can be made an argument in the
-/// future.
-int64_t kDefaultVectorSizeBits;
-
-void setMemoryLineSize(int64_t _kSharedMemoryLineSizeBytes) {
-  kSharedMemoryLineSizeBytes = _kSharedMemoryLineSizeBytes;
-}
-
-void setDefaultVectorSize(int64_t _kDefaultVectorSizeBits) {
-  kDefaultVectorSizeBits = _kDefaultVectorSizeBits;
-}
 /// Uses `srcIndexValue` to permute `tgtIndexValue` via
 /// `result = xor(floordiv(srcIdxVal,permuteEveryN),
 ///               floordiv(tgtIdxVal,vectorSize)))
@@ -56,7 +42,9 @@ void setDefaultVectorSize(int64_t _kDefaultVectorSizeBits) {
 /// This is done using an optimized sequence of `arith` operations.
 static Value permuteVectorOffset(OpBuilder &b, Location loc,
                                  ArrayRef<Value> indices, MemRefType memrefTy,
-                                 int64_t srcDim, int64_t tgtDim) {
+                                 int64_t srcDim, int64_t tgtDim,
+                                 int64_t kSharedMemoryLineSizeBytes,
+                                 int64_t kDefaultVectorSizeBits) {
   // Adjust the src index to change how often the permutation changes
   // if necessary.
   Value src = indices[srcDim];
@@ -112,9 +100,11 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
 static void transformIndices(OpBuilder &builder, Location loc,
                              SmallVector<Value, 4> &indices,
                              MemRefType memrefTy, int64_t srcDim,
-                             int64_t tgtDim) {
+                             int64_t tgtDim, int64_t kSharedMemoryLineSizeBytes,
+                             int64_t kDefaultVectorSizeBits) {
   indices[tgtDim] =
-      permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
+      permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim,
+                          kSharedMemoryLineSizeBytes, kDefaultVectorSizeBits);
 }
 
 // Return all operations within `parentOp` that read from or write to
@@ -156,8 +146,9 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
   return success();
 }
 
-LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
-                                                         Value memrefValue) {
+LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
+    Operation *parentOp, Value memrefValue, int64_t kSharedMemoryLineSizeBytes,
+    int64_t kDefaultVectorSizeBits) {
 
   auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
   if (!memRefType ||
@@ -206,7 +197,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
     auto indices = amdgpu::getIndices(shmWriteOp);
     SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
     transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
-                     memRefType, srcDim, tgtDim);
+                     memRefType, srcDim, tgtDim, kSharedMemoryLineSizeBytes,
+                     kDefaultVectorSizeBits);
     amdgpu::setIndices(shmWriteOp, transformedIndices);
   }
 
@@ -218,7 +210,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
     auto indices = amdgpu::getIndices(shmReadOp);
     SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
     transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
-                     memRefType, srcDim, tgtDim);
+                     memRefType, srcDim, tgtDim, kSharedMemoryLineSizeBytes,
+                     kDefaultVectorSizeBits);
     amdgpu::setIndices(shmReadOp, transformedIndices);
   }
 
@@ -229,8 +222,6 @@ std::optional<LogicalResult>
 amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
                                              int64_t kSharedMemoryLineSizeBytes,
                                              int64_t kDefaultVectorSizeBits) {
-  setMemoryLineSize(kSharedMemoryLineSizeBytes);
-  setDefaultVectorSize(kDefaultVectorSizeBits);
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -238,8 +229,9 @@ amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
     shmAllocOps.push_back(allocOp);
   });
   for (auto allocOp : shmAllocOps) {
-    if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
-                                                          allocOp.getMemref())))
+    if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(
+            funcOp, allocOp.getMemref(), kSharedMemoryLineSizeBytes,
+            kDefaultVectorSizeBits)))
       return failure();
   }
   return success();
@@ -251,19 +243,16 @@ struct OptimizeSharedMemoryPass
 public:
   OptimizeSharedMemoryPass()
       : OptimizeSharedMemoryBase(),
-        _kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes = 128),
-        _kDefaultVectorSizeBits(kDefaultVectorSizeBits = 128){};
+        kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes = 128),
+        kDefaultVectorSizeBits(kDefaultVectorSizeBits = 128){};
 
   OptimizeSharedMemoryPass(int64_t kSharedMemoryLineSizeBytes,
                            int64_t kDefaultVectorSizeBits)
       : OptimizeSharedMemoryBase(),
-        _kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes),
-        _kDefaultVectorSizeBits(kDefaultVectorSizeBits){};
+        kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes),
+        kDefaultVectorSizeBits(kDefaultVectorSizeBits){};
 
   void runOnOperation() override {
-    setMemoryLineSize(_kSharedMemoryLineSizeBytes);
-    setDefaultVectorSize(_kDefaultVectorSizeBits);
-
     Operation *op = getOperation();
     SmallVector<memref::AllocOp> shmAllocOps;
     op->walk([&](memref::AllocOp allocOp) {
@@ -273,13 +262,14 @@ struct OptimizeSharedMemoryPass
       shmAllocOps.push_back(allocOp);
     });
     for (auto allocOp : shmAllocOps) {
-      if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
-                                                    allocOp.getMemref())))
+      if (failed(optimizeSharedMemoryReadsAndWrites(op, allocOp.getMemref(),
+                                                    kSharedMemoryLineSizeBytes,
+                                                    kDefaultVectorSizeBits)))
         return;
     }
   }
 
 private:
-  int64_t _kSharedMemoryLineSizeBytes;
-  int64_t _kDefaultVectorSizeBits;
+  int64_t kSharedMemoryLineSizeBytes;
+  int64_t kDefaultVectorSizeBits;
 };

>From 3e21d96865fb4ab85f9516a870bdd935324fd7f2 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Tue, 27 Feb 2024 10:24:32 -0600
Subject: [PATCH 6/8] Update variable names

---
 .../Dialect/AMDGPU/Transforms/Transforms.h    |  8 +--
 .../Transforms/OptimizeSharedMemory.cpp       | 58 +++++++++----------
 2 files changed, 33 insertions(+), 33 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index 50e35c8d263720..843cea2c503b9a 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -47,13 +47,13 @@ namespace amdgpu {
 /// dimension of the memref have minimal conflicts.
 LogicalResult
 optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue,
-                                   int64_t kSharedMemoryLineSizeBytes,
-                                   int64_t kDefaultVectorSizeBits);
+                                   int64_t sharedMemoryLineSizeBytes,
+                                   int64_t defaultVectorSizeBits);
 
 std::optional<LogicalResult>
 optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
-                                     int64_t kSharedMemoryLineSizeBytes,
-                                     int64_t kDefaultVectorSizeBits);
+                                     int64_t sharedMemoryLineSizeBytes,
+                                     int64_t defaultVectorSizeBits);
 
 } // namespace amdgpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 8119cc697fb54e..83b78a3e7d3c68 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -43,8 +43,8 @@ using namespace mlir::amdgpu;
 static Value permuteVectorOffset(OpBuilder &b, Location loc,
                                  ArrayRef<Value> indices, MemRefType memrefTy,
                                  int64_t srcDim, int64_t tgtDim,
-                                 int64_t kSharedMemoryLineSizeBytes,
-                                 int64_t kDefaultVectorSizeBits) {
+                                 int64_t sharedMemoryLineSizeBytes,
+                                 int64_t defaultVectorSizeBits) {
   // Adjust the src index to change how often the permutation changes
   // if necessary.
   Value src = indices[srcDim];
@@ -52,7 +52,7 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
   // We only want to permute every N iterations of the target dim where N is
   // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
   const int64_t permuteEveryN = std::max<int64_t>(
-      1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
+      1, sharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
                                         memrefTy.getElementTypeBitWidth()) /
                                        8));
 
@@ -66,7 +66,7 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
   // bits[N:M] = vector index
   // clang-format on
   int64_t n =
-      llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
+      llvm::Log2_64(defaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
   int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
 
   // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
@@ -100,11 +100,11 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
 static void transformIndices(OpBuilder &builder, Location loc,
                              SmallVector<Value, 4> &indices,
                              MemRefType memrefTy, int64_t srcDim,
-                             int64_t tgtDim, int64_t kSharedMemoryLineSizeBytes,
-                             int64_t kDefaultVectorSizeBits) {
+                             int64_t tgtDim, int64_t sharedMemoryLineSizeBytes,
+                             int64_t defaultVectorSizeBits) {
   indices[tgtDim] =
       permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim,
-                          kSharedMemoryLineSizeBytes, kDefaultVectorSizeBits);
+                          sharedMemoryLineSizeBytes, defaultVectorSizeBits);
 }
 
 // Return all operations within `parentOp` that read from or write to
@@ -147,8 +147,8 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
 }
 
 LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
-    Operation *parentOp, Value memrefValue, int64_t kSharedMemoryLineSizeBytes,
-    int64_t kDefaultVectorSizeBits) {
+    Operation *parentOp, Value memrefValue, int64_t sharedMemoryLineSizeBytes,
+    int64_t defaultVectorSizeBits) {
 
   auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
   if (!memRefType ||
@@ -166,10 +166,10 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
   // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
   const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
   const int64_t rowsPerLine =
-      (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
+      (8 * sharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
       rowSize;
   const int64_t threadGroupSize =
-      1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
+      1LL << (7 - llvm::Log2_64(defaultVectorSizeBits / 8));
   if (rowsPerLine >= threadGroupSize)
     return failure();
 
@@ -197,8 +197,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
     auto indices = amdgpu::getIndices(shmWriteOp);
     SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
     transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
-                     memRefType, srcDim, tgtDim, kSharedMemoryLineSizeBytes,
-                     kDefaultVectorSizeBits);
+                     memRefType, srcDim, tgtDim, sharedMemoryLineSizeBytes,
+                     defaultVectorSizeBits);
     amdgpu::setIndices(shmWriteOp, transformedIndices);
   }
 
@@ -210,8 +210,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
     auto indices = amdgpu::getIndices(shmReadOp);
     SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
     transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
-                     memRefType, srcDim, tgtDim, kSharedMemoryLineSizeBytes,
-                     kDefaultVectorSizeBits);
+                     memRefType, srcDim, tgtDim, sharedMemoryLineSizeBytes,
+                     defaultVectorSizeBits);
     amdgpu::setIndices(shmReadOp, transformedIndices);
   }
 
@@ -220,8 +220,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
 
 std::optional<LogicalResult>
 amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
-                                             int64_t kSharedMemoryLineSizeBytes,
-                                             int64_t kDefaultVectorSizeBits) {
+                                             int64_t sharedMemoryLineSizeBytes,
+                                             int64_t defaultVectorSizeBits) {
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -230,8 +230,8 @@ amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
   });
   for (auto allocOp : shmAllocOps) {
     if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(
-            funcOp, allocOp.getMemref(), kSharedMemoryLineSizeBytes,
-            kDefaultVectorSizeBits)))
+            funcOp, allocOp.getMemref(), sharedMemoryLineSizeBytes,
+            defaultVectorSizeBits)))
       return failure();
   }
   return success();
@@ -243,14 +243,14 @@ struct OptimizeSharedMemoryPass
 public:
   OptimizeSharedMemoryPass()
       : OptimizeSharedMemoryBase(),
-        kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes = 128),
-        kDefaultVectorSizeBits(kDefaultVectorSizeBits = 128){};
+        sharedMemoryLineSizeBytes(sharedMemoryLineSizeBytes = 128),
+        defaultVectorSizeBits(defaultVectorSizeBits = 128){};
 
-  OptimizeSharedMemoryPass(int64_t kSharedMemoryLineSizeBytes,
-                           int64_t kDefaultVectorSizeBits)
+  OptimizeSharedMemoryPass(int64_t sharedMemoryLineSizeBytes,
+                           int64_t defaultVectorSizeBits)
       : OptimizeSharedMemoryBase(),
-        kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes),
-        kDefaultVectorSizeBits(kDefaultVectorSizeBits){};
+        sharedMemoryLineSizeBytes(sharedMemoryLineSizeBytes),
+        defaultVectorSizeBits(defaultVectorSizeBits){};
 
   void runOnOperation() override {
     Operation *op = getOperation();
@@ -263,13 +263,13 @@ struct OptimizeSharedMemoryPass
     });
     for (auto allocOp : shmAllocOps) {
       if (failed(optimizeSharedMemoryReadsAndWrites(op, allocOp.getMemref(),
-                                                    kSharedMemoryLineSizeBytes,
-                                                    kDefaultVectorSizeBits)))
+                                                    sharedMemoryLineSizeBytes,
+                                                    defaultVectorSizeBits)))
         return;
     }
   }
 
 private:
-  int64_t kSharedMemoryLineSizeBytes;
-  int64_t kDefaultVectorSizeBits;
+  int64_t sharedMemoryLineSizeBytes;
+  int64_t defaultVectorSizeBits;
 };

>From 1025f2b621a4272becb5e8c9aaf30190d15b7710 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Tue, 27 Feb 2024 10:32:51 -0600
Subject: [PATCH 7/8] Add the formatting

---
 .../mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td    | 4 ++--
 mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp   | 4 ++--
 mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp   | 4 ++--
 .../Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir | 2 +-
 4 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
index 9419c8b14069e2..0eb67050608630 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
@@ -29,8 +29,8 @@ def ApplyOptimizeSharedMemoryReadsAndWritesOp :
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                    DefaultValuedOptionalAttr<I64Attr, "128">:$kSharedMemoryLineSizeBytes,
-                    DefaultValuedOptionalAttr<I64Attr, "128">:$kDefaultVectorSizeBits);
+                    DefaultValuedOptionalAttr<I64Attr, "128">:$sharedMemoryLineSizeBytes,
+                    DefaultValuedOptionalAttr<I64Attr, "128">:$defaultVectorSizeBits);
   let results = (outs);
 
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
index 08b57f7c8182f4..b7e17a92897389 100644
--- a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
@@ -27,8 +27,8 @@ DiagnosedSilenceableFailure
 ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
     TransformRewriter &rewriter, FuncOp funcOp, ApplyToEachResultList &results,
     TransformState &state) {
-  optimizeSharedMemoryReadsAndWritesOp(funcOp, getKSharedMemoryLineSizeBytes(),
-                                       getKDefaultVectorSizeBits());
+  optimizeSharedMemoryReadsAndWritesOp(funcOp, getSharedMemoryLineSizeBytes(),
+                                       getDefaultVectorSizeBits());
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 83b78a3e7d3c68..797ba7f50ab8e4 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -53,8 +53,8 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
   // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
   const int64_t permuteEveryN = std::max<int64_t>(
       1, sharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
-                                        memrefTy.getElementTypeBitWidth()) /
-                                       8));
+                                       memrefTy.getElementTypeBitWidth()) /
+                                      8));
 
   // clang-format off
   // Index bit representation (b0 = least significant bit) for dim(1)
diff --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
index 83fcc2520f3ce7..b1bb91ffc29721 100644
--- a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -48,7 +48,7 @@
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
-    transform.amdgpu.optimize_shared_memory_reads_and_writes %0 {kSharedMemoryLineSizeBytes = 128, kDefaultVectorSizeBits = 128}: (!transform.any_op) -> ()
+    transform.amdgpu.optimize_shared_memory_reads_and_writes %0 {sharedMemoryLineSizeBytes = 128, defaultVectorSizeBits = 128}: (!transform.any_op) -> ()
     transform.yield
   } // @__transform_main
 } // module

>From a74928c17e94fcc37f072ed2189fe1596072312c Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Tue, 27 Feb 2024 18:38:35 -0600
Subject: [PATCH 8/8] Add variables as options using ODS

---
 .../mlir/Dialect/AMDGPU/Transforms/Passes.td  |  9 +++++++-
 .../Transforms/OptimizeSharedMemory.cpp       | 21 ++++---------------
 2 files changed, 12 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index c8059e6d316e8a..f30283f1b55925 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -37,10 +37,17 @@ def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
     attempts to optimize reads/writes from a memref representing GPU shared
     memory in order to avoid bank conflicts.
   }];
-
   let dependentDialects = [
     "memref::MemRefDialect", "vector::VectorDialect"
   ];
+  let options = [
+    Option<"sharedMemoryLineSizeBytes", "shared-memory-Line-size-bytes", "uint64_t",
+           /*default=*/"128",
+           "Shared memory line size in bytes">,
+    Option<"defaultVectorSizeBits", "default-vector-size-bits", "uint64_t",
+           /*default=*/"128",
+           "Default vector size in bits">,
+  ];
 }
 
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 797ba7f50ab8e4..a0cde5952f3a65 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -149,7 +149,6 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
 LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
     Operation *parentOp, Value memrefValue, int64_t sharedMemoryLineSizeBytes,
     int64_t defaultVectorSizeBits) {
-
   auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
   if (!memRefType ||
       !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
@@ -239,20 +238,12 @@ amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
 
 struct OptimizeSharedMemoryPass
     : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
-
 public:
-  OptimizeSharedMemoryPass()
-      : OptimizeSharedMemoryBase(),
-        sharedMemoryLineSizeBytes(sharedMemoryLineSizeBytes = 128),
-        defaultVectorSizeBits(defaultVectorSizeBits = 128){};
-
-  OptimizeSharedMemoryPass(int64_t sharedMemoryLineSizeBytes,
-                           int64_t defaultVectorSizeBits)
-      : OptimizeSharedMemoryBase(),
-        sharedMemoryLineSizeBytes(sharedMemoryLineSizeBytes),
-        defaultVectorSizeBits(defaultVectorSizeBits){};
-
+  OptimizeSharedMemoryPass() = default;
+  OptimizeSharedMemoryPass(const OptimizeSharedMemoryOptions &options)
+      : OptimizeSharedMemoryBase(options) {}
   void runOnOperation() override {
+
     Operation *op = getOperation();
     SmallVector<memref::AllocOp> shmAllocOps;
     op->walk([&](memref::AllocOp allocOp) {
@@ -268,8 +259,4 @@ struct OptimizeSharedMemoryPass
         return;
     }
   }
-
-private:
-  int64_t sharedMemoryLineSizeBytes;
-  int64_t defaultVectorSizeBits;
 };



More information about the Mlir-commits mailing list