[Mlir-commits] [mlir] [mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) (PR #73523)
Andrzej Warzyński
llvmlistbot at llvm.org
Mon Dec 4 13:37:18 PST 2023
================
@@ -542,45 +544,99 @@ class FlattenContiguousRowMajorTransferReadPattern
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
- Value source = transferReadOp.getSource();
+ auto source = transferReadOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+ // 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
+ // If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
- // Already 0D/1D, nothing to do.
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
- int64_t firstContiguousInnerDim =
- sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
if (!transferReadOp.getPermutationMap().isMinorIdentity())
return failure();
if (transferReadOp.getMask())
return failure();
+
SmallVector<Value> collapsedIndices;
- if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
- firstContiguousInnerDim,
- collapsedIndices)))
- return failure();
+ int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+
+ // 1. Collapse the source memref
Value collapsedSource =
- collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+ collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
dyn_cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
- assert(collapsedRank == firstContiguousInnerDim + 1);
+ assert(collapsedRank == firstDimToCollapse + 1);
+
+ // 2. Generate input args for a new vector.transfer_read that will read
+ // from the collapsed memref.
+ // 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
- getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+ getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+
+ // 2.2 New indices
+ // If all the collapsed indices are zero then no extra logic is needed.
+ // Otherwise, a new offset/index has to be computed.
+ if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
+ firstDimToCollapse,
+ collapsedIndices))) {
+ // Copy all the leading indices
+ collapsedIndices = transferReadOp.getIndices();
+ collapsedIndices.resize(firstDimToCollapse);
+
+ // Compute the remaining trailing index/offset required for reading from
+ // the collapsed memref:
+ //
+ // offset = 0
+ // for (i = firstDimToCollapse; i < outputRank; ++i)
+ // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
+ //
+ // For this example:
+ // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
+ // memref<1x43x2xi32>, vector<1x2xi32>
+ // which would be collapsed to:
+ // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
+ // memref<1x86xi32>, vector<2xi32>
+ // one would get the following offset:
+ // %offset = %arg0 * 43
+ AffineExpr offsetE, idx;
+ bindSymbols(rewriter.getContext(), offsetE, idx);
+
+ int64_t outputRank = transferReadOp.getIndices().size();
+ OpFoldResult offset =
+ rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
+ for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
+ int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
----------------
banach-space wrote:
> What happen if we have dynamic shapes in source type?
There are 2 separate cases.
**Case 1**
One of the leading dims is dynamic:
```
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<?x43x4x6xi32>, vector<1x2x6xi32>
```
In this case the dynamic dimension is not used in the offset calculation (it's not relevant - we only look at the trailing dims).
**Case 2**
One of the trailing dims is dynamic:
```mlir
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x?x4x6xi32>, vector<1x2x6xi32>
```
This should actually be fine, but I'd rather disable it for now - see updates in `vector::isContiguousSlice` (I am about to send them, I will also add a test). I'd rather do it carefully rather than in a rush past my bed time 😅 .
https://github.com/llvm/llvm-project/pull/73523
More information about the Mlir-commits
mailing list