[Mlir-commits] [mlir] [mlir][gpu][spirv] Remove rotation semantics of gpu.shuffle up/down (PR #139105)

Hsiangkai Wang llvmlistbot at llvm.org
Mon Jun 9 02:18:17 PDT 2025


================
@@ -436,25 +436,53 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
         shuffleOp, "shuffle width and target subgroup size mismatch");
 
   Location loc = shuffleOp.getLoc();
-  Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
-                                            shuffleOp.getLoc(), rewriter);
   auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
   Value result;
+  Value validVal;
 
   switch (shuffleOp.getMode()) {
-  case gpu::ShuffleMode::XOR:
+  case gpu::ShuffleMode::XOR: {
     result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
         loc, scope, adaptor.getValue(), adaptor.getOffset());
+    validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
+                                         shuffleOp.getLoc(), rewriter);
     break;
-  case gpu::ShuffleMode::IDX:
+  }
+  case gpu::ShuffleMode::IDX: {
     result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
         loc, scope, adaptor.getValue(), adaptor.getOffset());
+    validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
+                                         shuffleOp.getLoc(), rewriter);
+    break;
+  }
+  case gpu::ShuffleMode::DOWN: {
+    result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
+        loc, scope, adaptor.getValue(), adaptor.getOffset());
+
+    Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
+    Value resultLandId =
+        rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
+    validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+                                              resultLandId, adaptor.getWidth());
+    break;
+  }
+  case gpu::ShuffleMode::UP: {
+    result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
+        loc, scope, adaptor.getValue(), adaptor.getOffset());
+
+    Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
+    Value resultLandId =
+        rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
----------------
Hsiangkai wrote:

Added

https://github.com/llvm/llvm-project/pull/139105


More information about the Mlir-commits mailing list