[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)

Han-Chung Wang llvmlistbot at llvm.org
Fri Apr 18 16:25:58 PDT 2025


================
@@ -4957,9 +5030,19 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     // Insert a cast if needed
     if (needUpdateDestType) {
       rewriter.setInsertionPointAfter(packOp);
-      auto castOp =
-          rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
-      rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+      Operation *castOp;
+      bool hasTensorSemantics = packOp.hasPureTensorSemantics();
+      if (hasTensorSemantics) {
+        castOp =
+            rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+      } else {
+        castOp =
+            rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
+      }
+      rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
+    } else {
+      // TODO: support memref.cast if static shape inference is available.
+      return failure();
----------------
hanhanW wrote:

This does not make sense to me. This casting is needed only for tensor version. I was wrong in the previous review.

Reason: the memref does not have return value at all. The casting is needed for making type consistent. E.g.,

```mlir
%pack = linalg.pack %src into %dest ... -> tensor<?x?xf32>
return %pack : tensor<?x?xf32> // some uses
```

after the shape inference, the result type is changed in the rewriter:

```

%pack = linalg.pack %casted_src into %casted_dest ... -> tensor<3x4xf32>
// This is where the types mismatch, so we need a final cast for this.
return %pack : tensor<?x?xf32>
```

It is not the case for memref because memref does not have any return value at all.

So you only need to update the if condition, which becomes

```cpp
if (needUpdateDestType && packOp.hasPureTensorSemantics()) {
  rewriter.setInsertionPointAfter(packOp);
  auto castOp =
      rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
   rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
}
```

https://github.com/llvm/llvm-project/pull/129036


More information about the Mlir-commits mailing list