[Mlir-commits] [mlir] 894a591 - [mlir][nvgpu] Move mma.sync and ldmatrix in nvgpu dialect

Thomas Raoux llvmlistbot at llvm.org
Thu Apr 14 16:46:25 PDT 2022


Author: Thomas Raoux
Date: 2022-04-14T23:44:52Z
New Revision: 894a591cf6fc542e6fc5d84222c839495a3d832f

URL: https://github.com/llvm/llvm-project/commit/894a591cf6fc542e6fc5d84222c839495a3d832f
DIFF: https://github.com/llvm/llvm-project/commit/894a591cf6fc542e6fc5d84222c839495a3d832f.diff

LOG: [mlir][nvgpu] Move mma.sync and ldmatrix in nvgpu dialect

Move gpu operation mma.sync and ldmatrix in nvgpu as they are specific
to nvidia target.

Differential Revision: https://reviews.llvm.org/D123824

Added: 
    mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h
    mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/include/mlir/Dialect/NVGPU/NVGPU.td
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/test/Dialect/NVGPU/roundtrip.mlir

Removed: 
    mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir


################################################################################
diff  --git a/mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h b/mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h
new file mode 100644
index 0000000000000..05dd975322622
--- /dev/null
+++ b/mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h
@@ -0,0 +1,26 @@
+//===- NVGPUToNVVMPass.h - Convert NVGPU to NVVM dialect --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_
+#define MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
+                                           RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createConvertNVGPUToNVVMPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ae05446b7b236..0e9b4fc9f233e 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -35,6 +35,7 @@
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
+#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
 #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5dcf9d7415964..d42e7dd4be871 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -506,6 +506,22 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv", "ModuleOp"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// NVGPUToNVVM
+//===----------------------------------------------------------------------===//
+
+def ConvertNVGPUToNVVM : Pass<"convert-nvgpu-to-nvvm"> {
+  let summary = "Convert NVGPU dialect to NVVM dialect";
+  let description = [{
+    This pass converts supported NVGPU ops to NVVM dialect intrinsics.
+  }];
+  let constructor = "mlir::createConvertNVGPUToNVVMPass()";
+  let dependentDialects = [
+    "NVVM::NVVMDialect",
+  ];
+}
+
+
 //===----------------------------------------------------------------------===//
 // OpenACCToSCF
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 078a66dc821b2..f93a32384becc 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -1369,58 +1369,4 @@ def GPU_DeviceAsyncWaitOp : GPU_Op<"device_async_wait", []> {
   }];
 }
 
-def GPU_MmaLdMatrixOp : GPU_Op<"mma.ldmatrix", 
-                                [MemoryEffects<[MemRead]>]> {
-  let description = [{
-  The `gpu.mma.ldmatrix` op represents loading a matrix fragment from
-  memory. The load source and result type must be compatible with lowering 
-  to the `nvvm.ldmatrix` instruction. This op is meant to represent 
-  the distributed version of a `vector.transfer_read` as an intermediate 
-  step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`.
-  
-  Example:
-
-  ```mlir
-  gpu.mma.ldmatrix %shm_buffer[%c0, %c0] : memref<16x16xf16, 3> -> vector<4x2xf16>
-  ```
-  }];                                
-
-  let arguments = (ins Arg<AnyMemRef, "", [MemRead]>:$srcMemref,
-                      Variadic<Index>:$indices, BoolAttr:$transpose,
-                      I32Attr:$numTiles);
-  let results = (outs AnyVector:$res);
-  let assemblyFormat = [{
-    $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)
-  }];                      
-}
-
-def GPU_MmaSyncOp : GPU_Op<"mma.sync", [NoSideEffect]> {
-  let description = [{
-  The `gpu.mma.sync` op represents the distributed form of a collective
-  matrix-multiply-and-accumulate (mma) operation that is compatible with
-  `nvvm.mma.sync`. The operands and results are fragments of the full matrix 
-  operands. The full shape of the distributed mma operation is given by the
-  `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`.  
-
-  This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and
-  is an intermediate point between lowering from `vector.contract` to
-  `nvvm.mma.sync`.
-  
-  Example:
-
-  ```mlir
-  gpu.mma.sync (%a, %b, %c) : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
-  ```
-  }];   
-  let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, AnyVector:$matrixC,
-    I64ArrayAttr:$mmaShape);
-
-  let results = (outs AnyVector:$res);
-
-  let assemblyFormat = [{
-    `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
-    `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
-  }];
-}
-
 #endif // GPU_OPS

diff  --git a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
index 9ed34ace009b6..28147e3fe4aae 100644
--- a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/NVGPU.td
@@ -69,4 +69,37 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix",
   }];
 }
 
+def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
+  let description = [{
+  The `nvgpu.mma.sync` op represents the distributed form of a collective
+  matrix-multiply-and-accumulate (mma) operation that is compatible with
+  `nvvm.mma.sync`. The operands and results are fragments of the full matrix 
+  operands. The full shape of the distributed mma operation is given by the
+  `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`.  
+
+  This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and
+  is an intermediate point between lowering from `vector.contract` to
+  `nvvm.mma.sync`.
+  
+  This operation is meant to follow the semantic of described here:
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma
+  
+  Example:
+  
+  ```mlir
+  nvgpu.mma.sync (%a, %b, %c) :
+    (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  ```
+  }];   
+  let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB,
+                       AnyVector:$matrixC, I64ArrayAttr:$mmaShape);
+
+  let results = (outs AnyVector:$res);
+
+  let assemblyFormat = [{
+    `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
+    `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
+  }];
+}
+
 #endif // NVGPU

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 5ef84273cbe0c..533d61acea414 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@ add_subdirectory(MathToLLVM)
 add_subdirectory(MathToSPIRV)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
+add_subdirectory(NVGPUToNVVM)
 add_subdirectory(OpenACCToLLVM)
 add_subdirectory(OpenACCToSCF)
 add_subdirectory(OpenMPToLLVM)

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index e5145f6513fdf..6f0e585365e29 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -208,290 +208,6 @@ struct GPUAsyncWaitLowering
   }
 };
 
-struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<gpu::MmaLdMatrixOp> {
-  using ConvertOpToLLVMPattern<gpu::MmaLdMatrixOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::MmaLdMatrixOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    MLIRContext *ctx = getContext();
-    Location loc = op->getLoc();
-
-    // The result type of ldmatrix will always be a struct of 32bit integer
-    // registers if more than one 32bit value is returned. Otherwise, the result
-    // is a single i32. The result type of the GPU operation is always a vector
-    // of shape (NumRegisters, VectorRegister) where VectorRegister is the
-    // vector type of the result and always 32 bits long. We bitcast the result
-    // of the NVVM::LdMatrix to this vector type.
-    auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
-    if (!vectorResultType) {
-      return failure();
-    }
-    Type innerVectorType = LLVM::getFixedVectorType(
-        vectorResultType.getElementType(), vectorResultType.getDimSize(1));
-
-    int64_t num32BitRegs = vectorResultType.getDimSize(0);
-
-    Type ldMatrixResultType;
-    if (num32BitRegs > 1) {
-      ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
-          ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
-    } else {
-      ldMatrixResultType = rewriter.getI32Type();
-    }
-
-    auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>();
-    Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(),
-                                        adaptor.indices(), rewriter);
-    Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
-        loc, ldMatrixResultType, srcPtr,
-        /*num=*/op.numTiles(),
-        /*layout=*/op.transpose() ? NVVM::MMALayout::col
-                                  : NVVM::MMALayout::row);
-
-    // The ldmatrix operation returns either a single i32 value or a struct of
-    // i32 values. Here we unpack those values and cast them back to their
-    // actual vector type (still of width 32b) and repack them into a result
-    // struct.
-    Type finalResultType = typeConverter->convertType(vectorResultType);
-    Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
-    for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
-      Value i32Register = num32BitRegs > 1
-                              ? rewriter.create<LLVM::ExtractValueOp>(
-                                    loc, rewriter.getI32Type(), ldMatrixResult,
-                                    rewriter.getI64ArrayAttr(i))
-                              : ldMatrixResult;
-      Value casted =
-          rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
-      result = rewriter.create<LLVM::InsertValueOp>(
-          loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
-    }
-
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-/// Checks if all the operands of the op being lowered are of LLVM Types. The
-/// types are expected to be converted by the `LLVMTypeConverter` before the
-/// op is actually lowered. If the type of an operands is not already
-/// converted it hints a missing typeConversion and failure is returned in
-/// that case.
-LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
-                              ConversionPatternRewriter &rewriter) {
-  if (!llvm::all_of(operands, [](Value value) {
-        return LLVM::isCompatibleType(value.getType());
-      })) {
-    return rewriter.notifyMatchFailure(
-        op, "cannot convert if operands aren't of LLVM type.");
-  }
-
-  return success();
-}
-
-/// Returns the type for the intrinsic given the vectorResultType of the
-/// `gpu.mma.sync` operation.
-Type inferIntrinsicResultType(Type vectorResultType) {
-  MLIRContext *ctx = vectorResultType.getContext();
-  auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
-  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
-  auto i32Ty = IntegerType::get(ctx, 32);
-  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
-  Type f64Ty = Float64Type::get(ctx);
-  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
-  if (a.getElementType() == f16x2Ty) {
-    return LLVM::LLVMStructType::getLiteral(
-        ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
-  }
-  if (a.getElementType() == i32x2Ty) {
-    return LLVM::LLVMStructType::getLiteral(
-        ctx,
-        SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
-  }
-  if (a.getElementType() == f64x2Ty) {
-    return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
-  }
-  return vectorResultType;
-}
-
-/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
-/// always an LLVM struct) into a fragment that is compatible with the vector
-/// type of this operation. This involves extracting elements from the struct
-/// and inserting them into an LLVM array. These extra data-movement
-/// operations should be canonicalized away by the LLVM backend.
-Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
-                             Type resultType, Value intrinsicResult,
-                             RewriterBase &rewriter) {
-  MLIRContext *ctx = rewriter.getContext();
-  auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
-  auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
-  Type i32Ty = rewriter.getI32Type();
-  Type f64Ty = rewriter.getF64Type();
-  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
-  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
-  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
-
-  auto makeConst = [&](int32_t index) -> Value {
-    return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
-                                             rewriter.getI32IntegerAttr(index));
-  };
-
-  if (arrayType) {
-    SmallVector<Value, 4> elements;
-
-    if (arrayType.getElementType() == f16x2Ty) {
-      for (unsigned i = 0; i < structType.getBody().size(); i++) {
-        elements.push_back(rewriter.create<LLVM::ExtractValueOp>(
-            loc, structType.getBody()[i], intrinsicResult,
-            rewriter.getI64ArrayAttr(i)));
-      }
-    }
-
-    // The intrinsic returns i32 and f64 values as individual scalars. We need
-    // to extract them from the struct and pack them into vectors.
-    if (arrayType.getElementType() == i32x2Ty ||
-        arrayType.getElementType() == f64x2Ty) {
-      Value vec =
-          rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
-      for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
-        Value x1 = rewriter.create<LLVM::ExtractValueOp>(
-            loc, structType.getBody()[i * 2], intrinsicResult,
-            rewriter.getI64ArrayAttr(i * 2));
-        Value x2 = rewriter.create<LLVM::ExtractValueOp>(
-            loc, structType.getBody()[i * 2 + 1], intrinsicResult,
-            rewriter.getI64ArrayAttr(i * 2 + 1));
-        vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
-                                                     x1, makeConst(0));
-        vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
-                                                     x2, makeConst(1));
-      }
-      elements.push_back(vec);
-    }
-
-    // Create the final vectorized result.
-    Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
-    for (const auto &el : llvm::enumerate(elements)) {
-      result = rewriter.create<LLVM::InsertValueOp>(
-          loc, arrayType, result, el.value(),
-          rewriter.getI64ArrayAttr(el.index()));
-    }
-    return result;
-  }
-
-  return intrinsicResult;
-}
-
-/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
-/// given as 2D `vectors` where the rows are 32b or 64b wide. The
-/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
-/// scalars of certain types. This function helps unpack the `vector` arguments
-/// and cast them to the types expected by `nvvm.mma.sync`.
-SmallVector<Value> unpackOperandVector(RewriterBase &rewriter, Location loc,
-                                       Value operand) {
-  SmallVector<Value> result;
-  Type i32Ty = rewriter.getI32Type();
-  Type f64Ty = rewriter.getF64Type();
-  Type i8Ty = rewriter.getI8Type();
-  Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
-  auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
-
-  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
-    Value toUse = rewriter.create<LLVM::ExtractValueOp>(
-        loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
-
-    // For 4xi8 vectors, the intrinsic expects these to be provided as i32
-    // scalar types.
-    if (arrayTy.getElementType() == i8x4Ty) {
-      result.push_back(
-          rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
-      continue;
-    }
-
-    // For some element types (i32, f64), we need to unpack the inner
-    // vector/array type as well because the intrinsic expects individual
-    // scalars to be provided.
-    VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
-    if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
-                         innerArrayTy.getElementType() == f64Ty)) {
-      for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
-           idx < innerSize; idx++) {
-        result.push_back(rewriter.create<LLVM::ExtractElementOp>(
-            loc, toUse,
-            rewriter.create<LLVM::ConstantOp>(
-                loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
-      }
-      continue;
-    }
-    result.push_back(toUse);
-  }
-  return result;
-}
-
-struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<gpu::MmaSyncOp> {
-  using ConvertOpToLLVMPattern<gpu::MmaSyncOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::MmaSyncOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
-    if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) {
-      return failure();
-    }
-
-    // Get the shapes of the MMAMatrix type being used. The shapes will
-    // choose which intrinsic this op will be lowered to.
-    auto aType = op.matrixA().getType().cast<VectorType>();
-
-    int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt();
-    int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt();
-    int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt();
-    std::array<int64_t, 3> gemmShape{m, n, k};
-
-    SmallVector<Value> matA =
-        unpackOperandVector(rewriter, loc, adaptor.matrixA());
-    SmallVector<Value> matB =
-        unpackOperandVector(rewriter, loc, adaptor.matrixB());
-    SmallVector<Value> matC =
-        unpackOperandVector(rewriter, loc, adaptor.matrixC());
-
-    NVVM::MMATypes ptxTypeA;
-    NVVM::MMATypes ptxTypeB;
-    Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
-    if (aType.getElementType().isInteger(8)) {
-      ptxTypeA = NVVM::MMATypes::s8;
-      ptxTypeB = NVVM::MMATypes::s8;
-      overflow = NVVM::MMAIntOverflow::satfinite;
-
-    } else if (aType.getElementType().isF16()) {
-      ptxTypeA = NVVM::MMATypes::f16;
-      ptxTypeB = NVVM::MMATypes::f16;
-    } else if (aType.getElementType().isF64()) {
-      ptxTypeA = NVVM::MMATypes::f64;
-      ptxTypeB = NVVM::MMATypes::f64;
-    } else {
-      return op->emitError("could not deduce operand PTX types");
-    }
-
-    Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
-    Type intrinsicResTy = inferIntrinsicResultType(
-        typeConverter->convertType(op->getResultTypes()[0]));
-    Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
-        op.getLoc(), intrinsicResTy, matA, matB, matC,
-        /*shape=*/gemmShape,
-        /*b1Op=*/llvm::None,
-        /*intOverflow=*/overflow,
-        /*multiplicandPtxTypes=*/
-        std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
-        /*multiplicandLayouts=*/
-        std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
-                                       NVVM::MMALayout::col});
-    rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
-                                                  desiredRetTy, intrinsicResult,
-                                                  rewriter));
-    return success();
-  }
-};
-
 struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
 
@@ -611,8 +327,8 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                        NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
            GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
                                        NVVM::GridDimYOp, NVVM::GridDimZOp>,
-           GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering,
-           MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter);
+           GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
+          converter);
 
   // Explicitly drop memory space when lowering private memory
   // attributions since NVVM models it as `alloca`s in the default

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
new file mode 100644
index 0000000000000..50750e478095b
--- /dev/null
+++ b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_conversion_library(MLIRNVGPUToNVVM
+  NVGPUToNVVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/NVGPUToNVVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRLLVMCommonConversion
+  MLIRLLVMIR
+  MLIRNVVMIR
+  MLIRNVGPU
+  MLIRPass
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
new file mode 100644
index 0000000000000..a304e49e58387
--- /dev/null
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -0,0 +1,308 @@
+//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
+#include "../PassDetail.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
+
+using namespace mlir;
+
+/// Returns the type for the intrinsic given the vectorResultType of the
+/// `gpu.mma.sync` operation.
+static Type inferIntrinsicResultType(Type vectorResultType) {
+  MLIRContext *ctx = vectorResultType.getContext();
+  auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
+  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
+  auto i32Ty = IntegerType::get(ctx, 32);
+  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+  Type f64Ty = Float64Type::get(ctx);
+  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+  if (a.getElementType() == f16x2Ty) {
+    return LLVM::LLVMStructType::getLiteral(
+        ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
+  }
+  if (a.getElementType() == i32x2Ty) {
+    return LLVM::LLVMStructType::getLiteral(
+        ctx,
+        SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
+  }
+  if (a.getElementType() == f64x2Ty) {
+    return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
+  }
+  return vectorResultType;
+}
+
+/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
+/// always an LLVM struct) into a fragment that is compatible with the vector
+/// type of this operation. This involves extracting elements from the struct
+/// and inserting them into an LLVM array. These extra data-movement
+/// operations should be canonicalized away by the LLVM backend.
+static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
+                                    Type resultType, Value intrinsicResult,
+                                    RewriterBase &rewriter) {
+  MLIRContext *ctx = rewriter.getContext();
+  auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
+  auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
+  Type i32Ty = rewriter.getI32Type();
+  Type f64Ty = rewriter.getF64Type();
+  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
+  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+
+  auto makeConst = [&](int32_t index) -> Value {
+    return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
+                                             rewriter.getI32IntegerAttr(index));
+  };
+
+  if (arrayType) {
+    SmallVector<Value, 4> elements;
+
+    if (arrayType.getElementType() == f16x2Ty) {
+      for (unsigned i = 0; i < structType.getBody().size(); i++) {
+        elements.push_back(rewriter.create<LLVM::ExtractValueOp>(
+            loc, structType.getBody()[i], intrinsicResult,
+            rewriter.getI64ArrayAttr(i)));
+      }
+    }
+
+    // The intrinsic returns i32 and f64 values as individual scalars. We need
+    // to extract them from the struct and pack them into vectors.
+    if (arrayType.getElementType() == i32x2Ty ||
+        arrayType.getElementType() == f64x2Ty) {
+      Value vec =
+          rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
+      for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
+        Value x1 = rewriter.create<LLVM::ExtractValueOp>(
+            loc, structType.getBody()[i * 2], intrinsicResult,
+            rewriter.getI64ArrayAttr(i * 2));
+        Value x2 = rewriter.create<LLVM::ExtractValueOp>(
+            loc, structType.getBody()[i * 2 + 1], intrinsicResult,
+            rewriter.getI64ArrayAttr(i * 2 + 1));
+        vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
+                                                     x1, makeConst(0));
+        vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
+                                                     x2, makeConst(1));
+      }
+      elements.push_back(vec);
+    }
+
+    // Create the final vectorized result.
+    Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
+    for (const auto &el : llvm::enumerate(elements)) {
+      result = rewriter.create<LLVM::InsertValueOp>(
+          loc, arrayType, result, el.value(),
+          rewriter.getI64ArrayAttr(el.index()));
+    }
+    return result;
+  }
+
+  return intrinsicResult;
+}
+
+/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
+/// given as 2D `vectors` where the rows are 32b or 64b wide. The
+/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
+/// scalars of certain types. This function helps unpack the `vector` arguments
+/// and cast them to the types expected by `nvvm.mma.sync`.
+static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
+                                              Location loc, Value operand) {
+  SmallVector<Value> result;
+  Type i32Ty = rewriter.getI32Type();
+  Type f64Ty = rewriter.getF64Type();
+  Type i8Ty = rewriter.getI8Type();
+  Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
+  auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
+
+  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
+    Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+        loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
+
+    // For 4xi8 vectors, the intrinsic expects these to be provided as i32
+    // scalar types.
+    if (arrayTy.getElementType() == i8x4Ty) {
+      result.push_back(
+          rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
+      continue;
+    }
+
+    // For some element types (i32, f64), we need to unpack the inner
+    // vector/array type as well because the intrinsic expects individual
+    // scalars to be provided.
+    VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
+    if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
+                         innerArrayTy.getElementType() == f64Ty)) {
+      for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
+           idx < innerSize; idx++) {
+        result.push_back(rewriter.create<LLVM::ExtractElementOp>(
+            loc, toUse,
+            rewriter.create<LLVM::ConstantOp>(
+                loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
+      }
+      continue;
+    }
+    result.push_back(toUse);
+  }
+  return result;
+}
+
+namespace {
+
+struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
+  using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MLIRContext *ctx = getContext();
+    Location loc = op->getLoc();
+
+    // The result type of ldmatrix will always be a struct of 32bit integer
+    // registers if more than one 32bit value is returned. Otherwise, the result
+    // is a single i32. The result type of the GPU operation is always a vector
+    // of shape (NumRegisters, VectorRegister) where VectorRegister is the
+    // vector type of the result and always 32 bits long. We bitcast the result
+    // of the NVVM::LdMatrix to this vector type.
+    auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
+    if (!vectorResultType) {
+      return failure();
+    }
+    Type innerVectorType = LLVM::getFixedVectorType(
+        vectorResultType.getElementType(), vectorResultType.getDimSize(1));
+
+    int64_t num32BitRegs = vectorResultType.getDimSize(0);
+
+    Type ldMatrixResultType;
+    if (num32BitRegs > 1) {
+      ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
+          ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
+    } else {
+      ldMatrixResultType = rewriter.getI32Type();
+    }
+
+    auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>();
+    Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(),
+                                        adaptor.indices(), rewriter);
+    Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
+        loc, ldMatrixResultType, srcPtr,
+        /*num=*/op.numTiles(),
+        /*layout=*/op.transpose() ? NVVM::MMALayout::col
+                                  : NVVM::MMALayout::row);
+
+    // The ldmatrix operation returns either a single i32 value or a struct of
+    // i32 values. Here we unpack those values and cast them back to their
+    // actual vector type (still of width 32b) and repack them into a result
+    // struct.
+    Type finalResultType = typeConverter->convertType(vectorResultType);
+    Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
+    for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
+      Value i32Register = num32BitRegs > 1
+                              ? rewriter.create<LLVM::ExtractValueOp>(
+                                    loc, rewriter.getI32Type(), ldMatrixResult,
+                                    rewriter.getI64ArrayAttr(i))
+                              : ldMatrixResult;
+      Value casted =
+          rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
+      result = rewriter.create<LLVM::InsertValueOp>(
+          loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
+  using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    // Get the shapes of the MMAMatrix type being used. The shapes will
+    // choose which intrinsic this op will be lowered to.
+    auto aType = op.matrixA().getType().cast<VectorType>();
+
+    int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt();
+    int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt();
+    int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt();
+    std::array<int64_t, 3> gemmShape{m, n, k};
+
+    SmallVector<Value> matA =
+        unpackOperandVector(rewriter, loc, adaptor.matrixA());
+    SmallVector<Value> matB =
+        unpackOperandVector(rewriter, loc, adaptor.matrixB());
+    SmallVector<Value> matC =
+        unpackOperandVector(rewriter, loc, adaptor.matrixC());
+
+    NVVM::MMATypes ptxTypeA;
+    NVVM::MMATypes ptxTypeB;
+    Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
+    if (aType.getElementType().isInteger(8)) {
+      ptxTypeA = NVVM::MMATypes::s8;
+      ptxTypeB = NVVM::MMATypes::s8;
+      overflow = NVVM::MMAIntOverflow::satfinite;
+
+    } else if (aType.getElementType().isF16()) {
+      ptxTypeA = NVVM::MMATypes::f16;
+      ptxTypeB = NVVM::MMATypes::f16;
+    } else if (aType.getElementType().isF64()) {
+      ptxTypeA = NVVM::MMATypes::f64;
+      ptxTypeB = NVVM::MMATypes::f64;
+    } else {
+      return op->emitError("could not deduce operand PTX types");
+    }
+
+    Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
+    Type intrinsicResTy = inferIntrinsicResultType(
+        typeConverter->convertType(op->getResultTypes()[0]));
+    Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
+        op.getLoc(), intrinsicResTy, matA, matB, matC,
+        /*shape=*/gemmShape,
+        /*b1Op=*/llvm::None,
+        /*intOverflow=*/overflow,
+        /*multiplicandPtxTypes=*/
+        std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
+        /*multiplicandLayouts=*/
+        std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
+                                       NVVM::MMALayout::col});
+    rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
+                                                  desiredRetTy, intrinsicResult,
+                                                  rewriter));
+    return success();
+  }
+};
+
+struct ConvertNVGPUToNVVMPass
+    : public ConvertNVGPUToNVVMBase<ConvertNVGPUToNVVMPass> {
+  ConvertNVGPUToNVVMPass() = default;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    LLVMTypeConverter converter(&getContext());
+    populateNVGPUToNVVMConversionPatterns(converter, patterns);
+    LLVMConversionTarget target(getContext());
+    target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+    target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
+                                                 RewritePatternSet &patterns) {
+  patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter);
+}
+
+std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {
+  return std::make_unique<ConvertNVGPUToNVVMPass>();
+}

diff  --git a/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir
deleted file mode 100644
index 2f70a15badb7d..0000000000000
--- a/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir
+++ /dev/null
@@ -1,129 +0,0 @@
-// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s
-
-gpu.module @test_module {
-  // CHECK-LABEL: @m16n8k16_fp16
-  func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
-    // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<2xf16>>
-    // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<4 x vector<2xf16>>
-    // CHECK: llvm.extractvalue %arg0[2] : !llvm.array<4 x vector<2xf16>>
-    // CHECK: llvm.extractvalue %arg0[3] : !llvm.array<4 x vector<2xf16>>
-
-    // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<2 x vector<2xf16>>
-    // CHECK: llvm.extractvalue %arg1[1] : !llvm.array<2 x vector<2xf16>>
-
-    // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>>
-    // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>>
-    // CHECK-NOT llvm.extractvalue
-    // CHECK: [[d:%.+]] = nvvm.mma.sync
-    // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n  = 8 : i32}
-    %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
-    // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>    
-    // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
-    // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
-    // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
-    // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>      
-    // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>>
-    return %d : vector<2x2xf16>
-  }
-
-  // CHECK-LABEL: @m16n8k8_fp16
-  func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
-    // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<2 x vector<2xf16>>
-    // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<2 x vector<2xf16>>
-
-    // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<1 x vector<2xf16>>
-
-    // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>>
-    // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>>
-    // CHECK-NOT llvm.extractvalue
-    // CHECK: [[d:%.+]] = nvvm.mma.sync
-    // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n  = 8 : i32}
-    %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
-    // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>    
-    // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
-    // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
-    // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
-    // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>      
-    // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>>
-    return %d : vector<2x2xf16>
-  }
-
-  // CHECK-LABEL: @m16n8k32_int8
-  func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
-
-    // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
-    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
-    // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
-    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
-    // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
-    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
-    // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
-    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
-
-    // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>>
-    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
-    // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>>
-    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
-
-    // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>>
-    // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>>
-
-    // CHECK: [[d:%.+]] = nvvm.mma.sync
-    // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
-    // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
-    // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
-    // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
-    %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
-
-    // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xi32>>
-    return %d : vector<2x2xi32>
-  }
-
-  // CHECK-LABEL: @m8n8k4_f64
-  func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
-    // CHECK: llvm.extractvalue %arg0
-    // CHECK: llvm.extractvalue %arg1
-    // CHECK: llvm.extractvalue %arg2
-
-    // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
-    // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}
-    %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>    
-    // CHECK: llvm.mlir.undef : vector<2xf64>
-    // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>    
-    // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
-    // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64>
-    // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>>
-    // CHECK: llvm.return {{%.+}} : !llvm.array<1 x vector<2xf64>>
-    return %d : vector<1x2xf64>
-  }
-
-  // CHECK-LABEL: @ldmatrix_x4
-  func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
-    %c0  = arith.constant 0 : index
-    // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
-    %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
-    // CHECK: llvm.extractvalue
-    // CHECK: llvm.bitcast
-    // CHECK: llvm.insertvalue
-    // CHECK: llvm.extractvalue
-    // CHECK: llvm.bitcast
-    // CHECK: llvm.insertvalue
-    // CHECK: llvm.extractvalue
-    // CHECK: llvm.bitcast
-    // CHECK: llvm.insertvalue
-    // CHECK: llvm.extractvalue
-    // CHECK: llvm.bitcast
-    // CHECK: llvm.insertvalue
-    return %a : vector<4x2xf16>
-  }
-
-  // CHECK-LABEL: @ldmatrix_x1
-  func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
-    %c0  = arith.constant 0 : index
-    // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
-    %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>    
-    // CHECK: llvm.bitcast
-    // CHECK: llvm.insertvalue    
-    return %a : vector<1x2xf16>
-  }
-}

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir
new file mode 100644
index 0000000000000..e095fecdfa87e
--- /dev/null
+++ b/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir
@@ -0,0 +1,127 @@
+// RUN: mlir-opt --convert-nvgpu-to-nvvm --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @m16n8k16_fp16
+func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+  // CHECK-NOT llvm.extractvalue
+  // CHECK: [[d:%.+]] = nvvm.mma.sync
+  // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n  = 8 : i32}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>    
+  // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
+  // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
+  // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>      
+  return %d : vector<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @m16n8k8_fp16
+func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<1 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+  // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+  // CHECK-NOT llvm.extractvalue
+  // CHECK: [[d:%.+]] = nvvm.mma.sync
+  // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n  = 8 : i32}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>    
+  // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
+  // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
+  // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>      
+  // CHECK: return
+  return %d : vector<2x2xf16>
+}
+
+// -----
+
+
+// CHECK-LABEL: @m16n8k32_int8
+func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+  // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+  // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+  // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+  // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<4xi8>>
+  // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<4xi8>>
+  // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+  // CHECK: [[d:%.+]] = nvvm.mma.sync
+  // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
+  // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
+  // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
+  // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  return %d : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @m8n8k4_f64
+func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
+  // CHECK: llvm.extractvalue
+  // CHECK: llvm.extractvalue
+  // CHECK: llvm.extractvalue
+  // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
+  // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+  // CHECK: llvm.mlir.undef : vector<2xf64>
+  // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>    
+  // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
+  // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64>
+  // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>>
+  // CHECK: return
+  return %d : vector<1x2xf64>
+}
+
+// -----
+
+
+// CHECK-LABEL: @ldmatrix_x4
+func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
+  %c0  = arith.constant 0 : index
+  // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
+  %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
+  // CHECK: llvm.extractvalue
+  // CHECK: llvm.bitcast
+  // CHECK: llvm.insertvalue
+  // CHECK: llvm.extractvalue
+  // CHECK: llvm.bitcast
+  // CHECK: llvm.insertvalue
+  // CHECK: llvm.extractvalue
+  // CHECK: llvm.bitcast
+  // CHECK: llvm.insertvalue
+  // CHECK: llvm.extractvalue
+  // CHECK: llvm.bitcast
+  // CHECK: llvm.insertvalue
+  return %a : vector<4x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @ldmatrix_x1
+func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
+  %c0  = arith.constant 0 : index
+  // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
+  %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>    
+  // CHECK: llvm.bitcast
+  // CHECK: llvm.insertvalue    
+  return %a : vector<1x2xf16>
+}

diff  --git a/mlir/test/Dialect/NVGPU/roundtrip.mlir b/mlir/test/Dialect/NVGPU/roundtrip.mlir
index 8a52180676445..5a35d39f1acca 100644
--- a/mlir/test/Dialect/NVGPU/roundtrip.mlir
+++ b/mlir/test/Dialect/NVGPU/roundtrip.mlir
@@ -8,3 +8,13 @@ func @ldmatrix(%arg0: memref<?x?xf16, 3>, %x: index, %y: index) {
     memref<?x?xf16, 3> -> vector<4x2xf16>
   return
 }
+
+// CHECK-LABEL: func @mma_sync(
+func @mma_sync(%arg0: vector<4x2xf16>,
+               %arg1: vector<2x2xf16>,
+               %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+//       CHECK: nvgpu.mma.sync(%{{.*}}, %{{.*}}, %{{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  %d = nvgpu.mma.sync(%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} :
+    (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  return %d : vector<2x2xf16>
+}


        


More information about the Mlir-commits mailing list