[Mlir-commits] [mlir] [mlir][vector] Drop innermost unit dims on transfer_write. (PR #78554)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Jan 18 10:38:21 PST 2024


================
@@ -1152,8 +1152,71 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
   }
 };
 
-// Drop inner most contiguous unit dimensions from transfer_read operand.
-class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
+/// Returns the number of dims can be folded away from transfer ops. It returns
+/// a failure if strides and offsets can not be resolved.
+static FailureOr<size_t>
+getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
+  SmallVector<int64_t> srcStrides;
+  int64_t srcOffset;
+  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+    return failure();
+
+  // According to vector.transfer_read/write semantics, the vector can be a
+  // slice. It pads the indices with `1` starting from beginning. Thus, we have
+  // to offset the check index with `rankDiff` in `srcStrides` and source dim
+  // sizes.
+  size_t result = 0;
+  int rankDiff = srcType.getRank() - vectorType.getRank();
+  for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
+    // Check that the inner dim size is 1 for both memref/tensor type and
+    // vector slice. It can be folded only if they are 1 and the stride is 1.
+    int dim = vectorType.getRank() - i - 1;
+    if (srcStrides[dim + rankDiff] == 1 &&
+        srcType.getDimSize(dim + rankDiff) == 1 &&
+        vectorType.getDimSize(dim) == 1) {
+      result++;
+    } else {
+      break;
+    }
+  }
+  return result;
+}
+
+/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
+/// `srcType`.
+static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
+                                                     MemRefType srcType,
+                                                     size_t dimsToDrop) {
+  MemRefType resultMemrefType;
+  MemRefLayoutAttrInterface layout = srcType.getLayout();
+  if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
+    return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+                           srcType.getElementType(), nullptr,
+                           srcType.getMemorySpace());
+  }
+  MemRefLayoutAttrInterface updatedLayout;
+  if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
+    auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+    updatedLayout = StridedLayoutAttr::get(strided.getContext(),
+                                           strided.getOffset(), strides);
+  } else {
----------------
banach-space wrote:

[nit] I'd avoid `else` statements - things get tricky to follow once the code grows. Instead, I'd do this:
```cpp 
  if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
    auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
    updatedLayout = StridedLayoutAttr::get(strided.getContext(),
                                           strided.getOffset(), strides);
    
    return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
                         srcType.getElementType(), updatedLayout,
                         srcType.getMemorySpace());                                 
  }
    
  // Non-strided layout case
  AffineMap map = srcType.getLayout().getAffineMap();
  int numSymbols = map.getNumSymbols();
  for (size_t i = 0; i < dimsToDrop; ++i) {
      int dim = srcType.getRank() - i - 1;
      map = map.replace(builder.getAffineDimExpr(dim),
                        builder.getAffineConstantExpr(0), map.getNumDims() - 1,
                        numSymbols);
  }

  return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
                         srcType.getElementType(), updatedLayout,
                         srcType.getMemorySpace());
```

I think that it would better align with https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code

https://github.com/llvm/llvm-project/pull/78554


More information about the Mlir-commits mailing list