[Mlir-commits] [mlir] [mlir][vector] Enable transfer op hoisting with dynamic indices (PR #68500)
Lei Zhang
llvmlistbot at llvm.org
Sat Oct 7 15:55:20 PDT 2023
https://github.com/antiagainst updated https://github.com/llvm/llvm-project/pull/68500
>From 66b22bc865ecd21579bae1b1e568d2174f114d17 Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sat, 7 Oct 2023 15:12:58 -0700
Subject: [PATCH 1/2] [mlir][vector] Enable transfer op hoisting with dynamic
indices
Recent changes (https://github.com/llvm/llvm-project/pull/66930)
disabled vector transfer ops hoisting with view-like intermediate
ops. The recommended way is to fold subview ops into transfer
op indices before invoking hoisting. That would mean now we
see transfer op indices involving dynamic values, instead of
static constant values before with subview ops. Therefore hoisting
won't kick in anymore. This breaks downstream users.
To fix it, this commit enables hoisting transfer ops with dynamic
indices by using `ValueBoundsConstraintSet` to prove ranges are
disjoint in `isDisjointTransferIndices`. Given that utility is
used in many places including op folders, right now we introduce
a flag to it and only set as true for "heavy" transforms in hoisting
and load-store forwarding.
---
.../mlir/Dialect/Vector/IR/VectorOps.h | 19 +++-
.../mlir/Interfaces/ValueBoundsOpInterface.h | 10 ++
.../Dialect/Linalg/Transforms/Hoisting.cpp | 12 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 51 ++++++---
.../Transforms/VectorTransferOpTransforms.cpp | 6 +-
.../lib/Interfaces/ValueBoundsOpInterface.cpp | 29 +++--
mlir/test/Dialect/Linalg/hoisting.mlir | 90 +++++++++++++++
.../Dialect/Vector/vector-transferop-opt.mlir | 103 ++++++++++++++++++
8 files changed, 283 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index fc0c80036ff79ad..9ab20e20d975429 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -105,16 +105,23 @@ bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);
/// op.
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);
-/// Same behavior as `isDisjointTransferSet` but doesn't require the operations
-/// to have the same tensor/memref. This allows comparing operations accessing
-/// different tensors.
+/// Return true if we can prove that the transfer operations access disjoint
+/// memory, without requring the accessed tensor/memref to be the same.
+///
+/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
+/// via ValueBoundsOpInterface.
bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
- VectorTransferOpInterface transferB);
+ VectorTransferOpInterface transferB,
+ bool testDynamicValueUsingBounds = false);
/// Return true if we can prove that the transfer operations access disjoint
-/// memory.
+/// memory, requiring the operations to access the same tensor/memref.
+///
+/// If `testDynamicValueUsingBounds` is true, tries to test dynamic values
+/// via ValueBoundsOpInterface.
bool isDisjointTransferSet(VectorTransferOpInterface transferA,
- VectorTransferOpInterface transferB);
+ VectorTransferOpInterface transferB,
+ bool testDynamicValueUsingBounds = false);
/// Return the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 2687d79aec68ebb..b79c31c5998a62a 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -176,6 +176,16 @@ class ValueBoundsConstraintSet {
presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
StopConditionFn stopCondition = nullptr, bool closedUB = false);
+ /// Compute a constant distance between the given two values. Return "failure"
+ /// if a constant distance could not be determined.
+ ///
+ /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
+ /// index-typed.
+ static FailureOr<int64_t>
+ computeConstantDistance(Value value1, Value value2,
+ std::optional<int64_t> dim1 = std::nullopt,
+ std::optional<int64_t> dim2 = std::nullopt);
+
/// Compute whether the given values/dimensions are equal. Return "failure" if
/// equality could not be determined.
///
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 221bec713b38aa3..cbb2c507de69f9e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -173,16 +173,16 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
if (auto transferWriteUse =
dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
if (!vector::isDisjointTransferSet(
- cast<VectorTransferOpInterface>(transferWrite.getOperation()),
- cast<VectorTransferOpInterface>(
- transferWriteUse.getOperation())))
+ cast<VectorTransferOpInterface>(*transferWrite),
+ cast<VectorTransferOpInterface>(*transferWriteUse),
+ /*testDynamicValueUsingBounds=*/true))
return WalkResult::advance();
} else if (auto transferReadUse =
dyn_cast<vector::TransferReadOp>(use.getOwner())) {
if (!vector::isDisjointTransferSet(
- cast<VectorTransferOpInterface>(transferWrite.getOperation()),
- cast<VectorTransferOpInterface>(
- transferReadUse.getOperation())))
+ cast<VectorTransferOpInterface>(*transferWrite),
+ cast<VectorTransferOpInterface>(*transferReadUse),
+ /*testDynamicValueUsingBounds=*/true))
return WalkResult::advance();
} else {
// Unknown use, we cannot prove that it doesn't alias with the
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 044b6cc07d3d629..34810497ebefa7c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@@ -168,39 +169,63 @@ bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
}
bool mlir::vector::isDisjointTransferIndices(
- VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
+ VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
+ bool testDynamicValueUsingBounds) {
// For simplicity only look at transfer of same type.
if (transferA.getVectorType() != transferB.getVectorType())
return false;
unsigned rankOffset = transferA.getLeadingShapedRank();
for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
- auto indexA = getConstantIntValue(transferA.indices()[i]);
- auto indexB = getConstantIntValue(transferB.indices()[i]);
- // If any of the indices are dynamic we cannot prove anything.
- if (!indexA.has_value() || !indexB.has_value())
- continue;
+ Value indexA = transferA.indices()[i];
+ Value indexB = transferB.indices()[i];
+ std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
+ std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
if (i < rankOffset) {
// For leading dimensions, if we can prove that index are different we
// know we are accessing disjoint slices.
- if (*indexA != *indexB)
- return true;
+ if (cstIndexA.has_value() && cstIndexB.has_value()) {
+ if (*cstIndexA != *cstIndexB)
+ return true;
+ continue;
+ }
+ if (testDynamicValueUsingBounds) {
+ FailureOr<bool> testEqual =
+ ValueBoundsConstraintSet::areEqual(indexA, indexB);
+ if (succeeded(testEqual) && !testEqual.value())
+ return true;
+ }
} else {
// For this dimension, we slice a part of the memref we need to make sure
// the intervals accessed don't overlap.
- int64_t distance = std::abs(*indexA - *indexB);
- if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
- return true;
+ if (cstIndexA.has_value() && cstIndexB.has_value()) {
+ int64_t distance = std::abs(*cstIndexA - *cstIndexB);
+ if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
+ return true;
+ continue;
+ }
+ if (testDynamicValueUsingBounds) {
+ FailureOr<int64_t> computeDist =
+ ValueBoundsConstraintSet::computeConstantDistance(indexA, indexB);
+
+ if (succeeded(computeDist)) {
+ int64_t distance = std::abs(computeDist.value());
+ if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
+ return true;
+ }
+ }
}
}
return false;
}
bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
- VectorTransferOpInterface transferB) {
+ VectorTransferOpInterface transferB,
+ bool testDynamicValueUsingBounds) {
if (transferA.source() != transferB.source())
return false;
- return isDisjointTransferIndices(transferA, transferB);
+ return isDisjointTransferIndices(transferA, transferB,
+ testDynamicValueUsingBounds);
}
// Helper to iterate over n-D vector slice elements. Calculate the next
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 603b88f11c8e007..a5f1b28152b9bde 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -142,7 +142,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
// Don't need to consider disjoint accesses.
if (vector::isDisjointTransferSet(
cast<VectorTransferOpInterface>(write.getOperation()),
- cast<VectorTransferOpInterface>(transferOp.getOperation())))
+ cast<VectorTransferOpInterface>(transferOp.getOperation()),
+ /*testDynamicValueUsingBounds=*/true))
continue;
}
blockingAccesses.push_back(user);
@@ -217,7 +218,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
// the write.
if (vector::isDisjointTransferSet(
cast<VectorTransferOpInterface>(write.getOperation()),
- cast<VectorTransferOpInterface>(read.getOperation())))
+ cast<VectorTransferOpInterface>(read.getOperation()),
+ /*testDynamicValueUsingBounds=*/true))
continue;
if (write.getSource() == read.getSource() &&
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index c00ee0315a9639a..de09e4dd35ecb2c 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -484,25 +484,34 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
return failure();
}
-FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
- std::optional<int64_t> dim1,
- std::optional<int64_t> dim2) {
+FailureOr<int64_t>
+ValueBoundsConstraintSet::computeConstantDistance(Value value1, Value value2,
+ std::optional<int64_t> dim1,
+ std::optional<int64_t> dim2) {
#ifndef NDEBUG
assertValidValueDim(value1, dim1);
assertValidValueDim(value2, dim2);
#endif // NDEBUG
- // Subtract the two values/dimensions from each other. If the result is 0,
- // both are equal.
Builder b(value1.getContext());
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
- FailureOr<int64_t> bound = computeConstantBound(
- presburger::BoundType::EQ, map, {{value1, dim1}, {value2, dim2}});
- if (failed(bound))
+ map = simplifyAffineMap(map);
+ return computeConstantBound(presburger::BoundType::EQ, map,
+ {{value1, dim1}, {value2, dim2}});
+}
+
+FailureOr<bool>
+ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
+ std::optional<int64_t> dim1,
+ std::optional<int64_t> dim2) {
+ // Subtract the two values/dimensions from each other. If the result is 0,
+ // both are equal.
+ FailureOr<int64_t> distance =
+ computeConstantDistance(value1, value2, dim1, dim2);
+ if (failed(distance))
return failure();
- return *bound == 0;
+ return *distance == 0;
}
ValueBoundsConstraintSet::BoundBuilder &
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 7d0c3648c344b1d..efbbf19c2fb8df4 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -872,3 +872,93 @@ transform.sequence failures(propagate) {
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
}
+
+// -----
+
+// Test that we can hoist out read-write pairs whose indices are dynamic values.
+
+// CHECK: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 1)>
+// CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)>
+
+// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
+// CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index)
+
+// CHECK: %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]]
+// CHECK: %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]]
+// CHECK: %2 = vector.transfer_read %[[BUFFER]][%[[I0]], %[[I0]]]
+// CHECK: %3 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[I0]]]
+// CHECK: %4 = vector.transfer_read %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]]
+// CHECK-COUNT-2: scf.for %{{.+}} = {{.+}} -> (vector<4xf32>, vector<4xf32>, vector<4xf32>)
+// CHECK-COUNT-3: "some_use"
+// CHECK-COUNT-2: scf.yield {{.+}} : vector<4xf32>, vector<4xf32>, vector<4xf32>
+// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[PLUS4]]]
+// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[PLUS1]], %[[I0]]]
+// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]]
+
+func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
+ %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+ %cst = arith.constant 0.0 : f32
+ %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
+ %i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
+
+ scf.for %i = %lb to %ub step %step {
+ scf.for %j = %lb to %ub step %step {
+ %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+ // Disjoint leading dim
+ %r1 = vector.transfer_read %buffer[%i1, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+ // Non-overlap trailing dim
+ %r2 = vector.transfer_read %buffer[%i1, %i2], %cst: memref<?x?xf32>, vector<4xf32>
+ %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32>
+ %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32>
+ %u2 = "some_use"(%r2) : (vector<4xf32>) -> vector<4xf32>
+ vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32>
+ vector.transfer_write %u1, %buffer[%i1, %i0] : vector<4xf32>, memref<?x?xf32>
+ vector.transfer_write %u2, %buffer[%i1, %i2] : vector<4xf32>, memref<?x?xf32>
+ }
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+// Test that we cannot hoist out read-write pairs whose indices are overlapping.
+
+// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_overlapping_dynamic
+// CHECK-COUNT-2: scf.for
+// CHECK-COUNT-2: vector.transfer_read
+// CHECK-COUNT-2: vector.transfer_write
+
+func.func @hoist_vector_transfer_pairs_overlapping_dynamic(
+ %buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
+ %cst = arith.constant 0.0 : f32
+ %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
+
+ scf.for %i = %lb to %ub step %step {
+ scf.for %j = %lb to %ub step %step {
+ %r0 = vector.transfer_read %buffer[%i0, %i0], %cst: memref<?x?xf32>, vector<4xf32>
+ // Overlapping range with the above
+ %r1 = vector.transfer_read %buffer[%i0, %i1], %cst: memref<?x?xf32>, vector<4xf32>
+ %u0 = "some_use"(%r0) : (vector<4xf32>) -> vector<4xf32>
+ %u1 = "some_use"(%r1) : (vector<4xf32>) -> vector<4xf32>
+ vector.transfer_write %u0, %buffer[%i0, %i0] : vector<4xf32>, memref<?x?xf32>
+ vector.transfer_write %u1, %buffer[%i0, %i1] : vector<4xf32>, memref<?x?xf32>
+ }
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+}
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index f43367ab4aeba7d..d0daa53667fab46 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -256,3 +256,106 @@ func.func @collapse_shape(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
}
return
}
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_same_index
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_store_dynamic_same_index(
+ %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i : index) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ vector.transfer_write %v0, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%i, %i], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+ %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+ %1 = arith.addf %acc, %acc : vector<4xf32>
+ scf.yield %1 : vector<4xf32>
+ }
+ vector.transfer_write %x, %buffer[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ return
+}
+
+// CHECK-LABEL: func @dont_forward_dead_store_dynamic_overlap
+// CHECK-COUNT-2: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @dont_forward_dead_store_dynamic_overlap(
+ %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
+ vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ // The following transfer op writes to an overlapping range so we cannot forward.
+ vector.transfer_write %v0, %buffer[%i0, %i1] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+ %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+ %1 = arith.addf %acc, %acc : vector<4xf32>
+ scf.yield %1 : vector<4xf32>
+ }
+ vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ return
+}
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_non_overlap_leading_dim
+// CHECK: vector.transfer_write
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_store_dynamic_non_overlap_leading_dim(
+ %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
+ vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ // The following transfer op writes to an non-overlapping range so we can forward.
+ vector.transfer_write %v0, %buffer[%i1, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+ %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+ %1 = arith.addf %acc, %acc : vector<4xf32>
+ scf.yield %1 : vector<4xf32>
+ }
+ vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ return
+}
+
+// CHECK-LABEL: func @forward_dead_store_dynamic_non_overlap_trailing_dim
+// CHECK: vector.transfer_write
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
+ %buffer : memref<?x?xf32>, %v0 : vector<4xf32>, %v1 : vector<4xf32>, %i0 : index) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %i1 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
+ vector.transfer_write %v0, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ // The following transfer op writes to an non-overlapping range so we can forward.
+ vector.transfer_write %v0, %buffer[%i0, %i1] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ %0 = vector.transfer_read %buffer[%i0, %i0], %cf0 {in_bounds = [true]} : memref<?x?xf32>, vector<4xf32>
+ %x = scf.for %iv = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<4xf32>) {
+ %1 = arith.addf %acc, %acc : vector<4xf32>
+ scf.yield %1 : vector<4xf32>
+ }
+ vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
+ return
+}
>From d6fb718f0560ddacfd4261ba3467062dc7e28f33 Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sat, 7 Oct 2023 15:54:31 -0700
Subject: [PATCH 2/2] Add ValueBoundsOpInterface as dependency in BUILD
---
mlir/lib/Dialect/Vector/IR/CMakeLists.txt | 1 +
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 +
2 files changed, 2 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 9ec919423b3428f..5954ae7557db89c 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -22,5 +22,6 @@ add_mlir_dialect_library(MLIRVectorDialect
MLIRMemRefDialect
MLIRSideEffectInterfaces
MLIRTensorDialect
+ MLIRValueBoundsOpInterface
MLIRVectorInterfaces
)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 1dfba7de465a5ae..3729f523f7d1419 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4335,6 +4335,7 @@ cc_library(
":SideEffectInterfaces",
":Support",
":TensorDialect",
+ ":ValueBoundsOpInterface",
":VectorInterfaces",
":VectorAttributesIncGen",
":VectorDialectIncGen",
More information about the Mlir-commits
mailing list