[Mlir-commits] [mlir] [mlir][linalg] Add pattern to propagate pack up through tensor.pad (PR #82035)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 16 11:56:17 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

This mirrors the existing pattern for pushing unpack down through
padding, restricting to cases where the padded dimensions aren't tiled
by the pack.

Additionally reformats the propagation test to make it easier to read. The
test reformatting change is kept as a separate PR for ease of review.

---

Patch is 59.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82035.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+86-3) 
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+452-386) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 6a971b37cad7c5..5ceb85e7d9903b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -470,6 +470,88 @@ struct BubbleUpPackOpThroughGenericOpPattern
   ControlPropagationFn controlFn;
 };
 
+/// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
+/// add as many zero padding dimensions in `high` and `low` based on the number
+/// of point loops.
+class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
+public:
+  BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
+      : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+
+    // User controlled propagation function.
+    if (!controlFn(padOp))
+      return failure();
+
+    if (!padOp.getResult().hasOneUse())
+      return failure();
+
+    // TODO: Enable padding when the padding values are the same.
+    if (packOp.getPaddingValue())
+      return failure();
+
+    // Fail for non-constant padding values. The body of the pad could
+    // depend on the padding indices and/or properties of the padded
+    // tensor so for now we fail.
+    // TODO: Support non-constant padding values.
+    Value paddingVal = padOp.getConstantPaddingValue();
+    if (!paddingVal)
+      return failure();
+
+    if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+      return failure();
+
+    ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+    ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+
+    // Bail out if one of the padded dimension is a tiled one.
+    llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
+    llvm::SmallBitVector innerDims(paddedDims.size());
+    for (int64_t dim : innerDimsPos)
+      innerDims.flip(dim);
+    if (paddedDims.anyCommon(innerDims))
+      return failure();
+
+    Location loc = padOp->getLoc();
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(padOp);
+
+    auto empty = tensor::PackOp::createDestinationTensor(
+        rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
+        outerDimsPerm);
+    Value packedSource = rewriter.create<tensor::PackOp>(
+        loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
+        /*padding=*/std::nullopt, outerDimsPerm);
+
+    // If we have `outer_dims_perms` we need to adjust the padded dimensions.
+    SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
+    if (!outerDimsPerm.empty()) {
+      applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
+      applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
+    }
+    // The tiled dimensions were verified to be unpadded above, so here we
+    // just append 0 for the inner tile dimensions.
+    size_t pointLoopsSize = innerDimsPos.size();
+    lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+    highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
+        padOp.getNofold());
+    rewriter.replaceOp(packOp, newPadOp.getResult());
+    return success();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 // TODO: Relax this restriction. We should unpack a generic op also
 // in the presence of multiple unpack ops as producers.
 /// Return the unpacked operand, if present, for the current generic op.
@@ -690,7 +772,8 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
 void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation) {
-  patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
-                  PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
-      patterns.getContext(), controlPackUnPackPropagation);
+  patterns
+      .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
+              PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
+          patterns.getContext(), controlPackUnPackPropagation);
 }
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 4c59c97aecc251..e036695a2ac9fd 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -21,28 +21,28 @@ func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>)
     into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
   return %4 : tensor<?x?x8x2xf32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK:      func.func @dynamic_elem_pack
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:    %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG:    %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK-DAG:    %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
-// CHECK-DAG:    %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor<?x?x8x2xf32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [0, 1] inner_tiles = [8, 2]
-// CHECK-SAME:     into %[[ARG0_EMPTY]]
-// CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[DEST]]
-// CHECK:        return %[[ELEM]] : tensor<?x?x8x2xf32>
+// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
+// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL:  func.func @dynamic_elem_pack
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG:      %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:      %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:      %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG:      %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG:      %[[OUTER_D0:.+]] = affine.apply #[[$MAP0]]()[%[[D0]]]
+// CHECK-DAG:      %[[OUTER_D1:.+]] = affine.apply #[[$MAP1]]()[%[[D1]]]
+// CHECK:          %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor<?x?x8x2xf32>
+// CHECK:          %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:       inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+// CHECK-SAME:       into %[[ARG0_EMPTY]]
+// CHECK:          %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[$MAP2]], #[[$MAP2]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:       ins(%[[PACK_ARG0]]
+// CHECK-SAME:       outs(%[[DEST]]
+// CHECK:          return %[[ELEM]] : tensor<?x?x8x2xf32>
 
 // -----
 
@@ -62,20 +62,20 @@ func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: ten
     into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32>
   return %pack : tensor<4x16x16x32xi32>
 }
-// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK:      func.func @elem_pack_transpose_inner_dims
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [1, 0] inner_tiles = [16, 32]
-// CHECK-SAME:     into %[[ARG0_EMPTY]]
-// CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[DEST]]
-// CHECK:        return %[[ELEM]] : tensor<4x16x16x32xi32>
+// CHECK-DAG:  #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_inner_dims
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
+// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK:         %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:      ins(%[[PACK_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
+// CHECK:         return %[[ELEM]] : tensor<4x16x16x32xi32>
 
 // -----
 
@@ -96,20 +96,20 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: ten
     into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
   return %pack : tensor<16x4x32x16xi32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK:      func.func @elem_pack_transpose_outer_dims
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:     into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
-// CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[DEST]]
-// CHECK:        return %[[ELEM]] : tensor<16x4x32x16xi32>
+// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:      into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+// CHECK:         %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:      ins(%[[PACK_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
+// CHECK:         return %[[ELEM]] : tensor<16x4x32x16xi32>
 
 // -----
 
@@ -130,20 +130,20 @@ func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>,
     into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32>
   return %pack : tensor<16x4x16x32xi32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK:      func.func @elem_pack_transpose_inner_and_outer_dims
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
-// CHECK-SAME:     into %[[ARG0_EMPTY]]
-// CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[DEST]]
-// CHECK:        return %[[ELEM]] : tensor<16x4x16x32xi32>
+// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_inner_and_outer_dims
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
+// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK:         %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:      ins(%[[PACK_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
+// CHECK:         return %[[ELEM]] : tensor<16x4x16x32xi32>
 
 // -----
 
@@ -169,34 +169,34 @@ func.func @dynamic_broadcast_pack(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %d
     into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
   return %4 : tensor<?x?x8x2xf32>
 }
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
-// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
-// CHECK-DAG:  #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
-// CHECK-DAG:  #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-// CHECK:      func.func @dynamic_broadcast_pack
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG:    %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
-// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor<?x8xf32>
-// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [8]
-// CHECK-SAME:     into %[[ARG0_EMPTY]]
-// CHECK-DAG:    %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C0]]
-// CHECK-DAG:    %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
-// CHECK:        %[[ARG1_EMPTY:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor<?x2xf32>
-// CHECK:        %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
-// CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [2]
-// CHECK-SAME:     into %[[ARG1_EMPTY]]
-// CHECK:        %[[ELEM:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]]
-// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
-// CHECK-SAME:     ins(%[[PACK_ARG0]], %[[PACK_ARG0]]
-// CHECK-SAME:     outs(%[[DEST]]
-// CHECK:        return %[[ELEM]] : tensor<?x?x8x2xf32>
+// CHECK-DAG:  #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG:  #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
+// CHECK-DAG:  #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+// CHECK-DAG:  #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK-DAG:  #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @dynamic_broadcast_pack
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:     %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG:     %[[OUTER_D0:.+]] = affine.apply #[[$MAP0]]()[%[[D0]]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor<?x8xf32>
+// CHECK:         %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [8]
+// CHECK-SAME:      into %[[ARG0_EMPTY]]
+// CHECK-DAG:     %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG:     %[[OUTER_D1:.+]] = affine.apply #[[$MAP1]]()[%[[D1]]]
+// CHECK:         %[[ARG1_EMPTY:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor<?x2xf32>
+// CHECK:         %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [2]
+// CHECK-SAME:      into %[[ARG1_EMPTY]]
+// CHECK:         %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP4]]]
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:      ins(%[[PACK_ARG0]], %[[PACK_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
+// CHECK:         return %[[ELEM]] : tensor<?x?x8x2xf32>
 
 // -----
 
@@ -215,19 +215,19 @@ func.func @elem_pack_transpose_inner_and_outer_dims2(%arg0: tensor<64xf32>, %des
   %2 = tensor.pack %1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %dest : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
   return %2 : tensor<1x2x56x57x32xf32>
 }
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK:     func.func @elem_pack_transpose_inner_and_outer_dims2
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:       %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<2x32xf32>
-// CHECK:       %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:    inner_dims_pos = [0] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG0_EMPTY]]
-// CHECK:       %[[RES:.+]] = linalg.generic
-// CHECK-SAME:    indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME:    ins(%[[PACKED_ARG0]]
-// CHECK-SAME:    outs(%[[DEST]]
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-LABEL: func.func @elem_pack_transpose_inner_and_outer_dims2
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:     %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:         %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<2x32xf32>
+// CHECK:         %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME:    into %[[ARG0_EMPTY]]
+// CHECK:         %[[RES:.+]] = linalg.generic
+// CHECK-SAME:      indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME:      ins(%[[PACKED_ARG0]]
+// CHECK-SAME:      outs(%[[DEST]]
 
 // -----
 
@@ -253,27 +253,27 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x
     into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
   return %4 : tensor<100x200x4x16x16x32xi32>
 }
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
-// CHECK:     func.func @transpose_pack
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
-// CHECK:       %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
-// CHECK:       %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:    inner_dims_pos = [3, 1] inner_tiles = [16, 32]
-// CHECK-SAME:  into %[[ARG0_EMPTY]]
-// CHECK:       %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
-// CHECK:       %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
-// CHECK-SAME:    inner_dims_pos = [0] inner_tiles = [32]
-// CHECK-SAME:  into %[[ARG2_EMPTY]]
-// CHECK:       %[[RES:.+]] = linalg.generic
-// CHECK-SAME:    indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
-// CHECK-SAME:    ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
-// CHECK-SAME:    outs(%[[DEST]]
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82035


More information about the Mlir-commits mailing list