[Mlir-commits] [mlir] [MLIR][Linalg] pack, unpack to take memref inputs (PR #129036)

Hyunsung Lee llvmlistbot at llvm.org
Sun Mar 16 04:01:27 PDT 2025


================
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
                          srcType.getMemorySpace());
 }
 
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
----------------
ita9naiwa wrote:

```c++
template <typename ShapedTy>
ShapedTy inferCollapsedType(ShapedTy type, ArrayRef<AffineMap> reassociation) {
  static_assert(std::is_same<ShapedTy, RankedTensorType>::value ||
                std::is_same<ShapedTy, MemRefType>::value,
                "Expected RankedTensorType or MemRefType");

  auto shape = type.getShape();
  SmallVector<int64_t, 4> newShape;
  newShape.reserve(reassociation.size());

  assert(isReassociationValid(reassociation) && "invalid reassociation");
  unsigned currentDim = 0;
  for (AffineMap m : reassociation) {
    unsigned dim = m.getNumResults();
    auto band = shape.slice(currentDim, dim);
    int64_t size = 1;
    if (llvm::is_contained(band, ShapedType::kDynamic))
      size = ShapedType::kDynamic;
    else
      for (unsigned d = 0; d < dim; ++d)
        size *= shape[currentDim + d];
    newShape.push_back(size);
    currentDim += dim;
  }

  return ShapedTy::get(newShape, type.getElementType());
}
```



https://github.com/llvm/llvm-project/pull/129036


More information about the Mlir-commits mailing list