[Mlir-commits] [mlir] 6bb0ab0 - [MLIR] Propagate unpack through element-wise ops

Lorenzo Chelini llvmlistbot at llvm.org
Wed Feb 1 00:11:34 PST 2023


Author: Lorenzo Chelini
Date: 2023-02-01T09:11:28+01:00
New Revision: 6bb0ab0de039ddbcc70eafc549b50da7867fb617

URL: https://github.com/llvm/llvm-project/commit/6bb0ab0de039ddbcc70eafc549b50da7867fb617
DIFF: https://github.com/llvm/llvm-project/commit/6bb0ab0de039ddbcc70eafc549b50da7867fb617.diff

LOG: [MLIR] Propagate unpack through element-wise ops

Introduce `pushDownUnPackOpThroughElemGenericOp` to propagate producer
unpack operation through an element-wise linalg.generic operation. This
pattern complements `BubbleUpPackOpThroughElemGenericOp`. The general
idea is to bubble up tensor.pack as much as possible while pushing down
tensor.unpack as much as possible, and canonicalize away symmetrical
tensor.pack and tensor.unpack operations.

Currently, `pushDownUnPackOpThroughElemGenericOp` expects a single
tensor.unpack operation as the producer of one of the linalg.generic's
operands.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D142523

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
    mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index df2b15dca65a8..1b6d1d247a2d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -46,19 +46,22 @@ struct PackInfo {
   SmallVector<int64_t> outerDimsOnDomainPerm;
 };
 
-static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap,
-                                           tensor::PackOp packOp) {
+template <typename OpTy>
+static PackInfo getPackingInfoFromOperand(AffineMap indexingMap,
+                                          OpTy packOrUnPackOp) {
+  static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
+                "applies to only pack or unpack operations");
   LLVM_DEBUG(
-      { llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; });
+      { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
   PackInfo packInfo;
   int64_t origNumDims = indexingMap.getNumDims();
   SmallVector<AffineExpr> exprs(indexingMap.getResults());
-  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+  ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
   for (auto [index, innerDimPos, tileSize] :
        llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
-                       innerDimsPos, packOp.getMixedTiles())) {
+                       innerDimsPos, packOrUnPackOp.getMixedTiles())) {
     int64_t domainDimPos =
-        exprs[innerDimPos].cast<AffineDimExpr>().getPosition();
+        exprs[innerDimPos].template cast<AffineDimExpr>().getPosition();
     packInfo.tiledDimsPos.push_back(domainDimPos);
     packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
     packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
@@ -71,7 +74,7 @@ static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap,
     });
   }
 
-  for (auto dim : packOp.getOuterDimsPerm())
+  for (auto dim : packOrUnPackOp.getOuterDimsPerm())
     packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim));
   if (!packInfo.outerDimsOnDomainPerm.empty()) {
     LLVM_DEBUG({
@@ -209,6 +212,35 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
   return std::make_tuple(packedOperand, indexingMap);
 }
 
+/// Pack an element-wise genericOp and return it.
+static GenericOp packElementWiseOp(RewriterBase &rewriter, GenericOp genericOp,
+                                   Value dest, AffineMap packedOutIndexingMap,
+                                   const PackInfo &packInfo) {
+  Location loc = genericOp.getLoc();
+  SmallVector<Value> inputOperands;
+  SmallVector<AffineMap> indexingMaps;
+  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
+    auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
+        rewriter, loc, packInfo, genericOp, inputOperand);
+    inputOperands.push_back(packedOperand);
+    indexingMaps.push_back(packedIndexingMap);
+  }
+
+  int64_t numInnerLoops = packInfo.getNumTiledLoops();
+  SmallVector<utils::IteratorType> iterTypes =
+      genericOp.getIteratorTypesArray();
+  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
+
+  indexingMaps.push_back(packedOutIndexingMap);
+
+  auto newGenericOp = rewriter.create<linalg::GenericOp>(
+      loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+                             newGenericOp.getRegion().begin());
+  return newGenericOp;
+}
+
 /// Bubbles up tensor.pack op through elementwise generic op. This
 /// swap pack(generic) to generic(pack). The new generic op works on packed
 /// domain; pack ops are created for input and output operands. E.g.,
@@ -275,29 +307,13 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
     return failure();
 
   OpOperand *opOperand = genericOp.getDpsInitOperand(0);
-  auto packInfo = getPackingInfoFromConsumer(
+  auto packInfo = getPackingInfoFromOperand(
       genericOp.getMatchingIndexingMap(opOperand), packOp);
 
-  Location loc = packOp.getLoc();
-  SmallVector<Value> inputOperands;
-  SmallVector<AffineMap> indexingMaps;
-  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
-    auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
-        rewriter, loc, packInfo, genericOp, inputOperand);
-    inputOperands.push_back(packedOperand);
-    indexingMaps.push_back(packedIndexingMap);
-  }
-
-  int64_t numInnerLoops = packInfo.getNumTiledLoops();
-  SmallVector<utils::IteratorType> iterTypes =
-      genericOp.getIteratorTypesArray();
-  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
-
   // Rebuild the indexing map for the corresponding init operand.
   auto [packedOutOperand, packedOutIndexingMap] =
-      getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp,
-                                     opOperand);
-  indexingMaps.push_back(packedOutIndexingMap);
+      getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo,
+                                     genericOp, opOperand);
 
   // We'll replace the init operand with the destination of pack op if the init
   // operand has not users in the body of the linalg.generic (pure elementwise).
@@ -306,15 +322,12 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
   Value dest = (genericOp.getRegionOutputArgs()[0].use_empty())
                    ? packOp.getDest()
                    : packedOutOperand;
-  auto newGenericOp = rewriter.create<linalg::GenericOp>(
-      loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
-      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
-  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
-                             newGenericOp.getRegion().begin());
-  return newGenericOp;
+
+  return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap,
+                           packInfo);
 }
 
-// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
+/// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
 struct BubbleUpPackOpThroughElemGenericOpPattern
     : public OpRewritePattern<tensor::PackOp> {
   using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
@@ -328,10 +341,134 @@ struct BubbleUpPackOpThroughElemGenericOpPattern
     return success();
   }
 };
+
+// TODO: Relax this restriction. We should unpack an elementwise also
+// in the presence of multiple unpack ops as producers.
+/// Return the unpacked operand, if present, for the current generic op.
+static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
+  OpOperand *unPackedOperand = nullptr;
+  for (OpOperand &operand : genericOp->getOpOperands()) {
+    auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
+    if (!unPackOp)
+      continue;
+    if (unPackedOperand)
+      return failure();
+    unPackedOperand = &operand;
+  }
+  if (!unPackedOperand)
+    return failure();
+  return unPackedOperand;
+}
+
+/// Push down a tensor.unpack op through elementwise generic op.
+/// The new generic op works on packed domain; pack ops are created for input
+/// and output operands. A tensor.unpack op is inserted right after the packed
+/// generic. E.g.
+///
+/// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+///
+/// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
+///
+/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
+/// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
+///                          inner_dims_pos = [3] inner_tiles = [32] into %0
+/// %2 = linalg.generic {indexing_maps = [#map],
+///      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+///      outs(%1 : tensor<12x56x56x64xf32>) {
+///      ^bb0(%out : f32):
+///         linalg.yield %out : f32
+///      } -> tensor<12x56x56x64xf32>
+///
+/// will be converted to
+///
+/// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+///
+/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
+/// %1 = linalg.generic {indexing_maps = [#map],
+///      iterator_types = ["parallel", "parallel", "parallel",
+///                        "parallel", "parallel"]}
+///      outs(%arg0 : tensor<12x2x56x56x32xf32>) {
+///      ^bb0(%out : f32):
+///         linalg.yield %out : f32
+///      } -> tensor<12x2x56x56x32xf32>
+/// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
+///                       inner_dims_pos = [3] inner_tiles = [32] into %0
+///
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownUnPackOpThroughElemGenericOp(RewriterBase &rewriter,
+                                     GenericOp genericOp) {
+  if (!isElementwise(genericOp))
+    return failure();
+  if (genericOp.getNumResults() != 1)
+    return failure();
+
+  // Collect the unPacked operand, if present.
+  auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
+  if (failed(maybeUnPackedOperand))
+    return failure();
+  OpOperand *unPackedOperand = *(maybeUnPackedOperand);
+
+  // Extract packing information.
+  tensor::UnPackOp producerUnPackOp =
+      unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
+  assert(producerUnPackOp && "expect a valid UnPackOp");
+  auto packInfo = getPackingInfoFromOperand(
+      genericOp.getMatchingIndexingMap(unPackedOperand), producerUnPackOp);
+
+  // Rebuild the indexing map for the corresponding init operand.
+  auto [packedOutOperand, packedOutIndexingMap] =
+      getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo,
+                                     genericOp, genericOp.getDpsInitOperand(0));
+
+  // If the dps init operand of the generic is a tensor.empty, do not pack it
+  // and forward the new tensor.empty as a destination.
+  Value dest = packedOutOperand;
+  if (auto initTensor = genericOp.getDpsInitOperand(0)
+                            ->get()
+                            .getDefiningOp<tensor::EmptyOp>()) {
+    if (auto packOp = packedOutOperand.getDefiningOp<tensor::PackOp>())
+      dest = packOp.getDest();
+  }
+
+  // Pack the genericOp.
+  GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest,
+                                             packedOutIndexingMap, packInfo);
+
+  auto unPackOp = unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
+  // Insert an unPackOp right after the packed generic.
+  Value unPackOpRes =
+      rewriter
+          .create<tensor::UnPackOp>(
+              genericOp.getLoc(),
+              newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+              unPackOp.getDest(), producerUnPackOp.getInnerDimsPos(),
+              producerUnPackOp.getMixedTiles(),
+              producerUnPackOp.getOuterDimsPerm())
+          .getResult();
+
+  return std::make_tuple(newGenericOp, unPackOpRes);
+}
+
+// Wrapper pattern that applies pushDownUnPackOpThroughElemGenericOp method.
+struct PushDownUnPackOpThroughElemGenericOp
+    : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    auto genericAndRepl =
+        pushDownUnPackOpThroughElemGenericOp(rewriter, genericOp);
+    if (failed(genericAndRepl))
+      return failure();
+    rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns) {
-  patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern>(
-      patterns.getContext());
+  patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern,
+                  PushDownUnPackOpThroughElemGenericOp>(patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cd9d3acc7a635..b699b3d19cee6 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -352,15 +352,123 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
 // CHECK: func.func @elem_pack_transpose_outer_dims
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
-// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
-// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG0_EMPTY]]
 // CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
 // CHECK: %[[PACKED_ARG1:.+]] = tensor.pack %[[ARG1]]
 // CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[ARG1_EMPTY]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG0_EMPTY]]
 // CHECK: %[[RES:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP]]]
 // CHECK-SAME:  ins(%[[PACKED_ARG0]]
 // CHECK-SAME:  outs(%[[PACKED_ARG1]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
+  %0 = tensor.empty() : tensor<12x56x56x64xf32>
+  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
+  %2 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%1 : tensor<12x56x56x64xf32>) {
+    ^bb0(%out: f32):
+      %3 = arith.addf %out, %out : f32
+      linalg.yield %3 : f32
+  } -> tensor<12x56x56x64xf32>
+  return %2 : tensor<12x56x56x64xf32>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: func.func @unpack_on_output
+// CHECK-SAME:  %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY_UNPACK:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:  into %[[ARG0_EMPTY_UNPACK]]
+// CHECK: %[[ARG0_EMPTY_PACK:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:  into %[[ARG0_EMPTY_PACK]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]]]
+// CHECK-SAME:  outs(%[[PACKED_ARG0]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:  into %[[ARG0_EMPTY_UNPACK]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf32>) -> tensor<12x56x56x64xf32> {
+  %0 = tensor.empty() : tensor<12x56x56x64xf32>
+  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
+  %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %3 = arith.addf %in, %out : f32
+      linalg.yield %3 : f32
+  } -> tensor<12x56x56x64xf32>
+  return %2 : tensor<12x56x56x64xf32>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: func.func @unpack_on_input
+// CHECK-SAME:  %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
+// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[ARG1_PACK:.+]] = tensor.pack %[[ARG1]] 
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[ARG1_PACK_EMPTY]]
+// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[ARG0_PACK:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME:  ins(%[[ARG0_PACK]]
+// CHECK-SAME:  outs(%[[ARG1_PACK]]
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[RES]] 
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
+  %init = tensor.empty() : tensor<12x56x56x64xf32>
+  %0 = tensor.empty() : tensor<12x56x56x64xf32>
+  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
+  %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      %3 = arith.addf %in, %in : f32
+      linalg.yield %3 : f32
+  } -> tensor<12x56x56x64xf32>
+  return %2 : tensor<12x56x56x64xf32>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: func.func @forward_tensor_empty
+// CHECK-SAME:  %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_UNPACK_EMPTY:.+]] = tensor.empty() : tensor<12x56x56x64xf32>
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack %[[ARG0]] 
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
+// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[ARG0_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]] 
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[ARG0_PACK_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME:  ins(%[[PACKED_ARG0]]
+// CHECK-SAME:  outs(%[[DEST]]
+// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]] 


        


More information about the Mlir-commits mailing list