[Mlir-commits] [mlir] [mlir][linalg] Fix empty outer dim case for packing reshape op (PR #96732)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 25 22:31:00 PDT 2024


https://github.com/yifeizh2 created https://github.com/llvm/llvm-project/pull/96732

None

>From 52b58dc5bb5562e0bb77614991b4a14d6d375859 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Tue, 25 Jun 2024 22:22:07 -0700
Subject: [PATCH] [mlir][linalg] Fix empty outer dim case for packing reshape
 op

---
 .../Linalg/Transforms/DataLayoutPropagation.cpp | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e51ae2264a36a..699bf56f96581 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -641,7 +641,14 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
                                    PatternRewriter &rewriter) {
   SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
-  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+  auto numOuterDims =
+      dyn_cast<RankedTensorType>(packOp.getDpsInputs()[0].getType())
+          .getShape()
+          .size();
+  SmallVector<int64_t> outerDimsPerm =
+      packOp.getOuterDimsPerm().empty()
+          ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
+          : SmallVector<int64_t>(packOp.getOuterDimsPerm());
 
   ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
   SmallVector<ReassociationIndices> reassocIndices =
@@ -885,8 +892,12 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
                                    PatternRewriter &rewriter) {
   SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
-  ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
-
+  auto numOuterDims =
+      dyn_cast<RankedTensorType>(unPackOp.getType()).getShape().size();
+  SmallVector<int64_t> outerDimsPerm =
+      unPackOp.getOuterDimsPerm().empty()
+          ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
+          : SmallVector<int64_t>(unPackOp.getOuterDimsPerm());
   auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
   if (!expandTy)
     return failure();



More information about the Mlir-commits mailing list