[Mlir-commits] [mlir] 0f297ca - [mlir][tensor][linalg] Introduce DataLayoutPropagation pass.
Hanhan Wang
llvmlistbot at llvm.org
Tue Dec 6 15:00:16 PST 2022
Author: Hanhan Wang
Date: 2022-12-06T15:00:07-08:00
New Revision: 0f297cad4d5b5ccb69c2e8610d7da4891cbf1f6b
URL: https://github.com/llvm/llvm-project/commit/0f297cad4d5b5ccb69c2e8610d7da4891cbf1f6b
DIFF: https://github.com/llvm/llvm-project/commit/0f297cad4d5b5ccb69c2e8610d7da4891cbf1f6b.diff
LOG: [mlir][tensor][linalg] Introduce DataLayoutPropagation pass.
It introduces a pattern that swaps `linalg.generic + tensor.pack` to
`tensor.pack + linalg.generic`. It requires all the iteration types
being parallel; the indexing map of output operand is identiy. They can
all be relaxed in the future.
The user can decide whether the propagation should be applied or not by
passing a control function.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D138882
Added:
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/test/Dialect/Linalg/data-layout-propagation.mlir
mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/lib/Dialect/Linalg/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f8473151c009e..a58c9dc23c1fc 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -81,6 +81,9 @@ void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);
+/// Patterns to bubble up or down data layout ops across other operations.
+void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns);
+
/// Pattern to remove dead operands and results of `linalg.generic` operations.
/// This is effectively DCE for a linalg op.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index ee44862ba72db..f41bd7d7c563f 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1776,6 +1776,10 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
static ShapedType inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
+
+ static Value createDestinationTensor(OpBuilder &b, Location loc,
+ Value source, ArrayRef<OpFoldResult> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
}];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ca13f44c16022..4ca9f617adc3e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ConstantFold.cpp
+ DataLayoutPropagation.cpp
DecomposeLinalgOps.cpp
Detensorize.cpp
DropUnitDims.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
new file mode 100644
index 0000000000000..47145e36c55cf
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -0,0 +1,248 @@
+//===- DataLayoutPropagation.cpp -----------------------------------------===///
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+#define DEBUG_TYPE "linalg-data-layout-propagation"
+
+namespace {
+
+/// Returns a tuple for packed operand and indexing_map with the assumptions:
+/// 1) The generic op is the producer of the pack op.
+/// 2) The generic op has only one result.
+/// 3) The indexing map of the output operand is identity.
+/// If the operand is a scalar or packing dimensions are all irrelevant to the
+/// operand, the opreand and the updated indexing map will be returned.
+/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
+///
+/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
+/// #map1 = affine_map<(d0, d1) -> (d0)>
+/// #map2 = affine_map<(d0, d1) -> (d1)>
+/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
+/// iterator_types = ["parallel", "parallel"]}
+/// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+/// outs(%init : tensor<?x?xf32>) {
+/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+/// %4 = arith.addf %arg3, %arg4 : f32
+/// linalg.yield %4 : f32
+/// } -> tensor<?x?xf32>
+/// %1 = tensor.pack %0
+/// inner_dims_pos = [0, 1]
+/// inner_tiles = [8, 2]
+/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+///
+/// Taking the first input operand as an example, the inner tile size of d1 is
+/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
+/// affine_map<(d1, d3)>` will be returned.
+///
+/// %pack = tensor.pack %arg0
+/// inner_dims_pos = [0]
+/// inner_tiles = [8]
+/// into %init : tensor<?xf32> -> tensor<?x8xf32>
+static std::tuple<Value, AffineMap>
+getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc,
+ tensor::PackOp packOp, GenericOp genericOp,
+ OpOperand *opOperand) {
+ int numOrigLoops = genericOp.getNumLoops();
+ int64_t numInnerLoops = packOp.getInnerDimsPos().size();
+ int64_t numLoops = numOrigLoops + numInnerLoops;
+ AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
+ SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
+
+ if (genericOp.isScalar(opOperand))
+ return std::make_tuple(
+ opOperand->get(),
+ AffineMap::get(numLoops, 0, exprs, packOp.getContext()));
+
+ llvm::SetVector<int64_t> innerDimsPosSet(packOp.getInnerDimsPos().begin(),
+ packOp.getInnerDimsPos().end());
+ // Mapping from AffinDimExpr of indexing maps to the operand shape dimension.
+ DenseMap<int64_t, int64_t> iterMapToDim;
+ for (auto [index, expr] : llvm::enumerate(origIndexingMap.getResults())) {
+ int64_t dimPos = expr.cast<AffineDimExpr>().getPosition();
+ if (!innerDimsPosSet.contains(dimPos))
+ continue;
+ iterMapToDim[dimPos] = index;
+ }
+
+ // Construct the information of packing data dimensions and new indexing maps
+ // for the operand.
+ SmallVector<int64_t> innerDimsPos;
+ SmallVector<OpFoldResult> innerTileSizes;
+ for (auto [index, value] : llvm::enumerate(
+ llvm::zip(packOp.getInnerDimsPos(), packOp.getMixedTiles()))) {
+ int64_t dimPos = std::get<0>(value);
+ if (!iterMapToDim.count(dimPos))
+ continue;
+ innerDimsPos.push_back(iterMapToDim[dimPos]);
+ innerTileSizes.push_back(std::get<1>(value));
+ exprs.push_back(b.getAffineDimExpr(numOrigLoops + index));
+ }
+ auto indexingMap = AffineMap::get(numLoops, 0, exprs, packOp.getContext());
+
+ SmallVector<int64_t> outerDimsPerm;
+ for (auto outDim : packOp.getOuterDimsPerm()) {
+ if (!iterMapToDim.count(outDim))
+ continue;
+ outerDimsPerm.push_back(iterMapToDim[outDim]);
+ }
+
+ // The operand does not have dimensions that relates to pack op.
+ if (innerDimsPos.empty() && outerDimsPerm.empty())
+ return std::make_tuple(opOperand->get(), indexingMap);
+
+ auto empty = tensor::PackOp::createDestinationTensor(
+ b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
+ auto packedOperand = b.create<tensor::PackOp>(
+ loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
+ packOp.getPaddingValue(), outerDimsPerm);
+ return std::make_tuple(packedOperand, indexingMap);
+}
+
+/// Bubbles up tensor.pack op through elementwise generic op. This
+/// swap pack(generic) to generic(pack). The new generic op works on packed
+/// domain; pack ops are created for input and output operands. E.g.,
+///
+/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
+/// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+/// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+/// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
+/// %3 = linalg.generic {indexing_maps = [#map0, #map0],
+/// iterator_types = ["parallel", "parallel"]}
+/// ins(%arg0 : tensor<?x?xf32>)
+/// outs(%2 : tensor<?x?xf32>) {
+/// ^bb0(%arg3: f32, %arg4: f32):
+/// %4 = arith.addf %arg3, %arg3 : f32
+/// linalg.yield %4 : f32
+/// } -> tensor<?x?xf32>
+/// %4 = tensor.pack %3
+/// inner_dims_pos = [0, 1]
+/// inner_tiles = [8, 2]
+/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+///
+/// will be converted to
+///
+/// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
+/// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
+/// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+/// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+/// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+/// %0 = affine.apply #map()[%dim]
+/// %1 = affine.apply #map1()[%dim_0]
+/// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
+/// %pack = tensor.pack %arg0
+/// inner_dims_pos = [0, 1]
+/// inner_tiles = [8, 2]
+/// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+/// %3 = linalg.generic {indexing_maps = [#map2, #map2],
+/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+/// ins(%pack : tensor<?x?x8x2xf32>)
+/// outs(%arg1 : tensor<?x?x8x2xf32>) {
+/// ^bb0(%in: f32, %out: f32):
+/// %4 = arith.addf %in, %in : f32
+/// linalg.yield %4 : f32
+/// } -> tensor<?x?x8x2xf32>
+static FailureOr<GenericOp>
+bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
+ tensor::PackOp packOp) {
+ auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
+ if (!genericOp)
+ return failure();
+
+ if (!isElementwise(genericOp))
+ return failure();
+
+ // TODO: Relax the restriction. We are able to bubble up the pack op through
+ // multi-result generic op. It just needs more work.
+ if (genericOp.getNumResults() != 1)
+ return failure();
+
+ // TODO: Add an option for allowing padding values. It could introduce
+ // undefined behavior if we unconditionally propagate pack op through all
+ // the ops. E.g., if the padding value is zero and there are division ops in
+ // a generic op. Some values of padding area could be NaN (0/0).
+ if (packOp.getPaddingValue())
+ return failure();
+
+ OpOperand *opOperand = genericOp.getDpsInitOperand(0);
+ // TODO: Add support for all permutation indexing maps.
+ if (!genericOp.getMatchingIndexingMap(opOperand).isIdentity())
+ return rewriter.notifyMatchFailure(
+ packOp, "the result of generic op does not have identity indexing_map");
+
+ Location loc = packOp.getLoc();
+ SmallVector<Value> inputOperands;
+ SmallVector<AffineMap> indexingMaps;
+ for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
+ auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
+ rewriter, loc, packOp, genericOp, inputOperand);
+ inputOperands.push_back(packedOperand);
+ indexingMaps.push_back(packedIndexingMap);
+ }
+
+ int64_t numLoops = genericOp.getNumLoops();
+ int64_t numInnerLoops = packOp.getInnerDimsPos().size();
+ int64_t newNumLoops = numLoops + numInnerLoops;
+ SmallVector<utils::IteratorType> iterTypes =
+ genericOp.getIteratorTypesArray();
+ iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
+
+ SmallVector<AffineExpr> outExprs(
+ genericOp.getMatchingIndexingMap(opOperand).getResults());
+ for (int i = 0; i < numInnerLoops; ++i)
+ outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i));
+ indexingMaps.push_back(
+ AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext()));
+
+ auto newGenericOp = rewriter.create<linalg::GenericOp>(
+ loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps,
+ iterTypes, /*bodyBuild=*/nullptr,
+ linalg::getPrunedAttributeList(genericOp));
+ rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+ newGenericOp.getRegion().begin());
+ return newGenericOp;
+}
+
+// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
+struct BubbleUpPackOpThroughElemGenericOpPattern
+ : public OpRewritePattern<tensor::PackOp> {
+ using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto genericOp = bubbleUpPackOpThroughElemGenericOp(rewriter, packOp);
+ if (failed(genericOp))
+ return failure();
+ rewriter.replaceOp(packOp, genericOp.value().getResults());
+ return success();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateDataLayoutPropagationPatterns(
+ RewritePatternSet &patterns) {
+ patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern>(
+ patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0b24149a3e98a..8faf6cc2e1c8d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3349,6 +3349,37 @@ ShapedType PackOp::inferPackedType(ShapedType sourceType,
return RankedTensorType::get(resultShape, sourceType.getElementType());
}
+Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
+ ArrayRef<OpFoldResult> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm) {
+ AffineExpr dim0, dim1;
+ bindDims(b.getContext(), dim0, dim1);
+ auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
+ return makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1), {v1, v2});
+ };
+
+ SmallVector<OpFoldResult> mixedSizes;
+ for (auto [index, value] :
+ llvm::enumerate(source.getType().cast<RankedTensorType>().getShape())) {
+ if (ShapedType::isDynamic(value))
+ mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
+ else
+ mixedSizes.push_back(b.getIndexAttr(value));
+ }
+ for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
+ int64_t dimPos = std::get<0>(it);
+ OpFoldResult tileSize = std::get<1>(it);
+ mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
+ }
+ if (!outerDimsPerm.empty())
+ applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
+
+ mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
+ auto elemType = source.getType().cast<ShapedType>().getElementType();
+ return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
+}
+
/// Returns true if the tiles and the tiled dims are constant.
template <typename OpTy>
bool areTilesAndTiledDimsAllConstant(OpTy op) {
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
new file mode 100644
index 0000000000000..a5488d28b20c9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -0,0 +1,230 @@
+// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
+{
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
+ %3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %4 = arith.addf %arg3, %arg3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x?xf32>
+ %4 = tensor.pack %3
+ inner_dims_pos = [0, 1]
+ inner_tiles = [8, 2]
+ into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+ return %4 : tensor<?x?x8x2xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: func.func @dynamic_elem_pack
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
+// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor<?x?x8x2xf32>
+// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[PACK_ARG0]]
+// CHECK-SAME: outs(%[[DEST]]
+// CHECK: return %[[ELEM]] : tensor<?x?x8x2xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{
+ %init = tensor.empty() : tensor<128x256xi32>
+ %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<128x256xi32>)
+ outs(%init : tensor<128x256xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32):
+ %4 = arith.addi %arg3, %arg3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<128x256xi32>
+ %pack = tensor.pack %elem
+ inner_dims_pos = [1, 0]
+ inner_tiles = [16, 32]
+ into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32>
+ return %pack : tensor<4x16x16x32xi32>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: func.func @elem_pack_transpose_inner_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
+// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[PACK_ARG0]]
+// CHECK-SAME: outs(%[[DEST]]
+// CHECK: return %[[ELEM]] : tensor<4x16x16x32xi32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{
+ %init = tensor.empty() : tensor<128x256xi32>
+ %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<128x256xi32>)
+ outs(%init : tensor<128x256xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32):
+ %4 = arith.addi %arg3, %arg3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<128x256xi32>
+ %pack = tensor.pack %elem
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [0, 1]
+ inner_tiles = [32, 16]
+ into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+ return %pack : tensor<16x4x32x16xi32>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[PACK_ARG0]]
+// CHECK-SAME: outs(%[[DEST]]
+// CHECK: return %[[ELEM]] : tensor<16x4x32x16xi32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x16x32xi32>) -> tensor<16x4x16x32xi32>{
+ %init = tensor.empty() : tensor<128x256xi32>
+ %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<128x256xi32>)
+ outs(%init : tensor<128x256xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32):
+ %4 = arith.addi %arg3, %arg3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<128x256xi32>
+ %pack = tensor.pack %elem
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [1, 0]
+ inner_tiles = [16, 32]
+ into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32>
+ return %pack : tensor<16x4x16x32xi32>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: func.func @elem_pack_transpose_inner_and_outer_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
+// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[PACK_ARG0]]
+// CHECK-SAME: outs(%[[DEST]]
+// CHECK: return %[[ELEM]] : tensor<16x4x16x32xi32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1)>
+func.func @dynamic_broadcast_pack(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
+{
+ %c0 = arith.constant 0 : index
+ %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %1 = tensor.dim %arg1, %c0 : tensor<?xf32>
+ %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
+ %3 = linalg.generic {indexing_maps = [#map1, #map2, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+ outs(%2 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %4 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x?xf32>
+ %4 = tensor.pack %3
+ inner_dims_pos = [0, 1]
+ inner_tiles = [8, 2]
+ into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
+ return %4 : tensor<?x?x8x2xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: func.func @dynamic_broadcast_pack
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor<?x8xf32>
+// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C0]]
+// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]]
+// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor<?x2xf32>
+// CHECK: %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [2]
+// CHECK-SAME: into %[[ARG1_EMPTY]]
+// CHECK: %[[ELEM:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[PACK_ARG0]], %[[PACK_ARG0]]
+// CHECK-SAME: outs(%[[DEST]]
+// CHECK: return %[[ELEM]] : tensor<?x?x8x2xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+#map2 = affine_map<(d0, d1) -> (d1)>
+func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32>
+{
+ %init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
+ %transpose = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>)
+ outs(%init_transpose : tensor<100x200x128x256xi32>) {
+ ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+ %0 = arith.addi %b0, %b1 : i32
+ %1 = arith.addi %0, %b2 : i32
+ linalg.yield %1 : i32
+ } -> tensor<100x200x128x256xi32>
+ %4 = tensor.pack %transpose
+ inner_dims_pos = [3, 2]
+ inner_tiles = [16, 32]
+ into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
+ return %4 : tensor<100x200x4x16x16x32xi32>
+}
+// CHECK: func.func @transpose_pack
+// CHECK: linalg.generic
+// CHECK: tensor.pack
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 9a62e8a6eaf49..4640a2c10d229 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRLinalgTestPasses
+ TestDataLayoutPropagation.cpp
TestLinalgDecomposeOps.cpp
TestLinalgElementwiseFusion.cpp
TestLinalgFusionTransforms.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
new file mode 100644
index 0000000000000..b4d6d42ab76af
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -0,0 +1,49 @@
+//===- TestDataLayoutPropagation.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestDataLayoutPropagationPass
+ : public PassWrapper<TestDataLayoutPropagationPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass)
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
+ }
+
+ StringRef getArgument() const final {
+ return "test-linalg-data-layout-propagation";
+ }
+ StringRef getDescription() const final {
+ return "Test data layout propagation";
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ linalg::populateDataLayoutPropagationPatterns(patterns);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestDataLayoutPropagation() {
+ PassRegistration<TestDataLayoutPropagationPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3295ad22dbaa7..e9200b7bf9724 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -72,6 +72,7 @@ void registerTestConstantFold();
void registerTestControlFlowSink();
void registerTestGpuSerializeToCubinPass();
void registerTestGpuSerializeToHsacoPass();
+void registerTestDataLayoutPropagation();
void registerTestDataLayoutQuery();
void registerTestDeadCodeAnalysisPass();
void registerTestDecomposeCallGraphTypes();
@@ -181,6 +182,7 @@ void registerTestPasses() {
mlir::test::registerTestGpuSerializeToHsacoPass();
#endif
mlir::test::registerTestDecomposeCallGraphTypes();
+ mlir::test::registerTestDataLayoutPropagation();
mlir::test::registerTestDataLayoutQuery();
mlir::test::registerTestDeadCodeAnalysisPass();
mlir::test::registerTestDominancePass();
More information about the Mlir-commits
mailing list