[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)

Hyunsung Lee llvmlistbot at llvm.org
Wed Apr 2 00:11:23 PDT 2025


================
@@ -4956,9 +5036,15 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     // Insert a cast if needed
     if (needUpdateDestType) {
       rewriter.setInsertionPointAfter(packOp);
-      auto castOp =
-          rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
-      rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+      if (hasTensorSemantics) {
+        auto castOp =
+            rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+        rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+      } else {
+        auto castOp =
+            rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
+        rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+      }
----------------
ita9naiwa wrote:

It should be done like
```MLIR
  Operation* castOp;
  if (hasTensorSemantics) {
    castOp = rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
  } else {
    castOp = rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
  }
  rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
```

is it correct?

https://github.com/llvm/llvm-project/pull/129036


More information about the Mlir-commits mailing list