[Mlir-commits] [mlir] Implement gpu.subgroup_reduce with DPP intrinsics on AMD GPUs (PR #133204)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 26 21:17:10 PDT 2025


https://github.com/Muzammiluddin-Syed-ECE created https://github.com/llvm/llvm-project/pull/133204

[DRAFT]
See related [Issue](https://github.com/iree-org/iree/issues/20007)
We can better leverage DPP ops in the AMDGPU dialect when lowering subgroup reduce ops. 

To this end this PR implements a new pass where we perform such a lowering. 

To do:
- Improve lowering to subgroup_reduce in compatible matvecs (these get directly lowered to gpu.shuffles in an earlier pass)
- Add test for pass


>From a7b81b4dcbf927442046568304bc835205aceefa Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Tue, 25 Mar 2025 14:04:06 -0500
Subject: [PATCH 1/4] temp

---
 .../GPU/Transforms/SubgroupReduceLowering.cpp | 162 +++++++++++++++++-
 1 file changed, 161 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 43eff3eddcc49..0b553274eceb4 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -11,10 +11,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
@@ -24,6 +26,8 @@
 #include <cassert>
 #include <cstdint>
 
+#define DPP
+
 using namespace mlir;
 
 namespace {
@@ -188,6 +192,8 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
                                      function_ref<Value(Value)> unpackFn) {
   // Lane value always stays in the original type. We use it to perform arith
   // reductions.
+  llvm::errs() << "Cluster Stride: " << ci.clusterStride << "\n";
+  llvm::errs() << "Cluster Size: " << ci.clusterSize << "\n";
   Value laneVal = input;
   // Parallel reduction using butterfly shuffles.
   for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
@@ -206,6 +212,146 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
   return laneVal;
 }
 
+#ifdef DPP
+Value createSubgroupDPPReduction(OpBuilder &b, Location loc,
+  Value input, gpu::AllReduceOperation mode,
+  const ClusterInfo &ci,
+  function_ref<Value(Value)> packFn,
+  function_ref<Value(Value)> unpackFn) {
+  llvm::errs() << "createSubgroupDPPReduction" << "\n";
+  Value result = input;
+  if (ci.clusterSize >= 2) {
+    auto permArg = b.getIntegerAttr(b.getIntegerType(32), 1);
+    Value dppResult = b.create<amdgpu::DPPOp>(loc, result.getType(), result, result, amdgpu::DPPPerm::row_shr, permArg);
+    llvm::errs() << dppResult << " c 2 \n";
+    result = vector::makeArithReduction(b, loc,
+      gpu::convertReductionKind(mode),
+      result, dppResult);
+  }
+
+  if (ci.clusterSize >= 4) {
+    auto permArg = b.getIntegerAttr(b.getIntegerType(32), 2);
+    Value dppResult = b.create<amdgpu::DPPOp>(loc, result.getType(), result, result, amdgpu::DPPPerm::row_shr, permArg);
+    llvm::errs() << dppResult << " c 4 \n";
+    result = vector::makeArithReduction(b, loc,
+      gpu::convertReductionKind(mode),
+      result, dppResult);
+  }
+
+  if (ci.clusterSize >= 8) {
+
+    Value dppResult = b.create<amdgpu::DPPOp>(loc, result.getType(), result, result, amdgpu::DPPPerm::row_half_mirror, b.getUnitAttr());
+    llvm::errs() << dppResult << " c 8 \n";
+    result = vector::makeArithReduction(b, loc,
+      gpu::convertReductionKind(mode),
+      result, dppResult);
+  }
+
+  if (ci.clusterSize >= 16) {
+    Value dppResult = b.create<amdgpu::DPPOp>(loc, result.getType(), result, result, amdgpu::DPPPerm::row_mirror, b.getUnitAttr());
+    llvm::errs() << dppResult << " c 16 \n";
+    result = vector::makeArithReduction(b, loc,
+      gpu::convertReductionKind(mode),
+      result, dppResult);
+  }
+
+  if (ci.clusterSize >= 32) {
+    // auto permArg = builder.getInt32(15);
+    // auto rowMask = builder.getInt32("0xa");
+    // auto bankMask = builder.getInt32("0xf");
+    // auto boundCtrl = builder.getBoolAttr(false);
+    auto permArg = b.getIntegerAttr(b.getIntegerType(32), 15);
+    Value dppResult = b.create<amdgpu::DPPOp>(loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_15, b.getUnitAttr(), 10, 15, false);
+    llvm::errs() << dppResult << " c 32 \n";
+    result = vector::makeArithReduction(b, loc,
+      gpu::convertReductionKind(mode),
+      result, dppResult);
+  }
+
+  if (ci.clusterSize == 64) {
+    // auto permArg = builder.getInt32(31);
+    // auto rowMask = builder.getInt32("0xc");
+    // auto bankMask = builder.getInt32("0xf");
+    // auto boundCtrl = builder.getBoolAttr(false);
+    auto permArg = b.getIntegerAttr(b.getIntegerType(32), 31);
+    Value dppResult = b.create<amdgpu::DPPOp>(loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_31, b.getUnitAttr(), 12, 15, false);
+    llvm::errs() << dppResult << " c 64 \n";
+    result = vector::makeArithReduction(b, loc,
+      gpu::convertReductionKind(mode),
+      result, dppResult);
+  }
+  
+  // // read lane 63 with the final result. 
+  // auto lane = b.getIntegerAttr(b.getIntegerType(32), 63);
+  // result = b.create<ROCDL::ReadLaneOp>(loc, input.getType(), result, lane);  
+  assert(result.getType() == input.getType());
+  return result;
+}
+#endif
+
+// Value createSubgroupDPPReduction(OpBuilder &b, Location loc,
+//   Value input, gpu::AllReduceOperation mode,
+//   const ClusterInfo &ci,
+//   function_ref<Value(Value)> packFn,
+//   function_ref<Value(Value)> unpackFn) {
+
+//   Value result = input;
+//   if (ci.clusterSize >= 2) {
+//     auto permArg = b.getInt32(1);
+//     Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_shr, permArg);
+//     result = vector::makeArithReduction(builder, loc,
+//       gpu::convertReductionKind(mode),
+//       result, unpackFn(dppResult));
+//   }
+
+//   if (ci.clusterSize >= 4) {
+//     auto permArg = builder.getInt32(2);
+//     Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_shr, permArg);
+//     result = vector::makeArithReduction(builder, loc,
+//       gpu::convertReductionKind(mode),
+//       result, unpackFn(dppResult));
+//   }
+
+//   if (ci.clusterSize >= 8) {
+//     Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_half_mirror);
+//     result = vector::makeArithReduction(builder, loc,
+//       gpu::convertReductionKind(mode),
+//       result, unpackFn(dppResult));
+//   }
+
+//   if (ci.clusterSize >= 16) {
+//     Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_mirror);
+//     result = vector::makeArithReduction(builder, loc,
+//       gpu::convertReductionKind(mode),
+//       result, unpackFn(dppResult));
+//   }
+
+//   if (ci.clusterSize >= 32) {
+//     auto permArg = builder.getInt32(15);
+//     auto rowMask = builder.getInt32("0xa");
+//     auto bankMask = builder.getInt32("0xf");
+//     auto boundCtrl = builder.getBoolAttr(false);
+//     Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_bcast, permArg, rowMask, bankMask, boundCtrl);
+//     result = vector::makeArithReduction(builder, loc,
+//       gpu::convertReductionKind(mode),
+//       result, unpackFn(dppResult));
+//   }
+
+//   if (ci.clusterSize == 64) {
+//     auto permArg = builder.getInt32(31);
+//     auto rowMask = builder.getInt32("0xc");
+//     auto bankMask = builder.getInt32("0xf");
+//     auto boundCtrl = builder.getBoolAttr(false);
+//     Value dppResult = builder.create<amdgpu::DPPOp>(packFn(result), packFn(result), amdgpu::DPPPerm::row_bcast, permArg, rowMask, bankMask, boundCtrl);
+//     result = vector::makeArithReduction(builder, loc,
+//       gpu::convertReductionKind(mode),
+//       result, unpackFn(dppResult));
+//   }
+
+//   assert(result.getType() == input.getType());
+//   return result;
+// }
+
 /// Lowers scalar gpu subgroup reductions to a series of shuffles.
 struct ScalarSubgroupReduceToShuffles final
     : OpRewritePattern<gpu::SubgroupReduceOp> {
@@ -217,6 +363,7 @@ struct ScalarSubgroupReduceToShuffles final
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    llvm::errs() << "ScalarSubgroupReduceToShuffles" << "\n";
     if (op.getClusterSize().has_value() != matchClustered) {
       return rewriter.notifyMatchFailure(
           op, llvm::formatv("op is {0}clustered but pattern is configured to "
@@ -239,10 +386,17 @@ struct ScalarSubgroupReduceToShuffles final
     Location loc = op.getLoc();
     // Since this is already a native shuffle scalar, no packing is necessary.
     if (elemBitwidth == shuffleBitwidth) {
+      llvm::errs() << "ScalarSubgroupReduceToShuffles - 1" << "\n";
       auto identityFn = [](Value v) { return v; };
+#ifndef DPP
       rewriter.replaceOp(op, createSubgroupShuffleReduction(
                                  rewriter, loc, op.getValue(), op.getOp(), *ci,
                                  identityFn, identityFn));
+#else
+      rewriter.replaceOp(op, createSubgroupDPPReduction(
+                                  rewriter, loc, op.getValue(), op.getOp(), *ci,
+                                  identityFn, identityFn));
+#endif
       return success();
     }
 
@@ -260,10 +414,15 @@ struct ScalarSubgroupReduceToShuffles final
           rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
       return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
     };
-
+    llvm::errs() << "ScalarSubgroupReduceToShuffles - 2" << "\n";
+#ifndef DPP
     rewriter.replaceOp(
         op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
                                            op.getOp(), *ci, packFn, unpackFn));
+#else
+    rewriter.replaceOp(op, createSubgroupDPPReduction(rewriter, loc, op.getValue(),
+    op.getOp(), *ci, packFn, unpackFn));
+#endif
     return success();
   }
 
@@ -284,6 +443,7 @@ struct VectorSubgroupReduceToShuffles final
 
   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
                                 PatternRewriter &rewriter) const override {
+    llvm::errs() << "VectorSubgroupReduceToShuffles" << "\n";
     if (op.getClusterSize().has_value() != matchClustered) {
       return rewriter.notifyMatchFailure(
           op, llvm::formatv("op is {0}clustered but pattern is configured to "

>From a578aeb2baef34bbf9994619c972c9d4e0ec4a6d Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Tue, 25 Mar 2025 14:18:31 -0500
Subject: [PATCH 2/4] subgroup reduce change

---
 mlir/include/mlir/Conversion/Passes.h             |  1 +
 mlir/include/mlir/Conversion/Passes.td            | 15 +++++++++++++++
 .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp    |  1 +
 mlir/lib/Conversion/CMakeLists.txt                |  1 +
 4 files changed, 18 insertions(+)

diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ccd862f67c068..1189423799092 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -34,6 +34,7 @@
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h"
 #include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h"
 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..502ca5a84ee7b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -643,6 +643,21 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// GPUToAMDGPU
+//===----------------------------------------------------------------------===//
+
+def ConvertGPUToAMDGPUPass : Pass<"convert-gpu-to-amdgpu"> {
+  let summary = "Generate AMDGPU operations for gpu operations";
+  let dependentDialects = [
+    "::mlir::gpu::GPUDialect",
+    "amdgpu::AMDGPUDialect"
+  ];
+  // let options = [Option<"chipset", "chipset", "std::string",
+  //                       /*default=*/"\"gfx000\"",
+  //                       "Chipset that these operations will run on">];
+}
+
 //===----------------------------------------------------------------------===//
 // ConvertIndexToLLVMPass
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..5296f75571188 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1214,6 +1214,7 @@ struct ConvertAMDGPUToROCDLPass
   using Base::Base;
 
   void runOnOperation() override {
+    llvm::errs() << " WHEN DOES AMDGPU TO ROCDL RUN\n";
     MLIRContext *ctx = &getContext();
     FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
     if (failed(maybeChipset)) {
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index b6c21440c571c..b957a4473f1e6 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@ add_subdirectory(FuncToEmitC)
 add_subdirectory(FuncToLLVM)
 add_subdirectory(FuncToSPIRV)
 add_subdirectory(GPUCommon)
+add_subdirectory(GPUToAMDGPU)
 add_subdirectory(GPUToLLVMSPV)
 add_subdirectory(GPUToNVVM)
 add_subdirectory(GPUToROCDL)

>From 72a3b968600199c4cfd61f957df0ac8a7b38c241 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 26 Mar 2025 12:29:06 -0500
Subject: [PATCH 3/4] adding new dialect

---
 .../mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h |  26 +
 .../lib/Conversion/GPUToAMDGPU/CMakeLists.txt |  21 +
 .../Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp    |  44 ++
 .../Conversion/GPUToAMDGPU/gpu-to-amdgpu.mlir | 463 ++++++++++++++++++
 4 files changed, 554 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h
 create mode 100644 mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp
 create mode 100644 mlir/test/Conversion/GPUToAMDGPU/gpu-to-amdgpu.mlir

diff --git a/mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h b/mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h
new file mode 100644
index 0000000000000..77723c82cce5e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h
@@ -0,0 +1,26 @@
+//===- GPUToAMDGPU.h - Convert AMDGPU to ROCDL dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_GPUTOAMDGPU_GPUTOAMDGPU_H_
+#define MLIR_CONVERSION_GPUTOAMDGPU_GPUTOAMDGPU_H_
+
+#include <memory>
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+void populateGPUToAMDGPUConversionPatterns(LLVMTypeConverter &converter,
+                                            RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createConvertGPUToAMDGPUPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTOAMDGPU_GPUTOAMDGPU_H_
\ No newline at end of file
diff --git a/mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..2ca7347a3e097
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRGPUToAMDGPU
+  GPUToAMDGPU.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/GPUToAMDGPU
+  
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRGPUDialect
+  MLIRAMDGPUDialect
+  MLIRROCDLDialect
+  MLIRPass
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp b/mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp
new file mode 100644
index 0000000000000..69c2edf580c12
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp
@@ -0,0 +1,44 @@
+//===- GPUToAMDGPU.cpp - GPU to AMDGPU dialect conversion -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h"
+#include "../PassDetail.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+
+using namespace mlir;
+
+namespace {
+struct ConvertGPUToAMDGPUPass
+    : public ConvertGPUToAMDGPUBase<ConvertGPUToAMDGPUPass> {
+  ConvertGPUToAMDGPUPass() = default;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    LLVMTypeConverter converter(&getContext());
+    populateGPUToAMDGPUConversionPatterns(converter, patterns);
+    LLVMConversionTarget target(getContext());
+    target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+    target.addLegalDialect<::mlir::AMDGPU::AMDGPUDialect>();
+    target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::populateGPUToAMDGPUConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+}
+
+std::unique_ptr<Pass> mlir::createConvertGPUToAMDGPUPass() {
+  return std::make_unique<ConvertGPUToAMDGPUPass>();
+}
\ No newline at end of file
diff --git a/mlir/test/Conversion/GPUToAMDGPU/gpu-to-amdgpu.mlir b/mlir/test/Conversion/GPUToAMDGPU/gpu-to-amdgpu.mlir
new file mode 100644
index 0000000000000..8871b2ce0eadb
--- /dev/null
+++ b/mlir/test/Conversion/GPUToAMDGPU/gpu-to-amdgpu.mlir
@@ -0,0 +1,463 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx908 | FileCheck %s --check-prefixes=CHECK,GFX9,GFX908
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx90a | FileCheck %s --check-prefixes=CHECK,GFX9,GFX90A
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9,GFX942
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10,RDNA
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11,RDNA
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1201 | FileCheck %s --check-prefixes=CHECK,GFX12,RDNA
+
+// Note: #gpu.address_space<global> is hardcoded to `1` here because the
+// test pass doesn't set up the GPU address space conversions.
+
+#gpu_global_addrspace = 1
+
+// CHECK-LABEL: func @fat_raw_buffer_cast
+func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<8xi32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1]
+  // CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2]
+  // CHECK-DAG: %[[sizes:.*]] = llvm.extractvalue %[[desc]][3]
+  // CHECK-DAG: %[[strides:.*]] = llvm.extractvalue %[[desc]][4]
+  // CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(32 : i32) : i32
+  // CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]] : <1> to <7>
+  // CHECK: %[[ret0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<7>, ptr<7>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0]
+  // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1]
+  // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2]
+  // CHECK: %[[ret4:.*]] = llvm.insertvalue %[[sizes]], %[[ret3]][3]
+  // CHECK: %[[ret5:.*]] = llvm.insertvalue %[[strides]], %[[ret4]][4]
+  // CHECK: builtin.unrealized_conversion_cast %[[ret5]]
+  %ret = amdgpu.fat_raw_buffer_cast %buf : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_0d
+func.func @fat_raw_buffer_cast_0d(%buf: memref<i32, #gpu_global_addrspace>) -> memref<i32, #amdgpu.address_space<fat_raw_buffer>> {
+  // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<i32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64)>
+  // CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1]
+  // CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2]
+  // CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[base]], %[[strideArg]], %[[numRecords]], %[[flags]]
+  // CHECK: %[[ret0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<7>, ptr<7>, i64)>
+  // CHECK: %[[ret1:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret0]][0]
+  // CHECK: %[[ret2:.*]] = llvm.insertvalue %[[fatBuf]], %[[ret1]][1]
+  // CHECK: %[[ret3:.*]] = llvm.insertvalue %[[offset]], %[[ret2]][2]
+  // CHECK: builtin.unrealized_conversion_cast %[[ret3]]
+  %ret = amdgpu.fat_raw_buffer_cast %buf : memref<i32, #gpu_global_addrspace> to memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<i32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_dyn_size_offset
+func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> {
+  // CHECK: %[[size0:.*]] = llvm.extractvalue %{{.*}}[3, 0]
+  // CHECK: %[[stride0:.*]] = llvm.extractvalue %{{.*}}[4, 0]
+  // CHECK: %[[maxVals:.*]] = llvm.mul %[[size0]], %[[stride0]]
+  // CHECK: %[[maxValsI32:.*]] = llvm.trunc %[[maxVals]] : i64 to i32
+  // CHECK: %[[byteSize:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[numRecords:.*]] = llvm.mul %[[maxValsI32]], %[[byteSize]]
+  // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2]
+  // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+  // CHECK: llvm.insertvalue %[[offset]], %{{.*}}[2]
+  %ret = amdgpu.fat_raw_buffer_cast %buf : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<?xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_reset_offset
+func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace>) -> memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>> {
+  // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<?xi32, strided<[1], offset: ?>, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1]
+  // CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2]
+  // CHECK-DAG: %[[basePtr:.*]] = llvm.getelementptr %[[memRefPtr]][%[[memRefOff]]]
+  // CHECK-DAG: %[[zeroOff:.*]] = llvm.mlir.constant(0 : index) : i64
+  // CHECK: %[[fatBuf:.*]] = rocdl.make.buffer.rsrc %[[basePtr]], %{{.*}}, %{{.*}}, %{{.*}}
+  // CHECK: llvm.insertvalue %[[fatBuf]], %{{.*}}[1]
+  // CHECK: llvm.insertvalue %[[zeroOff]], %{{.*}}[2]
+  %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref<?xi32, strided<[1], offset: ?>, #gpu_global_addrspace> to memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<?xi32, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes
+func.func @fat_raw_buffer_cast_valid_bytes(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  // CHECK: %[[numRecords:.*]] = arith.constant -1 : i32
+  // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+  %cu32_max = arith.constant 0xffffffff : i32
+  %ret = amdgpu.fat_raw_buffer_cast %buf validBytes(%cu32_max) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_bounds_check
+func.func @fat_raw_buffer_cast_bounds_check(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(553807872 : i32)
+  // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %{{.*}}, %[[flags]]
+  %ret = amdgpu.fat_raw_buffer_cast %buf boundsCheck(false) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @fat_raw_buffer_cast_cache_swizzle
+// CHECK-SAME: (%{{.*}}: memref<64x64xi32, 1>, %[[stride:.*]]: i14)
+func.func @fat_raw_buffer_cast_cache_swizzle(%buf: memref<64x64xi32, #gpu_global_addrspace>, %stride: i14) -> memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>> {
+  // GFX908: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+  // GFX90A: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+  // RDNA: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+  // GFX942: %[[asI16:.*]] = llvm.zext %[[stride]] : i14 to i16
+  // GFX942: %[[cacheSwizzleOn:.*]] = llvm.mlir.constant(16384 : i16) : i16
+  // GFX942: %[[stride:.*]] = llvm.or disjoint %[[asI16]], %[[cacheSwizzleOn]]
+  // CHECK: rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %{{.*}}, %{{.*}}
+  %ret = amdgpu.fat_raw_buffer_cast %buf cacheSwizzleStride(%stride) : memref<64x64xi32, #gpu_global_addrspace> to memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>>
+  return %ret : memref<64x64xi32, #amdgpu.address_space<fat_raw_buffer>>
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_scalar_i32
+func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref<i32>) -> i32 {
+  // Extra constant for byte width
+  // CHECK: llvm.mlir.constant(4 : i32)
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32)
+  // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
+  // CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[] : memref<i32> -> i32
+  func.return %0 : i32
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32
+func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
+  // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[stride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
+  // CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32>, i32 -> i32
+  func.return %0 : i32
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided
+func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[?, ?], offset: ?>>, %i: i32, %j: i32) -> i32 {
+    // CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<16x16xi32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32) : i32
+    // CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %[[descriptor]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK: %[[offset:.*]] = llvm.extractvalue %[[descriptor]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+    // CHECK: %[[sz_i:.*]] = llvm.extractvalue %[[descriptor]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK: %[[stride_i:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK: %[[ext_i:.*]] = llvm.mul %[[sz_i]], %[[stride_i]] : i64
+    // CHECK: %[[sz_j:.*]] = llvm.extractvalue %[[descriptor]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK: %[[stride_j:.*]] = llvm.extractvalue %[[descriptor]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK: %[[ext_j:.*]] = llvm.mul %[[sz_j]], %[[stride_j]] : i64
+    // CHECK: %[[num_records:.*]] = llvm.intr.umax(%[[ext_i]], %[[ext_j]]) : (i64, i64) -> i64
+    // CHECK: %[[num_rec_i32:.*]] = llvm.trunc %[[num_records]] : i64 to i32
+    // CHECK: %[[elem_size_2:.*]] = llvm.mlir.constant(4 : i32) : i32
+    // CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size_2]] : i32
+    // CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
+    // CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %[[stride]], %[[num_rec_bytes_i32]], %{{.*}} : !llvm.ptr to <8>
+    // CHECK: %[[stride_i_1:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK: %[[stride_i_i32:.*]] = llvm.trunc %[[stride_i_1]] : i64 to i32
+    // CHECK: %[[t_0:.*]] = llvm.mul %{{.*}}, %[[stride_i_i32]] : i32
+    // CHECK: %[[stride_j_1:.*]] = llvm.extractvalue %[[descriptor]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK: %[[stride_j_i32:.*]] = llvm.trunc %[[stride_j_1]] : i64 to i32
+    // CHECK: %[[t_1:.*]] = llvm.mul %{{.*}}, %[[stride_j_i32]] : i32
+    // CHECK: %[[index:.*]] = llvm.add %[[t_0]], %[[t_1]] : i32
+    // CHECK: %[[vgpr_off:.*]] = llvm.mul %[[index]], %[[elem_size]] : i32
+    // CHECK: %[[zero_0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK: %[[sgpr_off:.*]] = llvm.mul %[[zero_0]], %[[elem_size]] : i32
+    // CHECK: %[[zero_1:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK: %[[v:.*]] = rocdl.raw.ptr.buffer.load %[[rsrc]], %[[vgpr_off]], %[[sgpr_off]], %[[zero_1]] : i32
+    // CHECK: return %[[v]] : i32
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%i, %j] :  memref<16x16xi32, strided<[?, ?], offset: ?>>, i32, i32 -> i32
+  func.return %0 : i32
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_oob_off
+func.func @gpu_gcn_raw_buffer_load_i32_oob_off(%buf: memref<64xi32>, %idx: i32) -> i32 {
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(553807872 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %{{.*}}, %[[flags]]
+  // CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = false} %buf[%idx] : memref<64xi32>, i32 -> i32
+  func.return %0 : i32
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_1xi32
+func.func @gpu_gcn_raw_buffer_load_1xi32(%buf: memref<64xi32>, %idx: i32) -> vector<1xi32> {
+  // CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32
+  // CHECK: %[[cast:.*]] = llvm.bitcast %[[ret]] : i32 to vector<1xi32>
+  // CHECK: return %[[cast]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32>, i32 -> vector<1xi32>
+  func.return %0 : vector<1xi32>
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_2xi32
+func.func @gpu_gcn_raw_buffer_load_2xi32(%buf: memref<64xi32>, %idx: i32) -> vector<2xi32> {
+  // CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi32>
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32>, i32 -> vector<2xi32>
+  func.return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i8
+func.func @gpu_gcn_raw_buffer_load_i8(%buf: memref<64xi8>, %idx: i32) -> i8 {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+  // CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i8
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi8>, i32 -> i8
+  func.return %0 : i8
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_2xi8
+func.func @gpu_gcn_raw_buffer_load_2xi8(%buf: memref<64xi8>, %idx: i32) -> vector<2xi8> {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+  // CHECK: %[[loaded:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i16
+  // CHECK: %[[ret:.*]] = llvm.bitcast %[[loaded]] : i16 to vector<2xi8>
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi8>, i32 -> vector<2xi8>
+  func.return %0 : vector<2xi8>
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_16xi8
+func.func @gpu_gcn_raw_buffer_load_16xi8(%buf: memref<64xi8>, %idx: i32) -> vector<16xi8> {
+  // CHECK: %[[loaded:.*]] = rocdl.raw.ptr.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi32>
+  // CHECK: %[[ret:.*]] = llvm.bitcast %[[loaded]] : vector<4xi32> to vector<16xi8>
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi8>, i32 -> vector<16xi8>
+  func.return %0 : vector<16xi8>
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ
+func.func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ(%buf: memref<64xf8E5M2FNUZ>, %idx: i32) -> f8E5M2FNUZ {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+  // CHECK: %[[loaded:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i8
+  // CHECK: %[[ret:.*]] = builtin.unrealized_conversion_cast %[[loaded]] : i8 to f8E5M2FNUZ
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xf8E5M2FNUZ>, i32 -> f8E5M2FNUZ
+  func.return %0 : f8E5M2FNUZ
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_4xf8E4M3FNUZ
+func.func @gpu_gcn_raw_buffer_load_4xf8E4M3FNUZ(%buf: memref<64xf8E4M3FNUZ>, %idx: i32) -> vector<4xf8E4M3FNUZ> {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
+  // CHECK: %[[loaded:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
+  // CHECK: %[[cast:.*]] = llvm.bitcast %[[loaded]] : i32 to vector<4xi8>
+  // CHECK: %[[ret:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
+  // CHECK: return %[[ret]]
+  %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xf8E4M3FNUZ>, i32 -> vector<4xf8E4M3FNUZ>
+  func.return %0 : vector<4xf8E4M3FNUZ>
+}
+
+// Since the lowering logic is shared with loads, only bitcasts need to be rechecked
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_store_scalar_i32
+func.func @gpu_gcn_raw_buffer_store_scalar_i32(%value: i32, %buf: memref<i32>) {
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %{{.*}}, %[[flags]]
+  // CHECK: rocdl.raw.ptr.buffer.store %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
+  amdgpu.raw_buffer_store {boundsCheck = true} %value -> %buf[] : i32 -> memref<i32>
+  func.return
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_store_i32
+func.func @gpu_gcn_raw_buffer_store_i32(%value: i32, %buf: memref<64xi32>, %idx: i32) {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
+  // CHECK: rocdl.raw.ptr.buffer.store %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
+  amdgpu.raw_buffer_store {boundsCheck = true} %value -> %buf[%idx] : i32 -> memref<64xi32>, i32
+  func.return
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_store_1xf32
+func.func @gpu_gcn_raw_buffer_store_1xf32(%value: vector<1xf32>, %buf: memref<64xf32>, %idx: i32) {
+  // CHECK: %[[cast:.*]] = llvm.bitcast %{{.*}} : vector<1xf32> to f32
+  // CHECK: rocdl.raw.ptr.buffer.store %[[cast]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32
+  amdgpu.raw_buffer_store {boundsCheck = true} %value -> %buf[%idx] : vector<1xf32> -> memref<64xf32>, i32
+  func.return
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_store_2xi8
+func.func @gpu_gcn_raw_buffer_store_2xi8(%value: vector<2xi8>, %buf: memref<64xi8>, %idx: i32) {
+  // CHECK: %[[cast:.*]] = llvm.bitcast %{{.*}} : vector<2xi8> to i16
+  // CHECK: rocdl.raw.ptr.buffer.store %[[cast]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i16
+  amdgpu.raw_buffer_store {boundsCheck = true} %value -> %buf[%idx] : vector<2xi8> -> memref<64xi8>, i32
+  func.return
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_store_16xi8
+func.func @gpu_gcn_raw_buffer_store_16xi8(%value: vector<16xi8>, %buf: memref<64xi8>, %idx: i32) {
+  // CHECK: %[[cast:.*]] = llvm.bitcast %{{.*}} : vector<16xi8> to vector<4xi32>
+  // CHECK: rocdl.raw.ptr.buffer.store %[[cast]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi32>
+  amdgpu.raw_buffer_store {boundsCheck = true} %value -> %buf[%idx] : vector<16xi8> -> memref<64xi8>, i32
+  func.return
+}
+
+// And more so for atomic add
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fadd_f32
+func.func @gpu_gcn_raw_buffer_atomic_fadd_f32(%value: f32, %buf: memref<64xf32>, %idx: i32) {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
+  // CHECK: rocdl.raw.ptr.buffer.atomic.fadd %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : f32
+  amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %value -> %buf[%idx] : f32 -> memref<64xf32>, i32
+  func.return
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fadd_v2f16
+func.func @gpu_gcn_raw_buffer_atomic_fadd_v2f16(%value: vector<2xf16>, %buf: memref<64xf16>, %idx: i32) {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(128 : i32)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
+  // CHECK: rocdl.raw.ptr.buffer.atomic.fadd %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : vector<2xf16>
+  amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %value -> %buf[%idx] : vector<2xf16> -> memref<64xf16>, i32
+  func.return
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fadd_v2bf16
+func.func @gpu_gcn_raw_buffer_atomic_fadd_v2bf16(%value: vector<2xbf16>, %buf: memref<64xbf16>, %idx: i32) {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(128 : i32)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
+  // CHECK: rocdl.raw.ptr.buffer.atomic.fadd %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : vector<2xbf16>
+  amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %value -> %buf[%idx] : vector<2xbf16> -> memref<64xbf16>, i32
+  func.return
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fmax_f32
+func.func @gpu_gcn_raw_buffer_atomic_fmax_f32(%value: f32, %buf: memref<64xf32>, %idx: i32) {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
+  // CHECK: rocdl.raw.ptr.buffer.atomic.fmax %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : f32
+  amdgpu.raw_buffer_atomic_fmax {boundsCheck = true} %value -> %buf[%idx] : f32 -> memref<64xf32>, i32
+  func.return
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_smax_i32
+func.func @gpu_gcn_raw_buffer_atomic_smax_i32(%value: i32, %buf: memref<64xi32>, %idx: i32) {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
+  // CHECK: rocdl.raw.ptr.buffer.atomic.smax %{{.*}} %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
+  amdgpu.raw_buffer_atomic_smax {boundsCheck = true} %value -> %buf[%idx] : i32 -> memref<64xi32>, i32
+  func.return
+}
+
+// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_umin_i32
+func.func @gpu_gcn_raw_buffer_atomic_umin_i32(%value: i32, %buf: memref<64xi32>, %idx: i32) {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
+  // CHECK: rocdl.raw.ptr.buffer.atomic.umin %{{.*}} %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
+  amdgpu.raw_buffer_atomic_umin {boundsCheck = true} %value -> %buf[%idx] : i32 -> memref<64xi32>, i32
+  func.return
+}
+
+// CHECK-LABEL: func @amdgpu_raw_buffer_atomic_cmpswap_f32
+// CHECK-SAME: (%[[src:.*]]: f32, %[[cmp:.*]]: f32, {{.*}})
+func.func @amdgpu_raw_buffer_atomic_cmpswap_f32(%src : f32, %cmp : f32, %buf : memref<64xf32>, %idx: i32) -> f32 {
+  // CHECK: %[[srcCast:.*]] = llvm.bitcast %[[src]] : f32 to i32
+  // CHECK: %[[cmpCast:.*]] = llvm.bitcast %[[cmp]] : f32 to i32
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
+  // CHECK: %[[dst:.*]] = rocdl.raw.ptr.buffer.atomic.cmpswap %[[srcCast]], %[[cmpCast]], %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
+  // CHECK: %[[dstCast:.*]] = llvm.bitcast %[[dst]] : i32 to f32
+  // CHECK: return %[[dstCast]]
+  %dst = amdgpu.raw_buffer_atomic_cmpswap {boundsCheck = true} %src, %cmp -> %buf[%idx] : f32 -> memref<64xf32>, i32
+  func.return %dst : f32
+}
+
+// CHECK-LABEL: func @amdgpu_raw_buffer_atomic_cmpswap_i64
+// CHECK-SAME: (%[[src:.*]]: i64, %[[cmp:.*]]: i64, {{.*}})
+func.func @amdgpu_raw_buffer_atomic_cmpswap_i64(%src : i64, %cmp : i64, %buf : memref<64xi64>, %idx: i32) -> i64 {
+  // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(512 : i32)
+  // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
+  // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
+  // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
+  // CHECK: %[[dst:.*]] = rocdl.raw.ptr.buffer.atomic.cmpswap %[[src]], %[[cmp]], %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i64
+  // CHECK: return %[[dst]]
+  %dst = amdgpu.raw_buffer_atomic_cmpswap {boundsCheck = true} %src, %cmp -> %buf[%idx] : i64 -> memref<64xi64>, i32
+  func.return %dst : i64
+}
+
+// CHECK-LABEL: func @amdgpu_raw_buffer_atomic_cmpswap_v2f16
+// CHECK-SAME: (%[[src:.*]]: vector<2xf16>, %[[cmp:.*]]: vector<2xf16>, {{.*}})
+func.func @amdgpu_raw_buffer_atomic_cmpswap_v2f16(%src : vector<2xf16>, %cmp : vector<2xf16>, %buf : memref<64xf16>, %idx: i32) -> vector<2xf16> {
+  // CHECK-DAG: %[[srcBits:.+]] = llvm.bitcast %[[src]] : vector<2xf16> to i32
+  // CHECK-DAG: %[[cmpBits:.+]] = llvm.bitcast %[[cmp]] : vector<2xf16> to i32
+  // CHECK: %[[dstBits:.+]] = rocdl.raw.ptr.buffer.atomic.cmpswap %[[srcBits]], %[[cmpBits]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32
+  // CHECK: %[[dst:.+]] = llvm.bitcast %[[dstBits]] : i32 to vector<2xf16>
+  // CHECK: return %[[dst]]
+  %dst = amdgpu.raw_buffer_atomic_cmpswap {boundsCheck = true} %src, %cmp -> %buf[%idx] : vector<2xf16> -> memref<64xf16>, i32
+  func.return %dst : vector<2xf16>
+}
+
+// CHECK-LABEL: func @lds_barrier
+func.func @lds_barrier() {
+  // GFX908: llvm.inline_asm has_side_effects asm_dialect = att
+  // GFX908-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
+  // GFX90A: rocdl.s.waitcnt -7937
+  // GFX90A-NEXT: rocdl.s.barrier
+  // GFX942: rocdl.s.waitcnt -7937
+  // GFX942-NEXT: rocdl.s.barrier
+  // GFX10:  rocdl.s.waitcnt -16129
+  // GFX10-NEXT: rocdl.s.barrier
+  // GFX11:  llvm.inline_asm has_side_effects asm_dialect = att
+  // GFX11-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
+  // GFX12:  rocdl.s.wait.dscnt 0
+  // GFX12-NEXT: rocdl.s.barrier.signal -1
+  // GFX12-NEXT: rocdl.s.barrier.wait -1
+  amdgpu.lds_barrier
+  func.return
+}
+
+// CHECK-LABEL: func @sched_barrier
+func.func @sched_barrier() {
+  // CHECK: rocdl.sched.barrier 0
+  amdgpu.sched_barrier allow = <none>
+  // CHECK: rocdl.sched.barrier 1
+  amdgpu.sched_barrier allow = <non_mem_non_sideffect>
+  // CHECK: rocdl.sched.barrier 2
+  amdgpu.sched_barrier allow = <valu>
+  // CHECK: rocdl.sched.barrier 4
+  amdgpu.sched_barrier allow = <salu>
+  // CHECK: rocdl.sched.barrier 8
+  amdgpu.sched_barrier allow = <mfma_wmma>
+  // CHECK: rocdl.sched.barrier 16
+  amdgpu.sched_barrier allow = <all_vmem>
+  // CHECK: rocdl.sched.barrier 32
+  amdgpu.sched_barrier allow = <vmem_read>
+  // CHECK: rocdl.sched.barrier 64
+  amdgpu.sched_barrier allow = <vmem_write>
+  // CHECK: rocdl.sched.barrier 128
+  amdgpu.sched_barrier allow = <all_ds>
+  // CHECK: rocdl.sched.barrier 256
+  amdgpu.sched_barrier allow = <ds_read>
+  // CHECK: rocdl.sched.barrier 512
+  amdgpu.sched_barrier allow = <ds_write>
+  // CHECK: rocdl.sched.barrier 1024
+  amdgpu.sched_barrier allow = <transcendental>
+  // CHECK: rocdl.sched.barrier 18
+  amdgpu.sched_barrier allow = <valu|all_vmem>
+  func.return
+}

>From d766a5fe5340df4883af7a16a82616fda04e42b9 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Wed, 26 Mar 2025 23:02:43 -0500
Subject: [PATCH 4/4] making things compile

---
 .../mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h |  15 +-
 mlir/include/mlir/Conversion/Passes.td        |   9 +-
 .../lib/Conversion/GPUToAMDGPU/CMakeLists.txt |   1 +
 .../Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp    | 183 ++++++++++++++++--
 mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt |   1 +
 5 files changed, 190 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h b/mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h
index 77723c82cce5e..fea9b7ed50bcc 100644
--- a/mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h
@@ -8,18 +8,27 @@
 #ifndef MLIR_CONVERSION_GPUTOAMDGPU_GPUTOAMDGPU_H_
 #define MLIR_CONVERSION_GPUTOAMDGPU_GPUTOAMDGPU_H_
 
+
+#include "mlir/IR/PatternMatch.h"
 #include <memory>
+#include <string>
 
 namespace mlir {
 
 class LLVMTypeConverter;
 class RewritePatternSet;
+class TypeConverter;
 class Pass;
 
-void populateGPUToAMDGPUConversionPatterns(LLVMTypeConverter &converter,
-                                            RewritePatternSet &patterns);
+#define GEN_PASS_DECL_CONVERTGPUTOAMDGPUPASS
+#include "mlir/Conversion/Passes.h.inc"
 
-std::unique_ptr<Pass> createConvertGPUToAMDGPUPass();
+void populateSubgroupReduceLoweringPatterns(LLVMTypeConverter &converter,
+                                            RewritePatternSet &patterns,
+                                            unsigned subgroupSize,
+                                            PatternBenefit benefit);
+// void populateGPUToAMDGPUConversionPatterns(LLVMTypeConverter &converter,
+//                                             RewritePatternSet &patterns);
 
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 502ca5a84ee7b..6a1deeb230794 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -650,12 +650,13 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
 def ConvertGPUToAMDGPUPass : Pass<"convert-gpu-to-amdgpu"> {
   let summary = "Generate AMDGPU operations for gpu operations";
   let dependentDialects = [
+    "LLVM::LLVMDialect",
     "::mlir::gpu::GPUDialect",
-    "amdgpu::AMDGPUDialect"
+    "amdgpu::AMDGPUDialect",
   ];
-  // let options = [Option<"chipset", "chipset", "std::string",
-  //                       /*default=*/"\"gfx000\"",
-  //                       "Chipset that these operations will run on">];
+  let options = [Option<"subgroupSize", "subgroup-size", "unsigned",
+                        /*default=*/"64",
+                        "Size of subgroup">];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt
index 2ca7347a3e097..9b82b5dc63d9c 100644
--- a/mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRGPUToAMDGPU
   MLIRLLVMDialect
   MLIRGPUDialect
   MLIRAMDGPUDialect
+  MLIRAMDGPUUtils
   MLIRROCDLDialect
   MLIRPass
   MLIRTransforms
diff --git a/mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp b/mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp
index 69c2edf580c12..c2fc8b2e19ae6 100644
--- a/mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp
+++ b/mlir/lib/Conversion/GPUToAMDGPU/GPUToAMDGPU.cpp
@@ -7,27 +7,188 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPU.h"
-#include "../PassDetail.h"
+
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
-#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+#include <cstdint>
+
+#include "../LLVMCommon/MemRefDescriptor.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTGPUTOAMDGPUPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
 
 using namespace mlir;
 
 namespace {
+struct ClusterInfo {
+  unsigned clusterStride;
+  unsigned clusterSize;
+  unsigned subgroupSize;
+};
+
+static FailureOr<ClusterInfo>
+getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
+  assert(llvm::isPowerOf2_32(subgroupSize));
+
+  std::optional<uint32_t> clusterSize = op.getClusterSize();
+  assert(!clusterSize ||
+         llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
+  if (clusterSize && *clusterSize > subgroupSize)
+    return op.emitOpError()
+           << "cluster size " << *clusterSize
+           << " is greater than subgroup size " << subgroupSize;
+  unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+
+  auto clusterStride = op.getClusterStride();
+  assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
+  if (clusterStride >= subgroupSize)
+    return op.emitOpError()
+           << "cluster stride " << clusterStride
+           << " is not less than subgroup size " << subgroupSize;
+
+  return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
+}
+
+Value createSubgroupDPPReduction(OpBuilder &b, Location loc, Value input,
+                                 gpu::AllReduceOperation mode,
+                                 const ClusterInfo &ci) {
+  Value result = input;
+  if (ci.clusterSize >= 2) {
+    auto permArg = b.getIntegerAttr(b.getIntegerType(32), 1);
+    Value dppResult =
+        b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+                                amdgpu::DPPPerm::row_shr, permArg);
+    result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+                                        result, dppResult);
+  }
+
+  if (ci.clusterSize >= 4) {
+    auto permArg = b.getIntegerAttr(b.getIntegerType(32), 2);
+    Value dppResult =
+        b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+                                amdgpu::DPPPerm::row_shr, permArg);
+    result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+                                        result, dppResult);
+  }
+
+  if (ci.clusterSize >= 8) {
+    Value dppResult = b.create<amdgpu::DPPOp>(
+        loc, result.getType(), result, result, amdgpu::DPPPerm::row_half_mirror,
+        b.getUnitAttr());
+    result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+                                        result, dppResult);
+  }
+
+  if (ci.clusterSize >= 16) {
+    Value dppResult =
+        b.create<amdgpu::DPPOp>(loc, result.getType(), result, result,
+                                amdgpu::DPPPerm::row_mirror, b.getUnitAttr());
+    result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+                                        result, dppResult);
+  }
+
+  if (ci.clusterSize >= 32) {
+    // auto permArg = builder.getInt32(15);
+    // auto rowMask = builder.getInt32("0xa");
+    // auto bankMask = builder.getInt32("0xf");
+    // auto boundCtrl = builder.getBoolAttr(false);
+    auto permArg = b.getIntegerAttr(b.getIntegerType(32), 15);
+    Value dppResult = b.create<amdgpu::DPPOp>(
+        loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_15,
+        b.getUnitAttr(), 10, 15, false);
+    result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+                                        result, dppResult);
+  }
+
+  if (ci.clusterSize == 64) {
+    // auto permArg = builder.getInt32(31);
+    // auto rowMask = builder.getInt32("0xc");
+    // auto bankMask = builder.getInt32("0xf");
+    // auto boundCtrl = builder.getBoolAttr(false);
+    auto permArg = b.getIntegerAttr(b.getIntegerType(32), 31);
+    Value dppResult = b.create<amdgpu::DPPOp>(
+        loc, result.getType(), result, result, amdgpu::DPPPerm::row_bcast_31,
+        b.getUnitAttr(), 12, 15, false);
+    result = vector::makeArithReduction(b, loc, gpu::convertReductionKind(mode),
+                                        result, dppResult);
+  }
+
+  // // read lane 63 with the final result.
+  // auto lane = b.getIntegerAttr(b.getIntegerType(32), 63);
+  // result = b.create<ROCDL::ReadLaneOp>(loc, input.getType(), result, lane);
+  assert(result.getType() == input.getType());
+  return result;
+}
+
+struct ScalarSubgroupReduceToShuffles final
+    : OpRewritePattern<gpu::SubgroupReduceOp> {
+  ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
+                                 bool matchClustered, PatternBenefit benefit)
+      : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+        matchClustered(matchClustered) {}
+
+  LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+                                PatternRewriter &rewriter) const override {
+    llvm::errs() << "ScalarSubgroupReduceToShuffles" << "\n";
+    if (op.getClusterSize().has_value() != matchClustered) {
+      return rewriter.notifyMatchFailure(
+          op, llvm::formatv("op is {0}clustered but pattern is configured to "
+                            "only match {1}clustered ops",
+                            matchClustered ? "non-" : "",
+                            matchClustered ? "" : "non-"));
+    }
+
+    auto ci = getAndValidateClusterInfo(op, subgroupSize);
+    if (failed(ci))
+      return failure();
+
+    Location loc = op.getLoc();
+    rewriter.replaceOp(op, createSubgroupDPPReduction(
+                               rewriter, loc, op.getValue(), op.getOp(), *ci));
+    return success();
+  }
+
+private:
+  unsigned subgroupSize = 0;
+  bool matchClustered = false;
+};
+
 struct ConvertGPUToAMDGPUPass
-    : public ConvertGPUToAMDGPUBase<ConvertGPUToAMDGPUPass> {
-  ConvertGPUToAMDGPUPass() = default;
+    : public impl::ConvertGPUToAMDGPUPassBase<ConvertGPUToAMDGPUPass> {
+  using Base::Base;
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     LLVMTypeConverter converter(&getContext());
-    populateGPUToAMDGPUConversionPatterns(converter, patterns);
     LLVMConversionTarget target(getContext());
     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
-    target.addLegalDialect<::mlir::AMDGPU::AMDGPUDialect>();
+    target.addLegalDialect<::mlir::amdgpu::AMDGPUDialect>();
     target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
+
+    int subgroupSizeInt = static_cast<int>(subgroupSize);
+    populateSubgroupReduceLoweringPatterns(converter, patterns, subgroupSizeInt,
+                                           PatternBenefit(1));
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
@@ -35,10 +196,8 @@ struct ConvertGPUToAMDGPUPass
 };
 } // namespace
 
-void mlir::populateGPUToAMDGPUConversionPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-}
-
-std::unique_ptr<Pass> mlir::createConvertGPUToAMDGPUPass() {
-  return std::make_unique<ConvertGPUToAMDGPUPass>();
+void mlir::populateSubgroupReduceLoweringPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns, unsigned subgroupSize, PatternBenefit benefit) {
+  patterns.add<ScalarSubgroupReduceToShuffles>(
+      patterns.getContext(), subgroupSize, /*matchClustered=*/true, benefit);
 }
\ No newline at end of file
diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 945e3ccdfa87b..52484ac69a3e2 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
   MLIRMathToLLVM
   MLIRMathToROCDL
   MLIRAMDGPUToROCDL
+  MLIRGPUToAMDGPU
   MLIRFuncToLLVM
   MLIRGPUDialect
   MLIRGPUToGPURuntimeTransforms



More information about the Mlir-commits mailing list