[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)
Han-Chung Wang
llvmlistbot at llvm.org
Fri Apr 18 16:25:58 PDT 2025
================
@@ -4957,9 +5030,19 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
// Insert a cast if needed
if (needUpdateDestType) {
rewriter.setInsertionPointAfter(packOp);
- auto castOp =
- rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
- rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+ Operation *castOp;
+ bool hasTensorSemantics = packOp.hasPureTensorSemantics();
+ if (hasTensorSemantics) {
+ castOp =
+ rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+ } else {
+ castOp =
+ rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
+ }
+ rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
+ } else {
+ // TODO: support memref.cast if static shape inference is available.
+ return failure();
----------------
hanhanW wrote:
This does not make sense to me. This casting is needed only for tensor version. I was wrong in the previous review.
Reason: the memref does not have return value at all. The casting is needed for making type consistent. E.g.,
```mlir
%pack = linalg.pack %src into %dest ... -> tensor<?x?xf32>
return %pack : tensor<?x?xf32> // some uses
```
after the shape inference, the result type is changed in the rewriter:
```
%pack = linalg.pack %casted_src into %casted_dest ... -> tensor<3x4xf32>
// This is where the types mismatch, so we need a final cast for this.
return %pack : tensor<?x?xf32>
```
It is not the case for memref because memref does not have any return value at all.
So you only need to update the if condition, which becomes
```cpp
if (needUpdateDestType && packOp.hasPureTensorSemantics()) {
rewriter.setInsertionPointAfter(packOp);
auto castOp =
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
}
```
https://github.com/llvm/llvm-project/pull/129036
More information about the Mlir-commits
mailing list