[Mlir-commits] [mlir] [mlir][AMDGPU] Updated `PermlaneSwapOp` to select correct val (PR #157586)

Gaurav Verma llvmlistbot at llvm.org
Tue Sep 9 10:17:26 PDT 2025


https://github.com/xintin updated https://github.com/llvm/llvm-project/pull/157586

>From 8c5cfa39b184842426f4531b7e061a7ca93862c3 Mon Sep 17 00:00:00 2001
From: Matthew Wong <Matthew.Wong2 at amd.com>
Date: Tue, 9 Sep 2025 00:08:56 +0000
Subject: [PATCH 1/2] updated vdst selection

Signed-off-by: Matthew Wong <Matthew.Wong2 at amd.com>
---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 203790ed95153..0078eed8b7a67 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1915,7 +1915,16 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
       else
         llvm_unreachable("unsupported row length");
 
-      Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+      const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+      const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
+
+      const Value isEqual =
+          rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);
+
+      // Per `permlane(16|32)` semantics: if the first extracted element equals
+      // 'v', the result is the second element; otherwise it is the first.
+      Value vdstNew =
+          rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
       permuted.emplace_back(vdstNew);
     }
 

>From b42c42ad1d197d0eb787aab2945a5725a95c61ae Mon Sep 17 00:00:00 2001
From: Matthew Wong <Matthew.Wong2 at amd.com>
Date: Tue, 9 Sep 2025 00:53:37 +0000
Subject: [PATCH 2/2] updated lit tests

Signed-off-by: Matthew Wong <Matthew.Wong2 at amd.com>
---
 .../Conversion/AMDGPUToROCDL/permlane.mlir    | 97 ++++++++++++++-----
 .../Conversion/GPUToROCDL/gpu-to-rocdl.mlir   | 14 ++-
 2 files changed, 81 insertions(+), 30 deletions(-)
 mode change 100644 => 100755 mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
 mode change 100644 => 100755 mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

diff --git a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
old mode 100644
new mode 100755
index aae2b1d0fd90c..a92321da8f357
--- a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
@@ -4,8 +4,11 @@
 // CHECK-SAME: (%[[ARG0:.*]]: i32)
 func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
 // CHECK:  %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:  %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK:  return %[[RES]] : i32
+// CHECK:  %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
+// CHECK:  %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK:  return %[[SEL]] : i32
   %0 = amdgpu.permlane_swap %arg0 16 : i32
   return %0 : i32
 }
@@ -14,8 +17,11 @@ func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
 // CHECK-SAME: (%[[ARG0:.*]]: i32)
 func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
 // CHECK:  %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], true, true : (i32, i32) -> <(i32, i32)>
-// CHECK:  %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK:  return %[[RES]] : i32
+// CHECK:  %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
+// CHECK:  %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK:  return %[[SEL]] : i32
   %0 = amdgpu.permlane_swap %arg0 16 { fetch_inactive = true, bound_ctrl = true }  : i32
   return %0 : i32
 }
@@ -24,8 +30,11 @@ func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
 // CHECK-SAME: (%[[ARG0:.*]]: i32)
 func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
 // CHECK:  %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:  %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK:  return %[[RES]] : i32
+// CHECK:  %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
+// CHECK:  %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK:  return %[[SEL]] : i32
   %0 = amdgpu.permlane_swap %arg0 32 : i32
   return %0 : i32
 }
@@ -35,8 +44,11 @@ func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
 func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
 // CHECK:  %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
 // CHECK:  %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:  %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK:  %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[CAST]] : i32
+// CHECK:  %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[SEL]] : i32 to f32
 // CHECK:  return %[[RES_CAST]] : f32
   %0 = amdgpu.permlane_swap %arg0 16 : f32
   return %0 : f32
@@ -47,8 +59,11 @@ func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
 func.func @test_permlane32_f32(%arg0 : f32) -> f32 {
 // CHECK:  %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
 // CHECK:  %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:  %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK:  %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[CAST]] : i32
+// CHECK:  %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[SEL]] : i32 to f32
 // CHECK:  return %[[RES_CAST]] : f32
   %0 = amdgpu.permlane_swap %arg0 32 : f32
   return %0 : f32
@@ -60,8 +75,11 @@ func.func @test_permlane16_f16(%arg0 : f16) -> f16 {
 // CHECK:  %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
 // CHECK:  %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
 // CHECK:  %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:  %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK:  %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK:  %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ZEXT]] : i32
+// CHECK:  %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK:  %[[TRUNC:.*]] = llvm.trunc %[[SEL]] : i32 to i16
 // CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
 // CHECK:  return %[[RES_CAST]] : f16
   %0 = amdgpu.permlane_swap %arg0 16 : f16
@@ -74,8 +92,11 @@ func.func @test_permlane32_f16(%arg0 : f16) -> f16 {
 // CHECK:  %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
 // CHECK:  %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
 // CHECK:  %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:  %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK:  %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK:  %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK:  %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ZEXT]] : i32
+// CHECK:  %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK:  %[[TRUNC:.*]] = llvm.trunc %[[SEL]] : i32 to i16
 // CHECK:  %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
 // CHECK:  return %[[RES_CAST]] : f16
   %0 = amdgpu.permlane_swap %arg0 32 : f16
@@ -90,10 +111,16 @@ func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
 // CHECK-DAG:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
 // CHECK:      %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
 // CHECK:      %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
-// CHECK:      %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:      %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
-// CHECK:      %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:      %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[T0:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK:      %[[T0_0:.*]] = llvm.extractvalue %[[T0]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[T0_1:.*]] = llvm.extractvalue %[[T0]][1] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[CMP0:.*]] = llvm.icmp "eq" %[[T0_0]], %[[ELEM0]] : i32
+// CHECK:      %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[T0_1]], %[[T0_0]] : i1, i32
+// CHECK:      %[[T1:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK:      %[[T1_0:.*]] = llvm.extractvalue %[[T1]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[T1_1:.*]] = llvm.extractvalue %[[T1]][1] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[CMP1:.*]] = llvm.icmp "eq" %[[T1_0]], %[[ELEM1]] : i32
+// CHECK:      %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[T1_1]], %[[T1_0]] : i1, i32
 // CHECK:      %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
 // CHECK:      %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
 // CHECK:      return %[[VEC_INSERT1]] : vector<2xi32>
@@ -109,10 +136,16 @@ func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
 // CHECK-DAG:  %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
 // CHECK:      %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
 // CHECK:      %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
-// CHECK:      %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:      %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
-// CHECK:      %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:      %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[T0:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK:      %[[T0_0:.*]] = llvm.extractvalue %[[T0]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[T0_1:.*]] = llvm.extractvalue %[[T0]][1] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[CMP0:.*]] = llvm.icmp "eq" %[[T0_0]], %[[ELEM0]] : i32
+// CHECK:      %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[T0_1]], %[[T0_0]] : i1, i32
+// CHECK:      %[[T1:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK:      %[[T1_0:.*]] = llvm.extractvalue %[[T1]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[T1_1:.*]] = llvm.extractvalue %[[T1]][1] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[CMP1:.*]] = llvm.icmp "eq" %[[T1_0]], %[[ELEM1]] : i32
+// CHECK:      %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[T1_1]], %[[T1_0]] : i1, i32
 // CHECK:      %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
 // CHECK:      %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
 // CHECK:      return %[[VEC_INSERT1]] : vector<2xi32>
@@ -130,9 +163,15 @@ func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
 // CHECK:      %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
 // CHECK:      %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
 // CHECK:      %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:      %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[PERM0_E0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[PERM0_E1:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[CMP0:.*]] = llvm.icmp "eq" %[[PERM0_E0]], %[[ELEM0]] : i32
+// CHECK:      %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[PERM0_E1]], %[[PERM0_E0]] : i1, i32
 // CHECK:      %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:      %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[PERM1_E0:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[PERM1_E1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[CMP1:.*]] = llvm.icmp "eq" %[[PERM1_E0]], %[[ELEM1]] : i32
+// CHECK:      %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[PERM1_E1]], %[[PERM1_E0]] : i1, i32
 // CHECK:      %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
 // CHECK:      %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
 // CHECK:      %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
@@ -151,9 +190,15 @@ func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
 // CHECK:      %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
 // CHECK:      %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
 // CHECK:      %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:      %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[PERM0_E0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[PERM0_E1:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[CMP0:.*]] = llvm.icmp "eq" %[[PERM0_E0]], %[[ELEM0]] : i32
+// CHECK:      %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[PERM0_E1]], %[[PERM0_E0]] : i1, i32
 // CHECK:      %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK:      %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[PERM1_E0:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[PERM1_E1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK:      %[[CMP1:.*]] = llvm.icmp "eq" %[[PERM1_E0]], %[[ELEM1]] : i32
+// CHECK:      %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[PERM1_E1]], %[[PERM1_E0]] : i1, i32
 // CHECK:      %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
 // CHECK:      %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
 // CHECK:      %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
old mode 100644
new mode 100755
index c6261b37ef8f2..ef631ce8a12e5
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -749,13 +749,19 @@ gpu.module @test_module {
     %shfl1, %pred1 = gpu.shuffle xor %arg0, %arg1, %arg4 : f32
     // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
     // CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
-    // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
-    // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+    // CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+    // CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
+    // CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
+    // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
+    // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
     %shfl2, %pred2 = gpu.shuffle xor %arg0, %arg2, %arg4 : f32
     // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
     // CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
-    // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
-    // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+    // CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+    // CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
+    // CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
+    // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
+    // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
     %shfl3, %pred3 = gpu.shuffle xor  %arg0, %arg3, %arg4 : f32
     func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
   }



More information about the Mlir-commits mailing list