[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalizer for folding casts into gather_to_lds (PR #150503)
Alan Li
llvmlistbot at llvm.org
Thu Jul 24 12:48:10 PDT 2025
================
@@ -546,6 +550,58 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}
+namespace {
+/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
+/// information or changes layout, the cast can be skipped.
+struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
+ using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
+ PatternRewriter &rewriter) const override {
+ bool modified = false;
+
+ // Check source.
+ if (auto castOp = gatherOp.getSrc().getDefiningOp<memref::CastOp>()) {
+ auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());
+
+ if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType &&
+ fromType.getElementType() == toType.getElementType()) {
+ rewriter.modifyOpInPlace(gatherOp, [&] {
+ gatherOp.getSrcMutable().assign(castOp.getSource());
+ });
+ modified = true;
+ }
+ }
+
+ // Check target.
+ if (auto castOp = gatherOp.getDst().getDefiningOp<memref::CastOp>()) {
+ auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());
----------------
lialan wrote:
same here, use `getDest()` for `toType`.
https://github.com/llvm/llvm-project/pull/150503
More information about the Mlir-commits
mailing list