[Mlir-commits] [mlir] 886294a - [mlir][linalg] Add pattern to propagate pack up through tensor.pad (#82035)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 18 08:20:19 PST 2024


Author: Quinn Dawkins
Date: 2024-02-18T11:20:15-05:00
New Revision: 886294a2fe5928ecf34299e02526e17be19910c6

URL: https://github.com/llvm/llvm-project/commit/886294a2fe5928ecf34299e02526e17be19910c6
DIFF: https://github.com/llvm/llvm-project/commit/886294a2fe5928ecf34299e02526e17be19910c6.diff

LOG: [mlir][linalg] Add pattern to propagate pack up through tensor.pad (#82035)

This mirrors the existing pattern for pushing unpack down through
padding, restricting to cases where the padded dimensions aren't tiled
by the pack.

Additionally reformats the propagation test to make it easier to read.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
    mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6a971b37cad7c5..5ceb85e7d9903b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -470,6 +470,88 @@ struct BubbleUpPackOpThroughGenericOpPattern
   ControlPropagationFn controlFn;
 };
 
+/// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
+/// add as many zero padding dimensions in `high` and `low` based on the number
+/// of point loops.
+class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
+public:
+  BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
+      : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+
+    // User controlled propagation function.
+    if (!controlFn(padOp))
+      return failure();
+
+    if (!padOp.getResult().hasOneUse())
+      return failure();
+
+    // TODO: Enable padding when the padding values are the same.
+    if (packOp.getPaddingValue())
+      return failure();
+
+    // Fail for non-constant padding values. The body of the pad could
+    // depend on the padding indices and/or properties of the padded
+    // tensor so for now we fail.
+    // TODO: Support non-constant padding values.
+    Value paddingVal = padOp.getConstantPaddingValue();
+    if (!paddingVal)
+      return failure();
+
+    if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+      return failure();
+
+    ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+    ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+
+    // Bail out if one of the padded dimension is a tiled one.
+    llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
+    llvm::SmallBitVector innerDims(paddedDims.size());
+    for (int64_t dim : innerDimsPos)
+      innerDims.flip(dim);
+    if (paddedDims.anyCommon(innerDims))
+      return failure();
+
+    Location loc = padOp->getLoc();
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(padOp);
+
+    auto empty = tensor::PackOp::createDestinationTensor(
+        rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
+        outerDimsPerm);
+    Value packedSource = rewriter.create<tensor::PackOp>(
+        loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
+        /*padding=*/std::nullopt, outerDimsPerm);
+
+    // If we have `outer_dims_perms` we need to adjust the padded dimensions.
+    SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
+    if (!outerDimsPerm.empty()) {
+      applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
+      applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
+    }
+    // The tiled dimensions were verified to be unpadded above, so here we
+    // just append 0 for the inner tile dimensions.
+    size_t pointLoopsSize = innerDimsPos.size();
+    lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+    highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
+        padOp.getNofold());
+    rewriter.replaceOp(packOp, newPadOp.getResult());
+    return success();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 // TODO: Relax this restriction. We should unpack a generic op also
 // in the presence of multiple unpack ops as producers.
 /// Return the unpacked operand, if present, for the current generic op.
@@ -690,7 +772,8 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
 void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation) {
-  patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
-                  PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
-      patterns.getContext(), controlPackUnPackPropagation);
+  patterns
+      .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
+              PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
+          patterns.getContext(), controlPackUnPackPropagation);
 }

diff  --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 4c59c97aecc251..e036695a2ac9fd 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -21,28 +21,28 @@ func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>)
     into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
   return %4 : tensor<?x?x8x2xf32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK:      func.func @dynamic_elem_pack
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:    %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG:    %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK-DAG:    %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
-// CHECK-DAG:    %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor<?x?x8x2xf32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [0, 1] inner_tiles = [8, 2]
-// CHECK-SAME:     into %[[ARG0_EMPTY]]
-// CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[DEST]]
-// CHECK:        return %[[ELEM]] : tensor<?x?x8x2xf32>
+// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
+// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL:  func.func @dynamic_elem_pack
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG:      %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:      %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:      %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG:      %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG:      %[[OUTER_D0:.+]] = affine.apply #[[$MAP0]]()[%[[D0]]]
+// CHECK-DAG:      %[[OUTER_D1:.+]] = affine.apply #[[$MAP1]]()[%[[D1]]]
+// CHECK:          %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor<?x?x8x2xf32>
+// CHECK:          %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:       inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+// CHECK-SAME:       into %[[ARG0_EMPTY]]
+// CHECK:          %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[$MAP2]], #[[$MAP2]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:       ins(%[[PACK_ARG0]]
+// CHECK-SAME:       outs(%[[DEST]]
+// CHECK:          return %[[ELEM]] : tensor<?x?x8x2xf32>
 
 // -----
 
@@ -62,20 +62,20 @@ func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: ten
     into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32>
   return %pack : tensor<4x16x16x32xi32>
 }
-// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK:      func.func @elem_pack_transpose_inner_dims
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [1, 0] inner_tiles = [16, 32]
-// CHECK-SAME:     into %[[ARG0_EMPTY]]
-// CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[DEST]]
-// CHECK:        return %[[ELEM]] : tensor<4x16x16x32xi32>
+// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_inner_dims
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
+// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK:         %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:      ins(%[[PACK_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
+// CHECK:         return %[[ELEM]] : tensor<4x16x16x32xi32>
 
 // -----
 
@@ -96,20 +96,20 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: ten
     into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
   return %pack : tensor<16x4x32x16xi32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK:      func.func @elem_pack_transpose_outer_dims
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:     into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
-// CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[DEST]]
-// CHECK:        return %[[ELEM]] : tensor<16x4x32x16xi32>
+// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:      into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+// CHECK:         %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:      ins(%[[PACK_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
+// CHECK:         return %[[ELEM]] : tensor<16x4x32x16xi32>
 
 // -----
 
@@ -130,20 +130,20 @@ func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>,
     into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32>
   return %pack : tensor<16x4x16x32xi32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK:      func.func @elem_pack_transpose_inner_and_outer_dims
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
-// CHECK-SAME:     into %[[ARG0_EMPTY]]
-// CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[DEST]]
-// CHECK:        return %[[ELEM]] : tensor<16x4x16x32xi32>
+// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_inner_and_outer_dims
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
+// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK:         %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:      ins(%[[PACK_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
+// CHECK:         return %[[ELEM]] : tensor<16x4x16x32xi32>
 
 // -----
 
@@ -169,34 +169,34 @@ func.func @dynamic_broadcast_pack(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %d
     into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
   return %4 : tensor<?x?x8x2xf32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
-// CHECK-DAG:  #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
-// CHECK-DAG:  #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK:      func.func @dynamic_broadcast_pack
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG:    %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor<?x8xf32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [8]
-// CHECK-SAME:     into %[[ARG0_EMPTY]]
-// CHECK-DAG:    %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C0]]
-// CHECK-DAG:    %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
-// CHECK:        %[[ARG1_EMPTY:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor<?x2xf32>
-// CHECK:        %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
-// CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [2]
-// CHECK-SAME:     into %[[ARG1_EMPTY]]
-// CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]], %[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[DEST]]
-// CHECK:        return %[[ELEM]] : tensor<?x?x8x2xf32>
+// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
+// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+// CHECK-DAG:  #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK-DAG:  #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @dynamic_broadcast_pack
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG:     %[[OUTER_D0:.+]] = affine.apply #[[$MAP0]]()[%[[D0]]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor<?x8xf32>
+// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [8]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK-DAG:     %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG:     %[[OUTER_D1:.+]] = affine.apply #[[$MAP1]]()[%[[D1]]]
+// CHECK:         %[[ARG1_EMPTY:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor<?x2xf32>
+// CHECK:         %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [2]
+// CHECK-SAME:      into %[[ARG1_EMPTY]]
+// CHECK:         %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP4]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:      ins(%[[PACK_ARG0]], %[[PACK_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
+// CHECK:         return %[[ELEM]] : tensor<?x?x8x2xf32>
 
 // -----
 
@@ -215,19 +215,19 @@ func.func @elem_pack_transpose_inner_and_outer_dims2(%arg0: tensor<64xf32>, %des
   %2 = tensor.pack %1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %dest : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
   return %2 : tensor<1x2x56x57x32xf32>
 }
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK:     func.func @elem_pack_transpose_inner_and_outer_dims2
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:       %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<2x32xf32>
-// CHECK:       %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:    inner_dims_pos = [0] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG0_EMPTY]]
-// CHECK:       %[[RES:.+]] = linalg.generic
-// CHECK-SAME:    indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME:    ins(%[[PACKED_ARG0]]
-// CHECK-SAME:    outs(%[[DEST]]
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @elem_pack_transpose_inner_and_outer_dims2
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<2x32xf32>
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME:    into %[[ARG0_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
 
 // -----
 
@@ -253,27 +253,27 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x
     into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
   return %4 : tensor<100x200x4x16x16x32xi32>
 }
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
-// CHECK:     func.func @transpose_pack
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:       %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
-// CHECK:       %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:    inner_dims_pos = [3, 1] inner_tiles = [16, 32]
-// CHECK-SAME:  into %[[ARG0_EMPTY]]
-// CHECK:       %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
-// CHECK:       %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
-// CHECK-SAME:    inner_dims_pos = [0] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG2_EMPTY]]
-// CHECK:       %[[RES:.+]] = linalg.generic
-// CHECK-SAME:    indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
-// CHECK-SAME:    ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
-// CHECK-SAME:    outs(%[[DEST]]
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
+// CHECK-LABEL: func.func @transpose_pack
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      inner_dims_pos = [3, 1] inner_tiles = [16, 32]
+// CHECK-SAME:    into %[[ARG0_EMPTY]]
+// CHECK:         %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
+// CHECK:         %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME:    into %[[ARG2_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
+// CHECK-SAME:      ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
+// CHECK-SAME:      outs(%[[DEST]]
 
 // -----
 
@@ -299,27 +299,27 @@ func.func @affine_constant_expr_pack(%arg0: tensor<100x128x200x256xi32>, %arg1:
     into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
   return %4 : tensor<100x200x4x16x16x32xi32>
 }
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, 0, 0, 0)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (0, d1, 0, 0, d5)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
-// CHECK:     func.func @affine_constant_expr_pack
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:       %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
-// CHECK:       %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:    inner_dims_pos = [3, 1] inner_tiles = [16, 32]
-// CHECK-SAME:  into %[[ARG0_EMPTY]]
-// CHECK:       %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<1x4x1x1x32xi32>
-// CHECK:       %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
-// CHECK-SAME:    inner_dims_pos = [1] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG2_EMPTY]]
-// CHECK:       %[[RES:.+]] = linalg.generic
-// CHECK-SAME:    indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
-// CHECK-SAME:    ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
-// CHECK-SAME:    outs(%[[DEST]]
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, 0, 0, 0)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (0, d1, 0, 0, d5)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
+// CHECK-LABEL: func.func @affine_constant_expr_pack
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      inner_dims_pos = [3, 1] inner_tiles = [16, 32]
+// CHECK-SAME:    into %[[ARG0_EMPTY]]
+// CHECK:         %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<1x4x1x1x32xi32>
+// CHECK:         %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [32]
+// CHECK-SAME:    into %[[ARG2_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
+// CHECK-SAME:      ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
+// CHECK-SAME:      outs(%[[DEST]]
 
 // -----
 
@@ -347,26 +347,26 @@ func.func @transpose_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a
   return %4 : tensor<200x4x16x100x16x32xi32>
 }
 
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
-// CHECK:     func.func @transpose_pack_with_outer_dims
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<200x4x16x100x16x32xi32>
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:  outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32]
-// CHECK-SAME:  into %[[ARG0_EMPTY]]
-// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
-// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
-// CHECK-SAME:  inner_dims_pos = [0] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG2_EMPTY]]
-// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP]]]
-// CHECK-SAME:  ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
-// CHECK-SAME:  outs(%[[DEST]]
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
+// CHECK-LABEL: func.func @transpose_pack_with_outer_dims
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<200x4x16x100x16x32xi32>
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK:         %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
+// CHECK:         %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG2_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]]
+// CHECK-SAME:      ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
+// CHECK-SAME:      outs(%[[DEST]]
 
 // -----
 
@@ -388,22 +388,22 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
   return %pack : tensor<16x4x32x16xi32>
 }
 
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK: func.func @elem_pack_transpose_outer_dims
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
-// CHECK: %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]]
-// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG1_EMPTY]]
-// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG0_EMPTY]]
-// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP]]]
-// CHECK-SAME:  ins(%[[PACKED_ARG0]]
-// CHECK-SAME:  outs(%[[PACKED_ARG1]]
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK:         %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:      into %[[ARG1_EMPTY]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      outs(%[[PACKED_ARG1]]
 
 // -----
 
@@ -420,23 +420,23 @@ func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56
   return %2 : tensor<12x56x56x64xf32>
 }
 
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK: func.func @unpack_on_output
-// CHECK-SAME:  %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG0_EMPTY_UNPACK]]
-// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG0_EMPTY_PACK]]
-// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[MAP]]]
-// CHECK-SAME:  outs(%[[PACKED_ARG0]]
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG0_EMPTY_UNPACK]]
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @unpack_on_output
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG0_EMPTY_UNPACK]]
+// CHECK:         %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG0_EMPTY_PACK]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]]]
+// CHECK-SAME:      outs(%[[PACKED_ARG0]]
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG0_EMPTY_UNPACK]]
 
 // -----
 
@@ -453,29 +453,29 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
   return %2 : tensor<12x56x56x64xf32>
 }
 
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK: func.func @unpack_on_input
-// CHECK-SAME:  %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] 
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[ARG1_PACK_EMPTY]]
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[ARG0_PACK_EMPTY]]
-// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP]]]
-// CHECK-SAME:  ins(%[[ARG0_PACK]]
-// CHECK-SAME:  outs(%[[ARG1_PACK]]
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]] 
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @unpack_on_input
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
+// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK:         %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
+// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK:         %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:      ins(%[[ARG0_PACK]]
+// CHECK-SAME:      outs(%[[ARG1_PACK]]
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 
 // -----
 
@@ -492,30 +492,30 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
   return %2 : tensor<12x56x56x64xf16>
 }
 
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK: func.func @unpack_element_type_change
-// CHECK-SAME:  %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG1_PACK_EMPTY]]
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG0_PACK_EMPTY]]
-// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP]]]
-// CHECK-SAME:  ins(%[[ARG0_PACK]]
-// CHECK-SAME:  outs(%[[ARG1_PACK]]
-// CHECK: %[[ARG0_NEW_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf16>
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG0_NEW_EMPTY_UNPACK]]
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @unpack_element_type_change
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
+// CHECK:         %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
+// CHECK:         %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG1_PACK_EMPTY]]
+// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK:         %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:      ins(%[[ARG0_PACK]]
+// CHECK-SAME:      outs(%[[ARG1_PACK]]
+// CHECK:         %[[ARG0_NEW_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf16>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG0_NEW_EMPTY_UNPACK]]
 
 // -----
 
@@ -533,29 +533,29 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
   return %2 : tensor<12x56x56x64xf32>
 }
 
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK: func.func @forward_tensor_empty
-// CHECK-SAME:  %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
-// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[ARG0_PACK_EMPTY]]
-// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP]]]
-// CHECK-SAME:  ins(%[[PACKED_ARG0]]
-// CHECK-SAME:  outs(%[[DEST]]
-// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @forward_tensor_empty
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
+// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK:         %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      into %[[ARG0_PACK_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
+// CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      into %[[ARG0_UNPACK_EMPTY]]
 
 // -----
 
-func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x64xf32> {
+func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x64xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<1x56x56x64xf32>
   %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
@@ -566,18 +566,18 @@ func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58
   return %padded : tensor<1x58x58x64xf32>
 }
 
-// CHECK: func.func @pad_valid_propagation(
-// CHECK-SAME:  %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
-// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] 
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
+// CHECK-LABEL: func.func @pad_valid_unpack_propagation(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
 
 // -----
 
-func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58x58x64xf32> {
+func.func @pad_valid_unpack_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58x58x64xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<1x56x56x64xf32>
   %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
@@ -588,14 +588,14 @@ func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58
   return %padded : tensor<2x58x58x64xf32>
 }
 
-// CHECK: func.func @pad_valid_propagation(
-// CHECK-SAME:  %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
-// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[1, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x58x58x64xf32>
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME:  into %[[EMPTY]] : tensor<2x2x58x58x32xf32> -> tensor<2x58x58x64xf32>
+// CHECK-LABEL: func.func @pad_valid_unpack_propagation(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[1, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<2x58x58x64xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:      into %[[EMPTY]] : tensor<2x2x58x58x32xf32> -> tensor<2x58x58x64xf32>
 
 // -----
 
@@ -610,14 +610,80 @@ func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x5
   return %padded : tensor<1x58x58x66xf32>
 }
 
-// CHECK: func.func @pad_along_unpacked_dim(
-// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
-// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
-// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
+// CHECK-LABEL: func.func @pad_along_unpacked_dim(
+// CHECK:         %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:      into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
+
+// -----
+
+func.func @pad_valid_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x2x58x58x32xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32>
+  %0 = tensor.empty() : tensor<1x2x58x58x32xf32>
+  %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
+  return %1 : tensor<1x2x58x58x32xf32>
+}
+
+// CHECK-LABEL: func.func @pad_valid_pack_propagation(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
+// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x56x56x32xf32>
+// CHECK:         %[[PACKED:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK-SAME:      into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x2x56x56x32xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK:         return %[[PADDED]]
+
+// -----
+
+func.func @pad_valid_outer_dims_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x58x58x2x32xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32>
+  %0 = tensor.empty() : tensor<1x58x58x2x32xf32>
+  %1 = tensor.pack %padded outer_dims_perm = [0, 3, 2, 1] inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x58x58x2x32xf32>
+  return %1 : tensor<1x58x58x2x32xf32>
+}
+
+// CHECK-LABEL: func.func @pad_valid_outer_dims_pack_propagation(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x64x56x56xf32>)
+// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x2x32xf32>
+// CHECK:         %[[PACKED:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 2, 1] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK-SAME:      into %[[EMPTY]] : tensor<1x64x56x56xf32> -> tensor<1x56x56x2x32xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[PACKED]] low[0, 1, 1, 0, 0] high[0, 1, 1, 0, 0]
+// CHECK:         return %[[PADDED]]
+
+// -----
+
+func.func @pad_along_packed_dim(%arg0: tensor<1x60x56x56xf32>) -> tensor<1x2x58x58x32xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %arg0 low[0, 2, 1, 1] high[0, 2, 1, 1] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x60x56x56xf32> to tensor<1x64x58x58xf32>
+  %0 = tensor.empty() : tensor<1x2x58x58x32xf32>
+  %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
+  return %1 : tensor<1x2x58x58x32xf32>
+}
+
+// CHECK-LABEL: func.func @pad_along_packed_dim(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<1x60x56x56xf32>)
+// CHECK:         %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 2, 1, 1] high[0, 2, 1, 1]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x58x58x32xf32>
+// CHECK:         tensor.pack %[[PADDED]] inner_dims_pos = [1] inner_tiles = [32]
+// CHECK-SAME:      into %[[EMPTY]] : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32>
 
 // -----
 
@@ -639,16 +705,16 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3
   return %pack : tensor<4x16x16x32xi32>
 }
 
-// CHECK: func.func @would_break_dominance(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<128x256xi32>)
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x256xi32>
-// CHECK-NEXT: %[[GEN:.+]] = linalg.generic
-// CHECK-SAME:  ins(%[[ARG0]]
-// CHECK-SAME:  outs(%[[EMPTY]]
-// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
-// CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]]
-// CHECK-SAME:  inner_dims_pos = [1, 0] inner_tiles = [16, 32] 
-// CHECK-SAME:  into %[[ALLOC]]
+// CHECK-LABEL: func.func @would_break_dominance(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<128x256xi32>)
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x256xi32>
+// CHECK-NEXT:    %[[GEN:.+]] = linalg.generic
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
+// CHECK:         %[[ALLOC:.+]] = bufferization.alloc_tensor() : tensor<4x16x16x32xi32>
+// CHECK-NEXT:    %{{.+}} = tensor.pack %[[GEN]]
+// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32] 
+// CHECK-SAME:      into %[[ALLOC]]
 
 // -----
 
@@ -666,16 +732,16 @@ func.func @scalar_tensor(%arg0 : tensor<f32>) -> tensor<1x32x7x7x32xf32> {
   return %pack : tensor<1x32x7x7x32xf32>
 }
 
-// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK: func.func @scalar_tensor
-// CHECK-SAME: %[[ARG0:.+]]: tensor<f32>)
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x32x7x7x32xf32>
-// CHECK: linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME: ins(%[[ARG0]]
-// CHECK-SAME: outs(%[[EMPTY]]
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @scalar_tensor
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<f32>)
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<1x32x7x7x32xf32>
+// CHECK:         linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP1]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:      ins(%[[ARG0]]
+// CHECK-SAME:      outs(%[[EMPTY]]
 
 // -----
 
@@ -692,15 +758,15 @@ func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x
   return %2 : tensor<12x56x56x64xf32>
 }
 
-// CHECK: func.func @unpack_empty_inner_dims
-// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
-// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME:  ins(%[[PACKED_ARG0]]
-// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
-// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
+// CHECK-LABEL: func.func @unpack_empty_inner_dims
+// CHECK:         %[[UNPACKED_ARG0:.+]] = tensor.unpack
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK:         %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME:      outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] 
 
 // -----
 
@@ -722,25 +788,25 @@ func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>,
     into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32>
   return %pack : tensor<4x16x16x32xi32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
-// CHECK:      func.func @reduction_pack_transpose_inner_dims
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
-// CHECK:        %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
-// CHECK-SME:     inner_dims_pos = [1, 0] inner_tiles = [16, 32]
-// CHECK-SAME:    into %[[ARG1_EMPTY]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16x32xi32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [1, 0] inner_tiles = [16, 32]
-// CHECK-SAME:     into %[[ARG0_EMPTY]]
-// CHECK:        %[[RED:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[PACK_ARG1]]
-// CHECK:        return %[[RED]] : tensor<4x16x16x32xi32>
+// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+// CHECK-LABEL: func.func @reduction_pack_transpose_inner_dims
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
+// CHECK:         %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME:     into %[[ARG1_EMPTY]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16x32xi32>
+// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK:         %[[RED:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel"]
+// CHECK-SAME:      ins(%[[PACK_ARG0]]
+// CHECK-SAME:      outs(%[[PACK_ARG1]]
+// CHECK:         return %[[RED]] : tensor<4x16x16x32xi32>
 
 // -----
 
@@ -770,31 +836,31 @@ func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %a
   return %4 : tensor<4x16x100x16x32xi32>
 }
 
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d5)>
-// CHECK:     func.func @reduction_pack_with_outer_dims
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG3_EMPTY:.+]] = tensor.empty() : tensor<4x16x100x16x32xi32>
-// CHECK: %[[PACKED_ARG3:.+]] = tensor.pack %[[ARG3]]
-// CHECK-SAME:  outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 32]
-// CHECK-SAME:  into %[[ARG3_EMPTY]]
-// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x200x100x16x32xi32>
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:  outer_dims_perm = [1, 3, 2, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32]
-// CHECK-SAME:  into %[[ARG0_EMPTY]]
-// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
-// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
-// CHECK-SAME:  inner_dims_pos = [0] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG2_EMPTY]]
-// CHECK: %[[RES:.+]] = linalg.generic
-// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
-// CHECK-SAME:  ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
-// CHECK-SAME:  outs(%[[PACKED_ARG3]]
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5)>
+// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d5)>
+// CHECK-LABEL: func.func @reduction_pack_with_outer_dims
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG3_EMPTY:.+]] = tensor.empty() : tensor<4x16x100x16x32xi32>
+// CHECK:         %[[PACKED_ARG3:.+]] = tensor.pack %[[ARG3]]
+// CHECK-SAME:      outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 32]
+// CHECK-SAME:      into %[[ARG3_EMPTY]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x200x100x16x32xi32>
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [1, 3, 2, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK:         %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
+// CHECK:         %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME:      into %[[ARG2_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
+// CHECK-SAME:      ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
+// CHECK-SAME:      outs(%[[PACKED_ARG3]]
 
 // -----
 
@@ -818,24 +884,24 @@ func.func @unpack_
diff erent_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32
   } -> tensor<16x540x960xi32>
   return %pool : tensor<16x540x960xi32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5, d6)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5)>
-// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d3, d6)>
-// CHECK:      func.func @unpack_
diff erent_destination_shape
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK:        %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
-// CHECK:        %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack
-// CHECK-SAME:     inner_dims_pos = [1] inner_tiles = [16]
-// CHECK-SAME:     into %[[PACK_EMPTY]]
-// CHECK:        %[[POOL:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]], %[[ARG1]]
-// CHECK-SAME:     outs(%[[INIT]]
-// CHECK:        %[[UNPACK_NEW_DEST:.+]] = tensor.empty() : tensor<16x540x960xi32>
-// CHECK:        %[[UNPACK:.+]] = tensor.unpack %[[POOL]]
-// CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [16]
-// CHECK-SAME:     into %[[UNPACK_NEW_DEST]]
-// CHECK:        return %[[UNPACK]] : tensor<16x540x960xi32>
+// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5, d6)>
+// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5)>
+// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d3, d6)>
+// CHECK-LABEL: func.func @unpack_
diff erent_destination_shape
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
+// CHECK:         %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
+// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack
+// CHECK-SAME:      inner_dims_pos = [1] inner_tiles = [16]
+// CHECK-SAME:      into %[[PACK_EMPTY]]
+// CHECK:         %[[POOL:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+// CHECK-SAME:      ins(%[[PACK_ARG0]], %[[ARG1]]
+// CHECK-SAME:      outs(%[[INIT]]
+// CHECK:         %[[UNPACK_NEW_DEST:.+]] = tensor.empty() : tensor<16x540x960xi32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[POOL]]
+// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [16]
+// CHECK-SAME:      into %[[UNPACK_NEW_DEST]]
+// CHECK:         return %[[UNPACK]] : tensor<16x540x960xi32>


        


More information about the Mlir-commits mailing list