[Mlir-commits] [mlir] 5b2f7a1 - [mlir][linalg] Support lowering unpack with outer_dims_perm (#94477)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 7 06:00:31 PDT 2024


Author: Ryan Holt
Date: 2024-06-07T09:00:28-04:00
New Revision: 5b2f7a19711710df1469ab8a0096fe2f8052c8b1

URL: https://github.com/llvm/llvm-project/commit/5b2f7a19711710df1469ab8a0096fe2f8052c8b1
DIFF: https://github.com/llvm/llvm-project/commit/5b2f7a19711710df1469ab8a0096fe2f8052c8b1.diff

LOG: [mlir][linalg] Support lowering unpack with outer_dims_perm (#94477)

This commit adds support for lowering `tensor.unpack` with a
non-identity `outer_dims_perm`. This was previously left as a
not-yet-implemented case.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 91dfac802ad67..f30ef235e9cd3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -356,13 +356,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
 
 FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
                                                    tensor::UnPackOp unPackOp) {
-  // 1. Filter out NYI cases.
-  if (!unPackOp.getOuterDimsPerm().empty() &&
-      !isIdentityPermutation(unPackOp.getOuterDimsPerm())) {
-    return rewriter.notifyMatchFailure(unPackOp,
-                                       "non-identity outer dims perm NYI");
-  }
-
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
@@ -391,36 +384,33 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
     return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
                                /*reshapeOp=*/nullptr, extractSliceOp};
   }
-  // 2. Compute the permutation vector to move the last `numPackedDims` into
-  // the `innerPosDims` of a shape of rank `packedRank`.
-  int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
-  auto lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
-  PackingMetadata packingMetadata =
-      computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
-  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
-      packedRank, lastDims, packingMetadata.insertPositions);
-
-  // 3. Compute the stripMinedShape: this is the packed shape without outer and
+
+  // 1. Compute the permutation vector to shuffle packed shape into the shape
+  // before any outer or inner permutations have been applied.
+  PackingMetadata packingMetadata;
+  SmallVector<int64_t> packedToStripMinedShapePerm =
+      tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata);
+
+  // 2. Compute the stripMinedShape: this is the packed shape without outer and
   // inner permutations.
   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
-  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+  applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
-  // 4. Transpose packedShape to stripMinedShape.
+  // 3. Transpose packedShape to stripMinedShape.
   RankedTensorType stripMinedTensorType =
       RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
       stripMinedTensorType, packingMetadata.reassociations);
 
-  // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
+  // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
   // permutation.
   SmallVector<OpFoldResult, 4> dims =
       tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
-  applyPermutationToVector(dims, lastDimsToInsertPositionsPerm);
+  applyPermutationToVector(dims, packedToStripMinedShapePerm);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(
       loc, dims, stripMinedTensorType.getElementType());
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
-      loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
+      loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
 
   LLVM_DEBUG(
       DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
@@ -428,8 +418,8 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
       DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
                                       DBGS() << "packedShape: ");
       DBGSNL();
-      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
-                            DBGS() << "lastDimsToInsertPositionsPerm: ");
+      llvm::interleaveComma(packedToStripMinedShapePerm,
+                            DBGS() << "packedToStripMinedShapePerm: ");
       DBGSNL(); llvm::interleaveComma(
           packingMetadata.reassociations, DBGS() << "reassociations: ",
           [&](ReassociationIndices ri) {
@@ -439,12 +429,12 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
       llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
       DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
 
-  // 5. Collapse from the stripMinedShape to the padded result.
+  // 4. Collapse from the stripMinedShape to the padded result.
   auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
       loc, collapsedType, transposeOp->getResult(0),
       packingMetadata.reassociations);
 
-  // 6. ExtractSlice.
+  // 5. ExtractSlice.
   int64_t destRank = destTensorType.getRank();
   auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
       loc, destTensorType, reshapeOp->getResult(0),
@@ -452,11 +442,11 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
       tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
       SmallVector<OpFoldResult>(destRank, one));
 
-  // 7. Inject a copy to preserve DPS.
+  // 6. Inject a copy to preserve DPS.
   auto copyOp = rewriter.create<linalg::CopyOp>(
       loc, extractSliceOp->getResult(0), unPackOp.getDest());
 
-  // 8. Replace unPackOp by extractSliceOp.
+  // 7. Replace unPackOp by copyOp.
   rewriter.replaceOp(unPackOp, copyOp->getResults());
 
   return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};

diff  --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 926969bfc7388..f34ef4f961483 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -622,9 +622,20 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// At the moment, we cannot lower tensor.unpack with outer_dims_perm.
-func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
-  // expected-note @below {{target payload op}}
+// CHECK-LABEL: @unpack_with_outer_dims_perm
+//  CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
+//       CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
+//       CHECK: %[[TRAN:.*]] = linalg.transpose 
+//  CHECK-SAME:   ins(%[[ARG1]] : tensor<2x4x32x8xf32>) 
+//  CHECK-SAME:   outs(%[[EMPTY]] : tensor<4x8x2x32xf32>) 
+//  CHECK-SAME:   permutation = [1, 3, 0, 2]
+//       CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
+//  CHECK-SAME:   : tensor<4x8x2x32xf32> into tensor<32x64xf32>
+//       CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [32, 64] [1, 1]
+//  CHECK-SAME:   : tensor<32x64xf32> to tensor<32x64xf32>
+//       CHECK: linalg.copy ins(%[[SLICE]]
+//  CHECK-SAME:   : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
+func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
   %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0] 
     inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
   return %unpack : tensor<32x64xf32>
@@ -634,7 +645,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
       : (!transform.any_op) -> !transform.op<"tensor.unpack">
-    // expected-error @below {{cannot lower to transpose + collapse + extract}} 
     transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
       -> (!transform.op<"tensor.empty">,
           !transform.op<"linalg.transpose">,


        


More information about the Mlir-commits mailing list