[Mlir-commits] [mlir] d0453a8 - [mlir][vector] Extend pattern to trim lead unit dimension to Splat Op

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 7 13:56:05 PDT 2021


Author: thomasraoux
Date: 2021-05-07T13:54:41-07:00
New Revision: d0453a8933a14c9441b2d89e6f934bd1bc243200

URL: https://github.com/llvm/llvm-project/commit/d0453a8933a14c9441b2d89e6f934bd1bc243200
DIFF: https://github.com/llvm/llvm-project/commit/d0453a8933a14c9441b2d89e6f934bd1bc243200.diff

LOG: [mlir][vector] Extend pattern to trim lead unit dimension to Splat Op

Differential Revision: https://reviews.llvm.org/D102091

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 15a211bcec66..83d9ae1832fc 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3175,27 +3175,31 @@ struct CastAwayTransferWriteLeadingOneDim
   }
 };
 
-struct CastAwayBroadcastLeadingOneDim
-    : public OpRewritePattern<vector::BroadcastOp> {
-  using OpRewritePattern::OpRewritePattern;
+template <typename BroadCastType>
+struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern<BroadCastType> {
+  using OpRewritePattern<BroadCastType>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+  LogicalResult matchAndRewrite(BroadCastType broadcastOp,
                                 PatternRewriter &rewriter) const override {
-    VectorType newDstType = trimLeadingOneDims(broadcastOp.getVectorType());
-    if (newDstType == broadcastOp.getVectorType())
+    VectorType dstType =
+        broadcastOp.getResult().getType().template dyn_cast<VectorType>();
+    if (!dstType)
+      return failure();
+    VectorType newDstType = trimLeadingOneDims(dstType);
+    if (newDstType == dstType)
       return failure();
     Location loc = broadcastOp.getLoc();
-    VectorType srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
+    Value source = broadcastOp->getOperand(0);
+    VectorType srcVecType = source.getType().template dyn_cast<VectorType>();
     if (srcVecType)
       srcVecType = trimLeadingOneDims(srcVecType);
-    Value source = broadcastOp.source();
-    if (srcVecType && srcVecType != broadcastOp.getSourceType()) {
+    if (srcVecType && srcVecType != source.getType()) {
       source = rewriter.create<vector::ShapeCastOp>(loc, srcVecType, source);
     }
     Value newBroadcastOp =
-        rewriter.create<vector::BroadcastOp>(loc, newDstType, source);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-        broadcastOp, broadcastOp.getVectorType(), newBroadcastOp);
+        rewriter.create<BroadCastType>(loc, newDstType, source);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcastOp, dstType,
+                                                     newBroadcastOp);
     return success();
   }
 };
@@ -3833,13 +3837,13 @@ void mlir::vector::populateSplitVectorTransferPatterns(
 
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
     RewritePatternSet &patterns) {
-  patterns
-      .add<CastAwayExtractStridedSliceLeadingOneDim,
-           CastAwayInsertStridedSliceLeadingOneDim,
-           CastAwayTransferReadLeadingOneDim,
-           CastAwayTransferWriteLeadingOneDim, CastAwayBroadcastLeadingOneDim,
-           CastAwayElementwiseLeadingOneDim, ShapeCastOpFolder>(
-          patterns.getContext());
+  patterns.add<
+      CastAwayExtractStridedSliceLeadingOneDim,
+      CastAwayInsertStridedSliceLeadingOneDim,
+      CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim,
+      CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
+      CastAwayBroadcastLeadingOneDim<SplatOp>, CastAwayElementwiseLeadingOneDim,
+      ShapeCastOpFolder>(patterns.getContext());
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index abfed46e82f1..101afe1a4729 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -675,7 +675,7 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x
 // CHECK-LABEL: func @cast_away_broadcast_leading_one_dims
 func @cast_away_broadcast_leading_one_dims(
   %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
-  (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>) {
+  (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) {
   // CHECK:  vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32>
   // CHECK:  vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
   %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
@@ -686,7 +686,10 @@ func @cast_away_broadcast_leading_one_dims(
   // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<3x4xf32>
   // CHECK:  vector.shape_cast %{{.*}} : vector<3x4xf32> to vector<1x3x4xf32>
   %2 = vector.broadcast %arg2 : vector<1x4xf32> to vector<1x3x4xf32>
-  return %0, %1, %2: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>
+  // CHECK:  splat %{{.*}} : vector<4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32>
+  %3 = splat %arg1 : vector<1x1x4xf32>
+  return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>
 }
 
 // CHECK-LABEL: func @cast_away_elementwise_leading_one_dims


        


More information about the Mlir-commits mailing list