[Mlir-commits] [mlir] 942b403 - [mlir] Fix casting of leading unit dims for vector.insert

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 31 05:12:48 PDT 2023


Author: tyb0807
Date: 2023-03-31T12:12:35Z
New Revision: 942b403ff1a412778c9fb97bd53b44e35b544b0e

URL: https://github.com/llvm/llvm-project/commit/942b403ff1a412778c9fb97bd53b44e35b544b0e
DIFF: https://github.com/llvm/llvm-project/commit/942b403ff1a412778c9fb97bd53b44e35b544b0e.diff

LOG: [mlir] Fix casting of leading unit dims for vector.insert

When dropping leading unit dims of vector.insert's operands and creating
a new vector.insert, its new position rank should be computed explicitly
in two steps: first based on the numbers of leading unit dims dropped
from the vector.insert's destination, then based on the numbers of
leading unit dims dropped from its source.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D147280

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
    mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 3715cf0cdcbd4..58fe63687b4fd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -161,13 +161,17 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
     Value newDstVector = rewriter.create<vector::ExtractOp>(
         loc, insertOp.getDest(), splatZero(dstDropCount));
 
+    // New position rank needs to be computed in two steps: (1) if destination
+    // type has leading unit dims, we also trim the position array accordingly,
+    // then (2) if source type also has leading unit dims, we need to append
+    // zeroes to the position array accordingly.
     unsigned oldPosRank = insertOp.getPosition().getValue().size();
-    unsigned newPosRank = newDstType.getRank() - newSrcRank;
+    unsigned newPosRank = std::max<int64_t>(0, oldPosRank - dstDropCount);
     SmallVector<Attribute> newPositions = llvm::to_vector(
         insertOp.getPosition().getValue().take_back(newPosRank));
-    if (newPosRank > oldPosRank) {
+    if (srcDropCount >= dstDropCount) {
       auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
-      newPositions.resize(newPosRank, zeroAttr);
+      newPositions.resize(newPosRank + srcDropCount, zeroAttr);
     }
 
     auto newInsertOp = rewriter.create<vector::InsertOp>(

diff  --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 309ca33e9fdd1..587633ac20349 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -295,6 +295,18 @@ func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vect
   return %0: vector<1x1x4xf32>
 }
 
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
+//  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
+//       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
+//       CHECK:   %[[EXTRACTV:.+]] = vector.extract %[[V]][0] : vector<1x2x1x4xf32>
+//       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<4xf32> into vector<2x1x4xf32>
+//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32>
+//       CHECK:   return %[[BCAST]]
+func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>, %v: vector<1x2x1x4xf32>) -> vector<1x2x1x4xf32> {
+  %0 = vector.insert %s, %v [0, 1] : vector<1x4xf32> into vector<1x2x1x4xf32>
+  return %0: vector<1x2x1x4xf32>
+}
+
 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
 //       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>


        


More information about the Mlir-commits mailing list