[Mlir-commits] [mlir] 942b403 - [mlir] Fix casting of leading unit dims for vector.insert
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 31 05:12:48 PDT 2023
Author: tyb0807
Date: 2023-03-31T12:12:35Z
New Revision: 942b403ff1a412778c9fb97bd53b44e35b544b0e
URL: https://github.com/llvm/llvm-project/commit/942b403ff1a412778c9fb97bd53b44e35b544b0e
DIFF: https://github.com/llvm/llvm-project/commit/942b403ff1a412778c9fb97bd53b44e35b544b0e.diff
LOG: [mlir] Fix casting of leading unit dims for vector.insert
When dropping leading unit dims of vector.insert's operands and creating
a new vector.insert, its new position rank should be computed explicitly
in two steps: first based on the numbers of leading unit dims dropped
from the vector.insert's destination, then based on the numbers of
leading unit dims dropped from its source.
Reviewed By: pifon2a
Differential Revision: https://reviews.llvm.org/D147280
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 3715cf0cdcbd4..58fe63687b4fd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -161,13 +161,17 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
Value newDstVector = rewriter.create<vector::ExtractOp>(
loc, insertOp.getDest(), splatZero(dstDropCount));
+ // New position rank needs to be computed in two steps: (1) if destination
+ // type has leading unit dims, we also trim the position array accordingly,
+ // then (2) if source type also has leading unit dims, we need to append
+ // zeroes to the position array accordingly.
unsigned oldPosRank = insertOp.getPosition().getValue().size();
- unsigned newPosRank = newDstType.getRank() - newSrcRank;
+ unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
SmallVector<Attribute> newPositions = llvm::to_vector(
insertOp.getPosition().getValue().take_back(newPosRank));
- if (newPosRank > oldPosRank) {
+ if (srcDropCount >= dstDropCount) {
auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
- newPositions.resize(newPosRank, zeroAttr);
+ newPositions.resize(newPosRank + srcDropCount, zeroAttr);
}
auto newInsertOp = rewriter.create<vector::InsertOp>(
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 309ca33e9fdd1..587633ac20349 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -295,6 +295,18 @@ func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vect
return %0: vector<1x1x4xf32>
}
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
+// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
+// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
+// CHECK: %[[EXTRACTV:.+]] = vector.extract %[[V]][0] : vector<1x2x1x4xf32>
+// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<4xf32> into vector<2x1x4xf32>
+// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32>
+// CHECK: return %[[BCAST]]
+func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>, %v: vector<1x2x1x4xf32>) -> vector<1x2x1x4xf32> {
+ %0 = vector.insert %s, %v [0, 1] : vector<1x4xf32> into vector<1x2x1x4xf32>
+ return %0: vector<1x2x1x4xf32>
+}
+
// CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
More information about the Mlir-commits
mailing list