[Mlir-commits] [mlir] 4ca52c6 - [mlir][Linalg] Add a transform.structured.lower_pack op

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jan 31 10:06:16 PST 2023


Author: Nicolas Vasilache
Date: 2023-01-31T10:06:08-08:00
New Revision: 4ca52c6e7ed938ac20fa9845af9dbb7d46136226

URL: https://github.com/llvm/llvm-project/commit/4ca52c6e7ed938ac20fa9845af9dbb7d46136226
DIFF: https://github.com/llvm/llvm-project/commit/4ca52c6e7ed938ac20fa9845af9dbb7d46136226.diff

LOG: [mlir][Linalg] Add a transform.structured.lower_pack op

This revision introduces `transform.structured.lower_pack` which allows
rewriting a `tensor.pack` to `tensor.pad` + `tensor.expand_shape` + `linalg.transpose`.

The implementation is currently limited to static pack ops that do not have outer_dims permutations.

Differential Revision: https://reviews.llvm.org/D142881

Added: 
    mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index ae7c5cb5a91f6..5bd5c034e1c5b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -17,11 +17,16 @@
 namespace mlir {
 class TilingInterface;
 class RewriterBase;
+
 namespace linalg {
 class GenericOp;
 class LinalgOp;
 } // namespace linalg
 
+namespace tensor {
+class PackOp;
+} // namespace tensor
+
 namespace transform {
 class TransformHandleTypeInterface;
 // Types needed for builders.

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 043dc778b00fb..2aee320b81936 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -215,6 +215,43 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LowerPackOp
+//===----------------------------------------------------------------------===//
+def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
+                         FunctionalStyleTransformOpTrait,
+                         MemoryEffectsOpInterface,
+                         TransformEachOpTrait,
+                         TransformOpInterface]> {
+  let description = [{
+    Rewrite a tensor.pack into tensor.pad + tensor.expand_shape + linalg.transpose.
+
+    #### Return modes
+
+    This operation ignores non-pack ops and drops them in the return.
+    This operation produces a silenceableFailure if the padding fails for any
+    reason.
+    If all the operations referred to by the `target` are rewritten, the
+    transform succeeds.
+    Return handles to the newly produced pad, expand_shape and transpose ops.
+  }];
+
+  let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
+  let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
+                      Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
+                      Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
+  let assemblyFormat = [{
+    $target attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::tensor::PackOp target,
+        ::mlir::transform::ApplyToEachResultList &transformResults,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // MatchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6cb73bacb0e06..7fe9dfe2fbf5f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1142,12 +1142,17 @@ FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
     GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter);
 
+/// Struct to hold the result of a `pack` call.
+struct PackResult {
+  SmallVector<tensor::PackOp> packOps;
+  linalg::LinalgOp packedLinalgOp;
+  SmallVector<tensor::UnPackOp> unPackOps;
+};
 /// Implement packing of a single LinalgOp by `packedSizes`.
 /// There must be one packedSizes entry per `linalgOp` iterator.
 /// Return the packed Linalg op on success, failure otherwise.
-FailureOr<linalg::LinalgOp> pack(RewriterBase &rewriter,
-                                 linalg::LinalgOp linalgOp,
-                                 ArrayRef<OpFoldResult> packedSizes);
+FailureOr<PackResult> pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
+                           ArrayRef<OpFoldResult> packedSizes);
 
 /// Struct to hold the result of a `packTranspose` call.
 struct PackTransposeResult {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b055e2ccba9af..cc8bbd570ef66 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -115,7 +115,7 @@ FailureOr<int64_t> getConstantUpperBoundForIndex(Value value);
 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
                             Value source, Value pad, bool nofold);
 
-/// Returns a GenericOp that tansposes `inputTensor` into `outputTensor` using
+/// Returns a GenericOp that transposes `inputTensor` into `outputTensor` using
 /// `transposeVector` to permute the `inputTensor` dimensions.
 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
                           Value outputTensor,

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 209e70f53123b..9328b1f4b2b83 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1113,7 +1113,13 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
     }]>
   ];
 
-  let extraClassDeclaration = commonExtraClassDeclaration;
+  let extraClassDeclaration = commonExtraClassDeclaration # [{  
+    static RankedTensorType
+    inferCollapsedType(RankedTensorType type, ArrayRef<AffineMap> reassociation);
+    static RankedTensorType
+    inferCollapsedType(RankedTensorType type, 
+                       SmallVector<ReassociationIndices> reassociation);
+  }];
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2060dd7649781..9952bb1cc5ac8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -14,10 +14,12 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
@@ -29,12 +31,13 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/TilingInterface.h"
-#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/Debug.h"
 
@@ -131,6 +134,81 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
   return DiagnosedSilenceableFailure::success();
 }
 
+/// Return a permutation vector of size permSize that would result in moving
+/// positions into desiredPositions.
+///
+/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
+/// would result in a {4, 2, 0, 1, 3} permutation vector.
+static SmallVector<int64_t>
+computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
+                         ArrayRef<int64_t> desiredPositions) {
+  SmallVector<int64_t> res(permSize, -1);
+  DenseSet<int64_t> seen;
+  for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) {
+    res[desiredPos] = pos;
+    seen.insert(pos);
+  }
+  int64_t nextPos = 0;
+  for (int64_t &entry : res) {
+    if (entry != -1)
+      continue;
+    while (seen.contains(nextPos))
+      ++nextPos;
+    entry = nextPos;
+    ++nextPos;
+  }
+  return res;
+}
+
+struct PackingMetadata {
+  SmallVector<int64_t> insertPositions;
+  SmallVector<ReassociationIndices> reassociations;
+};
+/// Given a vector of `positions` indices representing desired packing insertion
+/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
+/// final positions in the target shape as well as the reshape reassociations.
+// Note: This should not be called with a large positions array (or the
+// implementation needs to be updated to use an N.log N sort instead of
+// repeated N^2 counts).
+static PackingMetadata computePackingMetadata(int64_t packedRank,
+                                              ArrayRef<int64_t> innerDimPos) {
+  PackingMetadata res;
+  res.insertPositions.reserve(innerDimPos.size());
+  // The pack insert position is the position + the number of previously
+  // inserted positions + offset.
+  // The offset controls whether the packing dimension is the first or last.
+  //
+  // Example
+  // =======
+  // Consider packing from a hypothetical ABCD layout to ABCDba whose
+  // pack.inner_dims is [1, 0]. The first step consists in undoing the
+  // permutation and producing AaBbCD. This is achieved purely by computing the
+  // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
+  // possibility, is to produce insert positions [2, 0], this would result in an
+  // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
+  // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
+  // The latter is what we expect from packing.
+  int64_t offset = 1;
+  for (int64_t pos : innerDimPos) {
+    int64_t numInsertedBefore = llvm::count_if(
+        innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
+    res.insertPositions.push_back(pos + numInsertedBefore + offset);
+  }
+
+  DenseSet<int64_t> posSet(res.insertPositions.begin(),
+                           res.insertPositions.end());
+  res.reassociations.reserve(packedRank);
+  for (int64_t i = 1; i <= packedRank; ++i) {
+    if (!posSet.contains(i)) {
+      res.reassociations.push_back(ReassociationIndices{i - 1});
+      continue;
+    }
+    res.reassociations.push_back(ReassociationIndices{i - 1, i});
+    ++i;
+  }
+  return res;
+}
+
 //===----------------------------------------------------------------------===//
 // DecomposeOp
 //===----------------------------------------------------------------------===//
@@ -323,7 +401,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
                                              Diagnostic &diag,
                                              Operation *producerOp,
                                              Operation *containingOp) {
-  LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n");
+  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
   if (!tileableProducer) {
     diag.attachNote(producerOp->getLoc())
@@ -354,7 +432,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
   // Tile the producer.
   int64_t resultNumber =
       sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
-  LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
 
   FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
       rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
@@ -364,7 +442,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
         << "failed to tile producer op: " << *tileableProducer;
     return nullptr;
   }
-  LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
+  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n");
 
   // Replace the extract op.
   Operation *fusedOp = tiledProducer->getDefiningOp();
@@ -388,8 +466,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
 static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
     RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
     Operation *containingOp) {
-  LLVM_DEBUG(
-      llvm::dbgs() << "Try to fuse an extract use through block argument\n");
+  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
 
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
   if (!tileableProducer) {
@@ -442,7 +519,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   // Replace the use in the tileableProducer before tiling: clone, replace and
   // then tile.
   int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber();
-  LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
 
   // Gather destination tensors.
   SmallVector<Value> destinationTensors;
@@ -471,7 +548,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
         << "failed to tile producer op: " << *tileableProducer;
     return nullptr;
   }
-  LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
+  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledProducer << "\n");
 
   // Replace the extract op.
   Operation *fusedOp = tiledProducer->getDefiningOp();
@@ -496,7 +573,7 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
 static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
                                        Operation *producerOp,
                                        Operation *containingOp) {
-  LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n");
+  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
 
   // Gather all uses inside the containing op.
   SmallVector<OpOperand *> uses;
@@ -530,7 +607,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
   assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
          "Parallel insert slice is not a valid clone destination");
   unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
-  LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
 
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(use->getOwner());
@@ -607,8 +684,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
     Operation *tiled =
         tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
     if (tiled) {
-      LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n"
-                              << *containingOp);
+      LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
       fusedOps.push_back(tiled);
       continue;
     }
@@ -617,9 +693,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
         tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
             rewriter, diag, producerOp, containingOp);
     if (tiledContainingOpOperand) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "\nFused an extract use through block argument\n"
-                 << *containingOp);
+      LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
+                        << *containingOp);
       fusedOps.push_back(tiledContainingOpOperand);
       continue;
     }
@@ -627,8 +702,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
     Operation *cloned =
         cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
     if (cloned) {
-      LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n"
-                              << *containingOp);
+      LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
       fusedOps.push_back(cloned);
       continue;
     }
@@ -697,6 +771,123 @@ LogicalResult transform::InterchangeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// LowerPackOp
+//===----------------------------------------------------------------------===//
+
+struct LowerPackResult {
+  tensor::PadOp padOp;
+  tensor::ExpandShapeOp expandShapeOp;
+  linalg::TransposeOp transposeOp;
+};
+
+/// Rewrite pack as pad + reshape + transpose.
+static FailureOr<LowerPackResult> rewriteLowerPack(RewriterBase &rewriter,
+                                                   tensor::PackOp packOp) {
+  // 1. Filter out NYI cases.
+  if (!packOp.getOuterDimsPerm().empty())
+    return rewriter.notifyMatchFailure(packOp, "outer dims perm NYI");
+
+  auto packedTensorType =
+      packOp->getResultTypes().front().cast<RankedTensorType>();
+  if (!packedTensorType.hasStaticShape()) {
+    return rewriter.notifyMatchFailure(
+        packOp,
+        "non-static shape NYI, needs a more powerful tensor.expand_shape op");
+  }
+
+  Location loc = packOp->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
+  // 2. Compute the permutation vector to move the last `numPackedDims` into the
+  // `innerPosDims` of a shape of rank `packedRank`.
+  int64_t numPackedDims = packOp.getInnerDimsPos().size();
+  int64_t packedRank = packedTensorType.getRank();
+  auto lastDims = llvm::to_vector(
+      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
+  PackingMetadata packingMetadata = computePackingMetadata(
+      packedTensorType.getRank(), packOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
+      packedRank, lastDims, packingMetadata.insertPositions);
+
+  // 3. Compute the stripMinedShape: this is the packed shape before any outer
+  // or inner permutations have been applied.
+  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
+  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+
+  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
+  RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
+      packingMetadata.reassociations);
+  Value paddingValue = packOp.getPaddingValue();
+  if (!paddingValue) {
+    paddingValue = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
+  }
+  auto padOp =
+      tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
+                              /*nofold=*/false, loc, rewriter);
+
+  LLVM_DEBUG(
+      DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
+                                                DBGS() << "insertPositions: ");
+      DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
+                                      DBGS() << "packedShape: ");
+      DBGSNL();
+      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
+                            DBGS() << "lastDimsToInsertPositionsPerm: ");
+      DBGSNL(); llvm::interleaveComma(
+          packingMetadata.reassociations, DBGS() << "reassociations: ",
+          [&](ReassociationIndices ri) {
+            llvm::interleaveComma(ri, llvm::dbgs() << "|");
+          });
+      DBGSNL();
+      llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
+      DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
+
+  // 5. Expand from the padded result to the stripMinedShape.
+  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
+      loc,
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
+      padOp.getResult(), packingMetadata.reassociations);
+
+  // 6. Transpose stripMinedShape to packedShape.
+  SmallVector<int64_t> insertPositionsToLastDimsPerm = computePermutationVector(
+      packedRank, packingMetadata.insertPositions, lastDims);
+  auto transposeOp = rewriter.create<linalg::TransposeOp>(
+      loc, reshapeOp.getResult(), packOp.getDest(),
+      insertPositionsToLastDimsPerm);
+
+  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
+             DBGS() << "reshape op: " << reshapeOp; DBGSNL();
+             llvm::interleaveComma(insertPositionsToLastDimsPerm,
+                                   DBGS() << "insertPositionsToLastDimsPerm: ");
+             DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
+
+  // 7. Replace packOp by transposeOp.
+  rewriter.replaceOp(packOp, transposeOp->getResults());
+
+  return LowerPackResult{padOp, reshapeOp, transposeOp};
+}
+
+DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
+    tensor::PackOp target, transform::ApplyToEachResultList &transformResults,
+    transform::TransformState &state) {
+  IRRewriter rewriter(target->getContext());
+  rewriter.setInsertionPoint(target);
+  FailureOr<LowerPackResult> res = rewriteLowerPack(rewriter, target);
+  if (failed(res)) {
+    Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
+    diag << "cannot lower to pad + expand + transpose";
+    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+  }
+  transformResults.push_back(res->padOp);
+  transformResults.push_back(res->expandShapeOp);
+  transformResults.push_back(res->transposeOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===---------------------------------------------------------------------===//
 // MatchOp
 //===---------------------------------------------------------------------===//
@@ -931,12 +1122,12 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
 
   IRRewriter rewriter(linalgOp->getContext());
   rewriter.setInsertionPoint(linalgOp);
-  FailureOr<LinalgOp> maybeResult = pack(rewriter, linalgOp, packedSizes);
+  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
   if (failed(maybeResult))
     return emitDefiniteFailure("data tiling failed");
 
   transformResults.set(getPackedOp().cast<OpResult>(),
-                       maybeResult->getOperation());
+                       maybeResult->packedLinalgOp.getOperation());
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -1045,35 +1236,8 @@ static FailureOr<GemmDimsForPacking> getGemmDims(LinalgOp linalgOp) {
   return GemmDimsForPacking{*ac.begin(), *bc.begin(), *ra.begin()};
 }
 
-/// Return a permutation vector of size permSize that would result in moving
-/// positions into desiredPositions.
-///
-/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
-/// would result in a {4, 2, 0, 1, 3} permutation vector.
-static SmallVector<int64_t>
-computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
-                         ArrayRef<int64_t> desiredPositions) {
-  SmallVector<int64_t> res(permSize, -1);
-  DenseSet<int64_t> seen;
-  for (auto [pos, desiredPos] : llvm::zip(positions, desiredPositions)) {
-    res[desiredPos] = pos;
-    seen.insert(pos);
-  }
-  int64_t nextPos = 0;
-  for (int64_t &entry : res) {
-    if (entry != -1)
-      continue;
-    while (seen.contains(nextPos))
-      ++nextPos;
-    entry = nextPos;
-    ++nextPos;
-  }
-  return res;
-}
-
-/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k)
-/// where m and n are proper parallel dimensions and k is a proper reduction
-/// dimension.
+/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k) where m and
+/// n are proper parallel dimensions and k is a proper reduction dimension.
 /// Packing occurs by rewriting the op as a linalg.generic and calling
 /// linalg::pack by `mnkPackedSizes`.
 /// The order of the packed dimensions is customizable: the `mnkOrder` is a
@@ -1081,7 +1245,7 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
 /// forms.
 /// The outer dimensions of the operands are not permuted at this time, this is
 /// left for future work.
-static FailureOr<LinalgOp>
+static FailureOr<PackResult>
 packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
                  ArrayRef<OpFoldResult> mnkPackedSizes,
                  ArrayRef<int64_t> mnkOrder) {
@@ -1182,13 +1346,13 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
     rewriter.setInsertionPointAfter(linalgOp);
     // Failing to pack greedily is perfectly fine.
     // In the future we will want to order packings according to some metric.
-    FailureOr<LinalgOp> gemm = packGemmGreedily(
+    FailureOr<PackResult> packResult = packGemmGreedily(
         /*rewriter=*/rewriter,
         /*linalgOp=*/linalgOp,
         /*mnkPackedSizes=*/getMixedGemmPackedSizes(),
         /*mnkOrder=*/getGemmInnerDimsOrder());
-    if (succeeded(gemm)) {
-      results.push_back(*gemm);
+    if (succeeded(packResult)) {
+      results.push_back(packResult->packedLinalgOp);
       continue;
     }
     results.push_back(linalgOp);
@@ -1235,9 +1399,9 @@ namespace {
 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
 } // namespace
 
-/// Return true if `permutation` is a valid permutation of the `outer_dims_perm`
-/// (case OuterOrInnerPerm::Outer) or `inner_dims_pos` (OuterOrInnerPerm::Inner)
-/// of the `tensor.pack` or `tensor.unpack` `op.
+/// Return true if `permutation` is a valid permutation of the
+/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
+/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
 /// This is the case when the `permutation` rank matches the rank expected by
 /// `op` and `permutation` is itself a permutation vector.
 /// Return true if either `op` or `permutation` are empty to allow a simpler
@@ -1281,10 +1445,10 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
   // Step 2. Bunch of runtime sanity check and error messages.
   // Step 2.1. Fail on multi-op handles.
   if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) {
-    return emitSilenceableError()
-           << "requires target to map to exactly 1 packing op and 1 packed op ("
-           << "got " << packOrUnpackOps.size() << " and " << linalgOps.size()
-           << ")";
+    return emitSilenceableError() << "requires target to map to exactly 1 "
+                                     "packing op and 1 packed op ("
+                                  << "got " << packOrUnpackOps.size() << " and "
+                                  << linalgOps.size() << ")";
   }
 
   // Step 2.2. Fail on wrong type.
@@ -1311,7 +1475,8 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
     return emitSilenceableError() << errorMsg;
   }
 
-  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical PackOp.
+  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
+  // PackOp.
   if (unPackOp) {
     assert(!packOp && "packOp must be null on entry when unPackOp is not null");
     OpOperand *packUse = linalgOp.getDpsInitOperand(
@@ -1700,9 +1865,9 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
   }
 
   if (second.size() != first.size() && !second.empty()) {
-    auto diag =
-        emitSilenceableError()
-        << "splitting does not produce the second part for a subset of targets";
+    auto diag = emitSilenceableError()
+                << "splitting does not produce the second part for a subset "
+                   "of targets";
     diag.attachNote() << "expected splitting to produce the second part of all "
                          "or none of the targets";
     diag.attachNote(noSecondPart->getLoc())
@@ -1965,7 +2130,8 @@ void transform::TileOp::build(OpBuilder &builder, OperationState &result,
                               Value target,
                               ArrayRef<OpFoldResult> mixedTileSizes,
                               ArrayRef<int64_t> interchange) {
-  // Loop types are automaticaly splat by the callee, setting up one is enough.
+  // Loop types are automaticaly splat by the callee, setting up one is
+  // enough.
   SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
   build(builder, result, loopTypes, target, mixedTileSizes, interchange);
 }
@@ -1978,8 +2144,8 @@ void transform::TileOp::build(OpBuilder &builder, OperationState &result,
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
   // Call the default builder which sets up the proper operands segment sizes
-  // attributes for multiple variadic operands. In the absence of this, horrible
-  // bugs ensue.
+  // attributes for multiple variadic operands. In the absence of this,
+  // horrible bugs ensue.
   auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
   unsigned numExpectedLoops =
       staticTileSizes.size() - llvm::count(staticTileSizes, 0);
@@ -2247,8 +2413,8 @@ void transform::TileToForeachThreadOp::build(
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
   // Call the default builder which sets up the proper operands segment sizes
-  // attributes for multiple variadic operands. In the absence of this, horrible
-  // bugs ensue.
+  // attributes for multiple variadic operands. In the absence of this,
+  // horrible bugs ensue.
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
@@ -2284,8 +2450,8 @@ void transform::TileToForeachThreadOp::build(
   dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
                              staticNumThreads);
   // Call the default builder which sets up the proper operands segment sizes
-  // attributes for multiple variadic operands. In the absence of this, horrible
-  // bugs ensue.
+  // attributes for multiple variadic operands. In the absence of this,
+  // horrible bugs ensue.
   MLIRContext *ctx = builder.getContext();
   auto operationType = pdl::OperationType::get(ctx);
   auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
@@ -2415,8 +2581,8 @@ LogicalResult TileToForeachThreadOp::verify() {
     return emitOpError(
         "tile_sizes and packed_tile_sizes are mutually exclusive");
   if (numThreadsSpec == 0 && tileSizesSpec == 0)
-    return emitOpError(
-        "either (packed_)num_threads or (packed_)tile_sizes must be specified");
+    return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
+                       "must be specified");
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 722bd9f2d8a36..f1f92d329ebb4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1052,9 +1052,9 @@ PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
 /// Implement packing of a single LinalgOp by performing packing by
 /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
 /// Return the packed Linalg op on success, failure otherwise.
-FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
-                                         linalg::LinalgOp linalgOp,
-                                         ArrayRef<OpFoldResult> packedSizes) {
+FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
+                                   linalg::LinalgOp linalgOp,
+                                   ArrayRef<OpFoldResult> packedSizes) {
   if (packedSizes.size() != linalgOp.getNumLoops()) {
     return rewriter.notifyMatchFailure(linalgOp,
                                        "incorrect number of pack sizes");
@@ -1069,6 +1069,8 @@ FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
              llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
              DBGSNL(););
 
+  SmallVector<tensor::PackOp> packOps;
+  SmallVector<tensor::UnPackOp> unPackOps;
   // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
   PackedOperandsDimList listOfPackedOperandsDim;
   for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
@@ -1124,8 +1126,9 @@ FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
       Attribute zeroAttr =
           rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
       Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
-      inputsAndInits.push_back(rewriter.create<tensor::PackOp>(
+      packOps.push_back(rewriter.create<tensor::PackOp>(
           loc, operand, dest, innerPos, innerPackSizes, zero));
+      inputsAndInits.push_back(packOps.back());
     }
   }
 
@@ -1149,16 +1152,19 @@ FailureOr<linalg::LinalgOp> linalg::pack(RewriterBase &rewriter,
       continue;
     }
     // Build the symmetrical UnPackOp to the existing PackOp.
-    results.push_back(rewriter.create<tensor::UnPackOp>(
+    unPackOps.push_back(rewriter.create<tensor::UnPackOp>(
         packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
         maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
+    results.push_back(unPackOps.back());
   }
 
   // Step 5. Replace `linalgOp`.
   rewriter.replaceOp(linalgOp, results);
 
   // Return packedLinalgOp.
-  return cast<linalg::LinalgOp>(packedLinalgOp.getOperation());
+  return PackResult{packOps,
+                    cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
+                    unPackOps};
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4515d711f72bf..7960b64fd7151 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1272,11 +1272,18 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
                                             getReassociationIndices());
 }
 
+RankedTensorType CollapseShapeOp::inferCollapsedType(
+    RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
+  return inferCollapsedType(
+      type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+                type.getContext(), reassociation)));
+}
+
 /// Compute the RankedTensorType obtained by applying `reassociation` to
 /// `type`.
-static RankedTensorType
-computeTensorReshapeCollapsedType(RankedTensorType type,
-                                  ArrayRef<AffineMap> reassociation) {
+RankedTensorType
+CollapseShapeOp::inferCollapsedType(RankedTensorType type,
+                                    ArrayRef<AffineMap> reassociation) {
   auto shape = type.getShape();
   SmallVector<int64_t, 4> newShape;
   newShape.reserve(reassociation.size());
@@ -1304,7 +1311,7 @@ computeTensorReshapeCollapsedType(RankedTensorType type,
 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<ReassociationIndices> reassociation,
                             ArrayRef<NamedAttribute> attrs) {
-  auto resultType = computeTensorReshapeCollapsedType(
+  auto resultType = inferCollapsedType(
       src.getType().cast<RankedTensorType>(),
       getSymbolLessAffineMaps(
           convertReassociationIndicesToExprs(b.getContext(), reassociation)));
@@ -1336,7 +1343,7 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
 
   auto maps = op.getReassociationMaps();
   RankedTensorType expectedType =
-      computeTensorReshapeCollapsedType(expandedType, maps);
+      CollapseShapeOp::inferCollapsedType(expandedType, maps);
   if (!isSameTypesWithoutEncoding(collapsedType, expectedType))
     return op.emitOpError("expected collapsed type to be ")
            << expectedType << ", but got " << collapsedType;
@@ -1436,7 +1443,7 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
 
     RankedTensorType srcType =
         castOp.getSource().getType().cast<RankedTensorType>();
-    RankedTensorType newResultType = computeTensorReshapeCollapsedType(
+    RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
         srcType, collapseShapeOp.getReassociationMaps());
 
     if (newResultType == collapseShapeOp.getResultType()) {

diff  --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
new file mode 100644
index 0000000000000..1b87903bba59f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s
+
+func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
+  %cst_0 = arith.constant 0.0 : f32
+
+  // tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
+  //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+  //      CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
+  //      CHECK:   : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
+  //      CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
+  // CHECK-SAME:   : tensor<136x64x16x16xf32> into tensor<17x8x2x32x16x16xf32>
+  //      CHECK: linalg.transpose
+  // CHECK-SAME:   ins(%{{.*}} : tensor<17x8x2x32x16x16xf32>)
+  // CHECK-SAME:   outs(%{{.*}} : tensor<17x2x16x16x32x8xf32>)
+  // CHECK-SAME:   permutation = [0, 2, 4, 5, 3, 1]
+  %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
+    : tensor<129x47x16x16xf32> -> tensor<17x2x16x16x32x8xf32>
+  return %pack : tensor<17x2x16x16x32x8xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %pack = transform.structured.match ops{["tensor.pack"]} in %module_op 
+    : (!pdl.operation) -> !transform.op<"tensor.pack">
+  transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">) 
+    -> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
+}
+

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7f4d34a4b916b..39649d7d45d30 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8341,12 +8341,14 @@ cc_library(
         ":LinalgDialect",
         ":LinalgTransformOpsIncGen",
         ":LinalgTransforms",
+        ":LinalgUtils",
         ":PDLDialect",
         ":Parser",
         ":SCFTransforms",
         ":SideEffectInterfaces",
         ":Support",
         ":TensorDialect",
+        ":TensorUtils",
         ":TilingInterface",
         ":TransformDialect",
         ":TransformDialectUtils",


        


More information about the Mlir-commits mailing list