[Mlir-commits] [mlir] 7cde516 - [mlir][vector] NFC, move some vector patterns in a separate file

Thomas Raoux llvmlistbot at llvm.org
Fri Nov 19 10:40:11 PST 2021


Author: Thomas Raoux
Date: 2021-11-19T10:39:29-08:00
New Revision: 7cde5165131f1268a8506066275ef7938c58d156

URL: https://github.com/llvm/llvm-project/commit/7cde5165131f1268a8506066275ef7938c58d156
DIFF: https://github.com/llvm/llvm-project/commit/7cde5165131f1268a8506066275ef7938c58d156.diff

LOG: [mlir][vector] NFC, move some vector patterns in a separate file

Move patterns related to dropping lead unit dim into their own file.

Differential Revision: https://reviews.llvm.org/D114265

Added: 
    mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp

Modified: 
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index abd961626e59..8f01eda3de4f 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRVector
+  VectorDropLeadUnitDim.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
   VectorMultiDimReductionTransforms.cpp
   VectorOps.cpp

diff  --git a/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
new file mode 100644
index 000000000000..f00a4f808c69
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
@@ -0,0 +1,259 @@
+//===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#define DEBUG_TYPE "vector-drop-unit-dim"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+// Trims leading one dimensions from `oldType` and returns the result type.
+// Returns `vector<1xT>` if `oldType` only has one element.
+static VectorType trimLeadingOneDims(VectorType oldType) {
+  ArrayRef<int64_t> oldShape = oldType.getShape();
+  ArrayRef<int64_t> newShape =
+      oldShape.drop_while([](int64_t dim) { return dim == 1; });
+  // Make sure we have at least 1 dimension per vector type requirements.
+  if (newShape.empty())
+    newShape = oldShape.take_back();
+  return VectorType::get(newShape, oldType.getElementType());
+}
+
+/// Return a smallVector of size `rank` containing all zeros.
+static SmallVector<int64_t> splatZero(int64_t rank) {
+  return SmallVector<int64_t>(rank, 0);
+}
+namespace {
+
+// Casts away leading one dimensions in vector.extract_strided_slice's vector
+// input by inserting vector.shape_cast.
+struct CastAwayExtractStridedSliceLeadingOneDim
+    : public OpRewritePattern<vector::ExtractStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // vector.extract_strided_slice requires the input and output vector to have
+    // the same rank. Here we drop leading one dimensions from the input vector
+    // type to make sure we don't cause mismatch.
+    VectorType oldSrcType = extractOp.getVectorType();
+    VectorType newSrcType = trimLeadingOneDims(oldSrcType);
+
+    if (newSrcType.getRank() == oldSrcType.getRank())
+      return failure();
+
+    int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
+
+    VectorType oldDstType = extractOp.getType();
+    VectorType newDstType =
+        VectorType::get(oldDstType.getShape().drop_front(dropCount),
+                        oldDstType.getElementType());
+
+    Location loc = extractOp.getLoc();
+
+    Value newSrcVector = rewriter.create<vector::ExtractOp>(
+        loc, extractOp.vector(), splatZero(dropCount));
+
+    // The offsets/sizes/strides attribute can have a less number of elements
+    // than the input vector's rank: it is meant for the leading dimensions.
+    auto newOffsets = rewriter.getArrayAttr(
+        extractOp.offsets().getValue().drop_front(dropCount));
+    auto newSizes = rewriter.getArrayAttr(
+        extractOp.sizes().getValue().drop_front(dropCount));
+    auto newStrides = rewriter.getArrayAttr(
+        extractOp.strides().getValue().drop_front(dropCount));
+
+    auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
+                                                     newExtractOp);
+
+    return success();
+  }
+};
+
+// Casts away leading one dimensions in vector.extract_strided_slice's vector
+// inputs by inserting vector.shape_cast.
+struct CastAwayInsertStridedSliceLeadingOneDim
+    : public OpRewritePattern<vector::InsertStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType oldSrcType = insertOp.getSourceVectorType();
+    VectorType newSrcType = trimLeadingOneDims(oldSrcType);
+    VectorType oldDstType = insertOp.getDestVectorType();
+    VectorType newDstType = trimLeadingOneDims(oldDstType);
+
+    int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
+    int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
+    if (srcDropCount == 0 && dstDropCount == 0)
+      return failure();
+
+    // Trim leading one dimensions from both operands.
+    Location loc = insertOp.getLoc();
+
+    Value newSrcVector = rewriter.create<vector::ExtractOp>(
+        loc, insertOp.source(), splatZero(srcDropCount));
+    Value newDstVector = rewriter.create<vector::ExtractOp>(
+        loc, insertOp.dest(), splatZero(dstDropCount));
+
+    auto newOffsets = rewriter.getArrayAttr(
+        insertOp.offsets().getValue().take_back(newDstType.getRank()));
+    auto newStrides = rewriter.getArrayAttr(
+        insertOp.strides().getValue().take_back(newSrcType.getRank()));
+
+    auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
+        loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
+                                                     newInsertOp);
+
+    return success();
+  }
+};
+
+// Turns vector.transfer_read on vector with leading 1 dimensions into
+// vector.shape_cast followed by vector.transfer_read on vector without leading
+// 1 dimensions.
+struct CastAwayTransferReadLeadingOneDim
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp read,
+                                PatternRewriter &rewriter) const override {
+    if (read.mask())
+      return failure();
+
+    auto shapedType = read.source().getType().cast<ShapedType>();
+    if (shapedType.getElementType() != read.getVectorType().getElementType())
+      return failure();
+
+    VectorType oldType = read.getVectorType();
+    VectorType newType = trimLeadingOneDims(oldType);
+
+    if (newType == oldType)
+      return failure();
+
+    AffineMap oldMap = read.permutation_map();
+    ArrayRef<AffineExpr> newResults =
+        oldMap.getResults().take_back(newType.getRank());
+    AffineMap newMap =
+        AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
+                       rewriter.getContext());
+
+    ArrayAttr inBounds;
+    if (read.in_bounds())
+      inBounds = rewriter.getArrayAttr(
+          read.in_boundsAttr().getValue().take_back(newType.getRank()));
+
+    auto newRead = rewriter.create<vector::TransferReadOp>(
+        read.getLoc(), newType, read.source(), read.indices(), newMap,
+        read.padding(), inBounds);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
+
+    return success();
+  }
+};
+
+// Turns vector.transfer_write on vector with leading 1 dimensions into
+// vector.shape_cast followed by vector.transfer_write on vector without leading
+// 1 dimensions.
+struct CastAwayTransferWriteLeadingOneDim
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+                                PatternRewriter &rewriter) const override {
+    if (write.mask())
+      return failure();
+
+    auto shapedType = write.source().getType().dyn_cast<ShapedType>();
+    if (shapedType.getElementType() != write.getVectorType().getElementType())
+      return failure();
+
+    VectorType oldType = write.getVectorType();
+    VectorType newType = trimLeadingOneDims(oldType);
+    if (newType == oldType)
+      return failure();
+    int64_t dropDim = oldType.getRank() - newType.getRank();
+
+    AffineMap oldMap = write.permutation_map();
+    ArrayRef<AffineExpr> newResults =
+        oldMap.getResults().take_back(newType.getRank());
+    AffineMap newMap =
+        AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
+                       rewriter.getContext());
+
+    ArrayAttr inBounds;
+    if (write.in_bounds())
+      inBounds = rewriter.getArrayAttr(
+          write.in_boundsAttr().getValue().take_back(newType.getRank()));
+
+    auto newVector = rewriter.create<vector::ExtractOp>(
+        write.getLoc(), write.vector(), splatZero(dropDim));
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        write, newVector, write.source(), write.indices(), newMap, inBounds);
+
+    return success();
+  }
+};
+
+class CastAwayElementwiseLeadingOneDim : public RewritePattern {
+public:
+  CastAwayElementwiseLeadingOneDim(MLIRContext *context)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+      return failure();
+    auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
+    if (!vecType)
+      return failure();
+    VectorType newVecType = trimLeadingOneDims(vecType);
+    if (newVecType == vecType)
+      return failure();
+    int64_t dropDim = vecType.getRank() - newVecType.getRank();
+    SmallVector<Value, 4> newOperands;
+    for (Value operand : op->getOperands()) {
+      if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
+        newOperands.push_back(rewriter.create<vector::ExtractOp>(
+            op->getLoc(), operand, splatZero(dropDim)));
+      } else {
+        newOperands.push_back(operand);
+      }
+    }
+    OperationState state(op->getLoc(), op->getName());
+    state.addAttributes(op->getAttrs());
+    state.addOperands(newOperands);
+    state.addTypes(newVecType);
+    Operation *newOp = rewriter.createOperation(state);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
+                                                     newOp->getResult(0));
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
+               CastAwayInsertStridedSliceLeadingOneDim,
+               CastAwayTransferReadLeadingOneDim,
+               CastAwayTransferWriteLeadingOneDim,
+               CastAwayElementwiseLeadingOneDim>(patterns.getContext());
+  populateShapeCastFoldingPatterns(patterns);
+}

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 8a77f171c61b..37f3c31e6a48 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2931,234 +2931,6 @@ struct TransferWriteToVectorStoreLowering
   llvm::Optional<unsigned> maxTransferRank;
 };
 
-// Trims leading one dimensions from `oldType` and returns the result type.
-// Returns `vector<1xT>` if `oldType` only has one element.
-static VectorType trimLeadingOneDims(VectorType oldType) {
-  ArrayRef<int64_t> oldShape = oldType.getShape();
-  ArrayRef<int64_t> newShape =
-      oldShape.drop_while([](int64_t dim) { return dim == 1; });
-  // Make sure we have at least 1 dimension per vector type requirements.
-  if (newShape.empty())
-    newShape = oldShape.take_back();
-  return VectorType::get(newShape, oldType.getElementType());
-}
-
-/// Return a smallVector of size `rank` containing all zeros.
-static SmallVector<int64_t> splatZero(int64_t rank) {
-  return SmallVector<int64_t>(rank, 0);
-}
-
-// Casts away leading one dimensions in vector.extract_strided_slice's vector
-// input by inserting vector.shape_cast.
-struct CastAwayExtractStridedSliceLeadingOneDim
-    : public OpRewritePattern<vector::ExtractStridedSliceOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    // vector.extract_strided_slice requires the input and output vector to have
-    // the same rank. Here we drop leading one dimensions from the input vector
-    // type to make sure we don't cause mismatch.
-    VectorType oldSrcType = extractOp.getVectorType();
-    VectorType newSrcType = trimLeadingOneDims(oldSrcType);
-
-    if (newSrcType.getRank() == oldSrcType.getRank())
-      return failure();
-
-    int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
-
-    VectorType oldDstType = extractOp.getType();
-    VectorType newDstType =
-        VectorType::get(oldDstType.getShape().drop_front(dropCount),
-                        oldDstType.getElementType());
-
-    Location loc = extractOp.getLoc();
-
-    Value newSrcVector = rewriter.create<vector::ExtractOp>(
-        loc, extractOp.vector(), splatZero(dropCount));
-
-    // The offsets/sizes/strides attribute can have a less number of elements
-    // than the input vector's rank: it is meant for the leading dimensions.
-    auto newOffsets = rewriter.getArrayAttr(
-        extractOp.offsets().getValue().drop_front(dropCount));
-    auto newSizes = rewriter.getArrayAttr(
-        extractOp.sizes().getValue().drop_front(dropCount));
-    auto newStrides = rewriter.getArrayAttr(
-        extractOp.strides().getValue().drop_front(dropCount));
-
-    auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
-        loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
-
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
-                                                     newExtractOp);
-
-    return success();
-  }
-};
-
-// Casts away leading one dimensions in vector.extract_strided_slice's vector
-// inputs by inserting vector.shape_cast.
-struct CastAwayInsertStridedSliceLeadingOneDim
-    : public OpRewritePattern<vector::InsertStridedSliceOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
-                                PatternRewriter &rewriter) const override {
-    VectorType oldSrcType = insertOp.getSourceVectorType();
-    VectorType newSrcType = trimLeadingOneDims(oldSrcType);
-    VectorType oldDstType = insertOp.getDestVectorType();
-    VectorType newDstType = trimLeadingOneDims(oldDstType);
-
-    int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
-    int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
-    if (srcDropCount == 0 && dstDropCount == 0)
-      return failure();
-
-    // Trim leading one dimensions from both operands.
-    Location loc = insertOp.getLoc();
-
-    Value newSrcVector = rewriter.create<vector::ExtractOp>(
-        loc, insertOp.source(), splatZero(srcDropCount));
-    Value newDstVector = rewriter.create<vector::ExtractOp>(
-        loc, insertOp.dest(), splatZero(dstDropCount));
-
-    auto newOffsets = rewriter.getArrayAttr(
-        insertOp.offsets().getValue().take_back(newDstType.getRank()));
-    auto newStrides = rewriter.getArrayAttr(
-        insertOp.strides().getValue().take_back(newSrcType.getRank()));
-
-    auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
-        loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
-
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
-                                                     newInsertOp);
-
-    return success();
-  }
-};
-
-// Turns vector.transfer_read on vector with leading 1 dimensions into
-// vector.shape_cast followed by vector.transfer_read on vector without leading
-// 1 dimensions.
-struct CastAwayTransferReadLeadingOneDim
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp read,
-                                PatternRewriter &rewriter) const override {
-    if (read.mask())
-      return failure();
-
-    auto shapedType = read.source().getType().cast<ShapedType>();
-    if (shapedType.getElementType() != read.getVectorType().getElementType())
-      return failure();
-
-    VectorType oldType = read.getVectorType();
-    VectorType newType = trimLeadingOneDims(oldType);
-
-    if (newType == oldType)
-      return failure();
-
-    AffineMap oldMap = read.permutation_map();
-    ArrayRef<AffineExpr> newResults =
-        oldMap.getResults().take_back(newType.getRank());
-    AffineMap newMap =
-        AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
-                       rewriter.getContext());
-
-    ArrayAttr inBounds;
-    if (read.in_bounds())
-      inBounds = rewriter.getArrayAttr(
-          read.in_boundsAttr().getValue().take_back(newType.getRank()));
-
-    auto newRead = rewriter.create<vector::TransferReadOp>(
-        read.getLoc(), newType, read.source(), read.indices(), newMap,
-        read.padding(), inBounds);
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
-
-    return success();
-  }
-};
-
-// Turns vector.transfer_write on vector with leading 1 dimensions into
-// vector.shape_cast followed by vector.transfer_write on vector without leading
-// 1 dimensions.
-struct CastAwayTransferWriteLeadingOneDim
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
-                                PatternRewriter &rewriter) const override {
-    if (write.mask())
-      return failure();
-
-    auto shapedType = write.source().getType().dyn_cast<ShapedType>();
-    if (shapedType.getElementType() != write.getVectorType().getElementType())
-      return failure();
-
-    VectorType oldType = write.getVectorType();
-    VectorType newType = trimLeadingOneDims(oldType);
-    if (newType == oldType)
-      return failure();
-    int64_t dropDim = oldType.getRank() - newType.getRank();
-
-    AffineMap oldMap = write.permutation_map();
-    ArrayRef<AffineExpr> newResults =
-        oldMap.getResults().take_back(newType.getRank());
-    AffineMap newMap =
-        AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
-                       rewriter.getContext());
-
-    ArrayAttr inBounds;
-    if (write.in_bounds())
-      inBounds = rewriter.getArrayAttr(
-          write.in_boundsAttr().getValue().take_back(newType.getRank()));
-
-    auto newVector = rewriter.create<vector::ExtractOp>(
-        write.getLoc(), write.vector(), splatZero(dropDim));
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        write, newVector, write.source(), write.indices(), newMap, inBounds);
-
-    return success();
-  }
-};
-
-class CastAwayElementwiseLeadingOneDim : public RewritePattern {
-public:
-  CastAwayElementwiseLeadingOneDim(MLIRContext *context)
-      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
-
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override {
-    if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
-      return failure();
-    auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
-    if (!vecType)
-      return failure();
-    VectorType newVecType = trimLeadingOneDims(vecType);
-    if (newVecType == vecType)
-      return failure();
-    int64_t dropDim = vecType.getRank() - newVecType.getRank();
-    SmallVector<Value, 4> newOperands;
-    for (Value operand : op->getOperands()) {
-      if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
-        newOperands.push_back(rewriter.create<vector::ExtractOp>(
-            op->getLoc(), operand, splatZero(dropDim)));
-      } else {
-        newOperands.push_back(operand);
-      }
-    }
-    OperationState state(op->getLoc(), op->getName());
-    state.addAttributes(op->getAttrs());
-    state.addOperands(newOperands);
-    state.addTypes(newVecType);
-    Operation *newOp = rewriter.createOperation(state);
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
-                                                     newOp->getResult(0));
-    return success();
-  }
-};
-
 // Returns the values in `arrayAttr` as an integer vector.
 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(
@@ -3638,16 +3410,6 @@ void mlir::vector::populateShapeCastFoldingPatterns(
   patterns.add<ShapeCastOpFolder>(patterns.getContext());
 }
 
-void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
-               CastAwayInsertStridedSliceLeadingOneDim,
-               CastAwayTransferReadLeadingOneDim,
-               CastAwayTransferWriteLeadingOneDim,
-               CastAwayElementwiseLeadingOneDim>(patterns.getContext());
-  populateShapeCastFoldingPatterns(patterns);
-}
-
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(
     RewritePatternSet &patterns) {
   patterns.add<BubbleDownVectorBitCastForExtract,


        


More information about the Mlir-commits mailing list