[Mlir-commits] [mlir] c409da2 - [mlir][ROCDL] Add permlanex16 op to allow subgroup reductions on gfx10+ (#135983)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 16 14:53:20 PDT 2025


Author: Muzammil
Date: 2025-04-16T16:53:17-05:00
New Revision: c409da2223ad910d2a0fb491c19623c679e8aae3

URL: https://github.com/llvm/llvm-project/commit/c409da2223ad910d2a0fb491c19623c679e8aae3
DIFF: https://github.com/llvm/llvm-project/commit/c409da2223ad910d2a0fb491c19623c679e8aae3.diff

LOG: [mlir][ROCDL] Add permlanex16 op to allow subgroup reductions on gfx10+ (#135983)

Adding Permlanex16Op to ROCDL dialect to enable subgroup reduce lowering
to DPP ops for gfx 10+ devices.
See [this PR](https://github.com/llvm/llvm-project/pull/133204).

---------

Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/test/Dialect/LLVMIR/rocdl.mlir
    mlir/test/Target/LLVMIR/rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 900155c274b4d..186a4f53f93cb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -668,6 +668,22 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
   }];
 }
 
+// PermLaneX16 intrinsic operation
+def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
+    [AllTypesMatch<["res", "old", "src0"]>, AllTypesMatch<["src1", "src2"]>], 1, 0, 0,
+    [4, 5], ["fi", "boundControl"]>,
+  Arguments<(ins LLVM_Type:$old, LLVM_Type:$src0, LLVM_Type:$src1, LLVM_Type:$src2,
+             I1Attr:$fi, I1Attr:$boundControl)> {
+  let results = (outs LLVM_Type:$res);
+  let assemblyFormat = [{
+    attr-dict $old `,` $src0 `,` $src1 `,` $src2 `,` $fi `,` $boundControl `:` type($src0) `,` type($src1)
+  }];
+  let description = [{
+    Performs a `permlanex16` operation with the given operands, applying the
+    permutation specified by $fi to the provided inputs.
+  }];
+}
+
 def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
                         BuildableType<"::mlir::VectorType::get("
                           "{2},$_builder.getI16Type())">;

diff  --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 828fd58544597..cda1a9ca5f1f6 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -889,6 +889,16 @@ llvm.func @rocdl.readlane(%src : f32) -> f32 {
 
 // -----
 
+llvm.func @rocdl.permlanex16(%src : f32) -> f32 {
+  %cst0 = llvm.mlir.constant(-1 : i32) : i32
+  // CHECK-LABEL: rocdl.permlanex16
+  // CHECK: rocdl.permlanex16 %{{.*}} %{{.*}}
+  %ret = rocdl.permlanex16 %src, %src, %cst0, %cst0, 0, -1 : f32, i32
+  llvm.return %ret : f32
+}
+
+// -----
+
 // expected-error at below {{attribute attached to unexpected op}}
 func.func private @expected_llvm_func() attributes { rocdl.kernel }
 

diff  --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index e70617bfff99e..3db1f7b2b6427 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -872,6 +872,20 @@ llvm.func @rocdl.make.buffer.rsrc.p7.p1(%ptr : !llvm.ptr<1>,
   llvm.return %rsrc : !llvm.ptr<7>
 }
 
+llvm.func @rocdl.permlanex16(%src0 : f32, %src1 : i32, %src2 : vector<2 x f32>, %src3 : vector<2 x i32>) -> f32 {
+  %cst0 = llvm.mlir.constant(-1 : i32) : i32
+  // CHECK-LABEL: rocdl.permlanex16
+  // CHECK: call float @llvm.amdgcn.permlanex16.f32(float %{{.*}}, float %{{.*}}, i32 -1, i32 -1, i1 false, i1 true)
+  %ret0 = rocdl.permlanex16 %src0, %src0, %cst0, %cst0, 0, -1 : f32, i32
+  // CHECK: call i32 @llvm.amdgcn.permlanex16.i32(i32 %{{.*}}, i32 %{{.*}}, i32 -1, i32 -1, i1 false, i1 true)
+  %ret1 = rocdl.permlanex16 %src1, %src1, %cst0, %cst0, 0, -1 : i32, i32
+  // CHECK: call <2 x float> @llvm.amdgcn.permlanex16.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, i32 -1, i32 -1, i1 false, i1 true)
+  %ret2 = rocdl.permlanex16 %src2, %src2, %cst0, %cst0, 0, -1 : vector<2 x f32>, i32
+  // CHECK: call <2 x i32> @llvm.amdgcn.permlanex16.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, i32 -1, i32 -1, i1 false, i1 true)
+  %ret3 = rocdl.permlanex16 %src3, %src3, %cst0, %cst0, 0, -1 : vector<2 x i32>, i32
+  llvm.return %ret0 : f32
+}
+
 llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
   // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
   %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>


        


More information about the Mlir-commits mailing list