[Mlir-commits] [mlir] [mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs (PR #72142)

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 15 06:11:27 PST 2023


================
@@ -277,30 +285,37 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
                                                  mlir::Location loc,
                                                  Value input) {
   MemRefType inputType = cast<MemRefType>(input.getType());
-  assert(inputType.hasStaticShape());
-  SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
-  SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
-  ArrayRef<int64_t> subViewSizes = inputType.getShape();
-  MemRefType resultType =
-      dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
+  SmallVector<OpFoldResult> offsets(inputType.getRank(),
+                                    rewriter.getIndexAttr(0));
+  SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
+  SmallVector<OpFoldResult> strides(inputType.getRank(),
+                                    rewriter.getIndexAttr(1));
+  MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
+
   if (canonicalizeStridedLayout(resultType) ==
       canonicalizeStridedLayout(inputType))
     return input;
-  return rewriter.create<memref::SubViewOp>(
-      loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
+  return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
+                                            sizes, strides);
 }
 
 /// Returns the number of dims that aren't unit dims.
 static int getReducedRank(ArrayRef<int64_t> shape) {
   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
 }
 
-/// Returns a copy of `shape` without unit dims.
-static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
-  SmallVector<int64_t> reducedShape;
-  llvm::copy_if(shape, std::back_inserter(reducedShape),
-                [](int64_t dimSize) { return dimSize != 1; });
-  return reducedShape;
+/// Trims non-scalable one dimensions from `oldType` and returns the result
+/// type.
+static VectorType trimUnitDims(VectorType oldType) {
----------------
nicolasvasilache wrote:

Can we be pedantic about the name and go for `trimNonScalableUnitDims` ?

https://github.com/llvm/llvm-project/pull/72142


More information about the Mlir-commits mailing list