[Mlir-commits] [mlir] 4623c11 - [mlir][Vector] Support vector.insert in bubbling bitcast patterns (#82843)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 28 08:15:51 PST 2024
Author: Diego Caballero
Date: 2024-02-28T08:15:47-08:00
New Revision: 4623c114fb51e48b5cbb352adecb60d76077cb0a
URL: https://github.com/llvm/llvm-project/commit/4623c114fb51e48b5cbb352adecb60d76077cb0a
DIFF: https://github.com/llvm/llvm-project/commit/4623c114fb51e48b5cbb352adecb60d76077cb0a.diff
LOG: [mlir][Vector] Support vector.insert in bubbling bitcast patterns (#82843)
This PR is adds support for `vector.insert` to the patterns that bubble up and down `vector.bitcat` ops across `vector.extract/extract_slice/insert_slice` ops.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 74dd1dfaca0da8..a2d4e216633181 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -710,6 +710,76 @@ struct BubbleDownBitCastForStridedSliceExtract
}
};
+// Shuffles vector.bitcast op before vector.insert_strided_slice op.
+//
+// This transforms IR like:
+// %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
+// %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
+// Into:
+// %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
+// %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
+// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
+//
+struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
+ PatternRewriter &rewriter) const override {
+ VectorType castSrcType = bitcastOp.getSourceVectorType();
+ VectorType castDstType = bitcastOp.getResultVectorType();
+
+ // 0-D and scalable vectors are not supported yet.
+ if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
+ castDstType.isScalable())
+ return failure();
+
+ int64_t castSrcLastDim = castSrcType.getShape().back();
+ int64_t castDstLastDim = castDstType.getShape().back();
+ bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
+ int64_t ratio;
+ if (isNumElemsShrink) {
+ assert(castSrcLastDim % castDstLastDim == 0);
+ ratio = castSrcLastDim / castDstLastDim;
+ } else {
+ assert(castDstLastDim % castSrcLastDim == 0);
+ ratio = castDstLastDim / castSrcLastDim;
+ }
+
+ auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
+ if (!insertOp)
+ return failure();
+
+ // Only vector sources are supported for now.
+ auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
+ if (!insertSrcType)
+ return failure();
+
+ // Bitcast the source.
+ SmallVector<int64_t> srcDims(insertSrcType.getShape());
+ srcDims.back() =
+ isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
+ VectorType newCastSrcType =
+ VectorType::get(srcDims, castDstType.getElementType());
+ auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
+ bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
+
+ SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
+ dstDims.back() =
+ isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
+ VectorType newCastDstType =
+ VectorType::get(dstDims, castDstType.getElementType());
+
+ // Bitcast the destination.
+ auto newCastDstOp = rewriter.create<vector::BitCastOp>(
+ bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
+
+ // Generate new insert.
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
+ return success();
+ }
+};
+
// Shuffles vector.bitcast op before vector.insert_strided_slice op.
//
// This transforms IR like:
@@ -1782,8 +1852,8 @@ void mlir::vector::populateBubbleVectorBitCastOpPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<BubbleDownVectorBitCastForExtract,
BubbleDownBitCastForStridedSliceExtract,
- BubbleUpBitCastForStridedSliceInsert>(patterns.getContext(),
- benefit);
+ BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
+ patterns.getContext(), benefit);
}
void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index ea10bd56390c78..eda6a5cc40d999 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -339,6 +339,51 @@ func.func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4
return %0: vector<3xf16>
}
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i4_i8(
+// CHECK-SAME: %[[VAL:.*]]: vector<32xi4>,
+// CHECK-SAME: %[[DST:.*]]: vector<8x32xi4>) -> vector<8x16xi8> {
+func.func @bubble_up_bitcast_in_insert_i4_i8(%val: vector<32xi4>, %src: vector<8x32xi4>) -> vector<8x16xi8> {
+// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<32xi4> to vector<16xi8>
+// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xi8> into vector<8x16xi8>
+ %0 = vector.insert %val, %src[4] : vector<32xi4> into vector<8x32xi4>
+ %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
+ return %1 : vector<8x16xi8>
+}
+
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i8_i4(
+// CHECK-SAME: %[[VAL:.*]]: vector<16xi8>,
+// CHECK-SAME: %[[DST:.*]]: vector<8x16xi8>) -> vector<8x32xi4> {
+func.func @bubble_up_bitcast_in_insert_i8_i4(%val: vector<16xi8>, %src: vector<8x16xi8>) -> vector<8x32xi4> {
+// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi8> to vector<32xi4>
+// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi8> to vector<8x32xi4>
+// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<32xi4> into vector<8x32xi4>
+ %0 = vector.insert %val, %src[4] : vector<16xi8> into vector<8x16xi8>
+ %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4>
+ return %1 : vector<8x32xi4>
+}
+
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i32_f32(
+// CHECK-SAME: %[[VAL:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[DST:.*]]: vector<8x16xi32>) -> vector<8x16xf32> {
+func.func @bubble_up_bitcast_in_insert_i32_f32(%val: vector<16xi32>, %src: vector<8x16xi32>) -> vector<8x16xf32> {
+// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi32> to vector<16xf32>
+// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi32> to vector<8x16xf32>
+// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xf32> into vector<8x16xf32>
+ %0 = vector.insert %val, %src[4] : vector<16xi32> into vector<8x16xi32>
+ %1 = vector.bitcast %0 : vector<8x16xi32> to vector<8x16xf32>
+ return %1 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_scalar(
+func.func @bubble_up_bitcast_in_insert_scalar(%val: i8, %src: vector<8x16xi8>) -> vector<8x32xi4> {
+// CHECK: vector.insert
+// CHECK-NEXT: vector.bitcast
+ %0 = vector.insert %val, %src[4, 8] : i8 into vector<8x16xi8>
+ %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4>
+ return %1 : vector<8x32xi4>
+}
+
// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert
// CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>)
func.func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> {
More information about the Mlir-commits
mailing list