[Mlir-commits] [mlir] [mlir][linalg] Fix and Refactor DecomposeOuterUnitDimsUnPackOpPattern (PR #119379)
Renato Golin
llvmlistbot at llvm.org
Wed Dec 11 03:31:47 PST 2024
================
@@ -1252,64 +1252,88 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
"require the tiled outer dimensions of the result are all 1s");
}
- // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
+ // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
+ // %extracted_tile = tensor.extract_slice(%unpack_op_input)
Location loc = unpackOp.getLoc();
Value source = unpackOp.getSource();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
unpackOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
- SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
- SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
- SmallVector<OpFoldResult> readSizes;
- SmallVector<int64_t> readShape;
- SmallVector<Value> dynamicDims;
+
+ // The sizes, affset and strides attributes for ExtractSliceOp.
+ SmallVector<OpFoldResult> extractSliceSizes;
+ SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
+ SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
+ // The shape for ExtractSliceOp (due to rank-reducing, this is likely !=
+ // extractSliceSizes).
+ SmallVector<int64_t> readShapeForExtractSlice;
+
+ // Shape for EmptyOp that's used as the init value for TransposeOp below.
+ // This should match tile size + transposition.
+ SmallVector<OpFoldResult> shapeForEmptyOp;
+
for (auto i : llvm::seq<unsigned>(0, destRank)) {
+ // Given the assumption that all outer tiled dims are 1, the corresponding
+ // slice size to read is also 1. As this will be rank-reducing "extract
+ // slice" (i.e. the unit dims will be "collapsed"), there's no need to
+ // update:
+ // * the output shape for ExtractSliceOp, nor
+ // * the shape for EmptyOp.
if (dimAndTileMapping.count(i)) {
- readSizes.push_back(oneIdxAttr);
+ extractSliceSizes.push_back(oneIdxAttr);
continue;
}
+ // Compute sizes attribute for ExtractSliceOp + EmptyOp
if (ShapedType::isDynamic(srcShape[i])) {
- Value dynamicDim =
+ OpFoldResult dynamicDim =
rewriter.create<tensor::DimOp>(loc, source, i).getResult();
- readSizes.push_back(dynamicDim);
- dynamicDims.push_back(dynamicDim);
+ extractSliceSizes.push_back(dynamicDim);
+ shapeForEmptyOp.push_back(dynamicDim);
} else {
- readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
+ extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
+ if (srcShape[i] != 1)
+ shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
+ }
+ // Compute the output shape for ExtractSliceOp (take into account
+ // rank-reducing)
+ if (srcShape[i] != 1) {
+ readShapeForExtractSlice.push_back(srcShape[i]);
}
- if (srcShape[i] != 1)
- readShape.push_back(srcShape[i]);
}
auto mixedTiles = unpackOp.getMixedTiles();
- readSizes.append(mixedTiles.begin(), mixedTiles.end());
+ // TODO: This effectively assumes that that tile sizes match the trailing
+ // sizes for ExtractSliceOp and EmptyOp - document this.
----------------
rengolin wrote:
Perhaps a quick check that there's only ever `1`s in the beginning and `mixedTiles` in the end?
https://github.com/llvm/llvm-project/pull/119379
More information about the Mlir-commits
mailing list