[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalizer for folding casts into gather_to_lds (PR #150503)

Alan Li llvmlistbot at llvm.org
Thu Jul 24 12:48:10 PDT 2025


================
@@ -546,6 +550,58 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+namespace {
+/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
+/// information or changes layout, the cast can be skipped.
+struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
+  using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
+                                PatternRewriter &rewriter) const override {
+    bool modified = false;
+
+    // Check source.
+    if (auto castOp = gatherOp.getSrc().getDefiningOp<memref::CastOp>()) {
+      auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
+      auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());
+
+      if (memref::CastOp::canFoldIntoConsumerOp(castOp) && fromType && toType &&
+          fromType.getElementType() == toType.getElementType()) {
+        rewriter.modifyOpInPlace(gatherOp, [&] {
+          gatherOp.getSrcMutable().assign(castOp.getSource());
+        });
+        modified = true;
+      }
+    }
+
+    // Check target.
+    if (auto castOp = gatherOp.getDst().getDefiningOp<memref::CastOp>()) {
+      auto fromType = dyn_cast<MemRefType>(castOp.getSource().getType());
+      auto toType = dyn_cast<MemRefType>(castOp.getSource().getType());
----------------
lialan wrote:

same here, use `getDest()` for `toType`.

https://github.com/llvm/llvm-project/pull/150503


More information about the Mlir-commits mailing list