[Mlir-commits] [mlir] b4563ee - [mlir][linalg] Enable propagation of pack/unpack ops through non-elementwise
Quinn Dawkins
llvmlistbot at llvm.org
Tue Apr 11 09:04:26 PDT 2023
Author: Quinn Dawkins
Date: 2023-04-11T11:59:26-04:00
New Revision: b4563ee17ce45728a323c2708e549627b0a8ee9c
URL: https://github.com/llvm/llvm-project/commit/b4563ee17ce45728a323c2708e549627b0a8ee9c
DIFF: https://github.com/llvm/llvm-project/commit/b4563ee17ce45728a323c2708e549627b0a8ee9c.diff
LOG: [mlir][linalg] Enable propagation of pack/unpack ops through non-elementwise
Allows pack propagation through non-elementwise generics as long as all
tiled dimensions have parallel iterator types and are only indexed with
affine dim expressions by any of the operands.
This enables unpack propagation cases where the result type is different
from the current unpack destination tensor and thus motivates a similar
helper as the for pack for creating a destination tensor based on
pack information.
Outer dim permutations are allowed to permute reduction dims, however
remains unsupported for non-affine dim indexing map results.
Additionally ops with gather semantics now explicitly prohibit propagation.
Pack/unpack propagation through reductions may not always be beneficial
so user control over propagation decisions is made available through
a control function similar to the one for fusion.
Differential Revision: https://reviews.llvm.org/D147508
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/data-layout-propagation.mlir
mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05f10909aec9b..7eaa2f7168bdf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1326,8 +1326,14 @@ void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);
+/// Function type which is used to control propagation of tensor.pack/unpack
+/// ops.
+using ControlPropagationFn = std::function<bool(Operation *op)>;
+
/// Patterns to bubble up or down data layout ops across other operations.
-void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns);
+void populateDataLayoutPropagationPatterns(
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation);
/// Pattern to remove dead operands and results of `linalg.generic` operations.
/// This is effectively DCE for a linalg op.
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index d106f1285dfda..6d701d2ea44a0 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1874,6 +1874,10 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
+ static Value createDestinationTensor(OpBuilder &b, Location loc,
+ Value source, ArrayRef<OpFoldResult> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+
/// Build and return a new UnPackOp that is a clone of the current UnPackOp
/// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
/// innerPermutation (resp. outerPermutation).
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e1b929c8b121e..62a00c047775e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -32,6 +32,13 @@ using namespace mlir::linalg;
namespace {
+static bool hasGatherSemantics(linalg::GenericOp genericOp) {
+ for (Operation &op : genericOp.getBody()->getOperations())
+ if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
+ return true;
+ return false;
+}
+
// The struct contains the infomation about mapping packing information to
// the iteration domain of Linalg ops.
struct PackInfo {
@@ -48,12 +55,19 @@ struct PackInfo {
};
template <typename OpTy>
-static PackInfo getPackingInfoFromOperand(AffineMap indexingMap,
- OpTy packOrUnPackOp) {
+static FailureOr<PackInfo>
+getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
+ OpTy packOrUnPackOp) {
static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
"applies to only pack or unpack operations");
LLVM_DEBUG(
{ llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
+
+ AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ SmallVector<utils::IteratorType> iterators =
+ genericOp.getIteratorTypesArray();
+
PackInfo packInfo;
int64_t origNumDims = indexingMap.getNumDims();
SmallVector<AffineExpr> exprs(indexingMap.getResults());
@@ -61,8 +75,13 @@ static PackInfo getPackingInfoFromOperand(AffineMap indexingMap,
for (auto [index, innerDimPos, tileSize] :
llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
innerDimsPos, packOrUnPackOp.getMixedTiles())) {
+ auto expr = exprs[innerDimPos];
+ if (!expr.template isa<AffineDimExpr>())
+ return failure();
int64_t domainDimPos =
exprs[innerDimPos].template cast<AffineDimExpr>().getPosition();
+ if (!isParallelIterator(iterators[domainDimPos]))
+ return failure();
packInfo.tiledDimsPos.push_back(domainDimPos);
packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
@@ -75,9 +94,57 @@ static PackInfo getPackingInfoFromOperand(AffineMap indexingMap,
});
}
- for (auto dim : packOrUnPackOp.getOuterDimsPerm())
- packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim));
- if (!packInfo.outerDimsOnDomainPerm.empty()) {
+ // Bail out if a tiled dimension is present in a map but not as an affine dim
+ // expression.
+ auto areAllAffineDimExpr = [&](int dim) {
+ for (AffineMap map : indexingMaps) {
+ if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
+ return expr.isFunctionOfDim(dim) && !expr.isa<AffineDimExpr>();
+ })) {
+ return false;
+ }
+ }
+ return true;
+ };
+ for (int64_t i : packInfo.tiledDimsPos)
+ if (!areAllAffineDimExpr(i))
+ return failure();
+
+ // Get the outer dims perm on the iteration domain. Start by identifying the
+ // set of domain dims affected by the outer permutation along with the
+ // permuted ordering for those dims. Then the full outer dims permutation can
+ // be constructed by replacing the affected dims with the permuted result in a
+ // numLoops-rank identity. e.g.
+ // outerDimsPerm = [1, 2, 0]
+ // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
+ //
+ // permutedOuterDims = [4, 3, 1]
+ // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
+ //
+ // Non-affine dim expressions must not be permuted by the outer dims
+ // permutation.
+ SmallVector<int64_t> permutedOuterDims;
+ for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
+ auto permutedExpr = indexingMap.getResult(dim);
+ if (auto dimExpr = permutedExpr.template dyn_cast<AffineDimExpr>()) {
+ permutedOuterDims.push_back(dimExpr.getPosition());
+ continue;
+ }
+
+ // TODO: Allow propagation with transposes on non affine dim expressions,
+ // e.g. d0 + d1 which implies transposing both dims simultaneously while
+ // maintaining the relative position between them.
+ if (static_cast<int64_t>(index) != dim)
+ return failure();
+ }
+ if (!permutedOuterDims.empty()) {
+ int64_t outerDimIndex = 0;
+ llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
+ permutedOuterDims.end());
+ for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
+ packInfo.outerDimsOnDomainPerm.push_back(
+ permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
+ : i);
LLVM_DEBUG({
llvm::dbgs() << "map outer dimsDimsPerm to ";
for (auto dim : packInfo.outerDimsOnDomainPerm)
@@ -107,8 +174,13 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
SmallVector<int64_t> outerDimsPerm;
DenseMap<int64_t, int64_t> currentPositionTileLoops;
for (auto [pos, expr] : llvm::enumerate(exprs)) {
- unsigned posInDomain = expr.cast<AffineDimExpr>().getPosition();
- currentPositionTileLoops[posInDomain] = pos;
+ // Here we rely on the assumption that the outer dims permutation
+ // when propagating currently requires that non-affine dim expressions
+ // are not permuted, thus allowing the identity assignment below.
+ if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
+ currentPositionTileLoops[dimExpr.getPosition()] = pos;
+ else
+ currentPositionTileLoops[pos] = pos;
}
for (int64_t loopIdx : perm) {
if (currentPositionTileLoops.count(loopIdx))
@@ -169,8 +241,6 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
domainDimToOperandDim[dimPos] = index;
continue;
}
- assert(expr.isa<AffineConstantExpr>() &&
- "Found non-constant and non-affine dim expression");
}
SmallVector<int64_t> innerDimsPos;
SmallVector<OpFoldResult> innerTileSizes;
@@ -212,7 +282,7 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
// The operand does not have dimensions that relates to pack op.
- if (innerDimsPos.empty())
+ if (innerDimsPos.empty() && outerDimsPerm.empty())
return std::make_tuple(opOperand->get(), indexingMap);
auto empty = tensor::PackOp::createDestinationTensor(
@@ -252,7 +322,7 @@ static GenericOp packElementWiseOp(RewriterBase &rewriter, GenericOp genericOp,
return newGenericOp;
}
-/// Bubbles up tensor.pack op through elementwise generic op. This
+/// Bubbles up tensor.pack op through a producer generic op. This
/// swap pack(generic) to generic(pack). The new generic op works on packed
/// domain; pack ops are created for input and output operands. E.g.,
///
@@ -296,10 +366,20 @@ static GenericOp packElementWiseOp(RewriterBase &rewriter, GenericOp genericOp,
/// linalg.yield %4 : f32
/// } -> tensor<?x?x8x2xf32>
static FailureOr<GenericOp>
-bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
- tensor::PackOp packOp) {
+bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
+ ControlPropagationFn controlFn) {
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
- if (!genericOp || !isElementwise(genericOp))
+ if (!genericOp)
+ return failure();
+
+ // User controlled propagation function.
+ if (!controlFn(genericOp))
+ return failure();
+
+ // TODO: Enable propagation in the presence of linalg.index and
+ // tensor.extract, likely as a separate pattern as the pack information and
+ // propagation decision needs to be inferred from the region of the generic.
+ if (hasGatherSemantics(genericOp))
return failure();
// TODO: Relax the restriction. We are able to bubble up the pack op through
@@ -309,6 +389,8 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
// Bail-out if the result of the generic has multiple uses, as bubbling up
// creates recomputation if the generic has multiple users.
+ // TODO: Enable the case where every use is an identical pack op as no
+ // recomputation is needed in that case.
if (!genericOp->getResult(0).hasOneUse())
return failure();
@@ -343,12 +425,13 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
return failure();
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
- auto packInfo = getPackingInfoFromOperand(
- genericOp.getMatchingIndexingMap(opOperand), packOp);
+ auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
+ if (failed(packInfo))
+ return failure();
// Rebuild the indexing map for the corresponding init operand.
auto [packedOutOperand, packedOutIndexingMap] =
- getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo,
+ getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, opOperand);
// We'll replace the init operand with the destination of pack op if the init
@@ -360,22 +443,29 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
: packedOutOperand;
return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap,
- packInfo);
+ *packInfo);
}
-/// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method.
-struct BubbleUpPackOpThroughElemGenericOpPattern
+/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
+struct BubbleUpPackOpThroughGenericOpPattern
: public OpRewritePattern<tensor::PackOp> {
- using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+public:
+ BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
+ ControlPropagationFn fun)
+ : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(tensor::PackOp packOp,
PatternRewriter &rewriter) const override {
- auto genericOp = bubbleUpPackOpThroughElemGenericOp(rewriter, packOp);
+ auto genericOp =
+ bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
if (failed(genericOp))
return failure();
rewriter.replaceOp(packOp, genericOp->getResults());
return success();
}
+
+private:
+ ControlPropagationFn controlFn;
};
// TODO: Relax this restriction. We should unpack an elementwise also
@@ -431,13 +521,13 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
/// inner_dims_pos = [3] inner_tiles = [32] into %0
///
static FailureOr<std::tuple<GenericOp, Value>>
-pushDownUnPackOpThroughElemGenericOp(RewriterBase &rewriter,
- GenericOp genericOp) {
- if (!isElementwise(genericOp))
- return failure();
+pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
if (genericOp.getNumResults() != 1)
return failure();
+ if (hasGatherSemantics(genericOp))
+ return failure();
+
// Collect the unPacked operand, if present.
auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
if (failed(maybeUnPackedOperand))
@@ -448,13 +538,16 @@ pushDownUnPackOpThroughElemGenericOp(RewriterBase &rewriter,
tensor::UnPackOp producerUnPackOp =
unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
assert(producerUnPackOp && "expect a valid UnPackOp");
- auto packInfo = getPackingInfoFromOperand(
- genericOp.getMatchingIndexingMap(unPackedOperand), producerUnPackOp);
+ auto packInfo =
+ getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
+ if (failed(packInfo))
+ return failure();
// Rebuild the indexing map for the corresponding init operand.
auto [packedOutOperand, packedOutIndexingMap] =
- getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), packInfo,
+ getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, genericOp.getDpsInitOperand(0));
+ auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
// If the dps init operand of the generic is a tensor.empty, do not pack it
// and forward the new tensor.empty as a destination.
@@ -462,66 +555,76 @@ pushDownUnPackOpThroughElemGenericOp(RewriterBase &rewriter,
if (auto initTensor = genericOp.getDpsInitOperand(0)
->get()
.getDefiningOp<tensor::EmptyOp>()) {
- if (auto packOp = packedOutOperand.getDefiningOp<tensor::PackOp>())
- dest = packOp.getDest();
+ if (destPack)
+ dest = destPack.getDest();
}
// Pack the genericOp.
GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest,
- packedOutIndexingMap, packInfo);
+ packedOutIndexingMap, *packInfo);
+ Value newResult =
+ newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
+
+ // If the output is unaffected, no need to unpack.
+ if (!destPack)
+ return std::make_tuple(newGenericOp, newResult);
- // If the output element type for the generic
diff ers from the source
- // unpack op, we need to create a new destination tensor.
+ auto mixedTiles = destPack.getMixedTiles();
+ auto innerDimsPos = destPack.getInnerDimsPos();
+ auto outerDimsPerm = destPack.getOuterDimsPerm();
+
+ // If the output type for the generic
diff ers from the source
+ // unpack op, we need to create a new destination tensor. In the
+ // dynamic case we always need a new destination.
auto loc = genericOp.getLoc();
Value unPackDest = producerUnPackOp.getDest();
- auto genericOutElementType = getElementTypeOrSelf(genericOp.getResult(0));
- if (producerUnPackOp.getDestType().getElementType() !=
- genericOutElementType) {
- SmallVector<OpFoldResult> unPackMixedSizes;
- if (auto unPackEmpty = unPackDest.getDefiningOp<tensor::EmptyOp>())
- unPackMixedSizes = unPackEmpty.getMixedSizes();
- else
- unPackMixedSizes = tensor::getMixedSizes(rewriter, loc, unPackDest);
-
- unPackDest = rewriter.create<tensor::EmptyOp>(loc, unPackMixedSizes,
- genericOutElementType);
+ auto genericOutType =
+ genericOp.getDpsInitOperand(0)->get().getType().cast<RankedTensorType>();
+ if (producerUnPackOp.getDestType() != genericOutType ||
+ !genericOutType.hasStaticShape()) {
+ unPackDest = tensor::UnPackOp::createDestinationTensor(
+ rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
}
// Insert an unPackOp right after the packed generic.
Value unPackOpRes =
rewriter
- .create<tensor::UnPackOp>(
- loc,
- newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
- unPackDest, producerUnPackOp.getInnerDimsPos(),
- producerUnPackOp.getMixedTiles(),
- producerUnPackOp.getOuterDimsPerm())
+ .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
+ mixedTiles, outerDimsPerm)
.getResult();
return std::make_tuple(newGenericOp, unPackOpRes);
}
-// Wrapper pattern that applies pushDownUnPackOpThroughElemGenericOp method.
-struct PushDownUnPackOpThroughElemGenericOp
- : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
+// Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
+struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
+public:
+ PushDownUnPackOpThroughGenericOp(MLIRContext *context,
+ ControlPropagationFn fun)
+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- auto genericAndRepl =
- pushDownUnPackOpThroughElemGenericOp(rewriter, genericOp);
+ if (!controlFn(genericOp))
+ return failure();
+
+ auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
if (failed(genericAndRepl))
return failure();
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
return success();
}
+
+private:
+ ControlPropagationFn controlFn;
};
/// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
/// add as many zero padding dimensions in `high` and `low` based on the number
/// of point loops.
struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
- using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+ PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
+ : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
@@ -530,6 +633,9 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
if (!unpackOp)
return failure();
+ if (!controlFn(padOp))
+ return failure();
+
Location loc = padOp.getLoc();
// Bail out if one of the padded dimension is a tiled one.
llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
@@ -572,14 +678,17 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
rewriter.replaceOp(padOp, replacement);
return success();
}
+
+private:
+ ControlPropagationFn controlFn;
};
} // namespace
void mlir::linalg::populateDataLayoutPropagationPatterns(
- RewritePatternSet &patterns) {
- patterns
- .insert<BubbleUpPackOpThroughElemGenericOpPattern,
- PushDownUnPackOpThroughElemGenericOp, PushDownUnPackThroughPadOp>(
- patterns.getContext());
+ RewritePatternSet &patterns,
+ const ControlPropagationFn &controlPackUnPackPropagation) {
+ patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
+ PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
+ patterns.getContext(), controlPackUnPackPropagation);
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ccea6dd854af3..e092c6ea0f4a1 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3765,6 +3765,38 @@ void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
builder.getDenseI64ArrayAttr(staticTileSizes));
}
+Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
+ Value source,
+ ArrayRef<OpFoldResult> innerTileSizes,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm) {
+ AffineExpr sym0, sym1;
+ bindSymbols(b.getContext(), sym0, sym1);
+ auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
+ return makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
+ };
+
+ SmallVector<OpFoldResult> mixedSizes;
+ auto srcType = source.getType().cast<RankedTensorType>();
+ for (auto i :
+ llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
+ if (srcType.isDynamicDim(i))
+ mixedSizes.push_back(b.create<DimOp>(loc, source, i).getResult());
+ else
+ mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
+ }
+ if (!outerDimsPerm.empty()) {
+ applyPermutationToVector<OpFoldResult>(
+ mixedSizes, invertPermutationVector(outerDimsPerm));
+ }
+
+ for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
+ mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
+
+ auto elemType = srcType.getElementType();
+ return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
+}
+
UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
Value transposedSource,
ArrayRef<int64_t> innerPermutation,
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index f9d62754ea647..266fe2dd29a45 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -323,9 +323,6 @@ func.func @affine_constant_expr_pack(%arg0: tensor<100x128x200x256xi32>, %arg1:
// -----
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d0)>
-#map2 = affine_map<(d0, d1) -> (d1)>
func.func @transpose_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32>
{
%init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
@@ -679,3 +676,164 @@ func.func @scalar_tensor(%arg0 : tensor<f32>) -> tensor<1x32x7x7x32xf32> {
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[ARG0]]
// CHECK-SAME: outs(%[[EMPTY]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @unpack_empty_inner_dims(%arg0: tensor<12x64x56x56xf32>) -> tensor<12x56x56x64xf32> {
+ %init = tensor.empty() : tensor<12x56x56x64xf32>
+ %0 = tensor.empty() : tensor<12x56x56x64xf32>
+ %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = [] into %0 : tensor<12x64x56x56xf32> -> tensor<12x56x56x64xf32>
+ %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %3 = arith.addf %in, %in : f32
+ linalg.yield %3 : f32
+ } -> tensor<12x56x56x64xf32>
+ return %2 : tensor<12x56x56x64xf32>
+}
+
+// CHECK: func.func @unpack_empty_inner_dims
+// CHECK: %[[UNPACKED_ARG0:.+]] = tensor.unpack
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[UNPACKED_ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[PACKED_ARG0]]
+// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [] inner_tiles = []
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @reduction_pack_transpose_inner_dims(%arg0: tensor<128x256x32xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{
+ %init = tensor.empty() : tensor<128x256xi32>
+ %elem = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0 : tensor<128x256x32xi32>)
+ outs(%init : tensor<128x256xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32):
+ %4 = arith.addi %arg3, %arg4 : i32
+ linalg.yield %4 : i32
+ } -> tensor<128x256xi32>
+ %pack = tensor.pack %elem
+ inner_dims_pos = [1, 0]
+ inner_tiles = [16, 32]
+ into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32>
+ return %pack : tensor<4x16x16x32xi32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+// CHECK: func.func @reduction_pack_transpose_inner_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[ORIG_INIT:.+]] = tensor.empty() : tensor<128x256xi32>
+// CHECK: %[[INIT_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
+// CHECK: %[[PACK_INIT:.+]] = tensor.pack %[[ORIG_INIT]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16x32xi32>
+// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[RED:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[PACK_ARG0]]
+// CHECK-SAME: outs(%[[PACK_INIT]]
+// CHECK: return %[[RED]] : tensor<4x16x16x32xi32>
+
+// -----
+
+func.func @reduction_pack_with_outer_dims(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>) -> tensor<4x16x100x16x32xi32>
+{
+ %init_reduction = tensor.empty() : tensor<100x128x256xi32>
+ %reduction = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+ ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>)
+ outs(%init_reduction : tensor<100x128x256xi32>) {
+ ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+ %0 = arith.addi %b0, %b1 : i32
+ %1 = arith.addi %0, %b2 : i32
+ %2 = arith.addi %1, %b3 : i32
+ linalg.yield %2 : i32
+ } -> tensor<100x128x256xi32>
+ %init_pack = tensor.empty() : tensor<4x16x100x16x32xi32>
+ %4 = tensor.pack %reduction
+ outer_dims_perm = [1, 2, 0]
+ inner_dims_pos = [2, 1]
+ inner_tiles = [16, 32]
+ into %init_pack : tensor<100x128x256xi32> -> tensor<4x16x100x16x32xi32>
+ return %4 : tensor<4x16x100x16x32xi32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d5)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d5)>
+// CHECK: func.func @reduction_pack_with_outer_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<100x128x256xi32>
+// CHECK: %[[INIT_EMPTY:.+]] = tensor.empty() : tensor<4x16x100x16x32xi32>
+// CHECK: %[[PACKED_INIT:.+]] = tensor.pack %[[INIT]]
+// CHECK-SAME: outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 32]
+// CHECK-SAME: into %[[INIT_EMPTY]]
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x200x100x16x32xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [1, 3, 2, 0] inner_dims_pos = [3, 1] inner_tiles = [16, 32]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
+// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32]
+// CHECK-SAME: into %[[ARG2_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
+// CHECK-SAME: outs(%[[PACKED_INIT]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d3)>
+func.func @unpack_
diff erent_destination_shape(%arg0: tensor<1x1x1080x1920x16xi32>) -> tensor<16x540x960xi32>{
+ %init = tensor.empty() : tensor<16x540x960xi32>
+ %filter = tensor.empty() : tensor<2x2xi32>
+ %empty = tensor.empty() : tensor<1x16x1080x1920xi32>
+ %unpack = tensor.unpack %arg0
+ inner_dims_pos = [1]
+ inner_tiles = [16]
+ into %empty : tensor<1x1x1080x1920x16xi32> -> tensor<1x16x1080x1920xi32>
+ %pool = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+ ins(%unpack, %filter : tensor<1x16x1080x1920xi32>, tensor<2x2xi32>)
+ outs(%init : tensor<16x540x960xi32>) {
+ ^bb0(%in: i32, %in_1: i32, %out: i32):
+ %max = arith.maxui %in, %out : i32
+ linalg.yield %max : i32
+ } -> tensor<16x540x960xi32>
+ return %pool : tensor<16x540x960xi32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d3, d6)>
+// CHECK: func.func @unpack_
diff erent_destination_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[FILTER:.+]] = tensor.empty() : tensor<2x2xi32>
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x540x960x16xi32>
+// CHECK: %[[PACK_EMPTY:.+]] = tensor.empty() : tensor<1x1x1080x1920x16xi32>
+// CHECK: %[[PACK_ARG0:.+]] = tensor.pack
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [16]
+// CHECK-SAME: into %[[PACK_EMPTY]]
+// CHECK: %[[POOL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
+// CHECK-SAME: ins(%[[PACK_ARG0]], %[[FILTER]]
+// CHECK-SAME: outs(%[[INIT]]
+// CHECK: %[[UNPACK_NEW_DEST:.+]] = tensor.empty() : tensor<16x540x960xi32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[POOL]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
+// CHECK-SAME: into %[[UNPACK_NEW_DEST]]
+// CHECK: return %[[UNPACK]] : tensor<16x540x960xi32>
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index b4d6d42ab76af..3f4fed0fd47a0 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -32,7 +32,8 @@ struct TestDataLayoutPropagationPass
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
- linalg::populateDataLayoutPropagationPatterns(patterns);
+ linalg::populateDataLayoutPropagationPatterns(
+ patterns, [](Operation *op) { return true; });
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
More information about the Mlir-commits
mailing list