[Mlir-commits] [mlir] 874ce9b - [mlir][vector] Add patterns to cast away leading 1-dim

Lei Zhang llvmlistbot at llvm.org
Fri Feb 5 06:02:28 PST 2021


Author: Lei Zhang
Date: 2021-02-05T09:02:15-05:00
New Revision: 874ce9b80f87ba6eb880aff38ba2a359e3cf12b2

URL: https://github.com/llvm/llvm-project/commit/874ce9b80f87ba6eb880aff38ba2a359e3cf12b2
DIFF: https://github.com/llvm/llvm-project/commit/874ce9b80f87ba6eb880aff38ba2a359e3cf12b2.diff

LOG: [mlir][vector] Add patterns to cast away leading 1-dim

This patch adds patterns to use vector.shape_cast to cast
away leading 1-dimensions from a few vector operations.
It allows exposing more canonical forms of vector.transfer_read,
vector.transfer_write, vector_extract_strided_slice, and
vector.insert_strided_slice. With this, we can have more
opportunity to cancelling extract/insert ops or forwarding
write/read ops.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D95873

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transforms.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 5540a56a4043..b01aa112feef 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -35,6 +35,15 @@ void populateVectorToVectorCanonicalizationPatterns(
 void populateVectorToVectorTransformationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context);
 
+/// Collect a set of leading one dimension removal patterns.
+///
+/// These patterns insert vector.shape_cast to remove leading one dimensions
+/// to expose more canonical forms of read/write/insert/extract operations.
+/// With them, there are more chances that we can cancel out extract-insert
+/// pairs or forward write-read pairs.
+void populateCastAwayVectorLeadingOneDimPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context);
+
 /// Collect a set of vector slices transformation patterns:
 ///    ExtractSlicesOpLowering, InsertSlicesOpLowering
 /// Useful for clients that want to express all vector "slices"

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index a567ba8f37a0..6a8ee49bfb7e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2607,6 +2607,186 @@ struct TransferWriteInsertPattern
   }
 };
 
+// Trims leading one dimensions from `oldType` and returns the result type.
+// Returns `vector<1xT>` if `oldType` only has one element.
+static VectorType trimLeadingOneDims(VectorType oldType) {
+  ArrayRef<int64_t> oldShape = oldType.getShape();
+  ArrayRef<int64_t> newShape =
+      oldShape.drop_while([](int64_t dim) { return dim == 1; });
+  // Make sure we have at least 1 dimension per vector type requirements.
+  if (newShape.empty())
+    newShape = oldShape.take_back();
+  return VectorType::get(newShape, oldType.getElementType());
+}
+
+// Casts away leading one dimensions in vector.extract_strided_slice's vector
+// input by inserting vector.shape_cast.
+struct CastAwayExtractStridedSliceLeadingOneDim
+    : public OpRewritePattern<vector::ExtractStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // vector.extract_strided_slice requires the input and output vector to have
+    // the same rank. Here we drop leading one dimensions from the input vector
+    // type to make sure we don't cause mismatch.
+    VectorType oldSrcType = extractOp.getVectorType();
+    VectorType newSrcType = trimLeadingOneDims(oldSrcType);
+
+    if (newSrcType.getRank() == oldSrcType.getRank())
+      return failure();
+
+    int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
+
+    VectorType oldDstType = extractOp.getType();
+    VectorType newDstType =
+        VectorType::get(oldDstType.getShape().drop_front(dropCount),
+                        oldDstType.getElementType());
+
+    Location loc = extractOp.getLoc();
+
+    Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
+        loc, newSrcType, extractOp.vector());
+
+    // The offsets/sizes/strides attribute can have a less number of elements
+    // than the input vector's rank: it is meant for the leading dimensions.
+    auto newOffsets = rewriter.getArrayAttr(
+        extractOp.offsets().getValue().drop_front(dropCount));
+    auto newSizes = rewriter.getArrayAttr(
+        extractOp.sizes().getValue().drop_front(dropCount));
+    auto newStrides = rewriter.getArrayAttr(
+        extractOp.strides().getValue().drop_front(dropCount));
+
+    auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
+                                                     newExtractOp);
+
+    return success();
+  }
+};
+
+// Casts away leading one dimensions in vector.extract_strided_slice's vector
+// inputs by inserting vector.shape_cast.
+struct CastAwayInsertStridedSliceLeadingOneDim
+    : public OpRewritePattern<vector::InsertStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType oldSrcType = insertOp.getSourceVectorType();
+    VectorType newSrcType = trimLeadingOneDims(oldSrcType);
+    VectorType oldDstType = insertOp.getDestVectorType();
+    VectorType newDstType = trimLeadingOneDims(oldDstType);
+
+    if (newSrcType.getRank() == oldSrcType.getRank() &&
+        newDstType.getRank() == oldDstType.getRank())
+      return failure();
+
+    // Trim leading one dimensions from both operands.
+    Location loc = insertOp.getLoc();
+
+    Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
+        loc, newSrcType, insertOp.source());
+    Value newDstVector =
+        rewriter.create<vector::ShapeCastOp>(loc, newDstType, insertOp.dest());
+
+    auto newOffsets = rewriter.getArrayAttr(
+        insertOp.offsets().getValue().take_back(newDstType.getRank()));
+    auto newStrides = rewriter.getArrayAttr(
+        insertOp.strides().getValue().take_back(newSrcType.getRank()));
+
+    auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
+        loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
+                                                     newInsertOp);
+
+    return success();
+  }
+};
+
+// Turns vector.transfer_read on vector with leading 1 dimensions into
+// vector.shape_cast followed by vector.transfer_read on vector without leading
+// 1 dimensions.
+struct CastAwayTransferReadLeadingOneDim
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp read,
+                                PatternRewriter &rewriter) const override {
+    auto shapedType = read.source().getType().cast<ShapedType>();
+    if (shapedType.getElementType() != read.getVectorType().getElementType())
+      return failure();
+
+    VectorType oldType = read.getVectorType();
+    VectorType newType = trimLeadingOneDims(oldType);
+
+    if (newType == oldType)
+      return failure();
+
+    AffineMap oldMap = read.permutation_map();
+    ArrayRef<AffineExpr> newResults =
+        oldMap.getResults().take_back(newType.getRank());
+    AffineMap newMap =
+        AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
+                       rewriter.getContext());
+
+    ArrayAttr mask;
+    if (read.masked())
+      mask = rewriter.getArrayAttr(
+          read.maskedAttr().getValue().take_back(newType.getRank()));
+
+    auto newRead = rewriter.create<vector::TransferReadOp>(
+        read.getLoc(), newType, read.source(), read.indices(), newMap,
+        read.padding(), mask);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
+
+    return success();
+  }
+};
+
+// Turns vector.transfer_write on vector with leading 1 dimensions into
+// vector.shape_cast followed by vector.transfer_write on vector without leading
+// 1 dimensions.
+struct CastAwayTransferWriteLeadingOneDim
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+                                PatternRewriter &rewriter) const override {
+    auto shapedType = write.source().getType().dyn_cast<ShapedType>();
+    if (shapedType.getElementType() != write.getVectorType().getElementType())
+      return failure();
+
+    VectorType oldType = write.getVectorType();
+    VectorType newType = trimLeadingOneDims(oldType);
+
+    if (newType == oldType)
+      return failure();
+
+    AffineMap oldMap = write.permutation_map();
+    ArrayRef<AffineExpr> newResults =
+        oldMap.getResults().take_back(newType.getRank());
+    AffineMap newMap =
+        AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
+                       rewriter.getContext());
+
+    ArrayAttr mask;
+    if (write.masked())
+      mask = rewriter.getArrayAttr(
+          write.maskedAttr().getValue().take_back(newType.getRank()));
+
+    auto newVector = rewriter.create<vector::ShapeCastOp>(
+        write.getLoc(), newType, write.vector());
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        write, newVector, write.source(), write.indices(), newMap, mask);
+
+    return success();
+  }
+};
+
 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
 // TODO: Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
@@ -2622,6 +2802,15 @@ void mlir::vector::populateVectorToVectorTransformationPatterns(
   // clang-format on
 }
 
+void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<CastAwayExtractStridedSliceLeadingOneDim,
+                  CastAwayInsertStridedSliceLeadingOneDim,
+                  CastAwayTransferReadLeadingOneDim,
+                  CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
+      context);
+}
+
 void mlir::vector::populateVectorSlicesLoweringPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
   patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index e1c79de4efae..831d2eb33d37 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -601,3 +601,73 @@ func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>,
       : vector<4x4xf32>, tensor<4x4xf32>
   return %r : tensor<4x4xf32>
 }
+
+// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
+func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
+  // CHECK:     %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
+  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
+  // CHECK:     %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
+  // CHECK: return %[[RET]]
+  return %0: vector<1x1x8xf16>
+}
+
+// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
+func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
+  // CHECK:    %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16>
+  // CHECK:    %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16>
+  // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
+  // CHECK:    %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
+  // CHECK: return %[[RET]]
+  return %0: vector<1x8x8xf16>
+}
+
+// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
+func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
+  // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
+  // CHECK: vector.shape_cast %{{.+}} : vector<1x1x1xf16> to vector<1xf16>
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
+  return %0: vector<1x1x1xf16>
+}
+
+// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims
+func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> {
+  // CHECK: %[[C0:.+]] = constant 0 : index
+  %c0 = constant 0 : index
+  // CHECK: %[[F0:.+]] = constant 0.000000e+00 : f16
+  %f0 = constant 0. : f16
+  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {masked = [false]} : memref<1x4x8x16xf16>, vector<4xf16>
+  // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x4x8x16xf16>, vector<1x4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<1x4xf16>
+}
+
+// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
+func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
+  %c0 = constant 0 : index
+  %f0 = constant 0. : f16
+  // CHECK: vector.shape_cast %{{.+}} : vector<1xf16> to vector<1x1xf16>
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x1x1x1xf16>, vector<1x1xf16>
+  return %0: vector<1x1xf16>
+}
+
+// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
+func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
+  // CHECK: %[[C0:.+]] = constant 0 : index
+  %c0 = constant 0 : index
+  // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16>
+  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {masked = [false]} : vector<4xf16>, memref<1x4x8x16xf16>
+
+  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf16>, memref<1x4x8x16xf16>
+  return
+}
+
+// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
+func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
+  %c0 = constant 0 : index
+  // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16>
+  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x1xf16>, memref<1x1x1x1xf16>
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index ea4b6f998db0..109a9fcb65e3 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -45,6 +45,7 @@ struct TestVectorToVectorConversion
     }
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
+    populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 


        


More information about the Mlir-commits mailing list