[Mlir-commits] [mlir] 4b3eaee - [mlir][AMDGPU] Define wrappers for WMMA matrix ops

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Jul 20 11:38:41 PDT 2023


Author: Giuseppe Rossini
Date: 2023-07-20T18:38:35Z
New Revision: 4b3eaee2701afa2549c9cd0f78692598e9bfd44e

URL: https://github.com/llvm/llvm-project/commit/4b3eaee2701afa2549c9cd0f78692598e9bfd44e
DIFF: https://github.com/llvm/llvm-project/commit/4b3eaee2701afa2549c9cd0f78692598e9bfd44e.diff

LOG: [mlir][AMDGPU] Define wrappers for WMMA matrix ops

Wave Matrix Multiply Accumulate (WMMA) is the instruction to accelerate
matrix multiplication on RDNA3 architectures.  LLVM already provides a
set of intrinsics to generate wmma instructions. This change uses those
intrinsics to enable the feature in MLIR.

Reviewed By: krzysz00

Differential Revision: https://reviews.llvm.org/D152451

Added: 
    mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
    mlir/test/Dialect/AMDGPU/invalid.mlir
    mlir/test/Dialect/AMDGPU/ops.mlir
    mlir/test/Target/LLVMIR/rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 5607fae575500a3..b41fc54e2c078f1 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -380,6 +380,10 @@ def MFMAOutTypes : AnyTypeOf<[F64,
                               VectorOfLengthAndType<[4, 16, 32], [F32]>,
                               VectorOfLengthAndType<[4, 16, 32], [I32]>,
                               VectorOfLengthAndType<[4], [F64]>]>;
+// wmma
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
+def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
+                              VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;
 
 def AMDGPU_MFMAOp :
     AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
@@ -438,4 +442,41 @@ def AMDGPU_MFMAOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_WMMAOp :
+    AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
+                       AllTypesMatch<["sourceA", "sourceB"]>,
+                        Pure]>,
+    Arguments<(ins
+                   WMMAInTypes:$sourceA,
+                   WMMAInTypes:$sourceB,
+                   WMMAOutTypes:$destC,
+                   DefaultValuedAttr<ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>, "0">:$subwordOffset,
+                   UnitAttr:$unsignedA,
+                   UnitAttr:$unsignedB,
+                   UnitAttr:$clamp)>,
+    Results<(outs WMMAOutTypes: $destD)> {
+  let summary = "MLIR wrapper for RDNA3 wmma instructions";
+  let description = [{
+    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
+    for various `wmma` instructions in the RDNA3 architecture, which perform
+    a 16x16 matrix multiplication for 
diff erent data types.
+
+    When emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 (or 16xbf16) vector
+    containing only 8 valid values:
+      - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
+      - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
+
+    `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
+
+    The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+    in case of overflow.
+  }];
+  let assemblyFormat = [{
+    $sourceA `*` $sourceB `+` $destC
+    attr-dict
+    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // AMDGPU

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 3187ce1615ff93f..1de2031a64c0307 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -124,6 +124,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
   let assemblyFormat = "attr-dict";
 }
 
+
 //===---------------------------------------------------------------------===//
 // Xdlops intrinsics
 
@@ -182,6 +183,26 @@ def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.f
 def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8">;
 def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">;
 
+//===---------------------------------------------------------------------===//
+// WMMA intrinsics
+class ROCDL_Wmma_IntrOp<string mnemonic, list<Trait> traits = []> :
+  LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
+                  "amdgcn_" # !subst(".","_", mnemonic),
+                  [0], [], traits, 1>,
+  Arguments<(ins Variadic<LLVM_Type>:$args)> {
+  let assemblyFormat =
+    "$args attr-dict `:` functional-type($args, $res)";
+}
+
+// Available on RDNA3
+def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16">;
+def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16">;
+def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16">;
+def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16">;
+def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8">;
+def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4">;
+
+
 //===---------------------------------------------------------------------===//
 // Vector buffer load/store intrinsics
 

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 0fc1aaef0c7db41..1de583c17809329 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -32,6 +32,12 @@ static Value createI32Constant(ConversionPatternRewriter &rewriter,
   return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
 }
 
+static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
+                              bool value) {
+  Type llvmI1 = rewriter.getI1Type();
+  return rewriter.createOrFold<LLVM::ConstantOp>(loc, llvmI1, value);
+}
+
 namespace {
 /// Define lowering patterns for raw buffer ops
 template <typename GpuOp, typename Intrinsic>
@@ -334,6 +340,64 @@ static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
   return input;
 }
 
+/// Push an input operand. If it is a float type, nothing to do. If it is
+/// an integer type, then we need to also push its signdness (1 for signed, 0
+/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
+/// vector.
+static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
+                                 Location loc, TypeConverter *typeConverter,
+                                 bool isUnsigned, Value llvmInput,
+                                 SmallVector<Value, 4> &operands) {
+  Type inputType = llvmInput.getType();
+  auto vectorType = inputType.dyn_cast<VectorType>();
+  Type elemType = vectorType.getElementType();
+
+  if (!elemType.isInteger(8)) {
+    operands.push_back(llvmInput);
+    return;
+  }
+
+  int64_t numBytes = vectorType.getNumElements();
+  Type i32 = rewriter.getI32Type();
+  VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
+  auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
+
+  Value result = rewriter.createOrFold<LLVM::BitcastOp>(
+      loc, llvmVectorType32bits, llvmInput);
+
+  // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
+  bool localIsUnsigned = isUnsigned;
+  if (elemType.isUnsignedInteger(8)) {
+    localIsUnsigned = true;
+  } else if (elemType.isSignedInteger(8)) {
+    localIsUnsigned = false;
+  }
+  Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
+  operands.push_back(sign);
+  operands.push_back(result);
+}
+
+/// Push the output operand. For many cases this is only pushing the output in
+/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
+/// since the same numbers of VGPRs is used, we need to decide if to store the
+/// result in the upper 16 bits of the VGPRs or in the lower part. To store the
+/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
+/// be stored it in the upper part
+static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
+                                  Location loc, TypeConverter *typeConverter,
+                                  Value output, int32_t subwordOffset,
+                                  bool clamp, SmallVector<Value, 4> &operands) {
+  Type inputType = output.getType();
+  auto vectorType = inputType.dyn_cast<VectorType>();
+  Type elemType = vectorType.getElementType();
+  operands.push_back(output);
+  if (elemType.isF16() || elemType.isBF16()) {
+    operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
+  } else if (elemType.isInteger(32)) {
+    operands.push_back(createI1Constant(rewriter, loc, clamp));
+  }
+}
+
 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -471,6 +535,31 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
   return std::nullopt;
 }
 
+/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// if one exists. This includes checking to ensure the intrinsic is supported
+/// on the architecture you are compiling for.
+static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
+                                                  Chipset chipset) {
+
+  auto sourceVectorType = wmma.getSourceA().getType().dyn_cast<VectorType>();
+  auto destVectorType = wmma.getDestC().getType().dyn_cast<VectorType>();
+  auto elemSourceType = sourceVectorType.getElementType();
+  auto elemDestType = destVectorType.getElementType();
+
+  if (elemSourceType.isF16() && elemDestType.isF32()) {
+    return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+  } else if (elemSourceType.isBF16() && elemDestType.isF32()) {
+    return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+  } else if (elemSourceType.isF16() && elemDestType.isF16()) {
+    return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+  } else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
+    return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+  } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
+    return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+  }
+  return std::nullopt;
+}
+
 namespace {
 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
   MFMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
@@ -510,6 +599,45 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
   }
 };
 
+struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
+  WMMAOpLowering(LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type outType = typeConverter->convertType(op.getDestD().getType());
+
+    if (chipset.majorVersion != 11)
+      return op->emitOpError("WMMA only supported on gfx11");
+
+    std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
+
+    if (!maybeIntrinsic.has_value())
+      return op.emitOpError("no intrinsic matching WMMA on the given chipset");
+
+    OperationState loweredOp(loc, *maybeIntrinsic);
+    loweredOp.addTypes(outType);
+
+    SmallVector<Value, 4> operands;
+    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
+                         adaptor.getSourceA(), operands);
+    wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
+                         adaptor.getSourceB(), operands);
+    wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
+                          op.getSubwordOffset(), op.getClamp(), operands);
+
+    loweredOp.addOperands(operands);
+    Operation *lowered = rewriter.create(loweredOp);
+    rewriter.replaceOp(op, lowered->getResults());
+
+    return success();
+  }
+};
+
 struct ConvertAMDGPUToROCDLPass
     : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
   ConvertAMDGPUToROCDLPass() = default;
@@ -549,7 +677,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
       RawBufferOpLowering<RawBufferAtomicUminOp, ROCDL::RawBufferAtomicUMinOp>,
       RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                           ROCDL::RawBufferAtomicCmpSwap>,
-      MFMAOpLowering>(converter, chipset);
+      MFMAOpLowering, WMMAOpLowering>(converter, chipset);
 }
 
 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 105535b05de859c..ac34acc8307485c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -205,6 +205,34 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
       context);
 }
 
+//===----------------------------------------------------------------------===//
+// WMMAOp
+//===----------------------------------------------------------------------===//
+LogicalResult WMMAOp::verify() {
+  Type sourceAType = getSourceA().getType();
+  Type destType = getDestC().getType();
+
+  VectorType sourceVectorAType = sourceAType.dyn_cast<VectorType>();
+  VectorType destVectorType = destType.dyn_cast<VectorType>();
+
+  Type sourceAElemType = sourceVectorAType.getElementType();
+  Type destElemType = destVectorType.getElementType();
+
+  bool isDestFloat =
+      (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
+  bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
+
+  if (isDestFloat && !isSrcFloat) {
+    return emitOpError("Expected float sources with float destination");
+  }
+
+  if (!isDestFloat && isSrcFloat) {
+    return emitOpError("Expected int sources with int destination");
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MFMAOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
new file mode 100644
index 000000000000000..407519605d42405
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
+                         %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
+                         %arg6 : vector<16xi8>, %arg7 : vector<4xi32>, %arg8 : vector<8xi32>,
+                         %arg9 : vector<16xui8>){
+  // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xf32>) -> vector<8xf32>
+  amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+  // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<4xf32>) -> vector<4xf32>
+  amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
+  amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+  // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
+  amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+  // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<16xbf16>, i1) -> vector<16xbf16>
+  amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+  // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>, i1) -> vector<8xbf16>
+  amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
+  amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32>
+  // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
+  amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<8xi32>
+  func.return
+}

diff  --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 82d7af2c6dfa4f9..142224e59a95d7a 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -103,3 +103,11 @@ func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
     abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, f32, vector<32xf32>
   func.return %d : vector<32xf32>
 }
+
+// -----
+
+func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> {
+  // expected-error at +1 {{'amdgpu.wmma' op Expected int sources with int destination}}
+  %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32>
+  func.return %0 : vector<8xi32>
+}

diff  --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index a612fba56a7634c..4088c6750c91b8d 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -94,3 +94,10 @@ func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> {
   %0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, f32, vector<32xf32>
   func.return %0 : vector<32xf32>
 }
+
+// CHECK-LABEL: func @wmma
+func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+  // CHECK: amdgpu.wmma
+  %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+  func.return %0 : vector<8xf16>
+}

diff  --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index e3f942e8bd785c7..03de137f2252233 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -215,6 +215,66 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
   llvm.return %r0 : vector<32 x f32>
 }
 
+llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : vector<16 x i16>, %arg3 : vector<8 x i32>,
+                      %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>) -> vector<8xf32> {
+  %zero = llvm.mlir.constant(false) : i1
+
+  // ---- Wave32 -----
+
+  // f16 -> f32
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x float> %{{.*}})
+  %r0 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg0 : (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
+
+  // bf16 -> f32
+  // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x float> %{{.*}})
+  %r1 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg0 : (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
+
+  // f16 -> f16 (OPSEL = {0,1})
+  // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x half> %{{.*}}, i1 {{.*}})
+  %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1, %zero : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
+
+  // bf16 -> bf16 (OPSEL = {0,1})
+  // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}, i1 {{.*}})
+  %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2, %zero : (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+
+  // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
+  %r5 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg3, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
+
+  // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+  // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
+  %r6 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+
+  // ---- Wave64 -----
+
+  // f16 -> f32
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}})
+  %r7 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg6 : (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
+
+  // bf16 -> f32
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <4 x float> %{{.*}})
+  %r8 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg6 : (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
+
+  // f16 -> f16 (OPSEL = {0,1})
+  // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x half> %{{.*}}, i1 {{.*}})
+  %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7, %zero : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
+
+  // bf16 -> bf16 (OPSEL = {0,1})
+  // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x i16> %{{.*}}, i1 {{.*}})
+  %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8, %zero : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
+
+  // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+  // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}})
+  %r12 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg5, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
+
+  // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1})
+  // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}})
+  %r13 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg5, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<4xi32>, i1) -> vector<4xi32>
+
+  llvm.return %r0 : vector<8xf32>
+}
+
+
 llvm.func @rocdl.mubuf(%rsrc : vector<4xi32>, %vindex : i32,
                        %offset : i32, %vdata1 : vector<1xf32>,
                        %vdata2 : vector<2xf32>, %vdata4 : vector<4xf32>) {


        


More information about the Mlir-commits mailing list