[Mlir-commits] [mlir] [mlir][amdgpu] Promote gpu.shuffle to amdgpu.dpp (PR #155158)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Aug 26 07:02:02 PDT 2025
================
@@ -96,13 +97,153 @@ struct PromoteShuffleToPermlanePattern
}
};
+static Value getLaneId(RewriterBase &rewriter, Location loc) {
+ auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+ Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
+ Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
+ NamedAttribute noundef = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
+ NamedAttribute lowRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 32)));
+ NamedAttribute highRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 64)));
+ Value mbcntLo = ROCDL::MbcntLoOp::create(
+ rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
+ /*res_attrs=*/
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
+ Value laneId = ROCDL::MbcntHiOp::create(
+ rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
+ return laneId;
+}
+
+/// Try to promote `gpu.shuffle` to `amdgpu.dpp`, width must be 64
+/// and offset must be a constant integer in the set {16, 32}.
+struct PromoteShuffleToDPPPattern : public OpRewritePattern<gpu::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ std::optional<int64_t> width = getConstantIntValue(op.getWidth());
+ if (!width)
+ return rewriter.notifyMatchFailure(op,
+ "width must be a constant integer");
+ int64_t widthValue = *width;
+ if (widthValue != 4 && widthValue != 8 && widthValue != 12 &&
+ widthValue != 16 && widthValue != 32 && widthValue != 48 &&
+ widthValue != 64)
----------------
kuhar wrote:
use `llvm::is_contained({4, 8, ...}, widthValue)`
https://github.com/llvm/llvm-project/pull/155158
More information about the Mlir-commits
mailing list