[Mlir-commits] [mlir] [mlir][amdgpu] Promote gpu.shuffle to amdgpu.dpp (PR #155158)

Jakub Kuderski llvmlistbot at llvm.org
Tue Aug 26 07:02:02 PDT 2025


================
@@ -96,13 +97,153 @@ struct PromoteShuffleToPermlanePattern
   }
 };
 
+static Value getLaneId(RewriterBase &rewriter, Location loc) {
+  auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+  Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
+  Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
+  NamedAttribute noundef = rewriter.getNamedAttr(
+      LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
+  NamedAttribute lowRange = rewriter.getNamedAttr(
+      LLVM::LLVMDialect::getRangeAttrName(),
+      LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+                                   APInt(32, 32)));
+  NamedAttribute highRange = rewriter.getNamedAttr(
+      LLVM::LLVMDialect::getRangeAttrName(),
+      LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+                                   APInt(32, 64)));
+  Value mbcntLo = ROCDL::MbcntLoOp::create(
+      rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
+      /*res_attrs=*/
+      rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
+  Value laneId = ROCDL::MbcntHiOp::create(
+      rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
+      rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
+  return laneId;
+}
+
+/// Try to promote `gpu.shuffle` to `amdgpu.dpp`, width must be 64
+/// and offset must be a constant integer in the set {16, 32}.
+struct PromoteShuffleToDPPPattern : public OpRewritePattern<gpu::ShuffleOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+                                PatternRewriter &rewriter) const override {
+    std::optional<int64_t> width = getConstantIntValue(op.getWidth());
+    if (!width)
+      return rewriter.notifyMatchFailure(op,
+                                         "width must be a constant integer");
+    int64_t widthValue = *width;
+    if (widthValue != 4 && widthValue != 8 && widthValue != 12 &&
+        widthValue != 16 && widthValue != 32 && widthValue != 48 &&
+        widthValue != 64)
----------------
kuhar wrote:

use `llvm::is_contained({4, 8, ...}, widthValue)`

https://github.com/llvm/llvm-project/pull/155158


More information about the Mlir-commits mailing list