[Mlir-commits] [mlir] [MLIR][ROCDL] Add Scale Convert f8 <-> F32 Support for GFX950 (PR #125564)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 3 11:53:50 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Corbin Robeck (CRobeck)

<details>
<summary>Changes</summary>

Add Rocdl support for the following GFX950 instructions:

CVT_SCALE_PK_FP8_F32
CVT_SCALE_PK_BF8_F32
CVT_SCALE_SR_FP8_F32
CVT_SCALE_SR_BF8_F32
CVT_SCALE_PK_F32_FP8
CVT_SCALE_PK_F32_BF8
CVT_SCALE_F32_FP8
CVT_SCALE_F32_BF8

---
Full diff: https://github.com/llvm/llvm-project/pull/125564.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+119) 
- (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+18) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 974712c581537a..a5fbc476317218 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -736,6 +736,95 @@ def ROCDL_CvtPkRtz:
   }];
 }
 
+//===---------------------------------------------------------------------===//
+// 32-bit float intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtScalePkF32Fp8 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+  let summary = "Scale and convert packed fp8 to packed f32";
+  let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed fp32.
+    Store the result in low/high word based on $wordSel, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+  }];
+}
+def ROCDL_CvtScalePkF32Bf8 :
+    ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+  let summary = "Scale and convert packed bf8 to packed f32";
+  let description = [{
+    Scale `src` by the exponent in `scale` then convert to packed fp32.
+    Store the result in low/high word based on $wordSel, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+
+	
+
+//===---------------------------------------------------------------------===//
+// 8-bit float scale intrinsics
+//===---------------------------------------------------------------------===//
+def ROCDL_CvtScaleF32PkFp8F32:
+    ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
+  let summary = "Scale and convert two f32's to packed fp8";
+  let description = [{
+    Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed fp8
+    and store into the low/high word of `old`, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+    
+def ROCDL_CvtScaleF32PkBf8F32:
+    ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
+  let summary = "Scale and convert two f32's to packed bf8";
+  let description = [{
+    Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed bf8
+    and store into the low/high word of `old`, preserving the other word.
+  }];
+  let assemblyFormat = [{
+    attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
+  }];
+}
+
+def ROCDL_CvtScaleF32SrFp8F32:
+    ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert f32 to fp8 using stochastic rounding";
+    let description = [{
+       Scale `src` by the exponent in `scale` then convert to fp8 with stochastic rounding
+       using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+
+def ROCDL_CvtScaleF32SrBf8F32:
+    ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
+    Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
+    let summary = "Scale and convert f32 to bf8 using stochastic rounding";
+    let description = [{
+       Scale `src` by the exponent in `scale` then convert to bf8 with stochastic rounding
+       using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+    }];
+  let assemblyFormat = [{
+    attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
+  }];
+}
+
+
+
 //===---------------------------------------------------------------------===//
 // 8-bit float intrinsics
 //===---------------------------------------------------------------------===//
@@ -751,6 +840,20 @@ def ROCDL_CvtF32Bf8Op :
   }];
 }
 
+def ROCDL_CvtScaleF32Bf8Op :
+    ROCDL_IntrOp<"cvt.scalef32.f32.bf8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
+  let summary = "Scale and convert bf8 to f32";
+  let description = [{
+    Scale `src` by the exponent in `scale` then convert 8-bit bf8 value 
+    from the `byteSel`th bit of `src` to fp32.
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+
 def ROCDL_CvtF32Fp8Op :
     ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>,
     Arguments<(ins I32:$srcA, I32:$byteSel)> {
@@ -763,6 +866,22 @@ def ROCDL_CvtF32Fp8Op :
   }];
 }
 
+
+def ROCDL_CvtScaleF32Fp8Op :
+    ROCDL_IntrOp<"cvt.scalef32.f32.fp8", [], [], [Pure], 1>,
+    Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
+  let summary = "Scale and convert fp8 to f32";
+  let description = [{
+    Scale `src` by the exponent in `scale` then convert 8-bit fp8 value
+    from the `byteSel`th bit of `src` to fp32.
+
+  }];
+  let assemblyFormat = [{
+    attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
+  }];
+}
+
+
 def ROCDL_CvtPkBf8F32Op :
     ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
     Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 5186e43398f01b..5f99a07f7fdac5 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -754,20 +754,38 @@ llvm.func @rocdl_8bit_floats(%source: i32, %stoch: i32) -> i32 {
 // CHECK-LABEL: @rocdl_8bit_floats
 // CHECK: rocdl.cvt.f32.bf8
 // CHECK: rocdl.cvt.f32.fp8
+// CHECK: rocdl.cvt.scalef32.f32.bf8
+// CHECK: rocdl.cvt.scalef32.f32.fp8
 // CHECK: rocdl.cvt.pk.bf8.f32
 // CHECK: rocdl.cvt.pk.fp8.f32
+// CHECK: rocdl.cvt.scalef32.pk.fp8.f32
 // CHECK: rocdl.cvt.sr.bf8.f32
 // CHECK: rocdl.cvt.sr.fp8.f32
+// CHECK: rocdl.cvt.scalef32.sr.fp8.f32
+// CHECK: rocdl.cvt.sr.bf8.f32
+// CHECK: rocdl.cvt.scalef32.sr.bf8.f32
+// CHECK: rocdl.cvt.scalef32.pk.f32.fp8
+// CHECK: rocdl.cvt.scalef32.pk.f32.bf8
   %c0 = llvm.mlir.constant(0 : i32) : i32
   %c2 = llvm.mlir.constant(2 : i32) : i32
   %c3 = llvm.mlir.constant(3 : i32) : i32
+  %c4 = llvm.mlir.constant(1.0 : f32) : f32
   %false = llvm.mlir.constant(false) : i1
   %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
   %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
+  %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
+  %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
   %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
   %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
+  %source3_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %v1, %v2, %c4 -> %source2[%false] : i32
   %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
+  %source4_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %v1, %v2, %c4 -> %source2[%false] : i32
   %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
+  %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
+  %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
+  %source6_scaled  = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
+  %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32 
+  %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
   llvm.return %source5 : i32
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/125564


More information about the Mlir-commits mailing list