[Mlir-commits] [mlir] [mlir][tensor] Update `GeneralizeOuterUnitDimsPackOpPattern` (PR #115312)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Nov 11 13:19:03 PST 2024
================
@@ -1153,71 +1153,65 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
Location loc = packOp.getLoc();
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
- auto inputShape = packOp.getSourceType().getShape();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
-
int64_t destRank = packOp.getDestRank();
- size_t numTiles = destRank - srcRank;
-
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
- // %extracted_tile = tensor.extract_slice(%pack_op_input)
- SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
- SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
+ int64_t numTiles = destRank - srcRank;
- // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
- // all outer dims are 1.
- SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
- // The shape of the output for ExtractSliceOp. All leading unit dims are
- // effectively rank-reduced, hence skipped.
- SmallVector<int64_t> outputShapeForExtractSlice;
+ if (!llvm::all_of(packOp.getInnerDimsPos(),
+ [&srcRank, &numTiles](int64_t dimPos) {
+ return dimPos >= (srcRank - numTiles - 1);
+ }))
+ return rewriter.notifyMatchFailure(
+ packOp, "Attempting to tile non-trailing source dims!");
- // Extract the trailing sizes and shape dims for ExtractSliceOp. These should
- // be equal to the inner tile sizes.
+ // 1. Extract the inner tile sizes.
+ // Where possible, values are replaced with constant attributes (to match the
+ // behaviour of `getPackOpSourceOrPaddedSource`).
+ SmallVector<OpFoldResult> tileSizes;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
- auto [tileSize, tileSizeOfr] =
+ // Rather than taking the tile size as is, extact the actual constant
+ // value Attribute where possible, e.g.:
+ // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
+ auto [_, tileSize] =
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
- extractSliceSizes.push_back(tileSizeOfr);
- outputShapeForExtractSlice.push_back(tileSize);
+ tileSizes.push_back(tileSize);
}
}
- Type elemType = packOp.getSourceType().getElementType();
- auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);
-
- Value tile = rewriter.create<tensor::ExtractSliceOp>(
- loc, readType, input, readOffsets, extractSliceSizes, readStrides);
-
- // 2. Transpose the tile to match the inner tile order:
+ // 2. Transpose the input to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
----------------
hanhanW wrote:
This comment needs to be updated. There are no extracted_tile anymore?
https://github.com/llvm/llvm-project/pull/115312
More information about the Mlir-commits
mailing list