[Mlir-commits] [mlir] 0ed8e66 - [MLIR][NVVM] Extend NVVM mma ops to support fp64 (#165380)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 31 06:19:09 PDT 2025


Author: Giacomo Castiglioni
Date: 2025-10-31T18:49:05+05:30
New Revision: 0ed8e66f88b689c152245d6b968a06fa8e67b51f

URL: https://github.com/llvm/llvm-project/commit/0ed8e66f88b689c152245d6b968a06fa8e67b51f
DIFF: https://github.com/llvm/llvm-project/commit/0ed8e66f88b689c152245d6b968a06fa8e67b51f.diff

LOG: [MLIR][NVVM] Extend NVVM mma ops to support fp64 (#165380)

This PR extends the  `nvvm.mma` ops to support fp64 type.
The extension requires special handling of the return type for load ops
for fragment `a` and `b` since they return a scalar instead of a struct.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index b572ef9c1d07b..ba5e48e4ec9ba 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1999,6 +1999,9 @@ class MMA_LDST_OPS<list<GEOM> Geom, list<string> Frags, list<string> Types> {
 // llvm supports and can be extended as needed.
 class NVVM_MMA_OPS {
   // "wmma" operations
+  list<list<WMMA_REGS>> fp64_wmma_ops = MMA_OPS<
+            [GEOM<8, 8, 4>],
+            ["f64"], [], ["f64"], []>.ret;
   list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
             [GEOM<16, 16, 8>],
             ["tf32"], [], ["f32"], []>.ret;
@@ -2009,6 +2012,7 @@ class NVVM_MMA_OPS {
             [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
             ["s8","u8"], [], ["s32"], []>.ret;
   list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
+            fp64_wmma_ops,
             tf32_wmma_ops,
             fp_wmma_ops,
             i8_wmma_ops);
@@ -2025,9 +2029,17 @@ class NVVM_MMA_OPS {
   list<WMMA_REGS> ldst_tf32_cd_ops = MMA_LDST_OPS<
             [GEOM<16, 16, 8>],
             ["c", "d"], ["f32"]>.ret;
+  list<WMMA_REGS> ldst_f64_ab_ops = MMA_LDST_OPS<
+            [GEOM<8, 8, 4>],
+            ["a", "b"], ["f64"]>.ret;
+  list<WMMA_REGS> ldst_f64_cd_ops = MMA_LDST_OPS<
+            [GEOM<8, 8, 4>],
+            ["c", "d"], ["f64"]>.ret;
   list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
                                              ldst_tf32_ab_ops,
-                                             ldst_tf32_cd_ops);
+                                             ldst_tf32_cd_ops, 
+                                             ldst_f64_ab_ops,
+                                             ldst_f64_cd_ops);
   // Separate A/B/C fragments (loads) from D (stores).
   list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d"));
   list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d"));
@@ -2334,7 +2346,7 @@ def MMAFragAttr : EnumAttr<NVVM_Dialect, MMAFrag, "mma_frag"> {
 }
 
 def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">,
-  Results<(outs LLVM_AnyStruct:$res)>,
+  Results<(outs AnyTypeOf<[LLVM_AnyStruct, F64]>:$res)>,
   Arguments<(ins LLVM_AnyPointer: $ptr, I32: $stride, I32Attr:$m,
              I32Attr:$n, I32Attr:$k, MMALayoutAttr:$layout,
              MMATypesAttr:$eltype, MMAFragAttr:$frag)> {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 53a6f43c0bbcf..a5ffb9e77fa9d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -896,6 +896,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
   } else if (type == NVVM::MMATypes::f32) {
     elementType = builder.getF32Type();
     numberElements = 8;
+  } else if (type == NVVM::MMATypes::f64) {
+    elementType = builder.getF64Type();
+    if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
+      numberElements = 1;
+    else
+      numberElements = 2;
   } else if (type == NVVM::MMATypes::tf32) {
     elementType = builder.getI32Type();
     numberElements = 4;
@@ -954,6 +960,14 @@ LogicalResult NVVM::WMMALoadOp::verify() {
     return emitOpError() << "invalid attribute combination";
   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
       getEltype(), getFrag(), getM(), getN(), getK(), getContext());
+  // Special case for f64 fragments
+  Type f64Ty = Float64Type::get(getContext());
+  if (typeInfo.first == f64Ty && typeInfo.second == 1) {
+    if (getType() != f64Ty)
+      return emitOpError("expected destination type to be f64");
+    return success();
+  }
+  // Everything else is a struct
   Type dstType = LLVM::LLVMStructType::getLiteral(
       getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
   if (getType() != dstType)

diff  --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 09b8f593154b5..42aa2210eae1a 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -621,3 +621,14 @@ func.func @invalid_range_equal_bounds() {
   %0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32
   return
 }
+
+// -----
+
+// Test for correct return type check for wmma.load fragment a for f64 
+llvm.func @nvvm_wmma_load_a_f64(%arg0: !llvm.ptr, %arg1 : i32) {
+  // expected-error @below {{'nvvm.wmma.load' op expected destination type to be f64}}
+  %0 = nvvm.wmma.load %arg0, %arg1
+    {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
+    : (!llvm.ptr) -> !llvm.struct<(f64)>
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 594ae4849e3eb..9115de65ff0e8 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -463,6 +463,43 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
   llvm.return
 }
 
+// CHECK-LABEL: @nvvm_wmma_load_a_f64
+llvm.func @nvvm_wmma_load_a_f64(%arg0: !llvm.ptr, %arg1 : i32) {
+  // CHECK: call double @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64.p0(ptr %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.wmma.load %arg0, %arg1
+    {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
+    : (!llvm.ptr) -> f64
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_wmma_load_c_f64
+llvm.func @nvvm_wmma_load_c_f64(%arg0: !llvm.ptr, %arg1 : i32) {
+  // CHECK: call { double, double } @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64.p0(ptr %{{.*}}, i32 %{{.*}})
+  %0 = nvvm.wmma.load %arg0, %arg1
+    {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<c>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
+    : (!llvm.ptr) -> !llvm.struct<(f64, f64)>
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_wmma_mma_f64
+llvm.func @nvvm_wmma_mma_f64(%0 : f64, %1 : f64, %2 : f64, %3 : f64) {
+  // CHECK: { double, double } @llvm.nvvm.wmma.m8n8k4.mma.row.col.f64(double %{{.*}}, double %{{.*}}, double %{{.*}}, double %{{.*}})
+  %r = nvvm.wmma.mma %0, %1, %2, %3
+    {eltypeA = #nvvm.mma_type<f64>, eltypeB = #nvvm.mma_type<f64>, k = 4 : i32, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, m = 8 : i32, n = 8 : i32}
+    : (f64, f64, f64, f64)
+    -> !llvm.struct<(f64, f64)>
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_wmma_store_d_f64
+llvm.func @nvvm_wmma_store_d_f64(%arg0: !llvm.ptr, %arg1 : i32, %arg2 : f64, %arg3 : f64) {
+  // CHECK: call void @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64.p0(ptr %{{.*}}, double %{{.*}}, double %{{.*}}, i32 %{{.*}})
+  nvvm.wmma.store %arg0, %arg1, %arg2, %arg3
+    {eltype = #nvvm.mma_type<f64>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32}
+    : !llvm.ptr, f64, f64
+  llvm.return
+}
+
 // CHECK-LABEL: @cp_async
 llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
   // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})


        


More information about the Mlir-commits mailing list