[Mlir-commits] [mlir] ddcc507 - [mlir][linalg] Expose lowerPack and lowerUnPack utils.
Hanhan Wang
llvmlistbot at llvm.org
Fri Apr 21 15:23:28 PDT 2023
Author: Hanhan Wang
Date: 2023-04-21T15:23:16-07:00
New Revision: ddcc50721a04996f52038066c693866daefa7d19
URL: https://github.com/llvm/llvm-project/commit/ddcc50721a04996f52038066c693866daefa7d19
DIFF: https://github.com/llvm/llvm-project/commit/ddcc50721a04996f52038066c693866daefa7d19.diff
LOG: [mlir][linalg] Expose lowerPack and lowerUnPack utils.
Reviewed By: qcolombet
Differential Revision: https://reviews.llvm.org/D148867
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 52982c3fd537f..73e830dfc86cb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -907,6 +907,27 @@ FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);
+struct LowerPackResult {
+ tensor::PadOp padOp;
+ tensor::ExpandShapeOp expandShapeOp;
+ linalg::TransposeOp transposeOp;
+};
+
+/// Rewrite pack as pad + reshape + transpose.
+FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
+ tensor::PackOp packOp);
+
+struct LowerUnPackOpResult {
+ tensor::EmptyOp emptyOp;
+ linalg::TransposeOp transposeOp;
+ tensor::CollapseShapeOp collapseShapeOp;
+ tensor::ExtractSliceOp extractSliceOp;
+};
+
+/// Rewrite pack as empty + transpose + reshape + extract_slice.
+FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
+ tensor::UnPackOp unPackOp);
+
/// Struct to hold the result of a `pack` call.
struct PackResult {
SmallVector<tensor::PackOp> packOps;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dbadabd66cba0..f113d3af7b44a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -748,126 +748,6 @@ LogicalResult transform::InterchangeOp::verify() {
// LowerPackOp
//===----------------------------------------------------------------------===//
-struct LowerPackResult {
- tensor::PadOp padOp;
- tensor::ExpandShapeOp expandShapeOp;
- linalg::TransposeOp transposeOp;
-};
-
-/// Rewrite pack as pad + reshape + transpose.
-static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
- tensor::PackOp packOp) {
- // 1. Filter out NYI cases.
- if (!packOp.getOuterDimsPerm().empty())
- return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
-
- auto packedTensorType =
- packOp->getResultTypes().front().cast<RankedTensorType>();
- if (!packedTensorType.hasStaticShape()) {
- return rewriter.notifyMatchFailure(
- packOp,
- "non-static shape NYI, needs a more powerful tensor.expand_shape op");
- }
-
- Location loc = packOp->getLoc();
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(packOp);
-
- // 2. Compute the permutation vector to move the last `numPackedDims` into the
- // `innerPosDims` of a shape of rank `packedRank`.
- int64_t numPackedDims = packOp.getInnerDimsPos().size();
- int64_t packedRank = packedTensorType.getRank();
- auto lastDims = llvm::to_vector(
- llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
- PackingMetadata packingMetadata = computePackingMetadata(
- packedTensorType.getRank(), packOp.getInnerDimsPos());
- SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
- packedRank, lastDims, packingMetadata.insertPositions);
-
- // 3. Compute the stripMinedShape: this is the packed shape before any outer
- // or inner permutations have been applied.
- SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
- applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
-
- // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
- RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
- packingMetadata.reassociations);
- Value paddingValue = packOp.getPaddingValue();
- if (!paddingValue) {
- paddingValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
- }
- auto padOp =
- tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
- /*nofold=*/false, loc, rewriter);
-
- LLVM_DEBUG(
- DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
- DBGS() << "insertPositions: ");
- DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
- DBGS() << "packedShape: ");
- DBGSNL();
- llvm::interleaveComma(lastDimsToInsertPositionsPerm,
- DBGS() << "lastDimsToInsertPositionsPerm: ");
- DBGSNL(); llvm::interleaveComma(
- packingMetadata.reassociations, DBGS() << "reassociations: ",
- [&](ReassociationIndices ri) {
- llvm::interleaveComma(ri, llvm::dbgs() << "|");
- });
- DBGSNL();
- llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
- DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
-
- if (packOp.isLikePad()) {
- // This pack is just a plain pad.
- // Just insert the pad in the higher ranked tensor.
- auto emptyOp =
- rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
- // Offsets.
- SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
- // Strides.
- SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> sizes =
- getMixedDimensions(rewriter, loc, packOp.getDest());
-
- auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, /*source=*/padOp, /*dest=*/emptyOp,
- /*offsets=*/zeros, sizes,
- /*strides=*/ones);
-
- LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
-
- rewriter.replaceOp(packOp, insertSliceOp->getResults());
-
- return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
- /*transposeOp=*/nullptr};
- }
- // 5. Expand from the padded result to the stripMinedShape.
- auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
- loc,
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
- padOp.getResult(), packingMetadata.reassociations);
-
- // 6. Transpose stripMinedShape to packedShape.
- SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
- packedRank, packingMetadata.insertPositions, lastDims);
- auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, reshapeOp.getResult(), packOp.getDest(),
- insertPositionsToLastDimsPerm);
-
- LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
- DBGS() << "reshape op: " << reshapeOp; DBGSNL();
- llvm::interleaveComma(insertPositionsToLastDimsPerm,
- DBGS() << "insertPositionsToLastDimsPerm: ");
- DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
-
- // 7. Replace packOp by transposeOp.
- rewriter.replaceOp(packOp, transposeOp->getResults());
-
- return LowerPackResult{padOp, reshapeOp, transposeOp};
-}
-
DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
tensor::PackOp target, transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) {
@@ -889,115 +769,6 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
// LowerUnPackOp
//===----------------------------------------------------------------------===//
-struct LowerUnPackOpResult {
- tensor::EmptyOp emptyOp;
- linalg::TransposeOp transposeOp;
- tensor::CollapseShapeOp collapseShapeOp;
- tensor::ExtractSliceOp extractSliceOp;
-};
-
-/// Rewrite pack as empty + transpose + reshape + extract_slice.
-static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
- tensor::UnPackOp unPackOp) {
- // 1. Filter out NYI cases.
- if (!unPackOp.getOuterDimsPerm().empty())
- return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
-
- RankedTensorType packedTensorType = unPackOp.getSourceType();
- if (!packedTensorType.hasStaticShape()) {
- return rewriter.notifyMatchFailure(
- unPackOp,
- "non-static shape NYI, needs a more powerful tensor.expand_shape op");
- }
-
- Location loc = unPackOp->getLoc();
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(unPackOp);
-
- int64_t packedRank = packedTensorType.getRank();
-
- OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
- auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
- if (unPackOp.isLikeUnPad()) {
- // This unpack is just a plain unpad.
- // Just extract the slice from the higher ranked tensor.
- ArrayRef<int64_t> destShape = destTensorType.getShape();
- // The inner dimensions stay the same as the destination tensor, but the
- // outer ones are additional 1s.
- SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
- sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest()));
-
- auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
- loc, destTensorType, unPackOp.getSource(),
- SmallVector<OpFoldResult>(packedRank, zero), sizes,
- SmallVector<OpFoldResult>(packedRank, one));
-
- rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
-
- return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
- /*reshapeOp=*/nullptr, extractSliceOp};
- }
- // 2. Compute the permutation vector to move the last `numPackedDims` into
- // the `innerPosDims` of a shape of rank `packedRank`.
- int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
- auto lastDims = llvm::to_vector(
- llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
- PackingMetadata packingMetadata =
- computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
- SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
- packedRank, lastDims, packingMetadata.insertPositions);
-
- // 3. Compute the stripMinedShape: this is the packed shape without outer and
- // inner permutations.
- SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
- applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
-
- // 4. Transpose packedShape to stripMinedShape.
- RankedTensorType stripMinedTensorType =
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMinedTensorType, packingMetadata.reassociations);
- auto emptyOp =
- rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
- auto transposeOp = rewriter.create<linalg::TransposeOp>(
- loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
-
- LLVM_DEBUG(
- DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
- DBGS() << "insertPositions: ");
- DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
- DBGS() << "packedShape: ");
- DBGSNL();
- llvm::interleaveComma(lastDimsToInsertPositionsPerm,
- DBGS() << "lastDimsToInsertPositionsPerm: ");
- DBGSNL(); llvm::interleaveComma(
- packingMetadata.reassociations, DBGS() << "reassociations: ",
- [&](ReassociationIndices ri) {
- llvm::interleaveComma(ri, llvm::dbgs() << "|");
- });
- DBGSNL();
- llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
- DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
-
- // 5. Collapse from the stripMinedShape to the padded result.
- auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
- loc, collapsedType, transposeOp->getResult(0),
- packingMetadata.reassociations);
-
- // 6. ExtractSlice
- int64_t destRank = destTensorType.getRank();
- auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
- loc, destTensorType, reshapeOp->getResult(0),
- SmallVector<OpFoldResult>(destRank, zero),
- tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
- SmallVector<OpFoldResult>(destRank, one));
-
- // 7. Replace unPackOp by extractSliceOp.
- rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
-
- return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
-}
-
DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults,
transform::TransformState &state) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 166f42637523f..4d5ef0edc9f8a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -477,6 +477,220 @@ struct PackedOperandsDimList {
} // namespace
+FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
+ tensor::PackOp packOp) {
+ // 1. Filter out NYI cases.
+ if (!packOp.getOuterDimsPerm().empty())
+ return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
+
+ auto packedTensorType =
+ packOp->getResultTypes().front().cast<RankedTensorType>();
+ if (!packedTensorType.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ packOp,
+ "non-static shape NYI, needs a more powerful tensor.expand_shape op");
+ }
+
+ Location loc = packOp->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(packOp);
+
+ // 2. Compute the permutation vector to move the last `numPackedDims` into the
+ // `innerPosDims` of a shape of rank `packedRank`.
+ int64_t numPackedDims = packOp.getInnerDimsPos().size();
+ int64_t packedRank = packedTensorType.getRank();
+ auto lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+ PackingMetadata packingMetadata = computePackingMetadata(
+ packedTensorType.getRank(), packOp.getInnerDimsPos());
+ SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+ packedRank, lastDims, packingMetadata.insertPositions);
+
+ // 3. Compute the stripMinedShape: this is the packed shape before any outer
+ // or inner permutations have been applied.
+ SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
+ applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+
+ // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+ RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
+ packingMetadata.reassociations);
+ Value paddingValue = packOp.getPaddingValue();
+ if (!paddingValue) {
+ paddingValue = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
+ }
+ auto padOp =
+ tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
+ /*nofold=*/false, loc, rewriter);
+
+ LLVM_DEBUG(
+ DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
+ DBGS() << "insertPositions: ");
+ DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
+ DBGS() << "packedShape: ");
+ DBGSNL();
+ llvm::interleaveComma(lastDimsToInsertPositionsPerm,
+ DBGS() << "lastDimsToInsertPositionsPerm: ");
+ DBGSNL(); llvm::interleaveComma(
+ packingMetadata.reassociations, DBGS() << "reassociations: ",
+ [&](ReassociationIndices ri) {
+ llvm::interleaveComma(ri, llvm::dbgs() << "|");
+ });
+ DBGSNL();
+ llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
+ DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
+
+ if (packOp.isLikePad()) {
+ // This pack is just a plain pad.
+ // Just insert the pad in the higher ranked tensor.
+ auto emptyOp =
+ rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
+ // Offsets.
+ SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
+ // Strides.
+ SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> sizes =
+ getMixedDimensions(rewriter, loc, packOp.getDest());
+
+ auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+ loc, /*source=*/padOp, /*dest=*/emptyOp,
+ /*offsets=*/zeros, sizes,
+ /*strides=*/ones);
+
+ LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
+
+ rewriter.replaceOp(packOp, insertSliceOp->getResults());
+
+ return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
+ /*transposeOp=*/nullptr};
+ }
+ // 5. Expand from the padded result to the stripMinedShape.
+ auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc,
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
+ padOp.getResult(), packingMetadata.reassociations);
+
+ // 6. Transpose stripMinedShape to packedShape.
+ SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
+ packedRank, packingMetadata.insertPositions, lastDims);
+ auto transposeOp = rewriter.create<linalg::TransposeOp>(
+ loc, reshapeOp.getResult(), packOp.getDest(),
+ insertPositionsToLastDimsPerm);
+
+ LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
+ DBGS() << "reshape op: " << reshapeOp; DBGSNL();
+ llvm::interleaveComma(insertPositionsToLastDimsPerm,
+ DBGS() << "insertPositionsToLastDimsPerm: ");
+ DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
+
+ // 7. Replace packOp by transposeOp.
+ rewriter.replaceOp(packOp, transposeOp->getResults());
+
+ return LowerPackResult{padOp, reshapeOp, transposeOp};
+}
+
+FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
+ tensor::UnPackOp unPackOp) {
+ // 1. Filter out NYI cases.
+ if (!unPackOp.getOuterDimsPerm().empty())
+ return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
+
+ RankedTensorType packedTensorType = unPackOp.getSourceType();
+ if (!packedTensorType.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ unPackOp,
+ "non-static shape NYI, needs a more powerful tensor.expand_shape op");
+ }
+
+ Location loc = unPackOp->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unPackOp);
+
+ int64_t packedRank = packedTensorType.getRank();
+
+ OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
+ auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
+ if (unPackOp.isLikeUnPad()) {
+ // This unpack is just a plain unpad.
+ // Just extract the slice from the higher ranked tensor.
+ ArrayRef<int64_t> destShape = destTensorType.getShape();
+ // The inner dimensions stay the same as the destination tensor, but the
+ // outer ones are additional 1s.
+ SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
+ sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest()));
+
+ auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, destTensorType, unPackOp.getSource(),
+ SmallVector<OpFoldResult>(packedRank, zero), sizes,
+ SmallVector<OpFoldResult>(packedRank, one));
+
+ rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+
+ return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
+ /*reshapeOp=*/nullptr, extractSliceOp};
+ }
+ // 2. Compute the permutation vector to move the last `numPackedDims` into
+ // the `innerPosDims` of a shape of rank `packedRank`.
+ int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
+ auto lastDims = llvm::to_vector(
+ llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+ PackingMetadata packingMetadata =
+ computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
+ SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+ packedRank, lastDims, packingMetadata.insertPositions);
+
+ // 3. Compute the stripMinedShape: this is the packed shape without outer and
+ // inner permutations.
+ SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
+ applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+
+ // 4. Transpose packedShape to stripMinedShape.
+ RankedTensorType stripMinedTensorType =
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMinedTensorType, packingMetadata.reassociations);
+ auto emptyOp =
+ rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
+ auto transposeOp = rewriter.create<linalg::TransposeOp>(
+ loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
+
+ LLVM_DEBUG(
+ DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
+ DBGS() << "insertPositions: ");
+ DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
+ DBGS() << "packedShape: ");
+ DBGSNL();
+ llvm::interleaveComma(lastDimsToInsertPositionsPerm,
+ DBGS() << "lastDimsToInsertPositionsPerm: ");
+ DBGSNL(); llvm::interleaveComma(
+ packingMetadata.reassociations, DBGS() << "reassociations: ",
+ [&](ReassociationIndices ri) {
+ llvm::interleaveComma(ri, llvm::dbgs() << "|");
+ });
+ DBGSNL();
+ llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
+ DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
+
+ // 5. Collapse from the stripMinedShape to the padded result.
+ auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
+ loc, collapsedType, transposeOp->getResult(0),
+ packingMetadata.reassociations);
+
+ // 6. ExtractSlice
+ int64_t destRank = destTensorType.getRank();
+ auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, destTensorType, reshapeOp->getResult(0),
+ SmallVector<OpFoldResult>(destRank, zero),
+ tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
+ SmallVector<OpFoldResult>(destRank, one));
+
+ // 7. Replace unPackOp by extractSliceOp.
+ rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
+
+ return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
+}
+
SmallVector<int64_t>
PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
SmallVector<int64_t> res;
More information about the Mlir-commits
mailing list