[Mlir-commits] [mlir] 7630520 - [mlir][vector] Add pattern to shuffle bitcast ops

Lei Zhang llvmlistbot at llvm.org
Fri Feb 5 14:53:10 PST 2021


Author: Lei Zhang
Date: 2021-02-05T17:52:49-05:00
New Revision: 7630520ae3c5af3f3536a81740cf316d3a21304e

URL: https://github.com/llvm/llvm-project/commit/7630520ae3c5af3f3536a81740cf316d3a21304e
DIFF: https://github.com/llvm/llvm-project/commit/7630520ae3c5af3f3536a81740cf316d3a21304e.diff

LOG: [mlir][vector] Add pattern to shuffle bitcast ops

These patterns move vector.bitcast ops to be before
insert ops or after extract ops where suitable.
With them, bitcast will happen on smaller vectors
and there are more chances to share extract/insert
ops.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D96040

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transforms.mlir
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index b01aa112feef..afc55c1911ba 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -44,6 +44,14 @@ void populateVectorToVectorTransformationPatterns(
 void populateCastAwayVectorLeadingOneDimPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context);
 
+/// Collect a set of patterns that bubble up/down bitcast ops.
+///
+/// These patterns move vector.bitcast ops to be before insert ops or after
+/// extract ops where suitable. With them, bitcast will happen on smaller
+/// vectors and there are more chances to share extract/insert ops.
+void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
+                                           MLIRContext *context);
+
 /// Collect a set of vector slices transformation patterns:
 ///    ExtractSlicesOpLowering, InsertSlicesOpLowering
 /// Useful for clients that want to express all vector "slices"

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6a8ee49bfb7e..765eb0800a42 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2787,6 +2787,244 @@ struct CastAwayTransferWriteLeadingOneDim
   }
 };
 
+// Returns the values in `arrayAttr` as an integer vector.
+static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
+  return llvm::to_vector<4>(
+      llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
+                      [](IntegerAttr attr) { return attr.getInt(); }));
+};
+
+// Shuffles vector.bitcast op after vector.extract op.
+//
+// This transforms IR like:
+//   %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
+//   %1 = vector.extract %0[3] : vector<8xf16>
+// Into:
+//   %0 = vector.extract %src[1] : vector<4xf32>
+//   %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
+//   %2 = vector.extract %1[1] : vector<2xf16>
+struct BubbleDownVectorBitCastForExtract
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // Only support extracting scalars for now.
+    if (extractOp.getVectorType().getRank() != 1)
+      return failure();
+
+    auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
+    if (!castOp)
+      return failure();
+
+    VectorType castSrcType = castOp.getSourceVectorType();
+    VectorType castDstType = castOp.getResultVectorType();
+    assert(castSrcType.getRank() == castDstType.getRank());
+
+    // Fail to match if we only have one element in the cast op source.
+    // This is to avoid infinite loop given that this pattern can generate
+    // such cases.
+    if (castSrcType.getNumElements() == 1)
+      return failure();
+
+    // Only support casting to a larger number of elements or now.
+    // E.g., vector<4xf32> -> vector<8xf16>.
+    if (castSrcType.getNumElements() > castDstType.getNumElements())
+      return failure();
+
+    unsigned expandRatio =
+        castDstType.getNumElements() / castSrcType.getNumElements();
+
+    auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
+      return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
+    };
+
+    uint64_t index = getFirstIntValue(extractOp.position());
+
+    // Get the single scalar (as a vector) in the source value that packs the
+    // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
+    VectorType oneScalarType =
+        VectorType::get({1}, castSrcType.getElementType());
+    Value packedValue = rewriter.create<vector::ExtractOp>(
+        extractOp.getLoc(), oneScalarType, castOp.source(),
+        rewriter.getI64ArrayAttr(index / expandRatio));
+
+    // Cast it to a vector with the desired scalar's type.
+    // E.g. f32 -> vector<2xf16>
+    VectorType packedType =
+        VectorType::get({expandRatio}, castDstType.getElementType());
+    Value castedValue = rewriter.create<vector::BitCastOp>(
+        extractOp.getLoc(), packedType, packedValue);
+
+    // Finally extract the desired scalar.
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+        extractOp, extractOp.getType(), castedValue,
+        rewriter.getI64ArrayAttr(index % expandRatio));
+
+    return success();
+  }
+};
+
+// Shuffles vector.bitcast op after vector.extract_strided_slice op.
+//
+// This transforms IR like:
+//    %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
+//     %0 = vector.extract_strided_slice %cast {
+//            offsets = [4], sizes = [4], strides = [1]
+//          } : vector<8xf16> to vector<4xf16>
+// Into:
+//   %0 = vector.extract_strided_slice %src {
+//          offsets = [2], sizes = [2], strides = [1]
+//        } : vector<4xf32> to vector<2xf32>
+//   %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
+struct BubbleDownBitCastForStridedSliceExtract
+    : public OpRewritePattern<vector::ExtractStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>();
+    if (!castOp)
+      return failure();
+
+    VectorType castSrcType = castOp.getSourceVectorType();
+    VectorType castDstType = castOp.getResultVectorType();
+    assert(castSrcType.getRank() == castDstType.getRank());
+
+    int64_t castSrcLastDim = castSrcType.getShape().back();
+    int64_t castDstLastDim = castDstType.getShape().back();
+    // Require casting to more elements for now; other cases to be implemented.
+    if (castSrcLastDim > castDstLastDim)
+      return failure();
+
+    // Only accept all one strides for now.
+    if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(),
+                     [](const APInt &val) { return !val.isOneValue(); }))
+      return failure();
+
+    unsigned rank = extractOp.getVectorType().getRank();
+    assert(castDstLastDim % castSrcLastDim == 0);
+    int64_t expandRatio = castDstLastDim / castSrcLastDim;
+
+    // If we have a less number of offsets than the rank, then implicitly we
+    // are selecting the full range for the last bitcasted dimension; other
+    // dimensions aren't affected. Otherwise, we need to scale down the last
+    // dimension's offset given we are extracting from less elements now.
+    ArrayAttr newOffsets = extractOp.offsets();
+    if (newOffsets.size() == rank) {
+      SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
+      if (offsets.back() % expandRatio != 0)
+        return failure();
+      offsets.back() = offsets.back() / expandRatio;
+      newOffsets = rewriter.getI64ArrayAttr(offsets);
+    }
+
+    // Similarly for sizes.
+    ArrayAttr newSizes = extractOp.sizes();
+    if (newSizes.size() == rank) {
+      SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
+      if (sizes.back() % expandRatio != 0)
+        return failure();
+      sizes.back() = sizes.back() / expandRatio;
+      newSizes = rewriter.getI64ArrayAttr(sizes);
+    }
+
+    SmallVector<int64_t, 4> dims =
+        llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
+    dims.back() = dims.back() / expandRatio;
+    VectorType newExtractType =
+        VectorType::get(dims, castSrcType.getElementType());
+
+    auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
+        extractOp.getLoc(), newExtractType, castOp.source(), newOffsets,
+        newSizes, extractOp.strides());
+
+    rewriter.replaceOpWithNewOp<vector::BitCastOp>(
+        extractOp, extractOp.getType(), newExtractOp);
+
+    return success();
+  }
+};
+
+// Shuffles vector.bitcast op before vector.insert_strided_slice op.
+//
+// This transforms IR like:
+//   %0 = vector.insert_strided_slice %src, %dst {
+//          offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
+//   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
+// Into:
+//   %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
+//   %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
+//   %2 = vector.insert_strided_slice %src, %dst {
+//          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+struct BubbleUpBitCastForStridedSliceInsert
+    : public OpRewritePattern<vector::BitCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType castSrcType = bitcastOp.getSourceVectorType();
+    VectorType castDstType = bitcastOp.getResultVectorType();
+    assert(castSrcType.getRank() == castDstType.getRank());
+
+    int64_t castSrcLastDim = castSrcType.getShape().back();
+    int64_t castDstLastDim = castDstType.getShape().back();
+    // Require casting to less elements for now; other cases to be implemented.
+    if (castSrcLastDim < castDstLastDim)
+      return failure();
+
+    assert(castSrcLastDim % castDstLastDim == 0);
+    int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
+
+    auto insertOp =
+        bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>();
+    if (!insertOp)
+      return failure();
+
+    // Only accept all one strides for now.
+    if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(),
+                     [](const APInt &val) { return !val.isOneValue(); }))
+      return failure();
+
+    unsigned rank = insertOp.getSourceVectorType().getRank();
+    // Require insert op to have the same rank for the source and destination
+    // vector; other cases to be implemented.
+    if (rank != insertOp.getDestVectorType().getRank())
+      return failure();
+
+    ArrayAttr newOffsets = insertOp.offsets();
+    assert(newOffsets.size() == rank);
+    SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
+    if (offsets.back() % shrinkRatio != 0)
+      return failure();
+    offsets.back() = offsets.back() / shrinkRatio;
+    newOffsets = rewriter.getI64ArrayAttr(offsets);
+
+    SmallVector<int64_t, 4> srcDims =
+        llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
+    srcDims.back() = srcDims.back() / shrinkRatio;
+    VectorType newCastSrcType =
+        VectorType::get(srcDims, castDstType.getElementType());
+
+    auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
+        bitcastOp.getLoc(), newCastSrcType, insertOp.source());
+
+    SmallVector<int64_t, 4> dstDims =
+        llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
+    dstDims.back() = dstDims.back() / shrinkRatio;
+    VectorType newCastDstType =
+        VectorType::get(dstDims, castDstType.getElementType());
+
+    auto newCastDstOp = rewriter.create<vector::BitCastOp>(
+        bitcastOp.getLoc(), newCastDstType, insertOp.dest());
+
+    rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
+        bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
+        insertOp.strides());
+
+    return success();
+  }
+};
+
 // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
 // TODO: Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
@@ -2811,6 +3049,13 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
       context);
 }
 
+void mlir::vector::populateBubbleVectorBitCastOpPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<BubbleDownVectorBitCastForExtract,
+                  BubbleDownBitCastForStridedSliceExtract,
+                  BubbleUpBitCastForStridedSliceInsert>(context);
+}
+
 void mlir::vector::populateVectorSlicesLoweringPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
   patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 831d2eb33d37..20c91882871d 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -671,3 +671,92 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x
   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x1xf16>, memref<1x1x1x1xf16>
   return
 }
+
+// CHECK-LABEL: func @bubble_down_bitcast_in_extract
+//  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
+func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {
+  %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
+  // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : vector<4xf32>
+  // CHECK:    %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<1xf32> to vector<2xf16>
+  // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : vector<2xf16>
+  %1 = vector.extract %0[3] : vector<8xf16>
+  // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : vector<4xf32>
+  // CHECK:    %[[CAST2:.+]] = vector.bitcast %[[EXTRACT3]] : vector<1xf32> to vector<2xf16>
+  // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : vector<2xf16>
+  %2 = vector.extract %0[4] : vector<8xf16>
+  // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]
+  return %1, %2: f16, f16
+}
+
+// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract
+//  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
+func @bubble_down_bitcast_in_strided_slice_extract(%arg0: vector<4xf32>) -> vector<4xf16> {
+  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+  // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2xf32> to vector<4xf16>
+  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
+  %0 = vector.extract_strided_slice %cast {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<4xf16>
+}
+
+// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim
+//  CHECK-SAME: %[[SRC:.+]]: vector<4x2xf32>
+func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim(%arg0: vector<4x2xf32>) -> vector<2x4xf16> {
+  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [1], sizes = [2], strides = [1]} : vector<4x2xf32> to vector<2x2xf32>
+  // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2x2xf32> to vector<2x4xf16>
+  %cast = vector.bitcast %arg0: vector<4x2xf32> to vector<4x4xf16>
+  %0 = vector.extract_strided_slice %cast {offsets = [1], sizes = [2], strides = [1]} : vector<4x4xf16> to vector<2x4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<2x4xf16>
+}
+
+// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_offset
+func @bubble_down_bitcast_in_strided_slice_extract_odd_offset(%arg0: vector<4xf32>) -> vector<4xf16> {
+  // CHECK: vector.bitcast
+  // CHECK-NEXT: vector.extract_strided_slice
+  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
+  %0 = vector.extract_strided_slice %cast {offsets = [3], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+  return %0: vector<4xf16>
+}
+
+// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_size
+func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4xf32>) -> vector<3xf16> {
+  // CHECK: vector.bitcast
+  // CHECK-NEXT: vector.extract_strided_slice
+  %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
+  %0 = vector.extract_strided_slice %cast {offsets = [0], sizes = [3], strides = [1]} : vector<8xf16> to vector<3xf16>
+  return %0: vector<3xf16>
+}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert
+//  CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>)
+func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> {
+  // CHECK-DAG: %[[CAST_SRC1:.+]] = vector.bitcast %[[SRC1]] : vector<4xf16> to vector<2xf32>
+  // CHECK-DAG: %[[CAST_SRC2:.+]] = vector.bitcast %[[SRC2]] : vector<4xf16> to vector<2xf32>
+  // CHECK-DAG: %[[CAST_DST:.+]] = vector.bitcast %[[DST]] : vector<8xf16> to vector<4xf32>
+  // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST_SRC1]], %[[CAST_DST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+  // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[CAST_SRC2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+  %0 = vector.insert_strided_slice %src1, %dst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
+  %1 = vector.insert_strided_slice %src2, %0   {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16>
+  %cast = vector.bitcast %1: vector<8xf16> to vector<4xf32>
+  // CHECK: return %[[INSERT2]]
+  return %cast: vector<4xf32>
+}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_offset
+func @bubble_up_bitcast_in_strided_slice_insert_odd_offset(%dst: vector<8xf16>, %src: vector<4xf16>) -> vector<4xf32> {
+  // CHECK: vector.insert_strided_slice
+  // CHECK-NEXT: vector.bitcast
+  %0 = vector.insert_strided_slice %src, %dst {offsets = [3], strides = [1]} : vector<4xf16> into vector<8xf16>
+  %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32>
+  return %cast: vector<4xf32>
+}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_
diff erent_rank
+func @bubble_up_bitcast_in_strided_slice_insert_
diff erent_rank(%dst: vector<16x4x8xf16>, %src: vector<2x4xf16>) -> vector<16x4x4xf32> {
+  // CHECK: vector.insert_strided_slice
+  // CHECK-NEXT: vector.bitcast
+  %0 = vector.insert_strided_slice %src, %dst {offsets = [0, 0, 2], strides = [1, 1]} : vector<2x4xf16> into vector<16x4x8xf16>
+  %cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32>
+  return %cast: vector<16x4x4xf32>
+}

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 109a9fcb65e3..61b17178ef59 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -45,6 +45,7 @@ struct TestVectorToVectorConversion
     }
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
+    populateBubbleVectorBitCastOpPatterns(patterns, ctx);
     populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }


        


More information about the Mlir-commits mailing list