[Mlir-commits] [mlir] eaaf7a6 - [MLIR][GPU][NVVM] Add conversion of warp synchronous matrix-multiply accumulate GPU ops

Uday Bondhugula llvmlistbot at llvm.org
Fri May 21 08:51:52 PDT 2021


Author: Navdeep Kumar
Date: 2021-05-21T21:20:33+05:30
New Revision: eaaf7a6a09da905cc314201f93e2be11773726a0

URL: https://github.com/llvm/llvm-project/commit/eaaf7a6a09da905cc314201f93e2be11773726a0
DIFF: https://github.com/llvm/llvm-project/commit/eaaf7a6a09da905cc314201f93e2be11773726a0.diff

LOG: [MLIR][GPU][NVVM] Add conversion of warp synchronous matrix-multiply accumulate GPU ops

Add conversion of warp synchronous matrix-multiply
accumulate GPU ops
Add conversion of warp synchronous matrix-multiply accumulate GPU ops to
NVVM ops. The following conversions are added :-
  1.) subgroup_mma_load_matrix -> wmma.m16n16k16.load.[a,b,c]..row.stride
  2.) subgroup_mma_store_matrix -> wmma.m16n16k16.store.d.[f16,f32].row.stride
  3.) subgroup_mma_compute -> wmma.m16n16k16.mma.row.row.[f16,f32].[f16,f32]

Reviewed By: bondhugula, ftynse

Differential Revision: https://reviews.llvm.org/D95331

Added: 
    mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
    mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir

Modified: 
    mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
    mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index a005fb50226f5..e291a77e3a9be 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -31,6 +31,10 @@ void configureGpuToNVVMConversionLegality(ConversionTarget &target);
 void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                          RewritePatternSet &patterns);
 
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
+void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
+                                             RewritePatternSet &patterns);
+
 /// Creates a pass that lowers GPU dialect operations to NVVM counterparts. The
 /// index bitwidth used for the lowering of the device side index computations
 /// is configurable.

diff  --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
index e2647103ba22d..040b220135180 100644
--- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
@@ -4,6 +4,7 @@ add_public_tablegen_target(MLIRGPUToNVVMIncGen)
 
 add_mlir_conversion_library(MLIRGPUToNVVMTransforms
   LowerGpuOpsToNVVMOps.cpp
+  WmmaOpsToNvvm.cpp
 
   DEPENDS
   MLIRConversionPassIncGen

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 8471de0a04152..42e64e5fb3c6a 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -126,6 +126,38 @@ struct LowerGpuOpsToNVVMOpsPass
       return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
     });
 
+    // Lowering for MMAMatrixType.
+    converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+      // The number of items in structToReturn are dependent on the the dataType
+      // and the MMA operand that this operation is associated with.
+      llvm::DenseMap<StringRef, int64_t> numElemsPerThreadF16,
+          numElemsPerThreadF32;
+      numElemsPerThreadF16["AOp"] = 8;
+      numElemsPerThreadF16["BOp"] = 8;
+      numElemsPerThreadF16["COp"] = 4;
+      numElemsPerThreadF16["DOp"] = 4;
+      numElemsPerThreadF32["AOp"] = 8;
+      numElemsPerThreadF32["BOp"] = 8;
+      numElemsPerThreadF32["COp"] = 8;
+      numElemsPerThreadF32["DOp"] = 8;
+      Type structToReturn;
+      if (type.getElementType().isF16()) {
+        // Number of f16's in 32-bit.
+        unsigned vecSize = 2;
+        Type vec = VectorType::get(vecSize, FloatType::getF16(&getContext()));
+        unsigned size = numElemsPerThreadF16[type.getOperand()];
+        SmallVector<Type> elements(size, vec);
+        structToReturn =
+            LLVM::LLVMStructType::getLiteral(&getContext(), elements);
+      } else if (type.getElementType().isF32()) {
+        unsigned size = numElemsPerThreadF32[type.getOperand()];
+        SmallVector<Type> elements(size, FloatType::getF32(&getContext()));
+        structToReturn =
+            LLVM::LLVMStructType::getLiteral(&getContext(), elements);
+      }
+      return structToReturn;
+    });
+
     RewritePatternSet patterns(m.getContext());
     RewritePatternSet llvmPatterns(m.getContext());
 
@@ -137,6 +169,7 @@ struct LowerGpuOpsToNVVMOpsPass
 
     populateStdToLLVMConversionPatterns(converter, llvmPatterns);
     populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
+    populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
     LLVMConversionTarget target(getContext());
     configureGpuToNVVMConversionLegality(target);
     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))

diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
new file mode 100644
index 0000000000000..c4458aa05f96d
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -0,0 +1,451 @@
+//===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
+// NVVM Dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Contains all the common LLVM types which are used across the lowerings of
+/// GPU subgroup ops to NVVM dialect.
+struct CommonLLVMAndBuiltInMLIRTypes {
+public:
+  CommonLLVMAndBuiltInMLIRTypes(MLIRContext *context) {
+    numHalfsInOpFrags.resize(4);
+    numHalfsInOpFrags[A] = 8;
+    numHalfsInOpFrags[B] = 8;
+    numHalfsInOpFrags[C] = 4;
+    numHalfsInOpFrags[D] = 4;
+    i32Ty = IntegerType::get(context, 32);
+    f16Ty = FloatType::getF16(context);
+    f32Ty = FloatType::getF32(context);
+    f16x2Ty = VectorType::get(2, f16Ty);
+    fragArrayABTy = LLVM::LLVMStructType::getLiteral(
+        context, SmallVector<Type>(8, f16x2Ty));
+    fragArrayCDTy = LLVM::LLVMStructType::getLiteral(
+        context, SmallVector<Type>(4, f16x2Ty));
+    fragArrayCDF32Ty =
+        LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
+  };
+
+  Type i32Ty;
+  Type f16Ty;
+  Type f32Ty;
+  Type f16x2Ty;
+  /// Type for the fragment of A and B operands that a single thread holds for
+  /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) +
+  /// (beta*C).
+  Type fragArrayABTy;
+  /// Type for the fragment of C and D operands that a single thread holds for
+  /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) +
+  /// (beta*C).
+  Type fragArrayCDTy;
+  /// Type for the fragment of C and D operands that a single thread holds for
+  /// fp32 data type in a WMMA operation of the form D = (alpha*(A*B)) +
+  /// (beta*C).
+  Type fragArrayCDF32Ty;
+  /// Represents the number of f16 elements a single thread holds in a WMMA
+  /// operation of the form D = (alpha*(A*B)) + (beta*C) .
+  SmallVector<unsigned, 4> numHalfsInOpFrags;
+  /// Represents the operands of a MMA operation of the form D = (alpha*(A*B)) +
+  /// (beta*C).
+  enum OperandMap { A, B, C, D };
+};
+
+/// Checks if all the operands of the op being lowered are of LLVM Types. The
+/// types are expected to be converted by the `LLVMTypeConverter` before the op
+/// is actually lowered. If the type of an operands is not already converted it
+/// hints a missing typeConversion and failure is returned in that case.
+static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
+                                     ConversionPatternRewriter &rewriter) {
+  if (!llvm::all_of(operands, [](Value value) {
+        return LLVM::isCompatibleType(value.getType());
+      })) {
+    return rewriter.notifyMatchFailure(
+        op, "cannot convert if operands aren't of LLVM type.");
+  }
+
+  return success();
+}
+
+/// Error string to emit when unimplemented WMMA variant is encountered.
+static constexpr StringRef kInvalidCaseStr =
+    "Unimplemented WMMA variant, Only M16N16K16 version implemented.";
+
+/// This class implements the conversion of GPU MMA loadOp to wmma.load op
+/// in the NVVM dialect. The conversion not only emits the NVVM op but also
+/// emits code that is necessary to store the data in the destination memref
+/// after it has been loaded.
+struct WmmaLoadOpToNVVMLowering
+    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp>,
+      private CommonLLVMAndBuiltInMLIRTypes {
+public:
+  explicit WmmaLoadOpToNVVMLowering(LLVMTypeConverter &typeConverter)
+      : ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp>(typeConverter),
+        CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
+  }
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Operation *op = subgroupMmaLoadMatrixOp.getOperation();
+    if (failed(areAllLLVMTypes(op, operands, rewriter)))
+      return failure();
+
+    unsigned indexTypeBitwidth =
+        this->getTypeConverter()->getIndexTypeBitwidth();
+
+    // The corresponding intrinsics expects leadDimension to be a 32-bit
+    // integer, so all the calculations of linearizing the load address
+    // must also follow this restriction.
+    if (indexTypeBitwidth != 32)
+      return rewriter.notifyMatchFailure(
+          op, "Expected indices to the memref to be 32-bit wide.");
+
+    // Source memref of the original op.
+    MemRefType srcMemrefType =
+        subgroupMmaLoadMatrixOp.srcMemref().getType().cast<MemRefType>();
+    Location loc = op->getLoc();
+
+    auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr();
+
+    // MemRefDescriptor to extract alignedPtr and offset.
+    MemRefDescriptor promotedSrcOp(
+        gpu::SubgroupMmaLoadMatrixOpAdaptor(operands).srcMemref());
+
+    // Emit ops which compute the load offset using `srcOffsetI`,
+    // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
+    // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are
+    // assumed to be normalized and hence the simple conversion works.
+    SmallVector<Value> indices(subgroupMmaLoadMatrixOp.indices());
+    Value srcOffsetIVal = indices[0];
+    Value srcOffsetJVal = indices[1];
+    Value leadingDim32 =
+        rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
+    Value numElemsLeadDim =
+        rewriter.create<LLVM::MulOp>(loc, i32Ty, leadingDim32, srcOffsetIVal);
+    Value loadOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, numElemsLeadDim,
+                                                    srcOffsetJVal);
+
+    Value promotedSrcOpToUse;
+    promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc);
+    Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
+                                                      promotedSrcOpToUse);
+    Value loadAddress = rewriter.create<LLVM::GEPOp>(
+        loc,
+        LLVM::LLVMPointerType::get(f16Ty, srcMemrefType.getMemorySpaceAsInt()),
+        promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
+
+    // Bitcast the base address pointer of the destination memref, So that
+    // values can be stored in chunks of 32-bits and semantics match with the
+    // intrinsic exposed by NVPTX backend.
+    Value loadAddressCasted = rewriter.create<LLVM::BitcastOp>(
+        loc,
+        LLVM::LLVMPointerType::get(i32Ty, srcMemrefType.getMemorySpaceAsInt()),
+        loadAddress);
+
+    // Get the shape of the MMAMatrix type being returned. The shape will
+    // choose which intrinsic this op will be lowered to.
+    gpu::MMAMatrixType retType =
+        subgroupMmaLoadMatrixOp.res().getType().cast<gpu::MMAMatrixType>();
+    ArrayRef<int64_t> retTypeShape = retType.getShape();
+
+    Type resType;
+    StringRef operandStr = retType.getOperand();
+    if (operandStr.equals("AOp") || operandStr.equals("BOp")) {
+      resType = fragArrayABTy;
+    } else {
+      if (srcMemrefType.getElementType().isF16())
+        resType = fragArrayCDTy;
+      else if (srcMemrefType.getElementType().isF32())
+        resType = fragArrayCDF32Ty;
+      else
+        return failure();
+    }
+
+    // Create nvvm.mma_load op according to the operand types.
+    SmallVector<Value, 2> loadOpOperands({loadAddressCasted, leadingDim32});
+    if (operandStr.equals("AOp")) {
+      if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
+        NVVM::WMMALoadAM16N16K16Op wmmaLoadAOp =
+            rewriter.create<NVVM::WMMALoadAM16N16K16Op>(loc, resType,
+                                                        loadOpOperands);
+        rewriter.replaceOp(op, wmmaLoadAOp.getResult());
+      } else {
+        return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+      }
+    } else if (operandStr.equals("BOp")) {
+      if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
+        NVVM::WMMALoadBM16N16K16Op wmmaLoadBOp =
+            rewriter.create<NVVM::WMMALoadBM16N16K16Op>(loc, resType,
+                                                        loadOpOperands);
+        rewriter.replaceOp(op, wmmaLoadBOp.getResult());
+      } else {
+        return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+      }
+    } else {
+      if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
+        if (srcMemrefType.getElementType().isF16()) {
+          NVVM::WMMALoadCF16M16N16K16Op wmmaLoadCOp =
+              rewriter.create<NVVM::WMMALoadCF16M16N16K16Op>(loc, resType,
+                                                             loadOpOperands);
+          rewriter.replaceOp(op, wmmaLoadCOp.getResult());
+        } else if (srcMemrefType.getElementType().isF32()) {
+          NVVM::WMMALoadCF32M16N16K16Op wmmaLoadCOp =
+              rewriter.create<NVVM::WMMALoadCF32M16N16K16Op>(loc, resType,
+                                                             loadOpOperands);
+          rewriter.replaceOp(op, wmmaLoadCOp.getResult());
+        }
+      } else {
+        return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+      }
+    }
+    return success();
+  }
+};
+
+/// This class implements the conversion of GPU MMA storeOp to wmma.store op
+/// in the NVVM dialect. The conversion not only emits the NVVM op but also
+/// emits code that is necessary to unpack the data in the source and
+/// convert the data in the format that is needed by the NVVM op.
+struct WmmaStoreOpToNVVMLowering
+    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp>,
+      private CommonLLVMAndBuiltInMLIRTypes {
+public:
+  explicit WmmaStoreOpToNVVMLowering(LLVMTypeConverter &typeConverter)
+      : ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp>(typeConverter),
+        CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
+  }
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Operation *op = subgroupMmaStoreMatrixOp.getOperation();
+    if (failed(areAllLLVMTypes(op, operands, rewriter)))
+      return failure();
+
+    unsigned indexTypeBitwidth =
+        this->getTypeConverter()->getIndexTypeBitwidth();
+    // The corresponding intrinsics expects leadDimension to be a 32-bit
+    // integer, so all the calculations of linearizing the store address
+    // must also follow this restriction.
+    if (indexTypeBitwidth != 32)
+      return rewriter.notifyMatchFailure(
+          op, "expected indices to the memref to be 32-bit wide.");
+
+    Location loc = op->getLoc();
+
+    // Destination memref of the original op.
+    MemRefType dstMemrefType =
+        subgroupMmaStoreMatrixOp.dstMemref().getType().cast<MemRefType>();
+
+    // MemRefDescriptor to extract alignedPtr and offset.
+    MemRefDescriptor promotedDstOp(
+        gpu::SubgroupMmaStoreMatrixOpAdaptor(operands).dstMemref());
+
+    auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr();
+
+    // Emit ops which compute the store offset using `dstOffsetI`,
+    // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr +
+    // ((leadDimension * dstOffsetI) + dstOffsetJ)).
+    SmallVector<Value> indices(subgroupMmaStoreMatrixOp.indices());
+    Value dstOffsetIVal = indices[0];
+    Value dstOffsetJVal = indices[1];
+    Value leadingDim32 =
+        rewriter.create<LLVM::ConstantOp>(loc, i32Ty, leadDimension);
+    Value numElemsLeadDim =
+        rewriter.create<LLVM::MulOp>(loc, i32Ty, leadingDim32, dstOffsetIVal);
+    Value loadOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, numElemsLeadDim,
+                                                    dstOffsetJVal);
+
+    Value promotedDstOpToUse;
+    promotedDstOpToUse = promotedDstOp.offset(rewriter, loc);
+    Value actualOffset = rewriter.create<LLVM::AddOp>(loc, i32Ty, loadOffset,
+                                                      promotedDstOpToUse);
+    Value storeAddress = rewriter.create<LLVM::GEPOp>(
+        loc,
+        LLVM::LLVMPointerType::get(f16Ty, dstMemrefType.getMemorySpaceAsInt()),
+        promotedDstOp.alignedPtr(rewriter, loc), ArrayRef<Value>{actualOffset});
+
+    // Bitcast the base address pointer of the destination memref, So that
+    // values can be stored in chunks of 32-bits and semantics match with the
+    // intrinsic exposed by NVPTX backend.
+    Value storeAddressCasted = rewriter.create<LLVM::BitcastOp>(
+        loc,
+        LLVM::LLVMPointerType::get(i32Ty, dstMemrefType.getMemorySpaceAsInt()),
+        storeAddress);
+
+    SmallVector<Value, 4> storeOpOperands;
+    storeOpOperands.push_back(storeAddressCasted);
+
+    // Get the shape of the MMAMatrix type being stored. The shape will
+    // choose which intrinsic this op will be lowered to.
+    gpu::MMAMatrixType srcType =
+        subgroupMmaStoreMatrixOp.src().getType().cast<gpu::MMAMatrixType>();
+    ArrayRef<int64_t> srcTypeShape = srcType.getShape();
+
+    // Unpack the results from the source.
+    if (subgroupMmaStoreMatrixOp.src()
+            .getType()
+            .cast<gpu::MMAMatrixType>()
+            .getElementType() == f16Ty) {
+      for (unsigned i = 0, e = numHalfsInOpFrags[D]; i < e; ++i) {
+        Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+            loc, f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i));
+        storeOpOperands.push_back(toUse);
+      }
+      storeOpOperands.push_back(leadingDim32);
+
+      // Create nvvm.mma_store op.
+      if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) {
+        rewriter.create<NVVM::WMMAStoreF16M16N16K16Op>(loc, storeOpOperands);
+      } else {
+        return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+      }
+      rewriter.eraseOp(op);
+      return success();
+    } else if (subgroupMmaStoreMatrixOp.src()
+                   .getType()
+                   .cast<gpu::MMAMatrixType>()
+                   .getElementType() == f32Ty) {
+      for (unsigned i = 0, e = 8; i < e; ++i) {
+        Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+            loc, f32Ty, operands[0], rewriter.getI32ArrayAttr(i));
+        storeOpOperands.push_back(toUse);
+      }
+      storeOpOperands.push_back(leadingDim32);
+
+      // Create nvvm.mma_store op.
+      if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16)
+        rewriter.create<NVVM::WMMAStoreF32M16N16K16Op>(loc, storeOpOperands);
+      else {
+        return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+      }
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+/// This class implements the conversion of GPU MMA computeOp to wmma.mma op
+/// in the NVVM dialect.
+struct WmmaMmaOpToNVVMLowering
+    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp>,
+      private CommonLLVMAndBuiltInMLIRTypes {
+  explicit WmmaMmaOpToNVVMLowering(LLVMTypeConverter &typeConverter)
+      : ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp>(typeConverter),
+        CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) {
+  }
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Operation *op = subgroupMmaComputeOp.getOperation();
+    if (failed(areAllLLVMTypes(op, operands, rewriter)))
+      return failure();
+
+    Location loc = op->getLoc();
+
+    // The wmma.mma intrinsic in llvm requires the operands as individual
+    // values. So individual elements from the memrefs need to be extracted and
+    // then passed on to the intrinsic call. Emit llvm ops to extract individual
+    // values form lowered memrefs.
+    SmallVector<Value> unpackedOps;
+
+    auto unpackOp = [&](CommonLLVMAndBuiltInMLIRTypes::OperandMap op,
+                        Value operand, unsigned numElems, Type elemType) {
+      for (unsigned i = 0; i < numElems; ++i) {
+        Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+            loc, elemType, operand, rewriter.getI32ArrayAttr(i));
+        unpackedOps.push_back(toUse);
+      }
+    };
+
+    // Get the shapes of the MMAMatrix type being used. The shapes will
+    // choose which intrinsic this op will be lowered to.
+    gpu::MMAMatrixType aType =
+        subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
+    ArrayRef<int64_t> aTypeShape = aType.getShape();
+    gpu::MMAMatrixType bType =
+        subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
+    ArrayRef<int64_t> bTypeShape = bType.getShape();
+    gpu::MMAMatrixType cType =
+        subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
+    ArrayRef<int64_t> cTypeShape = cType.getShape();
+
+    gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands);
+    if (subgroupMmaComputeOp.opC()
+            .getType()
+            .cast<gpu::MMAMatrixType>()
+            .getElementType() == f16Ty) {
+      unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty);
+      unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty);
+      unpackOp(C, transformedOperands.opC(), numHalfsInOpFrags[C], f16x2Ty);
+
+      if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
+          bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
+        // Create nvvm.wmma.mma op.
+        NVVM::WMMAMmaF16F16M16N16K16Op wmmaMmaOp =
+            rewriter.create<NVVM::WMMAMmaF16F16M16N16K16Op>(loc, fragArrayCDTy,
+                                                            unpackedOps);
+
+        rewriter.replaceOp(op, wmmaMmaOp.getResult());
+        return success();
+      } else {
+        return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+      }
+    } else if (subgroupMmaComputeOp.opC()
+                   .getType()
+                   .cast<gpu::MMAMatrixType>()
+                   .getElementType() == f32Ty) {
+      unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty);
+      unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty);
+      unpackOp(C, transformedOperands.opC(), 8, f32Ty);
+
+      if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
+          bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
+        // Create nvvm.wmma.mma op.
+        NVVM::WMMAMmaF32F32M16N16K16Op wmmaMmaOp =
+            rewriter.create<NVVM::WMMAMmaF32F32M16N16K16Op>(
+                loc, fragArrayCDF32Ty, unpackedOps);
+
+        rewriter.replaceOp(op, wmmaMmaOp.getResult());
+        return success();
+      } else {
+        return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+      }
+    }
+
+    return failure();
+  }
+};
+
+} // anonymous namespace
+
+namespace mlir {
+void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
+                                             RewritePatternSet &patterns) {
+  patterns.insert<WmmaLoadOpToNVVMLowering>(converter);
+  patterns.insert<WmmaMmaOpToNVVMLowering>(converter);
+  patterns.insert<WmmaStoreOpToNVVMLowering>(converter);
+}
+} // namespace mlir

diff  --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
new file mode 100644
index 0000000000000..c44d7f0cfa301
--- /dev/null
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -0,0 +1,91 @@
+// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck %s
+
+gpu.module @test_module {
+
+  // CHECK-LABEL: func @gpu_wmma_load_op() ->
+  // CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> {
+  func @gpu_wmma_load_op() -> (!gpu.mma_matrix<16x16xf16, "AOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
+    %i = constant 16 : index
+    %j = constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+    // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i32
+    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
+    // CHECK:  %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i32
+    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+    // CHECK:  %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+    // CHECK:  %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM]] : (!llvm.ptr<i32, 3>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+
+  // CHECK-LABEL: func @gpu_wmma_store_op
+  // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
+  func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+    %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
+    %i = constant 16 : index
+    %j = constant 16 : index
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
+    // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+    // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:  %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]]  : i32
+    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
+    // CHECK:  %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK:  %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]]  : i32
+    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
+    // CHECK:  %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr<f16, 3> to !llvm.ptr<i32, 3>
+    // CHECK:  %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM]] : !llvm.ptr<i32, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
+    // CHECK:  llvm.return
+    return
+  }
+}
+
+// -----
+
+gpu.module @test_module {
+
+  // CHECK-LABEL: func @gpu_wmma_mma_op
+  // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
+  func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
+    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    // CHECK:  %[[A1:.*]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[A2:.*]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[A3:.*]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[A4:.*]] = llvm.extractvalue %[[A]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[A5:.*]] = llvm.extractvalue %[[A]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[A6:.*]] = llvm.extractvalue %[[A]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[A7:.*]] = llvm.extractvalue %[[A]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[A8:.*]] = llvm.extractvalue %[[A]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[B1:.*]] = llvm.extractvalue %[[B]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[B2:.*]] = llvm.extractvalue %[[B]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[B3:.*]] = llvm.extractvalue %[[B]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[B4:.*]] = llvm.extractvalue %[[B]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[B5:.*]] = llvm.extractvalue %[[B]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[B6:.*]] = llvm.extractvalue %[[B]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[B7:.*]] = llvm.extractvalue %[[B]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[B8:.*]] = llvm.extractvalue %[[B]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[C1:.*]] = llvm.extractvalue %[[C]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[C2:.*]] = llvm.extractvalue %[[C]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[C3:.*]] = llvm.extractvalue %[[C]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %[[C4:.*]] = llvm.extractvalue %[[C]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  %{{.*}} = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  llvm.return
+    return
+  }
+}


        


More information about the Mlir-commits mailing list