[Mlir-commits] [mlir] [ROCDL] Added `rocdl.cvt.scale.pk8` ops (PR #161411)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 2 04:40:11 PDT 2025


https://github.com/ravil-mobile updated https://github.com/llvm/llvm-project/pull/161411

>From cf200fb88659e5e88ff5a07eea38692df97104f6 Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Tue, 30 Sep 2025 17:35:40 +0000
Subject: [PATCH] [ROCDL] Added `rocdl.cvt.scale.pk8` ops

---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 22 ++++++++++++--
 mlir/test/Dialect/LLVMIR/rocdl.mlir          | 32 ++++++++++++++++++++
 mlir/test/Target/LLVMIR/rocdl.mlir           | 28 +++++++++++++++++
 3 files changed, 80 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 8b687a7f29bef..29001e26eaaaf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -985,7 +985,6 @@ class ScaleArgInfo<TypeConstraint argTyVal, string typeName> {
 //===---------------------------------------------------------------------===//
 // Scaled {fp4,bf8,fp8} to {bf16,f16,f32} conversion intrinsics
 //===---------------------------------------------------------------------===//
-
 foreach smallT = [
   ScaleArgInfo<I32, "Fp4">,
   ScaleArgInfo<ROCDL_V2I32Type, "Fp8">,
@@ -996,6 +995,8 @@ foreach smallT = [
     ScaleArgInfo<ROCDL_V8BF16Type, "Bf16">,
     ScaleArgInfo<ROCDL_V8F32Type, "F32">,
   ] in {
+
+    // Up-scaling
     def ROCDL_CvtPkScalePk8 # largeT.nameForOp # smallT.nameForOp # Op :
           ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk8." # largeT.name # "." # smallT.name,
           [Pure], 1, [2], ["scaleSel"]>,
@@ -1010,13 +1011,30 @@ foreach smallT = [
         attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res)
       }];
     }
+
+    // Down-scaling
+    def ROCDL_CvtScaleF32Pk8 # smallT.nameForOp # largeT.nameForOp # Op :
+        ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk8." # smallT.name # "." # largeT.name,
+          [Pure], 1>,
+        Arguments<(ins largeT.type:$src, F32:$scale)> {
+      let results = (outs smallT.type:$res);
+      let summary = "Scale and convert packed "
+        # largeT.name # " to packed " # smallT.name ;
+     let description = [{
+        Convert 8 packed }] # largeT.name # [{ values to packed }]
+        # smallT.name # [{, multiplying by the exponent part of `scale`
+        before doing so. This op is for gfx1250+ arch.
+      }];
+      let assemblyFormat = [{
+        attr-dict $src `,` $scale `:` type($res)
+      }];
+    }
   } // foreach largeT
 } // foreach smallTOp
 
 //===---------------------------------------------------------------------===//
 // Scaled {bf6,fp6} to {bf16,f16,f32} conversion intrinsics
 //===---------------------------------------------------------------------===//
-
 foreach smallT = [
   ScaleArgInfo<ROCDL_V3I32Type, "Fp6">,
   ScaleArgInfo<ROCDL_V3I32Type, "Bf6">
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 0bad151570029..6134695e9ced6 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -1068,6 +1068,38 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
 
 // -----
 
+// CHECK-LABEL: rocdl.cvt.scalef32.pk8
+llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>,
+                                  %v8xf16: vector<8xf16>,
+                                  %v8xbf16: vector<8xbf16>,
+                                  %scale: f32) {
+
+  // CHECK: rocdl.cvt.scalef32.pk8.fp8.f32
+  %0 =      rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32>
+  // CHECK: rocdl.cvt.scalef32.pk8.bf8.f32
+  %1 =      rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32>
+  // CHECK: rocdl.cvt.scalef32.pk8.fp4.f32
+  %2 =      rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32
+
+  // CHECK: rocdl.cvt.scalef32.pk8.fp8.f16
+  %3 =      rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32>
+  // CHECK: rocdl.cvt.scalef32.pk8.bf8.f16
+  %4 =      rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32>
+  // CHECK: rocdl.cvt.scalef32.pk8.fp4.f16
+  %5 =      rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32
+
+  // CHECK: rocdl.cvt.scalef32.pk8.fp8.bf16
+  %6 =      rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32>
+  // CHECK: rocdl.cvt.scalef32.pk8.bf8.bf16
+  %7 =      rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32>
+  // CHECK: rocdl.cvt.scalef32.pk8.fp4.bf16
+  %8 =      rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32
+
+  llvm.return
+}
+
+// -----
+
 // CHECK-LABEL: rocdl.cvt.scale.pk16
 llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index e043a8c533d05..00ee6b795c43a 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1340,6 +1340,34 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
   llvm.return
 }
 
+// CHECK-LABEL: rocdl.cvt.scalef32.pk8
+// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], float %[[SCALE:.+]])
+llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16>, %v8xbf16: vector<8xbf16>, %scale: f32) {
+
+  // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
+  %0 = rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32>
+  // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
+  %1 = rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32>
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
+  %2 = rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32
+
+  // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
+  %3 = rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32>
+  // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
+  %4 = rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32>
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
+  %5 = rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32
+
+  // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
+  %6 = rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32>
+  // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
+  %7 = rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32>
+  // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
+  %8 = rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32
+
+  llvm.return
+}
+
 // CHECK-LABEL: @rocdl.cvt.scale.pk16
 // CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
 llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {



More information about the Mlir-commits mailing list