[Mlir-commits] [mlir] Add parameterization for optimized shared memory variables (PR #82508)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 21 09:14:02 PST 2024


https://github.com/erman-gurses updated https://github.com/llvm/llvm-project/pull/82508

>From 65453c098f70d0c387dac85c9422cb0c7736ff07 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Wed, 21 Feb 2024 10:58:30 -0600
Subject: [PATCH 1/2] Add parameterization for optimized shared memory
 variables

---
 .../AMDGPU/TransformOps/AMDGPUTransformOps.td |  6 ++-
 .../Transforms/OptimizeSharedMemory.cpp       | 33 ++++++++++--
 .../AMDGPU/optimize_shmem_reads_writes.mlir   | 50 ++++++++-----------
 ...transform_optimize_shmem_reads_writes.mlir | 46 ++++++++---------
 4 files changed, 79 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
index 23873d86b495c6..9a9446155bf27f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
@@ -13,8 +13,8 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
-
 //===----------------------------------------------------------------------===//
 // ApplyOptimizeSharedMemoryReadsAndWritesOp
 //===----------------------------------------------------------------------===//
@@ -28,7 +28,9 @@ def ApplyOptimizeSharedMemoryReadsAndWritesOp :
     reads/writes with the goal of avoiding bank conflicts.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target);
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                    DefaultValuedOptionalAttr<I64Attr, "128">:$kSharedMemoryLineSizeBytes,
+                    DefaultValuedOptionalAttr<I64Attr, "128">:$kDefaultVectorSizeBits);
   let results = (outs);
 
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 6bd03ed833898d..00b70d673adc9e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -37,11 +37,18 @@ using namespace mlir::amdgpu;
 
 /// The size of a shared memory line according to AMD documentation.
 /// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
-constexpr int64_t kSharedMemoryLineSizeBytes = 64;
+int64_t kSharedMemoryLineSizeBytes;
 /// We optimize for 64bit accesses, but this can be made an argument in the
 /// future.
-constexpr int64_t kDefaultVectorSizeBits = 64;
+int64_t kDefaultVectorSizeBits;
 
+void setMemoryLineSize(int64_t _kSharedMemoryLineSizeBytes) {
+  kSharedMemoryLineSizeBytes = _kSharedMemoryLineSizeBytes;
+}
+
+void setDefaultVectorSize(int64_t _kDefaultVectorSizeBits) {
+  kDefaultVectorSizeBits = _kDefaultVectorSizeBits;
+}
 /// Uses `srcIndexValue` to permute `tgtIndexValue` via
 /// `result = xor(floordiv(srcIdxVal,permuteEveryN),
 ///               floordiv(tgtIdxVal,vectorSize)))
@@ -151,6 +158,7 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
 
 LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                          Value memrefValue) {
+
   auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
   if (!memRefType ||
       !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
@@ -219,6 +227,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
 
 std::optional<LogicalResult>
 amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
+  //setMemoryLineSize(_kSharedMemoryLineSizeBytes);
+  //setDefaultVectorSize(_kDefaultVectorSizeBits);
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -235,10 +245,23 @@ amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
 
 struct OptimizeSharedMemoryPass
     : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
+
 public:
-  OptimizeSharedMemoryPass() = default;
+  OptimizeSharedMemoryPass()
+      : OptimizeSharedMemoryBase(),
+        _kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes = 128),
+        _kDefaultVectorSizeBits(kDefaultVectorSizeBits = 128){};
+
+  OptimizeSharedMemoryPass(int64_t kSharedMemoryLineSizeBytes,
+                           int64_t kDefaultVectorSizeBits)
+      : OptimizeSharedMemoryBase(),
+        _kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes),
+        _kDefaultVectorSizeBits(kDefaultVectorSizeBits){};
 
   void runOnOperation() override {
+    setMemoryLineSize(_kSharedMemoryLineSizeBytes);
+    setDefaultVectorSize(_kDefaultVectorSizeBits);
+
     Operation *op = getOperation();
     SmallVector<memref::AllocOp> shmAllocOps;
     op->walk([&](memref::AllocOp allocOp) {
@@ -253,4 +276,8 @@ struct OptimizeSharedMemoryPass
         return;
     }
   }
+
+private:
+  int64_t _kSharedMemoryLineSizeBytes;
+  int64_t _kDefaultVectorSizeBits;
 };
diff --git a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
index a1de1ff87c229f..983eee732e2afe 100644
--- a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
@@ -1,13 +1,13 @@
-// RUN: mlir-opt  %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
   
   // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
-  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, 
+  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
                     %readRow: index, %readCol: index,
                     %writeRow: index, %writeCol: index,
-                    %fragRow: index, %fragCol: index, 
+                    %fragRow: index, %fragCol: index,
                     %fragColPerm: index,
                     %stRow: index, %stCol: index) {
-    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16                  
+    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16
     %cst = arith.constant 0.000000e+00 : f16
 
     // CHECK: [[shmA:%.+]] = memref.alloc
@@ -15,42 +15,36 @@
     %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
     %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
 
-    // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]     
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]       
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] 
-    // CHECK:  vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
     %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
 
-    // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]          
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                     
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
-    // CHECK:  vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
     %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
     return
   }
diff --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
index 143e7c2d270952..83fcc2520f3ce7 100644
--- a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt  %s -transform-interpreter  | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
 
   // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
-  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, 
+  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
                     %readRow: index, %readCol: index,
                     %writeRow: index, %writeCol: index,
-                    %fragRow: index, %fragCol: index, 
+                    %fragRow: index, %fragCol: index,
                     %fragColPerm: index,
                     %stRow: index, %stCol: index) {
     %cst = arith.constant 0.000000e+00 : f16
@@ -13,33 +13,33 @@
     %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
 
     %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]     
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]       
-    // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] 
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
     %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
     %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]          
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                     
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
     %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
     return
@@ -48,7 +48,7 @@
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
-    transform.amdgpu.optimize_shared_memory_reads_and_writes %0 : (!transform.any_op) -> ()
+    transform.amdgpu.optimize_shared_memory_reads_and_writes %0 {kSharedMemoryLineSizeBytes = 128, kDefaultVectorSizeBits = 128}: (!transform.any_op) -> ()
     transform.yield
   } // @__transform_main
 } // module

>From e9ea568f71f1feb9b66a9042451cb2e027e8341a Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Wed, 21 Feb 2024 11:13:28 -0600
Subject: [PATCH 2/2] Add formatting

---
 mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 00b70d673adc9e..99ca266a737b3b 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -227,8 +227,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
 
 std::optional<LogicalResult>
 amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
-  //setMemoryLineSize(_kSharedMemoryLineSizeBytes);
-  //setDefaultVectorSize(_kDefaultVectorSizeBits);
+  // setMemoryLineSize(_kSharedMemoryLineSizeBytes);
+  // setDefaultVectorSize(_kDefaultVectorSizeBits);
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))



More information about the Mlir-commits mailing list