[Mlir-commits] [mlir] 009c053 - [mlir][linalg] Allow outer dims perm and untiled dims in pack/unpack generalization
Quinn Dawkins
llvmlistbot at llvm.org
Tue May 2 09:29:29 PDT 2023
Author: Quinn Dawkins
Date: 2023-05-02T12:26:45-04:00
New Revision: 009c053e3f822d0df556c6b39f632e31594373de
URL: https://github.com/llvm/llvm-project/commit/009c053e3f822d0df556c6b39f632e31594373de
DIFF: https://github.com/llvm/llvm-project/commit/009c053e3f822d0df556c6b39f632e31594373de.diff
LOG: [mlir][linalg] Allow outer dims perm and untiled dims in pack/unpack generalization
Extends the pack/unpack generalization patterns to work for any packing
op with only full tiles. This produces a combination of rank-reduced
insert/extract slice ops paired with a transpose on the reduced shape,
similar to what the pattern currently produces for fully tiled
pack/unpacks. Note that only the outer dims are rank-reduced in this
pattern, leaving the shape of the inner tile intact.
Differential Revision: https://reviews.llvm.org/D147555
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4f3f2dc0c734b..4a5c69c5bc061 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1241,66 +1241,124 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
/*nofold=*/false, loc, builder);
}
+// Normalizes a permutation on a higher rank space to its actual size, e.g.
+// perm = [1, 4, 2]
+// becomes
+// norm = [0, 2, 1]
static SmallVector<int64_t>
-getPackUnpackNormalizedInnerPerm(int rank, ArrayRef<int64_t> innerDimsPos) {
+getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) {
constexpr int64_t kNonTiledMarker = -1;
SmallVector<int64_t> vec(rank, kNonTiledMarker);
- for (auto [index, value] : llvm::enumerate(innerDimsPos))
+ for (auto [index, value] : llvm::enumerate(perm))
vec[value] = index;
- SmallVector<int64_t> perm = llvm::to_vector(llvm::make_filter_range(
+ SmallVector<int64_t> normalizedPerm = llvm::to_vector(llvm::make_filter_range(
vec, [&](int64_t v) { return v != kNonTiledMarker; }));
+ // This inverts the permutation in addition to normalizing so invert back.
+ return invertPermutationVector(normalizedPerm);
+}
+
+// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
+// assuming rank reduction of unit outer dims.
+static SmallVector<int64_t>
+getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> outerDimsPerm) {
+ SmallVector<int64_t> rankReducedOuterDimsPerm;
+ SmallVector<int64_t> outerDims;
+ SmallVector<int64_t> innerDims;
+ int64_t dim = 0;
+ int64_t unpackedRank = shape.size();
+ for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
+ if (llvm::is_contained(innerDimsPos, i)) {
+ innerDims.push_back(dim++);
+ continue;
+ }
+ if (shape[i] == 1)
+ continue;
+ outerDims.push_back(dim++);
+ if (!outerDimsPerm.empty())
+ rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
+ }
+
+ // Get the position of the inner dims after permutation.
+ SmallVector<int64_t> innerPerm =
+ getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
+ applyPermutationToVector<int64_t>(innerDims, innerPerm);
+
+ // Ditto for the outer dims.
+ SmallVector<int64_t> perm = outerDims;
+
+ rankReducedOuterDimsPerm =
+ getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
+ if (!rankReducedOuterDimsPerm.empty())
+ applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
+
+ // The tile always ends up as the inner most dims after packing.
+ perm.append(innerDims);
+
return perm;
}
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
- // TODO: support the case that outer dimensions are not all 1s A
- // tensor.expand_shape will be generated in this case.
- int64_t srcRank = packOp.getSourceRank();
- if (llvm::any_of(packOp.getDestType().getShape().take_front(srcRank),
- [](int64_t val) { return val != 1; })) {
- return rewriter.notifyMatchFailure(
- packOp, "require the outer dimension of the result are all 1s");
- }
-
if (llvm::any_of(packOp.getMixedTiles(),
[](OpFoldResult tile) { return tile.is<Value>(); })) {
return rewriter.notifyMatchFailure(packOp,
"require inner tile sizes being static");
}
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
+ // TODO: support the case that outer dimensions are not all 1s. A
+ // tensor.expand_shape will be generated in this case.
+ auto innerDimsPos = packOp.getInnerDimsPos();
+ int64_t srcRank = packOp.getSourceRank();
+ auto destShape = packOp.getDestType().getShape();
+ if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
+ return destShape[index] != 1;
+ })) {
+ return rewriter.notifyMatchFailure(
+ packOp, "require the tiled outer dimensions of the result are all 1s");
+ }
+
+ // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
+ // outer dims.
Location loc = packOp.getLoc();
+ Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
+ auto inputShape = packOp.getSourceType().getShape();
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+ packOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
SmallVector<int64_t> readShape;
- DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
- packOp.getDimAndTileMapping();
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
- if (!dimAndTileMapping.count(i)) {
- readSizes.push_back(oneIdxAttr);
+ if (dimAndTileMapping.count(i)) {
+ readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
+ .value_or(ShapedType::kDynamic));
+ readSizes.push_back(dimAndTileMapping[i]);
continue;
}
- readSizes.push_back(dimAndTileMapping[i]);
- readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
- .value_or(ShapedType::kDynamic));
+ if (ShapedType::isDynamic(inputShape[i])) {
+ readSizes.push_back(
+ rewriter.create<tensor::DimOp>(loc, input, i).getResult());
+ } else {
+ readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
+ }
+ if (inputShape[i] != 1)
+ readShape.push_back(inputShape[i]);
}
+
Type elemType = packOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShape, elemType);
- Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, readType, input, readOffsets, readSizes, readStrides);
// 2. Transpose the tile to match the inner tile order.
- SmallVector<int64_t> perm =
- getPackUnpackNormalizedInnerPerm(srcRank, packOp.getInnerDimsPos());
- // The permutation is inverted when normalizing so invert back to match the
- // ordering in the pack op.
- perm = invertPermutationVector(perm);
+
+ SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
+ inputShape, innerDimsPos, packOp.getOuterDimsPerm());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1316,9 +1374,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
int64_t destRank = packOp.getDestRank();
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
- SmallVector<OpFoldResult> writeSizes(srcRank, oneIdxAttr);
- for (auto size : transpShape)
- writeSizes.push_back(rewriter.getIndexAttr(size));
+ SmallVector<OpFoldResult> writeSizes =
+ tensor::getMixedSizes(rewriter, loc, packOp.getDest());
auto insert = rewriter.create<tensor::InsertSliceOp>(
loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
@@ -1333,35 +1390,59 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
int64_t srcRank = unpackOp.getSourceRank();
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
- if (llvm::any_of(srcShape.take_front(destRank),
- [](int64_t val) { return val != 1; })) {
+ ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
+ if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
+ return srcShape[index] != 1;
+ })) {
return rewriter.notifyMatchFailure(
- unpackOp, "require the outer dimension of the result are all 1s");
+ unpackOp,
+ "require the tiled outer dimensions of the result are all 1s");
}
// 1. Use rank-reduced tensor.extract_slice op to extract the tile.
Location loc = unpackOp.getLoc();
+ Value source = unpackOp.getSource();
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+ unpackOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
+ SmallVector<OpFoldResult> readSizes;
+ SmallVector<int64_t> readShape;
+ for (auto i : llvm::seq<unsigned>(0, destRank)) {
+ if (dimAndTileMapping.count(i)) {
+ readSizes.push_back(oneIdxAttr);
+ continue;
+ }
+ if (ShapedType::isDynamic(srcShape[i])) {
+ readSizes.push_back(
+ rewriter.create<tensor::DimOp>(loc, source, i).getResult());
+ } else {
+ readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
+ }
+ if (srcShape[i] != 1)
+ readShape.push_back(srcShape[i]);
+ }
auto mixedTiles = unpackOp.getMixedTiles();
- SmallVector<OpFoldResult> readSizes(destRank, oneIdxAttr);
readSizes.append(mixedTiles.begin(), mixedTiles.end());
// Explicitly create the type for extract_slice op because the inner tile
// size could be 1. We want to represent the whole inner tile in this case.
- ArrayRef<int64_t> readShape = srcShape.drop_front(destRank);
+ auto tileShape = srcShape.drop_front(destRank);
+ // Append the inner tile shape to the permuted and rank-reduced outer shape.
+ readShape.append(tileShape.begin(), tileShape.end());
Type elemType = unpackOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShape, elemType);
Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
// 2. Transpose the tile to match the outer corresponding tile order.
- ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
- SmallVector<int64_t> perm =
- getPackUnpackNormalizedInnerPerm(srcRank, innerDimsPos);
+ SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
+ srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
+ // Unpack is a transition out of packed space so we invert the permutation.
+ perm = invertPermutationVector(perm);
SmallVector<int64_t> transpShape(readShape);
applyPermutationToVector<int64_t>(transpShape, perm);
@@ -1375,11 +1456,13 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
SmallVector<OpFoldResult> tileSizes;
- for (int dim : innerDimsPos)
- tileSizes.push_back(getAsOpFoldResult(
- rewriter.createOrFold<tensor::DimOp>(loc, unpackOp.getDest(), dim)));
+ ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
+ for (auto i : llvm::seq<unsigned>(0, destRank)) {
+ if (dimAndTileMapping.count(i) || destShape[i] != 1)
+ tileSizes.push_back(getAsOpFoldResult(
+ rewriter.createOrFold<tensor::DimOp>(loc, unpackOp.getDest(), i)));
+ }
- applyPermutationToVector<OpFoldResult>(tileSizes, perm);
auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
@@ -1387,10 +1470,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> writeSizes;
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
- DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
- unpackOp.getDimAndTileMapping();
for (int i = 0, idx = 0; i < destRank; ++i) {
- if (dimAndTileMapping.count(i))
+ if (dimAndTileMapping.count(i) || destShape[i] != 1)
writeSizes.push_back(tileSizes[idx++]);
else
writeSizes.push_back(oneIdxAttr);
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
index 8e9b77ed6f679..283cb43e2997b 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir
@@ -76,3 +76,22 @@ func.func @simple_CHW_to_CHWhwc(%arg0: tensor<3x5x7xf32>, %arg1: tensor<1x1x1x5x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 5, 7, 3] [1, 1, 1, 1, 1, 1]
// CHECK: return %[[INSERT]]
+
+// -----
+
+func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1: tensor<3x1x1x1x8x32xf32>) -> tensor<3x1x1x1x8x32xf32> {
+ %0 = tensor.pack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32>
+ return %0 : tensor<3x1x1x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [3, 1, 32, 8] [1, 1, 1, 1]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x8x32xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[TILE]] : tensor<3x32x8xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<3x8x32xf32>)
+// CHECK-SAME: permutation = [0, 2, 1]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0, 0, 0] [3, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
+// CHECK: return %[[INSERT]]
diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
index cc734c24d4f56..a596690c2e4fd 100644
--- a/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir
@@ -55,3 +55,42 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
// They have the same type, so the insert_slice op is folded
// away.
// CHECK: return %[[TRANSP]]
+
+// -----
+
+func.func @simple_NCHWc_to_NCHW(%arg0: tensor<2x1x16x8x32xf32>, %arg1: tensor<2x32x16x8xf32>) -> tensor<2x32x16x8xf32> {
+ %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %arg1 : tensor<2x1x16x8x32xf32> -> tensor<2x32x16x8xf32>
+ return %0 : tensor<2x32x16x8xf32>
+}
+// CHECK-LABEL: func.func @simple_NCHWc_to_NCHW
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0] [2, 1, 16, 8, 32] [1, 1, 1, 1, 1]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x32x16x8xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[TILE]] : tensor<2x16x8x32xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x32x16x8xf32>)
+// CHECK-SAME: permutation = [0, 3, 1, 2]
+// They have the same type, so the insert_slice op is folded
+// away.
+// CHECK: return %[[TRANSP]]
+
+
+// -----
+
+func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x16x8xf32>) -> tensor<1x32x16x8xf32> {
+ %0 = tensor.unpack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [] inner_tiles = [] into %arg1 : tensor<1x16x8x32xf32> -> tensor<1x32x16x8xf32>
+ return %0 : tensor<1x32x16x8xf32>
+}
+// CHECK-LABEL: func.func @simple_NHWC_to_NCHW
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 16, 8, 32] [1, 1, 1, 1]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x8xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[TILE]] : tensor<16x8x32xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x16x8xf32>)
+// CHECK-SAME: permutation = [2, 0, 1]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1]
+// CHECK: return %[[INSERT]]
More information about the Mlir-commits
mailing list