[Mlir-commits] [mlir] 43a50de - [MLIR][ROCDL] Add GFX940 SMFMAC (2:4 sparsity) instructions to the ROCDL dialect (#124435)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 27 02:58:29 PST 2025


Author: Samuel Ginzburg
Date: 2025-01-27T11:58:26+01:00
New Revision: 43a50deb63453cd3c800f097514d500536f9d436

URL: https://github.com/llvm/llvm-project/commit/43a50deb63453cd3c800f097514d500536f9d436
DIFF: https://github.com/llvm/llvm-project/commit/43a50deb63453cd3c800f097514d500536f9d436.diff

LOG: [MLIR][ROCDL] Add GFX940 SMFMAC (2:4 sparsity) instructions to the ROCDL dialect (#124435)

# Overview

This PR adds 2:4 structured sparsity (sparse A, dense B) matrix multiply
instructions to ROCDL.

# Testing

I've added tests to Dialect/mlir and Target/mlir

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/test/Dialect/LLVMIR/rocdl.mlir
    mlir/test/Target/LLVMIR/rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 95fbe7ed66a434..974712c581537a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -408,6 +408,24 @@ def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8">;
 def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16">;
 def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", [0,1]>;
 def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", [0,1]>;
+
+// 2:4 Sparsity ops (GFX940)
+def ROCDL_smfmac_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.f16">;
+def ROCDL_smfmac_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.f16">;
+def ROCDL_smfmac_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.bf16">;
+def ROCDL_smfmac_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.bf16">;
+def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x64.i8">;
+def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x32.i8">;
+def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.bf8">;
+def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.fp8">;
+def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.bf8">;
+def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">;
+def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.bf8">;
+def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
+def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
+def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
+
+
 //===---------------------------------------------------------------------===//
 // WMMA intrinsics
 class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,

diff  --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 712f8c2a1caf66..5186e43398f01b 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -258,6 +258,93 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
   llvm.return
 }
 
+
+llvm.func @rocdl.smfmac(%arg0 : i32,
+                   %arg1 : vector<4 x f16>,
+                   %arg2 : vector<8 x f16>,
+                   %arg3 : vector<4 x f32>,
+                   %arg4 : vector<16 x f32>,
+                   %arg5 : vector<4 x i16>,
+                   %arg6 : vector<8 x i16>,
+                   %arg7 : vector<2xi32>,
+                   %arg8 : vector<4xi32>,
+                   %arg9 : vector<16xi32>) -> vector<4 x f32> {
+  %csti32 = llvm.mlir.constant(42 : i32) : i32
+
+  // CHECK-LABEL: rocdl.smfmac
+  // CHECK: rocdl.smfmac.f32.16x16x32.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xf16>, vector<8xf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x16.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xf16>, vector<8xf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x32.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi16>, vector<8xi16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x16.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi16>, vector<8xi16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.i32.16x16x64.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<4xi32>,
+                                 i32, i32, i32) -> vector<4xi32>
+
+  // CHECK: rocdl.smfmac.i32.32x32x32.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<16xi32>,
+                                 i32, i32, i32) -> vector<16xi32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  llvm.return %r0 : vector<4 x f32>
+}
+
 llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
                    %arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
                    %arg3 : vector<6xi32>, %arg4 : vector<4xi32>) {

diff  --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index b74edb62106837..326bd3ae6b6f84 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -398,6 +398,95 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
   llvm.return %r0 : vector<32 x f32>
 }
 
+llvm.func @rocdl.smfmac(%arg0 : i32,
+                   %arg1 : vector<4 x f16>,
+                   %arg2 : vector<8 x f16>,
+                   %arg3 : vector<4 x f32>,
+                   %arg4 : vector<16 x f32>,
+                   %arg5 : vector<4 x i16>,
+                   %arg6 : vector<8 x i16>,
+                   %arg7 : vector<2xi32>,
+                   %arg8 : vector<4xi32>,
+                   %arg9 : vector<16xi32>) -> vector<4 x f32> {
+  %csti32 = llvm.mlir.constant(42 : i32) : i32
+
+  // CHECK-LABEL: rocdl.smfmac
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xf16>, vector<8xf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xf16>, vector<8xf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi16>, vector<8xi16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi16>, vector<8xi16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x64.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+  %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<4xi32>,
+                                 i32, i32, i32) -> vector<4xi32>
+
+  // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x32.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+  %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<16xi32>,
+                                 i32, i32, i32) -> vector<16xi32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<2xi32>, vector<4xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  llvm.return %r0 : vector<4 x f32>
+}
+
+
 llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
                    %arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
                    %arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<16 x f32> {


        


More information about the Mlir-commits mailing list