[Mlir-commits] [mlir] 0b17336 - [mlir][Vector] Make vector.shape_cast based size-1 foldings opt-in and separate.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 15 13:21:48 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-15T21:17:57Z
New Revision: 0b17336f793108a7b10c3fa913039144ef1d0f61

URL: https://github.com/llvm/llvm-project/commit/0b17336f793108a7b10c3fa913039144ef1d0f61
DIFF: https://github.com/llvm/llvm-project/commit/0b17336f793108a7b10c3fa913039144ef1d0f61.diff

LOG: [mlir][Vector] Make vector.shape_cast based size-1 foldings opt-in and separate.

This is in prevision of dropping them altogether and using insert/extract based patterns.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D113928

Added: 
    mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 1671cc708110e..0ed4b8370cb1d 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -56,6 +56,9 @@ isBroadcastableTo(Type srcType, VectorType dstVectorType,
 void populateVectorToVectorCanonicalizationPatterns(
     RewritePatternSet &patterns);
 
+/// Collect a set of vector.shape_cast folding patterns.
+void populateShapeCastFoldingPatterns(RewritePatternSet &patterns);
+
 /// Collect a set of leading one dimension removal patterns.
 ///
 /// These patterns insert vector.shape_cast to remove leading one dimensions

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 80b4e606c6ff2..5b4f1abc570ed 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1215,29 +1215,12 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
 
 namespace {
 
-// If extractOp is only removing unit dimensions it can be transformed to a
-// shapecast.
-class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern<ExtractOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
-    if (!dstVecType || extractOp.getVectorType().getNumElements() !=
-                           dstVecType.getNumElements())
-      return failure();
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
-                                             extractOp.vector());
-    return success();
-  }
-};
-
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractToShapeCast>(context);
+  // ExtractToShapeCast is not a default canonicalization, it is opt-in by
+  // calling `populateCastAwayVectorLeadingOneDimPatterns`
 }
 
 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -1401,27 +1384,6 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
 
 namespace {
 
-// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
-// the degenerated case where the broadcast only adds dimensions of size 1 it
-// can be replaced by a ShapeCastOp. This canonicalization checks if the total
-// number of elements is the same before and after the broadcast to detect if
-// the only change in the vector type are new dimensions of size 1.
-class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
-public:
-  using OpRewritePattern<BroadcastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
-    if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
-                           srcVecType.getNumElements())
-      return failure();
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(
-        broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
-    return success();
-  }
-};
-
 // Fold broadcast1(broadcast2(x)) into broadcast1(x).
 struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
   using OpRewritePattern<BroadcastOp>::OpRewritePattern;
@@ -1440,7 +1402,9 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<BroadcastToShapeCast, BroadcastFolder>(context);
+  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
+  // calling `populateCastAwayVectorLeadingOneDimPatterns`
+  results.add<BroadcastFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1605,31 +1569,10 @@ static LogicalResult verify(InsertOp op) {
   return success();
 }
 
-namespace {
-
-// If insertOp is only inserting unit dimensions it can be transformed to a
-// shapecast.
-class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
-public:
-  using OpRewritePattern<InsertOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(InsertOp insertOp,
-                                PatternRewriter &rewriter) const override {
-    auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
-    if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
-                           srcVecType.getNumElements())
-      return failure();
-    rewriter.replaceOpWithNewOp<ShapeCastOp>(
-        insertOp, insertOp.getDestVectorType(), insertOp.source());
-    return success();
-  }
-};
-
-} // namespace
-
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToShapeCast>(context);
+  // InsertToShapeCast is not a default canonicalization, it is opt-in by
+  // calling `populateCastAwayVectorLeadingOneDimPatterns`
 }
 
 // Eliminates insert operations that produce values identical to their source

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 3fb6d4c50e9b4..cf40e4f272609 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1113,11 +1113,18 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
     Location loc = op.getLoc();
     auto sourceVectorType = op.getSourceVectorType();
     auto resultVectorType = op.getResultVectorType();
-    // Intended 2D/1D lowerings with better implementations.
+
+    // Special case 2D/1D lowerings with better implementations.
+    // TODO: make is ND/1D to allow generic ND->1D->MD.
     int64_t srcRank = sourceVectorType.getRank();
     int64_t resRank = resultVectorType.getRank();
     if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
       return failure();
+
+    // Generic ShapeCast lowering path goes all the way down to unrolled scalar
+    // extract/insert chains.
+    // TODO: consider evolving the semantics to only allow 1D source or dest and
+    // drop this potentially very expensive lowering.
     // Compute number of elements involved in the reshape.
     int64_t numElts = 1;
     for (int64_t r = 0; r < srcRank; r++)
@@ -3177,6 +3184,63 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
   }
 };
 
+// If extractOp is only removing unit dimensions it can be transformed to a
+// shapecast.
+class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern<ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
+    if (!dstVecType || extractOp.getVectorType().getNumElements() !=
+                           dstVecType.getNumElements())
+      return failure();
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
+                                             extractOp.vector());
+    return success();
+  }
+};
+
+// If insertOp is only inserting unit dimensions it can be transformed to a
+// shapecast.
+class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
+    if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
+                           srcVecType.getNumElements())
+      return failure();
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(
+        insertOp, insertOp.getDestVectorType(), insertOp.source());
+    return success();
+  }
+};
+
+// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
+// the degenerated case where the broadcast only adds dimensions of size 1 it
+// can be replaced by a ShapeCastOp. This canonicalization checks if the total
+// number of elements is the same before and after the broadcast to detect if
+// the only change in the vector type are new dimensions of size 1.
+class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
+public:
+  using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
+    if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
+                           srcVecType.getNumElements())
+      return failure();
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(
+        broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
+    return success();
+  }
+};
+
 // Returns the values in `arrayAttr` as an integer vector.
 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(
@@ -3651,16 +3715,21 @@ void mlir::vector::populatePropagateVectorDistributionPatterns(
       patterns.getContext());
 }
 
+void mlir::vector::populateShapeCastFoldingPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ShapeCastOpFolder>(patterns.getContext());
+}
+
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
-               CastAwayInsertStridedSliceLeadingOneDim,
-               CastAwayTransferReadLeadingOneDim,
-               CastAwayTransferWriteLeadingOneDim,
-               CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
-               CastAwayBroadcastLeadingOneDim<SplatOp>,
-               CastAwayElementwiseLeadingOneDim, ShapeCastOpFolder>(
-      patterns.getContext());
+  patterns.add<
+      BroadcastToShapeCast, CastAwayExtractStridedSliceLeadingOneDim,
+      CastAwayInsertStridedSliceLeadingOneDim,
+      CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim,
+      CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
+      CastAwayBroadcastLeadingOneDim<SplatOp>, CastAwayElementwiseLeadingOneDim,
+      ExtractToShapeCast, InsertToShapeCast>(patterns.getContext());
+  populateShapeCastFoldingPatterns(patterns);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index cf05308e8129b..3d60745b8ccd7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -717,16 +717,6 @@ func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
 
 // -----
 
-// CHECK-LABEL: broadcast_to_shapecast
-//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
-//  CHECK-NEXT:   return %[[C]] : vector<1x4x4xf16>
-func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
-  %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
-  return %0 : vector<1x4x4xf16>
-}
-
-// -----
-
 // CHECK-LABEL: func @dead_transfer_op
 //   CHECK-NOT:   vector.transfer_read
 //   CHECK-NOT:   vector.transfer_write
@@ -971,20 +961,6 @@ func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
 
 // -----
 
-// CHECK-LABEL: func @insert_extract_to_shapecast
-//  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
-//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
-//       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>,
-  %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
-  %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
-  %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
-  return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
-}
-
-// -----
-
 // CHECK-LABEL: func @transfer_read_of_extract_slice(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
 //   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index

diff  --git a/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir b/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir
new file mode 100644
index 0000000000000..5dd44d38ccb7e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
+
+// CHECK-LABEL: broadcast_to_shapecast
+//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
+//  CHECK-NEXT:   return %[[C]] : vector<1x4x4xf16>
+func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
+  %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
+  return %0 : vector<1x4x4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_extract_to_shapecast
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
+func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>,
+  %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
+  %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32>
+  %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
+  return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
+}


        


More information about the Mlir-commits mailing list