[Mlir-commits] [mlir] 37a867a - [vector] When trimming leading insertion dimensions, base the final result on the ranks
Benjamin Kramer
llvmlistbot at llvm.org
Tue Apr 18 09:50:31 PDT 2023
Author: Benjamin Kramer
Date: 2023-04-18T18:49:29+02:00
New Revision: 37a867a5a88871f937256d6cf1248eddabd8925e
URL: https://github.com/llvm/llvm-project/commit/37a867a5a88871f937256d6cf1248eddabd8925e
DIFF: https://github.com/llvm/llvm-project/commit/37a867a5a88871f937256d6cf1248eddabd8925e.diff
LOG: [vector] When trimming leading insertion dimensions, base the final result on the ranks
This was incorrect when the number of dropped source dims was smaller
than the number of dropped dst dims. We still need to insert zeros if
there is anything dropped from the src.
Differential Revision: https://reviews.llvm.org/D148636
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 58fe63687b4fd..849e0442bc7e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -169,10 +169,8 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
SmallVector<Attribute> newPositions = llvm::to_vector(
insertOp.getPosition().getValue().take_back(newPosRank));
- if (srcDropCount >= dstDropCount) {
- auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
- newPositions.resize(newPosRank + srcDropCount, zeroAttr);
- }
+ newPositions.resize(newDstType.getRank() - newSrcRank,
+ rewriter.getI64IntegerAttr(0));
auto newInsertOp = rewriter.create<vector::InsertOp>(
loc, newDstType, newSrcVector, newDstVector,
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 587633ac20349..0ee006e3df632 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -316,3 +316,15 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %
%0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32>
return %0: vector<8x1x4xf32>
}
+
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
+// CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
+// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x8xi1>
+// CHECK: %[[EXTRACTV:.+]] = vector.extract %[[V]][0, 0] : vector<1x1x8x1x8xi1>
+// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<8xi1> into vector<8x1x8xi1>
+// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1>
+// CHECK: return %[[BCAST]]
+func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v: vector<1x1x8x1x8xi1>) -> vector<1x1x8x1x8xi1> {
+ %0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1>
+ return %0: vector<1x1x8x1x8xi1>
+}
More information about the Mlir-commits
mailing list