[Mlir-commits] [mlir] ddcc507 - [mlir][linalg] Expose lowerPack and lowerUnPack utils.

Hanhan Wang llvmlistbot at llvm.org
Fri Apr 21 15:23:28 PDT 2023


Author: Hanhan Wang
Date: 2023-04-21T15:23:16-07:00
New Revision: ddcc50721a04996f52038066c693866daefa7d19

URL: https://github.com/llvm/llvm-project/commit/ddcc50721a04996f52038066c693866daefa7d19
DIFF: https://github.com/llvm/llvm-project/commit/ddcc50721a04996f52038066c693866daefa7d19.diff

LOG: [mlir][linalg] Expose lowerPack and lowerUnPack utils.

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D148867

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 52982c3fd537f..73e830dfc86cb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -907,6 +907,27 @@ FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
     GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter);
 
+struct LowerPackResult {
+  tensor::PadOp padOp;
+  tensor::ExpandShapeOp expandShapeOp;
+  linalg::TransposeOp transposeOp;
+};
+
+/// Rewrite pack as pad + reshape + transpose.
+FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
+                                     tensor::PackOp packOp);
+
+struct LowerUnPackOpResult {
+  tensor::EmptyOp emptyOp;
+  linalg::TransposeOp transposeOp;
+  tensor::CollapseShapeOp collapseShapeOp;
+  tensor::ExtractSliceOp extractSliceOp;
+};
+
+/// Rewrite pack as empty + transpose + reshape + extract_slice.
+FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
+                                           tensor::UnPackOp unPackOp);
+
 /// Struct to hold the result of a `pack` call.
 struct PackResult {
   SmallVector<tensor::PackOp> packOps;

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dbadabd66cba0..f113d3af7b44a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -748,126 +748,6 @@ LogicalResult transform::InterchangeOp::verify() {
 // LowerPackOp
 //===----------------------------------------------------------------------===//
 
-struct LowerPackResult {
-  tensor::PadOp padOp;
-  tensor::ExpandShapeOp expandShapeOp;
-  linalg::TransposeOp transposeOp;
-};
-
-/// Rewrite pack as pad + reshape + transpose.
-static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
-                                            tensor::PackOp packOp) {
-  // 1. Filter out NYI cases.
-  if (!packOp.getOuterDimsPerm().empty())
-    return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
-
-  auto packedTensorType =
-      packOp->getResultTypes().front().cast<RankedTensorType>();
-  if (!packedTensorType.hasStaticShape()) {
-    return rewriter.notifyMatchFailure(
-        packOp,
-        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
-  }
-
-  Location loc = packOp->getLoc();
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(packOp);
-
-  // 2. Compute the permutation vector to move the last `numPackedDims` into the
-  // `innerPosDims` of a shape of rank `packedRank`.
-  int64_t numPackedDims = packOp.getInnerDimsPos().size();
-  int64_t packedRank = packedTensorType.getRank();
-  auto lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
-  PackingMetadata packingMetadata = computePackingMetadata(
-      packedTensorType.getRank(), packOp.getInnerDimsPos());
-  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
-      packedRank, lastDims, packingMetadata.insertPositions);
-
-  // 3. Compute the stripMinedShape: this is the packed shape before any outer
-  // or inner permutations have been applied.
-  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
-  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
-
-  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
-  RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
-      packingMetadata.reassociations);
-  Value paddingValue = packOp.getPaddingValue();
-  if (!paddingValue) {
-    paddingValue = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
-  }
-  auto padOp =
-      tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
-                              /*nofold=*/false, loc, rewriter);
-
-  LLVM_DEBUG(
-      DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
-                                                DBGS() << "insertPositions: ");
-      DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
-                                      DBGS() << "packedShape: ");
-      DBGSNL();
-      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
-                            DBGS() << "lastDimsToInsertPositionsPerm: ");
-      DBGSNL(); llvm::interleaveComma(
-          packingMetadata.reassociations, DBGS() << "reassociations: ",
-          [&](ReassociationIndices ri) {
-            llvm::interleaveComma(ri, llvm::dbgs() << "|");
-          });
-      DBGSNL();
-      llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
-      DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
-
-  if (packOp.isLikePad()) {
-    // This pack is just a plain pad.
-    // Just insert the pad in the higher ranked tensor.
-    auto emptyOp =
-        rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
-    // Offsets.
-    SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
-    // Strides.
-    SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> sizes =
-        getMixedDimensions(rewriter, loc, packOp.getDest());
-
-    auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, /*source=*/padOp, /*dest=*/emptyOp,
-        /*offsets=*/zeros, sizes,
-        /*strides=*/ones);
-
-    LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
-
-    rewriter.replaceOp(packOp, insertSliceOp->getResults());
-
-    return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
-                           /*transposeOp=*/nullptr};
-  }
-  // 5. Expand from the padded result to the stripMinedShape.
-  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
-      loc,
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
-      padOp.getResult(), packingMetadata.reassociations);
-
-  // 6. Transpose stripMinedShape to packedShape.
-  SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
-      packedRank, packingMetadata.insertPositions, lastDims);
-  auto transposeOp = rewriter.create<linalg::TransposeOp>(
-      loc, reshapeOp.getResult(), packOp.getDest(),
-      insertPositionsToLastDimsPerm);
-
-  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
-             DBGS() << "reshape op: " << reshapeOp; DBGSNL();
-             llvm::interleaveComma(insertPositionsToLastDimsPerm,
-                                   DBGS() << "insertPositionsToLastDimsPerm: ");
-             DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
-
-  // 7. Replace packOp by transposeOp.
-  rewriter.replaceOp(packOp, transposeOp->getResults());
-
-  return LowerPackResult{padOp, reshapeOp, transposeOp};
-}
-
 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
     tensor::PackOp target, transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
@@ -889,115 +769,6 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
 // LowerUnPackOp
 //===----------------------------------------------------------------------===//
 
-struct LowerUnPackOpResult {
-  tensor::EmptyOp emptyOp;
-  linalg::TransposeOp transposeOp;
-  tensor::CollapseShapeOp collapseShapeOp;
-  tensor::ExtractSliceOp extractSliceOp;
-};
-
-/// Rewrite pack as empty + transpose + reshape + extract_slice.
-static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
-                                                  tensor::UnPackOp unPackOp) {
-  // 1. Filter out NYI cases.
-  if (!unPackOp.getOuterDimsPerm().empty())
-    return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
-
-  RankedTensorType packedTensorType = unPackOp.getSourceType();
-  if (!packedTensorType.hasStaticShape()) {
-    return rewriter.notifyMatchFailure(
-        unPackOp,
-        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
-  }
-
-  Location loc = unPackOp->getLoc();
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(unPackOp);
-
-  int64_t packedRank = packedTensorType.getRank();
-
-  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
-  auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
-  if (unPackOp.isLikeUnPad()) {
-    // This unpack is just a plain unpad.
-    // Just extract the slice from the higher ranked tensor.
-    ArrayRef<int64_t> destShape = destTensorType.getShape();
-    // The inner dimensions stay the same as the destination tensor, but the
-    // outer ones are additional 1s.
-    SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
-    sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest()));
-
-    auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
-        loc, destTensorType, unPackOp.getSource(),
-        SmallVector<OpFoldResult>(packedRank, zero), sizes,
-        SmallVector<OpFoldResult>(packedRank, one));
-
-    rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
-
-    return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
-                               /*reshapeOp=*/nullptr, extractSliceOp};
-  }
-  // 2. Compute the permutation vector to move the last `numPackedDims` into
-  // the `innerPosDims` of a shape of rank `packedRank`.
-  int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
-  auto lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
-  PackingMetadata packingMetadata =
-      computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
-  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
-      packedRank, lastDims, packingMetadata.insertPositions);
-
-  // 3. Compute the stripMinedShape: this is the packed shape without outer and
-  // inner permutations.
-  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
-  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
-
-  // 4. Transpose packedShape to stripMinedShape.
-  RankedTensorType stripMinedTensorType =
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
-  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
-      stripMinedTensorType, packingMetadata.reassociations);
-  auto emptyOp =
-      rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
-  auto transposeOp = rewriter.create<linalg::TransposeOp>(
-      loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
-
-  LLVM_DEBUG(
-      DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
-                                                DBGS() << "insertPositions: ");
-      DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
-                                      DBGS() << "packedShape: ");
-      DBGSNL();
-      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
-                            DBGS() << "lastDimsToInsertPositionsPerm: ");
-      DBGSNL(); llvm::interleaveComma(
-          packingMetadata.reassociations, DBGS() << "reassociations: ",
-          [&](ReassociationIndices ri) {
-            llvm::interleaveComma(ri, llvm::dbgs() << "|");
-          });
-      DBGSNL();
-      llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
-      DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
-
-  // 5. Collapse from the stripMinedShape to the padded result.
-  auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
-      loc, collapsedType, transposeOp->getResult(0),
-      packingMetadata.reassociations);
-
-  // 6. ExtractSlice
-  int64_t destRank = destTensorType.getRank();
-  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
-      loc, destTensorType, reshapeOp->getResult(0),
-      SmallVector<OpFoldResult>(destRank, zero),
-      tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
-      SmallVector<OpFoldResult>(destRank, one));
-
-  // 7. Replace unPackOp by extractSliceOp.
-  rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
-
-  return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
-}
-
 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
     tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 166f42637523f..4d5ef0edc9f8a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -477,6 +477,220 @@ struct PackedOperandsDimList {
 
 } // namespace
 
+FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
+                                             tensor::PackOp packOp) {
+  // 1. Filter out NYI cases.
+  if (!packOp.getOuterDimsPerm().empty())
+    return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
+
+  auto packedTensorType =
+      packOp->getResultTypes().front().cast<RankedTensorType>();
+  if (!packedTensorType.hasStaticShape()) {
+    return rewriter.notifyMatchFailure(
+        packOp,
+        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
+  }
+
+  Location loc = packOp->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
+  // 2. Compute the permutation vector to move the last `numPackedDims` into the
+  // `innerPosDims` of a shape of rank `packedRank`.
+  int64_t numPackedDims = packOp.getInnerDimsPos().size();
+  int64_t packedRank = packedTensorType.getRank();
+  auto lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+  PackingMetadata packingMetadata = computePackingMetadata(
+      packedTensorType.getRank(), packOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+      packedRank, lastDims, packingMetadata.insertPositions);
+
+  // 3. Compute the stripMinedShape: this is the packed shape before any outer
+  // or inner permutations have been applied.
+  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
+  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+
+  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+  RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
+      packingMetadata.reassociations);
+  Value paddingValue = packOp.getPaddingValue();
+  if (!paddingValue) {
+    paddingValue = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
+  }
+  auto padOp =
+      tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
+                              /*nofold=*/false, loc, rewriter);
+
+  LLVM_DEBUG(
+      DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
+                                                DBGS() << "insertPositions: ");
+      DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
+                                      DBGS() << "packedShape: ");
+      DBGSNL();
+      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
+                            DBGS() << "lastDimsToInsertPositionsPerm: ");
+      DBGSNL(); llvm::interleaveComma(
+          packingMetadata.reassociations, DBGS() << "reassociations: ",
+          [&](ReassociationIndices ri) {
+            llvm::interleaveComma(ri, llvm::dbgs() << "|");
+          });
+      DBGSNL();
+      llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
+      DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
+
+  if (packOp.isLikePad()) {
+    // This pack is just a plain pad.
+    // Just insert the pad in the higher ranked tensor.
+    auto emptyOp =
+        rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
+    // Offsets.
+    SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
+    // Strides.
+    SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> sizes =
+        getMixedDimensions(rewriter, loc, packOp.getDest());
+
+    auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, /*source=*/padOp, /*dest=*/emptyOp,
+        /*offsets=*/zeros, sizes,
+        /*strides=*/ones);
+
+    LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
+
+    rewriter.replaceOp(packOp, insertSliceOp->getResults());
+
+    return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
+                           /*transposeOp=*/nullptr};
+  }
+  // 5. Expand from the padded result to the stripMinedShape.
+  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+      loc,
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
+      padOp.getResult(), packingMetadata.reassociations);
+
+  // 6. Transpose stripMinedShape to packedShape.
+  SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
+      packedRank, packingMetadata.insertPositions, lastDims);
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, reshapeOp.getResult(), packOp.getDest(),
+      insertPositionsToLastDimsPerm);
+
+  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
+             DBGS() << "reshape op: " << reshapeOp; DBGSNL();
+             llvm::interleaveComma(insertPositionsToLastDimsPerm,
+                                   DBGS() << "insertPositionsToLastDimsPerm: ");
+             DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
+
+  // 7. Replace packOp by transposeOp.
+  rewriter.replaceOp(packOp, transposeOp->getResults());
+
+  return LowerPackResult{padOp, reshapeOp, transposeOp};
+}
+
+FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
+                                                   tensor::UnPackOp unPackOp) {
+  // 1. Filter out NYI cases.
+  if (!unPackOp.getOuterDimsPerm().empty())
+    return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
+
+  RankedTensorType packedTensorType = unPackOp.getSourceType();
+  if (!packedTensorType.hasStaticShape()) {
+    return rewriter.notifyMatchFailure(
+        unPackOp,
+        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
+  }
+
+  Location loc = unPackOp->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unPackOp);
+
+  int64_t packedRank = packedTensorType.getRank();
+
+  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
+  auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
+  if (unPackOp.isLikeUnPad()) {
+    // This unpack is just a plain unpad.
+    // Just extract the slice from the higher ranked tensor.
+    ArrayRef<int64_t> destShape = destTensorType.getShape();
+    // The inner dimensions stay the same as the destination tensor, but the
+    // outer ones are additional 1s.
+    SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
+    sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest()));
+
+    auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+        loc, destTensorType, unPackOp.getSource(),
+        SmallVector<OpFoldResult>(packedRank, zero), sizes,
+        SmallVector<OpFoldResult>(packedRank, one));
+
+    rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+
+    return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
+                               /*reshapeOp=*/nullptr, extractSliceOp};
+  }
+  // 2. Compute the permutation vector to move the last `numPackedDims` into
+  // the `innerPosDims` of a shape of rank `packedRank`.
+  int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
+  auto lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+  PackingMetadata packingMetadata =
+      computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+      packedRank, lastDims, packingMetadata.insertPositions);
+
+  // 3. Compute the stripMinedShape: this is the packed shape without outer and
+  // inner permutations.
+  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
+  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+
+  // 4. Transpose packedShape to stripMinedShape.
+  RankedTensorType stripMinedTensorType =
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMinedTensorType, packingMetadata.reassociations);
+  auto emptyOp =
+      rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
+
+  LLVM_DEBUG(
+      DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
+                                                DBGS() << "insertPositions: ");
+      DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
+                                      DBGS() << "packedShape: ");
+      DBGSNL();
+      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
+                            DBGS() << "lastDimsToInsertPositionsPerm: ");
+      DBGSNL(); llvm::interleaveComma(
+          packingMetadata.reassociations, DBGS() << "reassociations: ",
+          [&](ReassociationIndices ri) {
+            llvm::interleaveComma(ri, llvm::dbgs() << "|");
+          });
+      DBGSNL();
+      llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
+      DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
+
+  // 5. Collapse from the stripMinedShape to the padded result.
+  auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
+      loc, collapsedType, transposeOp->getResult(0),
+      packingMetadata.reassociations);
+
+  // 6. ExtractSlice
+  int64_t destRank = destTensorType.getRank();
+  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, destTensorType, reshapeOp->getResult(0),
+      SmallVector<OpFoldResult>(destRank, zero),
+      tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
+      SmallVector<OpFoldResult>(destRank, one));
+
+  // 7. Replace unPackOp by extractSliceOp.
+  rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+
+  return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
+}
+
 SmallVector<int64_t>
 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
   SmallVector<int64_t> res;


        


More information about the Mlir-commits mailing list