[Mlir-commits] [mlir] 0f297ca - [mlir][tensor][linalg] Introduce DataLayoutPropagation pass.

Hanhan Wang llvmlistbot at llvm.org
Tue Dec 6 15:00:16 PST 2022


Author: Hanhan Wang
Date: 2022-12-06T15:00:07-08:00
New Revision: 0f297cad4d5b5ccb69c2e8610d7da4891cbf1f6b

URL: https://github.com/llvm/llvm-project/commit/0f297cad4d5b5ccb69c2e8610d7da4891cbf1f6b
DIFF: https://github.com/llvm/llvm-project/commit/0f297cad4d5b5ccb69c2e8610d7da4891cbf1f6b.diff

LOG: [mlir][tensor][linalg] Introduce DataLayoutPropagation pass.

It introduces a pattern that swaps `linalg.generic + tensor.pack` to
`tensor.pack + linalg.generic`. It requires all the iteration types
being parallel; the indexing map of output operand is identiy. They can
all be relaxed in the future.

The user can decide whether the propagation should be applied or not by
passing a control function.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D138882

Added: 
    mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
    mlir/test/Dialect/Linalg/data-layout-propagation.mlir
    mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f8473151c009e..a58c9dc23c1fc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -81,6 +81,9 @@ void populateElementwiseOpsFusionPatterns(
     RewritePatternSet &patterns,
     const ControlFusionFn &controlElementwiseOpFusion);
 
+/// Patterns to bubble up or down data layout ops across other operations.
+void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns);
+
 /// Pattern to remove dead operands and results of `linalg.generic` operations.
 /// This is effectively DCE for a linalg op.
 void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index ee44862ba72db..f41bd7d7c563f 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1776,6 +1776,10 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
     static ShapedType inferPackedType(ShapedType sourceType,
         ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
         ArrayRef<int64_t> outerDimsPerm = {});
+
+    static Value createDestinationTensor(OpBuilder &b, Location loc,
+        Value source, ArrayRef<OpFoldResult> innerTileSizes,
+        ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
   }];
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ca13f44c16022..4ca9f617adc3e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   ConstantFold.cpp
+  DataLayoutPropagation.cpp
   DecomposeLinalgOps.cpp
   Detensorize.cpp
   DropUnitDims.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
new file mode 100644
index 0000000000000..47145e36c55cf
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -0,0 +1,248 @@
+//===- DataLayoutPropagation.cpp -----------------------------------------===///
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-data-layout-propagation"
+
+namespace {
+
+/// Returns a tuple for packed operand and indexing_map with the assumptions:
+///   1) The generic op is the producer of the pack op.
+///   2) The generic op has only one result.
+///   3) The indexing map of the output operand is identity.
+/// If the operand is a scalar or packing dimensions are all irrelevant to the
+/// operand, the opreand and the updated indexing map will be returned.
+/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
+///
+///   #map0 = affine_map<(d0, d1) -> (d0, d1)>
+///   #map1 = affine_map<(d0, d1) -> (d0)>
+///   #map2 = affine_map<(d0, d1) -> (d1)>
+///   %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
+///                        iterator_types = ["parallel", "parallel"]}
+///      ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+///      outs(%init : tensor<?x?xf32>) {
+///    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+///      %4 = arith.addf %arg3, %arg4 : f32
+///      linalg.yield %4 : f32
+///  } -> tensor<?x?xf32>
+///  %1 = tensor.pack %0
+///    inner_dims_pos = [0, 1]
+///    inner_tiles = [8, 2]
+///    into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+///
+///  Taking the first input operand as an example, the inner tile size of d1 is
+///  8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
+///  affine_map<(d1, d3)>` will be returned.
+///
+///  %pack = tensor.pack %arg0
+///    inner_dims_pos = [0]
+///    inner_tiles = [8]
+///    into %init : tensor<?xf32> -> tensor<?x8xf32>
+static std::tuple<Value, AffineMap>
+getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc,
+                               tensor::PackOp packOp, GenericOp genericOp,
+                               OpOperand *opOperand) {
+  int numOrigLoops = genericOp.getNumLoops();
+  int64_t numInnerLoops = packOp.getInnerDimsPos().size();
+  int64_t numLoops = numOrigLoops + numInnerLoops;
+  AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
+  SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
+
+  if (genericOp.isScalar(opOperand))
+    return std::make_tuple(
+        opOperand->get(),
+        AffineMap::get(numLoops, 0, exprs, packOp.getContext()));
+
+  llvm::SetVector<int64_t> innerDimsPosSet(packOp.getInnerDimsPos().begin(),
+                                           packOp.getInnerDimsPos().end());
+  // Mapping from AffinDimExpr of indexing maps to the operand shape dimension.
+  DenseMap<int64_t, int64_t> iterMapToDim;
+  for (auto [index, expr] : llvm::enumerate(origIndexingMap.getResults())) {
+    int64_t dimPos = expr.cast<AffineDimExpr>().getPosition();
+    if (!innerDimsPosSet.contains(dimPos))
+      continue;
+    iterMapToDim[dimPos] = index;
+  }
+
+  // Construct the information of packing data dimensions and new indexing maps
+  // for the operand.
+  SmallVector<int64_t> innerDimsPos;
+  SmallVector<OpFoldResult> innerTileSizes;
+  for (auto [index, value] : llvm::enumerate(
+           llvm::zip(packOp.getInnerDimsPos(), packOp.getMixedTiles()))) {
+    int64_t dimPos = std::get<0>(value);
+    if (!iterMapToDim.count(dimPos))
+      continue;
+    innerDimsPos.push_back(iterMapToDim[dimPos]);
+    innerTileSizes.push_back(std::get<1>(value));
+    exprs.push_back(b.getAffineDimExpr(numOrigLoops + index));
+  }
+  auto indexingMap = AffineMap::get(numLoops, 0, exprs, packOp.getContext());
+
+  SmallVector<int64_t> outerDimsPerm;
+  for (auto outDim : packOp.getOuterDimsPerm()) {
+    if (!iterMapToDim.count(outDim))
+      continue;
+    outerDimsPerm.push_back(iterMapToDim[outDim]);
+  }
+
+  // The operand does not have dimensions that relates to pack op.
+  if (innerDimsPos.empty() && outerDimsPerm.empty())
+    return std::make_tuple(opOperand->get(), indexingMap);
+
+  auto empty = tensor::PackOp::createDestinationTensor(
+      b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
+  auto packedOperand = b.create<tensor::PackOp>(
+      loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
+      packOp.getPaddingValue(), outerDimsPerm);
+  return std::make_tuple(packedOperand, indexingMap);
+}
+
+/// Bubbles up tensor.pack op through elementwise generic op. This
+/// swap pack(generic) to generic(pack). The new generic op works on packed
+/// domain; pack ops are created for input and output operands. E.g.,
+///
+///     #map0 = affine_map<(d0, d1) -> (d0, d1)>
+///     %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+///     %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+///     %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
+///     %3 = linalg.generic {indexing_maps = [#map0, #map0],
+///                          iterator_types = ["parallel", "parallel"]}
+///         ins(%arg0 : tensor<?x?xf32>)
+///         outs(%2 : tensor<?x?xf32>) {
+///       ^bb0(%arg3: f32, %arg4: f32):
+///         %4 = arith.addf %arg3, %arg3 : f32
+///         linalg.yield %4 : f32
+///     } -> tensor<?x?xf32>
+///     %4 = tensor.pack %3
+///       inner_dims_pos = [0, 1]
+///       inner_tiles = [8, 2]
+///       into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+///
+/// will be converted to
+///
+///     #map = affine_map<()[s0] -> (s0 ceildiv 8)>
+///     #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
+///     #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+///     %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+///     %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+///     %0 = affine.apply #map()[%dim]
+///     %1 = affine.apply #map1()[%dim_0]
+///     %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
+///     %pack = tensor.pack %arg0
+///       inner_dims_pos = [0, 1]
+///       inner_tiles = [8, 2]
+///       into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+///     %3 = linalg.generic {indexing_maps = [#map2, #map2],
+///       iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+///       ins(%pack : tensor<?x?x8x2xf32>)
+///       outs(%arg1 : tensor<?x?x8x2xf32>) {
+///     ^bb0(%in: f32, %out: f32):
+///       %4 = arith.addf %in, %in : f32
+///       linalg.yield %4 : f32
+///     } -> tensor<?x?x8x2xf32>
+static FailureOr<GenericOp>
+bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
+                                   tensor::PackOp packOp) {
+  auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
+  if (!genericOp)
+    return failure();
+
+  if (!isElementwise(genericOp))
+    return failure();
+
+  // TODO: Relax the restriction. We are able to bubble up the pack op through
+  // multi-result generic op. It just needs more work.
+  if (genericOp.getNumResults() != 1)
+    return failure();
+
+  // TODO: Add an option for allowing padding values. It could introduce
+  // undefined behavior if we unconditionally propagate pack op through all
+  // the ops. E.g., if the padding value is zero and there are division ops in
+  // a generic op. Some values of padding area could be NaN (0/0).
+  if (packOp.getPaddingValue())
+    return failure();
+
+  OpOperand *opOperand = genericOp.getDpsInitOperand(0);
+  // TODO: Add support for all permutation indexing maps.
+  if (!genericOp.getMatchingIndexingMap(opOperand).isIdentity())
+    return rewriter.notifyMatchFailure(
+        packOp, "the result of generic op does not have identity indexing_map");
+
+  Location loc = packOp.getLoc();
+  SmallVector<Value> inputOperands;
+  SmallVector<AffineMap> indexingMaps;
+  for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
+    auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
+        rewriter, loc, packOp, genericOp, inputOperand);
+    inputOperands.push_back(packedOperand);
+    indexingMaps.push_back(packedIndexingMap);
+  }
+
+  int64_t numLoops = genericOp.getNumLoops();
+  int64_t numInnerLoops = packOp.getInnerDimsPos().size();
+  int64_t newNumLoops = numLoops + numInnerLoops;
+  SmallVector<utils::IteratorType> iterTypes =
+      genericOp.getIteratorTypesArray();
+  iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
+
+  SmallVector<AffineExpr> outExprs(
+      genericOp.getMatchingIndexingMap(opOperand).getResults());
+  for (int i = 0; i < numInnerLoops; ++i)
+    outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i));
+  indexingMaps.push_back(
+      AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext()));
+
+  auto newGenericOp = rewriter.create<linalg::GenericOp>(
+      loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps,
+      iterTypes, /*bodyBuild=*/nullptr,
+      linalg::getPrunedAttributeList(genericOp));
+  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+                             newGenericOp.getRegion().begin());
+  return newGenericOp;
+}
+
+// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
+struct BubbleUpPackOpThroughElemGenericOpPattern
+    : public OpRewritePattern<tensor::PackOp> {
+  using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto genericOp = bubbleUpPackOpThroughElemGenericOp(rewriter, packOp);
+    if (failed(genericOp))
+      return failure();
+    rewriter.replaceOp(packOp, genericOp.value().getResults());
+    return success();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateDataLayoutPropagationPatterns(
+    RewritePatternSet &patterns) {
+  patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern>(
+      patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0b24149a3e98a..8faf6cc2e1c8d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3349,6 +3349,37 @@ ShapedType PackOp::inferPackedType(ShapedType sourceType,
   return RankedTensorType::get(resultShape, sourceType.getElementType());
 }
 
+Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
+                                      ArrayRef<OpFoldResult> innerTileSizes,
+                                      ArrayRef<int64_t> innerDimsPos,
+                                      ArrayRef<int64_t> outerDimsPerm) {
+  AffineExpr dim0, dim1;
+  bindDims(b.getContext(), dim0, dim1);
+  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
+    return makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1), {v1, v2});
+  };
+
+  SmallVector<OpFoldResult> mixedSizes;
+  for (auto [index, value] :
+       llvm::enumerate(source.getType().cast<RankedTensorType>().getShape())) {
+    if (ShapedType::isDynamic(value))
+      mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
+    else
+      mixedSizes.push_back(b.getIndexAttr(value));
+  }
+  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
+    int64_t dimPos = std::get<0>(it);
+    OpFoldResult tileSize = std::get<1>(it);
+    mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
+  }
+  if (!outerDimsPerm.empty())
+    applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
+
+  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
+  auto elemType = source.getType().cast<ShapedType>().getElementType();
+  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
+}
+
 /// Returns true if the tiles and the tiled dims are constant.
 template <typename OpTy>
 bool areTilesAndTiledDimsAllConstant(OpTy op) {

diff  --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
new file mode 100644
index 0000000000000..a5488d28b20c9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -0,0 +1,230 @@
+// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
+{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
+  %3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<?x?xf32>)
+      outs(%2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32):
+      %4 = arith.addf %arg3, %arg3 : f32
+      linalg.yield %4 : f32
+  } -> tensor<?x?xf32>
+  %4 = tensor.pack %3
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, 2]
+    into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+  return %4 : tensor<?x?x8x2xf32>
+}
+// CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
+// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK:      func.func @dynamic_elem_pack
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:    %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG:    %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG:    %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
+// CHECK-DAG:    %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
+// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor<?x?x8x2xf32>
+// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:     inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+// CHECK-SAME:     into %[[ARG0_EMPTY]]
+// CHECK:        %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]]]
+// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[PACK_ARG0]]
+// CHECK-SAME:     outs(%[[DEST]]
+// CHECK:        return %[[ELEM]] : tensor<?x?x8x2xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{
+  %init = tensor.empty() : tensor<128x256xi32>
+  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<128x256xi32>)
+      outs(%init : tensor<128x256xi32>) {
+    ^bb0(%arg3: i32, %arg4: i32):
+      %4 = arith.addi %arg3, %arg3 : i32
+      linalg.yield %4 : i32
+  } -> tensor<128x256xi32>
+  %pack = tensor.pack %elem
+    inner_dims_pos = [1, 0]
+    inner_tiles = [16, 32]
+    into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32>
+  return %pack : tensor<4x16x16x32xi32>
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK:      func.func @elem_pack_transpose_inner_dims
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
+// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:     inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME:     into %[[ARG0_EMPTY]]
+// CHECK:        %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[PACK_ARG0]]
+// CHECK-SAME:     outs(%[[DEST]]
+// CHECK:        return %[[ELEM]] : tensor<4x16x16x32xi32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{
+  %init = tensor.empty() : tensor<128x256xi32>
+  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<128x256xi32>)
+      outs(%init : tensor<128x256xi32>) {
+    ^bb0(%arg3: i32, %arg4: i32):
+      %4 = arith.addi %arg3, %arg3 : i32
+      linalg.yield %4 : i32
+  } -> tensor<128x256xi32>
+  %pack = tensor.pack %elem
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [32, 16]
+    into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+  return %pack : tensor<16x4x32x16xi32>
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK:      func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:     into %[[ARG0_EMPTY]]
+// CHECK:        %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[PACK_ARG0]]
+// CHECK-SAME:     outs(%[[DEST]]
+// CHECK:        return %[[ELEM]] : tensor<16x4x32x16xi32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x16x32xi32>) -> tensor<16x4x16x32xi32>{
+  %init = tensor.empty() : tensor<128x256xi32>
+  %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<128x256xi32>)
+      outs(%init : tensor<128x256xi32>) {
+    ^bb0(%arg3: i32, %arg4: i32):
+      %4 = arith.addi %arg3, %arg3 : i32
+      linalg.yield %4 : i32
+  } -> tensor<128x256xi32>
+  %pack = tensor.pack %elem
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [1, 0]
+    inner_tiles = [16, 32]
+    into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32>
+  return %pack : tensor<16x4x16x32xi32>
+}
+// CHECK-DAG:  #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK:      func.func @elem_pack_transpose_inner_and_outer_dims
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
+// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:     outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME:     into %[[ARG0_EMPTY]]
+// CHECK:        %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[PACK_ARG0]]
+// CHECK-SAME:     outs(%[[DEST]]
+// CHECK:        return %[[ELEM]] : tensor<16x4x16x32xi32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1)>
+func.func @dynamic_broadcast_pack(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
+{
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %1 = tensor.dim %arg1, %c0 : tensor<?xf32>
+  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
+  %3 = linalg.generic {indexing_maps = [#map1, #map2, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+      outs(%2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+      %4 = arith.addf %arg3, %arg4 : f32
+      linalg.yield %4 : f32
+  } -> tensor<?x?xf32>
+  %4 = tensor.pack %3
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, 2]
+    into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+  return %4 : tensor<?x?x8x2xf32>
+}
+// CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG:  #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
+// CHECK-DAG:  #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+// CHECK-DAG:  #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK-DAG:  #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK:      func.func @dynamic_broadcast_pack
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:    %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG:    %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
+// CHECK:        %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor<?x8xf32>
+// CHECK:        %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [8]
+// CHECK-SAME:     into %[[ARG0_EMPTY]]
+// CHECK-DAG:    %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG:    %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
+// CHECK:        %[[ARG1_EMPTY:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor<?x2xf32>
+// CHECK:        %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:     inner_dims_pos = [0] inner_tiles = [2]
+// CHECK-SAME:     into %[[ARG1_EMPTY]]
+// CHECK:        %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]]
+// CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[PACK_ARG0]], %[[PACK_ARG0]]
+// CHECK-SAME:     outs(%[[DEST]]
+// CHECK:        return %[[ELEM]] : tensor<?x?x8x2xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1)>
+func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32>
+{
+  %init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
+  %transpose = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0)>,
+                       affine_map<(d0, d1, d2, d3) -> (d1)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>)
+      outs(%init_transpose : tensor<100x200x128x256xi32>) {
+    ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+      %0 = arith.addi %b0, %b1 : i32
+      %1 = arith.addi %0, %b2 : i32
+      linalg.yield %1 : i32
+    } -> tensor<100x200x128x256xi32>
+  %4 = tensor.pack %transpose
+    inner_dims_pos = [3, 2]
+    inner_tiles = [16, 32]
+    into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
+  return %4 : tensor<100x200x4x16x16x32xi32>
+}
+// CHECK: func.func @transpose_pack
+// CHECK:   linalg.generic
+// CHECK:   tensor.pack

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 9a62e8a6eaf49..4640a2c10d229 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRLinalgTestPasses
+  TestDataLayoutPropagation.cpp
   TestLinalgDecomposeOps.cpp
   TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp

diff  --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
new file mode 100644
index 0000000000000..b4d6d42ab76af
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -0,0 +1,49 @@
+//===- TestDataLayoutPropagation.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestDataLayoutPropagationPass
+    : public PassWrapper<TestDataLayoutPropagationPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass)
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
+  }
+
+  StringRef getArgument() const final {
+    return "test-linalg-data-layout-propagation";
+  }
+  StringRef getDescription() const final {
+    return "Test data layout propagation";
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    linalg::populateDataLayoutPropagationPatterns(patterns);
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestDataLayoutPropagation() {
+  PassRegistration<TestDataLayoutPropagationPass>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3295ad22dbaa7..e9200b7bf9724 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -72,6 +72,7 @@ void registerTestConstantFold();
 void registerTestControlFlowSink();
 void registerTestGpuSerializeToCubinPass();
 void registerTestGpuSerializeToHsacoPass();
+void registerTestDataLayoutPropagation();
 void registerTestDataLayoutQuery();
 void registerTestDeadCodeAnalysisPass();
 void registerTestDecomposeCallGraphTypes();
@@ -181,6 +182,7 @@ void registerTestPasses() {
   mlir::test::registerTestGpuSerializeToHsacoPass();
 #endif
   mlir::test::registerTestDecomposeCallGraphTypes();
+  mlir::test::registerTestDataLayoutPropagation();
   mlir::test::registerTestDataLayoutQuery();
   mlir::test::registerTestDeadCodeAnalysisPass();
   mlir::test::registerTestDominancePass();


        


More information about the Mlir-commits mailing list