[Mlir-commits] [mlir] [mlir][gpu][spirv] Remove rotation semantics of gpu.shuffle up/down (PR #139105)

Hsiangkai Wang llvmlistbot at llvm.org
Wed Jun 4 11:51:21 PDT 2025


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/139105

>From 632c8978a02abd93687bb95090503bdea6ebcaf1 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 9 May 2025 09:33:33 +0100
Subject: [PATCH 1/6] [mlir][gpu][spirv] Add patterns for gpu.shuffle up/down

Convert

gpu.shuffle down %val, %offset, %width

to

spirv.GroupNonUniformRotateKHR <Subgroup> %val, %offset, cluster_size(%width)

Convert

gpu.shuffle up %val, %offset, %width

to

%down_offset = arith.subi %width, %offset
spirv.GroupNonUniformRotateKHR <Subgroup> %val, %down_offset, cluster_size(%width)
---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 15 ++++-
 mlir/test/Conversion/GPUToSPIRV/shuffle.mlir  | 57 +++++++++++++++++++
 2 files changed, 70 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 3cc64b82950b5..3d53c17eb6c07 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -450,8 +450,19 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
     result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
         loc, scope, adaptor.getValue(), adaptor.getOffset());
     break;
-  default:
-    return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
+  case gpu::ShuffleMode::DOWN:
+    result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
+        loc, scope, adaptor.getValue(), adaptor.getOffset(),
+        shuffleOp.getWidth());
+    break;
+  case gpu::ShuffleMode::UP: {
+    Value offsetForShuffleDown = rewriter.create<arith::SubIOp>(
+        loc, shuffleOp.getWidth(), adaptor.getOffset());
+    result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
+        loc, scope, adaptor.getValue(), offsetForShuffleDown,
+        shuffleOp.getWidth());
+    break;
+  }
   }
 
   rewriter.replaceOp(shuffleOp, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
index d3d8ec0dab40f..5d7d3c81577e3 100644
--- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
@@ -72,3 +72,60 @@ gpu.module @kernels {
 }
 
 }
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformRotateKHR], []>,
+    #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @shuffle_down()
+  gpu.func @shuffle_down() kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    %offset = arith.constant 4 : i32
+    %width = arith.constant 16 : i32
+    %val = arith.constant 42.0 : f32
+
+    // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
+    // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
+    // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
+    // CHECK: %{{.+}} = spirv.Constant true
+    // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
+    %result, %valid = gpu.shuffle down %val, %offset, %width : f32
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformRotateKHR], []>,
+    #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @shuffle_up()
+  gpu.func @shuffle_up() kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    %offset = arith.constant 4 : i32
+    %width = arith.constant 16 : i32
+    %val = arith.constant 42.0 : f32
+
+    // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
+    // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
+    // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
+    // CHECK: %{{.+}} = spirv.Constant true
+    // CHECK: %[[DOWN_OFFSET:.+]] = spirv.Constant 12 : i32
+    // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[DOWN_OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
+    %result, %valid = gpu.shuffle up %val, %offset, %width : f32
+    gpu.return
+  }
+}
+
+}

>From 86c60375c30d5cba20a248615dbafa865dcd4489 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 21 May 2025 13:54:43 +0100
Subject: [PATCH 2/6] The width argument cannot exceed the subgroup limit.

---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 3d53c17eb6c07..2e45d782ce0c7 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -430,10 +430,12 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
   unsigned subgroupSize =
       targetEnv.getAttr().getResourceLimits().getSubgroupSize();
   IntegerAttr widthAttr;
+  // The width argument specifies the number of lanes that participate in the
+  // shuffle. The width value should not exceed the subgroup limit.
   if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
-      widthAttr.getValue().getZExtValue() != subgroupSize)
+      widthAttr.getValue().getZExtValue() <= subgroupSize)
     return rewriter.notifyMatchFailure(
-        shuffleOp, "shuffle width and target subgroup size mismatch");
+        shuffleOp, "shuffle width is larger than target subgroup size");
 
   Location loc = shuffleOp.getLoc();
   Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),

>From 62e777e26cb72a255f9ef289cc689dcf935748e5 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 21 May 2025 14:07:55 +0100
Subject: [PATCH 3/6] fix typo

---
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 2e45d782ce0c7..c8dc1f41c7146 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -433,7 +433,7 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
   // The width argument specifies the number of lanes that participate in the
   // shuffle. The width value should not exceed the subgroup limit.
   if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
-      widthAttr.getValue().getZExtValue() <= subgroupSize)
+      widthAttr.getValue().getZExtValue() > subgroupSize)
     return rewriter.notifyMatchFailure(
         shuffleOp, "shuffle width is larger than target subgroup size");
 

>From 07acb22f3d5edd9954d1804793f3390c94a9603b Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 22 May 2025 09:33:22 +0100
Subject: [PATCH 4/6] remove test for gpu.shuffle width != subgroup_size limit

---
 mlir/test/Conversion/GPUToSPIRV/shuffle.mlir | 23 --------------------
 1 file changed, 23 deletions(-)

diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
index 5d7d3c81577e3..f0bf5e110915c 100644
--- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
@@ -26,29 +26,6 @@ gpu.module @kernels {
 
 // -----
 
-module attributes {
-  gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle], []>, #spirv.resource_limits<subgroup_size = 32>>
-} {
-
-gpu.module @kernels {
-  gpu.func @shuffle_xor() kernel
-    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
-    %mask = arith.constant 8 : i32
-    %width = arith.constant 16 : i32
-    %val = arith.constant 42.0 : f32
-
-    // Cannot convert due to shuffle width and target subgroup size mismatch
-    // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}}
-    %result, %valid = gpu.shuffle xor %val, %mask, %width : f32
-    gpu.return
-  }
-}
-
-}
-
-// -----
-
 module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle], []>, #spirv.resource_limits<subgroup_size = 16>>

>From f85e4a27869b282f71cefb4aae6ae9ef28559667 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 4 Jun 2025 10:34:41 +0100
Subject: [PATCH 5/6] Remove rotation semantic in gpu.shufflw up/down

There is no such semantic in SPIRV OpGroupNonUniformShuffleUp and
OpGroupNonUniformShuffleDown. In addition, there is no such semantic in
NVVM shfl intrinsics.

Refer to NVVM IR spec
https://docs.nvidia.com/cuda/archive/12.2.1/nvvm-ir-spec/index.html#data-movement

"If the computed source lane index j is in range, the returned i32 value
will be the value of %a from lane j; otherwise, it will be the the value
of %a from the current thread. If the thread corresponding to lane j is
inactive, then the returned i32 value is undefined."
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  6 ++--
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 21 ++++--------
 mlir/test/Conversion/GPUToSPIRV/shuffle.mlir  | 32 ++++++++++++++++---
 3 files changed, 38 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 8d83d02e27c33..6203192bc644b 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1332,7 +1332,8 @@ def GPU_ShuffleOp : GPU_Op<
     %3, %4 = gpu.shuffle down %0, %cst1, %width : f32
     ```
 
-    For lane `k`, returns the value from lane `(k + 1) % width`.
+    For lane `k`, returns the value from lane `(k + cst1)`. The resulting value
+    is undefined if the lane is out of bounds in the subgroup.
 
     `up` example:
 
@@ -1341,7 +1342,8 @@ def GPU_ShuffleOp : GPU_Op<
     %5, %6 = gpu.shuffle up %0, %cst1, %width : f32
     ```
 
-    For lane `k`, returns the value from lane `(k - 1) % width`.
+    For lane `k`, returns the value from lane `(k - cst1)`. The resulting value
+    is undefined if the lane is out of bounds in the subgroup.
 
     `idx` example:
 
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index c8dc1f41c7146..9fe2c9a3019af 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -430,12 +430,10 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
   unsigned subgroupSize =
       targetEnv.getAttr().getResourceLimits().getSubgroupSize();
   IntegerAttr widthAttr;
-  // The width argument specifies the number of lanes that participate in the
-  // shuffle. The width value should not exceed the subgroup limit.
   if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
-      widthAttr.getValue().getZExtValue() > subgroupSize)
+      widthAttr.getValue().getZExtValue() != subgroupSize)
     return rewriter.notifyMatchFailure(
-        shuffleOp, "shuffle width is larger than target subgroup size");
+        shuffleOp, "shuffle width and target subgroup size mismatch");
 
   Location loc = shuffleOp.getLoc();
   Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
@@ -453,19 +451,14 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
         loc, scope, adaptor.getValue(), adaptor.getOffset());
     break;
   case gpu::ShuffleMode::DOWN:
-    result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
-        loc, scope, adaptor.getValue(), adaptor.getOffset(),
-        shuffleOp.getWidth());
+    result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
+        loc, scope, adaptor.getValue(), adaptor.getOffset());
     break;
-  case gpu::ShuffleMode::UP: {
-    Value offsetForShuffleDown = rewriter.create<arith::SubIOp>(
-        loc, shuffleOp.getWidth(), adaptor.getOffset());
-    result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
-        loc, scope, adaptor.getValue(), offsetForShuffleDown,
-        shuffleOp.getWidth());
+  case gpu::ShuffleMode::UP:
+    result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
+        loc, scope, adaptor.getValue(), adaptor.getOffset());
     break;
   }
-  }
 
   rewriter.replaceOp(shuffleOp, {result, trueVal});
   return success();
diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
index f0bf5e110915c..56877a756b7ba 100644
--- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
@@ -26,6 +26,29 @@ gpu.module @kernels {
 
 // -----
 
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle], []>, #spirv.resource_limits<subgroup_size = 32>>
+} {
+
+gpu.module @kernels {
+  gpu.func @shuffle_xor() kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    %mask = arith.constant 8 : i32
+    %width = arith.constant 16 : i32
+    %val = arith.constant 42.0 : f32
+
+    // Cannot convert due to shuffle width and target subgroup size mismatch
+    // expected-error @+1 {{failed to legalize operation 'gpu.shuffle'}}
+    %result, %valid = gpu.shuffle xor %val, %mask, %width : f32
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
 module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle], []>, #spirv.resource_limits<subgroup_size = 16>>
@@ -54,7 +77,7 @@ gpu.module @kernels {
 
 module attributes {
   gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformRotateKHR], []>,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
     #spirv.resource_limits<subgroup_size = 16>>
 } {
 
@@ -70,7 +93,7 @@ gpu.module @kernels {
     // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
     // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
     // CHECK: %{{.+}} = spirv.Constant true
-    // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
+    // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
     %result, %valid = gpu.shuffle down %val, %offset, %width : f32
     gpu.return
   }
@@ -82,7 +105,7 @@ gpu.module @kernels {
 
 module attributes {
   gpu.container_module,
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformRotateKHR], []>,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
     #spirv.resource_limits<subgroup_size = 16>>
 } {
 
@@ -98,8 +121,7 @@ gpu.module @kernels {
     // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
     // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
     // CHECK: %{{.+}} = spirv.Constant true
-    // CHECK: %[[DOWN_OFFSET:.+]] = spirv.Constant 12 : i32
-    // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[DOWN_OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
+    // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
     %result, %valid = gpu.shuffle up %val, %offset, %width : f32
     gpu.return
   }

>From 09d2ef873edc88e9db95ba336951ff2509258c9f Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 4 Jun 2025 17:36:30 +0100
Subject: [PATCH 6/6] Refine description and set 'valid' flag according to the
 resulting landID

---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  9 +--
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 67 +++++++++++++++++--
 mlir/test/Conversion/GPUToSPIRV/shuffle.mlir  | 39 +++++++++++
 3 files changed, 106 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 6203192bc644b..adf2e85f64b85 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1332,8 +1332,9 @@ def GPU_ShuffleOp : GPU_Op<
     %3, %4 = gpu.shuffle down %0, %cst1, %width : f32
     ```
 
-    For lane `k`, returns the value from lane `(k + cst1)`. The resulting value
-    is undefined if the lane is out of bounds in the subgroup.
+    For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is
+    bigger than or equal to `width`, the value is unspecified and `valid` is
+    `false`.
 
     `up` example:
 
@@ -1342,8 +1343,8 @@ def GPU_ShuffleOp : GPU_Op<
     %5, %6 = gpu.shuffle up %0, %cst1, %width : f32
     ```
 
-    For lane `k`, returns the value from lane `(k - cst1)`. The resulting value
-    is undefined if the lane is out of bounds in the subgroup.
+    For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is
+    smaller than `0`, the value is unspecified and `valid` is `false`.
 
     `idx` example:
 
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 9fe2c9a3019af..f04031c370359 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -416,6 +416,15 @@ LogicalResult GPUBarrierConversion::matchAndRewrite(
   return success();
 }
 
+template <typename T>
+Value getDimOp(OpBuilder &builder, MLIRContext *ctx, Location loc,
+               gpu::Dimension dimension) {
+  Type indexType = IndexType::get(ctx);
+  IntegerType i32Type = IntegerType::get(ctx, 32);
+  Value dim = builder.create<T>(loc, indexType, dimension);
+  return builder.create<arith::IndexCastOp>(loc, i32Type, dim);
+}
+
 //===----------------------------------------------------------------------===//
 // Shuffle
 //===----------------------------------------------------------------------===//
@@ -436,8 +445,8 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
         shuffleOp, "shuffle width and target subgroup size mismatch");
 
   Location loc = shuffleOp.getLoc();
-  Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
-                                            shuffleOp.getLoc(), rewriter);
+  Value validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
+                                             shuffleOp.getLoc(), rewriter);
   auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
   Value result;
 
@@ -450,17 +459,65 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
     result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
         loc, scope, adaptor.getValue(), adaptor.getOffset());
     break;
-  case gpu::ShuffleMode::DOWN:
+  case gpu::ShuffleMode::DOWN: {
     result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
         loc, scope, adaptor.getValue(), adaptor.getOffset());
+
+    MLIRContext *ctx = shuffleOp.getContext();
+    Value dimX =
+        getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
+    Value dimY =
+        getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
+    Value tidX =
+        getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
+    Value tidY =
+        getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
+    Value tidZ =
+        getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
+    auto i32Type = rewriter.getIntegerType(32);
+    Value tmp1 = rewriter.create<arith::MulIOp>(loc, i32Type, tidZ, dimY);
+    Value tmp2 = rewriter.create<arith::AddIOp>(loc, i32Type, tmp1, tidY);
+    Value tmp3 = rewriter.create<arith::MulIOp>(loc, i32Type, tmp2, dimX);
+    Value landId = rewriter.create<arith::AddIOp>(loc, i32Type, tmp3, tidX);
+
+    Value resultLandId =
+        rewriter.create<arith::AddIOp>(loc, landId, adaptor.getOffset());
+    validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+                                              resultLandId, adaptor.getWidth());
     break;
-  case gpu::ShuffleMode::UP:
+  }
+  case gpu::ShuffleMode::UP: {
     result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
         loc, scope, adaptor.getValue(), adaptor.getOffset());
+
+    MLIRContext *ctx = shuffleOp.getContext();
+    Value dimX =
+        getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::x);
+    Value dimY =
+        getDimOp<gpu::BlockDimOp>(rewriter, ctx, loc, gpu::Dimension::y);
+    Value tidX =
+        getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::x);
+    Value tidY =
+        getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::y);
+    Value tidZ =
+        getDimOp<gpu::ThreadIdOp>(rewriter, ctx, loc, gpu::Dimension::z);
+    auto i32Type = rewriter.getIntegerType(32);
+    Value tmp1 = rewriter.create<arith::MulIOp>(loc, i32Type, tidZ, dimY);
+    Value tmp2 = rewriter.create<arith::AddIOp>(loc, i32Type, tmp1, tidY);
+    Value tmp3 = rewriter.create<arith::MulIOp>(loc, i32Type, tmp2, dimX);
+    Value landId = rewriter.create<arith::AddIOp>(loc, i32Type, tmp3, tidX);
+
+    Value resultLandId =
+        rewriter.create<arith::SubIOp>(loc, landId, adaptor.getOffset());
+    validVal = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::sge, resultLandId,
+        rewriter.create<arith::ConstantOp>(
+            loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
     break;
   }
+  }
 
-  rewriter.replaceOp(shuffleOp, {result, trueVal});
+  rewriter.replaceOp(shuffleOp, {result, validVal});
   return success();
 }
 
diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
index 56877a756b7ba..396421b7585af 100644
--- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir
@@ -94,6 +94,25 @@ gpu.module @kernels {
     // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
     // CHECK: %{{.+}} = spirv.Constant true
     // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
+
+    // CHECK: %[[BLOCK_SIZE_X:.+]] = spirv.Constant 16 : i32
+    // CHECK: %[[BLOCK_SIZE_Y:.+]] = spirv.Constant 1 : i32
+    // CHECK: %__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
+    // CHECK: %[[WORKGROUP:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32>
+    // CHECK: %[[THREAD_X:.+]] = spirv.CompositeExtract %[[WORKGROUP]][0 : i32] : vector<3xi32>
+    // CHECK: %__builtin__LocalInvocationId___addr_1 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
+    // CHECK: %[[WORKGROUP_1:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_1 : vector<3xi32>
+    // CHECK: %[[THREAD_Y:.+]] = spirv.CompositeExtract %[[WORKGROUP_1]][1 : i32] : vector<3xi32>
+    // CHECK: %__builtin__LocalInvocationId___addr_2 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
+    // CHECK: %[[WORKGROUP_2:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_2 : vector<3xi32>
+    // CHECK: %[[THREAD_Z:.+]] = spirv.CompositeExtract %[[WORKGROUP_2]][2 : i32] : vector<3xi32>
+    // CHECK: %[[S0:.+]] = spirv.IMul %[[THREAD_Z]], %[[BLOCK_SIZE_Y]] : i32
+    // CHECK: %[[S1:.+]] = spirv.IAdd %[[S0]], %[[THREAD_Y]] : i32
+    // CHECK: %[[S2:.+]] = spirv.IMul %[[S1]], %[[BLOCK_SIZE_X]] : i32
+    // CHECK: %[[LANE_ID:.+]] = spirv.IAdd %[[S2]], %[[THREAD_X]] : i32
+    // CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32
+    // CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32
+
     %result, %valid = gpu.shuffle down %val, %offset, %width : f32
     gpu.return
   }
@@ -122,6 +141,26 @@ gpu.module @kernels {
     // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
     // CHECK: %{{.+}} = spirv.Constant true
     // CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
+
+    // CHECK: %[[BLOCK_SIZE_X:.+]] = spirv.Constant 16 : i32
+    // CHECK: %[[BLOCK_SIZE_Y:.+]] = spirv.Constant 1 : i32
+    // CHECK: %__builtin__LocalInvocationId___addr = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
+    // CHECK: %[[WORKGROUP:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr : vector<3xi32>
+    // CHECK: %[[THREAD_X:.+]] = spirv.CompositeExtract %[[WORKGROUP]][0 : i32] : vector<3xi32>
+    // CHECK: %__builtin__LocalInvocationId___addr_1 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
+    // CHECK: %[[WORKGROUP_1:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_1 : vector<3xi32>
+    // CHECK: %[[THREAD_Y:.+]] = spirv.CompositeExtract %[[WORKGROUP_1]][1 : i32] : vector<3xi32>
+    // CHECK: %__builtin__LocalInvocationId___addr_2 = spirv.mlir.addressof @__builtin__LocalInvocationId__ : !spirv.ptr<vector<3xi32>, Input>
+    // CHECK: %[[WORKGROUP_2:.+]] = spirv.Load "Input" %__builtin__LocalInvocationId___addr_2 : vector<3xi32>
+    // CHECK: %[[THREAD_Z:.+]] = spirv.CompositeExtract %[[WORKGROUP_2]][2 : i32] : vector<3xi32>
+    // CHECK: %[[S0:.+]] = spirv.IMul %[[THREAD_Z]], %[[BLOCK_SIZE_Y]] : i32
+    // CHECK: %[[S1:.+]] = spirv.IAdd %[[S0]], %[[THREAD_Y]] : i32
+    // CHECK: %[[S2:.+]] = spirv.IMul %[[S1]], %[[BLOCK_SIZE_X]] : i32
+    // CHECK: %[[LANE_ID:.+]] = spirv.IAdd %[[S2]], %[[THREAD_X]] : i32
+    // CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32
+    // CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32
+    // CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32
+
     %result, %valid = gpu.shuffle up %val, %offset, %width : f32
     gpu.return
   }



More information about the Mlir-commits mailing list