[Mlir-commits] [mlir] 1ca772e - [MLIR][GPU] Add NvGpu mma.sync path to the VectorToGPU pass

Christopher Bate llvmlistbot at llvm.org
Fri May 20 08:44:18 PDT 2022


Author: Christopher Bate
Date: 2022-05-20T09:42:55-06:00
New Revision: 1ca772ed951e6412ef006459b56ae9a21691a97c

URL: https://github.com/llvm/llvm-project/commit/1ca772ed951e6412ef006459b56ae9a21691a97c
DIFF: https://github.com/llvm/llvm-project/commit/1ca772ed951e6412ef006459b56ae9a21691a97c.diff

LOG: [MLIR][GPU] Add NvGpu mma.sync path to the VectorToGPU pass

This changes adds the option to lower to NvGpu dialect ops during the
VectorToGPU convsersion pass. Because this transformation reuses
existing VectorToGPU logic, a seperate VectorToNvGpu conversion pass is
not created. The option `use-nvgpu` is added to the VectorToGPU pass.
When this is true, the pass will attempt to convert slices rooted at
`vector.contract` operations into `nvgpu.mma.sync` ops, and
`vector.transfer_read` ops are converted to either `nvgpu.ldmatrix` or
one or more `vector.load` operations.  The specific data loaded will
depend on the thread id within a subgroup (warp). These index
calculations depend on data type and shape of the MMA op
according to the downstream PTX specification. The code for supporting
these details is separated into `NvGpuSupport.cpp|h`.

Differential Revision: https://reviews.llvm.org/D122940

Added: 
    mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
    mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
    mlir/lib/Conversion/PassDetail.h
    mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 41e7b29f15d0d..6d9863e723487 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -851,8 +851,13 @@ def ConvertVectorToGPU : Pass<"convert-vector-to-gpu"> {
                 "dialect";
   let constructor = "mlir::createConvertVectorToGPUPass()";
   let dependentDialects = [
-    "memref::MemRefDialect",
-    "gpu::GPUDialect"
+    "memref::MemRefDialect", "gpu::GPUDialect", "AffineDialect", 
+    "vector::VectorDialect", "nvgpu::NVGPUDialect"
+  ];
+
+  let options = [
+    Option<"useNvGpu", "use-nvgpu", "bool", /*default=*/"false", 
+      "convert to NvGPU ops instead of GPU dialect ops">
   ];
 }
 

diff  --git a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
index 266fa0eac4c48..1ba5b3f90d9a8 100644
--- a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
+++ b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h
@@ -17,16 +17,25 @@ class Pass;
 class RewritePatternSet;
 
 /// Patterns to transform vector ops into a canonical form to convert to MMA
-/// matrix operations.
-void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns);
+/// matrix operations. If `useNvGpu` is true, then the patterns will populated
+/// will prepare for conversion to `nvgpu` mma operations rather than the `gpu`
+/// dialect WMMA operations.
+void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
+                                        bool useNvGpu = false);
 
 /// Convert vector ops to MMA matrix operations nested under `rootOp`. This will
 /// convert slice of operations that can be legally converted to MMA operations.
 /// The rest of the vector operations are left untouched.
 void convertVectorToMMAOps(Operation *rootOp);
 
+/// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons
+/// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice
+/// of operations that can be legally lowered on this path while the rest of
+/// the vector operations are left untouched.
+LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp);
+
 /// Convert from vector to GPU ops.
-std::unique_ptr<Pass> createConvertVectorToGPUPass();
+std::unique_ptr<Pass> createConvertVectorToGPUPass(bool useNvGpu = false);
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index e05004061dd4d..530e156024fdc 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -55,6 +55,10 @@ namespace LLVM {
 class LLVMDialect;
 } // namespace LLVM
 
+namespace nvgpu {
+class NVGPUDialect;
+}
+
 namespace NVVM {
 class NVVMDialect;
 } // namespace NVVM

diff  --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
index 06758c5fe1266..778f2c42eebe5 100644
--- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_conversion_library(MLIRVectorToGPU
   VectorToGPU.cpp
+  NvGpuSupport.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU

diff  --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
new file mode 100644
index 0000000000000..a2820c3e88f8c
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
@@ -0,0 +1,327 @@
+//===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides utilities to assist in the lowering of Vector operations
+// to NvGPU dialect MMA operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "NvGpuSupport.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir {
+namespace nvgpu {
+namespace {
+
+/// There are always 4 threads per [128|256|512] bit row.
+constexpr int64_t kThreadsPerRow = 4;
+
+constexpr int64_t kNumRowsPerTile = 8;
+
+bool isAccumulatorOrResult(MatMulOperandRole operandType) {
+  return operandType == MatMulOperandRole::C;
+}
+
+/// Returns the number of registers which compose a matrix fragment held by a
+/// single thread.
+int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
+  int64_t lineSize = inferTileWidthInBits(type);
+  auto shape = type.vectorType.getShape();
+  return (shape[0] / kNumRowsPerTile) *
+         (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
+         lineSize;
+}
+
+/// Returns the number of 8 x [128|256|512] bit tiles that compose the given
+/// operand shape.
+std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
+                                    Type elementType, int64_t lineSizeBits) {
+  // For each 8x128bit square, a thread is responsible for one 32bit register.
+  return {operandShape[0] / kNumRowsPerTile,
+          (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
+              lineSizeBits};
+}
+
+} // namespace
+
+FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) {
+  WarpMatrixInfo info;
+
+  // Determine the vector type.
+  if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
+    info.vectorType = writeOp.getVectorType();
+  } else if (isa<vector::TransferReadOp, vector::ContractionOp,
+                 arith::ConstantOp>(op)) {
+    info.vectorType = op->getResult(0).getType().cast<VectorType>();
+  } else {
+    return op->emitError()
+           << "unhandled operation type in nvgpu.mma.sync conversion path";
+  }
+
+  // Determine the operand role. We assume it is an accumulator/result unless it
+  // is directly consumed by a `vector.contract` op.
+  info.operandRole = MatMulOperandRole::C;
+  for (Operation *user : op->getUsers()) {
+    auto contract = dyn_cast<vector::ContractionOp>(user);
+    if (!contract)
+      continue;
+    if (contract.getLhs() == op->getResult(0)) {
+      info.operandRole = MatMulOperandRole::A;
+      break;
+    }
+    if (contract.getRhs() == op->getResult(0)) {
+      info.operandRole = MatMulOperandRole::B;
+      break;
+    }
+  }
+  return info;
+}
+
+int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
+  bool isAcc = isAccumulatorOrResult(type.operandRole);
+  Type elType = type.vectorType.getElementType();
+  if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
+    return 256;
+  }
+  if (elType.getIntOrFloatBitWidth() == 64) {
+    return isAcc ? 512 : 256;
+  }
+  return 128;
+}
+
+FailureOr<FragmentElementInfo>
+getMmaSyncRegisterType(const WarpMatrixInfo &type) {
+  MLIRContext *ctx = type.vectorType.getContext();
+  const bool isAccum = isAccumulatorOrResult(type.operandRole);
+
+  Type elType = type.vectorType.getElementType();
+  if (elType.isF16()) {
+    return FragmentElementInfo{
+        LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
+        inferNumRegistersPerMatrixFragment(type)};
+  }
+
+  // f64 operand
+  Type f64Ty = Float64Type::get(ctx);
+  if (elType.isF64()) {
+    return isAccum
+               ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
+                                     inferNumRegistersPerMatrixFragment(type)}
+               : FragmentElementInfo{f64Ty, 1, 64,
+                                     inferNumRegistersPerMatrixFragment(type)};
+  }
+
+  // int8 operand
+  if (elType.isInteger(8)) {
+    return FragmentElementInfo{
+        LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
+        inferNumRegistersPerMatrixFragment(type)};
+  }
+  // Integer 32bit acc operands
+  if (elType.isInteger(32)) {
+    return FragmentElementInfo{
+        LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
+        inferNumRegistersPerMatrixFragment(type)};
+  }
+
+  // Floating point 32bit operands
+  if (elType.isF32()) {
+    Type f32Ty = Float32Type::get(ctx);
+    return isAccum
+               ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
+                                     inferNumRegistersPerMatrixFragment(type)}
+               : FragmentElementInfo{f32Ty, 1, 32,
+                                     inferNumRegistersPerMatrixFragment(type)};
+  }
+  return failure();
+}
+
+static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
+                                                 Type elementType,
+                                                 ArrayRef<int64_t> operandShape,
+                                                 bool isAccumulator,
+                                                 int64_t elementsPerRegister,
+                                                 AffineExpr logicalValueId) {
+  const int64_t elementsPerLine =
+      lineSize / elementType.getIntOrFloatBitWidth();
+  const std::array<int64_t, 2> num8x128bTiles =
+      getTileShape(operandShape, elementType, lineSize);
+  AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
+  return AffineMap::get(
+      2, 0,
+      {(registerIdx % num8x128bTiles[0]) * 8,
+       (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
+      elementType.getContext());
+}
+
+FailureOr<AffineMap>
+getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+                                  const WarpMatrixInfo &fragmentType) {
+  Type elementType = fragmentType.vectorType.getElementType();
+  ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
+  FailureOr<nvgpu::FragmentElementInfo> regInfo =
+      getMmaSyncRegisterType(fragmentType);
+  if (failed(regInfo))
+    return failure();
+
+  const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
+  const int64_t elementsPerRegister =
+      regInfo->registerWidthBits / elementBitWidth;
+  const int64_t lineSize = inferTileWidthInBits(fragmentType);
+
+  AffineExpr laneId, logicalValueIdDim;
+  bindDims(builder.getContext(), laneId, logicalValueIdDim);
+
+  // Determine what register logicalValueId corresponds to. Use that as a
+  // linear index into the coordinate mapping `index -> (tile row, tile col)`.
+  AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
+      lineSize, elementType, operandShape,
+      isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
+      logicalValueIdDim);
+
+  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
+    return AffineMap::get(2, 0, dimExprs, builder.getContext());
+  };
+
+  auto tileRow = registerIndexToTileCoord.getResult(0);
+  auto tileCol = registerIndexToTileCoord.getResult(1);
+  return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
+                  tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
+                      (logicalValueIdDim % elementsPerRegister)});
+}
+
+FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
+                                                   bool transpose) {
+  LdMatrixParams params;
+  Type elType = type.vectorType.getElementType();
+  params.fragmentType = type.vectorType;
+  if (type.operandRole == MatMulOperandRole::A ||
+      type.operandRole == MatMulOperandRole::C) {
+    params.targetLayout = NVVM::MMALayout::row;
+  } else {
+    params.targetLayout = NVVM::MMALayout::col;
+  }
+  ArrayRef<int64_t> shape = type.vectorType.getShape();
+  params.contiguousDimType =
+      transpose ? IteratorType::Parallel : IteratorType::Reduction;
+
+  if (params.targetLayout == NVVM::MMALayout::row) {
+    params.numTiles = (shape[0] / kNumRowsPerTile) *
+                      ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
+  } else {
+    params.numTiles = (shape[1] / kNumRowsPerTile) *
+                      ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
+  }
+
+  if (params.numTiles == 0)
+    return failure();
+
+  return params;
+}
+
+FailureOr<AffineMap>
+getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+                               const LdMatrixParams &params) {
+  // One thread per 128b row.
+  const int64_t kNumThreadsPerTile = kNumRowsPerTile;
+  const int bitsPerElement = static_cast<int>(
+      params.fragmentType.getElementType().getIntOrFloatBitWidth());
+  const int kElementsPer128b = (128 / bitsPerElement);
+  ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
+  AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
+
+  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
+    return AffineMap::get(1, 0, dimExprs, builder.getContext());
+  };
+
+  // This case corresponds to row-major A|C or col-major B operands.
+  if (params.contiguousDimType == IteratorType::Reduction) {
+    AffineExpr row = d0 % (operandShape[0]);
+    AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
+    return makeMap({row, col});
+  }
+
+  // This case Corresponds to col-major A|C or row-major B operands. The
+  // operandShape given is already pre-transposed (e.g. 8x16 = KxN).
+  if (params.contiguousDimType == IteratorType::Parallel) {
+    const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
+    // Threads are assigned in groups of 8 first across columns, then to
+    // rows. This is transpose of what `ldmatrix` expects, but when
+    // `ldmatrix` gets the `.trans` qualifier, final the effect will be to
+    // transpose just the blocks.
+    auto groupIdx = d0.floorDiv(kNumThreadsPerTile);
+    auto tileCol = (groupIdx % num8x128bCols);
+    auto tileRow = groupIdx.floorDiv(num8x128bCols);
+    return makeMap({tileCol * kElementsPer128b,
+                    tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)});
+  }
+  return failure();
+}
+
+LogicalResult
+PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
+                                             PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  Value lhs = op.getLhs();
+  Value rhs = op.getRhs();
+  Value res = op.getAcc();
+
+  // Set up the parallel/reduction structure in right form.
+  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
+  AffineExpr m;
+  AffineExpr n;
+  AffineExpr k;
+  bindDims(rewriter.getContext(), m, n, k);
+  static constexpr std::array<int64_t, 2> perm = {1, 0};
+  auto iteratorTypes = op.getIteratorTypes().getValue();
+  SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
+  if (iteratorTypes.size() != 3)
+    return failure();
+  if (!(isParallelIterator(iteratorTypes[0]) &&
+        isParallelIterator(iteratorTypes[1]) &&
+        isReductionIterator(iteratorTypes[2])))
+    return failure();
+
+  // The canonical form is "TNT" = A row-major, B col-major, C row-major.
+  const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
+  if (maps == canonicalForm) {
+    return failure();
+  }
+  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
+    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
+    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
+    std::swap(rhs, lhs);
+    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
+    std::swap(rhs, lhs);
+    rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
+  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
+    std::swap(lhs, rhs);
+    lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
+  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
+    std::swap(lhs, rhs);
+  } else {
+    return failure();
+  }
+  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+      op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
+      op.getIteratorTypes());
+  return success();
+}
+
+} // namespace nvgpu
+} // namespace mlir
\ No newline at end of file

diff  --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
new file mode 100644
index 0000000000000..9902faa835a6f
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
@@ -0,0 +1,100 @@
+//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides utilities to assist in the lowering of Vector operations
+// to GPU dialect MMA operations.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
+#define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace nvgpu {
+
+enum class MatMulOperandRole : int32_t { A = 0, B, C };
+
+/// Collects information about a warp-level matrix operand represented by a
+/// VectorType.
+struct WarpMatrixInfo {
+  VectorType vectorType;
+  MatMulOperandRole operandRole;
+};
+
+/// Given an op that operates on a VectorType representing a warp-level matrix
+/// operand, the function returns a struct containing relevant type information.
+FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op);
+
+/// Returns the number of bits in a single tile row. It is either 128, 256, or
+/// 512 bits depending on the data type and` whether the operand is an
+/// accumulator/result operand
+int64_t inferTileWidthInBits(const WarpMatrixInfo &type);
+
+/// Specifies information about the registers which compose a matrix fragment
+/// according to the PTX documentation.
+struct FragmentElementInfo {
+  Type registerLLVMType;
+  int64_t elementsPerRegister;
+  int64_t registerWidthBits;
+  int64_t numRegistersPerFragment;
+};
+
+/// Returns a FragmentElementInfo struct describing the register types for the
+/// given matrix fragment type.
+FailureOr<FragmentElementInfo>
+getMmaSyncRegisterType(const WarpMatrixInfo &type);
+
+/// Returns an AffineMap which maps a two dimensions representing (laneId,
+/// logicalValueId) and returns two results representing offsets within a
+/// matrix operand. The offsets point to the values the thread is responsible
+/// for (AKA the matrix fragment values) during a warp-collective matrix
+/// operation. For a visual reference of this LaneId -> (row, col) mapping,
+/// please see NVIDIA's PTX documentation:
+/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
+FailureOr<AffineMap>
+getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+                                  const WarpMatrixInfo &fragmentType);
+
+struct LdMatrixParams {
+  VectorType fragmentType;
+  bool isAccum;
+  int64_t numTiles;
+  IteratorType contiguousDimType;
+  NVVM::MMALayout targetLayout;
+};
+
+FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
+                                            bool transpose);
+/// Returns an AffineMap which maps a single dimension representing the laneId
+/// to two results representing offsets within the matrix operand that should
+/// be the pointer locations a thread should pass to the ldmatrix instruction.
+FailureOr<AffineMap>
+getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+                               const LdMatrixParams &params);
+
+// Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted
+// to MMA matmul.
+struct PrepareContractToGPUMMASync
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+} // namespace nvgpu
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 9ed1c3483c11d..a6e122c380315 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -12,6 +12,7 @@
 
 #include <type_traits>
 
+#include "NvGpuSupport.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
 
 #include "../PassDetail.h"
@@ -19,6 +20,7 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -27,11 +29,39 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 
+/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
+/// AffineMap representing offsets to apply to indices, the function fills
+/// `indices` with the original indices plus the offsets. The offsets are
+/// applied by taking into account the permutation map of the transfer op. If
+/// the `offsetMap` has dimension placeholders, those should be provided in
+/// `dimValues`.
+template <typename TransferOpType>
+static void getXferIndices(OpBuilder &b, TransferOpType xferOp,
+                           AffineMap offsetMap, ArrayRef<Value> dimValues,
+                           SmallVector<Value, 4> &indices) {
+  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
+  Location loc = xferOp.getLoc();
+  unsigned offsetsIdx = 0;
+  for (auto expr : xferOp.getPermutationMap().getResults()) {
+    if (auto dim = expr.template dyn_cast<AffineDimExpr>()) {
+      Value prevIdx = indices[dim.getPosition()];
+      SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end());
+      dims.push_back(prevIdx);
+      AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims());
+      indices[dim.getPosition()] = makeComposedAffineApply(
+          b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
+      continue;
+    }
+  }
+}
+
 // Return true if the contract op can be convert to MMA matmul.
-static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
+static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
+                                          bool useNvGpu) {
   if (llvm::size(contract.getMasks()) != 0)
     return false;
 
@@ -47,7 +77,10 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
 
   // The contract needs to represent a matmul to be able to convert to
   // MMAMatrix matmul.
-  if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
+  if (!useNvGpu &&
+      contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
+    return false;
+  if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}}))
     return false;
 
   return true;
@@ -61,7 +94,7 @@ getMemrefConstantHorizontalStride(ShapedType type) {
   if (!memrefType)
     return false;
   // If the memref is 0 or 1D the horizontal stride is 0.
-  if(memrefType.getRank() < 2)
+  if (memrefType.getRank() < 2)
     return 0;
   int64_t offset = 0;
   SmallVector<int64_t, 2> strides;
@@ -75,7 +108,8 @@ getMemrefConstantHorizontalStride(ShapedType type) {
 }
 
 // Return true if the transfer op can be converted to a MMA matrix load.
-static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
+static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
+                                              bool useNvGpu) {
   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
       readOp.getVectorType().getRank() != 2)
     return false;
@@ -87,9 +121,14 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
   AffineExpr zero = b.getAffineConstantExpr(0);
   auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
                                           readOp.getContext());
-  // TODO: Support transpose once it is added to GPU dialect ops.
-  // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
-  return !(!map.isMinorIdentity() && map != broadcastInnerDim);
+
+  if (!useNvGpu) {
+    // TODO: Support transpose once it is added to GPU dialect ops.
+    // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
+    return map.isMinorIdentity() || map == broadcastInnerDim;
+  }
+
+  return true;
 }
 
 // Return true if the transfer op can be converted to a MMA matrix store.
@@ -147,15 +186,15 @@ static bool elementwiseSupportsMMAMatrixType(Operation *op) {
   return convertElementwiseOpToMMA(op).hasValue();
 }
 
-static bool supportsMMaMatrixType(Operation *op) {
+static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
   if (isa<scf::ForOp, scf::YieldOp>(op))
     return true;
   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
-    return transferReadSupportsMMAMatrixType(transferRead);
+    return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
     return transferWriteSupportsMMAMatrixType(transferWrite);
   if (auto contract = dyn_cast<vector::ContractionOp>(op))
-    return contractSupportsMMAMatrixType(contract);
+    return contractSupportsMMAMatrixType(contract, useNvGpu);
   if (auto constant = dyn_cast<arith::ConstantOp>(op))
     return constantSupportsMMAMatrixType(constant);
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
@@ -203,7 +242,8 @@ static SetVector<Operation *> getSliceContract(Operation *op,
 
 // Analyze slice of operations based on convert op to figure out if the whole
 // slice can be converted to MMA operations.
-static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
+static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
+                                             bool useNvGpu) {
   auto hasVectorDest = [](Operation *op) {
     return llvm::any_of(op->getResultTypes(),
                         [](Type t) { return t.isa<VectorType>(); });
@@ -221,8 +261,9 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
     // If any instruction cannot use MMA matrix type drop the whole
     // chain. MMA matrix are stored in an opaque type so they cannot be used
     // by all operations.
-    if (llvm::any_of(dependentOps,
-                     [](Operation *op) { return !supportsMMaMatrixType(op); }))
+    if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
+          return !supportsMMaMatrixType(op, useNvGpu);
+        }))
       return;
     opToConvert.insert(dependentOps.begin(), dependentOps.end());
   });
@@ -351,7 +392,7 @@ static const char *inferFragType(OpTy op) {
 static void convertTransferReadOp(vector::TransferReadOp op,
                                   llvm::DenseMap<Value, Value> &valueMapping) {
   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
-  assert(transferReadSupportsMMAMatrixType(op));
+  assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
   Optional<int64_t> stride =
       getMemrefConstantHorizontalStride(op.getShapedType());
   AffineMap map = op.getPermutationMap();
@@ -386,6 +427,250 @@ static void convertTransferWriteOp(vector::TransferWriteOp op,
   op.erase();
 }
 
+/// Returns the vector type which represents a matrix fragment.
+static VectorType
+getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
+  SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
+                             regInfo.elementsPerRegister};
+  Type elType = regInfo.registerLLVMType;
+  if (auto vecType = elType.dyn_cast<VectorType>())
+    elType = vecType.getElementType();
+  return VectorType::get(shape, elType);
+}
+
+/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
+static LogicalResult
+convertConstantOpMmaSync(arith::ConstantOp op,
+                         llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder b(op);
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+
+  FailureOr<nvgpu::FragmentElementInfo> regInfo =
+      nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+  if (failed(regInfo))
+    return failure();
+
+  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+  auto dense = op.getValue().dyn_cast<SplatElementsAttr>();
+  if (!dense)
+    return failure();
+  Value result = b.create<arith::ConstantOp>(
+      op.getLoc(), vectorType,
+      DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
+  valueMapping[op.getResult()] = result;
+  return success();
+}
+
+static LogicalResult
+creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
+                             llvm::DenseMap<Value, Value> &valueMapping) {
+  Location loc = op->getLoc();
+
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+
+  FailureOr<nvgpu::FragmentElementInfo> regInfo =
+      nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+  if (failed(regInfo))
+    return failure();
+
+  FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams(
+      *warpMatrixInfo,
+      /*transpose=*/!op.getPermutationMap().isMinorIdentity());
+  if (failed(params)) {
+    return op->emitError()
+           << "failed to convert vector.transfer_read to ldmatrix; this op "
+              "likely "
+              "should not be converted to a nvgpu.ldmatrix call.";
+  }
+
+  // Adjust the load offset.
+  auto laneId = builder.create<gpu::LaneIdOp>(loc);
+  FailureOr<AffineMap> offsets =
+      nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params);
+  if (failed(offsets))
+    return failure();
+
+  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+
+  SmallVector<Value, 4> indices;
+  getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId},
+                                         indices);
+  nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>(
+      loc, vectorType, op.getSource(), indices,
+      !op.getPermutationMap().isMinorIdentity(), params->numTiles);
+  valueMapping[op] = newOp->getResult(0);
+  return success();
+}
+
+static LogicalResult
+createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
+                       llvm::DenseMap<Value, Value> &valueMapping) {
+  Location loc = op.getLoc();
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+  FailureOr<nvgpu::FragmentElementInfo> regInfo =
+      nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+  if (failed(regInfo)) {
+    op->emitError() << "Failed to deduce register fragment type during "
+                       "conversion to distributed non-ldmatrix compatible load";
+    return failure();
+  }
+
+  NVVM::MMALayout targetLayout =
+      warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B
+          ? NVVM::MMALayout::col
+          : NVVM::MMALayout::row;
+
+  Value laneId = builder.create<gpu::LaneIdOp>(loc);
+  SmallVector<Value, 4> elements;
+
+  // This is the individual element type.
+  Type loadedElType = regInfo->registerLLVMType;
+  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+
+  Value fill = builder.create<arith::ConstantOp>(
+      op.getLoc(), vectorType.getElementType(),
+      builder.getZeroAttr(vectorType.getElementType()));
+  Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
+
+  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
+
+  // Vectorized loads.
+  if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) {
+    if (!loadedElType.isa<VectorType>()) {
+      loadedElType = VectorType::get({1}, loadedElType);
+    }
+
+    for (int i = 0; i < vectorType.getShape()[0]; i++) {
+      FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
+          op.getLoc(), builder, *warpMatrixInfo);
+      if (failed(coords))
+        return failure();
+      Value logicalValueId = builder.create<arith::ConstantOp>(
+          loc, builder.getIndexType(),
+          builder.getIndexAttr(i * regInfo->elementsPerRegister));
+      SmallVector<Value, 4> newIndices;
+      getXferIndices<vector::TransferReadOp>(
+          builder, op, *coords, {laneId, logicalValueId}, newIndices);
+
+      Value el = builder.create<vector::LoadOp>(loc, loadedElType,
+                                                op.getSource(), newIndices);
+      result = builder.create<vector::InsertOp>(loc, el, result,
+                                                builder.getI64ArrayAttr(i));
+    }
+  } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) {
+    if (auto vecType = loadedElType.dyn_cast<VectorType>()) {
+      loadedElType = vecType.getElementType();
+    }
+    // Load each element individually.
+    for (int i = 0; i < vectorType.getShape()[0]; i++) {
+      for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
+           innerIdx++) {
+
+        Value logicalValueId = builder.create<arith::ConstantOp>(
+            loc, builder.getIndexType(),
+            builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
+        FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
+            op.getLoc(), builder, *warpMatrixInfo);
+        if (failed(coords))
+          return failure();
+
+        SmallVector<Value, 4> newIndices;
+        getXferIndices<vector::TransferReadOp>(
+            builder, op, *coords, {laneId, logicalValueId}, newIndices);
+        Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType,
+                                                  op.getSource(), newIndices);
+        result = builder.create<vector::InsertOp>(
+            op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx}));
+      }
+    }
+  } else {
+    return failure();
+  }
+
+  valueMapping[op.getResult()] = result;
+  return success();
+}
+
+/// Converts a `vector.transfer_read` operation directly to either a
+/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
+/// used when converting to `nvgpu.mma.sync` operations.
+static LogicalResult
+convertTransferReadToLoads(vector::TransferReadOp op,
+                           llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder b(op);
+
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+
+  bool isLdMatrixCompatible =
+      op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 &&
+      nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
+
+  VectorType vecTy = op.getVectorType();
+  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
+
+  // When we are transposing the B operand, ldmatrix will only work if we have
+  // at least 8 rows to read and  the width to read for the transpose is 128
+  // bits.
+  if (!op.getPermutationMap().isMinorIdentity() &&
+      (vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128))
+    isLdMatrixCompatible = false;
+
+  if (!isLdMatrixCompatible)
+    return createNonLdMatrixLoads(op, b, valueMapping);
+
+  return creatLdMatrixCompatibleLoads(op, b, valueMapping);
+}
+
+static LogicalResult
+convertTransferWriteToStores(vector::TransferWriteOp op,
+                             llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder b(op);
+  Location loc = op->getLoc();
+  Value matrix = valueMapping.find(op.getVector())->second;
+
+  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
+      nvgpu::getWarpMatrixInfo(op);
+  if (failed(warpMatrixInfo))
+    return failure();
+  FailureOr<nvgpu::FragmentElementInfo> regInfo =
+      nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
+  if (failed(regInfo))
+    return failure();
+
+  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
+  Value laneId = b.create<gpu::LaneIdOp>(loc);
+
+  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
+    Value logicalValueId = b.create<arith::ConstantOp>(
+        loc, b.getIndexType(),
+        b.getIndexAttr(i * regInfo->elementsPerRegister));
+    FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
+        op.getLoc(), b, *warpMatrixInfo);
+    if (failed(coords))
+      return failure();
+
+    Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
+    SmallVector<Value, 4> newIndices;
+    getXferIndices<vector::TransferWriteOp>(
+        b, op, *coords, {laneId, logicalValueId}, newIndices);
+    b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
+  }
+  op->erase();
+  return success();
+}
+
 static void convertContractOp(vector::ContractionOp op,
                               llvm::DenseMap<Value, Value> &valueMapping) {
   OpBuilder b(op);
@@ -397,6 +682,22 @@ static void convertContractOp(vector::ContractionOp op,
   valueMapping[op.getResult()] = matmul;
 }
 
+static LogicalResult
+convertContractOpToMmaSync(vector::ContractionOp op,
+                           llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder b(op);
+  Value opA = valueMapping.find(op.getLhs())->second;
+  Value opB = valueMapping.find(op.getRhs())->second;
+  Value opC = valueMapping.find(op.getAcc())->second;
+  int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
+  int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
+  int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
+  Value matmul = b.create<nvgpu::MmaSyncOp>(
+      op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k}));
+  valueMapping[op.getResult()] = matmul;
+  return success();
+}
+
 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
 static void convertConstantOp(arith::ConstantOp op,
                               llvm::DenseMap<Value, Value> &valueMapping) {
@@ -509,13 +810,20 @@ static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType,
   valueMapping[op->getResult(0)] = newOp;
 }
 
-void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
-  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
-      patterns.getContext());
+void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
+                                              bool useNvGpu) {
+  if (!useNvGpu) {
+    patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
+        patterns.getContext());
+    return;
+  }
+  patterns
+      .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>(
+          patterns.getContext());
 }
 
 void mlir::convertVectorToMMAOps(Operation *rootOp) {
-  SetVector<Operation *> ops = getOpToConvert(rootOp);
+  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
   llvm::DenseMap<Value, Value> valueMapping;
   for (Operation *op : ops) {
     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
@@ -538,21 +846,71 @@ void mlir::convertVectorToMMAOps(Operation *rootOp) {
   }
 }
 
+LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
+  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
+  llvm::DenseMap<Value, Value> valueMapping;
+  for (Operation *op : ops) {
+    if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
+            .Case([&](vector::TransferReadOp transferReadOp) {
+              return convertTransferReadToLoads(transferReadOp, valueMapping);
+            })
+            .Case([&](vector::TransferWriteOp transferWriteOp) {
+              return convertTransferWriteToStores(transferWriteOp,
+                                                  valueMapping);
+            })
+            .Case([&](vector::ContractionOp contractionOp) {
+              return convertContractOpToMmaSync(contractionOp, valueMapping);
+            })
+            .Case([&](scf::ForOp forOp) {
+              convertForOp(forOp, valueMapping);
+              return success();
+            })
+            .Case([&](scf::YieldOp yieldOp) {
+              convertYieldOp(yieldOp, valueMapping);
+              return success();
+            })
+            .Case([&](arith::ConstantOp constOp) {
+              return convertConstantOpMmaSync(constOp, valueMapping);
+            })
+            .Default([&](Operation *op) {
+              op->emitError() << "unhandled vector to mma type: " << *op;
+              return failure();
+            })
+            .failed()) {
+      op->emitError() << "Failed to convert op " << *op;
+      return failure();
+    }
+  }
+  return success();
+}
+
 namespace {
 
 struct ConvertVectorToGPUPass
     : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
+
+  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
+    useNvGpu.setValue(useNvGpu_);
+  }
+
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populatePrepareVectorToMMAPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+    populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+
+    if (useNvGpu.getValue()) {
+      if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
+        return signalPassFailure();
+    }
 
-    convertVectorToMMAOps(getOperation());
+    (void)convertVectorToMMAOps(getOperation());
   }
 };
 
 } // namespace
 
-std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() {
-  return std::make_unique<ConvertVectorToGPUPass>();
+std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
+  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
 }

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
new file mode 100644
index 0000000000000..be8d08be06ce6
--- /dev/null
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir
@@ -0,0 +1,349 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline="func.func(convert-vector-to-gpu{use-nvgpu=true})" | FileCheck %s
+
+//#########################################################
+// INT8 row-row-row
+//#########################################################
+
+// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16 + 1)>
+
+// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 39)>
+// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)>
+// CHECK-DAG: [[$rowB1_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 40)>
+// CHECK-DAG: [[$rowB2_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 41)>
+// CHECK-DAG: [[$rowB3_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 42)>
+// CHECK-DAG: [[$rowB4_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 55)>
+// CHECK-DAG: [[$rowB5_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 56)>
+// CHECK-DAG: [[$rowB6_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 57)>
+// CHECK-DAG: [[$rowB7_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 58)>
+
+// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)>
+// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)>
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 57)>
+
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m16n8k32_int8_row_row_row
+func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8, 3>, %arg2: memref<128x128xi32>) {
+  %cst_0 = arith.constant dense<0> : vector<32x8xi8>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c17 = arith.constant 17 : index  
+  %c39 = arith.constant 39 : index  
+  %c40 = arith.constant 40 : index  
+  %c49 = arith.constant 49 : index  
+  %c50 = arith.constant 50 : index  
+  %cst = arith.constant 0 : i8
+  %cst0 = arith.constant 0 : i32
+
+  // Verify that the operand A is distributed to loads correctly.
+
+  // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[{{%.+}}]
+  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8>
+
+  // Verify that the operand B is distributed to loads correctly. It's elements
+  // must be loaded in a non-vectorized manner to do the transpose.
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB1_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB2_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB3_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB4_map]]()[{{%.+}}]  
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB5_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB6_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB7_map]]()[{{%.+}}]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
+  // CHECK-NOT: memref.load %arg1
+
+  // Verify that the operand C is distributed to loads correctly.
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK-NOT: vector.load %arg2{{.*}}
+
+  %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8>
+  %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, 3>, vector<8x32xi8>
+  %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
+  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32>
+
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+  // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}]
+  // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
+  // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
+  vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32>
+  return
+}
+
+// -----
+
+//#########################################################
+// f64 row-row-row
+//#########################################################
+// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 1)>
+// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 1)>
+
+// CHECK-DAG: [[$rowb0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 39)>
+// CHECK-DAG: [[$colb0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)>
+
+// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)>
+// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m8n8k4_f64_row_row_row
+func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x128xf64>, %arg2: memref<128x128xf64>) {
+  %cst_0 = arith.constant dense<0.0> : vector<4x8xf64>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c17 = arith.constant 17 : index  
+  %c39 = arith.constant 39 : index  
+  %c40 = arith.constant 40 : index  
+  %c49 = arith.constant 49 : index  
+  %c50 = arith.constant 50 : index  
+  %cst = arith.constant 0.0 : f64
+  %cst0 = arith.constant 0.0 : f64
+
+  // Verify that the operand A is distributed to loads correctly.
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]]
+  // CHECK: vector.load %arg0[[[row]], [[col]]] : memref<128x128xf64>, vector<1xf64>
+
+  // Verify that the operand B is distributed to loads correctly. It's elements
+  // must be loaded in a non-vectorized manner to do the transpose.
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowb0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colb0_map]]
+  // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xf64>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]]
+  // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64>  
+
+  %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x4xf64>
+  %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64>
+  %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64>
+  // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]]
+  // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64>  
+  vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<8x8xf64>, memref<128x128xf64>
+  return
+}
+
+// -----
+
+//#########################################################
+// FP16 row-row-row
+//#########################################################
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
+
+// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)>
+// CHECK-DAG: [[$colB_map:#.+]] = affine_map<() -> (3)>
+
+// CHECK-LABEL: func @m16n8k16_fp16_row_row_row
+func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}  
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true}
+  %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3>
+  return
+}
+
+// -----
+
+// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
+// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)>
+// CHECK-DAG: [[$Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)>
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @batch_m16n8k16_fp16_row_row_row
+func.func @batch_m16n8k16_fp16_row_row_row(%arg0: memref<2x20x20xf16, 3>, %arg1: memref<2x20x20xf16, 3>, %arg2: memref<2x20x20xf16, 3>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16>
+  // CHECK: [[C0:%.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16>
+  %A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16>
+  
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$Brow_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$Bcol_map]]  
+  // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
+  %B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16>
+  
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]]
+  // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
+  %C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3>
+  return
+}
+
+// -----
+
+//#########################################################
+// FP16 row-col-row
+//#########################################################
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
+
+// CHECK: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + 1)>
+// CHECK: [[$colB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 + 3)>
+
+// CHECK-LABEL: func @m16n8k16_fp16_row_col_row
+func.func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+  // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32
+  // CHECK-SAME: transpose = false
+  
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+  // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32
+  // CHECK-SAME: transpose = false
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]   
+  // CHECK: nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 2 : i32
+  // CHECK-SAME: transpose = false
+  %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
+  %C = vector.transfer_read %arg2[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3>
+  return
+}
+
+// -----
+
+//#########################################################
+// TF32 (multiplicand) F32 (accumulator) row-row-row
+//#########################################################
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
+// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 4 + 3)>
+
+// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 3)>
+// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 3)>
+
+// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)>
+
+// CHECK-LABEL: func @m16n8k4_tf32_f32_row_row_row
+func.func @m16n8k4_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %cst = arith.constant 0.000000e+00 : f32
+
+  // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32>
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
+  // CHECK: [[a_frag:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false}  
+
+  // b and c are not loaded by ldmatrix in this test.
+  // CHECK-NOT: nvgpu.ldmatrix
+
+  // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
+  // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
+  // CHECK: [[b_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>  
+  // CHECK: [[b_frag:%.+]] = vector.insert [[b_el]], {{.*}} : f32 into vector<1x1xf32>
+
+  // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
+  // CHECK-SAME: mmaShape = [16, 8, 4]
+  // CHECK-SAME: -> vector<2x2xf32>
+  %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<16x4xf32>
+  %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x4xf32>  
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x4xf32>, vector<8x4xf32> into vector<16x8xf32>
+
+  // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32>
+  // CHECK: affine.apply [[$rowC_map]]
+  // CHECK: affine.apply [[$colC_map]]
+  // CHECK: vector.store
+  // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32>
+  // CHECK: affine.apply [[$rowC8_map]]
+  // CHECK: affine.apply [[$colC_map]]
+  // CHECK: vector.store
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32>
+  return
+}


        


More information about the Mlir-commits mailing list