[Mlir-commits] [mlir] [MLIR][GPU] subgroup_mma fp64 extension (PR #165873)
Giacomo Castiglioni
llvmlistbot at llvm.org
Fri Oct 31 09:13:52 PDT 2025
https://github.com/castigli updated https://github.com/llvm/llvm-project/pull/165873
>From 1799e6a9448a382dd76bdecf2457e9c78c7c1087 Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Fri, 31 Oct 2025 16:05:46 +0100
Subject: [PATCH 1/3] GPU mma fp64 extension
---
.../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h | 2 +-
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td | 2 +-
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 8 +--
.../Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 52 +++++++++++---
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 4 +-
.../GPUToNVVM/wmma-ops-to-nvvm.mlir | 22 ++++++
.../GPU/CUDA/TensorCore/wmma-matmul-f64.mlir | 72 +++++++++++++++++++
7 files changed, 144 insertions(+), 18 deletions(-)
create mode 100644 mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 4c8abea680b66..48982ac6efe7c 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -27,7 +27,7 @@ class MMAMatrixType;
#define GEN_PASS_DECL_CONVERTGPUOPSTONVVMOPS
#include "mlir/Conversion/Passes.h.inc"
-LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
+Type convertMMAToLLVMType(gpu::MMAMatrixType type);
/// Configure target to convert from the GPU dialect to NVVM.
void configureGpuToNVVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 860f893367203..2c29bb8a01a41 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -114,7 +114,7 @@ def GPU_MMAMatrix : DialectType<
GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
// Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
-def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>;
+def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, F64, VectorOfRankAndType<[1], [I8, I32, F16, F32, F64]>]>;
class MMAMatrixOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index a6c6038e1e224..5c7df25c58cde 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1872,7 +1872,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
```
}];
- let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src,
+ let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32, F64]>>:$src,
Arg<GPU_MMAMemRef, "",[MemWriteAt<0, FullEffect>]>:$dstMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension,
@@ -1919,9 +1919,9 @@ def GPU_SubgroupMmaComputeOp
```
}];
- let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA,
- Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB,
- Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC,
+ let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opA,
+ Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opB,
+ Arg<MMAMatrixOf<[I32, F16, F32, F64]>>:$opC,
OptionalAttr<UnitAttr>:$a_transpose,
OptionalAttr<UnitAttr>:$b_transpose);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 99c059cb03299..fb1a37a03fe4d 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
using namespace mlir;
@@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
if (type.getElementType().isF32())
return type.getOperand() == "COp" ? NVVM::MMATypes::f32
: NVVM::MMATypes::tf32;
-
+ if (type.getElementType().isF64())
+ return NVVM::MMATypes::f64;
if (type.getElementType().isSignedInteger(8))
return NVVM::MMATypes::s8;
if (type.getElementType().isUnsignedInteger(8))
@@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering
// then passed on to the intrinsic call. Emit llvm ops to extract individual
// values form lowered memrefs.
SmallVector<Value> unpackedOps;
-
auto unpackOp = [&](Value operand) {
+ // f64 a and b fragments are not structs but scalars.
+ if (!isa<LLVM::LLVMStructType>(operand.getType())) {
+ unpackedOps.push_back(operand);
+ return;
+ }
+ // every other type is lowered to an LLVM struct, extract the values.
auto structType = cast<LLVM::LLVMStructType>(operand.getType());
for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
@@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering
return failure();
Location loc = subgroupMmaConstantOp.getLoc();
Value cst = adaptor.getOperands()[0];
- LLVM::LLVMStructType type = convertMMAToLLVMType(
+ Type type = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
+ // If the element is not a struct, it means it's a scalar f64.
+ LLVM::LLVMStructType structType = dyn_cast<LLVM::LLVMStructType>(type);
+ if (!structType) {
+ rewriter.replaceOp(subgroupMmaConstantOp, cst);
+ return success();
+ }
// If the element type is a vector create a vector from the operand.
- if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
+ if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
Value idx = LLVM::ConstantOp::create(rewriter, loc,
@@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering
}
cst = vecCst;
}
- Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
- for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
+ Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
+ for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) {
matrixStruct =
LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
}
@@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering
return failure();
Location loc = subgroupMmaElementwiseOp.getLoc();
size_t numOperands = adaptor.getOperands().size();
- LLVM::LLVMStructType destType = convertMMAToLLVMType(
+ Type destType = convertMMAToLLVMType(
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
- Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
- for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
+
+ // If the element is not a struct, it means it's a scalar f64.
+ LLVM::LLVMStructType structDestTy = dyn_cast<LLVM::LLVMStructType>(destType);
+ if (!structDestTy) {
+ SmallVector<Value> operands;
+ for (auto operand : adaptor.getOperands()) {
+ operands.push_back(operand);
+ }
+ Value element =
+ createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
+ operands);
+ rewriter.replaceOp(subgroupMmaElementwiseOp, element);
+ return success();
+ }
+ Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
+ for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
extractedOperands.push_back(LLVM::ExtractValueOp::create(
@@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering
} // namespace
/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
-LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
+Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
NVVM::MMAFrag frag = convertOperand(type.getOperand());
NVVM::MMATypes eltType = getElementType(type);
auto nRow = type.getShape()[0];
auto nCol = type.getShape()[1];
std::pair<Type, unsigned> typeInfo =
NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
+ // Special handling for f64 a and b fragments
+ Type f64Ty = Float64Type::get(type.getContext());
+ if (typeInfo.first == f64Ty && typeInfo.second == 1) {
+ return f64Ty;
+ }
return LLVM::LLVMStructType::getLiteral(
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 6c6d8d2bad55d..61a630aa88960 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
bool MMAMatrixType::isValidElementType(Type elementType) {
- return elementType.isF16() || elementType.isF32() ||
+ return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
elementType.isInteger(32);
}
@@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
if (!MMAMatrixType::isValidElementType(elementType))
return emitError()
- << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
+ << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
return success();
}
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index b479467efc208..83b5fb5e6ea54 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -79,6 +79,28 @@ gpu.module @test_module {
// -----
+gpu.module @test_module {
+
+ // CHECK-LABEL: func @gpu_wmma_f64_load_op() ->
+ // CHECK-SAME: f64
+ // CHECK32-LABEL: func @gpu_wmma_f64_load_op() ->
+ func.func @gpu_wmma_f64_load_op() -> (!gpu.mma_matrix<8x4xf64, "AOp">) {
+ %wg = memref.alloca() {alignment = 32} : memref<32x32xf64, 3>
+ %i = arith.constant 16 : index
+ %j = arith.constant 16 : index
+ %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf64, 3> -> !gpu.mma_matrix<8x4xf64, "AOp">
+ return %0 : !gpu.mma_matrix<8x4xf64, "AOp">
+ // CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : i64
+ // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
+ // CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f64
+ // CHECK: %[[C32_I32:.*]] = llvm.mlir.constant(32 : index) : i32
+ // CHECK: %[[LOAD:.*]] = nvvm.wmma.load %[[GEP]], %[[C32_I32]] {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32} : (!llvm.ptr<3>) -> f64
+ // CHECK: llvm.return %[[LOAD]] : f64
+ }
+}
+
+// -----
+
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_store_op
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir
new file mode 100644
index 0000000000000..a016a60022699
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @main() {
+ %a = memref.alloc() : memref<8x4xf64>
+ %b = memref.alloc() : memref<4x8xf64>
+ %c = memref.alloc() : memref<8x8xf64>
+ %d = memref.alloc() : memref<8x8xf64>
+
+ %f1 = arith.constant 1.0e+00 : f64
+ %fcst = arith.constant 3.14e+00 : f64
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+
+ // Initialize the Input matrixes with ones.
+ scf.for %arg0 = %c0 to %c8 step %c1 {
+ scf.for %arg1 = %c0 to %c4 step %c1 {
+ memref.store %f1, %a[%arg0, %arg1] : memref<8x4xf64>
+ memref.store %f1, %b[%arg1, %arg0] : memref<4x8xf64>
+ }
+ }
+ // Initialize the accumulator matrix with a constant.
+ scf.for %arg0 = %c0 to %c8 step %c1 {
+ scf.for %arg1 = %c0 to %c8 step %c1 {
+ memref.store %fcst, %c[%arg0, %arg1] : memref<8x8xf64>
+ }
+ }
+
+ %2 = memref.cast %a : memref<8x4xf64> to memref<*xf64>
+ %20 = memref.cast %b : memref<4x8xf64> to memref<*xf64>
+ %33 = memref.cast %c : memref<8x8xf64> to memref<*xf64>
+ %34 = memref.cast %d : memref<8x8xf64> to memref<*xf64>
+
+ gpu.host_register %2 : memref<*xf64>
+ gpu.host_register %20 : memref<*xf64>
+ gpu.host_register %33 : memref<*xf64>
+
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+ %A = gpu.subgroup_mma_load_matrix %a[%c0, %c0] {leadDimension = 4 : index} : memref<8x4xf64> -> !gpu.mma_matrix<8x4xf64, "AOp">
+ %B = gpu.subgroup_mma_load_matrix %b[%c0, %c0] {leadDimension = 8 : index} : memref<4x8xf64> -> !gpu.mma_matrix<4x8xf64, "BOp">
+ %C = gpu.subgroup_mma_load_matrix %c[%c0, %c0] {leadDimension = 8 : index} : memref<8x8xf64> -> !gpu.mma_matrix<8x8xf64, "COp">
+
+ %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<8x4xf64, "AOp">, !gpu.mma_matrix<4x8xf64, "BOp"> -> !gpu.mma_matrix<8x8xf64, "COp">
+
+ gpu.subgroup_mma_store_matrix %R, %d[%c0, %c0] {leadDimension = 8 : index}: !gpu.mma_matrix<8x8xf64, "COp">, memref<8x8xf64>
+ gpu.terminator
+ }
+ // Print the memref after computation.
+ call @printMemrefF64(%34) : (memref<*xf64>) -> ()
+ // CHECK: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14],
+ // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14]
+ return
+}
+
+func.func private @printMemrefF64(memref<*xf64>)
>From 4cee36a2d674a28f741eb4f77160ce878a3cfd93 Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Fri, 31 Oct 2025 16:24:03 +0100
Subject: [PATCH 2/3] format
---
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index fb1a37a03fe4d..13bb2231b13ca 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -371,15 +371,15 @@ struct WmmaElementwiseOpToNVVMLowering
cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
// If the element is not a struct, it means it's a scalar f64.
- LLVM::LLVMStructType structDestTy = dyn_cast<LLVM::LLVMStructType>(destType);
+ LLVM::LLVMStructType structDestTy =
+ dyn_cast<LLVM::LLVMStructType>(destType);
if (!structDestTy) {
SmallVector<Value> operands;
for (auto operand : adaptor.getOperands()) {
operands.push_back(operand);
}
- Value element =
- createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
- operands);
+ Value element = createScalarOp(
+ rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands);
rewriter.replaceOp(subgroupMmaElementwiseOp, element);
return success();
}
>From 094264fa3f398cee597c739c0de8d61ba00bc618 Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Fri, 31 Oct 2025 17:12:48 +0100
Subject: [PATCH 3/3] fix invalid IR test
---
mlir/test/Dialect/GPU/invalid.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 35381dab7b200..26bcf948bc85d 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -688,7 +688,7 @@ func.func @mmamatrix_operand_type(){
func.func @mmamatrix_invalid_element_type(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
- // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}}
+ // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp">
return
}
@@ -708,7 +708,7 @@ func.func @mmaLoadOp_identity_layout(){
// -----
func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) {
- // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}}
+ // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float values of ranks 1 values}}
%0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp">
return
}
More information about the Mlir-commits
mailing list