[Mlir-commits] [llvm] [mlir] [AMDGPU][MLIR]Add shmem-optimization as an op using transform dialect (PR #81550)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 13 17:33:53 PST 2024


https://github.com/erman-gurses updated https://github.com/llvm/llvm-project/pull/81550

>From 3e54be96bd8fc7e1410256be991586259298e47f Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Mon, 12 Feb 2024 17:35:32 -0600
Subject: [PATCH 1/5] Add shmem-opimization as an op using transform dialect

---
 .../mlir/Dialect/AMDGPU/CMakeLists.txt        |  1 +
 .../AMDGPU/TransformOps/AMDGPUTransformOps.h  | 49 +++++++++++
 .../AMDGPU/TransformOps/AMDGPUTransformOps.td | 46 ++++++++++
 .../AMDGPU/TransformOps/CMakeLists.txt        |  4 +
 .../Dialect/AMDGPU/Transforms/Transforms.h    |  3 +
 mlir/include/mlir/InitAllExtensions.h         |  2 +
 mlir/lib/Dialect/AMDGPU/CMakeLists.txt        |  3 +-
 .../TransformOps/AMDGPUTransformOps.cpp       | 84 +++++++++++++++++++
 .../AMDGPU/TransformOps/CMakeLists.txt        | 25 ++++++
 .../Transforms/OptimizeSharedMemory.cpp       | 15 ++++
 ...transform_optimize_shmem_reads_writes.mlir | 64 ++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     | 66 +++++++++++++++
 12 files changed, 361 insertions(+), 1 deletion(-)
 create mode 100644 mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h
 create mode 100644 mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
 create mode 100644 mlir/include/mlir/Dialect/AMDGPU/TransformOps/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
 create mode 100644 mlir/lib/Dialect/AMDGPU/TransformOps/CMakeLists.txt
 create mode 100644 mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir

diff --git a/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
index 9f57627c321fb0..660deb21479d29 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
+add_subdirectory(TransformOps)
 add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h
new file mode 100644
index 00000000000000..c7721f2f4e0ceb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h
@@ -0,0 +1,49 @@
+//===- AMDGPUTransformOps.h - AMDGPU transform ops ----------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMOPS_AMDGPUTRANSFORMOPS_H
+#define MLIR_DIALECT_AMDGPU_TRANSFORMOPS_AMDGPUTRANSFORMOPS_H
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+namespace mlir {
+namespace transform {
+class TransformHandleTypeInterface;
+} // namespace transform
+} // namespace mlir
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+class LinalgOp;
+} // namespace linalg
+
+namespace scf {
+class ForOp;
+} // namespace scf
+
+namespace amdgpu {
+void registerTransformDialectExtension(DialectRegistry &registry);
+} // namespace amdgpu
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// AMDGPU Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h.inc"
+
+#endif // MLIR_DIALECT_AMDGPU_TRANSFORMOPS_AMDGPUTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
new file mode 100644
index 00000000000000..f028aa14097ebb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
@@ -0,0 +1,46 @@
+//===- AMDGPUTransformOps.td - AMDGPU transform ops ----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef AMDGPU_TRANSFORM_OPS
+#define AMDGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// ApplyOptimizeSharedMemoryReadsAndWritesOp
+//===----------------------------------------------------------------------===//
+
+def ApplyOptimizeSharedMemoryReadsAndWritesOp :
+  Op<Transform_Dialect, "amdgpu.optimize_shared_memory_reads_and_writes",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Reduce shared memory bank conflicts";
+  let description = [{ This opp adds a transformation and pass to the AMDGPU 
+    dialect that attempts to optimize reads/writes from a memref representing 
+    GPU shared memory in order to avoid bank conflicts.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::func::FuncOp funcOp,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+#endif // AMDGPU_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 00000000000000..07bfebc9f96d2e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS AMDGPUTransformOps.td)
+mlir_tablegen(AMDGPUTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(AMDGPUTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRAMDGPUTransformOpsIncGen) 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index 140bc12deed690..b4e9ad27003db1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
 #define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_
 
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Support/LogicalResult.h"
 
@@ -48,6 +49,8 @@ namespace amdgpu {
 mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                        Value memrefValue);
 
+void optimizeSharedMemoryReadsAndWritesOp(mlir::func::FuncOp funcOp);
+
 } // namespace amdgpu
 } // namespace mlir
 
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 7708ca5571de3b..23d88e00ded1fb 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -26,6 +26,7 @@
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h"
 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
 #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
@@ -66,6 +67,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   ub::registerConvertUBToLLVMInterface(registry);
 
   // Register all transform dialect extensions.
+  amdgpu::registerTransformDialectExtension(registry);
   affine::registerTransformDialectExtension(registry);
   bufferization::registerTransformDialectExtension(registry);
   func::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Dialect/AMDGPU/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
index 31167e6af908b9..63b4d8b99f53fd 100644
--- a/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
-add_subdirectory(Transforms)
 add_subdirectory(Utils)
+add_subdirectory(TransformOps)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
new file mode 100644
index 00000000000000..cc2be492286f8c
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
@@ -0,0 +1,84 @@
+//===- AMDGPUTransformOps.cpp - Implementation of AMDGPU transform ops ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/ArrayRef.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::amdgpu;
+using namespace mlir::transform;
+
+#define DEBUG_TYPE "amdgpu-transforms"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
+                                       
+DiagnosedSilenceableFailure transform::ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
+    transform::TransformRewriter &rewriter, ::mlir::func::FuncOp funcOp,
+    ::mlir::transform::ApplyToEachResultList &results,
+    ::mlir::transform::TransformState &state) {
+  mlir::amdgpu::optimizeSharedMemoryReadsAndWritesOp(funcOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ApplyOptimizeSharedMemoryReadsAndWritesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class AMDGPUTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          AMDGPUTransformDialectExtension> {
+public:
+  AMDGPUTransformDialectExtension() {
+    declareGeneratedDialect<arith::ArithDialect>();
+    declareGeneratedDialect<affine::AffineDialect>();
+    declareGeneratedDialect<amdgpu::AMDGPUDialect>();
+    declareGeneratedDialect<vector::VectorDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp.inc"
+
+void mlir::amdgpu::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<AMDGPUTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/AMDGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 00000000000000..c39a3b55eabca4
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,25 @@
+add_mlir_dialect_library(MLIRAMDGPUTransformOps
+  AMDGPUTransformOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/TransformOps
+
+  DEPENDS
+  MLIRAMDGPUTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRAffineDialect
+  MLIRArithDialect
+  MLIRIR
+  MLIRLinalgDialect
+  MLIRAMDGPUDialect
+  MLIRAMDGPUTransforms
+  MLIRParser
+  MLIRSideEffectInterfaces
+  MLIRSCFDialect
+  MLIRSCFTransforms
+  MLIRTransformDialect
+  MLIRTransformDialectUtils
+  MLIRVectorTransforms
+
+  )
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index c7001fc6d57d5f..4fb8242a0afcfd 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -220,6 +220,21 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   return success();
 }
 
+void mlir::amdgpu::optimizeSharedMemoryReadsAndWritesOp(
+    ::mlir::func::FuncOp funcOp) {
+  SmallVector<memref::AllocOp> shmAllocOps;
+  funcOp.walk([&](memref::AllocOp allocOp) {
+    if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
+      return;
+    shmAllocOps.push_back(allocOp);
+  });
+  for (auto allocOp : shmAllocOps) {
+    if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
+                                                          allocOp.getMemref())))
+      return;
+  }
+}
+
 struct OptimizeSharedMemoryPass
     : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
 public:
diff --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
new file mode 100644
index 00000000000000..dfdd1b17e244e3
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-opt  %s -transform-interpreter  | FileCheck %s
+
+  // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
+  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, 
+                    %readRow: index, %readCol: index,
+                    %writeRow: index, %writeCol: index,
+                    %fragRow: index, %fragCol: index, 
+                    %fragColPerm: index,
+                    %stRow: index, %stCol: index) {
+    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16                  
+    %cst = arith.constant 0.000000e+00 : f16
+
+    // CHECK: [[shmA:%.+]] = memref.alloc
+    // CHECK: [[shmB:%.+]] = memref.alloc
+    %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
+    %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
+
+    // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
+    %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
+    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
+    // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
+    vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
+    gpu.barrier
+    gpu.barrier
+    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]     
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]       
+    // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] 
+    // CHECK:  vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
+    %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
+
+    // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
+    %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
+    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
+    // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
+    vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
+    gpu.barrier
+    gpu.barrier
+    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]          
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index                     
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 
+    // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
+    // CHECK:  vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
+    %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
+    return
+  }
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.amdgpu.optimize_shared_memory_reads_and_writes %0 : (!transform.any_op) -> ()
+    transform.yield
+  } // @__transform_main
+} // module
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 821481ee272a56..2c534c7614b10d 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1511,6 +1511,70 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "AMDGPUTransformOps",
+    srcs = glob([
+        "lib/Dialect/AMDGPU/TransformOps/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/mlir/Dialect/AMDGPU/TransformOps/*.h",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":AMDGPUDialect",
+        ":AMDGPUTransformOpsIncGen",
+        ":AMDGPUTransforms",
+        ":AffineDialect",
+        ":Analysis",
+        ":ArithDialect",
+        ":ArithUtils",
+        ":DialectUtils",
+        ":GPUCommonTransforms",
+        ":GPUCompilationAttrInterfacesIncGen",
+        ":GPUDialect",
+        ":IR",
+        ":LLVMCommonConversion",
+        ":LinalgDialect",
+        ":MemRefDialect",
+        ":SCFDialect",
+        ":SCFTransforms",
+        ":Support",
+        ":TransformDialect",
+        ":VectorDialect",
+        "//llvm:Support",
+    ],
+)
+
+td_library(
+    name = "AMDGPUTransformOpsTdFiles",
+    srcs = glob([
+        "include/mlir/Dialect/AMDGPU/TransformOps/*.td",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":TransformDialectTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "AMDGPUTransformOpsIncGen",
+    tbl_outs = [
+        (
+            ["-gen-op-decls"],
+            "include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h.inc",
+        ),
+        (
+            ["-gen-op-defs"],
+            "include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td",
+    deps = [
+        ":AMDGPUTransformOpsTdFiles",
+    ],
+)
+
 gentbl_cc_library(
     name = "AMDGPUPassIncGen",
     tbl_outs = [
@@ -4614,6 +4678,7 @@ cc_library(
     name = "AllExtensions",
     hdrs = ["include/mlir/InitAllExtensions.h"],
     deps = [
+        ":AMDGPUTransformOps",
         ":AffineTransformOps",
         ":ArithToLLVM",
         ":BufferizationTransformOps",
@@ -8961,6 +9026,7 @@ cc_library(
     deps = [
         ":AMDGPUDialect",
         ":AMDGPUToROCDL",
+        ":AMDGPUTransformOps",
         ":AMDGPUTransforms",
         ":AMXDialect",
         ":AMXTransforms",

>From 60e42a53f5fa955c8ed061ff8cdbd0b530084568 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Mon, 12 Feb 2024 17:48:40 -0600
Subject: [PATCH 2/5] Fix the format

---
 mlir/include/mlir/InitAllExtensions.h           |  2 +-
 .../AMDGPU/TransformOps/AMDGPUTransformOps.cpp  | 17 ++++++++++-------
 2 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 23d88e00ded1fb..b31fb26f00f8f4 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -23,10 +23,10 @@
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
-#include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h"
 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
 #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
diff --git a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
index cc2be492286f8c..fd82e2497a5799 100644
--- a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
@@ -1,4 +1,5 @@
-//===- AMDGPUTransformOps.cpp - Implementation of AMDGPU transform ops ------===//
+//===- AMDGPUTransformOps.cpp - Implementation of AMDGPU transform ops
+//------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -11,15 +12,15 @@
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
-#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -39,8 +40,9 @@ using namespace mlir::transform;
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 #define DBGSNL() (llvm::dbgs() << "\n")
 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
-                                       
-DiagnosedSilenceableFailure transform::ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
+
+DiagnosedSilenceableFailure
+transform::ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
     transform::TransformRewriter &rewriter, ::mlir::func::FuncOp funcOp,
     ::mlir::transform::ApplyToEachResultList &results,
     ::mlir::transform::TransformState &state) {
@@ -79,6 +81,7 @@ class AMDGPUTransformDialectExtension
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp.inc"
 
-void mlir::amdgpu::registerTransformDialectExtension(DialectRegistry &registry) {
+void mlir::amdgpu::registerTransformDialectExtension(
+    DialectRegistry &registry) {
   registry.addExtensions<AMDGPUTransformDialectExtension>();
 }

>From 82045cad4518bda94319e57ae24f6b6e2506d22f Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Mon, 12 Feb 2024 19:25:51 -0600
Subject: [PATCH 3/5] Remove redundant namespaces

---
 .../TransformOps/AMDGPUTransformOps.cpp       | 21 ++++++++-----------
 1 file changed, 9 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
index fd82e2497a5799..cf76691381b1ec 100644
--- a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
@@ -19,7 +19,6 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
@@ -32,9 +31,9 @@
 #include "llvm/ADT/ArrayRef.h"
 
 using namespace mlir;
-using namespace mlir::linalg;
 using namespace mlir::amdgpu;
 using namespace mlir::transform;
+using namespace mlir::func;
 
 #define DEBUG_TYPE "amdgpu-transforms"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -42,18 +41,17 @@ using namespace mlir::transform;
 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
 
 DiagnosedSilenceableFailure
-transform::ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
-    transform::TransformRewriter &rewriter, ::mlir::func::FuncOp funcOp,
-    ::mlir::transform::ApplyToEachResultList &results,
-    ::mlir::transform::TransformState &state) {
-  mlir::amdgpu::optimizeSharedMemoryReadsAndWritesOp(funcOp);
+ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
+    TransformRewriter &rewriter, FuncOp funcOp, ApplyToEachResultList &results,
+    TransformState &state) {
+  optimizeSharedMemoryReadsAndWritesOp(funcOp);
   return DiagnosedSilenceableFailure::success();
 }
 
-void transform::ApplyOptimizeSharedMemoryReadsAndWritesOp::getEffects(
+void ApplyOptimizeSharedMemoryReadsAndWritesOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getTarget(), effects);
-  transform::modifiesPayload(effects);
+  onlyReadsHandle(getTarget(), effects);
+  modifiesPayload(effects);
 }
 
 //===----------------------------------------------------------------------===//
@@ -62,8 +60,7 @@ void transform::ApplyOptimizeSharedMemoryReadsAndWritesOp::getEffects(
 
 namespace {
 class AMDGPUTransformDialectExtension
-    : public transform::TransformDialectExtension<
-          AMDGPUTransformDialectExtension> {
+    : public TransformDialectExtension<AMDGPUTransformDialectExtension> {
 public:
   AMDGPUTransformDialectExtension() {
     declareGeneratedDialect<arith::ArithDialect>();

>From 1e6a448a32857882079daf44a264cd65e2d1fd53 Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Tue, 13 Feb 2024 11:03:53 -0600
Subject: [PATCH 4/5] Address the last comments

---
 .../AMDGPU/TransformOps/AMDGPUTransformOps.h  |  3 +--
 .../AMDGPU/TransformOps/AMDGPUTransformOps.td |  7 +++---
 .../TransformOps/AMDGPUTransformOps.cpp       | 22 ++-----------------
 .../Transforms/OptimizeSharedMemory.cpp       |  6 ++---
 4 files changed, 8 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h
index c7721f2f4e0ceb..4fb4ab08a0da34 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h
@@ -1,5 +1,4 @@
-//===- AMDGPUTransformOps.h - AMDGPU transform ops ----------------*- C++
-//-*-===//
+//===- AMDGPUTransformOps.h - AMDGPU transform ops ---------------*- C++-*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
index f028aa14097ebb..23873d86b495c6 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
@@ -1,4 +1,4 @@
-//===- AMDGPUTransformOps.td - AMDGPU transform ops ----------*- tablegen -*-===//
+//===- AMDGPUTransformOps.td - AMDGPU transform ops --------*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -24,9 +24,8 @@ def ApplyOptimizeSharedMemoryReadsAndWritesOp :
     [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      TransformOpInterface, TransformEachOpTrait]> {
   let summary = "Reduce shared memory bank conflicts";
-  let description = [{ This opp adds a transformation and pass to the AMDGPU 
-    dialect that attempts to optimize reads/writes from a memref representing 
-    GPU shared memory in order to avoid bank conflicts.
+  let description = [{ This op attempts to optimize GPU Shared memory
+    reads/writes with the goal of avoiding bank conflicts.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target);
diff --git a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
index cf76691381b1ec..ff29f9f6938535 100644
--- a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
@@ -1,5 +1,4 @@
-//===- AMDGPUTransformOps.cpp - Implementation of AMDGPU transform ops
-//------===//
+//===- AMDGPUTransformOps.cpp - Implementation of AMDGPU transform ops-----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -9,26 +8,10 @@
 
 #include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.h"
 
-#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
-#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 #include "mlir/Dialect/AMDGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/AMDGPU/Transforms/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Transforms.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Value.h"
-#include "llvm/ADT/ArrayRef.h"
 
 using namespace mlir;
 using namespace mlir::amdgpu;
@@ -78,7 +61,6 @@ class AMDGPUTransformDialectExtension
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp.inc"
 
-void mlir::amdgpu::registerTransformDialectExtension(
-    DialectRegistry &registry) {
+void amdgpu::registerTransformDialectExtension(DialectRegistry &registry) {
   registry.addExtensions<AMDGPUTransformDialectExtension>();
 }
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 4fb8242a0afcfd..66beaa3c58665f 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -24,8 +24,6 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/MathExtras.h"
 
 namespace mlir {
 namespace amdgpu {
@@ -220,8 +218,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   return success();
 }
 
-void mlir::amdgpu::optimizeSharedMemoryReadsAndWritesOp(
-    ::mlir::func::FuncOp funcOp) {
+void amdgpu::optimizeSharedMemoryReadsAndWritesOp(
+    func::FuncOp funcOp) {
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))

>From 889f1a14c3d4fee35ca9076b3110753ea07f90ed Mon Sep 17 00:00:00 2001
From: erman-gurses <erman at nod-labs.com>
Date: Tue, 13 Feb 2024 11:10:58 -0600
Subject: [PATCH 5/5] Fix the format

---
 .../Transforms/OptimizeSharedMemory.cpp       | 35 +++++++++----------
 1 file changed, 17 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 66beaa3c58665f..7c50a876e78f45 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -50,12 +50,12 @@ constexpr int64_t kDefaultVectorSizeBits = 64;
 static Value permuteVectorOffset(OpBuilder &b, Location loc,
                                  ArrayRef<Value> indices, MemRefType memrefTy,
                                  int64_t srcDim, int64_t tgtDim) {
-  // Adjust the src index to change how often the permutation changes
-  // if necessary.
+  /// Adjust the src index to change how often the permutation changes
+  /// if necessary.
   Value src = indices[srcDim];
 
-  // We only want to permute every N iterations of the target dim where N is
-  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+  /// We only want to permute every N iterations of the target dim where N is
+  /// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
   const int64_t permuteEveryN = std::max<int64_t>(
       1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
                                         memrefTy.getElementTypeBitWidth()) /
@@ -81,8 +81,8 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
   Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
   srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
 
-  // Use the src bits to permute the target bits b[N:M] containing the
-  // vector offset.
+  /// Use the src bits to permute the target bits b[N:M] containing the
+  /// vector offset.
   if (permuteEveryN > 1) {
     int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
     if (shlBits > 0) {
@@ -131,8 +131,8 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
       writeOps.push_back(op);
   });
 
-  // Restrict to a supported set of ops. We also require at least 2D access,
-  // although this could be relaxed.
+  /// Restrict to a supported set of ops. We also require at least 2D access,
+  /// although this could be relaxed.
   if (llvm::any_of(readOps, [](Operation *op) {
         return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
                    op) ||
@@ -157,15 +157,15 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
       !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
     return failure();
 
-  // Abort if the given value has any sub-views; we do not do any alias
-  // analysis.
+  /// Abort if the given value has any sub-views; we do not do any alias
+  /// analysis.
   bool hasSubView = false;
   parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
   if (hasSubView)
     return failure();
 
-  // Check if this is necessary given the assumption of 128b accesses:
-  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+  /// Check if this is necessary given the assumption of 128b accesses:
+  /// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
   const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
   const int64_t rowsPerLine =
       (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
@@ -175,8 +175,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   if (rowsPerLine >= threadGroupSize)
     return failure();
 
-  // Get sets of operations within the function that read/write to shared
-  // memory.
+  /// Get sets of operations within the function that read/write to shared
+  /// memory.
   SmallVector<Operation *, 16> shmReadOps;
   SmallVector<Operation *, 16> shmWriteOps;
   if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
@@ -191,7 +191,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   int64_t tgtDim = memRefType.getRank() - 1;
   int64_t srcDim = memRefType.getRank() - 2;
 
-  // Transform indices for the ops writing to shared memory.
+  /// Transform indices for the ops writing to shared memory.
   while (!shmWriteOps.empty()) {
     Operation *shmWriteOp = shmWriteOps.pop_back_val();
     builder.setInsertionPoint(shmWriteOp);
@@ -203,7 +203,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
     amdgpu::setIndices(shmWriteOp, transformedIndices);
   }
 
-  // Transform indices for the ops reading from shared memory.
+  /// Transform indices for the ops reading from shared memory.
   while (!shmReadOps.empty()) {
     Operation *shmReadOp = shmReadOps.pop_back_val();
     builder.setInsertionPoint(shmReadOp);
@@ -218,8 +218,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   return success();
 }
 
-void amdgpu::optimizeSharedMemoryReadsAndWritesOp(
-    func::FuncOp funcOp) {
+void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))



More information about the Mlir-commits mailing list