[Mlir-commits] [mlir] [mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs (PR #72142)

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 15 06:11:28 PST 2023


================
@@ -260,14 +260,22 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
   opToErase.push_back(read.getOperation());
 }
 
+/// Returns a copy of `shape` without unit dims.
+static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
+  SmallVector<int64_t> reducedShape;
+  llvm::copy_if(shape, std::back_inserter(reducedShape),
+                [](int64_t dimSize) { return dimSize != 1; });
+  return reducedShape;
+}
+
 /// Drops unit dimensions from the input MemRefType.
-static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
-                               ArrayRef<int64_t> sizes,
-                               ArrayRef<int64_t> strides) {
-  SmallVector<int64_t> targetShape = llvm::to_vector(
-      llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
+static MemRefType dropUnitDims(MemRefType inputType,
+                               ArrayRef<OpFoldResult> offsets,
+                               ArrayRef<OpFoldResult> sizes,
+                               ArrayRef<OpFoldResult> strides) {
   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
-      targetShape, inputType, offsets, sizes, strides);
+      getReducedShape(inputType.getShape()), inputType, offsets, sizes,
----------------
nicolasvasilache wrote:

Something changed here: getReducedShape used to be called on sizes but is now called on inputType.getShape.
I understand that in the existing cases sizes is inputType.getShape but it seems the "targetShape" is logically tied to the `sizes` not the `inputType`.
Can we evolve `getReducedShape` to take `sizes` instead?

https://github.com/llvm/llvm-project/pull/72142


More information about the Mlir-commits mailing list