[Mlir-commits] [mlir] [mlir][Vector] Support vector.insert in bubbling bitcast patterns (PR #82843)
Diego Caballero
llvmlistbot at llvm.org
Fri Feb 23 15:32:33 PST 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/82843
>From f9779adacfd6a9007de7bf95f9bf832a6c7642e1 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 23 Feb 2024 23:17:21 +0000
Subject: [PATCH 1/2] [mlir][Vector] Support vector.insert in bubbling bitcast
patterns
This PR is adds support for `vector.insert` to the patterns that bubble up
and down vector.bitcat ops across `vector.extract/extract_slice/insert_slice`
ops.
---
.../Vector/Transforms/VectorTransforms.cpp | 82 ++++++++++++++++++-
.../Dialect/Vector/vector-transforms.mlir | 33 ++++++++
2 files changed, 113 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 74dd1dfaca0da8..278f02bb498291 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -710,6 +710,84 @@ struct BubbleDownBitCastForStridedSliceExtract
}
};
+// Shuffles vector.bitcast op before vector.insert_strided_slice op.
+//
+// This transforms IR like:
+// %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
+// %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
+// Into:
+// %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
+// %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
+// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
+//
+struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
+ PatternRewriter &rewriter) const override {
+ VectorType castSrcType = bitcastOp.getSourceVectorType();
+ VectorType castDstType = bitcastOp.getResultVectorType();
+ assert(castSrcType.getRank() == castDstType.getRank());
+ // Skip 0-D vector which will not from InsertStridedSliceOp.
+ if (castSrcType.getRank() == 0)
+ return failure();
+
+ int64_t castSrcLastDim = castSrcType.getShape().back();
+ int64_t castDstLastDim = castDstType.getShape().back();
+ bool isShrink = castSrcLastDim >= castDstLastDim;
+ int64_t ratio;
+ if (isShrink) {
+ assert(castSrcLastDim % castDstLastDim == 0);
+ ratio = castSrcLastDim / castDstLastDim;
+ } else {
+ assert(castDstLastDim % castSrcLastDim == 0);
+ ratio = castDstLastDim / castSrcLastDim;
+ }
+
+ auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
+ if (!insertOp)
+ return failure();
+
+ // Only vector sources are supported for now.
+ auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
+ if (!insertSrcType)
+ return failure();
+
+ // Requires that shape of insert op src is castable to dstType.
+ unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
+ unsigned destinationWidth =
+ castDstType.getElementType().getIntOrFloatBitWidth();
+ unsigned numElements = isShrink ? destinationWidth / sourceWidth
+ : sourceWidth / destinationWidth;
+ if (insertSrcType.getNumElements() % numElements != 0)
+ return failure();
+
+ // Bitcast the source.
+ SmallVector<int64_t> srcDims =
+ llvm::to_vector<4>(insertSrcType.getShape());
+ srcDims.back() = isShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
+ VectorType newCastSrcType =
+ VectorType::get(srcDims, castDstType.getElementType());
+ auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
+ bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
+
+ SmallVector<int64_t> dstDims =
+ llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
+ dstDims.back() = isShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
+ VectorType newCastDstType =
+ VectorType::get(dstDims, castDstType.getElementType());
+
+ // Bitcast the destination.
+ auto newCastDstOp = rewriter.create<vector::BitCastOp>(
+ bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
+
+ // Generate new insert.
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
+ return success();
+ }
+};
+
// Shuffles vector.bitcast op before vector.insert_strided_slice op.
//
// This transforms IR like:
@@ -1782,8 +1860,8 @@ void mlir::vector::populateBubbleVectorBitCastOpPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<BubbleDownVectorBitCastForExtract,
BubbleDownBitCastForStridedSliceExtract,
- BubbleUpBitCastForStridedSliceInsert>(patterns.getContext(),
- benefit);
+ BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
+ patterns.getContext(), benefit);
}
void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index ea10bd56390c78..f10feaf7654c53 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -339,6 +339,39 @@ func.func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4
return %0: vector<3xf16>
}
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i4_i8(
+// CHECK-SAME: %[[VAL:.*]]: vector<32xi4>,
+// CHECK-SAME: %[[DST:.*]]: vector<8x32xi4>) -> vector<8x16xi8> {
+func.func @bubble_up_bitcast_in_insert_i4_i8(%val: vector<32xi4>, %src: vector<8x32xi4>) -> vector<8x16xi8> {
+// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<32xi4> to vector<16xi8>
+// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xi8> into vector<8x16xi8>
+ %0 = vector.insert %val, %src[4] : vector<32xi4> into vector<8x32xi4>
+ %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
+ return %1 : vector<8x16xi8>
+}
+
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i8_i4(
+// CHECK-SAME: %[[VAL:.*]]: vector<16xi8>,
+// CHECK-SAME: %[[DST:.*]]: vector<8x16xi8>) -> vector<8x32xi4> {
+func.func @bubble_up_bitcast_in_insert_i8_i4(%val: vector<16xi8>, %src: vector<8x16xi8>) -> vector<8x32xi4> {
+// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi8> to vector<32xi4>
+// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi8> to vector<8x32xi4>
+// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<32xi4> into vector<8x32xi4>
+ %0 = vector.insert %val, %src[4] : vector<16xi8> into vector<8x16xi8>
+ %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4>
+ return %1 : vector<8x32xi4>
+}
+
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_scalar(
+func.func @bubble_up_bitcast_in_insert_scalar(%val: i8, %src: vector<8x16xi8>) -> vector<8x32xi4> {
+// CHECK: vector.insert
+// CHECK-NEXT: vector.bitcast
+ %0 = vector.insert %val, %src[4, 8] : i8 into vector<8x16xi8>
+ %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4>
+ return %1 : vector<8x32xi4>
+}
+
// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert
// CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>)
func.func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> {
>From 2b69d35e245ebb96bd8160927d7721ff49d3b861 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 23 Feb 2024 23:32:16 +0000
Subject: [PATCH 2/2] Clang format
---
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 278f02bb498291..ad295928cb12b1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -763,8 +763,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
return failure();
// Bitcast the source.
- SmallVector<int64_t> srcDims =
- llvm::to_vector<4>(insertSrcType.getShape());
+ SmallVector<int64_t> srcDims = llvm::to_vector<4>(insertSrcType.getShape());
srcDims.back() = isShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
VectorType newCastSrcType =
VectorType::get(srcDims, castDstType.getElementType());
More information about the Mlir-commits
mailing list