[Mlir-commits] [mlir] [mlir][tensor] Generalize/restrict `GeneralizeOuterUnitDimsPackOpPattern` (PR #114315)
Quinn Dawkins
llvmlistbot at llvm.org
Tue Nov 5 13:38:15 PST 2024
================
@@ -1148,69 +1172,104 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp, "require the tiled outer dimensions of the result are all 1s");
}
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
- // outer dims.
+ Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
+ Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
+
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
auto inputShape = packOp.getSourceType().getShape();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
- Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
- Attribute oneIdxAttr = rewriter.getIndexAttr(1);
int64_t srcRank = packOp.getSourceRank();
+
+ int64_t destRank = packOp.getDestRank();
+ size_t numTiles = destRank - srcRank;
+
+ // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
+ // %extracted_tile = tensor.extract_slice(%pack_op_input)
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
- SmallVector<OpFoldResult> readSizes;
- SmallVector<OpFoldResult> transShapeForEmpty;
- SmallVector<int64_t> readShapeForExtractSlice;
+
+ // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
+ // all outer dims are 1.
+ SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
+ // The shape of the output for ExtractSliceOp. All leading unit dims are
+ // effectively rank-reduced, hence skipped.
+ SmallVector<int64_t> outputShapeForExtractSlice;
+
+ // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
+ // be equal to the inner tile sizes.
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
- readShapeForExtractSlice.push_back(
- getConstantIntValue(dimAndTileMapping[i])
- .value_or(ShapedType::kDynamic));
- readSizes.push_back(dimAndTileMapping[i]);
- transShapeForEmpty.push_back(dimAndTileMapping[i]);
- continue;
- }
- if (ShapedType::isDynamic(inputShape[i])) {
- readSizes.push_back(
- rewriter.create<tensor::DimOp>(loc, input, i).getResult());
- } else {
- readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
- }
- if (inputShape[i] != 1) {
- readShapeForExtractSlice.push_back(inputShape[i]);
- transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
+ auto [tileSize, tileSizeOfr] =
+ getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter);
+ extractSliceSizes.push_back(tileSizeOfr);
+ outputShapeForExtractSlice.push_back(tileSize);
}
}
Type elemType = packOp.getSourceType().getElementType();
- auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
+ auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
Value tile = rewriter.create<tensor::ExtractSliceOp>(
- loc, readType, input, readOffsets, readSizes, readStrides);
+ loc, readType, input, readOffsets, extractSliceSizes, readStrides);
- // 2. Transpose the tile to match the inner tile order.
+ // 2. Transpose the tile to match the inner tile order:
+ // %init = tensor.empty()
+ // %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
+ // NOTE: Outer dims are 1 and hence effectively ignored.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
- applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
+ // 2.1 Create tensor.empty (init value for TransposeOp)
+ SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
+ SmallVector<int64_t> transShapeForEmptyOpStatic;
+
+ // Acquire tensor shape required to create EmptyOp. This will match the inner
+ // tile sizes, but the actual data format will depend on whether the tile
+ // sizes are static or dynamic (each case leads to a different builder for
+ // EmptyOp). Conservatively, prepare for both scenarios.
+ size_t idx = numTiles;
+ while (idx != 0) {
+ transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
+ transShapeForEmptyOpStatic.push_back(
+ outputShapeForExtractSlice[numTiles - idx]);
+ idx--;
+ }
- Value empty =
- rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
+ applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
+ applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
+
+ Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
----------------
qedawkins wrote:
Shouldn't the builder for the static case always produce the same result as the dynamic case? Can we just keep the dynamic path?
I'm thinking that for any case where you needed the static builder, we could have had an additional dynamic dim that would make it take the dynamic path, which should still do the same thing for the static part.
https://github.com/llvm/llvm-project/pull/114315
More information about the Mlir-commits
mailing list