[Mlir-commits] [mlir] 28fef90 - [mlir][VectorOps] Fix folding of vector.extract from stretch vector.broadcast
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Aug 4 06:33:10 PDT 2023
Author: Benjamin Maxwell
Date: 2023-08-04T13:32:33Z
New Revision: 28fef902fc9f872fcdcd95237f425f58c9609ab4
URL: https://github.com/llvm/llvm-project/commit/28fef902fc9f872fcdcd95237f425f58c9609ab4
DIFF: https://github.com/llvm/llvm-project/commit/28fef902fc9f872fcdcd95237f425f58c9609ab4.diff
LOG: [mlir][VectorOps] Fix folding of vector.extract from stretch vector.broadcast
Previously, foldExtractFromBroadcast() would incorrectly fold:
func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 {
%0 = vector.broadcast %src : vector<3x1x2xf32> to vector<3x4x2xf32>
%1 = vector.extract %0[0, 2, 0] : vector<3x4x2xf32>
return %1: f32
}
to:
func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 {
%0 = vector.extract %src[0, 2, 0] : vector<3x1x2xf32>
return %0: f32
}
This was due to the wrong offset being used when zeroing the "dim-1"
broadcasted dims. It should use the difference in rank across the
broadcast as the starting offset, as the ranks after that are the ones
that could have been stretched.
Reviewed By: awarzynski, dcaballe
Differential Revision: https://reviews.llvm.org/D157003
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 20bd3f32fac91c..094d555a0e8fc9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1467,18 +1467,21 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return Value();
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
- int64_t rankDiff = broadcastSrcRank - extractResultRank;
+ int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
+
// Detect all the positions that come from "dim-1" broadcasting.
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
// extract position to `0` when extracting from the source operand.
llvm::SetVector<int64_t> broadcastedUnitDims =
broadcastOp.computeBroadcastedUnitDims();
SmallVector<int64_t> extractPos(extractOp.getPosition());
- for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
+ int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
+ for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
if (broadcastedUnitDims.contains(i))
extractPos[i] = 0;
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
// matching extract position when extracting from the source operand.
+ int64_t rankDiff = broadcastSrcRank - extractResultRank;
extractPos.erase(extractPos.begin(),
std::next(extractPos.begin(), extractPos.size() - rankDiff));
// OpBuilder is only used as a helper to build an I64ArrayAttr.
@@ -4953,7 +4956,8 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
- if (auto attr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
+ if (auto attr =
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
if (attr.isSplat())
return attr.reshape(getResultVectorType());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d715f9acbb3c6d..126d8dbc4c1999 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2104,6 +2104,15 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
return %1: vector<1xf32>
}
+// CHECK-LABEL: func.func @extract_from_stretch_broadcast
+func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 {
+ // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0, 0] : vector<3x1x2xf32>
+ // CHECK-NEXT: return %0 : f32
+ %0 = vector.broadcast %src : vector<3x1x2xf32> to vector<3x4x2xf32>
+ %1 = vector.extract %0[0, 2, 0] : vector<3x4x2xf32>
+ return %1: f32
+}
+
// -----
// CHECK-LABEL: func.func @extract_strided_slice_of_constant_mask
func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{
More information about the Mlir-commits
mailing list