[Mlir-commits] [mlir] 77d2c81 - [MLIR][GPU] Add GPU ops nvvm.mma.sync, nvvm.mma.ldmatrix, lane_id

Thomas Raoux llvmlistbot at llvm.org
Wed Apr 13 15:50:25 PDT 2022


Author: Christopher Bate
Date: 2022-04-13T22:50:07Z
New Revision: 77d2c815f50b20d18f1207e4f442e2cf8eb8cec0

URL: https://github.com/llvm/llvm-project/commit/77d2c815f50b20d18f1207e4f442e2cf8eb8cec0
DIFF: https://github.com/llvm/llvm-project/commit/77d2c815f50b20d18f1207e4f442e2cf8eb8cec0.diff

LOG: [MLIR][GPU] Add GPU ops nvvm.mma.sync, nvvm.mma.ldmatrix, lane_id

This change adds three new operations to the GPU dialect: gpu.mma.sync,
gpu.mma.ldmatrix, and gpu.lane_id. The former two are meant to target
the lower level nvvm.mma.sync and nvvm.ldmatrix instructions, respectively.
Lowerings are added for the new GPU operations for conversion to
NVVM.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D123647

Added: 
    mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 9b42ecfaae215..eaee42bb8e319 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -96,6 +96,19 @@ def GPU_ThreadIdOp : GPU_IndexOp<"thread_id"> {
   }];
 }
 
+def GPU_LaneIdOp : GPU_Op<"lane_id", [NoSideEffect]> {
+  let description = [{
+    Returns the lane id within the subgroup (warp/wave).
+
+    Example:
+    ```mlir
+    %laneId = gpu.lane_id
+    ```
+  }];
+  let results = (outs Index:$result);
+  let assemblyFormat = "attr-dict";
+}
+
 def GPU_SubgroupIdOp : GPU_Op<"subgroup_id", [NoSideEffect]>,
     Arguments<(ins)>, Results<(outs Index:$result)> {
   let description = [{
@@ -1354,4 +1367,58 @@ def GPU_DeviceAsyncWaitOp : GPU_Op<"device_async_wait", []> {
   }];
 }
 
+def GPU_MmaLdMatrixOp : GPU_Op<"mma.ldmatrix", 
+                                [MemoryEffects<[MemRead]>]> {
+  let description = [{
+  The `gpu.mma.ldmatrix` op represents loading a matrix fragment from
+  memory. The load source and result type must be compatible with lowering 
+  to the `nvvm.ldmatrix` instruction. This op is meant to represent 
+  the distributed version of a `vector.transfer_read` as an intermediate 
+  step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`.
+  
+  Example:
+
+  ```mlir
+  gpu.mma.ldmatrix %shm_buffer[%c0, %c0] : memref<16x16xf16, 3> -> vector<4x2xf16>
+  ```
+  }];                                
+
+  let arguments = (ins Arg<AnyMemRef, "", [MemRead]>:$srcMemref,
+                      Variadic<Index>:$indices, BoolAttr:$transpose,
+                      I32Attr:$numTiles);
+  let results = (outs AnyVector:$res);
+  let assemblyFormat = [{
+    $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)
+  }];                      
+}
+
+def GPU_MmaSyncOp : GPU_Op<"mma.sync", [NoSideEffect]> {
+  let description = [{
+  The `gpu.mma.sync` op represents the distributed form of a collective
+  matrix-multiply-and-accumulate (mma) operation that is compatible with
+  `nvvm.mma.sync`. The operands and results are fragments of the full matrix 
+  operands. The full shape of the distributed mma operation is given by the
+  `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`.  
+
+  This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and
+  is an intermediate point between lowering from `vector.contract` to
+  `nvvm.mma.sync`.
+  
+  Example:
+
+  ```mlir
+  gpu.mma.sync (%a, %b, %c) : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  ```
+  }];   
+  let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, AnyVector:$matrixC,
+    I64ArrayAttr:$mmaShape);
+
+  let results = (outs AnyVector:$res);
+
+  let assemblyFormat = [{
+    `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
+    `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
+  }];
+}
+
 #endif // GPU_OPS

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 9efdc8a832a35..e5145f6513fdf 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -208,6 +208,314 @@ struct GPUAsyncWaitLowering
   }
 };
 
+struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<gpu::MmaLdMatrixOp> {
+  using ConvertOpToLLVMPattern<gpu::MmaLdMatrixOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::MmaLdMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MLIRContext *ctx = getContext();
+    Location loc = op->getLoc();
+
+    // The result type of ldmatrix will always be a struct of 32bit integer
+    // registers if more than one 32bit value is returned. Otherwise, the result
+    // is a single i32. The result type of the GPU operation is always a vector
+    // of shape (NumRegisters, VectorRegister) where VectorRegister is the
+    // vector type of the result and always 32 bits long. We bitcast the result
+    // of the NVVM::LdMatrix to this vector type.
+    auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
+    if (!vectorResultType) {
+      return failure();
+    }
+    Type innerVectorType = LLVM::getFixedVectorType(
+        vectorResultType.getElementType(), vectorResultType.getDimSize(1));
+
+    int64_t num32BitRegs = vectorResultType.getDimSize(0);
+
+    Type ldMatrixResultType;
+    if (num32BitRegs > 1) {
+      ldMatrixResultType = LLVM::LLVMStructType::getLiteral(
+          ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type()));
+    } else {
+      ldMatrixResultType = rewriter.getI32Type();
+    }
+
+    auto srcMemrefType = op.srcMemref().getType().cast<MemRefType>();
+    Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(),
+                                        adaptor.indices(), rewriter);
+    Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
+        loc, ldMatrixResultType, srcPtr,
+        /*num=*/op.numTiles(),
+        /*layout=*/op.transpose() ? NVVM::MMALayout::col
+                                  : NVVM::MMALayout::row);
+
+    // The ldmatrix operation returns either a single i32 value or a struct of
+    // i32 values. Here we unpack those values and cast them back to their
+    // actual vector type (still of width 32b) and repack them into a result
+    // struct.
+    Type finalResultType = typeConverter->convertType(vectorResultType);
+    Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
+    for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
+      Value i32Register = num32BitRegs > 1
+                              ? rewriter.create<LLVM::ExtractValueOp>(
+                                    loc, rewriter.getI32Type(), ldMatrixResult,
+                                    rewriter.getI64ArrayAttr(i))
+                              : ldMatrixResult;
+      Value casted =
+          rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
+      result = rewriter.create<LLVM::InsertValueOp>(
+          loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i));
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/// Checks if all the operands of the op being lowered are of LLVM Types. The
+/// types are expected to be converted by the `LLVMTypeConverter` before the
+/// op is actually lowered. If the type of an operands is not already
+/// converted it hints a missing typeConversion and failure is returned in
+/// that case.
+LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
+                              ConversionPatternRewriter &rewriter) {
+  if (!llvm::all_of(operands, [](Value value) {
+        return LLVM::isCompatibleType(value.getType());
+      })) {
+    return rewriter.notifyMatchFailure(
+        op, "cannot convert if operands aren't of LLVM type.");
+  }
+
+  return success();
+}
+
+/// Returns the type for the intrinsic given the vectorResultType of the
+/// `gpu.mma.sync` operation.
+Type inferIntrinsicResultType(Type vectorResultType) {
+  MLIRContext *ctx = vectorResultType.getContext();
+  auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
+  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
+  auto i32Ty = IntegerType::get(ctx, 32);
+  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+  Type f64Ty = Float64Type::get(ctx);
+  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+  if (a.getElementType() == f16x2Ty) {
+    return LLVM::LLVMStructType::getLiteral(
+        ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
+  }
+  if (a.getElementType() == i32x2Ty) {
+    return LLVM::LLVMStructType::getLiteral(
+        ctx,
+        SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty));
+  }
+  if (a.getElementType() == f64x2Ty) {
+    return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
+  }
+  return vectorResultType;
+}
+
+/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
+/// always an LLVM struct) into a fragment that is compatible with the vector
+/// type of this operation. This involves extracting elements from the struct
+/// and inserting them into an LLVM array. These extra data-movement
+/// operations should be canonicalized away by the LLVM backend.
+Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
+                             Type resultType, Value intrinsicResult,
+                             RewriterBase &rewriter) {
+  MLIRContext *ctx = rewriter.getContext();
+  auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
+  auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
+  Type i32Ty = rewriter.getI32Type();
+  Type f64Ty = rewriter.getF64Type();
+  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
+  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+
+  auto makeConst = [&](int32_t index) -> Value {
+    return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
+                                             rewriter.getI32IntegerAttr(index));
+  };
+
+  if (arrayType) {
+    SmallVector<Value, 4> elements;
+
+    if (arrayType.getElementType() == f16x2Ty) {
+      for (unsigned i = 0; i < structType.getBody().size(); i++) {
+        elements.push_back(rewriter.create<LLVM::ExtractValueOp>(
+            loc, structType.getBody()[i], intrinsicResult,
+            rewriter.getI64ArrayAttr(i)));
+      }
+    }
+
+    // The intrinsic returns i32 and f64 values as individual scalars. We need
+    // to extract them from the struct and pack them into vectors.
+    if (arrayType.getElementType() == i32x2Ty ||
+        arrayType.getElementType() == f64x2Ty) {
+      Value vec =
+          rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
+      for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
+        Value x1 = rewriter.create<LLVM::ExtractValueOp>(
+            loc, structType.getBody()[i * 2], intrinsicResult,
+            rewriter.getI64ArrayAttr(i * 2));
+        Value x2 = rewriter.create<LLVM::ExtractValueOp>(
+            loc, structType.getBody()[i * 2 + 1], intrinsicResult,
+            rewriter.getI64ArrayAttr(i * 2 + 1));
+        vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
+                                                     x1, makeConst(0));
+        vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
+                                                     x2, makeConst(1));
+      }
+      elements.push_back(vec);
+    }
+
+    // Create the final vectorized result.
+    Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType);
+    for (const auto &el : llvm::enumerate(elements)) {
+      result = rewriter.create<LLVM::InsertValueOp>(
+          loc, arrayType, result, el.value(),
+          rewriter.getI64ArrayAttr(el.index()));
+    }
+    return result;
+  }
+
+  return intrinsicResult;
+}
+
+/// The `gpu.mma.sync` converter below expects matrix fragment operands to be
+/// given as 2D `vectors` where the rows are 32b or 64b wide. The
+/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
+/// scalars of certain types. This function helps unpack the `vector` arguments
+/// and cast them to the types expected by `nvvm.mma.sync`.
+SmallVector<Value> unpackOperandVector(RewriterBase &rewriter, Location loc,
+                                       Value operand) {
+  SmallVector<Value> result;
+  Type i32Ty = rewriter.getI32Type();
+  Type f64Ty = rewriter.getF64Type();
+  Type i8Ty = rewriter.getI8Type();
+  Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
+  auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
+
+  for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
+    Value toUse = rewriter.create<LLVM::ExtractValueOp>(
+        loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i));
+
+    // For 4xi8 vectors, the intrinsic expects these to be provided as i32
+    // scalar types.
+    if (arrayTy.getElementType() == i8x4Ty) {
+      result.push_back(
+          rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
+      continue;
+    }
+
+    // For some element types (i32, f64), we need to unpack the inner
+    // vector/array type as well because the intrinsic expects individual
+    // scalars to be provided.
+    VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
+    if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
+                         innerArrayTy.getElementType() == f64Ty)) {
+      for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
+           idx < innerSize; idx++) {
+        result.push_back(rewriter.create<LLVM::ExtractElementOp>(
+            loc, toUse,
+            rewriter.create<LLVM::ConstantOp>(
+                loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
+      }
+      continue;
+    }
+    result.push_back(toUse);
+  }
+  return result;
+}
+
+struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<gpu::MmaSyncOp> {
+  using ConvertOpToLLVMPattern<gpu::MmaSyncOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::MmaSyncOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) {
+      return failure();
+    }
+
+    // Get the shapes of the MMAMatrix type being used. The shapes will
+    // choose which intrinsic this op will be lowered to.
+    auto aType = op.matrixA().getType().cast<VectorType>();
+
+    int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt();
+    int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt();
+    int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt();
+    std::array<int64_t, 3> gemmShape{m, n, k};
+
+    SmallVector<Value> matA =
+        unpackOperandVector(rewriter, loc, adaptor.matrixA());
+    SmallVector<Value> matB =
+        unpackOperandVector(rewriter, loc, adaptor.matrixB());
+    SmallVector<Value> matC =
+        unpackOperandVector(rewriter, loc, adaptor.matrixC());
+
+    NVVM::MMATypes ptxTypeA;
+    NVVM::MMATypes ptxTypeB;
+    Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
+    if (aType.getElementType().isInteger(8)) {
+      ptxTypeA = NVVM::MMATypes::s8;
+      ptxTypeB = NVVM::MMATypes::s8;
+      overflow = NVVM::MMAIntOverflow::satfinite;
+
+    } else if (aType.getElementType().isF16()) {
+      ptxTypeA = NVVM::MMATypes::f16;
+      ptxTypeB = NVVM::MMATypes::f16;
+    } else if (aType.getElementType().isF64()) {
+      ptxTypeA = NVVM::MMATypes::f64;
+      ptxTypeB = NVVM::MMATypes::f64;
+    } else {
+      return op->emitError("could not deduce operand PTX types");
+    }
+
+    Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
+    Type intrinsicResTy = inferIntrinsicResultType(
+        typeConverter->convertType(op->getResultTypes()[0]));
+    Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
+        op.getLoc(), intrinsicResTy, matA, matB, matC,
+        /*shape=*/gemmShape,
+        /*b1Op=*/llvm::None,
+        /*intOverflow=*/overflow,
+        /*multiplicandPtxTypes=*/
+        std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
+        /*multiplicandLayouts=*/
+        std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
+                                       NVVM::MMALayout::col});
+    rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy,
+                                                  desiredRetTy, intrinsicResult,
+                                                  rewriter));
+    return success();
+  }
+};
+
+struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
+  using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op->getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Value newOp = rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type());
+    // Truncate or extend the result depending on the index bitwidth specified
+    // by the LLVMTypeConverter options.
+    const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
+    if (indexBitwidth > 32) {
+      newOp = rewriter.create<LLVM::SExtOp>(
+          loc, IntegerType::get(context, indexBitwidth), newOp);
+    } else if (indexBitwidth < 32) {
+      newOp = rewriter.create<LLVM::TruncOp>(
+          loc, IntegerType::get(context, indexBitwidth), newOp);
+    }
+    rewriter.replaceOp(op, {newOp});
+    return success();
+  }
+};
+
 /// Import the GPU Ops to NVVM Patterns.
 #include "GPUToNVVM.cpp.inc"
 
@@ -303,7 +611,8 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                        NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
            GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
                                        NVVM::GridDimYOp, NVVM::GridDimZOp>,
-           GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
+           GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering,
+           MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter);
 
   // Explicitly drop memory space when lowering private memory
   // attributions since NVVM models it as `alloca`s in the default

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 973166b2b5b4d..771e2f10d3bdb 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -6,7 +6,8 @@ gpu.module @test_module {
   // CHECK32-LABEL: func @gpu_index_ops()
   func.func @gpu_index_ops()
       -> (index, index, index, index, index, index,
-          index, index, index, index, index, index) {
+          index, index, index, index, index, index,
+          index) {
     // CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64
 
     // CHECK: = nvvm.read.ptx.sreg.tid.x : i32
@@ -49,10 +50,17 @@ gpu.module @test_module {
     // CHECK: = llvm.sext %{{.*}} : i32 to i64
     %gDimZ = gpu.grid_dim z
 
+
+    // CHECK: = nvvm.read.ptx.sreg.laneid : i32
+    // CHECK: = llvm.sext %{{.*}} : i32 to i64
+    %laneId = gpu.lane_id
+
     func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
-               %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ
+               %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ,
+               %laneId
         : index, index, index, index, index, index,
-          index, index, index, index, index, index
+          index, index, index, index, index, index,
+          index
   }
 }
 

diff  --git a/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir
new file mode 100644
index 0000000000000..2f70a15badb7d
--- /dev/null
+++ b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir
@@ -0,0 +1,129 @@
+// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s
+
+gpu.module @test_module {
+  // CHECK-LABEL: @m16n8k16_fp16
+  func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+    // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<2xf16>>
+    // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<4 x vector<2xf16>>
+    // CHECK: llvm.extractvalue %arg0[2] : !llvm.array<4 x vector<2xf16>>
+    // CHECK: llvm.extractvalue %arg0[3] : !llvm.array<4 x vector<2xf16>>
+
+    // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractvalue %arg1[1] : !llvm.array<2 x vector<2xf16>>
+
+    // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>>
+    // CHECK-NOT llvm.extractvalue
+    // CHECK: [[d:%.+]] = nvvm.mma.sync
+    // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n  = 8 : i32}
+    %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+    // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>    
+    // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+    // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
+    // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
+    // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>      
+    // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>>
+    return %d : vector<2x2xf16>
+  }
+
+  // CHECK-LABEL: @m16n8k8_fp16
+  func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+    // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<2 x vector<2xf16>>
+
+    // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<1 x vector<2xf16>>
+
+    // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>>
+    // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>>
+    // CHECK-NOT llvm.extractvalue
+    // CHECK: [[d:%.+]] = nvvm.mma.sync
+    // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n  = 8 : i32}
+    %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+    // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>    
+    // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+    // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
+    // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>>
+    // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>>      
+    // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>>
+    return %d : vector<2x2xf16>
+  }
+
+  // CHECK-LABEL: @m16n8k32_int8
+  func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+
+    // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+    // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+    // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+    // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>>
+    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+
+    // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>>
+    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+    // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>>
+    // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32
+
+    // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+    // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+
+    // CHECK: [[d:%.+]] = nvvm.mma.sync
+    // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
+    // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
+    // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
+    // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
+    %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+
+    // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xi32>>
+    return %d : vector<2x2xi32>
+  }
+
+  // CHECK-LABEL: @m8n8k4_f64
+  func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
+    // CHECK: llvm.extractvalue %arg0
+    // CHECK: llvm.extractvalue %arg1
+    // CHECK: llvm.extractvalue %arg2
+
+    // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
+    // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}
+    %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>    
+    // CHECK: llvm.mlir.undef : vector<2xf64>
+    // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>    
+    // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)>
+    // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64>
+    // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>>
+    // CHECK: llvm.return {{%.+}} : !llvm.array<1 x vector<2xf64>>
+    return %d : vector<1x2xf64>
+  }
+
+  // CHECK-LABEL: @ldmatrix_x4
+  func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
+    %c0  = arith.constant 0 : index
+    // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
+    %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
+    // CHECK: llvm.extractvalue
+    // CHECK: llvm.bitcast
+    // CHECK: llvm.insertvalue
+    // CHECK: llvm.extractvalue
+    // CHECK: llvm.bitcast
+    // CHECK: llvm.insertvalue
+    // CHECK: llvm.extractvalue
+    // CHECK: llvm.bitcast
+    // CHECK: llvm.insertvalue
+    // CHECK: llvm.extractvalue
+    // CHECK: llvm.bitcast
+    // CHECK: llvm.insertvalue
+    return %a : vector<4x2xf16>
+  }
+
+  // CHECK-LABEL: @ldmatrix_x1
+  func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
+    %c0  = arith.constant 0 : index
+    // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
+    %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>    
+    // CHECK: llvm.bitcast
+    // CHECK: llvm.insertvalue    
+    return %a : vector<1x2xf16>
+  }
+}


        


More information about the Mlir-commits mailing list