[Mlir-commits] [mlir] [mlir][linalg] Fix empty outer dim case for packing reshape op (PR #96732)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 25 22:31:51 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: None (yifeizh2)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/96732.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+14-3)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e51ae2264a36a..699bf56f96581 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -641,7 +641,14 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
PatternRewriter &rewriter) {
SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
- ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+ auto numOuterDims =
+ dyn_cast<RankedTensorType>(packOp.getDpsInputs()[0].getType())
+ .getShape()
+ .size();
+ SmallVector<int64_t> outerDimsPerm =
+ packOp.getOuterDimsPerm().empty()
+ ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
+ : SmallVector<int64_t>(packOp.getOuterDimsPerm());
ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
SmallVector<ReassociationIndices> reassocIndices =
@@ -885,8 +892,12 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
PatternRewriter &rewriter) {
SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
- ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
-
+ auto numOuterDims =
+ dyn_cast<RankedTensorType>(unPackOp.getType()).getShape().size();
+ SmallVector<int64_t> outerDimsPerm =
+ unPackOp.getOuterDimsPerm().empty()
+ ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
+ : SmallVector<int64_t>(unPackOp.getOuterDimsPerm());
auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
if (!expandTy)
return failure();
``````````
</details>
https://github.com/llvm/llvm-project/pull/96732
More information about the Mlir-commits
mailing list