[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)
Hyunsung Lee
llvmlistbot at llvm.org
Sun Mar 16 04:01:27 PDT 2025
================
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
----------------
ita9naiwa wrote:
```c++
template <typename ShapedTy>
ShapedTy inferCollapsedType(ShapedTy type, ArrayRef<AffineMap> reassociation) {
static_assert(std::is_same<ShapedTy, RankedTensorType>::value ||
std::is_same<ShapedTy, MemRefType>::value,
"Expected RankedTensorType or MemRefType");
auto shape = type.getShape();
SmallVector<int64_t, 4> newShape;
newShape.reserve(reassociation.size());
assert(isReassociationValid(reassociation) && "invalid reassociation");
unsigned currentDim = 0;
for (AffineMap m : reassociation) {
unsigned dim = m.getNumResults();
auto band = shape.slice(currentDim, dim);
int64_t size = 1;
if (llvm::is_contained(band, ShapedType::kDynamic))
size = ShapedType::kDynamic;
else
for (unsigned d = 0; d < dim; ++d)
size *= shape[currentDim + d];
newShape.push_back(size);
currentDim += dim;
}
return ShapedTy::get(newShape, type.getElementType());
}
```
https://github.com/llvm/llvm-project/pull/129036
More information about the Mlir-commits
mailing list