[Mlir-commits] [mlir] [mlir][vector] Improve flattening vector.transfer_write ops. (PR #94051)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Jun 3 14:37:27 PDT 2024
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/94051
>From 0d76610415d1e53b054f6273f21ccdee8718e61e Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 31 May 2024 14:40:56 -0700
Subject: [PATCH 1/2] [mlir][vector] Improve flattening vector.transfer_write
ops.
We can flatten the transfer ops even when the collapsed indices are not
zeros. We can compute it. It is already supported in
vector.transfer_read cases. The revision refactors the logic and reuse
it in transfer_write cases.
---
.../Transforms/VectorTransferOpTransforms.cpp | 136 +++++++++---------
.../Vector/vector-transfer-flatten.mlir | 18 +--
2 files changed, 75 insertions(+), 79 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 997b56a1ce142..ae47d7fc28811 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -510,20 +510,59 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
/// TODO: Extract the logic that writes to outIndices so that this method
/// simply checks one pre-condition.
-static LogicalResult
-checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
- SmallVector<Value> &outIndices) {
- int64_t rank = indices.size();
- if (firstDimToCollapse >= rank)
- return failure();
- for (int64_t i = firstDimToCollapse; i < rank; ++i) {
- std::optional<int64_t> cst = getConstantIntValue(indices[i]);
- if (!cst || cst.value() != 0)
- return failure();
+static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
+ Location loc,
+ ArrayRef<int64_t> shape,
+ ValueRange indices,
+ int64_t firstDimToCollapse) {
+ assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
+
+ // If all the collapsed indices are zero then no extra logic is needed.
+ // Otherwise, a new offset/index has to be computed.
+ SmallVector<Value> collapsedIndices(indices.begin(),
+ indices.begin() + firstDimToCollapse);
+ SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
+ indices.end());
+ if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
+ collapsedIndices.push_back(indicesToCollapse[0]);
+ return collapsedIndices;
+ }
+
+ // Compute the remaining trailing index/offset required for reading from
+ // the collapsed memref:
+ //
+ // offset = 0
+ // for (i = firstDimToCollapse; i < outputRank; ++i)
+ // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
+ //
+ // For this example:
+ // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
+ // memref<1x43x2xi32>, vector<1x2xi32>
+ // which would be collapsed to:
+ // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
+ // memref<1x86xi32>, vector<2xi32>
+ // one would get the following offset:
+ // %offset = %arg0 * 43
+ OpFoldResult collapsedOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
+
+ auto collapsedStrides = computeSuffixProduct(
+ ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
+
+ // Compute the collapsed offset.
+ auto &&[collapsedExpr, collapsedVals] =
+ computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
+ collapsedOffset = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, collapsedExpr, collapsedVals);
+
+ if (collapsedOffset.is<Value>()) {
+ collapsedIndices.push_back(collapsedOffset.get<Value>());
+ } else {
+ collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
+ loc, *getConstantIntValue(collapsedOffset)));
}
- outIndices = indices;
- outIndices.resize(firstDimToCollapse + 1);
- return success();
+
+ return collapsedIndices;
}
namespace {
@@ -594,54 +633,9 @@ class FlattenContiguousRowMajorTransferReadPattern
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
// 2.2 New indices
- // If all the collapsed indices are zero then no extra logic is needed.
- // Otherwise, a new offset/index has to be computed.
- SmallVector<Value> collapsedIndices;
- if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
- firstDimToCollapse,
- collapsedIndices))) {
- // Copy all the leading indices.
- SmallVector<Value> indices = transferReadOp.getIndices();
- collapsedIndices.append(indices.begin(),
- indices.begin() + firstDimToCollapse);
-
- // Compute the remaining trailing index/offset required for reading from
- // the collapsed memref:
- //
- // offset = 0
- // for (i = firstDimToCollapse; i < outputRank; ++i)
- // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
- //
- // For this example:
- // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
- // memref<1x43x2xi32>, vector<1x2xi32>
- // which would be collapsed to:
- // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
- // memref<1x86xi32>, vector<2xi32>
- // one would get the following offset:
- // %offset = %arg0 * 43
- OpFoldResult collapsedOffset =
- rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
-
- auto sourceShape = sourceType.getShape();
- auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
- sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
-
- // Compute the collapsed offset.
- ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
- indices.end());
- auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
- collapsedOffset, collapsedStrides, indicesToCollapse);
- collapsedOffset = affine::makeComposedFoldedAffineApply(
- rewriter, loc, collapsedExpr, collapsedVals);
-
- if (collapsedOffset.is<Value>()) {
- collapsedIndices.push_back(collapsedOffset.get<Value>());
- } else {
- collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
- loc, *getConstantIntValue(collapsedOffset)));
- }
- }
+ SmallVector<Value> collapsedIndices =
+ getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+ transferReadOp.getIndices(), firstDimToCollapse);
// 3. Create new vector.transfer_read that reads from the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
@@ -697,8 +691,7 @@ class FlattenContiguousRowMajorTransferWritePattern
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
- int64_t firstContiguousInnerDim =
- sourceType.getRank() - vectorType.getRank();
+ int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
@@ -706,22 +699,23 @@ class FlattenContiguousRowMajorTransferWritePattern
return failure();
if (transferWriteOp.getMask())
return failure();
- SmallVector<Value> collapsedIndices;
- if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
- firstContiguousInnerDim,
- collapsedIndices)))
- return failure();
+
+ SmallVector<Value> collapsedIndices =
+ getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+ transferWriteOp.getIndices(), firstDimToCollapse);
Value collapsedSource =
- collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+ collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
- assert(collapsedRank == firstContiguousInnerDim + 1);
+ assert(collapsedRank == firstDimToCollapse + 1);
+
SmallVector<AffineExpr, 1> dimExprs{
- getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+ getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
Value flatVector =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..b5f29b2ac958b 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -471,16 +471,16 @@ func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, str
}
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
-// CHECK-LABEL: func.func @regression_non_contiguous_dim_read(
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK-LABEL: func.func @regression_non_contiguous_dim_read(
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
// CHECK-128B-LABEL: func @regression_non_contiguous_dim_read(
// CHECK-128B: memref.collapse_shape
// -----
-func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
+func.func @regression_non_contiguous_dim_write(%value : vector<2x2xf32>,
%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
%idx0 : index, %idx1 : index) {
%c0 = arith.constant 0 : index
@@ -488,8 +488,10 @@ func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
return
}
-// CHECK-LABEL: func.func @unsupported_non_contiguous_dim_write(
-// CHECK-NOT: memref.collapse_shape
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func.func @regression_non_contiguous_dim_write(
+// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-// CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write(
-// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-LABEL: func @regression_non_contiguous_dim_write(
+// CHECK-128B: memref.collapse_shape
>From 0bb49f1ef3b5859df1fb84f2f07a1da904f21974 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 3 Jun 2024 14:37:14 -0700
Subject: [PATCH 2/2] address comments
---
.../Transforms/VectorTransferOpTransforms.cpp | 21 +++++++---------
.../Vector/vector-transfer-flatten.mlir | 25 +++++++++++++++++++
2 files changed, 34 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index ae47d7fc28811..c131fde517f80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -505,11 +505,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
}
-/// Checks that the indices corresponding to dimensions starting at
-/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
-/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
-/// TODO: Extract the logic that writes to outIndices so that this method
-/// simply checks one pre-condition.
+/// Returns the new indices that collapses the inner dimensions starting from
+/// the `firstDimToCollapse` dimension.
static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
Location loc,
ArrayRef<int64_t> shape,
@@ -519,13 +516,13 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
// If all the collapsed indices are zero then no extra logic is needed.
// Otherwise, a new offset/index has to be computed.
- SmallVector<Value> collapsedIndices(indices.begin(),
- indices.begin() + firstDimToCollapse);
+ SmallVector<Value> indicesAfterCollapsing(
+ indices.begin(), indices.begin() + firstDimToCollapse);
SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
indices.end());
if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
- collapsedIndices.push_back(indicesToCollapse[0]);
- return collapsedIndices;
+ indicesAfterCollapsing.push_back(indicesToCollapse[0]);
+ return indicesAfterCollapsing;
}
// Compute the remaining trailing index/offset required for reading from
@@ -556,13 +553,13 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
rewriter, loc, collapsedExpr, collapsedVals);
if (collapsedOffset.is<Value>()) {
- collapsedIndices.push_back(collapsedOffset.get<Value>());
+ indicesAfterCollapsing.push_back(collapsedOffset.get<Value>());
} else {
- collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
+ indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
loc, *getConstantIntValue(collapsedOffset)));
}
- return collapsedIndices;
+ return indicesAfterCollapsing;
}
namespace {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index b5f29b2ac958b..65bf0b9335d28 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -495,3 +495,28 @@ func.func @regression_non_contiguous_dim_write(%value : vector<2x2xf32>,
// CHECK-128B-LABEL: func @regression_non_contiguous_dim_write(
// CHECK-128B: memref.collapse_shape
+
+// -----
+
+func.func @negative_out_of_bound_transfer_read(
+ %arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} :
+ memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
+}
+// CHECK: func.func @negative_out_of_bound_transfer_read
+// CHECK-NOT: memref.collapse_shape
+
+// -----
+
+func.func @negative_out_of_bound_transfer_write(
+ %arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x3x2xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} :
+ vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
+}
+// CHECK: func.func @negative_out_of_bound_transfer_write
+// CHECK-NOT: memref.collapse_shape
More information about the Mlir-commits
mailing list