[Mlir-commits] [mlir] e54236d - [mlir][vector] Cast away leading one dims for insert ops

Lei Zhang llvmlistbot at llvm.org
Thu Apr 14 06:00:51 PDT 2022


Author: Lei Zhang
Date: 2022-04-14T08:57:32-04:00
New Revision: e54236dfb5982bc8358bad62a27e6048f06a0272

URL: https://github.com/llvm/llvm-project/commit/e54236dfb5982bc8358bad62a27e6048f06a0272
DIFF: https://github.com/llvm/llvm-project/commit/e54236dfb5982bc8358bad62a27e6048f06a0272.diff

LOG: [mlir][vector] Cast away leading one dims for insert ops

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D123621

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
    mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index d555c60439f71..0688a405491ab 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
@@ -37,7 +38,7 @@ static SmallVector<int64_t> splatZero(int64_t rank) {
 namespace {
 
 // Casts away leading one dimensions in vector.extract_strided_slice's vector
-// input by inserting vector.shape_cast.
+// input by inserting vector.broadcast.
 struct CastAwayExtractStridedSliceLeadingOneDim
     : public OpRewritePattern<vector::ExtractStridedSliceOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -84,8 +85,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
   }
 };
 
-// Casts away leading one dimensions in vector.extract_strided_slice's vector
-// inputs by inserting vector.shape_cast.
+// Casts away leading one dimensions in vector.insert_strided_slice's vector
+// inputs by inserting vector.broadcast.
 struct CastAwayInsertStridedSliceLeadingOneDim
     : public OpRewritePattern<vector::InsertStridedSliceOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -125,6 +126,61 @@ struct CastAwayInsertStridedSliceLeadingOneDim
   }
 };
 
+// Casts away leading one dimensions in vector.insert's vector inputs by
+// inserting vector.broadcast.
+struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    Type oldSrcType = insertOp.getSourceType();
+    Type newSrcType = oldSrcType;
+    int64_t oldSrcRank = 0, newSrcRank = 0;
+    if (auto type = oldSrcType.dyn_cast<VectorType>()) {
+      newSrcType = trimLeadingOneDims(type);
+      oldSrcRank = type.getRank();
+      newSrcRank = newSrcType.cast<VectorType>().getRank();
+    }
+
+    VectorType oldDstType = insertOp.getDestVectorType();
+    VectorType newDstType = trimLeadingOneDims(oldDstType);
+
+    int64_t srcDropCount = oldSrcRank - newSrcRank;
+    int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
+    if (srcDropCount == 0 && dstDropCount == 0)
+      return failure();
+
+    // Trim leading one dimensions from both operands.
+    Location loc = insertOp.getLoc();
+
+    Value newSrcVector = insertOp.getSource();
+    if (oldSrcRank != 0) {
+      newSrcVector = rewriter.create<vector::ExtractOp>(
+          loc, insertOp.getSource(), splatZero(srcDropCount));
+    }
+    Value newDstVector = rewriter.create<vector::ExtractOp>(
+        loc, insertOp.getDest(), splatZero(dstDropCount));
+
+    unsigned oldPosRank = insertOp.getPosition().getValue().size();
+    unsigned newPosRank = newDstType.getRank() - newSrcRank;
+    SmallVector<Attribute> newPositions = llvm::to_vector(
+        insertOp.getPosition().getValue().take_back(newPosRank));
+    if (newPosRank > oldPosRank) {
+      auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type());
+      newPositions.resize(newPosRank, zeroAttr);
+    }
+
+    auto newInsertOp = rewriter.create<vector::InsertOp>(
+        loc, newDstType, newSrcVector, newDstVector,
+        rewriter.getArrayAttr(newPositions));
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
+                                                     newInsertOp);
+
+    return success();
+  }
+};
+
 // Turns vector.transfer_read on vector with leading 1 dimensions into
 // vector.shape_cast followed by vector.transfer_read on vector without leading
 // 1 dimensions.
@@ -383,7 +439,7 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
     RewritePatternSet &patterns) {
   patterns
       .add<CastAwayExtractStridedSliceLeadingOneDim,
-           CastAwayInsertStridedSliceLeadingOneDim,
+           CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
            CastAwayTransferReadLeadingOneDim,
            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
            CastAwayContractionLeadingOneDim>(patterns.getContext());

diff  --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 70beb0fe43c5f..9115eda439cf4 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -265,3 +265,42 @@ func @cast_away_elementwise_leading_one_dims(
   return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
 }
 
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar
+//  CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>)
+//       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<1x1x4xf32>
+//       CHECK:   %[[INSERT:.+]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<4xf32>
+//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[BCAST]]
+func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
+  %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x4xf32>
+  return %0: vector<1x1x4xf32>
+}
+
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
+//  CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
+//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[BCAST]]
+func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
+  %0 = vector.insert %s, %v [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+  return %0: vector<1x1x4xf32>
+}
+
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
+//  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
+//       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
+//       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[BCAST]]
+func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
+  %0 = vector.insert %s, %v [0] : vector<1x4xf32> into vector<1x1x4xf32>
+  return %0: vector<1x1x4xf32>
+}
+
+// CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
+//  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
+//       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>
+//       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32>
+//       CHECK:   return %[[INSERT]]
+func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %v: vector<8x1x4xf32>) -> vector<8x1x4xf32> {
+  %0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32>
+  return %0: vector<8x1x4xf32>
+}


        


More information about the Mlir-commits mailing list