[llvm-branch-commits] [mlir] 7418688 - [mlir][vector] Add more vector Ops canonicalization

Thomas Raoux via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Dec 23 11:30:17 PST 2020


Author: Thomas Raoux
Date: 2020-12-23T11:25:01-08:00
New Revision: 74186880ba99b37c0375e9d87df818beee8b4ff2

URL: https://github.com/llvm/llvm-project/commit/74186880ba99b37c0375e9d87df818beee8b4ff2
DIFF: https://github.com/llvm/llvm-project/commit/74186880ba99b37c0375e9d87df818beee8b4ff2.diff

LOG: [mlir][vector] Add more vector Ops canonicalization

Add canonicalization for BroadcastOp, ExtractStrideSlicesOp and ShapeCastOp

Differential Revision: https://reviews.llvm.org/D93120

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 13aba2076ee9..e031f87cfb8e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -271,6 +271,7 @@ def Vector_BroadcastOp :
   }];
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Vector_ShuffleOp :

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index a3ad355d30b2..539e00d58dbf 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1110,6 +1110,36 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+namespace {
+
+// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
+// the degenerated case where the broadcast only adds dimensions of size 1 it
+// can be replaced by a ShapeCastOp. This canonicalization checks if the total
+// number of elements is the same before and after the broadcast to detect if
+// the only change in the vector type are new dimensions of size 1.
+class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
+public:
+  using OpRewritePattern<BroadcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
+    if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
+                           srcVecType.getNumElements())
+      return failure();
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(
+        broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
+    return success();
+  }
+};
+
+} // namespace
+
+void BroadcastOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<BroadcastToShapeCast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ShuffleOp
 //===----------------------------------------------------------------------===//
@@ -1768,7 +1798,8 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
 
 namespace {
 
-// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> ConstantMaskOp.
+// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
+// ConstantMaskOp.
 class StridedSliceConstantMaskFolder final
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
@@ -1847,14 +1878,70 @@ class StridedSliceConstantFolder final
   }
 };
 
+// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
+static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
+                                              unsigned dropFront = 0,
+                                              unsigned dropBack = 0) {
+  assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
+  auto range = arrayAttr.getAsRange<IntegerAttr>();
+  SmallVector<int64_t, 4> res;
+  res.reserve(arrayAttr.size() - dropFront - dropBack);
+  for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
+       it != eit; ++it)
+    res.push_back((*it).getValue().getSExtValue());
+  return res;
+}
+
+// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
+// BroadcastOp(ExtractStrideSliceOp).
+class StridedSliceBroadcast final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto broadcast = op.vector().getDefiningOp<BroadcastOp>();
+    if (!broadcast)
+      return failure();
+    auto srcVecType = broadcast.source().getType().dyn_cast<VectorType>();
+    unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
+    auto dstVecType = op.getType().cast<VectorType>();
+    unsigned dstRank = dstVecType.getRank();
+    unsigned rankDiff = dstRank - srcRrank;
+    // Check if the most inner dimensions of the source of the broacast are the
+    // same as the destination of the extract. If this is the case we can just
+    // use a broadcast as the original dimensions are untouched.
+    bool lowerDimMatch = true;
+    for (unsigned i = 0; i < srcRrank; i++) {
+      if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
+        lowerDimMatch = false;
+        break;
+      }
+    }
+    Value source = broadcast.source();
+    if (!lowerDimMatch) {
+      // The inner dimensions don't match, it means we need to extract from the
+      // source of the orignal broadcast and then broadcast the extracted value.
+      source = rewriter.create<ExtractStridedSliceOp>(
+          op->getLoc(), source,
+          getI64SubArray(op.offsets(), /* dropFront=*/rankDiff),
+          getI64SubArray(op.sizes(), /* dropFront=*/rankDiff),
+          getI64SubArray(op.strides(), /* dropFront=*/rankDiff));
+    }
+    rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
+    return success();
+  }
+};
+
 } // end anonymous namespace
 
 void ExtractStridedSliceOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
-      context);
+  results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
+                 StridedSliceBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2652,10 +2739,12 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
     return source();
 
   // Canceling shape casts.
-  if (auto otherOp = source().getDefiningOp<ShapeCastOp>())
+  if (auto otherOp = source().getDefiningOp<ShapeCastOp>()) {
     if (result().getType() == otherOp.source().getType())
       return otherOp.source();
-
+    setOperand(otherOp.source());
+    return getResult();
+  }
   return {};
 }
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f07285d7d98c..f94c3bcce5be 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -613,4 +613,51 @@ func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) {
   return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32>
 }
 
+// -----
+
+// CHECK-LABEL: extract_strided_broadcast
+//       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : vector<4xf16> to vector<2x4xf16>
+//  CHECK-NEXT:   return %[[B]] : vector<2x4xf16>
+func @extract_strided_broadcast(%arg0: vector<4xf16>) -> vector<2x4xf16> {
+ %0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16>
+ %1 = vector.extract_strided_slice %0
+  {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} :
+  vector<16x4xf16> to vector<2x4xf16>
+  return %1 : vector<2x4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_broadcast2
+//       CHECK:   %[[E:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf16> to vector<2xf16>
+//  CHECK-NEXT:   %[[B:.*]] = vector.broadcast %[[E]] : vector<2xf16> to vector<2x2xf16>
+//  CHECK-NEXT:   return %[[B]] : vector<2x2xf16>
+func @extract_strided_broadcast2(%arg0: vector<4xf16>) -> vector<2x2xf16> {
+ %0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16>
+ %1 = vector.extract_strided_slice %0
+  {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
+  vector<16x4xf16> to vector<2x2xf16>
+  return %1 : vector<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: consecutive_shape_cast
+//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
+//  CHECK-NEXT:   return %[[C]] : vector<4x4xf16>
+func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
+  %0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
+  %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
+  return %1 : vector<4x4xf16>
+}
+
+// -----
+
+// CHECK-LABEL: broadcast_to_shapecast
+//       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16>
+//  CHECK-NEXT:   return %[[C]] : vector<1x4x4xf16>
+func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> {
+  %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16>
+  return %0 : vector<1x4x4xf16>
+}
 


        


More information about the llvm-branch-commits mailing list