[Mlir-commits] [mlir] [mlir][tensor] Generalize/restrict `GeneralizeOuterUnitDimsPackOpPattern` (PR #114315)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Nov 6 01:00:09 PST 2024
================
@@ -1148,69 +1172,104 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp, "require the tiled outer dimensions of the result are all 1s");
}
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
- // outer dims.
+ Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
+ Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
+
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
auto inputShape = packOp.getSourceType().getShape();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
- Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
- Attribute oneIdxAttr = rewriter.getIndexAttr(1);
int64_t srcRank = packOp.getSourceRank();
+
+ int64_t destRank = packOp.getDestRank();
+ size_t numTiles = destRank - srcRank;
+
+ // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
+ // %extracted_tile = tensor.extract_slice(%pack_op_input)
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
- SmallVector<OpFoldResult> readSizes;
- SmallVector<OpFoldResult> transShapeForEmpty;
- SmallVector<int64_t> readShapeForExtractSlice;
+
+ // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
+ // all outer dims are 1.
+ SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
+ // The shape of the output for ExtractSliceOp. All leading unit dims are
+ // effectively rank-reduced, hence skipped.
+ SmallVector<int64_t> outputShapeForExtractSlice;
+
+ // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
+ // be equal to the inner tile sizes.
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
- readShapeForExtractSlice.push_back(
- getConstantIntValue(dimAndTileMapping[i])
- .value_or(ShapedType::kDynamic));
- readSizes.push_back(dimAndTileMapping[i]);
- transShapeForEmpty.push_back(dimAndTileMapping[i]);
- continue;
- }
- if (ShapedType::isDynamic(inputShape[i])) {
- readSizes.push_back(
- rewriter.create<tensor::DimOp>(loc, input, i).getResult());
- } else {
- readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
- }
- if (inputShape[i] != 1) {
- readShapeForExtractSlice.push_back(inputShape[i]);
- transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
+ auto [tileSize, tileSizeOfr] =
+ getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter);
+ extractSliceSizes.push_back(tileSizeOfr);
+ outputShapeForExtractSlice.push_back(tileSize);
}
}
Type elemType = packOp.getSourceType().getElementType();
- auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
+ auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
Value tile = rewriter.create<tensor::ExtractSliceOp>(
- loc, readType, input, readOffsets, readSizes, readStrides);
+ loc, readType, input, readOffsets, extractSliceSizes, readStrides);
- // 2. Transpose the tile to match the inner tile order.
+ // 2. Transpose the tile to match the inner tile order:
+ // %init = tensor.empty()
+ // %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
+ // NOTE: Outer dims are 1 and hence effectively ignored.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
- applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
+ // 2.1 Create tensor.empty (init value for TransposeOp)
+ SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
+ SmallVector<int64_t> transShapeForEmptyOpStatic;
+
+ // Acquire tensor shape required to create EmptyOp. This will match the inner
+ // tile sizes, but the actual data format will depend on whether the tile
+ // sizes are static or dynamic (each case leads to a different builder for
+ // EmptyOp). Conservatively, prepare for both scenarios.
+ size_t idx = numTiles;
+ while (idx != 0) {
+ transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
+ transShapeForEmptyOpStatic.push_back(
+ outputShapeForExtractSlice[numTiles - idx]);
+ idx--;
+ }
- Value empty =
- rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
+ applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
+ applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
+
+ Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
----------------
banach-space wrote:
> Shouldn't the builder for the static case always produce the same result as the dynamic case? Can we just keep the dynamic path?
Great point!
It turns out that [EmptyOp::build](https://github.com/llvm/llvm-project/blob/08411c855f77bd7416725c280ad3dccdc00b7dd6/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp#L887-L894) already supports the necessary "magic" via [dispatchIndexOpFoldResults](https://github.com/llvm/llvm-project/blob/08411c855f77bd7416725c280ad3dccdc00b7dd6/mlir/lib/Dialect/Utils/StaticValueUtils.cpp#L61-L66) :)
https://github.com/llvm/llvm-project/pull/114315
More information about the Mlir-commits
mailing list