[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)
Hyunsung Lee
llvmlistbot at llvm.org
Mon Mar 24 22:52:49 PDT 2025
================
@@ -396,18 +397,30 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 3. Transpose packedShape to stripMinedShape.
- RankedTensorType stripMinedTensorType =
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMinedTensorType, packingMetadata.reassociations);
+ ShapedType stripMinedType;
+ if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+ stripMinedType =
+ RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+ } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+ stripMinedType =
+ MemRefType::get(stripMinedShape, memrefType.getElementType());
+ }
+ ShapedType collapsedType;
+ if (stripMinedType.isa<TensorType>()) {
+ collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
+ } else if (stripMinedType.isa<MemRefType>()) {
+ collapsedType = memref::CollapseShapeOp::computeCollapsedType(
+ cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
+ }
----------------
ita9naiwa wrote:
Sorry, I'll revert this change.
https://github.com/llvm/llvm-project/pull/129036
More information about the Mlir-commits
mailing list