[Mlir-commits] [mlir] Extend GPU and NVVM mma ops to support fp64 (PR #165380)

Giacomo Castiglioni llvmlistbot at llvm.org
Wed Oct 29 05:55:57 PDT 2025


https://github.com/castigli updated https://github.com/llvm/llvm-project/pull/165380

>From f03a9cb37158af54f53f67b1bb00272126cca66c Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Tue, 28 Oct 2025 12:05:48 +0100
Subject: [PATCH 1/2] Extend gpu and nvvm wmma to support fp64

---
 .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h |  2 +-
 mlir/include/mlir/Dialect/GPU/IR/GPUBase.td   |  2 +-
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  8 +--
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 16 ++++-
 .../Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp    | 52 +++++++++++---
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  4 +-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 14 ++++
 .../GPUToNVVM/wmma-ops-to-nvvm.mlir           | 22 ++++++
 .../GPU/CUDA/TensorCore/wmma-matmul-f64.mlir  | 70 +++++++++++++++++++
 mlir/test/Target/LLVMIR/nvvmir.mlir           | 37 ++++++++++
 10 files changed, 207 insertions(+), 20 deletions(-)
 create mode 100644 mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir

diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 4c8abea680b66..48982ac6efe7c 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -27,7 +27,7 @@ class MMAMatrixType;
 #define GEN_PASS_DECL_CONVERTGPUOPSTONVVMOPS
 #include "mlir/Conversion/Passes.h.inc"
 
-LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
+Type convertMMAToLLVMType(gpu::MMAMatrixType type);
 
 /// Configure target to convert from the GPU dialect to NVVM.
 void configureGpuToNVVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 860f893367203..2c29bb8a01a41 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -114,7 +114,7 @@ def GPU_MMAMatrix : DialectType<
   GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
 
 // Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
-def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>;
+def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, F64, VectorOfRankAndType<[1], [I8, I32, F16, F32, F64]>]>;
 
 class MMAMatrixOf<list<Type> allowedTypes> :
   ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index a6c6038e1e224..5c7df25c58cde 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1872,7 +1872,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
     ```
   }];
 
-  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src,
+  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32, F64]>>:$src,
                   Arg<GPU_MMAMemRef, "",[MemWriteAt<0, FullEffect>]>:$dstMemref,
                   Variadic<Index>:$indices,
                   IndexAttr:$leadDimension,
@@ -1919,9 +1919,9 @@ def GPU_SubgroupMmaComputeOp
     ```
   }];
 
-  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA,
-                  Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB,
-                  Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC,
+  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opA,
+                  Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opB,
+                  Arg<MMAMatrixOf<[I32, F16, F32, F64]>>:$opC,
                   OptionalAttr<UnitAttr>:$a_transpose,
                   OptionalAttr<UnitAttr>:$b_transpose);
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 4f483859ac18d..cccdc2c368d6d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2014,6 +2014,9 @@ class MMA_LDST_OPS<list<GEOM> Geom, list<string> Frags, list<string> Types> {
 // llvm supports and can be extended as needed.
 class NVVM_MMA_OPS {
   // "wmma" operations
+  list<list<WMMA_REGS>> fp64_wmma_ops = MMA_OPS<
+            [GEOM<8, 8, 4>],
+            ["f64"], [], ["f64"], []>.ret;
   list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
             [GEOM<16, 16, 8>],
             ["tf32"], [], ["f32"], []>.ret;
@@ -2024,6 +2027,7 @@ class NVVM_MMA_OPS {
             [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
             ["s8","u8"], [], ["s32"], []>.ret;
   list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
+            fp64_wmma_ops,
             tf32_wmma_ops,
             fp_wmma_ops,
             i8_wmma_ops);
@@ -2040,9 +2044,17 @@ class NVVM_MMA_OPS {
   list<WMMA_REGS> ldst_tf32_cd_ops = MMA_LDST_OPS<
             [GEOM<16, 16, 8>],
             ["c", "d"], ["f32"]>.ret;
+  list<WMMA_REGS> ldst_f64_ab_ops = MMA_LDST_OPS<
+            [GEOM<8, 8, 4>],
+            ["a", "b"], ["f64"]>.ret;
+  list<WMMA_REGS> ldst_f64_cd_ops = MMA_LDST_OPS<
+            [GEOM<8, 8, 4>],
+            ["c", "d"], ["f64"]>.ret;
   list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
                                              ldst_tf32_ab_ops,
-                                             ldst_tf32_cd_ops);
+                                             ldst_tf32_cd_ops, 
+                                             ldst_f64_ab_ops,
+                                             ldst_f64_cd_ops);
   // Separate A/B/C fragments (loads) from D (stores).
   list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d"));
   list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d"));
@@ -2349,7 +2361,7 @@ def MMAFragAttr : EnumAttr<NVVM_Dialect, MMAFrag, "mma_frag"> {
 }
 
 def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">,
-  Results<(outs LLVM_AnyStruct:$res)>,
+  Results<(outs AnyTypeOf<[LLVM_AnyStruct, F64]>:$res)>,
   Arguments<(ins LLVM_AnyPointer: $ptr, I32: $stride, I32Attr:$m,
              I32Attr:$n, I32Attr:$k, MMALayoutAttr:$layout,
              MMATypesAttr:$eltype, MMAFragAttr:$frag)> {
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 99c059cb03299..fb1a37a03fe4d 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
 
 using namespace mlir;
 
@@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
   if (type.getElementType().isF32())
     return type.getOperand() == "COp" ? NVVM::MMATypes::f32
                                       : NVVM::MMATypes::tf32;
-
+  if (type.getElementType().isF64())
+    return NVVM::MMATypes::f64;
   if (type.getElementType().isSignedInteger(8))
     return NVVM::MMATypes::s8;
   if (type.getElementType().isUnsignedInteger(8))
@@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering
     // then passed on to the intrinsic call. Emit llvm ops to extract individual
     // values form lowered memrefs.
     SmallVector<Value> unpackedOps;
-
     auto unpackOp = [&](Value operand) {
+      // f64 a and b fragments are not structs but scalars.
+      if (!isa<LLVM::LLVMStructType>(operand.getType())) {
+        unpackedOps.push_back(operand);
+        return;
+      }
+      // every other type is lowered to an LLVM struct, extract the values.
       auto structType = cast<LLVM::LLVMStructType>(operand.getType());
       for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
         Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
@@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering
       return failure();
     Location loc = subgroupMmaConstantOp.getLoc();
     Value cst = adaptor.getOperands()[0];
-    LLVM::LLVMStructType type = convertMMAToLLVMType(
+    Type type = convertMMAToLLVMType(
         cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
+    // If the element is not a struct, it means it's a scalar f64.
+    LLVM::LLVMStructType structType = dyn_cast<LLVM::LLVMStructType>(type);
+    if (!structType) {
+      rewriter.replaceOp(subgroupMmaConstantOp, cst);
+      return success();
+    }
     // If the element type is a vector create a vector from the operand.
-    if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
+    if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
       Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
       for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
         Value idx = LLVM::ConstantOp::create(rewriter, loc,
@@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering
       }
       cst = vecCst;
     }
-    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
-    for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
+    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
+    for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) {
       matrixStruct =
           LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
     }
@@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering
       return failure();
     Location loc = subgroupMmaElementwiseOp.getLoc();
     size_t numOperands = adaptor.getOperands().size();
-    LLVM::LLVMStructType destType = convertMMAToLLVMType(
+    Type destType = convertMMAToLLVMType(
         cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
-    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
-    for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
+
+    // If the element is not a struct, it means it's a scalar f64.
+    LLVM::LLVMStructType structDestTy = dyn_cast<LLVM::LLVMStructType>(destType);
+    if (!structDestTy) {
+      SmallVector<Value> operands;
+      for (auto operand : adaptor.getOperands()) {
+        operands.push_back(operand);
+      }
+      Value element =
+          createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
+                         operands);
+      rewriter.replaceOp(subgroupMmaElementwiseOp, element);
+      return success();
+    }
+    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
+    for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
       SmallVector<Value> extractedOperands;
       for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
         extractedOperands.push_back(LLVM::ExtractValueOp::create(
@@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering
 } // namespace
 
 /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
-LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
+Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
   NVVM::MMAFrag frag = convertOperand(type.getOperand());
   NVVM::MMATypes eltType = getElementType(type);
   auto nRow = type.getShape()[0];
   auto nCol = type.getShape()[1];
   std::pair<Type, unsigned> typeInfo =
       NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
+  // Special handling for f64 a and b fragments
+  Type f64Ty = Float64Type::get(type.getContext());
+  if (typeInfo.first == f64Ty && typeInfo.second == 1) {
+    return f64Ty;
+  }
   return LLVM::LLVMStructType::getLiteral(
       type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
 }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index b5f8ddaadacdf..0dd9b6adea954 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
 
 bool MMAMatrixType::isValidElementType(Type elementType) {
-  return elementType.isF16() || elementType.isF32() ||
+  return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
          elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
          elementType.isInteger(32);
 }
@@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
 
   if (!MMAMatrixType::isValidElementType(elementType))
     return emitError()
-           << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
+           << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
 
   return success();
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f0de4dbcc1d4b..c665cc4201049 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -896,6 +896,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
   } else if (type == NVVM::MMATypes::f32) {
     elementType = builder.getF32Type();
     numberElements = 8;
+  } else if (type == NVVM::MMATypes::f64) {
+    elementType = builder.getF64Type();
+    if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
+      numberElements = 1;
+    else
+      numberElements = 2;
   } else if (type == NVVM::MMATypes::tf32) {
     elementType = builder.getI32Type();
     numberElements = 4;
@@ -954,6 +960,14 @@ LogicalResult NVVM::WMMALoadOp::verify() {
     return emitOpError() << "invalid attribute combination";
   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
       getEltype(), getFrag(), getM(), getN(), getK(), getContext());
+  // Special case for f64 fragments
+  Type f64Ty = Float64Type::get(getContext());
+  if (typeInfo.first == f64Ty && typeInfo.second == 1) {
+      if (getType() != f64Ty)
+        return emitOpError("expected destination type to be f64");
+    return success();
+  }
+  // Everything else is a struct
   Type dstType = LLVM::LLVMStructType::getLiteral(
       getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
   if (getType() != dstType)
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index b479467efc208..83b5fb5e6ea54 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -79,6 +79,28 @@ gpu.module @test_module {
 
 // -----
 
+gpu.module @test_module {
+
+  // CHECK-LABEL: func @gpu_wmma_f64_load_op() ->
+  // CHECK-SAME: f64
+  // CHECK32-LABEL: func @gpu_wmma_f64_load_op() ->
+  func.func @gpu_wmma_f64_load_op() -> (!gpu.mma_matrix<8x4xf64, "AOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf64, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf64, 3> -> !gpu.mma_matrix<8x4xf64, "AOp">
+    return %0 : !gpu.mma_matrix<8x4xf64, "AOp">
+    // CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : i64
+    // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
+    // CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f64
+    // CHECK: %[[C32_I32:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK: %[[LOAD:.*]] = nvvm.wmma.load %[[GEP]], %[[C32_I32]] {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32} : (!llvm.ptr<3>) -> f64
+    // CHECK: llvm.return %[[LOAD]] : f64
+  }
+}
+
+// -----
+
 gpu.module @test_module {
 
   // CHECK-LABEL: func @gpu_wmma_store_op
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir
new file mode 100644
index 0000000000000..a37af5def8c42
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+func.func @main() {
+  %a = memref.alloc() : memref<8x4xf64>
+  %b = memref.alloc() : memref<4x8xf64>
+  %c = memref.alloc() : memref<8x8xf64>
+  %d = memref.alloc() : memref<8x8xf64>
+
+  %f1 = arith.constant 1.0e+00 : f64
+  %fcst = arith.constant 3.14e+00 : f64
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %c4 = arith.constant 4 : index
+  %c1 = arith.constant 1 : index
+  %c32 = arith.constant 32 : index
+
+  // Initialize the Input matrixes with ones.
+  scf.for %arg0 = %c0 to %c8 step %c1 {
+    scf.for %arg1 = %c0 to %c4 step %c1 {
+      memref.store %f1, %a[%arg0, %arg1] : memref<8x4xf64>
+      memref.store %f1, %b[%arg1, %arg0] : memref<4x8xf64>
+    }
+  }
+  // Initialize the accumulator matrix with zeros.
+  scf.for %arg0 = %c0 to %c8 step %c1 {
+    scf.for %arg1 = %c0 to %c8 step %c1 {
+      memref.store %fcst, %c[%arg0, %arg1] : memref<8x8xf64>
+    }
+  }
+
+  %2 = memref.cast %a : memref<8x4xf64> to memref<*xf64>
+  %20 = memref.cast %b : memref<4x8xf64> to memref<*xf64>
+  %33 = memref.cast %c : memref<8x8xf64> to memref<*xf64>
+  %34 = memref.cast %d : memref<8x8xf64> to memref<*xf64>
+
+  gpu.host_register %2 : memref<*xf64>
+  gpu.host_register %20 : memref<*xf64>
+  gpu.host_register %33 : memref<*xf64>
+
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+    %A = gpu.subgroup_mma_load_matrix %a[%c0, %c0] {leadDimension = 4 : index} : memref<8x4xf64> -> !gpu.mma_matrix<8x4xf64, "AOp">
+    %B = gpu.subgroup_mma_load_matrix %b[%c0, %c0] {leadDimension = 8 : index} : memref<4x8xf64> -> !gpu.mma_matrix<4x8xf64, "BOp">
+    %C = gpu.subgroup_mma_load_matrix %c[%c0, %c0] {leadDimension = 8 : index} : memref<8x8xf64> -> !gpu.mma_matrix<8x8xf64, "COp">
+
+    %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<8x4xf64, "AOp">, !gpu.mma_matrix<4x8xf64, "BOp"> -> !gpu.mma_matrix<8x8xf64, "COp">
+
+    gpu.subgroup_mma_store_matrix %R, %d[%c0, %c0] {leadDimension = 8 : index}: !gpu.mma_matrix<8x8xf64, "COp">, memref<8x8xf64>
+    gpu.terminator
+  }
+  // Print the memref after computation.
+  call @printMemrefF64(%34) : (memref<*xf64>) -> ()
+  // CHECK: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14]
+  return
+}
+
+func.func private @printMemrefF64(memref<*xf64>)
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 594ae4849e3eb..388ebe859a249 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -463,6 +463,43 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
   llvm.return
 }
 
+// CHECK-LABEL: @nvvm_wmma_load_a_f64
+llvm.func @nvvm_wmma_load_a_f64(%arg0: !llvm.ptr, %arg1 : i32) {
+  // CHECK: call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p0(ptr %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.wmma.load %arg0, %arg1
+    {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
+    : (!llvm.ptr) -> f64
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_wmma_load_c_f64
+llvm.func @nvvm_wmma_load_c_f64(%arg0: !llvm.ptr, %arg1 : i32) {
+  // CHECK: call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p0(ptr %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.wmma.load %arg0, %arg1
+    {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<c>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
+    : (!llvm.ptr) -> !llvm.struct<(f64, f64)>
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_wmma_mma_f64
+llvm.func @nvvm_wmma_mma_f64(%0 : f64, %1 : f64, %2 : f64, %3 : f64) {
+  // CHECK: { double, double } @llvm.nvvm.wmma.mm8n8k4.mma.row.col.f64(f64 %{{.*}}, f64 %{{.*}}, f64 %{{.*}}, f64 %{{.*}})
+  %r = nvvm.wmma.mma %0, %1, %2, %3
+    {eltypeA = #nvvm.mma_type<f64>, eltypeB = #nvvm.mma_type<f64>, k = 4 : i32, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, m = 8 : i32, n = 8 : i32}
+    : (f64, f64, f64, f64)
+    -> !llvm.struct<(f64, f64)>
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_wmma_store_d_f64
+llvm.func @nvvm_wmma_store_d_f64(%arg0: !llvm.ptr, %arg1 : i32, %arg2 : f64, %arg3 : f64) {
+  // CHECK: call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p0(ptr %{{.*}}, f64 %{{.*}}, f64 %{{.*}}, i32 %{{.*}})
+  nvvm.wmma.store %arg0, %arg1, %arg2, %arg3
+    {eltype = #nvvm.mma_type<f64>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
+    : !llvm.ptr, f64, f64
+  llvm.return
+}
+
 // CHECK-LABEL: @cp_async
 llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
   // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})

>From 5b4b363ddd3dbafa5b3bee29a6d9b4cada03c452 Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Wed, 29 Oct 2025 13:55:07 +0100
Subject: [PATCH 2/2] Keep only nvvm changes

---
 .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h |  2 +-
 mlir/include/mlir/Dialect/GPU/IR/GPUBase.td   |  2 +-
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  8 +--
 .../Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp    | 52 +++-----------
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  4 +-
 .../GPUToNVVM/wmma-ops-to-nvvm.mlir           | 22 ------
 .../GPU/CUDA/TensorCore/wmma-matmul-f64.mlir  | 70 -------------------
 mlir/test/Target/LLVMIR/nvvmir.mlir           |  4 +-
 8 files changed, 20 insertions(+), 144 deletions(-)
 delete mode 100644 mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir

diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 48982ac6efe7c..4c8abea680b66 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -27,7 +27,7 @@ class MMAMatrixType;
 #define GEN_PASS_DECL_CONVERTGPUOPSTONVVMOPS
 #include "mlir/Conversion/Passes.h.inc"
 
-Type convertMMAToLLVMType(gpu::MMAMatrixType type);
+LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
 
 /// Configure target to convert from the GPU dialect to NVVM.
 void configureGpuToNVVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 2c29bb8a01a41..860f893367203 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -114,7 +114,7 @@ def GPU_MMAMatrix : DialectType<
   GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
 
 // Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
-def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, F64, VectorOfRankAndType<[1], [I8, I32, F16, F32, F64]>]>;
+def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>;
 
 class MMAMatrixOf<list<Type> allowedTypes> :
   ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 5c7df25c58cde..a6c6038e1e224 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1872,7 +1872,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
     ```
   }];
 
-  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32, F64]>>:$src,
+  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src,
                   Arg<GPU_MMAMemRef, "",[MemWriteAt<0, FullEffect>]>:$dstMemref,
                   Variadic<Index>:$indices,
                   IndexAttr:$leadDimension,
@@ -1919,9 +1919,9 @@ def GPU_SubgroupMmaComputeOp
     ```
   }];
 
-  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opA,
-                  Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opB,
-                  Arg<MMAMatrixOf<[I32, F16, F32, F64]>>:$opC,
+  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA,
+                  Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB,
+                  Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC,
                   OptionalAttr<UnitAttr>:$a_transpose,
                   OptionalAttr<UnitAttr>:$b_transpose);
 
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index fb1a37a03fe4d..99c059cb03299 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/Types.h"
 
 using namespace mlir;
 
@@ -58,8 +57,7 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
   if (type.getElementType().isF32())
     return type.getOperand() == "COp" ? NVVM::MMATypes::f32
                                       : NVVM::MMATypes::tf32;
-  if (type.getElementType().isF64())
-    return NVVM::MMATypes::f64;
+
   if (type.getElementType().isSignedInteger(8))
     return NVVM::MMATypes::s8;
   if (type.getElementType().isUnsignedInteger(8))
@@ -214,13 +212,8 @@ struct WmmaMmaOpToNVVMLowering
     // then passed on to the intrinsic call. Emit llvm ops to extract individual
     // values form lowered memrefs.
     SmallVector<Value> unpackedOps;
+
     auto unpackOp = [&](Value operand) {
-      // f64 a and b fragments are not structs but scalars.
-      if (!isa<LLVM::LLVMStructType>(operand.getType())) {
-        unpackedOps.push_back(operand);
-        return;
-      }
-      // every other type is lowered to an LLVM struct, extract the values.
       auto structType = cast<LLVM::LLVMStructType>(operand.getType());
       for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
         Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
@@ -283,16 +276,10 @@ struct WmmaConstantOpToNVVMLowering
       return failure();
     Location loc = subgroupMmaConstantOp.getLoc();
     Value cst = adaptor.getOperands()[0];
-    Type type = convertMMAToLLVMType(
+    LLVM::LLVMStructType type = convertMMAToLLVMType(
         cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
-    // If the element is not a struct, it means it's a scalar f64.
-    LLVM::LLVMStructType structType = dyn_cast<LLVM::LLVMStructType>(type);
-    if (!structType) {
-      rewriter.replaceOp(subgroupMmaConstantOp, cst);
-      return success();
-    }
     // If the element type is a vector create a vector from the operand.
-    if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
+    if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
       Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
       for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
         Value idx = LLVM::ConstantOp::create(rewriter, loc,
@@ -302,8 +289,8 @@ struct WmmaConstantOpToNVVMLowering
       }
       cst = vecCst;
     }
-    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
-    for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) {
+    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
+    for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
       matrixStruct =
           LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
     }
@@ -367,24 +354,10 @@ struct WmmaElementwiseOpToNVVMLowering
       return failure();
     Location loc = subgroupMmaElementwiseOp.getLoc();
     size_t numOperands = adaptor.getOperands().size();
-    Type destType = convertMMAToLLVMType(
+    LLVM::LLVMStructType destType = convertMMAToLLVMType(
         cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
-
-    // If the element is not a struct, it means it's a scalar f64.
-    LLVM::LLVMStructType structDestTy = dyn_cast<LLVM::LLVMStructType>(destType);
-    if (!structDestTy) {
-      SmallVector<Value> operands;
-      for (auto operand : adaptor.getOperands()) {
-        operands.push_back(operand);
-      }
-      Value element =
-          createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
-                         operands);
-      rewriter.replaceOp(subgroupMmaElementwiseOp, element);
-      return success();
-    }
-    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
-    for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
+    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
+    for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
       SmallVector<Value> extractedOperands;
       for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
         extractedOperands.push_back(LLVM::ExtractValueOp::create(
@@ -404,18 +377,13 @@ struct WmmaElementwiseOpToNVVMLowering
 } // namespace
 
 /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
-Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
+LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
   NVVM::MMAFrag frag = convertOperand(type.getOperand());
   NVVM::MMATypes eltType = getElementType(type);
   auto nRow = type.getShape()[0];
   auto nCol = type.getShape()[1];
   std::pair<Type, unsigned> typeInfo =
       NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
-  // Special handling for f64 a and b fragments
-  Type f64Ty = Float64Type::get(type.getContext());
-  if (typeInfo.first == f64Ty && typeInfo.second == 1) {
-    return f64Ty;
-  }
   return LLVM::LLVMStructType::getLiteral(
       type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
 }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 0dd9b6adea954..b5f8ddaadacdf 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
 
 bool MMAMatrixType::isValidElementType(Type elementType) {
-  return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
+  return elementType.isF16() || elementType.isF32() ||
          elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
          elementType.isInteger(32);
 }
@@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
 
   if (!MMAMatrixType::isValidElementType(elementType))
     return emitError()
-           << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
+           << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
 
   return success();
 }
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index 83b5fb5e6ea54..b479467efc208 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -79,28 +79,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-
-  // CHECK-LABEL: func @gpu_wmma_f64_load_op() ->
-  // CHECK-SAME: f64
-  // CHECK32-LABEL: func @gpu_wmma_f64_load_op() ->
-  func.func @gpu_wmma_f64_load_op() -> (!gpu.mma_matrix<8x4xf64, "AOp">) {
-    %wg = memref.alloca() {alignment = 32} : memref<32x32xf64, 3>
-    %i = arith.constant 16 : index
-    %j = arith.constant 16 : index
-    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf64, 3> -> !gpu.mma_matrix<8x4xf64, "AOp">
-    return %0 : !gpu.mma_matrix<8x4xf64, "AOp">
-    // CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : i64
-    // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
-    // CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f64
-    // CHECK: %[[C32_I32:.*]] = llvm.mlir.constant(32 : index) : i32
-    // CHECK: %[[LOAD:.*]] = nvvm.wmma.load %[[GEP]], %[[C32_I32]] {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32} : (!llvm.ptr<3>) -> f64
-    // CHECK: llvm.return %[[LOAD]] : f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
 
   // CHECK-LABEL: func @gpu_wmma_store_op
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir
deleted file mode 100644
index a37af5def8c42..0000000000000
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir
+++ /dev/null
@@ -1,70 +0,0 @@
-// RUN: mlir-opt %s \
-// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-format=%gpu_compilation_format" \
-// RUN: | mlir-runner \
-// RUN:   --shared-libs=%mlir_cuda_runtime \
-// RUN:   --shared-libs=%mlir_runner_utils \
-// RUN:   --entry-point-result=void \
-// RUN: | FileCheck %s
-
-func.func @main() {
-  %a = memref.alloc() : memref<8x4xf64>
-  %b = memref.alloc() : memref<4x8xf64>
-  %c = memref.alloc() : memref<8x8xf64>
-  %d = memref.alloc() : memref<8x8xf64>
-
-  %f1 = arith.constant 1.0e+00 : f64
-  %fcst = arith.constant 3.14e+00 : f64
-  %c0 = arith.constant 0 : index
-  %c8 = arith.constant 8 : index
-  %c4 = arith.constant 4 : index
-  %c1 = arith.constant 1 : index
-  %c32 = arith.constant 32 : index
-
-  // Initialize the Input matrixes with ones.
-  scf.for %arg0 = %c0 to %c8 step %c1 {
-    scf.for %arg1 = %c0 to %c4 step %c1 {
-      memref.store %f1, %a[%arg0, %arg1] : memref<8x4xf64>
-      memref.store %f1, %b[%arg1, %arg0] : memref<4x8xf64>
-    }
-  }
-  // Initialize the accumulator matrix with zeros.
-  scf.for %arg0 = %c0 to %c8 step %c1 {
-    scf.for %arg1 = %c0 to %c8 step %c1 {
-      memref.store %fcst, %c[%arg0, %arg1] : memref<8x8xf64>
-    }
-  }
-
-  %2 = memref.cast %a : memref<8x4xf64> to memref<*xf64>
-  %20 = memref.cast %b : memref<4x8xf64> to memref<*xf64>
-  %33 = memref.cast %c : memref<8x8xf64> to memref<*xf64>
-  %34 = memref.cast %d : memref<8x8xf64> to memref<*xf64>
-
-  gpu.host_register %2 : memref<*xf64>
-  gpu.host_register %20 : memref<*xf64>
-  gpu.host_register %33 : memref<*xf64>
-
-  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
-             threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
-    %A = gpu.subgroup_mma_load_matrix %a[%c0, %c0] {leadDimension = 4 : index} : memref<8x4xf64> -> !gpu.mma_matrix<8x4xf64, "AOp">
-    %B = gpu.subgroup_mma_load_matrix %b[%c0, %c0] {leadDimension = 8 : index} : memref<4x8xf64> -> !gpu.mma_matrix<4x8xf64, "BOp">
-    %C = gpu.subgroup_mma_load_matrix %c[%c0, %c0] {leadDimension = 8 : index} : memref<8x8xf64> -> !gpu.mma_matrix<8x8xf64, "COp">
-
-    %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<8x4xf64, "AOp">, !gpu.mma_matrix<4x8xf64, "BOp"> -> !gpu.mma_matrix<8x8xf64, "COp">
-
-    gpu.subgroup_mma_store_matrix %R, %d[%c0, %c0] {leadDimension = 8 : index}: !gpu.mma_matrix<8x8xf64, "COp">, memref<8x8xf64>
-    gpu.terminator
-  }
-  // Print the memref after computation.
-  call @printMemrefF64(%34) : (memref<*xf64>) -> ()
-  // CHECK: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
-  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
-  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
-  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
-  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
-  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
-  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
-  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14]
-  return
-}
-
-func.func private @printMemrefF64(memref<*xf64>)
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 388ebe859a249..9115de65ff0e8 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -483,7 +483,7 @@ llvm.func @nvvm_wmma_load_c_f64(%arg0: !llvm.ptr, %arg1 : i32) {
 
 // CHECK-LABEL: @nvvm_wmma_mma_f64
 llvm.func @nvvm_wmma_mma_f64(%0 : f64, %1 : f64, %2 : f64, %3 : f64) {
-  // CHECK: { double, double } @llvm.nvvm.wmma.mm8n8k4.mma.row.col.f64(f64 %{{.*}}, f64 %{{.*}}, f64 %{{.*}}, f64 %{{.*}})
+  // CHECK: { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.col.f64(double %{{.*}}, double %{{.*}}, double %{{.*}}, double %{{.*}})
   %r = nvvm.wmma.mma %0, %1, %2, %3
     {eltypeA = #nvvm.mma_type<f64>, eltypeB = #nvvm.mma_type<f64>, k = 4 : i32, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, m = 8 : i32, n = 8 : i32}
     : (f64, f64, f64, f64)
@@ -493,7 +493,7 @@ llvm.func @nvvm_wmma_mma_f64(%0 : f64, %1 : f64, %2 : f64, %3 : f64) {
 
 // CHECK-LABEL: @nvvm_wmma_store_d_f64
 llvm.func @nvvm_wmma_store_d_f64(%arg0: !llvm.ptr, %arg1 : i32, %arg2 : f64, %arg3 : f64) {
-  // CHECK: call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p0(ptr %{{.*}}, f64 %{{.*}}, f64 %{{.*}}, i32 %{{.*}})
+  // CHECK: call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p0(ptr %{{.*}}, double %{{.*}}, double %{{.*}}, i32 %{{.*}})
   nvvm.wmma.store %arg0, %arg1, %arg2, %arg3
     {eltype = #nvvm.mma_type<f64>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
     : !llvm.ptr, f64, f64



More information about the Mlir-commits mailing list