[Mlir-commits] [mlir] Add Lowerings for GPU WMMA F16/F32 ops to ROCDL dialect (PR #69357)

Navdeep Katel llvmlistbot at llvm.org
Tue Oct 17 10:10:37 PDT 2023


https://github.com/navdeepkk-polymagelabs created https://github.com/llvm/llvm-project/pull/69357

The following support is added:
1.) Lowering for GPU WMMA load op for AOp, BOp, COp. The lowering supports transposed and non-transposed loads for AOp and BOp. Only non-transposed loads are supported for COp. Loading for COp also supports the opSelect bit.
2.) Lowering for GPU WMMA mma op with support for opselect bit.
3.) Lowering for GPU WMMA store op with support for opSelect bit.

>From 73ac46aaf54daf6482455858c4463d0b94e4faff Mon Sep 17 00:00:00 2001
From: Navdeep Katel <navdeep at polymagelabs.com>
Date: Tue, 19 Sep 2023 15:50:18 +0530
Subject: [PATCH 1/3] [MLIR][AMDGPU] Expose an utility to get the laneID of the
 current lane

Expose an utility to get laneID of the current lane. The implementation
is borrowed from the `gpu.lane_id` to ROCDL conversion pattern.
---
 .../Conversion/GPUToROCDL/GPUToROCDLPass.h    | 27 ++++++++++
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      | 51 ++++++++-----------
 2 files changed, 48 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
index 5647787712997b5..311490f1d05f0c3 100644
--- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
+++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
@@ -10,11 +10,19 @@
 
 #include "mlir/Conversion/GPUToROCDL/Runtimes.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include <memory>
 
 namespace mlir {
 class LLVMTypeConverter;
 class ConversionTarget;
+class OpBuilder;
+class Location;
 class RewritePatternSet;
 
 template <typename OpT>
@@ -27,6 +35,25 @@ class GPUModuleOp;
 #define GEN_PASS_DECL_CONVERTGPUOPSTOROCDLOPS
 #include "mlir/Conversion/Passes.h.inc"
 
+namespace amd {
+/// Constant representing 32 workitems in a workgroup.
+const unsigned kWaveFrontSize32 = 32;
+
+/// Constant representing 64 workitems in a workgroup.
+const unsigned kWaveFrontSize64 = 64;
+
+/// Wavefront sizes that are supported by the GPU to ROCDL lowerings.
+const unsigned kWMMASupportedWaveFrontSizes[] = {kWaveFrontSize32,
+                                                 kWaveFrontSize64};
+
+/// Generate ops to get the laneId of the current lane and return it.
+Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
+                unsigned indexBitwidth);
+
+/// Return the LLVM Type corresponding to the MMAMatrixType.
+Type convertWMMAToROCDLLLVMType(gpu::MMAMatrixType matrixType);
+} // namespace amd
+
 /// Collect a set of patterns to convert from the GPU dialect to ROCDL.
 /// If `runtime` is Unknown, gpu.printf will not be lowered
 /// The resulting pattern set should be run over a gpu.module op
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e2cb3687d87288f..938dc5a6909fe04 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -64,15 +64,27 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
   return canBeBare;
 }
 
-Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
+Value amd::getLaneId(ConversionPatternRewriter &rewriter, Location loc,
                 const unsigned indexBitwidth) {
-  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+  // convert to:  %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
+  // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
+  MLIRContext *context = rewriter.getContext();
+  Type intTy = IntegerType::get(context, 32);
   Value zero = rewriter.createOrFold<arith::ConstantIntOp>(loc, 0, 32);
   Value minus1 = rewriter.createOrFold<arith::ConstantIntOp>(loc, -1, 32);
-  Value mbcntLo = rewriter.create<ROCDL::MbcntLoOp>(loc, int32Type,
-                                                    ValueRange{minus1, zero});
-  Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, int32Type,
+  Value mbcntLo =
+      rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
+  Value laneId = rewriter.create<ROCDL::MbcntHiOp>(loc, intTy,
                                                    ValueRange{minus1, mbcntLo});
+  // Truncate or extend the result depending on the index bitwidth specified
+  // by the LLVMTypeConverter options.
+  if (indexBitwidth > 32) {
+    laneId = rewriter.create<LLVM::SExtOp>(
+        loc, IntegerType::get(context, indexBitwidth), laneId);
+  } else if (indexBitwidth < 32) {
+    laneId = rewriter.create<LLVM::TruncOp>(
+        loc, IntegerType::get(context, indexBitwidth), laneId);
+  }
   return laneId;
 }
 
@@ -83,29 +95,9 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   LogicalResult
   matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op->getLoc();
-    MLIRContext *context = rewriter.getContext();
-    // convert to:  %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
-    // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
-
-    Type intTy = IntegerType::get(context, 32);
-    Value zero = rewriter.createOrFold<arith::ConstantIntOp>(loc, 0, 32);
-    Value minus1 = rewriter.createOrFold<arith::ConstantIntOp>(loc, -1, 32);
-    Value mbcntLo =
-        rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
-    Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
-        loc, intTy, ValueRange{minus1, mbcntLo});
-    // Truncate or extend the result depending on the index bitwidth specified
-    // by the LLVMTypeConverter options.
-    const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
-    if (indexBitwidth > 32) {
-      laneId = rewriter.create<LLVM::SExtOp>(
-          loc, IntegerType::get(context, indexBitwidth), laneId);
-    } else if (indexBitwidth < 32) {
-      laneId = rewriter.create<LLVM::TruncOp>(
-          loc, IntegerType::get(context, indexBitwidth), laneId);
-    }
-    rewriter.replaceOp(op, {laneId});
+    rewriter.replaceOp(
+        op, {amd::getLaneId(rewriter, op->getLoc(),
+                            getTypeConverter()->getIndexTypeBitwidth())});
     return success();
   }
 };
@@ -136,8 +128,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     // TODO: Add support for non 32-bit shuffle values.
     if (adaptor.getValue().getType().getIntOrFloatBitWidth() != 32)
       return failure();
-    const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
-    Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
+    Value srcLaneId = amd::getLaneId(rewriter, op->getLoc(), 32);
 
     auto int32Type = IntegerType::get(rewriter.getContext(), 32);
     Value width = adaptor.getWidth();

>From 1a7c22904a23b145e6c9408efc48d9e0ba5f92e5 Mon Sep 17 00:00:00 2001
From: Navdeep Katel <navdeep at polymagelabs.com>
Date: Sat, 23 Sep 2023 14:54:05 +0530
Subject: [PATCH 2/3] [MLIR][AMDGPU] Add `convert-gpu-to-amdgpu` pass

Add `convert-gpu-to-amdgpu` pass. This pass currently converts
`gpu.subgroup_mma_compute` op only.
---
 .../Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h  |  85 +++++++++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  24 +++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../lib/Conversion/GPUToAMDGPU/CMakeLists.txt |  18 ++
 .../GPUToAMDGPU/LowerGPUOpsToAMDGPUOps.cpp    | 101 ++++++++++
 .../GPUToAMDGPU/WmmaOpsToAMDGPU.cpp           | 180 ++++++++++++++++++
 ...mma-ops-to-amdgpu-unsupported-chipset.mlir |  10 +
 ...ma-ops-to-amdgpu-unsupported-operands.mlir |  33 ++++
 ...ma-ops-to-amdgpu-unsupported-warpsize.mlir |  10 +
 .../GPUToAMDGPU/wmma-ops-to-amdgpu.mlir       |  34 ++++
 11 files changed, 497 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h
 create mode 100644 mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/GPUToAMDGPU/LowerGPUOpsToAMDGPUOps.cpp
 create mode 100644 mlir/lib/Conversion/GPUToAMDGPU/WmmaOpsToAMDGPU.cpp
 create mode 100644 mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-chipset.mlir
 create mode 100644 mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-operands.mlir
 create mode 100644 mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-warpsize.mlir
 create mode 100644 mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu.mlir

diff --git a/mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h b/mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h
new file mode 100644
index 000000000000000..b5d0ab97d0ec6ca
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h
@@ -0,0 +1,85 @@
+//===- GPUToAMDGPUPass.h - Convert GPU kernel to AMDGPU dialect -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_GPUTOAMDGPU_GPUTOAMDGPUPASS_H_
+#define MLIR_CONVERSION_GPUTOAMDGPU_GPUTOAMDGPUPASS_H_
+
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include <memory>
+
+namespace llvm {
+class StringRef;
+} // namespace llvm
+
+namespace mlir {
+class ConversionTarget;
+class OpBuilder;
+class Location;
+class RewritePatternSet;
+class Type;
+class TypeConverter;
+
+template <typename OpT>
+class OperationPass;
+
+namespace gpu {
+class GPUModuleOp;
+class MMAMatrixType;
+} // namespace gpu
+
+#define GEN_PASS_DECL_CONVERTGPUOPSTOAMDGPUOPS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace amd {
+/// Return the LLVM Type corresponding to the MMAMatrixType.
+Type convertWMMAToVectorType(gpu::MMAMatrixType matrixType);
+
+/// String to represent the `opSelect` attribute name.
+constexpr char kAMDGpuOpselectAttrName[] = "opSelect";
+} // namespace amd
+
+/// Collect a set of patterns to convert from the GPU dialect to AMDGPU.
+/// If `runtime` is Unknown, gpu.printf will not be lowered. The resulting
+/// pattern set should be run over a gpu.module op. `chipset` is the chip we are
+/// targeting. `warpSize` is the warp size to use when generating WMMA
+/// intrinsics. `opSelect` is used in the lowering of f16 versions of WMMA ops
+/// involving `C` operand. If `opSelect` is true upper half of the general
+/// purpose 32-bit registers is used for storing the values; If false the lower
+/// half is used.
+void populateGpuToAMDGPUConversionPatterns(TypeConverter &typeConverter,
+                                           RewritePatternSet &patterns,
+                                           llvm::StringRef chipset = "gfx1100",
+                                           unsigned warpSize = 32);
+
+/// Creates a pass that lowers GPU dialect operations to AMDGPU counterparts.
+/// The index bitwidth used for the lowering of the device side index
+/// computations is configurable. AMD gpus have a configurable warp size; valid
+/// choices are 32 and 64. We choose 32 as the default size. `opSelect` is used
+/// in the lowering of f16 versions of WMMA ops involving `C` operand. If
+/// `opSelect` is true upper half of the general purpose 32-bit registers is
+/// used for storing the values; If false the lower half is used.
+std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
+createLowerGpuOpsToAMDGPUOpsPass(const std::string &chipset = "gfx1100",
+                                 unsigned warpSize = 32);
+
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to AMDGPU.
+/// `chipset` is the target chip for which the IR is being generated.
+/// `warpSize` is the warp size to use when generating WMMA intrinsics.
+void populateGpuWMMAToAMDGPUConversionPatterns(TypeConverter &typeConverter,
+                                               RewritePatternSet &patterns,
+                                               llvm::StringRef chipset,
+                                               unsigned warpSize);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTOAMDGPU_GPUTOAMDGPUPASS_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index e714f5070f23db8..9a4f9812253d81b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -30,6 +30,7 @@
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h"
 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a269fb4a83af41f..688f505a5b1ee3a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -495,6 +495,30 @@ def LowerHostCodeToLLVMPass : Pass<"lower-host-to-llvm", "ModuleOp"> {
   let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// GPUToAMDGPU
+//===----------------------------------------------------------------------===//
+
+def ConvertGpuOpsToAMDGPUOps : Pass<"convert-gpu-to-amdgpu", "gpu::GPUModuleOp"> {
+  let summary = "Generate AMD GPU operations for gpu operations";
+  let constructor = "mlir::createLowerGpuOpsToAMDGPUOpsPass()";
+  let dependentDialects = [
+    "amdgpu::AMDGPUDialect",
+  ];
+  let options = [
+    Option<"chipset", "chipset", "std::string",
+           /*default=*/"\"gfx000\"",
+           "Chipset that these operations will run on">,
+    Option<"indexBitwidth", "index-bitwidth", "unsigned",
+           /*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
+           "Bitwidth of the index type, 0 to use size of machine word">,
+    Option<"warpSize", "warp-size", "unsigned",
+           /*default=*/"32",
+           "AMD GPUs have a configurable warp size; valid choices are 32 and "
+           "64. 32 is used as the default size.">,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // GPUToNVVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 35790254be137be..6a7bee3a10866cd 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -19,6 +19,7 @@ add_subdirectory(ConvertToLLVM)
 add_subdirectory(FuncToLLVM)
 add_subdirectory(FuncToSPIRV)
 add_subdirectory(GPUCommon)
+add_subdirectory(GPUToAMDGPU)
 add_subdirectory(GPUToNVVM)
 add_subdirectory(GPUToROCDL)
 add_subdirectory(GPUToSPIRV)
diff --git a/mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt b/mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt
new file mode 100644
index 000000000000000..7e201484a76cf30
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToAMDGPU/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRGPUToAMDGPUTransforms
+  LowerGPUOpsToAMDGPUOps.cpp
+  WmmaOpsToAMDGPU.cpp
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArithToLLVM
+  MLIRFuncToLLVM
+  MLIRGPUDialect
+  MLIRGPUToGPURuntimeTransforms
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRMemRefToLLVM
+  MLIRROCDLDialect
+  MLIRPass
+  )
diff --git a/mlir/lib/Conversion/GPUToAMDGPU/LowerGPUOpsToAMDGPUOps.cpp b/mlir/lib/Conversion/GPUToAMDGPU/LowerGPUOpsToAMDGPUOps.cpp
new file mode 100644
index 000000000000000..c20d8eedea13361
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToAMDGPU/LowerGPUOpsToAMDGPUOps.cpp
@@ -0,0 +1,101 @@
+//===- LowerGpuOpsToAMDGPUOps.cpp - MLIR GPU to AMD GPU lowering passes ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to generate AMDGPU operations for higher-level
+// GPU operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h"
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTGPUOPSTOAMDGPUOPS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct LowerGpuOpsToAMDGPUOpsPass
+    : public impl::ConvertGpuOpsToAMDGPUOpsBase<LowerGpuOpsToAMDGPUOpsPass> {
+  LowerGpuOpsToAMDGPUOpsPass() = default;
+  LowerGpuOpsToAMDGPUOpsPass(const std::string &chipset, unsigned warpSize) {
+    if (this->chipset.getNumOccurrences() == 0)
+      this->chipset = chipset;
+    if (this->warpSize.getNumOccurrences() == 0)
+      this->warpSize = warpSize;
+  }
+
+  void runOnOperation() override {
+    gpu::GPUModuleOp m = getOperation();
+    MLIRContext *ctx = m.getContext();
+
+    // Request C wrapper emission.
+    for (auto func : m.getOps<func::FuncOp>()) {
+      func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
+                    UnitAttr::get(ctx));
+    }
+
+    FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+    if (failed(maybeChipset)) {
+      emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
+      return signalPassFailure();
+    }
+
+    TypeConverter converter;
+
+    RewritePatternSet amdgpuPatterns(ctx);
+
+    populateGpuToAMDGPUConversionPatterns(converter, amdgpuPatterns,
+                                          this->chipset, this->warpSize);
+    ConversionTarget target(*ctx);
+    // We do not mark GPU dialect illegal as other GPU ops and WMMA ops
+    // unsupported by pattersn defined here are still allowed.
+    target.addLegalDialect<amdgpu::AMDGPUDialect>();
+
+    if (failed(applyPartialConversion(m, target, std::move(amdgpuPatterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+void mlir::populateGpuToAMDGPUConversionPatterns(TypeConverter &converter,
+                                                 RewritePatternSet &patterns,
+                                                 StringRef chipset,
+                                                 unsigned warpSize) {
+  // Lowering for MMAMatrixType.
+  converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+    return amd::convertWMMAToROCDLLLVMType(type);
+  });
+
+  // We need to add target and source materializations so that the IR still
+  // remains valid after the `gpu.mma_matrix` type conversion is done.
+  auto buildUnrealizedCast = [](OpBuilder &builder, Type type,
+                                ValueRange inputs, Location loc) {
+    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return std::optional<Value>(cast.getResult(0));
+  };
+  converter.addSourceMaterialization(buildUnrealizedCast);
+  converter.addTargetMaterialization(buildUnrealizedCast);
+
+  /// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
+  populateGpuWMMAToAMDGPUConversionPatterns(converter, patterns, chipset,
+                                            warpSize);
+}
+
+std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
+mlir::createLowerGpuOpsToAMDGPUOpsPass(const std::string &chipset,
+                                       unsigned warpSize) {
+  return std::make_unique<LowerGpuOpsToAMDGPUOpsPass>(chipset, warpSize);
+}
diff --git a/mlir/lib/Conversion/GPUToAMDGPU/WmmaOpsToAMDGPU.cpp b/mlir/lib/Conversion/GPUToAMDGPU/WmmaOpsToAMDGPU.cpp
new file mode 100644
index 000000000000000..9828e93daa3bfbd
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToAMDGPU/WmmaOpsToAMDGPU.cpp
@@ -0,0 +1,180 @@
+//===-------- WmmaOpsToAMDGPU.cpp - GPU WMMA ops to AMD GPU lowering ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
+// AMD GPU Dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h"
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+
+using namespace mlir;
+
+namespace {
+
+static LogicalResult areAllVectorTypes(Operation *op, ValueRange operands,
+                                       ConversionPatternRewriter &rewriter) {
+  if (!llvm::all_of(operands, [](Value value) {
+        return isa<mlir::VectorType>(value.getType());
+      })) {
+    return rewriter.notifyMatchFailure(
+        op, "cannot convert if operands aren't of Vector type.");
+  }
+
+  return success();
+}
+
+/// Create a WMMA compute intrinsic doing the multiply-add operation as :
+///
+///  `cOp` = `aOp` * `bOp` + `cOp`
+///
+/// and return the generated op in `computeOp`.
+static LogicalResult createWMMAComputeIntrinsic(Value aOp, Value bOp, Value cOp,
+                                                Location loc, bool opSelect,
+                                                PatternRewriter &rewriter,
+                                                Value &computeOp) {
+  Type aType = aOp.getType();
+  Type bType = bOp.getType();
+  Type cType = cOp.getType();
+
+  // All the intrinsics present currently operate on vector types.
+  auto checkVecType = [](Value value, StringRef op) {
+    if (!isa<VectorType>(value.getType())) {
+      return mlir::emitError(value.getLoc(), op + "should be of vector type");
+    }
+    return InFlightDiagnostic();
+  };
+
+  if (failed(checkVecType(aOp, "aOp")))
+    return failure();
+  if (failed(checkVecType(bOp, "bOp")))
+    return failure();
+  if (failed(checkVecType(cOp, "cOp")))
+    return failure();
+
+  auto aVecType = aType.cast<VectorType>();
+  auto bVecType = bType.cast<VectorType>();
+  auto cVecType = cType.cast<VectorType>();
+
+  if (aVecType != bVecType)
+    return emitError(aOp.getLoc(), "aOp and bOp must be of same type");
+
+  Type aEltType = aVecType.getElementType();
+  Type cEltType = cVecType.getElementType();
+
+  // We support lowering for the mixed-precision and full fp16 WMMA intrinsics
+  // currently.
+  if (aEltType.isF16() && cEltType.isF32()) {
+    // subwordOffset is always false for F32 `C` operands as they occupy all 32
+    // bits in the VGPR.
+    computeOp = rewriter.create<amdgpu::WMMAOp>(loc, cType, aOp, bOp, cOp,
+                                                /*subwordOffset=*/false);
+    return success();
+  }
+  if (aEltType.isF16() && cEltType.isF16()) {
+    computeOp =
+        rewriter.create<amdgpu::WMMAOp>(loc, cType, aOp, bOp, cOp, opSelect);
+    return success();
+  }
+
+  return failure();
+}
+
+/// This class implements the conversion of GPU MMA computeOp to wmma.mma op
+/// in the ROCDL dialect.
+struct WmmaMmaOpToAMDGPULowering
+    : public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+  WmmaMmaOpToAMDGPULowering(TypeConverter &typeConverter, MLIRContext *context,
+                            StringRef chip, unsigned warpSize)
+      : OpConversionPattern<gpu::SubgroupMmaComputeOp>::OpConversionPattern(
+            typeConverter, context),
+        warpSize(warpSize), chip(chip){};
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (failed(areAllVectorTypes(subgroupMmaComputeOp.getOperation(),
+                                 adaptor.getOperands(), rewriter)))
+      return failure();
+
+    std::size_t firstPos = chip.find("gfx11");
+    std::size_t lastPos = chip.rfind("gfx11");
+    if (firstPos != 0 || (firstPos != lastPos))
+      return subgroupMmaComputeOp->emitError(
+          "wmma lowering is supported for gfx11 series only");
+
+    if (warpSize != amd::kWaveFrontSize32)
+      return subgroupMmaComputeOp->emitError(
+          "wavefront of size 32 only supported");
+
+    auto aTranspose = subgroupMmaComputeOp.getATranspose();
+    auto bTranspose = subgroupMmaComputeOp.getBTranspose();
+
+    if ((aTranspose.has_value() && aTranspose.value()) ||
+        (bTranspose.has_value() && bTranspose.value()))
+      return subgroupMmaComputeOp->emitError(
+          "lowering with transpose is not supported. Please "
+          "use transpose while loading/storing the operands.");
+
+    Location loc = subgroupMmaComputeOp->getLoc();
+
+    gpu::MMAMatrixType aType =
+        subgroupMmaComputeOp.getOpA().getType().cast<gpu::MMAMatrixType>();
+    gpu::MMAMatrixType bType =
+        subgroupMmaComputeOp.getOpA().getType().cast<gpu::MMAMatrixType>();
+    gpu::MMAMatrixType cType =
+        subgroupMmaComputeOp.getOpC().getType().cast<gpu::MMAMatrixType>();
+
+    SmallVector<gpu::MMAMatrixType> allTypes = {aType, bType, cType};
+
+    SmallVector<int64_t> aTypeShape(aType.getShape());
+    SmallVector<int64_t> bTypeShape(bType.getShape());
+    SmallVector<int64_t> cTypeShape(cType.getShape());
+    SmallVector<SmallVector<int64_t>> allShapes = {aTypeShape, bTypeShape,
+                                                   cTypeShape};
+
+    if (!llvm::all_of(allShapes, [](ArrayRef<int64_t> shape) {
+          return llvm::all_of(shape, [](int dim) { return dim == 16; });
+        }))
+      return subgroupMmaComputeOp->emitError(
+          "wmma ops of shape 16x16x16 are only supported.");
+
+    // Get the WMMA intrinsic to map to.
+    bool opSelect = subgroupMmaComputeOp->hasAttrOfType<UnitAttr>(
+        amd::kAMDGpuOpselectAttrName);
+    Value computeOp;
+    if (failed(createWMMAComputeIntrinsic(adaptor.getOpA(), adaptor.getOpB(),
+                                          adaptor.getOpC(), loc, opSelect,
+                                          rewriter, computeOp)))
+      return rewriter.notifyMatchFailure(subgroupMmaComputeOp,
+                                         "unsupported mma op variant.");
+
+    rewriter.replaceOp(subgroupMmaComputeOp, computeOp);
+    return success();
+  }
+
+  /// `warpSize` is the warp size to use when generating WMMA intrinsics.
+  unsigned warpSize;
+
+  /// The target chip for which to generate the lowering.
+  std::string chip;
+};
+
+} // namespace
+
+void mlir::populateGpuWMMAToAMDGPUConversionPatterns(
+    TypeConverter &converter, RewritePatternSet &patterns, StringRef chip,
+    unsigned warpSize) {
+  patterns.add<WmmaMmaOpToAMDGPULowering>(converter, patterns.getContext(),
+                                          chip, warpSize);
+}
diff --git a/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-chipset.mlir b/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-chipset.mlir
new file mode 100644
index 000000000000000..3722a1ee43b966e
--- /dev/null
+++ b/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-chipset.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -convert-gpu-to-amdgpu='chipset=gfx900 index-bitwidth=32' -split-input-file -verify-diagnostics
+
+gpu.module @test_module {
+  func.func @compute_op_f32_f16(%arg0: !gpu.mma_matrix<16x16xf16, "AOp">, %arg1: !gpu.mma_matrix<16x16xf16, "BOp">, %arg2: !gpu.mma_matrix<16x16xf32, "COp">) -> (!gpu.mma_matrix<16x16xf32, "COp">) {
+    %0 = gpu.subgroup_mma_compute %arg0, %arg1, %arg2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
+    // expected-error at -1 {{wmma lowering is supported for gfx11 series only}}
+    return %0 : !gpu.mma_matrix<16x16xf32, "COp">
+  }
+}
+
diff --git a/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-operands.mlir b/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-operands.mlir
new file mode 100644
index 000000000000000..871bda752a5f397
--- /dev/null
+++ b/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-operands.mlir
@@ -0,0 +1,33 @@
+// This file tests the we error out properly when unsupported ops are
+// encountered for GPU wmma ops to ROCDL conversion.
+
+// RUN: mlir-opt %s -convert-gpu-to-amdgpu='chipset=gfx1100 index-bitwidth=32' -split-input-file -verify-diagnostics
+gpu.module @test_module {
+  // CHECK-LABEL: compute_op_f32_f16_transpose
+  func.func @compute_op_f32_f16(%arg0: !gpu.mma_matrix<16x16xf16, "AOp">, %arg1: !gpu.mma_matrix<16x16xf16, "BOp">, %arg2: !gpu.mma_matrix<16x16xf32, "COp">) -> (!gpu.mma_matrix<16x16xf32, "COp">) {
+    %0 = gpu.subgroup_mma_compute %arg0, %arg1, %arg2 {a_transpose}: !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
+    // expected-error at -1 {{lowering with transpose is not supported. Please use transpose while loading/storing the operands.}}
+    return %0 : !gpu.mma_matrix<16x16xf32, "COp">
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: compute_op_f32_f16_transpose
+  func.func @compute_op_f32_f16(%arg0: !gpu.mma_matrix<16x16xf16, "AOp">, %arg1: !gpu.mma_matrix<16x16xf16, "BOp">, %arg2: !gpu.mma_matrix<16x16xf32, "COp">) -> (!gpu.mma_matrix<16x16xf32, "COp">) {
+    %0 = gpu.subgroup_mma_compute %arg0, %arg1, %arg2 {b_transpose}: !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
+    // expected-error at -1 {{lowering with transpose is not supported. Please use transpose while loading/storing the operands.}}
+    return %0 : !gpu.mma_matrix<16x16xf32, "COp">
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  func.func @compute_op_f32_f16(%arg0: !gpu.mma_matrix<16x8xf16, "AOp">, %arg1: !gpu.mma_matrix<8x16xf16, "BOp">, %arg2: !gpu.mma_matrix<16x16xf32, "COp">) -> (!gpu.mma_matrix<16x16xf32, "COp">) {
+    %0 = gpu.subgroup_mma_compute %arg0, %arg1, %arg2 : !gpu.mma_matrix<16x8xf16, "AOp">, !gpu.mma_matrix<8x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
+    // expected-error at -1 {{wmma ops of shape 16x16x16 are only supported.}}
+    return %0 : !gpu.mma_matrix<16x16xf32, "COp">
+  }
+}
diff --git a/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-warpsize.mlir b/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-warpsize.mlir
new file mode 100644
index 000000000000000..7d6410e2768d33d
--- /dev/null
+++ b/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu-unsupported-warpsize.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -convert-gpu-to-amdgpu='chipset=gfx1100 warp-size=64' -split-input-file -verify-diagnostics
+
+gpu.module @test_module {
+  func.func @compute_op_f32_f16(%arg0: !gpu.mma_matrix<16x16xf16, "AOp">, %arg1: !gpu.mma_matrix<16x16xf16, "BOp">, %arg2: !gpu.mma_matrix<16x16xf32, "COp">) -> (!gpu.mma_matrix<16x16xf32, "COp">) {
+    %0 = gpu.subgroup_mma_compute %arg0, %arg1, %arg2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
+    // expected-error at -1 {{wavefront of size 32 only supported}}
+    return %0 : !gpu.mma_matrix<16x16xf32, "COp">
+  }
+}
+
diff --git a/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu.mlir b/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu.mlir
new file mode 100644
index 000000000000000..502ad767baae453
--- /dev/null
+++ b/mlir/test/Conversion/GPUToAMDGPU/wmma-ops-to-amdgpu.mlir
@@ -0,0 +1,34 @@
+// This file tests the conversion of GPU WMMA ops to ROCDL dialect.
+// RUN: mlir-opt %s -convert-gpu-to-amdgpu='chipset=gfx1100 index-bitwidth=32' -split-input-file | FileCheck %s
+
+gpu.module @test_module {
+  // CHECK-LABEL: compute_op_f32_f16
+  // CHECK-SAME: (%[[AOP:.*]]: !gpu.mma_matrix<16x16xf16, "AOp">, %[[BOP:.*]]: !gpu.mma_matrix<16x16xf16, "BOp">, %[[COP:.*]]: !gpu.mma_matrix<16x16xf32, "COp">)
+  func.func @compute_op_f32_f16(%arg0: !gpu.mma_matrix<16x16xf16, "AOp">, %arg1: !gpu.mma_matrix<16x16xf16, "BOp">, %arg2: !gpu.mma_matrix<16x16xf32, "COp">) -> (!gpu.mma_matrix<16x16xf32, "COp">) {
+    %0 = gpu.subgroup_mma_compute %arg0, %arg1, %arg2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
+    // CHECK:      %[[AOPMAT:.*]] = builtin.unrealized_conversion_cast %[[AOP]] : !gpu.mma_matrix<16x16xf16, "AOp"> to vector<16xf16>
+    // CHECK-NEXT: %[[BOPMAT:.*]] = builtin.unrealized_conversion_cast %[[BOP]] : !gpu.mma_matrix<16x16xf16, "BOp"> to vector<16xf16>
+    // CHECK-NEXT: %[[COPMAT:.*]] = builtin.unrealized_conversion_cast %[[COP]] : !gpu.mma_matrix<16x16xf32, "COp"> to vector<8xf32>
+    // CHECK-NEXT: %[[OUTVEC:.*]] = amdgpu.wmma %[[AOPMAT]] * %[[BOPMAT]] + %[[COPMAT]] : vector<16xf16>, vector<16xf16>, vector<8xf32>
+    // CHECK-NEXT: %[[OUTMAT:.*]] = builtin.unrealized_conversion_cast %[[OUTVEC]] : vector<8xf32> to !gpu.mma_matrix<16x16xf32, "COp">
+    // CHECK-NEXT: return %[[OUTMAT]] : !gpu.mma_matrix<16x16xf32, "COp">
+    return %0 : !gpu.mma_matrix<16x16xf32, "COp">
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: compute_op_f16_f16
+  // CHECK-SAME: (%[[AOP:.*]]: !gpu.mma_matrix<16x16xf16, "AOp">, %[[BOP:.*]]: !gpu.mma_matrix<16x16xf16, "BOp">, %[[COP:.*]]: !gpu.mma_matrix<16x16xf16, "COp">)
+  func.func @compute_op_f16_f16(%arg0: !gpu.mma_matrix<16x16xf16, "AOp">, %arg1: !gpu.mma_matrix<16x16xf16, "BOp">, %arg2: !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
+    %0 = gpu.subgroup_mma_compute %arg0, %arg1, %arg2 {opSelect} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+    // CHECK:      %[[AOPMAT:.*]] = builtin.unrealized_conversion_cast %[[AOP]] : !gpu.mma_matrix<16x16xf16, "AOp"> to vector<16xf16>
+    // CHECK-NEXT: %[[BOPMAT:.*]] = builtin.unrealized_conversion_cast %[[BOP]] : !gpu.mma_matrix<16x16xf16, "BOp"> to vector<16xf16>
+    // CHECK-NEXT: %[[COPMAT:.*]] = builtin.unrealized_conversion_cast %[[COP]] : !gpu.mma_matrix<16x16xf16, "COp"> to vector<16xf16>
+    // CHECK-NEXT: %[[OUTVEC:.*]] = amdgpu.wmma %[[AOPMAT]] * %[[BOPMAT]] + %[[COPMAT]] {subwordOffset = 1 : i32} : vector<16xf16>, vector<16xf16>, vector<16xf16>
+    // CHECK-NEXT: %[[OUTMAT:.*]] = builtin.unrealized_conversion_cast %[[OUTVEC]] : vector<16xf16> to !gpu.mma_matrix<16x16xf16, "COp">
+    // CHECK-NEXT: return %[[OUTMAT]] : !gpu.mma_matrix<16x16xf16, "COp">
+    return %0 : !gpu.mma_matrix<16x16xf16, "COp">
+  }
+}

>From 449409a91af5f5f099c620782ec7ecf7e0991eab Mon Sep 17 00:00:00 2001
From: navdeep <navdeep at polymagelab.com>
Date: Tue, 19 Sep 2023 15:13:44 +0530
Subject: [PATCH 3/3] [MLIR][AMDGPU] Add Lowerings for GPU WMMA load/store
 F16/F32 ops to ROCDL dialect

The following support is added:
1.) Lowering for GPU WMMA load op for AOp, BOp, COp. The lowering
  supports transposed and non-transposed loads for AOp and BOp. Only
  non-transposed loads are supported for COp. Loading for COp also
  supports the opSelect bit.
2.) Lowering for GPU WMMA store op with support for opSelect bit.

Differential Revision: https://reviews.llvm.org/D157228
---
 .../Conversion/GPUToROCDL/GPUToROCDLPass.h    |  38 +-
 mlir/include/mlir/Conversion/Passes.td        |  27 +-
 .../mlir/Dialect/LLVMIR/CMakeLists.txt        |   4 +
 .../mlir/Dialect/LLVMIR/ROCDLDialect.h        |   2 +
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td  |  13 +
 mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt |   1 +
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      |  38 +-
 .../Conversion/GPUToROCDL/WmmaOpsToROCDL.cpp  | 512 ++++++++++++++++++
 mlir/test/CMakeLists.txt                      |   3 +
 ...wmma-ops-to-rocdl-unsupported-chipset.mlir |  30 +
 .../wmma-ops-to-rocdl-unsupported.mlir        | 181 +++++++
 .../GPUToROCDL/wmma-ops-to-rocdl.mlir         | 442 +++++++++++++++
 .../Integration/GPU/ROCM/WMMA/lit.local.cfg   |   5 +
 .../GPU/ROCM/WMMA/wmma_f16_16_16_16_f16.mlir  |  95 ++++
 .../WMMA/wmma_f16_16_16_16_f16_opselect.mlir  |  95 ++++
 .../ROCM/WMMA/wmma_f16_16_16_16_f16_x2.mlir   | 100 ++++
 .../GPU/ROCM/WMMA/wmma_f32_16_16_16_f16.mlir  |  86 +++
 .../wmma_f32_16_16_16_f16_a_b_transpose.mlir  |  84 +++
 mlir/test/lit.site.cfg.py.in                  |   1 +
 19 files changed, 1729 insertions(+), 28 deletions(-)
 create mode 100644 mlir/lib/Conversion/GPUToROCDL/WmmaOpsToROCDL.cpp
 create mode 100644 mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported-chipset.mlir
 create mode 100644 mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported.mlir
 create mode 100644 mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl.mlir
 create mode 100644 mlir/test/Integration/GPU/ROCM/WMMA/lit.local.cfg
 create mode 100644 mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16.mlir
 create mode 100644 mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_opselect.mlir
 create mode 100644 mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_x2.mlir
 create mode 100644 mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16.mlir
 create mode 100644 mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16_a_b_transpose.mlir

diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
index 311490f1d05f0c3..7b0e845cf81a520 100644
--- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
+++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
@@ -18,18 +18,24 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include <memory>
 
+namespace llvm {
+class StringRef;
+} // namespace llvm
+
 namespace mlir {
 class LLVMTypeConverter;
 class ConversionTarget;
 class OpBuilder;
 class Location;
 class RewritePatternSet;
+class Type;
 
 template <typename OpT>
 class OperationPass;
 
 namespace gpu {
 class GPUModuleOp;
+class MMAMatrixType;
 } // namespace gpu
 
 #define GEN_PASS_DECL_CONVERTGPUOPSTOROCDLOPS
@@ -47,7 +53,7 @@ const unsigned kWMMASupportedWaveFrontSizes[] = {kWaveFrontSize32,
                                                  kWaveFrontSize64};
 
 /// Generate ops to get the laneId of the current lane and return it.
-Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
+Value getLaneId(PatternRewriter &rewriter, Location loc,
                 unsigned indexBitwidth);
 
 /// Return the LLVM Type corresponding to the MMAMatrixType.
@@ -55,24 +61,40 @@ Type convertWMMAToROCDLLLVMType(gpu::MMAMatrixType matrixType);
 } // namespace amd
 
 /// Collect a set of patterns to convert from the GPU dialect to ROCDL.
-/// If `runtime` is Unknown, gpu.printf will not be lowered
-/// The resulting pattern set should be run over a gpu.module op
-void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
-                                          RewritePatternSet &patterns,
-                                          gpu::amd::Runtime runtime);
+/// If `runtime` is Unknown, gpu.printf will not be lowered. The resulting
+/// pattern set should be run over a gpu.module op. `chipset` is the chip we are
+/// targeting. `indexBitwidth` is the bitwidth to be used while converting index
+/// types. `warpSize` is the warp size to use when generating WMMA intrinsics.
+void populateGpuToROCDLConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    gpu::amd::Runtime runtime, llvm::StringRef chipset = "gfx900",
+    unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
+    unsigned warpSize = 32);
 
 /// Configure target to convert from the GPU dialect to ROCDL.
 void configureGpuToROCDLConversionLegality(ConversionTarget &target);
 
 /// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The
 /// index bitwidth used for the lowering of the device side index computations
-/// is configurable.
+/// is configurable. AMD gpus have a configurable warp size; valid choices are
+/// 32 and 64. We choose 32 as the default size.
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
 createLowerGpuOpsToROCDLOpsPass(
     const std::string &chipset = "gfx900",
     unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
     bool useBarePtrCallConv = false,
-    gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);
+    gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown,
+    unsigned warpSize = 32);
+
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to ROCDL.
+/// `chipset` is the target chip for which the IR is being generated.
+/// `indexBitwidth` is the bitwidth to be used while converting index types.
+/// `warpSize` is the warp size to use when generating WMMA intrinsics.
+void populateGpuWMMAToROCDLConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    llvm::StringRef chipset = "gfx900",
+    unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
+    unsigned warpSize = 32);
 
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 688f505a5b1ee3a..5ea284774a9823b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -563,23 +563,30 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
            /*default=*/"\"gfx000\"",
            "Chipset that these operations will run on">,
     Option<"indexBitwidth", "index-bitwidth", "unsigned",
-           /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+           /*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
            "Bitwidth of the index type, 0 to use size of machine word">,
     Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
            /*default=*/"false",
            "Replace memref arguments in GPU functions with bare pointers."
            "All memrefs must have static shape">,
     Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
-          "::mlir::gpu::amd::Runtime::Unknown",
-          "Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
-          [{::llvm::cl::values(
-            clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown", "Unknown (default)"),
-            clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
-            clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL", "OpenCL")
-          )}]>,
+           "::mlir::gpu::amd::Runtime::Unknown",
+           "Runtime code will be run on (default is Unknown, can also use HIP "
+           "or OpenCl)",
+           [{::llvm::cl::values(
+               clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown",
+                          "Unknown (default)"),
+               clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
+               clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL",
+                          "OpenCL"))}]>,
     Option<"useOpaquePointers", "use-opaque-pointers", "bool",
-               /*default=*/"true", "Generate LLVM IR using opaque pointers "
-               "instead of typed pointers">,
+           /*default=*/"true",
+           "Generate LLVM IR using opaque pointers "
+           "instead of typed pointers">,
+    Option<"warpSize", "warp-size", "unsigned",
+           /*default=*/"32",
+           "AMD GPUs have a configurable warp size; valid choices are 32 and "
+           "64. 32 is used as the default size.">,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 64de028c7fe4061..4d0caae203c7d31 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -65,6 +65,10 @@ add_public_tablegen_target(MLIRNVVMConversionsIncGen)
 add_mlir_dialect(ROCDLOps rocdl)
 add_mlir_doc(ROCDLOps ROCDLDialect Dialects/ -gen-dialect-doc -dialect=rocdl)
 set(LLVM_TARGET_DEFINITIONS ROCDLOps.td)
+mlir_tablegen(ROCDLOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ROCDLOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(ROCDLOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=rocdl)
+mlir_tablegen(ROCDLOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=rocdl)
 mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions)
 mlir_tablegen(ROCDLOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=rocdl)
 mlir_tablegen(ROCDLOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=rocdl)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
index c2a82ffc1c43cf6..54e9980bb213f59 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
@@ -28,6 +28,8 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "mlir/Dialect/LLVMIR/ROCDLOpsEnums.h.inc"
+
 ///// Ops /////
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/LLVMIR/ROCDLOpsAttributes.h.inc"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 6c6419bf238b457..55d5c018f7430bb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/IR/EnumAttr.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -262,6 +263,18 @@ class ROCDL_Wmma_IntrOp<string mnemonic, list<Trait> traits = []> :
     "$args attr-dict `:` functional-type($args, $res)";
 }
 
+def ROCDLWMMAFragA : I32EnumAttrCase<"a", 0>;
+def ROCDLWMMAFragB : I32EnumAttrCase<"b", 1>;
+def ROCDLWMMAFragC : I32EnumAttrCase<"c", 2>;
+
+/// Enum attribute of the different frag types.
+def ROCDLWMMAFrag
+    : I32EnumAttr<"ROCDLWMMAFrag", "ROCDL WMMA frag type",
+                  [ROCDLWMMAFragA, ROCDLWMMAFragB, ROCDLWMMAFragC]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::ROCDL";
+}
+
 // Available on RDNA3
 def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16">;
 def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16">;
diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 70707b5c3a0494c..932878173202c57 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -4,6 +4,7 @@ add_public_tablegen_target(MLIRGPUToROCDLIncGen)
 
 add_mlir_conversion_library(MLIRGPUToROCDLTransforms
   LowerGpuOpsToROCDLOps.cpp
+  WmmaOpsToROCDL.cpp
 
   DEPENDS
   MLIRConversionPassIncGen
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 938dc5a6909fe04..4ecb2afcd29207e 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -64,8 +65,8 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
   return canBeBare;
 }
 
-Value amd::getLaneId(ConversionPatternRewriter &rewriter, Location loc,
-                const unsigned indexBitwidth) {
+Value amd::getLaneId(PatternRewriter &rewriter, Location loc,
+                     const unsigned indexBitwidth) {
   // convert to:  %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
   // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
   MLIRContext *context = rewriter.getContext();
@@ -187,8 +188,8 @@ struct LowerGpuOpsToROCDLOpsPass
     : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
   LowerGpuOpsToROCDLOpsPass() = default;
   LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
-                            bool useBarePtrCallConv,
-                            gpu::amd::Runtime runtime) {
+                            bool useBarePtrCallConv, gpu::amd::Runtime runtime,
+                            unsigned warpSize) {
     if (this->chipset.getNumOccurrences() == 0)
       this->chipset = chipset;
     if (this->indexBitwidth.getNumOccurrences() == 0)
@@ -197,6 +198,8 @@ struct LowerGpuOpsToROCDLOpsPass
       this->useBarePtrCallConv = useBarePtrCallConv;
     if (this->runtime.getNumOccurrences() == 0)
       this->runtime = runtime;
+    if (this->warpSize.getNumOccurrences() == 0)
+      this->warpSize = warpSize;
   }
 
   void runOnOperation() override {
@@ -272,7 +275,9 @@ struct LowerGpuOpsToROCDLOpsPass
     cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
     populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
     populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
-    populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
+    populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime,
+                                         this->chipset, this->indexBitwidth,
+                                         this->warpSize);
     LLVMConversionTarget target(getContext());
     configureGpuToROCDLConversionLegality(target);
     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
@@ -324,11 +329,19 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
 }
 
-void mlir::populateGpuToROCDLConversionPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    mlir::gpu::amd::Runtime runtime) {
+void mlir::populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter,
+                                                RewritePatternSet &patterns,
+                                                mlir::gpu::amd::Runtime runtime,
+                                                StringRef chipset,
+                                                unsigned indexBitwidth,
+                                                unsigned warpSize) {
   using mlir::gpu::amd::Runtime;
 
+  // Lowering for MMAMatrixType.
+  converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+    return amd::convertWMMAToROCDLLLVMType(type);
+  });
+
   populateWithGenerated(patterns);
   patterns
       .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
@@ -358,6 +371,10 @@ void mlir::populateGpuToROCDLConversionPatterns(
 
   patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
 
+  /// Collect a set of patterns to convert WMMA ops from GPU dialect to ROCDL.
+  populateGpuWMMAToROCDLConversionPatterns(converter, patterns, chipset,
+                                           indexBitwidth, warpSize);
+
   populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
                                    "__ocml_fabs_f64");
   populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
@@ -408,7 +425,8 @@ std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
 mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
                                       unsigned indexBitwidth,
                                       bool useBarePtrCallConv,
-                                      gpu::amd::Runtime runtime) {
+                                      gpu::amd::Runtime runtime,
+                                      unsigned warpSize) {
   return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
-      chipset, indexBitwidth, useBarePtrCallConv, runtime);
+      chipset, indexBitwidth, useBarePtrCallConv, runtime, warpSize);
 }
diff --git a/mlir/lib/Conversion/GPUToROCDL/WmmaOpsToROCDL.cpp b/mlir/lib/Conversion/GPUToROCDL/WmmaOpsToROCDL.cpp
new file mode 100644
index 000000000000000..dfcd7f64836cd0f
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToROCDL/WmmaOpsToROCDL.cpp
@@ -0,0 +1,512 @@
+//===--------- WmmaOpsToROCDL.cpp - GPU WMMA ops to ROCDL lowering --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
+// ROCDL Dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
+#include "mlir/Conversion/GPUToAMDGPU/GPUToAMDGPUPass.h"
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Checks if all the operands of the op being lowered are of LLVM Types. The
+/// types are expected to be converted by the `LLVMTypeConverter` before the op
+/// is actually lowered. If the type of an operands is not already converted it
+/// hints a missing typeConversion and failure is returned in that case.
+static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
+                                     ConversionPatternRewriter &rewriter) {
+  if (!llvm::all_of(operands, [](Value value) {
+        return LLVM::isCompatibleType(value.getType());
+      })) {
+    return rewriter.notifyMatchFailure(
+        op, "cannot convert if operands aren't of LLVM type.");
+  }
+
+  return success();
+}
+
+/// Return the WMMA operand corresponding to `operandName`.
+static ROCDL::ROCDLWMMAFrag convertOperand(StringRef operandName) {
+  if (operandName.equals("AOp"))
+    return ROCDL::ROCDLWMMAFrag::a;
+  if (operandName.equals("BOp"))
+    return ROCDL::ROCDLWMMAFrag::b;
+  if (operandName.equals("COp"))
+    return ROCDL::ROCDLWMMAFrag::c;
+  llvm_unreachable("Unknown operand name");
+}
+
+/// Generate load ops for `AOp` or `BOp`. `dataPtr` is the base address starting
+/// from which values will be loaded. `laneId` lane ID of the thread loading the
+/// values. `vecType` is the vector type of the values that will be loaded. The
+/// loaded values are returned in `loadedValues`. The address for loading the
+/// values is generated in the following manner:
+///
+/// wrappedLaneId = laneId % 16
+/// for i in vectorSize {
+///   loadedValues[i] = dataPtr + ((wrappedLaneId * leadingDim) + i);
+/// }
+static void generateAbLoadOpsVecFirst(Location loc, Value dataPtr, Value laneId,
+                                      Value leadingDim, VectorType vecType,
+                                      PatternRewriter &rewriter,
+                                      Value &loadedValues) {
+  // We wrap the laneId to 16 because of matrix replication in RDNA 3.
+  Value wrapSize = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                     /*value=*/16);
+  mlir::TypedAttr x;
+  Value wrappedLaneId = rewriter.create<LLVM::SRemOp>(loc, laneId, wrapSize);
+  loadedValues = rewriter.create<LLVM::UndefOp>(loc, vecType);
+  Value laneIdLdm =
+      rewriter.create<LLVM::MulOp>(loc, wrappedLaneId, leadingDim);
+  for (unsigned i = 0; i < vecType.getNumElements(); ++i) {
+    Value iter = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                   /*value=*/i);
+    Value curInx = rewriter.create<LLVM::AddOp>(loc, laneIdLdm, iter);
+    Value curAddress = rewriter.create<LLVM::GEPOp>(
+        loc, dataPtr.getType(), vecType.getElementType(), dataPtr, curInx);
+    // Load the value from the current index.
+    Value loaded = rewriter.create<LLVM::LoadOp>(loc, vecType.getElementType(),
+                                                 curAddress);
+    loadedValues = rewriter.create<LLVM::InsertElementOp>(
+        loc, vecType, loadedValues, loaded, iter);
+  }
+}
+
+/// Generate load ops for `AOp` or `BOp`. `dataPtr` is the base address starting
+/// from which values will be loaded. `laneId` is the lane ID of the thread
+/// loading the values. `vecType` is the vector type of the values that will be
+/// loaded. The loaded values are returned in `loadedValues`. The address for
+/// loading the values is generated in the following manner:
+///
+/// wrappedLaneId = laneId % 16
+/// for i in vectorSize {
+///   loadedValues[i] = dataPtr + ((i * leadingDim) + wrappedLaneId);
+/// }
+static void generateAbLoadOpsLaneFirst(Location loc, Value dataPtr,
+                                       Value laneId, Value leadingDim,
+                                       VectorType vecType,
+                                       PatternRewriter &rewriter,
+                                       Value &loadedValues) {
+  // We wrap the laneId to 16 because of matrix replication in RDNA 3.
+  Value wrapSize = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                     /*value=*/16);
+  Value wrappedLaneId = rewriter.create<LLVM::SRemOp>(loc, laneId, wrapSize);
+  loadedValues = rewriter.create<LLVM::UndefOp>(loc, vecType);
+  for (unsigned i = 0; i < vecType.getNumElements(); ++i) {
+    Value iter = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                   /*value=*/i);
+    Value iterLdm = rewriter.create<LLVM::MulOp>(loc, iter, leadingDim);
+    Value curInx = rewriter.create<LLVM::AddOp>(loc, iterLdm, wrappedLaneId);
+    Value curAddress = rewriter.create<LLVM::GEPOp>(
+        loc, dataPtr.getType(), vecType.getElementType(), dataPtr, curInx);
+    // Load the value from the current index.
+    Value loaded = rewriter.create<LLVM::LoadOp>(loc, vecType.getElementType(),
+                                                 curAddress);
+    loadedValues = rewriter.create<LLVM::InsertElementOp>(
+        loc, vecType, loadedValues, loaded, iter);
+  }
+}
+
+/// Generate load ops for `COp`. `dataPtr` is the base address starting
+/// from which values will be loaded. `laneId` is the lane ID  of the
+/// thread loading the values. `vecType` is the vector type of the values that
+/// will be loaded. The loaded values are returned in `loadedValues`. The
+/// address for loading the values is generated in the following manner:
+///
+/// wrappedLaneId = laneId % 16
+/// for i in vectorSize {
+///   row = i * 2 + (laneId / 16)
+///   if opSelect
+///     loadedValues[i * 2 + 1] = dataPtr + ((row * leadingDim) +
+///     wrappedLaneId);
+///   else
+///     loadedValues[i * 2] = dataPtr + ((row * leadingDim) + wrappedLaneId);
+/// }
+static void generateCLoadOpsLaneFirst(bool opSelect, Location loc,
+                                      Value dataPtr, Value laneId,
+                                      Value leadingDim, VectorType vecType,
+                                      PatternRewriter &rewriter,
+                                      Value &loadedValues) {
+  // We wrap the laneId to 16 because of matrix replication in RDNA 3.
+  Value wrapSize = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                     /*value=*/16);
+  Value wrappedLaneId = rewriter.create<LLVM::SRemOp>(loc, laneId, wrapSize);
+  loadedValues = rewriter.create<LLVM::UndefOp>(loc, vecType);
+  Value constTwo = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                     /*value=*/2);
+  Value sixteen = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                    /*value=*/16);
+  Value laneIdHalf = rewriter.create<LLVM::SDivOp>(loc, laneId, sixteen);
+  for (unsigned i = 0; i < vecType.getNumElements(); ++i) {
+    Value iter = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                   /*value=*/i);
+    Value iterTwo = rewriter.create<LLVM::MulOp>(loc, iter, constTwo);
+    Value row = rewriter.create<LLVM::AddOp>(loc, iterTwo, laneIdHalf);
+    Value rowLdm = rewriter.create<LLVM::MulOp>(loc, row, leadingDim);
+    Value curInx = rewriter.create<LLVM::AddOp>(loc, rowLdm, wrappedLaneId);
+    Value curAddress = rewriter.create<LLVM::GEPOp>(
+        loc, dataPtr.getType(), vecType.getElementType(), dataPtr, curInx);
+    // Load the value from the current index.
+    Value loaded = rewriter.create<LLVM::LoadOp>(loc, vecType.getElementType(),
+                                                 curAddress);
+    // We have to skip every second element if opselect is true.
+    Value inx = iter;
+    if (vecType.getElementType().isF16()) {
+      if (opSelect) {
+        Value constOne =
+            rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                              /*value=*/1);
+        inx = rewriter.create<LLVM::AddOp>(loc, iterTwo, constOne);
+      } else {
+        inx = iterTwo;
+      }
+    }
+    loadedValues = rewriter.create<LLVM::InsertElementOp>(
+        loc, vecType, loadedValues, loaded, inx);
+  }
+}
+
+/// Generate load ops for `AOp`, `BOp`, or `COp`. `opSelect` is the opSelect bit
+/// governing how to store/load half precision `COp` values. `transpose` tells
+/// if the matrix has to be loaded in a transposed manner. `frag` is the type of
+/// the WMMA operand being loaded. `dataPtr` is the base address starting from
+/// which values will be loaded. `vecType` is the vector type of the values that
+/// will be loaded. The loaded values are returned in `loadedValues`.
+static LogicalResult generateLoadOps(bool opSelect, bool transpose,
+                                     Location loc, ROCDL::ROCDLWMMAFrag frag,
+                                     unsigned indexBitwidth, Value dataPtr,
+                                     Value leadingDim, VectorType vecType,
+                                     PatternRewriter &rewriter,
+                                     Value &loadedValues) {
+  Value laneId = amd::getLaneId(rewriter, loc, indexBitwidth);
+  Type eltType = vecType.getElementType();
+  if (frag == ROCDL::ROCDLWMMAFrag::a && !transpose && eltType.isF16()) {
+    generateAbLoadOpsVecFirst(loc, dataPtr, laneId, leadingDim, vecType,
+                              rewriter, loadedValues);
+    return success();
+  }
+  if (frag == ROCDL::ROCDLWMMAFrag::a && transpose && eltType.isF16()) {
+    generateAbLoadOpsLaneFirst(loc, dataPtr, laneId, leadingDim, vecType,
+                               rewriter, loadedValues);
+    return success();
+  }
+  if (frag == ROCDL::ROCDLWMMAFrag::b && transpose && eltType.isF16()) {
+    generateAbLoadOpsVecFirst(loc, dataPtr, laneId, leadingDim, vecType,
+                              rewriter, loadedValues);
+    return success();
+  }
+  if (frag == ROCDL::ROCDLWMMAFrag::b && !transpose && eltType.isF16()) {
+    generateAbLoadOpsLaneFirst(loc, dataPtr, laneId, leadingDim, vecType,
+                               rewriter, loadedValues);
+    return success();
+  }
+  if (frag == ROCDL::ROCDLWMMAFrag::c && !transpose &&
+      (eltType.isF32() || eltType.isF16())) {
+    generateCLoadOpsLaneFirst(opSelect, loc, dataPtr, laneId, leadingDim,
+                              vecType, rewriter, loadedValues);
+    return success();
+  }
+
+  return failure();
+}
+
+/// This class implements the conversion of GPU MMA loadOp to wmma.load op
+/// in the ROCDL dialect. The conversion not only emits the ROCDL op but also
+/// emits code that is necessary to store the data in the destination memref
+/// after it has been loaded.
+struct WmmaLoadOpToROCDLLowering
+    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> {
+  WmmaLoadOpToROCDLLowering(LLVMTypeConverter &typeConverter, StringRef chip,
+                            unsigned indexBitwidth, unsigned warpSize)
+      : ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp>(typeConverter),
+        indexBitwidth(indexBitwidth), warpSize(warpSize), chip(chip){};
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (failed(areAllLLVMTypes(subgroupMmaLoadMatrixOp.getOperation(),
+                               adaptor.getOperands(), rewriter)))
+      return failure();
+
+    std::size_t firstPos = chip.find("gfx11");
+    std::size_t lastPos = chip.rfind("gfx11");
+    if (firstPos != 0 || (firstPos != lastPos))
+      return subgroupMmaLoadMatrixOp->emitError(
+          "wmma lowering is supported for gfx11 series only");
+
+    if (warpSize != amd::kWaveFrontSize32)
+      return subgroupMmaLoadMatrixOp->emitError(
+          "only size 32 wavefronts are supported");
+
+    auto transpose = subgroupMmaLoadMatrixOp.getTranspose();
+    gpu::MMAMatrixType retType =
+        subgroupMmaLoadMatrixOp.getRes().getType().cast<gpu::MMAMatrixType>();
+    SmallVector<int64_t> retTypeShape(retType.getShape());
+
+    if (!llvm::all_of(retTypeShape, [](int dim) { return dim == 16; }))
+      return subgroupMmaLoadMatrixOp->emitError(
+          "wmma ops of shape 16x16x16 are only supported.");
+
+    auto srcMemrefType =
+        subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast<MemRefType>();
+
+    if (srcMemrefType.getElementType() != retType.getElementType())
+      return subgroupMmaLoadMatrixOp->emitError(
+          "src memref type and mma matrix element type must be same");
+
+    // Get the LLVM type of corresponding to the result MMAMatrixType.
+    Type llvmRetType = amd::convertWMMAToROCDLLLVMType(retType);
+
+    // We need to declare a vector type and then emit instructions to load the
+    // elements into the vector type.
+    Location loc = subgroupMmaLoadMatrixOp.getLoc();
+    Value dataPtr =
+        getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
+                             adaptor.getIndices(), rewriter);
+
+    Value leadingDim = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI32Type(),
+        subgroupMmaLoadMatrixOp.getLeadDimensionAttr());
+
+    Value loadedValues;
+    ROCDL::ROCDLWMMAFrag operand = convertOperand(retType.getOperand());
+    if (auto vecType = dyn_cast<VectorType>(llvmRetType)) {
+      bool opSelect = subgroupMmaLoadMatrixOp->hasAttrOfType<UnitAttr>(
+          amd::kAMDGpuOpselectAttrName);
+      if (failed(generateLoadOps(opSelect,
+                                 transpose.has_value() && transpose.value(),
+                                 loc, operand, indexBitwidth, dataPtr,
+                                 leadingDim, vecType, rewriter, loadedValues)))
+        return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp,
+                                           "unsupported load op variant.");
+      rewriter.replaceOp(subgroupMmaLoadMatrixOp, loadedValues);
+      return success();
+    }
+    return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp,
+                                       "unsupported load op variant.");
+  }
+
+  /// Index bitwidth to use in any index calculation.
+  unsigned indexBitwidth;
+
+  /// `warpSize` is the warp size to use when generating WMMA intrinsics.
+  unsigned warpSize;
+
+  /// The target chip for which to generate the lowering.
+  std::string chip;
+};
+
+/// Generate store ops for `COp`. `dataPtr` is the base address starting
+/// to which the values will be stored. `laneId` is the lane ID  of the
+/// thread loading the values. `vecType` is the vector type of the values that
+/// are being stored. The values to be stored are supplied in `toStore`. The
+/// address for storing the values is generated in the following manner:
+///
+/// wrappedLaneId = laneId % 16
+/// for i in vectorSize {
+///   row = i * 2 + (laneId / 16)
+///   if opSelect
+///     store toStore[i * 2 + 1], dataPtr + ((row * leadingDim) + wrappedLaneId)
+///   else
+///     store toStore[i * 2], dataPtr + ((row * leadingDim) + wrappedLaneId)
+/// }
+static void generateCStoreOpsLaneFirst(bool opSelect, Location loc,
+                                       Value dataPtr, Value laneId,
+                                       Value leadingDim, VectorType vecType,
+                                       Value toStore,
+                                       PatternRewriter &rewriter) {
+  // We wrap the laneId to 16 because of matrix replication in RDNA 3.
+  Value wrapSize = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                     /*value=*/16);
+  Value wrappedLaneId = rewriter.create<LLVM::SRemOp>(loc, laneId, wrapSize);
+  Value constSixteen =
+      rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                        /*value=*/16);
+  Value constTwo = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                     /*value=*/2);
+  Value laneIdHalf = rewriter.create<LLVM::SDivOp>(loc, laneId, constSixteen);
+  for (int i = 0; i < vecType.getNumElements(); ++i) {
+    Value inx = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                                  /*value=*/i);
+    Value inxTimesTwo = rewriter.create<LLVM::MulOp>(loc, inx, constTwo);
+    Value row = rewriter.create<LLVM::AddOp>(loc, laneIdHalf, inxTimesTwo);
+    Value rowLdm = rewriter.create<LLVM::MulOp>(loc, row, leadingDim);
+    Value offset = rewriter.create<LLVM::AddOp>(loc, rowLdm, wrappedLaneId);
+    Value storeAddress = rewriter.create<LLVM::GEPOp>(
+        loc, dataPtr.getType(), vecType.getElementType(), dataPtr, offset);
+    Value toStoreAtInx;
+    if (vecType.getElementType().isF16()) {
+      if (!opSelect) {
+        toStoreAtInx = rewriter.create<LLVM::ExtractElementOp>(
+            loc, vecType.getElementType(), toStore, inxTimesTwo);
+
+      } else {
+        Value constOne =
+            rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
+                                              /*value=*/1);
+        Value inxTimesTwoAddOne =
+            rewriter.create<LLVM::AddOp>(loc, inxTimesTwo, constOne);
+        toStoreAtInx = rewriter.create<LLVM::ExtractElementOp>(
+            loc, vecType.getElementType(), toStore, inxTimesTwoAddOne);
+      }
+    } else if (vecType.getElementType().isF32()) {
+      toStoreAtInx = rewriter.create<LLVM::ExtractElementOp>(
+          loc, vecType.getElementType(), toStore, inx);
+    }
+    rewriter.create<LLVM::StoreOp>(loc, toStoreAtInx, storeAddress);
+  }
+}
+
+/// Generate store ops for `COp`. `opSelect` is the opSelect bit governing how
+/// to store half precision `COp` values. `frag` is the type of the WMMA
+/// operand being stored. `dataPtr` is the base address starting from which
+/// starting from which the values will be stored. `vecType` is the vector type
+/// of the values being stored. `toStore` contains the values to be stored.
+static LogicalResult generateStoreOps(bool opSelect, Location loc,
+                                      ROCDL::ROCDLWMMAFrag frag, Value dataPtr,
+                                      unsigned indexBitwidth, Value leadingDim,
+                                      VectorType vecType, Value toStore,
+                                      PatternRewriter &rewriter) {
+  // Store ops can only be generated for C operands.
+  if (frag != ROCDL::ROCDLWMMAFrag::c)
+    return emitError(toStore.getLoc(), "only COp can be stored");
+
+  // Get the laneID.
+  Value laneId = amd::getLaneId(rewriter, loc, indexBitwidth);
+  Type eltType = vecType.getElementType();
+  if (eltType.isF16() || eltType.isF32()) {
+    generateCStoreOpsLaneFirst(opSelect, loc, dataPtr, laneId, leadingDim,
+                               vecType, toStore, rewriter);
+    return success();
+  }
+
+  return failure();
+}
+
+/// This class implements the conversion of GPU MMA storeOp to wmma.store op
+/// in the ROCDL dialect.
+struct WmmaStoreOpToROCDLowering
+    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> {
+  WmmaStoreOpToROCDLowering(LLVMTypeConverter &typeConverter, StringRef chip,
+                            unsigned indexBitwidth, unsigned warpSize)
+      : ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp>(typeConverter),
+        indexBitwidth(indexBitwidth), warpSize(warpSize), chip(chip){};
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (failed(areAllLLVMTypes(subgroupMmaStoreMatrixOp.getOperation(),
+                               adaptor.getOperands(), rewriter)))
+      return failure();
+
+    std::size_t firstPos = chip.find("gfx11");
+    std::size_t lastPos = chip.rfind("gfx11");
+    if (firstPos != 0 || (firstPos != lastPos))
+      return subgroupMmaStoreMatrixOp->emitError(
+          "wmma lowering is supported for gfx11 series only");
+
+    if (warpSize != amd::kWaveFrontSize32)
+      return subgroupMmaStoreMatrixOp->emitError(
+          "wavefront of size 32 only supported");
+
+    Location loc = subgroupMmaStoreMatrixOp->getLoc();
+
+    auto transpose = subgroupMmaStoreMatrixOp.getTranspose();
+    if (transpose.has_value() && transpose.value())
+      return subgroupMmaStoreMatrixOp->emitError(
+          "lowering with transpose is not supported.");
+
+    gpu::MMAMatrixType retType =
+        subgroupMmaStoreMatrixOp.getSrc().getType().cast<gpu::MMAMatrixType>();
+    SmallVector<int64_t> retTypeShape(retType.getShape());
+
+    if (!llvm::all_of(retTypeShape, [](int dim) { return dim == 16; }))
+      return subgroupMmaStoreMatrixOp->emitError(
+          "wmma ops of shape 16x16x16 are only supported.");
+
+    auto dstMemrefType =
+        subgroupMmaStoreMatrixOp.getDstMemref().getType().cast<MemRefType>();
+
+    if (dstMemrefType.getElementType() != retType.getElementType())
+      return subgroupMmaStoreMatrixOp->emitError(
+          "dst memref type and mma matrix element type must be same");
+
+    Value dataPtr = getStridedElementPtr(
+        loc,
+        subgroupMmaStoreMatrixOp.getDstMemref().getType().cast<MemRefType>(),
+        adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
+    Value leadingDim = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI32Type(),
+        subgroupMmaStoreMatrixOp.getLeadDimensionAttr());
+
+    // Get the LLVM type of corresponding to the result MMAMatrixType.
+    Type llvmRetType = amd::convertWMMAToROCDLLLVMType(retType);
+
+    Value toStore = adaptor.getSrc();
+
+    bool opSelect = subgroupMmaStoreMatrixOp->hasAttrOfType<UnitAttr>(
+        amd::kAMDGpuOpselectAttrName);
+    if (auto vecType = dyn_cast<VectorType>(llvmRetType)) {
+      if (failed(generateStoreOps(
+              opSelect, loc, convertOperand(retType.getOperand()), dataPtr,
+              indexBitwidth, leadingDim, vecType, toStore, rewriter)))
+        return rewriter.notifyMatchFailure(subgroupMmaStoreMatrixOp,
+                                           "unsupported store op variant.");
+    }
+    rewriter.eraseOp(subgroupMmaStoreMatrixOp);
+    return success();
+  }
+
+  /// Index bitwidth to use in any index calculation.
+  unsigned indexBitwidth;
+
+  /// `warpSize` is the warp size to use when generating WMMA intrinsics.
+  unsigned warpSize;
+
+  /// The target chip for which to generate the lowering.
+  std::string chip;
+};
+} // namespace
+
+// Convert the MMAMatrix type to LLVM types based of the elemental type of
+// MMAMatrixType.
+Type mlir::amd::convertWMMAToROCDLLLVMType(
+    mlir::gpu::MMAMatrixType matrixType) {
+  Type eltType = matrixType.getElementType();
+  ROCDL::ROCDLWMMAFrag frag = convertOperand(matrixType.getOperand());
+  if (eltType.isF16() &&
+      (frag == ROCDL::ROCDLWMMAFrag::a || frag == ROCDL::ROCDLWMMAFrag::b ||
+       frag == ROCDL::ROCDLWMMAFrag::c))
+    return VectorType::get({16}, eltType);
+  if (eltType.isF32() && frag == ROCDL::ROCDLWMMAFrag::c)
+    return VectorType::get({8}, eltType);
+
+  llvm_unreachable("Unsupported data type");
+}
+
+void mlir::populateGpuWMMAToROCDLConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef chip,
+    unsigned indexBitwidth, unsigned warpSize) {
+  patterns.add<WmmaLoadOpToROCDLLowering, WmmaStoreOpToROCDLowering>(
+      converter, chip, indexBitwidth, warpSize);
+}
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 6fc9ae0f3fc58fa..dbbe1d3a13ec512 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -31,12 +31,14 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS)
   option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
   option(MLIR_RUN_X86VECTOR_TESTS "Run X86Vector tests.")
   option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
+  option(MLIR_RUN_ROCM_WMMA_TESTS "Run WMMA tests for AMD GPU.")
   option(MLIR_RUN_CUDA_SM80_TESTS "Run CUDA A100 tests.")
   option(MLIR_RUN_CUDA_SM80_LT_TESTS "Run CUDA A100 structured sparsity tests.")
   option(MLIR_RUN_CUDA_SM90_TESTS "Run CUDA H100 tests.")
   option(MLIR_RUN_ARM_SVE_TESTS "Run Arm SVE tests.")
   option(MLIR_RUN_ARM_SME_TESTS "Run Arm SME tests.")
 
+  set(GFX_WMMA_TARGET "gfx1100")
 
   # The native target may not be enabled when cross compiling, raise an error.
   if(NOT MLIR_ENABLE_EXECUTION_ENGINE)
@@ -69,6 +71,7 @@ llvm_canonicalize_cmake_booleans(
   MLIR_INCLUDE_INTEGRATION_TESTS
   MLIR_RUN_AMX_TESTS
   MLIR_RUN_CUDA_TENSOR_CORE_TESTS
+  MLIR_RUN_ROCM_WMMA_TESTS
   MLIR_RUN_X86VECTOR_TESTS
   MLIR_RUN_ARM_SVE_TESTS
   MLIR_RUN_ARM_SME_TESTS
diff --git a/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported-chipset.mlir b/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported-chipset.mlir
new file mode 100644
index 000000000000000..be81717db6aa195
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported-chipset.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx900 index-bitwidth=32' -split-input-file -verify-diagnostics
+
+gpu.module @main {
+  // CHECK-LABEL: load_a_op_16_16_16_no_transpose_invalid_shape
+  func.func @load_a_op_16_16_16_no_transpose()->(!gpu.mma_matrix<16x16xf16, "AOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    // expected-error at -1 {{wmma lowering is supported for gfx11 series only}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}}
+    return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: store_cop_f32
+  func.func @store_cop_f32(%arg0: !gpu.mma_matrix<16x16xf32, "COp">) -> () {
+    %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf32, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<32x32xf32, 3>
+    // expected-error at -1 {{wmma lowering is supported for gfx11 series only}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_store_matrix' that was explicitly marked illegal}}
+    return
+  }
+}
+
diff --git a/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported.mlir b/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported.mlir
new file mode 100644
index 000000000000000..4b9ea96913c00cc
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported.mlir
@@ -0,0 +1,181 @@
+// This file tests the we error out properly when unsupported ops are
+// encountered for GPU wmma ops to ROCDL conversion.
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx1100 index-bitwidth=32' -split-input-file -verify-diagnostics
+
+gpu.module @main {
+  // CHECK-LABEL: load_a_op_16_16_16_no_transpose_invalid_shape
+  func.func @load_a_op_16_16_16_no_transpose()->(!gpu.mma_matrix<32x8xf16, "AOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "AOp">
+    // expected-error at -1 {{wmma ops of shape 16x16x16 are only supported.}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}}
+    return %0 : !gpu.mma_matrix<32x8xf16, "AOp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_a_op_16_16_16_transpose_invalid_shape
+  func.func @load_a_op_16_16_16_transpose()->(!gpu.mma_matrix<32x8xf16, "AOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "AOp">
+    // expected-error at -1 {{wmma ops of shape 16x16x16 are only supported.}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}}
+    return %0 : !gpu.mma_matrix<32x8xf16, "AOp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_a_op_16_16_16_no_transpose_invalid_types
+  func.func @load_a_op_16_16_16_no_transpose_invalid_types()->(!gpu.mma_matrix<16x16xf16, "AOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf32, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf32, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    // expected-error at -1 {{src memref type and mma matrix element type must be same}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}}
+    return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_b_op_16_16_16_no_transpose_invalid_shape
+  func.func @load_b_op_16_16_16_no_transpose()->(!gpu.mma_matrix<32x8xf16, "BOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "BOp">
+    // expected-error at -1 {{wmma ops of shape 16x16x16 are only supported.}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}}
+    return %0 : !gpu.mma_matrix<32x8xf16, "BOp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_b_op_16_16_16_transpose_invalid_shape
+  func.func @load_b_op_16_16_16_transpose()->(!gpu.mma_matrix<32x8xf16, "BOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "BOp">
+    // expected-error at -1 {{wmma ops of shape 16x16x16 are only supported.}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}}
+    return %0 : !gpu.mma_matrix<32x8xf16, "BOp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_b_op_16_16_16_no_transpose_invalid_types
+  func.func @load_b_op_16_16_16_no_transpose_invalid_types()->(!gpu.mma_matrix<16x16xf16, "BOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf32, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf32, 3> -> !gpu.mma_matrix<16x16xf16, "BOp">
+    // expected-error at -1 {{src memref type and mma matrix element type must be same}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}}
+    return %0 : !gpu.mma_matrix<16x16xf16, "BOp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_c_op_16_16_16_no_transpose_invalid_shape
+  func.func @load_c_op_16_16_16_no_transpose()->(!gpu.mma_matrix<32x8xf16, "COp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "COp">
+    // expected-error at -1 {{wmma ops of shape 16x16x16 are only supported.}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}}
+    return %0 : !gpu.mma_matrix<32x8xf16, "COp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_c_op_16_16_16_transpose_invalid_shape
+  func.func @load_c_op_16_16_16_transpose()->(!gpu.mma_matrix<32x8xf16, "COp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "COp">
+    // expected-error at -1 {{wmma ops of shape 16x16x16 are only supported.}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}}
+    return %0 : !gpu.mma_matrix<32x8xf16, "COp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_c_op_16_16_16_no_transpose_invalid_types
+  func.func @load_c_op_16_16_16_no_transpose_invalid_types()->(!gpu.mma_matrix<16x16xf16, "COp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf32, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf32, 3> -> !gpu.mma_matrix<16x16xf16, "COp">
+    // expected-error at -1 {{src memref type and mma matrix element type must be same}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}}
+    return %0 : !gpu.mma_matrix<16x16xf16, "COp">
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: store_cop_f32
+  func.func @store_cop_f32(%arg0: !gpu.mma_matrix<32x8xf32, "COp">) -> () {
+    %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf32, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index} : !gpu.mma_matrix<32x8xf32, "COp">, memref<32x32xf32, 3>
+    // expected-error at -1 {{wmma ops of shape 16x16x16 are only supported.}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_store_matrix' that was explicitly marked illegal}}
+    return
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: store_cop_f32
+  func.func @store_cop_f32(%arg0: !gpu.mma_matrix<16x16xf32, "COp">) -> () {
+    %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf32, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index, transpose} : !gpu.mma_matrix<16x16xf32, "COp">, memref<32x32xf32, 3>
+    // expected-error at -1 {{lowering with transpose is not supported.}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_store_matrix' that was explicitly marked illegal}}
+    return
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: store_cop_f32
+  func.func @store_cop_f32(%arg0: !gpu.mma_matrix<16x16xf32, "COp">) -> () {
+    %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<32x32xf16, 3>
+    // expected-error at -1 {{dst memref type and mma matrix element type must be same}}
+    // expected-error at -2 {{failed to legalize operation 'gpu.subgroup_mma_store_matrix' that was explicitly marked illegal}}
+    return
+  }
+}
diff --git a/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl.mlir
new file mode 100644
index 000000000000000..1ae5e54660e9eb7
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl.mlir
@@ -0,0 +1,442 @@
+// This file tests the conversion of GPU WMMA ops to ROCDL dialect.
+// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx1100 index-bitwidth=32' -split-input-file | FileCheck %s
+
+gpu.module @main {
+  // CHECK-LABEL: load_a_op_16_16_16_no_transpose
+  func.func @load_a_op_16_16_16_no_transpose()->(!gpu.mma_matrix<16x16xf16, "AOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK:       %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK-NEXT:  %[[C0_I32:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[C_1_I32:.*]] = llvm.mlir.constant(-1 : i32) : i32
+    // CHECK-NEXT:  %[[MBCNT_LO:.*]] = rocdl.mbcnt.lo %[[C_1_I32]], %[[C0_I32]] : (i32, i32) -> i32
+    // CHECK-NEXT:  %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %[[C_1_I32]], %[[MBCNT_LO]] : (i32, i32) -> i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]]  : i32
+    // The part checked up to this point will be common in most of the WMMA op
+    // lowerings. Checking all of these lines will be skipped in the subsequent
+    // tests as the same utility emits the IR up to this point. Only some
+    // values which are used later will be matched.
+    // CHECK-NEXT:  %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16>
+    // CHECK-NEXT:  %[[WRAPPEDTID32:.*]] = llvm.mul %[[WRAPPEDTID]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[OFFSET0:.*]] = llvm.add %[[WRAPPEDTID32]], %[[C0]]  : i32
+    // CHECK-NEXT:  %[[LOADADDR0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED0:.*]] = llvm.load %[[LOADADDR0]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[C0]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[OFFSET1:.*]] = llvm.add %[[WRAPPEDTID32]], %[[C1]]  : i32
+    // CHECK-NEXT:  %[[LOADADDR1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED1:.*]] = llvm.load %[[LOADADDR1]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[LOADEDVALS1:.*]] = llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[C1]] : i32] : vector<16xf16>
+    // We just check the loading and insertion of two values only, rest of the
+    // values need not be checked as they are emitted in a loop just with
+    // different parameters.
+    // CHECK:       %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32
+    // CHECK:       %[[OFFSET15:.*]] = llvm.add %[[WRAPPEDTID32]], %{{.*}} : i32
+    // CHECK-NEXT:  %[[LOADADDR15:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET15]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADEDVALS15:.*]] = llvm.load %[[LOADADDR15]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[RES:.*]] = llvm.insertelement %[[LOADEDVALS15]], %{{.*}}[%[[C15]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.return %[[RES]] : vector<16xf16>
+    return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_a_op_16_16_16_transpose 
+  func.func @load_a_op_16_16_16_transpose()->(!gpu.mma_matrix<16x16xf16, "AOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK:       %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[WRAPPEDTID:.*]] = llvm.srem %{{.*}}, {{.*}}  : i32
+    // CHECK-NEXT:  %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16>
+    // CHECK-NEXT:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[ROW0:.*]] = llvm.mul %[[C0]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET0:.*]] = llvm.add %[[ROW0]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED0:.*]] = llvm.load %[[ADDRESS0]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[C0]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[ROW1:.*]] = llvm.mul %[[C1]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET1:.*]] = llvm.add %[[ROW1]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED1:.*]] = llvm.load %[[ADDRESS1]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[C1]] : i32] : vector<16xf16>
+    // We just check the loading and insertion of two values only, rest of the
+    // values need not be checked as they are emitted in a loop just with
+    // different parameters.
+    // CHECK:       %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32
+    // CHECK-NEXT:  %[[ROW15:.*]] = llvm.mul %[[C15]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET15:.*]] = llvm.add %[[ROW15]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS15:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET15]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED15:.*]] = llvm.load %[[ADDRESS15]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[RES:.*]] = llvm.insertelement %[[LOADED15]], %{{.*}}[%[[C15]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.return %[[RES]] : vector<16xf16>
+    return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_b_op_16_16_16_no_transpose
+  func.func @load_b_op_16_16_16_no_transpose()->(!gpu.mma_matrix<16x16xf16, "BOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "BOp">
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK:       %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[WRAPPEDTID:.*]] = llvm.srem %{{.*}}, {{.*}}  : i32
+    // CHECK:       %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16>
+    // CHECK-NEXT:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[ROW0:.*]] = llvm.mul %[[C0]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET0:.*]] = llvm.add %[[ROW0]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED0:.*]] = llvm.load %[[ADDRESS0]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[C0]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[ROW1:.*]] = llvm.mul %[[C1]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET1:.*]] = llvm.add %[[ROW1]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED1:.*]] = llvm.load %[[ADDRESS1]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[C1]] : i32] : vector<16xf16>
+    // We just check the loading and insertion of two values only, rest of the
+    // values need not be checked as they are emitted in a loop just with
+    // different parameters.
+    // CHECK:       %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32
+    // CHECK-NEXT:  %[[ROW15:.*]] = llvm.mul %[[C15]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET15:.*]] = llvm.add %[[ROW15]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS15:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET15]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED15:.*]] = llvm.load %[[ADDRESS15]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[RES:.*]] = llvm.insertelement %[[LOADED15]], %{{.*}}[%[[C15]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.return %[[RES]] : vector<16xf16>
+    return %0 : !gpu.mma_matrix<16x16xf16, "BOp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_b_op_16_16_16_transpose
+  func.func @load_b_op_16_16_16_transpose()->(!gpu.mma_matrix<16x16xf16, "BOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "BOp">
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK:       %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[WRAPPEDTID:.*]] = llvm.srem %{{.*}}, %{{.*}}  : i32
+    // CHECK-NEXT:  %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16>
+    // CHECK-NEXT:  %[[WRAPPEDTID32:.*]] = llvm.mul %[[WRAPPEDTID]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[OFFSET0:.*]] = llvm.add %[[WRAPPEDTID32]], %[[C0]]  : i32
+    // CHECK-NEXT:  %[[LOADADDR0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED0:.*]] = llvm.load %[[LOADADDR0]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[C0]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[OFFSET1:.*]] = llvm.add %[[WRAPPEDTID32]], %[[C1]]  : i32
+    // CHECK-NEXT:  %[[LOADADDR1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED1:.*]] = llvm.load %[[LOADADDR1]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[LOADEDVALS1:.*]] = llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[C1]] : i32] : vector<16xf16>
+    // We just check the loading and insertion of two values only, rest of the
+    // values need not be checked as they are emitted in a loop just with
+    // different parameters.
+    // CHECK:       %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32
+    // CHECK:       %[[OFFSET15:.*]] = llvm.add %[[WRAPPEDTID32]], %{{.*}} : i32
+    // CHECK-NEXT:  %[[LOADADDR15:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET15]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADEDVALS15:.*]] = llvm.load %[[LOADADDR15]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[RES:.*]] = llvm.insertelement %[[LOADEDVALS15]], %{{.*}}[%[[C15]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.return %[[RES]] : vector<16xf16>
+    return %0 : !gpu.mma_matrix<16x16xf16, "BOp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_c_op_16_16_16_no_opselect
+  func.func @load_c_op_16_16_16_no_opselect()->(!gpu.mma_matrix<16x16xf32, "COp">) {
+    %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf32, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg_1[%i, %j] {leadDimension = 32 : index} : memref<32x32xf32, 3> -> !gpu.mma_matrix<16x16xf32, "COp">
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+    // CHECK:       %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<8xf32>
+    // CHECK-NEXT:  %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]]  : i32
+    // CHECK-NEXT:  %[[ROW0:.*]] = llvm.add %[[ITER0]], %[[WTIDDIV16]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+    // CHECK-NEXT:  %[[LOADED0:.*]] = llvm.load %[[ADDRESS0]] : !llvm.ptr<3> -> f32
+    // CHECK-NEXT:  %[[LOADEDVAL0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[C0]] : i32] : vector<8xf32>
+    // CHECK-NEXT:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]]  : i32
+    // CHECK-NEXT:  %[[ROW1:.*]] = llvm.add %[[ITER1]], %[[WTIDDIV16]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+    // CHECK-NEXT:  %[[LOADED1:.*]] = llvm.load %[[ADDRESS1]] : !llvm.ptr<3> -> f32
+    // CHECK-NEXT:  %[[LOADEDVAL1:.*]] = llvm.insertelement %[[LOADED1]], %[[LOADEDVAL0]][%[[C1]] : i32] : vector<8xf32>
+    // We just check the loading and insertion of two values only, rest of the
+    // values need not be checked as they are emitted in a loop just with
+    // different parameters.
+    // CHECK:       %[[C7:.*]] = llvm.mlir.constant(7 : i32) : i32
+    // CHECK:       %[[RES:.*]] = llvm.insertelement %{{.*}}, %{{.*}}[%[[C7]] : i32] : vector<8xf32>
+    // CHECK-NEXT:  llvm.return %[[RES]] : vector<8xf32>
+    return %0 : !gpu.mma_matrix<16x16xf32, "COp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_c_op_16_16_16_no_opselect
+  func.func @load_c_op_16_16_16_no_opselect()->(!gpu.mma_matrix<16x16xf16, "COp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "COp">
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK:       %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16>
+    // CHECK-NEXT:  %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]]  : i32
+    // CHECK-NEXT:  %[[ROW0:.*]] = llvm.add %[[ITER0]], %[[WTIDDIV16]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED0:.*]] = llvm.load %[[ADDRESS0]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[ITER0]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]]  : i32
+    // CHECK-NEXT:  %[[ROW1:.*]] = llvm.add %[[ITER1]], %[[WTIDDIV16]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED1:.*]] = llvm.load %[[ADDRESS1]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[LOADEDVALS1:.*]] = llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[ITER1]] : i32] : vector<16xf16>
+    // CHECK:       %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32
+    // CHECK:       %[[RES:.*]] = llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.return %[[RES]] : vector<16xf16>
+    return %0 : !gpu.mma_matrix<16x16xf16, "COp">
+  }
+}
+
+// -----
+
+gpu.module @main {
+  // CHECK-LABEL: load_c_op_16_16_16_opselect
+  func.func @load_c_op_16_16_16_opselect()->(!gpu.mma_matrix<16x16xf16, "COp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, opSelect} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "COp">
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK:       %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16>
+    // CHECK-NEXT:  %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]]  : i32
+    // CHECK-NEXT:  %[[ROW0:.*]] = llvm.add %[[ITER0]], %[[WTIDDIV16]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED0:.*]] = llvm.load %[[ADDRESS0]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[C1C:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[VECOFFSET0:.*]] = llvm.add %[[ITER0]], %[[C1C]]  : i32
+    // CHECK-NEXT:  %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[VECOFFSET0]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]]  : i32
+    // CHECK-NEXT:  %[[ROW1:.*]] = llvm.add %[[ITER1]], %[[WTIDDIV16]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[LOADED1:.*]] = llvm.load %[[ADDRESS1]] : !llvm.ptr<3> -> f16
+    // CHECK-NEXT:  %[[C1C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[VECOFFSET1:.*]] = llvm.add %[[ITER1]], %[[C1C1]]  : i32
+    // CHECK-NEXT:  %[[LOADEDVALS1:.*]] = llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[VECOFFSET1]] : i32] : vector<16xf16>
+    // CHECK:       %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32
+    // CHECK:       %[[RES:.*]] = llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.return %[[RES]] : vector<16xf16>
+    return %0 : !gpu.mma_matrix<16x16xf16, "COp">
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: store_cop_f32
+  // CHECK-SAME: (%[[SRC:.*]]: vector<8xf32>)
+  func.func @store_cop_f32(%arg0: !gpu.mma_matrix<16x16xf32, "COp">) -> () {
+    %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf32, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<32x32xf32, 3>
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+    // CHECK:       %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
+    // CHECK-NEXT:  %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]] : i32
+    // CHECK-NEXT:  %[[ROW0:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER0]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+    // CHECK-NEXT:  %[[ELE0:.*]] = llvm.extractelement %[[SRC]][%[[C0]] : i32] : vector<8xf32>
+    // CHECK-NEXT:  llvm.store %[[ELE0]], %[[ADDRESS0]] : f32, !llvm.ptr<3>
+    // CHECK-NEXT:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]] : i32
+    // CHECK-NEXT:  %[[ROW1:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER1]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
+    // CHECK-NEXT:  %[[ELE1:.*]] = llvm.extractelement %[[SRC]][%[[C1]] : i32] : vector<8xf32>
+    // CHECK-NEXT:  llvm.store %[[ELE1]], %[[ADDRESS1]] : f32, !llvm.ptr<3>
+    // CHECK:       %[[C7:.*]] = llvm.mlir.constant(7 : i32) : i32
+    // CHECK:       %[[ELE7:.*]] = llvm.extractelement %[[SRC]][%[[C7]] : i32] : vector<8xf32>
+    // CHECK-NEXT:  llvm.store %[[ELE7]], %{{.*}} : f32, !llvm.ptr<3>
+    return
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: store_cop_f16_no_opsel
+  // CHECK-SAME: (%[[SRC:.*]]: vector<16xf16>)
+  func.func @store_cop_f16_no_opsel(%arg0: !gpu.mma_matrix<16x16xf16, "COp">) -> () {
+    %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK:       %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
+    // CHECK-NEXT:  %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]] : i32
+    // CHECK-NEXT:  %[[ROW0:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER0]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[ELE0:.*]] = llvm.extractelement %[[SRC]][%[[ITER0]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.store %[[ELE0]], %[[ADDRESS0]] : f16, !llvm.ptr<3>
+    // CHECK-NEXT:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]] : i32
+    // CHECK-NEXT:  %[[ROW1:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER1]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[ELE1:.*]] = llvm.extractelement %[[SRC]][%[[ITER1]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.store %[[ELE1]], %[[ADDRESS1]] : f16, !llvm.ptr<3>
+    // CHECK:       %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32
+    // CHECK:       %[[ITER15:.*]] = llvm.mul %[[C15]], %[[C2]] : i32
+    // CHECK:       %[[ADDRESS15:.*]] = llvm.getelementptr %[[BASE]][%{{.*}}] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[ELE15:.*]] = llvm.extractelement %[[SRC]][%[[ITER15]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.store %[[ELE15]], %[[ADDRESS15]] : f16, !llvm.ptr<3>
+    return
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: store_cop_f16_opsel
+  // CHECK-SAME: (%[[SRC:.*]]: vector<16xf16>)
+  func.func @store_cop_f16_opsel(%arg0: !gpu.mma_matrix<16x16xf16, "COp">) -> () {
+    %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index, opSelect} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
+    // CHECK:       llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK:       %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:       %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+    // CHECK-NEXT:  %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
+    // CHECK-NEXT:  %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]]  : i32
+    // CHECK-NEXT:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+    // CHECK-NEXT:  %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]] : i32
+    // CHECK-NEXT:  %[[ROW0:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER0]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[C01:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[INX0:.*]] = llvm.add %[[ITER0]], %[[C01]]  : i32
+    // CHECK-NEXT:  %[[ELE0:.*]] = llvm.extractelement %[[SRC]][%[[INX0]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.store %[[ELE0]], %[[ADDRESS0]] : f16, !llvm.ptr<3>
+    // CHECK-NEXT:  %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]] : i32
+    // CHECK-NEXT:  %[[ROW1:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER1]]  : i32
+    // CHECK-NEXT:  %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]]  : i32
+    // CHECK-NEXT:  %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]]  : i32
+    // CHECK-NEXT:  %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[C11:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[INX1:.*]] = llvm.add %[[ITER1]], %[[C11]]  : i32
+    // CHECK-NEXT:  %[[ELE1:.*]] = llvm.extractelement %[[SRC]][%[[INX1]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.store %[[ELE1]], %[[ADDRESS1]] : f16, !llvm.ptr<3>
+    // CHECK:       %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32
+    // CHECK-NEXT:  %[[ITER15:.*]] = llvm.mul %[[C15]], %[[C2]] : i32
+    // CHECK:       %[[ADDRESS15:.*]] = llvm.getelementptr %[[BASE]][%{{.*}}] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
+    // CHECK-NEXT:  %[[C151:.*]] = llvm.mlir.constant(1 : i32) : i32
+    // CHECK-NEXT:  %[[INX15:.*]] = llvm.add %[[ITER15]], %[[C151]]  : i32
+    // CHECK-NEXT:  %[[ELE15:.*]] = llvm.extractelement %[[SRC]][%[[INX15]] : i32] : vector<16xf16>
+    // CHECK-NEXT:  llvm.store %[[ELE15]], %[[ADDRESS15]] : f16, !llvm.ptr<3>
+    return
+  }
+}
diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/lit.local.cfg b/mlir/test/Integration/GPU/ROCM/WMMA/lit.local.cfg
new file mode 100644
index 000000000000000..15dca5bd9ca9338
--- /dev/null
+++ b/mlir/test/Integration/GPU/ROCM/WMMA/lit.local.cfg
@@ -0,0 +1,5 @@
+import sys
+
+# WMMA tests must be enabled via build flag.
+if not config.mlir_run_rocm_wmma_tests:
+  config.unsupported = True
diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16.mlir b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16.mlir
new file mode 100644
index 000000000000000..07874ed383f2e6a
--- /dev/null
+++ b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16.mlir
@@ -0,0 +1,95 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -convert-scf-to-cf \
+// RUN: | mlir-opt -gpu-kernel-outlining \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-amdgpu{chipset=%chip index-bitwidth=32},convert-gpu-to-rocdl{chipset=%chip index-bitwidth=32},reconcile-unrealized-casts),rocdl-attach-target{chip=%chip})' \
+// RUN: | mlir-opt -gpu-to-llvm -gpu-module-to-binary \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_rocm_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+func.func @main() {
+  %0 = memref.alloc() : memref<16x16xf16>
+  %22 = memref.alloc() : memref<16x16xf16>
+
+  %f1 = arith.constant 1.0e+00 : f16
+  %f0 = arith.constant 0.0e+00 : f16
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+
+  // Intialize the Input matrix with ones.
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      %cast_c = arith.index_cast %arg1 : index to i16
+      %cast_r = arith.index_cast %arg0 : index to i16
+      %add = arith.addi %cast_r, %cast_c : i16
+      %float = arith.sitofp %add : i16 to f16
+      memref.store %float, %0[%arg0, %arg1] : memref<16x16xf16>
+    }
+  }
+  // Intialize the accumulator matrix with zeros.
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      memref.store %f0, %22[%arg0, %arg1] : memref<16x16xf16>
+    }
+  }
+
+  %2 = memref.cast %0 : memref<16x16xf16> to memref<*xf16>
+
+  %stream = gpu.wait async
+  %gpu_in, %asyncToken_0 = gpu.alloc async [%stream] () : memref<16x16xf16>
+  %gpu_out, %asyncToken_1 = gpu.alloc async [%stream] () : memref<16x16xf16>
+
+  %asyncToken_2 = gpu.memcpy async [%stream] %gpu_in, %0 : memref<16x16xf16>, memref<16x16xf16>
+  %asyncToken_3 = gpu.memcpy async [%stream] %gpu_out, %22 : memref<16x16xf16>, memref<16x16xf16>
+
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+    %A = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %B = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+    %C = gpu.subgroup_mma_load_matrix %gpu_out[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+
+    %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+
+    gpu.subgroup_mma_store_matrix %R, %gpu_out[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+    gpu.terminator
+  }
+
+  %asyncToken_4 = gpu.memcpy async [%stream] %22, %gpu_out : memref<16x16xf16>, memref<16x16xf16>
+  gpu.wait [%stream]
+
+  %res_f32 = memref.alloc() : memref<16x16xf32>
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      %load = memref.load %gpu_out[%arg0, %arg1] : memref<16x16xf16>
+      %ext = arith.extf %load : f16 to f32
+      memref.store %ext, %res_f32[%arg0, %arg1] : memref<16x16xf32>
+    }
+  }
+  %res_f32_cast = memref.cast %res_f32 : memref<16x16xf32> to memref<*xf32>
+
+  // Print the memref after computation.
+  call @printMemrefF32(%res_f32_cast) : (memref<*xf32>) -> ()
+  // CHECK:      [1240,   1360,   1480,   1600,   1720,   1840,   1960,   2080,   2200,   2320,   2440,   2560,   2680,   2800,   2920,   3040],
+  // CHECK-NEXT: [1360,   1496,   1632,   1768,   1904,   2040,   2176,   2312,   2448,   2584,   2720,   2856,   2992,   3128,   3264,   3400],
+  // CHECK-NEXT: [1480,   1632,   1784,   1936,   2088,   2240,   2392,   2544,   2696,   2848,   3000,   3152,   3304,   3456,   3608,   3760],
+  // CHECK-NEXT: [1600,   1768,   1936,   2104,   2272,   2440,   2608,   2776,   2944,   3112,   3280,   3448,   3616,   3784,   3952,   4120],
+  // CHECK-NEXT: [1720,   1904,   2088,   2272,   2456,   2640,   2824,   3008,   3192,   3376,   3560,   3744,   3928,   4112,   4296,   4480],
+  // CHECK-NEXT: [1840,   2040,   2240,   2440,   2640,   2840,   3040,   3240,   3440,   3640,   3840,   4040,   4240,   4440,   4640,   4840],
+  // CHECK-NEXT: [1960,   2176,   2392,   2608,   2824,   3040,   3256,   3472,   3688,   3904,   4120,   4336,   4552,   4768,   4984,   5200],
+  // CHECK-NEXT: [2080,   2312,   2544,   2776,   3008,   3240,   3472,   3704,   3936,   4168,   4400,   4632,   4864,   5100,   5328,   5556],
+  // CHECK-NEXT: [2200,   2448,   2696,   2944,   3192,   3440,   3688,   3936,   4184,   4432,   4680,   4928,   5172,   5424,   5676,   5920],
+  // CHECK-NEXT: [2320,   2584,   2848,   3112,   3376,   3640,   3904,   4168,   4432,   4696,   4960,   5228,   5488,   5748,   6016,   6284],
+  // CHECK-NEXT: [2440,   2720,   3000,   3280,   3560,   3840,   4120,   4400,   4680,   4960,   5236,   5520,   5804,   6080,   6356,   6640],
+  // CHECK-NEXT: [2560,   2856,   3152,   3448,   3744,   4040,   4336,   4632,   4928,   5228,   5520,   5812,   6112,   6412,   6704,   6996],
+  // CHECK-NEXT: [2680,   2992,   3304,   3616,   3928,   4240,   4552,   4864,   5172,   5488,   5804,   6112,   6420,   6736,   7052,   7360],
+  // CHECK-NEXT: [2800,   3128,   3456,   3784,   4112,   4440,   4768,   5100,   5424,   5748,   6080,   6412,   6736,   7060,   7392,   7724],
+  // CHECK-NEXT: [2920,   3264,   3608,   3952,   4296,   4640,   4984,   5328,   5676,   6016,   6356,   6704,   7052,   7392,   7732,   8080],
+  // CHECK-NEXT: [3040,   3400,   3760,   4120,   4480,   4840,   5200,   5556,   5920,   6284,   6640,   6996,   7360,   7724,   8080,   8440]
+  return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_opselect.mlir b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_opselect.mlir
new file mode 100644
index 000000000000000..e151b0844db48ab
--- /dev/null
+++ b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_opselect.mlir
@@ -0,0 +1,95 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -convert-scf-to-cf \
+// RUN: | mlir-opt -gpu-kernel-outlining \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-amdgpu{chipset=%chip index-bitwidth=32},convert-gpu-to-rocdl{chipset=%chip index-bitwidth=32},reconcile-unrealized-casts),rocdl-attach-target{chip=%chip})' \
+// RUN: | mlir-opt -gpu-to-llvm -gpu-module-to-binary \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_rocm_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+func.func @main() {
+  %0 = memref.alloc() : memref<16x16xf16>
+  %22 = memref.alloc() : memref<16x16xf16>
+
+  %f1 = arith.constant 1.0e+00 : f16
+  %f0 = arith.constant 0.0e+00 : f16
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+
+  // Intialize the Input matrix with ones.
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      %cast_c = arith.index_cast %arg1 : index to i16
+      %cast_r = arith.index_cast %arg0 : index to i16
+      %add = arith.addi %cast_r, %cast_c : i16
+      %float = arith.sitofp %add : i16 to f16
+      memref.store %float, %0[%arg0, %arg1] : memref<16x16xf16>
+    }
+  }
+  // Intialize the accumulator matrix with zeros.
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      memref.store %f0, %22[%arg0, %arg1] : memref<16x16xf16>
+    }
+  }
+
+  %2 = memref.cast %0 : memref<16x16xf16> to memref<*xf16>
+
+  %stream = gpu.wait async
+  %gpu_in, %asyncToken_0 = gpu.alloc async [%stream] () : memref<16x16xf16>
+  %gpu_out, %asyncToken_1 = gpu.alloc async [%stream] () : memref<16x16xf16>
+
+  %asyncToken_2 = gpu.memcpy async [%stream] %gpu_in, %0 : memref<16x16xf16>, memref<16x16xf16>
+  %asyncToken_3 = gpu.memcpy async [%stream] %gpu_out, %22 : memref<16x16xf16>, memref<16x16xf16>
+
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+    %A = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %B = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+    %C = gpu.subgroup_mma_load_matrix %gpu_out[%c0, %c0] {leadDimension = 16 : index, opSelect} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+
+    %R = gpu.subgroup_mma_compute %A, %B, %C {opSelect} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+
+    gpu.subgroup_mma_store_matrix %R, %gpu_out[%c0, %c0] {leadDimension = 16 : index, opSelect}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+    gpu.terminator
+  }
+
+  %asyncToken_4 = gpu.memcpy async [%stream] %22, %gpu_out : memref<16x16xf16>, memref<16x16xf16>
+  gpu.wait [%stream]
+
+  %res_f32 = memref.alloc() : memref<16x16xf32>
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      %load = memref.load %gpu_out[%arg0, %arg1] : memref<16x16xf16>
+      %ext = arith.extf %load : f16 to f32
+      memref.store %ext, %res_f32[%arg0, %arg1] : memref<16x16xf32>
+    }
+  }
+  %res_f32_cast = memref.cast %res_f32 : memref<16x16xf32> to memref<*xf32>
+
+  // Print the memref after computation.
+  call @printMemrefF32(%res_f32_cast) : (memref<*xf32>) -> ()
+  // CHECK:      [1240,   1360,   1480,   1600,   1720,   1840,   1960,   2080,   2200,   2320,   2440,   2560,   2680,   2800,   2920,   3040],
+  // CHECK-NEXT: [1360,   1496,   1632,   1768,   1904,   2040,   2176,   2312,   2448,   2584,   2720,   2856,   2992,   3128,   3264,   3400],
+  // CHECK-NEXT: [1480,   1632,   1784,   1936,   2088,   2240,   2392,   2544,   2696,   2848,   3000,   3152,   3304,   3456,   3608,   3760],
+  // CHECK-NEXT: [1600,   1768,   1936,   2104,   2272,   2440,   2608,   2776,   2944,   3112,   3280,   3448,   3616,   3784,   3952,   4120],
+  // CHECK-NEXT: [1720,   1904,   2088,   2272,   2456,   2640,   2824,   3008,   3192,   3376,   3560,   3744,   3928,   4112,   4296,   4480],
+  // CHECK-NEXT: [1840,   2040,   2240,   2440,   2640,   2840,   3040,   3240,   3440,   3640,   3840,   4040,   4240,   4440,   4640,   4840],
+  // CHECK-NEXT: [1960,   2176,   2392,   2608,   2824,   3040,   3256,   3472,   3688,   3904,   4120,   4336,   4552,   4768,   4984,   5200],
+  // CHECK-NEXT: [2080,   2312,   2544,   2776,   3008,   3240,   3472,   3704,   3936,   4168,   4400,   4632,   4864,   5100,   5328,   5556],
+  // CHECK-NEXT: [2200,   2448,   2696,   2944,   3192,   3440,   3688,   3936,   4184,   4432,   4680,   4928,   5172,   5424,   5676,   5920],
+  // CHECK-NEXT: [2320,   2584,   2848,   3112,   3376,   3640,   3904,   4168,   4432,   4696,   4960,   5228,   5488,   5748,   6016,   6284],
+  // CHECK-NEXT: [2440,   2720,   3000,   3280,   3560,   3840,   4120,   4400,   4680,   4960,   5236,   5520,   5804,   6080,   6356,   6640],
+  // CHECK-NEXT: [2560,   2856,   3152,   3448,   3744,   4040,   4336,   4632,   4928,   5228,   5520,   5812,   6112,   6412,   6704,   6996],
+  // CHECK-NEXT: [2680,   2992,   3304,   3616,   3928,   4240,   4552,   4864,   5172,   5488,   5804,   6112,   6420,   6736,   7052,   7360],
+  // CHECK-NEXT: [2800,   3128,   3456,   3784,   4112,   4440,   4768,   5100,   5424,   5748,   6080,   6412,   6736,   7060,   7392,   7724],
+  // CHECK-NEXT: [2920,   3264,   3608,   3952,   4296,   4640,   4984,   5328,   5676,   6016,   6356,   6704,   7052,   7392,   7732,   8080],
+  // CHECK-NEXT: [3040,   3400,   3760,   4120,   4480,   4840,   5200,   5556,   5920,   6284,   6640,   6996,   7360,   7724,   8080,   8440]
+  return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_x2.mlir b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_x2.mlir
new file mode 100644
index 000000000000000..f3ea57ad3265ff1
--- /dev/null
+++ b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_x2.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -convert-scf-to-cf \
+// RUN: | mlir-opt -gpu-kernel-outlining \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-amdgpu{chipset=%chip index-bitwidth=32},convert-gpu-to-rocdl{chipset=%chip index-bitwidth=32},reconcile-unrealized-casts),rocdl-attach-target{chip=%chip})' \
+// RUN: | mlir-opt -gpu-to-llvm -gpu-module-to-binary \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_rocm_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+// This file implements WMMA operation in which there are two
+// subgroup_mma_compute operations which utilize the lower and upper halves of
+// the 32-bit accumulator registers.
+
+func.func @main() {
+  %0 = memref.alloc() : memref<16x16xf16>
+  %22 = memref.alloc() : memref<16x32xf16>
+
+  %f1 = arith.constant 1.0e+00 : f16
+  %f0 = arith.constant 0.0e+00 : f16
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+
+  // Intialize the Input matrix with ones.
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      %cast_c = arith.index_cast %arg1 : index to i16
+      %cast_r = arith.index_cast %arg0 : index to i16
+      %add = arith.addi %cast_r, %cast_c : i16
+      %float = arith.sitofp %add : i16 to f16
+      memref.store %float, %0[%arg0, %arg1] : memref<16x16xf16>
+    }
+  }
+  // Intialize the accumulator matrix with zeros.
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c32 step %c1 {
+      memref.store %f0, %22[%arg0, %arg1] : memref<16x32xf16>
+    }
+  }
+
+  %stream = gpu.wait async
+  %gpu_in, %asyncToken_0 = gpu.alloc async [%stream] () : memref<16x16xf16>
+  %gpu_out, %asyncToken_1 = gpu.alloc async [%stream] () : memref<16x32xf16>
+
+  %asyncToken_2 = gpu.memcpy async [%stream] %gpu_in, %0 : memref<16x16xf16>, memref<16x16xf16>
+  %asyncToken_3 = gpu.memcpy async [%stream] %gpu_out, %22 : memref<16x32xf16>, memref<16x32xf16>
+
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+    %A = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %B = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+    %C_lo = gpu.subgroup_mma_load_matrix %gpu_out[%c0, %c0] {leadDimension = 32 : index} : memref<16x32xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+    %C_hi = gpu.subgroup_mma_load_matrix %gpu_out[%c0, %c16] {leadDimension = 32 : index, opSelect} : memref<16x32xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+
+    %R_lo = gpu.subgroup_mma_compute %A, %B, %C_lo : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+    %R_hi = gpu.subgroup_mma_compute %A, %B, %C_hi {opSelect} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+
+    gpu.subgroup_mma_store_matrix %R_lo, %gpu_out[%c0, %c0] {leadDimension = 32 : index}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x32xf16>
+    gpu.subgroup_mma_store_matrix %R_hi, %gpu_out[%c0, %c16] {leadDimension = 32 : index, opSelect}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x32xf16>
+    gpu.terminator
+  }
+
+  %asyncToken_4 = gpu.memcpy async [%stream] %22, %gpu_out : memref<16x32xf16>, memref<16x32xf16>
+  gpu.wait [%stream]
+
+  %res_f32 = memref.alloc() : memref<16x32xf32>
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c32 step %c1 {
+      %load = memref.load %gpu_out[%arg0, %arg1] : memref<16x32xf16>
+      %ext = arith.extf %load : f16 to f32
+      memref.store %ext, %res_f32[%arg0, %arg1] : memref<16x32xf32>
+    }
+  }
+  %res_f32_cast = memref.cast %res_f32 : memref<16x32xf32> to memref<*xf32>
+
+  // Print the memref after computation.
+  call @printMemrefF32(%res_f32_cast) : (memref<*xf32>) -> ()
+  // CHECK:      [1240,   1360,   1480,   1600,   1720,   1840,   1960,   2080,   2200,   2320,   2440,   2560,   2680,   2800,   2920,   3040,   1240,   1360,   1480,   1600,   1720,   1840,   1960,   2080,   2200,   2320,   2440,   2560,   2680,   2800,   2920,   3040],
+  // CHECK-NEXT: [1360,   1496,   1632,   1768,   1904,   2040,   2176,   2312,   2448,   2584,   2720,   2856,   2992,   3128,   3264,   3400,   1360,   1496,   1632,   1768,   1904,   2040,   2176,   2312,   2448,   2584,   2720,   2856,   2992,   3128,   3264,   3400],
+  // CHECK-NEXT: [1480,   1632,   1784,   1936,   2088,   2240,   2392,   2544,   2696,   2848,   3000,   3152,   3304,   3456,   3608,   3760,   1480,   1632,   1784,   1936,   2088,   2240,   2392,   2544,   2696,   2848,   3000,   3152,   3304,   3456,   3608,   3760],
+  // CHECK-NEXT: [1600,   1768,   1936,   2104,   2272,   2440,   2608,   2776,   2944,   3112,   3280,   3448,   3616,   3784,   3952,   4120,   1600,   1768,   1936,   2104,   2272,   2440,   2608,   2776,   2944,   3112,   3280,   3448,   3616,   3784,   3952,   4120],
+  // CHECK-NEXT: [1720,   1904,   2088,   2272,   2456,   2640,   2824,   3008,   3192,   3376,   3560,   3744,   3928,   4112,   4296,   4480,   1720,   1904,   2088,   2272,   2456,   2640,   2824,   3008,   3192,   3376,   3560,   3744,   3928,   4112,   4296,   4480],
+  // CHECK-NEXT: [1840,   2040,   2240,   2440,   2640,   2840,   3040,   3240,   3440,   3640,   3840,   4040,   4240,   4440,   4640,   4840,   1840,   2040,   2240,   2440,   2640,   2840,   3040,   3240,   3440,   3640,   3840,   4040,   4240,   4440,   4640,   4840],
+  // CHECK-NEXT: [1960,   2176,   2392,   2608,   2824,   3040,   3256,   3472,   3688,   3904,   4120,   4336,   4552,   4768,   4984,   5200,   1960,   2176,   2392,   2608,   2824,   3040,   3256,   3472,   3688,   3904,   4120,   4336,   4552,   4768,   4984,   5200],
+  // CHECK-NEXT: [2080,   2312,   2544,   2776,   3008,   3240,   3472,   3704,   3936,   4168,   4400,   4632,   4864,   5100,   5328,   5556,   2080,   2312,   2544,   2776,   3008,   3240,   3472,   3704,   3936,   4168,   4400,   4632,   4864,   5100,   5328,   5556],
+  // CHECK-NEXT: [2200,   2448,   2696,   2944,   3192,   3440,   3688,   3936,   4184,   4432,   4680,   4928,   5172,   5424,   5676,   5920,   2200,   2448,   2696,   2944,   3192,   3440,   3688,   3936,   4184,   4432,   4680,   4928,   5172,   5424,   5676,   5920],
+  // CHECK-NEXT: [2320,   2584,   2848,   3112,   3376,   3640,   3904,   4168,   4432,   4696,   4960,   5228,   5488,   5748,   6016,   6284,   2320,   2584,   2848,   3112,   3376,   3640,   3904,   4168,   4432,   4696,   4960,   5228,   5488,   5748,   6016,   6284],
+  // CHECK-NEXT: [2440,   2720,   3000,   3280,   3560,   3840,   4120,   4400,   4680,   4960,   5236,   5520,   5804,   6080,   6356,   6640,   2440,   2720,   3000,   3280,   3560,   3840,   4120,   4400,   4680,   4960,   5236,   5520,   5804,   6080,   6356,   6640],
+  // CHECK-NEXT: [2560,   2856,   3152,   3448,   3744,   4040,   4336,   4632,   4928,   5228,   5520,   5812,   6112,   6412,   6704,   6996,   2560,   2856,   3152,   3448,   3744,   4040,   4336,   4632,   4928,   5228,   5520,   5812,   6112,   6412,   6704,   6996],
+  // CHECK-NEXT: [2680,   2992,   3304,   3616,   3928,   4240,   4552,   4864,   5172,   5488,   5804,   6112,   6420,   6736,   7052,   7360,   2680,   2992,   3304,   3616,   3928,   4240,   4552,   4864,   5172,   5488,   5804,   6112,   6420,   6736,   7052,   7360],
+  // CHECK-NEXT: [2800,   3128,   3456,   3784,   4112,   4440,   4768,   5100,   5424,   5748,   6080,   6412,   6736,   7060,   7392,   7724,   2800,   3128,   3456,   3784,   4112,   4440,   4768,   5100,   5424,   5748,   6080,   6412,   6736,   7060,   7392,   7724],
+  // CHECK-NEXT: [2920,   3264,   3608,   3952,   4296,   4640,   4984,   5328,   5676,   6016,   6356,   6704,   7052,   7392,   7732,   8080,   2920,   3264,   3608,   3952,   4296,   4640,   4984,   5328,   5676,   6016,   6356,   6704,   7052,   7392,   7732,   8080],
+  // CHECK-NEXT: [3040,   3400,   3760,   4120,   4480,   4840,   5200,   5556,   5920,   6284,   6640,   6996,   7360,   7724,   8080,   8440,   3040,   3400,   3760,   4120,   4480,   4840,   5200,   5556,   5920,   6284,   6640,   6996,   7360,   7724,   8080,   8440]
+  return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16.mlir b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16.mlir
new file mode 100644
index 000000000000000..597ff836aaf1776
--- /dev/null
+++ b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -convert-scf-to-cf \
+// RUN: | mlir-opt -gpu-kernel-outlining \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-amdgpu{chipset=%chip index-bitwidth=32},convert-gpu-to-rocdl{chipset=%chip index-bitwidth=32},reconcile-unrealized-casts),rocdl-attach-target{chip=%chip})' \
+// RUN: | mlir-opt -gpu-to-llvm -gpu-module-to-binary \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_rocm_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+func.func @main() {
+  %0 = memref.alloc() : memref<16x16xf16>
+  %22 = memref.alloc() : memref<16x16xf32>
+
+  %f1 = arith.constant 1.0e+00 : f16
+  %f0 = arith.constant 0.0e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+
+  // Intialize the Input matrix with ones.
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      %cast_c = arith.index_cast %arg1 : index to i16
+      %cast_r = arith.index_cast %arg0 : index to i16
+      %add = arith.addi %cast_r, %cast_c : i16
+      %float = arith.sitofp %add : i16 to f16
+      memref.store %float, %0[%arg0, %arg1] : memref<16x16xf16>
+    }
+  }
+  // Intialize the accumulator matrix with zeros.
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      memref.store %f0, %22[%arg0, %arg1] : memref<16x16xf32>
+    }
+  }
+
+  %2 = memref.cast %0 : memref<16x16xf16> to memref<*xf16>
+  %33 = memref.cast %22 : memref<16x16xf32> to memref<*xf32>
+
+  %stream = gpu.wait async
+  %gpu_in, %asyncToken_0 = gpu.alloc async [%stream] () : memref<16x16xf16>
+  %gpu_out, %asyncToken_1 = gpu.alloc async [%stream] () : memref<16x16xf32>
+
+  %asyncToken_2 = gpu.memcpy async [%stream] %gpu_in, %0 : memref<16x16xf16>, memref<16x16xf16>
+  %asyncToken_3 = gpu.memcpy async [%stream] %gpu_out, %22 : memref<16x16xf32>, memref<16x16xf32>
+
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+    %A = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %B = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+    %C = gpu.subgroup_mma_load_matrix %gpu_out[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf32> -> !gpu.mma_matrix<16x16xf32, "COp">
+
+    %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
+
+    gpu.subgroup_mma_store_matrix %R, %gpu_out[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
+    gpu.terminator
+  }
+
+  %asyncToken_4 = gpu.memcpy async [%stream] %22, %gpu_out : memref<16x16xf32>, memref<16x16xf32>
+  gpu.wait [%stream]
+
+  // Print the memref after computation.
+  call @printMemrefF32(%33) : (memref<*xf32>) -> ()
+  // CHECK:        [1240,   1360,   1480,   1600,   1720,   1840,   1960,   2080,   2200,   2320,   2440,   2560,   2680,   2800,   2920,   3040],
+  // CHECK-NEXT:   [1360,   1496,   1632,   1768,   1904,   2040,   2176,   2312,   2448,   2584,   2720,   2856,   2992,   3128,   3264,   3400],
+  // CHECK-NEXT:   [1480,   1632,   1784,   1936,   2088,   2240,   2392,   2544,   2696,   2848,   3000,   3152,   3304,   3456,   3608,   3760],
+  // CHECK-NEXT:   [1600,   1768,   1936,   2104,   2272,   2440,   2608,   2776,   2944,   3112,   3280,   3448,   3616,   3784,   3952,   4120],
+  // CHECK-NEXT:   [1720,   1904,   2088,   2272,   2456,   2640,   2824,   3008,   3192,   3376,   3560,   3744,   3928,   4112,   4296,   4480],
+  // CHECK-NEXT:   [1840,   2040,   2240,   2440,   2640,   2840,   3040,   3240,   3440,   3640,   3840,   4040,   4240,   4440,   4640,   4840],
+  // CHECK-NEXT:   [1960,   2176,   2392,   2608,   2824,   3040,   3256,   3472,   3688,   3904,   4120,   4336,   4552,   4768,   4984,   5200],
+  // CHECK-NEXT:   [2080,   2312,   2544,   2776,   3008,   3240,   3472,   3704,   3936,   4168,   4400,   4632,   4864,   5096,   5328,   5560],
+  // CHECK-NEXT:   [2200,   2448,   2696,   2944,   3192,   3440,   3688,   3936,   4184,   4432,   4680,   4928,   5176,   5424,   5672,   5920],
+  // CHECK-NEXT:   [2320,   2584,   2848,   3112,   3376,   3640,   3904,   4168,   4432,   4696,   4960,   5224,   5488,   5752,   6016,   6280],
+  // CHECK-NEXT:   [2440,   2720,   3000,   3280,   3560,   3840,   4120,   4400,   4680,   4960,   5240,   5520,   5800,   6080,   6360,   6640],
+  // CHECK-NEXT:   [2560,   2856,   3152,   3448,   3744,   4040,   4336,   4632,   4928,   5224,   5520,   5816,   6112,   6408,   6704,   7000],
+  // CHECK-NEXT:   [2680,   2992,   3304,   3616,   3928,   4240,   4552,   4864,   5176,   5488,   5800,   6112,   6424,   6736,   7048,   7360],
+  // CHECK-NEXT:   [2800,   3128,   3456,   3784,   4112,   4440,   4768,   5096,   5424,   5752,   6080,   6408,   6736,   7064,   7392,   7720],
+  // CHECK-NEXT:   [2920,   3264,   3608,   3952,   4296,   4640,   4984,   5328,   5672,   6016,   6360,   6704,   7048,   7392,   7736,   8080],
+  // CHECK-NEXT:   [3040,   3400,   3760,   4120,   4480,   4840,   5200,   5560,   5920,   6280,   6640,   7000,   7360,   7720,   8080,   8440]
+  return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16_a_b_transpose.mlir b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16_a_b_transpose.mlir
new file mode 100644
index 000000000000000..cb8c37777baa03d
--- /dev/null
+++ b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16_a_b_transpose.mlir
@@ -0,0 +1,84 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -convert-scf-to-cf \
+// RUN: | mlir-opt -gpu-kernel-outlining \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-amdgpu{chipset=%chip index-bitwidth=32},convert-gpu-to-rocdl{chipset=%chip index-bitwidth=32},reconcile-unrealized-casts),rocdl-attach-target{chip=%chip})' \
+// RUN: | mlir-opt -gpu-to-llvm -gpu-module-to-binary \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_rocm_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+func.func @main() {
+  %0 = memref.alloc() : memref<16x16xf16>
+  %22 = memref.alloc() : memref<16x16xf32>
+
+  %f1 = arith.constant 1.0e+00 : f16
+  %f0 = arith.constant 0.0e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+
+  // Intialize the Input matrix with ones.
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      %cast_c = arith.index_cast %arg1 : index to i16
+      %float = arith.sitofp %cast_c : i16 to f16
+      memref.store %float, %0[%arg0, %arg1] : memref<16x16xf16>
+    }
+  }
+  // Intialize the accumulator matrix with zeros.
+  scf.for %arg0 = %c0 to %c16 step %c1 {
+    scf.for %arg1 = %c0 to %c16 step %c1 {
+      memref.store %f0, %22[%arg0, %arg1] : memref<16x16xf32>
+    }
+  }
+
+  %2 = memref.cast %0 : memref<16x16xf16> to memref<*xf16>
+  %33 = memref.cast %22 : memref<16x16xf32> to memref<*xf32>
+
+  %stream = gpu.wait async
+  %gpu_in, %asyncToken_0 = gpu.alloc async [%stream] () : memref<16x16xf16>
+  %gpu_out, %asyncToken_1 = gpu.alloc async [%stream] () : memref<16x16xf32>
+
+  %asyncToken_2 = gpu.memcpy async [%stream] %gpu_in, %0 : memref<16x16xf16>, memref<16x16xf16>
+  %asyncToken_3 = gpu.memcpy async [%stream] %gpu_out, %22 : memref<16x16xf32>, memref<16x16xf32>
+
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+    %A = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %B = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+    %C = gpu.subgroup_mma_load_matrix %gpu_out[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf32> -> !gpu.mma_matrix<16x16xf32, "COp">
+
+    %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
+
+    gpu.subgroup_mma_store_matrix %R, %gpu_out[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
+    gpu.terminator
+  }
+
+  %asyncToken_4 = gpu.memcpy async [%stream] %22, %gpu_out : memref<16x16xf32>, memref<16x16xf32>
+  gpu.wait [%stream]
+
+  // Print the memref after computation.
+  call @printMemrefF32(%33) : (memref<*xf32>) -> ()
+  // CHECK:      [0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
+  // CHECK-NEXT: [120,   120,   120,   120,   120,   120,   120,   120,   120,   120,   120,   120,   120,   120,   120,   120],
+  // CHECK-NEXT: [240,   240,   240,   240,   240,   240,   240,   240,   240,   240,   240,   240,   240,   240,   240,   240],
+  // CHECK-NEXT: [360,   360,   360,   360,   360,   360,   360,   360,   360,   360,   360,   360,   360,   360,   360,   360],
+  // CHECK-NEXT: [480,   480,   480,   480,   480,   480,   480,   480,   480,   480,   480,   480,   480,   480,   480,   480],
+  // CHECK-NEXT: [600,   600,   600,   600,   600,   600,   600,   600,   600,   600,   600,   600,   600,   600,   600,   600],
+  // CHECK-NEXT: [720,   720,   720,   720,   720,   720,   720,   720,   720,   720,   720,   720,   720,   720,   720,   720],
+  // CHECK-NEXT: [840,   840,   840,   840,   840,   840,   840,   840,   840,   840,   840,   840,   840,   840,   840,   840],
+  // CHECK-NEXT: [960,   960,   960,   960,   960,   960,   960,   960,   960,   960,   960,   960,   960,   960,   960,   960],
+  // CHECK-NEXT: [1080,   1080,   1080,   1080,   1080,   1080,   1080,   1080,   1080,   1080,   1080,   1080,   1080,   1080,   1080,   1080],
+  // CHECK-NEXT: [1200,   1200,   1200,   1200,   1200,   1200,   1200,   1200,   1200,   1200,   1200,   1200,   1200,   1200,   1200,   1200],
+  // CHECK-NEXT: [1320,   1320,   1320,   1320,   1320,   1320,   1320,   1320,   1320,   1320,   1320,   1320,   1320,   1320,   1320,   1320],
+  // CHECK-NEXT: [1440,   1440,   1440,   1440,   1440,   1440,   1440,   1440,   1440,   1440,   1440,   1440,   1440,   1440,   1440,   1440],
+  // CHECK-NEXT: [1560,   1560,   1560,   1560,   1560,   1560,   1560,   1560,   1560,   1560,   1560,   1560,   1560,   1560,   1560,   1560],
+  // CHECK-NEXT: [1680,   1680,   1680,   1680,   1680,   1680,   1680,   1680,   1680,   1680,   1680,   1680,   1680,   1680,   1680,   1680],
+  // CHECK-NEXT: [1800,   1800,   1800,   1800,   1800,   1800,   1800,   1800,   1800,   1800,   1800,   1800,   1800,   1800,   1800,   1800]
+  return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index 2de40ba5e8e57e6..20fd9ad63311602 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -47,6 +47,7 @@ config.mlir_run_arm_sme_tests = @MLIR_RUN_ARM_SME_TESTS@
 config.mlir_run_x86vector_tests = @MLIR_RUN_X86VECTOR_TESTS@
 config.mlir_run_riscv_vector_tests = "@MLIR_RUN_RISCV_VECTOR_TESTS@"
 config.mlir_run_cuda_tensor_core_tests = @MLIR_RUN_CUDA_TENSOR_CORE_TESTS@
+config.mlir_run_rocm_wmma_tests = @MLIR_RUN_ROCM_WMMA_TESTS@
 config.mlir_run_cuda_sm80_tests = @MLIR_RUN_CUDA_SM80_TESTS@
 config.mlir_run_cuda_sm80_lt_tests = @MLIR_RUN_CUDA_SM80_LT_TESTS@
 config.mlir_run_cuda_sm90_tests = @MLIR_RUN_CUDA_SM90_TESTS@



More information about the Mlir-commits mailing list