[Mlir-commits] [mlir] 6bb0ab0 - [MLIR] Propagate unpack through element-wise ops
Lorenzo Chelini
llvmlistbot at llvm.org
Wed Feb 1 00:11:34 PST 2023
Author: Lorenzo Chelini
Date: 2023-02-01T09:11:28+01:00
New Revision: 6bb0ab0de039ddbcc70eafc549b50da7867fb617
URL: https://github.com/llvm/llvm-project/commit/6bb0ab0de039ddbcc70eafc549b50da7867fb617
DIFF: https://github.com/llvm/llvm-project/commit/6bb0ab0de039ddbcc70eafc549b50da7867fb617.diff
LOG: [MLIR] Propagate unpack through element-wise ops
Introduce `pushDownUnPackOpThroughElemGenericOp` to propagate producer
unpack operation through an element-wise linalg.generic operation. This
pattern complements `BubbleUpPackOpThroughElemGenericOp`. The general
idea is to bubble up tensor.pack as much as possible while pushing down
tensor.unpack as much as possible, and canonicalize away symmetrical
tensor.pack and tensor.unpack operations.
Currently, `pushDownUnPackOpThroughElemGenericOp` expects a single
tensor.unpack operation as the producer of one of the linalg.generic's
operands.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D142523
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index df2b15dca65a8..1b6d1d247a2d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -46,19 +46,22 @@ struct PackInfo {
SmallVector<int64_t> outerDimsOnDomainPerm;
};
-static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap,
- tensor::PackOp packOp) {
+template <typename OpTy>
+static PackInfo getPackingInfoFromOperand(AffineMap indexingMap,
+ OpTy packOrUnPackOp) {
+ static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
+ "applies to only pack or unpack operations");
LLVM_DEBUG(
- { llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; });
+ { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
PackInfo packInfo;
int64_t origNumDims = indexingMap.getNumDims();
SmallVector<AffineExpr> exprs(indexingMap.getResults());
- ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+ ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
for (auto [index, innerDimPos, tileSize] :
llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
- innerDimsPos, packOp.getMixedTiles())) {
+ innerDimsPos, packOrUnPackOp.getMixedTiles())) {
int64_t domainDimPos =
- exprs[innerDimPos].cast<AffineDimExpr>().getPosition();
+ exprs[innerDimPos].template cast<AffineDimExpr>().getPosition();
packInfo.tiledDimsPos.push_back(domainDimPos);
packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
@@ -71,7 +74,7 @@ static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap,
});
}
- for (auto dim : packOp.getOuterDimsPerm())
+ for (auto dim : packOrUnPackOp.getOuterDimsPerm())
packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim));
if (!packInfo.outerDimsOnDomainPerm.empty()) {
LLVM_DEBUG({
@@ -209,6 +212,35 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
return std::make_tuple(packedOperand, indexingMap);
}
+/// Pack an element-wise genericOp and return it.
+static GenericOp packElementWiseOp(RewriterBase &rewriter, GenericOp genericOp,
+ Value dest, AffineMap packedOutIndexingMap,
+ const PackInfo &packInfo) {
+ Location loc = genericOp.getLoc();
+ SmallVector<Value> inputOperands;
+ SmallVector<AffineMap> indexingMaps;
+ for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
+ auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
+ rewriter, loc, packInfo, genericOp, inputOperand);
+ inputOperands.push_back(packedOperand);
+ indexingMaps.push_back(packedIndexingMap);
+ }
+
+ int64_t numInnerLoops = packInfo.getNumTiledLoops();
+ SmallVector<utils::IteratorType> iterTypes =
+ genericOp.getIteratorTypesArray();
+ iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
+
+ indexingMaps.push_back(packedOutIndexingMap);
+
+ auto newGenericOp = rewriter.create<linalg::GenericOp>(
+ loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+ rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+ newGenericOp.getRegion().begin());
+ return newGenericOp;
+}
+
/// Bubbles up tensor.pack op through elementwise generic op. This
/// swap pack(generic) to generic(pack). The new generic op works on packed
/// domain; pack ops are created for input and output operands. E.g.,
@@ -275,29 +307,13 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
return failure();
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
- auto packInfo = getPackingInfoFromConsumer(
+ auto packInfo = getPackingInfoFromOperand(
genericOp.getMatchingIndexingMap(opOperand), packOp);
- Location loc = packOp.getLoc();
- SmallVector<Value> inputOperands;
- SmallVector<AffineMap> indexingMaps;
- for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
- auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
- rewriter, loc, packInfo, genericOp, inputOperand);
- inputOperands.push_back(packedOperand);
- indexingMaps.push_back(packedIndexingMap);
- }
-
- int64_t numInnerLoops = packInfo.getNumTiledLoops();
- SmallVector<utils::IteratorType> iterTypes =
- genericOp.getIteratorTypesArray();
- iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
-
// Rebuild the indexing map for the corresponding init operand.
auto [packedOutOperand, packedOutIndexingMap] =
- getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp,
- opOperand);
- indexingMaps.push_back(packedOutIndexingMap);
+ getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo,
+ genericOp, opOperand);
// We'll replace the init operand with the destination of pack op if the init
// operand has not users in the body of the linalg.generic (pure elementwise).
@@ -306,15 +322,12 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
Value dest = (genericOp.getRegionOutputArgs()[0].use_empty())
? packOp.getDest()
: packedOutOperand;
- auto newGenericOp = rewriter.create<linalg::GenericOp>(
- loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
- /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
- rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
- newGenericOp.getRegion().begin());
- return newGenericOp;
+
+ return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap,
+ packInfo);
}
-// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
+/// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
struct BubbleUpPackOpThroughElemGenericOpPattern
: public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
@@ -328,10 +341,134 @@ struct BubbleUpPackOpThroughElemGenericOpPattern
return success();
}
};
+
+// TODO: Relax this restriction. We should unpack an elementwise also
+// in the presence of multiple unpack ops as producers.
+/// Return the unpacked operand, if present, for the current generic op.
+static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
+ OpOperand *unPackedOperand = nullptr;
+ for (OpOperand &operand : genericOp->getOpOperands()) {
+ auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
+ if (!unPackOp)
+ continue;
+ if (unPackedOperand)
+ return failure();
+ unPackedOperand = &operand;
+ }
+ if (!unPackedOperand)
+ return failure();
+ return unPackedOperand;
+}
+
+/// Push down a tensor.unpack op through elementwise generic op.
+/// The new generic op works on packed domain; pack ops are created for input
+/// and output operands. A tensor.unpack op is inserted right after the packed
+/// generic. E.g.
+///
+/// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+///
+/// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
+///
+/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
+/// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
+/// inner_dims_pos = [3] inner_tiles = [32] into %0
+/// %2 = linalg.generic {indexing_maps = [#map],
+/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+/// outs(%1 : tensor<12x56x56x64xf32>) {
+/// ^bb0(%out : f32):
+/// linalg.yield %out : f32
+/// } -> tensor<12x56x56x64xf32>
+///
+/// will be converted to
+///
+/// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+///
+/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
+/// %1 = linalg.generic {indexing_maps = [#map],
+/// iterator_types = ["parallel", "parallel", "parallel",
+/// "parallel", "parallel"]}
+/// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
+/// ^bb0(%out : f32):
+/// linalg.yield %out : f32
+/// } -> tensor<12x2x56x56x32xf32>
+/// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
+/// inner_dims_pos = [3] inner_tiles = [32] into %0
+///
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownUnPackOpThroughElemGenericOp(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ if (!isElementwise(genericOp))
+ return failure();
+ if (genericOp.getNumResults() != 1)
+ return failure();
+
+ // Collect the unPacked operand, if present.
+ auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
+ if (failed(maybeUnPackedOperand))
+ return failure();
+ OpOperand *unPackedOperand = *(maybeUnPackedOperand);
+
+ // Extract packing information.
+ tensor::UnPackOp producerUnPackOp =
+ unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
+ assert(producerUnPackOp && "expect a valid UnPackOp");
+ auto packInfo = getPackingInfoFromOperand(
+ genericOp.getMatchingIndexingMap(unPackedOperand), producerUnPackOp);
+
+ // Rebuild the indexing map for the corresponding init operand.
+ auto [packedOutOperand, packedOutIndexingMap] =
+ getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo,
+ genericOp, genericOp.getDpsInitOperand(0));
+
+ // If the dps init operand of the generic is a tensor.empty, do not pack it
+ // and forward the new tensor.empty as a destination.
+ Value dest = packedOutOperand;
+ if (auto initTensor = genericOp.getDpsInitOperand(0)
+ ->get()
+ .getDefiningOp<tensor::EmptyOp>()) {
+ if (auto packOp = packedOutOperand.getDefiningOp<tensor::PackOp>())
+ dest = packOp.getDest();
+ }
+
+ // Pack the genericOp.
+ GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest,
+ packedOutIndexingMap, packInfo);
+
+ auto unPackOp = unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
+ // Insert an unPackOp right after the packed generic.
+ Value unPackOpRes =
+ rewriter
+ .create<tensor::UnPackOp>(
+ genericOp.getLoc(),
+ newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+ unPackOp.getDest(), producerUnPackOp.getInnerDimsPos(),
+ producerUnPackOp.getMixedTiles(),
+ producerUnPackOp.getOuterDimsPerm())
+ .getResult();
+
+ return std::make_tuple(newGenericOp, unPackOpRes);
+}
+
+// Wrapper pattern that applies pushDownUnPackOpThroughElemGenericOp method.
+struct PushDownUnPackOpThroughElemGenericOp
+ : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ auto genericAndRepl =
+ pushDownUnPackOpThroughElemGenericOp(rewriter, genericOp);
+ if (failed(genericAndRepl))
+ return failure();
+ rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+ return success();
+ }
+};
+
} // namespace
void mlir::linalg::populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns) {
- patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern>(
- patterns.getContext());
+ patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern,
+ PushDownUnPackOpThroughElemGenericOp>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cd9d3acc7a635..b699b3d19cee6 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -352,15 +352,123 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
// CHECK: func.func @elem_pack_transpose_outer_dims
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[ARG0_EMPTY]]
// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
// CHECK: %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]]
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
// CHECK-SAME: into %[[ARG1_EMPTY]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: outs(%[[PACKED_ARG1]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
+ %0 = tensor.empty() : tensor<12x56x56x64xf32>
+ %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
+ %2 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%1 : tensor<12x56x56x64xf32>) {
+ ^bb0(%out: f32):
+ %3 = arith.addf %out, %out : f32
+ linalg.yield %3 : f32
+ } -> tensor<12x56x56x64xf32>
+ return %2 : tensor<12x56x56x64xf32>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: func.func @unpack_on_output
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
+// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG0_EMPTY_PACK]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]]]
+// CHECK-SAME: outs(%[[PACKED_ARG0]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG0_EMPTY_UNPACK]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf32>) -> tensor<12x56x56x64xf32> {
+ %0 = tensor.empty() : tensor<12x56x56x64xf32>
+ %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
+ %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %3 = arith.addf %in, %out : f32
+ linalg.yield %3 : f32
+ } -> tensor<12x56x56x64xf32>
+ return %2 : tensor<12x56x56x64xf32>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: func.func @unpack_on_input
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
+// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
+// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: ins(%[[ARG0_PACK]]
+// CHECK-SAME: outs(%[[ARG1_PACK]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
+ %init = tensor.empty() : tensor<12x56x56x64xf32>
+ %0 = tensor.empty() : tensor<12x56x56x64xf32>
+ %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
+ %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %3 = arith.addf %in, %in : f32
+ linalg.yield %3 : f32
+ } -> tensor<12x56x56x64xf32>
+ return %2 : tensor<12x56x56x64xf32>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: func.func @forward_tensor_empty
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
+// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: ins(%[[PACKED_ARG0]]
+// CHECK-SAME: outs(%[[DEST]]
+// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
More information about the Mlir-commits
mailing list