[Mlir-commits] [mlir] [mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) (PR #73523)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Dec 4 02:36:46 PST 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/73523
>From ee5e3550e3ccb3aedb7b728d601e62ff7ce02818 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 25 Nov 2023 16:51:42 +0000
Subject: [PATCH 1/2] [mlir][Vector] Update patterns for flattening vector.xfer
Ops (2/N)
Updates patterns for flattening vector.transfer_read by relaxing the
requirement that the "collapsed" indices are all zero. This enables
collapsing cases like this one:
```mlir
%2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
```
Previously only the following case would be consider for collapsing:
```mlir
%2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
```
The pattern itself, `FlattenContiguousRowMajorTransferReadPattern`, was
a bit refactored too:
* added comments,
* renamed `firstContiguousInnerDim` as `firstDimToCollapse` (the
latter better matches the meaning and is already consistently used
in various helper methods that use it),
Similar update for `vector.transfer_write` will be implemented in a
follow-up patch.
---
.../Transforms/VectorTransferOpTransforms.cpp | 72 ++++++++++++++++---
.../Vector/vector-transfer-flatten.mlir | 32 +++++++++
2 files changed, 94 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index aab7075006031..015c0cc011a30 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -511,6 +511,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
/// Checks that the indices corresponding to dimensions starting at
/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
+/// TODO: Extract the logic that writes to outIndices so that this method
+/// simply checks one pre-condition.
static LogicalResult
checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
SmallVector<Value> &outIndices) {
@@ -544,16 +546,16 @@ class FlattenContiguousRowMajorTransferReadPattern
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferReadOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+ // 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
+ // If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
- // Already 0D/1D, nothing to do.
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
- int64_t firstContiguousInnerDim =
- sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
@@ -561,26 +563,76 @@ class FlattenContiguousRowMajorTransferReadPattern
return failure();
if (transferReadOp.getMask())
return failure();
+
SmallVector<Value> collapsedIndices;
- if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
- firstContiguousInnerDim,
- collapsedIndices)))
- return failure();
+ int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+
+ // 1. Collapse the source memref
Value collapsedSource =
- collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+ collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
dyn_cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
- assert(collapsedRank == firstContiguousInnerDim + 1);
+ assert(collapsedRank == firstDimToCollapse + 1);
+
+ // 2. Generate input args for a new vector.transfer_read that will read
+ // from the collapsed memref.
+ // 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
- getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+ getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+
+ // 2.2 New indices
+ // If all the collapsed indices are zero then no extra logic is needed.
+ // Otherwise, a new offset/index has to be computed.
+ if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
+ firstDimToCollapse,
+ collapsedIndices))) {
+ // Copy all the leading indices
+ collapsedIndices = transferReadOp.getIndices();
+ collapsedIndices.resize(firstDimToCollapse);
+
+ // Compute the remaining trailing index/offset required for reading from
+ // the collapsed memref:
+ //
+ // offset = 0
+ // for (i = firstDimToCollapse; i < outputRank; ++i)
+ // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
+ //
+ // For this example:
+ // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
+ // memref<1x43x2xi32>, vector<1x2xi32>
+ // which would be collapsed to:
+ // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
+ // memref<1x86xi32>, vector<2xi32>
+ // one would get the following offset:
+ // %offset = %arg0 * 43
+ int64_t outputRank = transferReadOp.getIndices().size();
+ Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
+ Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ auto sourceDimSize =
+ rewriter.create<memref::DimOp>(loc, source, dimIdx);
+
+ offset = rewriter.create<arith::AddIOp>(
+ loc,
+ rewriter.create<arith::MulIOp>(loc, transferReadOp.getIndices()[i],
+ sourceDimSize),
+ offset);
+ }
+ collapsedIndices.push_back(offset);
+ }
+
+ // 3. Create new vector.transfer_read that reads from the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+ // 4. Replace the old transfer_read with the new one reading from the
+ // collapsed shape
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transferReadOp, cast<VectorType>(vector.getType()), flatRead);
return success();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 2ffe85bf3bfa6..a882ea1f4291c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -41,6 +41,38 @@ func.func @transfer_read_dims_mismatch_contiguous(
// -----
+func.func @transfer_read_dims_mismatch_non_zero_indices(
+ %idx_1: index,
+ %idx_2: index,
+ %m_in: memref<1x43x4x6xi32>,
+ %m_out: memref<1x2x6xi32>) {
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ memref<1x43x4x6xi32>, vector<1x2x6xi32>
+ vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+ vector<1x2x6xi32>, memref<1x2x6xi32>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
+// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<1x43x4x6xi32>,
+// CHECK-SAME: %[[VAL_3:.*]]: memref<1x2x6xi32>) {
+// CHECK: %[[VAL_4:.*]] = arith.constant 43 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_0]], %[[VAL_4]] : index
+// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_1]], %[[VAL_5]] : index
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_9]] : index
+// CHECK: %[[VAL_12:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_7]], %[[VAL_11]]], %[[VAL_6]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK: %[[VAL_13:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
+// CHECK: vector.transfer_write %[[VAL_12]], %[[VAL_13]]{{\[}}%[[VAL_7]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
+
+// -----
+
func.func @transfer_read_dims_mismatch_non_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
%c0 = arith.constant 0 : index
>From b27c49d60759cf890ffe99da3a94f1a281c8e537 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 4 Dec 2023 10:32:35 +0000
Subject: [PATCH 2/2] fixup! [mlir][Vector] Update patterns for flattening
vector.xfer Ops (2/N)
Refactor to use makeComposedFoldedAffineApply
---
.../Transforms/VectorTransferOpTransforms.cpp | 32 +++++++++++--------
.../Vector/vector-transfer-flatten.mlir | 20 ++++++------
.../Dialect/Vector/TestVectorTransforms.cpp | 1 +
3 files changed, 28 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 015c0cc011a30..a404307d6a8b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -544,7 +544,7 @@ class FlattenContiguousRowMajorTransferReadPattern
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
- Value source = transferReadOp.getSource();
+ auto source = transferReadOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// 0. Check pre-conditions
@@ -602,26 +602,30 @@ class FlattenContiguousRowMajorTransferReadPattern
//
// For this example:
// %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
- // memref<1x43x2xi32>, vector<1x2xi32>
+ // memref<1x43x2xi32>, vector<1x2xi32>
// which would be collapsed to:
// %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
- // memref<1x86xi32>, vector<2xi32>
+ // memref<1x86xi32>, vector<2xi32>
// one would get the following offset:
// %offset = %arg0 * 43
+ AffineExpr offsetE, idx;
+ bindSymbols(rewriter.getContext(), offsetE, idx);
+
int64_t outputRank = transferReadOp.getIndices().size();
- Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ OpFoldResult offset =
+ rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
- Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
- auto sourceDimSize =
- rewriter.create<memref::DimOp>(loc, source, dimIdx);
-
- offset = rewriter.create<arith::AddIOp>(
- loc,
- rewriter.create<arith::MulIOp>(loc, transferReadOp.getIndices()[i],
- sourceDimSize),
- offset);
+ int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
+ offset = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, offsetE + dim * idx,
+ {offset, transferReadOp.getIndices()[i]});
+ }
+ if (offset.is<Value>()) {
+ collapsedIndices.push_back(offset.get<Value>());
+ } else {
+ collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
+ loc, *getConstantIntValue(offset)));
}
- collapsedIndices.push_back(offset);
}
// 3. Create new vector.transfer_read that reads from the collapsed memref
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index a882ea1f4291c..8ce96bde8e2e1 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -55,21 +55,19 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
return
}
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
// CHECK-SAME: %[[VAL_2:.*]]: memref<1x43x4x6xi32>,
// CHECK-SAME: %[[VAL_3:.*]]: memref<1x2x6xi32>) {
-// CHECK: %[[VAL_4:.*]] = arith.constant 43 : index
-// CHECK: %[[VAL_5:.*]] = arith.constant 4 : index
-// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
-// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_0]], %[[VAL_4]] : index
-// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_1]], %[[VAL_5]] : index
-// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_9]] : index
-// CHECK: %[[VAL_12:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_7]], %[[VAL_11]]], %[[VAL_6]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
-// CHECK: %[[VAL_13:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
-// CHECK: vector.transfer_write %[[VAL_12]], %[[VAL_13]]{{\[}}%[[VAL_7]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK: %[[VAL_7:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_1]], %[[VAL_0]]]
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_5]], %[[VAL_7]]], %[[VAL_4]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK: %[[VAL_9:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
+// CHECK: vector.transfer_write %[[VAL_8]], %[[VAL_9]]{{\[}}%[[VAL_5]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
// -----
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index feb716cdbf404..86b8d5f9b0995 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -454,6 +454,7 @@ struct TestFlattenVectorTransferPatterns
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<memref::MemRefDialect>();
+ registry.insert<affine::AffineDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
More information about the Mlir-commits
mailing list