[Mlir-commits] [mlir] [mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs (PR #72142)
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 15 06:11:27 PST 2023
================
@@ -277,30 +285,37 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
mlir::Location loc,
Value input) {
MemRefType inputType = cast<MemRefType>(input.getType());
- assert(inputType.hasStaticShape());
- SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
- SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
- ArrayRef<int64_t> subViewSizes = inputType.getShape();
- MemRefType resultType =
- dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
+ SmallVector<OpFoldResult> offsets(inputType.getRank(),
+ rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
+ SmallVector<OpFoldResult> strides(inputType.getRank(),
+ rewriter.getIndexAttr(1));
+ MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
+
if (canonicalizeStridedLayout(resultType) ==
canonicalizeStridedLayout(inputType))
return input;
- return rewriter.create<memref::SubViewOp>(
- loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
+ return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
+ sizes, strides);
}
/// Returns the number of dims that aren't unit dims.
static int getReducedRank(ArrayRef<int64_t> shape) {
return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
}
-/// Returns a copy of `shape` without unit dims.
-static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
- SmallVector<int64_t> reducedShape;
- llvm::copy_if(shape, std::back_inserter(reducedShape),
- [](int64_t dimSize) { return dimSize != 1; });
- return reducedShape;
+/// Trims non-scalable one dimensions from `oldType` and returns the result
+/// type.
+static VectorType trimUnitDims(VectorType oldType) {
----------------
nicolasvasilache wrote:
Can we be pedantic about the name and go for `trimNonScalableUnitDims` ?
https://github.com/llvm/llvm-project/pull/72142
More information about the Mlir-commits
mailing list