[Mlir-commits] [mlir] cc83c24 - [mlir][vector] Add canonicalization extract + splat

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 13 08:08:52 PDT 2021


Author: thomasraoux
Date: 2021-10-13T08:08:46-07:00
New Revision: cc83c2444f8a3bdb7f4b92f52a14b53792346057

URL: https://github.com/llvm/llvm-project/commit/cc83c2444f8a3bdb7f4b92f52a14b53792346057
DIFF: https://github.com/llvm/llvm-project/commit/cc83c2444f8a3bdb7f4b92f52a14b53792346057.diff

LOG: [mlir][vector] Add canonicalization extract + splat

Make canonicalization working on broadcast also work on splat op.

Differential Revision: https://reviews.llvm.org/D111690

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index b1168d084b802..343bc3e3e5100 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1094,17 +1094,18 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
   return Value();
 }
 
-/// Fold extractOp with scalar result coming from BroadcastOp.
+/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
-  auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
-  if (!broadcastOp)
+  Operation *defOp = extractOp.vector().getDefiningOp();
+  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
     return Value();
-  if (extractOp.getType() == broadcastOp.getSourceType())
-    return broadcastOp.source();
+  Value source = defOp->getOperand(0);
+  if (extractOp.getType() == source.getType())
+    return source;
   auto getRank = [](Type type) {
     return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
   };
-  unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
+  unsigned broadcasrSrcRank = getRank(source.getType());
   unsigned extractResultRank = getRank(extractOp.getType());
   if (extractResultRank < broadcasrSrcRank) {
     auto extractPos = extractVector<int64_t>(extractOp.position());
@@ -1112,7 +1113,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     extractPos.erase(
         extractPos.begin(),
         std::next(extractPos.begin(), extractPos.size() - rankDiff));
-    extractOp.setOperand(broadcastOp.source());
+    extractOp.setOperand(source);
     // OpBuilder is only used as a helper to build an I64ArrayAttr.
     OpBuilder b(extractOp.getContext());
     extractOp->setAttr(ExtractOp::getPositionAttrName(),
@@ -2259,6 +2260,21 @@ class StridedSliceBroadcast final
   }
 };
 
+/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
+class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto splat = op.vector().getDefiningOp<SplatOp>();
+    if (!splat)
+      return failure();
+    rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.input());
+    return success();
+  }
+};
+
 } // end anonymous namespace
 
 void ExtractStridedSliceOp::getCanonicalizationPatterns(
@@ -2266,7 +2282,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
   results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
-              StridedSliceBroadcast>(context);
+              StridedSliceBroadcast, StridedSliceSplat>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index d6e0a16ca4f69..6c232cc0f7ab5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -462,6 +462,17 @@ func @fold_extract_broadcast(%a : f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: fold_extract_splat
+//  CHECK-SAME:   %[[A:.*]]: f32
+//       CHECK:   return %[[A]] : f32
+func @fold_extract_splat(%a : f32) -> f32 {
+  %b = splat %a : vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+  return %r : f32
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_vector
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
 //       CHECK:   return %[[A]] : vector<4xf32>
@@ -1047,3 +1058,16 @@ func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<
   // CHECK: return %[[SOURCE]]
   return %0: vector<16x16xf16>
 }
+
+// -----
+
+// CHECK-LABEL: extract_strided_splat
+//       CHECK:   %[[B:.*]] = splat %{{.*}} : vector<2x4xf16>
+//  CHECK-NEXT:   return %[[B]] : vector<2x4xf16>
+func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
+ %0 = splat %arg0 : vector<16x4xf16>
+ %1 = vector.extract_strided_slice %0
+  {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
+  vector<16x4xf16> to vector<2x4xf16>
+  return %1 : vector<2x4xf16>
+}


        


More information about the Mlir-commits mailing list