[Mlir-commits] [mlir] [mlir][Vector] Support vector.insert in bubbling bitcast patterns (PR #82843)
Diego Caballero
llvmlistbot at llvm.org
Tue Feb 27 15:24:59 PST 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/82843
>From f9779adacfd6a9007de7bf95f9bf832a6c7642e1 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 23 Feb 2024 23:17:21 +0000
Subject: [PATCH 1/4] [mlir][Vector] Support vector.insert in bubbling bitcast
patterns
This PR is adds support for `vector.insert` to the patterns that bubble up
and down vector.bitcat ops across `vector.extract/extract_slice/insert_slice`
ops.
---
.../Vector/Transforms/VectorTransforms.cpp | 82 ++++++++++++++++++-
.../Dialect/Vector/vector-transforms.mlir | 33 ++++++++
2 files changed, 113 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 74dd1dfaca0da8..278f02bb498291 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -710,6 +710,84 @@ struct BubbleDownBitCastForStridedSliceExtract
}
};
+// Shuffles vector.bitcast op before vector.insert_strided_slice op.
+//
+// This transforms IR like:
+// %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
+// %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
+// Into:
+// %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
+// %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
+// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
+//
+struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
+ PatternRewriter &rewriter) const override {
+ VectorType castSrcType = bitcastOp.getSourceVectorType();
+ VectorType castDstType = bitcastOp.getResultVectorType();
+ assert(castSrcType.getRank() == castDstType.getRank());
+ // Skip 0-D vector which will not from InsertStridedSliceOp.
+ if (castSrcType.getRank() == 0)
+ return failure();
+
+ int64_t castSrcLastDim = castSrcType.getShape().back();
+ int64_t castDstLastDim = castDstType.getShape().back();
+ bool isShrink = castSrcLastDim >= castDstLastDim;
+ int64_t ratio;
+ if (isShrink) {
+ assert(castSrcLastDim % castDstLastDim == 0);
+ ratio = castSrcLastDim / castDstLastDim;
+ } else {
+ assert(castDstLastDim % castSrcLastDim == 0);
+ ratio = castDstLastDim / castSrcLastDim;
+ }
+
+ auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
+ if (!insertOp)
+ return failure();
+
+ // Only vector sources are supported for now.
+ auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
+ if (!insertSrcType)
+ return failure();
+
+ // Requires that shape of insert op src is castable to dstType.
+ unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
+ unsigned destinationWidth =
+ castDstType.getElementType().getIntOrFloatBitWidth();
+ unsigned numElements = isShrink ? destinationWidth / sourceWidth
+ : sourceWidth / destinationWidth;
+ if (insertSrcType.getNumElements() % numElements != 0)
+ return failure();
+
+ // Bitcast the source.
+ SmallVector<int64_t> srcDims =
+ llvm::to_vector<4>(insertSrcType.getShape());
+ srcDims.back() = isShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
+ VectorType newCastSrcType =
+ VectorType::get(srcDims, castDstType.getElementType());
+ auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
+ bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
+
+ SmallVector<int64_t> dstDims =
+ llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
+ dstDims.back() = isShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
+ VectorType newCastDstType =
+ VectorType::get(dstDims, castDstType.getElementType());
+
+ // Bitcast the destination.
+ auto newCastDstOp = rewriter.create<vector::BitCastOp>(
+ bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
+
+ // Generate new insert.
+ rewriter.replaceOpWithNewOp<vector::InsertOp>(
+ bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
+ return success();
+ }
+};
+
// Shuffles vector.bitcast op before vector.insert_strided_slice op.
//
// This transforms IR like:
@@ -1782,8 +1860,8 @@ void mlir::vector::populateBubbleVectorBitCastOpPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<BubbleDownVectorBitCastForExtract,
BubbleDownBitCastForStridedSliceExtract,
- BubbleUpBitCastForStridedSliceInsert>(patterns.getContext(),
- benefit);
+ BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
+ patterns.getContext(), benefit);
}
void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index ea10bd56390c78..f10feaf7654c53 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -339,6 +339,39 @@ func.func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4
return %0: vector<3xf16>
}
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i4_i8(
+// CHECK-SAME: %[[VAL:.*]]: vector<32xi4>,
+// CHECK-SAME: %[[DST:.*]]: vector<8x32xi4>) -> vector<8x16xi8> {
+func.func @bubble_up_bitcast_in_insert_i4_i8(%val: vector<32xi4>, %src: vector<8x32xi4>) -> vector<8x16xi8> {
+// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<32xi4> to vector<16xi8>
+// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xi8> into vector<8x16xi8>
+ %0 = vector.insert %val, %src[4] : vector<32xi4> into vector<8x32xi4>
+ %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
+ return %1 : vector<8x16xi8>
+}
+
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i8_i4(
+// CHECK-SAME: %[[VAL:.*]]: vector<16xi8>,
+// CHECK-SAME: %[[DST:.*]]: vector<8x16xi8>) -> vector<8x32xi4> {
+func.func @bubble_up_bitcast_in_insert_i8_i4(%val: vector<16xi8>, %src: vector<8x16xi8>) -> vector<8x32xi4> {
+// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi8> to vector<32xi4>
+// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi8> to vector<8x32xi4>
+// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<32xi4> into vector<8x32xi4>
+ %0 = vector.insert %val, %src[4] : vector<16xi8> into vector<8x16xi8>
+ %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4>
+ return %1 : vector<8x32xi4>
+}
+
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_scalar(
+func.func @bubble_up_bitcast_in_insert_scalar(%val: i8, %src: vector<8x16xi8>) -> vector<8x32xi4> {
+// CHECK: vector.insert
+// CHECK-NEXT: vector.bitcast
+ %0 = vector.insert %val, %src[4, 8] : i8 into vector<8x16xi8>
+ %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4>
+ return %1 : vector<8x32xi4>
+}
+
// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert
// CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>)
func.func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> {
>From 2b69d35e245ebb96bd8160927d7721ff49d3b861 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 23 Feb 2024 23:32:16 +0000
Subject: [PATCH 2/4] Clang format
---
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 278f02bb498291..ad295928cb12b1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -763,8 +763,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
return failure();
// Bitcast the source.
- SmallVector<int64_t> srcDims =
- llvm::to_vector<4>(insertSrcType.getShape());
+ SmallVector<int64_t> srcDims = llvm::to_vector<4>(insertSrcType.getShape());
srcDims.back() = isShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
VectorType newCastSrcType =
VectorType::get(srcDims, castDstType.getElementType());
>From 8c98b9bb65cd8dc710bc04b0f2946988cbb68611 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Sat, 24 Feb 2024 01:19:50 +0000
Subject: [PATCH 3/4] Feedback
---
.../Vector/Transforms/VectorTransforms.cpp | 32 ++++++++-----------
1 file changed, 14 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ad295928cb12b1..b85ef050d6cfbd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -728,15 +728,17 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
VectorType castSrcType = bitcastOp.getSourceVectorType();
VectorType castDstType = bitcastOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
- // Skip 0-D vector which will not from InsertStridedSliceOp.
- if (castSrcType.getRank() == 0)
+
+ // 0-D and scalable vectors are not supported yet.
+ if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
+ castDstType.isScalable())
return failure();
int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
- bool isShrink = castSrcLastDim >= castDstLastDim;
+ bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
int64_t ratio;
- if (isShrink) {
+ if (isNumElemsShrink) {
assert(castSrcLastDim % castDstLastDim == 0);
ratio = castSrcLastDim / castDstLastDim;
} else {
@@ -753,26 +755,20 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
if (!insertSrcType)
return failure();
- // Requires that shape of insert op src is castable to dstType.
- unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
- unsigned destinationWidth =
- castDstType.getElementType().getIntOrFloatBitWidth();
- unsigned numElements = isShrink ? destinationWidth / sourceWidth
- : sourceWidth / destinationWidth;
- if (insertSrcType.getNumElements() % numElements != 0)
- return failure();
-
// Bitcast the source.
- SmallVector<int64_t> srcDims = llvm::to_vector<4>(insertSrcType.getShape());
- srcDims.back() = isShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
+ auto insertSrcShape = insertSrcType.getShape();
+ SmallVector<int64_t> srcDims(insertSrcShape.begin(), insertSrcShape.end());
+ srcDims.back() =
+ isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
VectorType newCastSrcType =
VectorType::get(srcDims, castDstType.getElementType());
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
- SmallVector<int64_t> dstDims =
- llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
- dstDims.back() = isShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
+ auto dstShape = insertOp.getDestVectorType().getShape();
+ SmallVector<int64_t> dstDims(dstShape.begin(), dstShape.end());
+ dstDims.back() =
+ isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
VectorType newCastDstType =
VectorType::get(dstDims, castDstType.getElementType());
>From 36a4d4c6823a4ff2726e721687bed07ea5f27cba Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Tue, 27 Feb 2024 23:24:02 +0000
Subject: [PATCH 4/4] Addressed Hanhan's feedback
---
.../Dialect/Vector/Transforms/VectorTransforms.cpp | 7 ++-----
mlir/test/Dialect/Vector/vector-transforms.mlir | 12 ++++++++++++
2 files changed, 14 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b85ef050d6cfbd..a2d4e216633181 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -727,7 +727,6 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
PatternRewriter &rewriter) const override {
VectorType castSrcType = bitcastOp.getSourceVectorType();
VectorType castDstType = bitcastOp.getResultVectorType();
- assert(castSrcType.getRank() == castDstType.getRank());
// 0-D and scalable vectors are not supported yet.
if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
@@ -756,8 +755,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
return failure();
// Bitcast the source.
- auto insertSrcShape = insertSrcType.getShape();
- SmallVector<int64_t> srcDims(insertSrcShape.begin(), insertSrcShape.end());
+ SmallVector<int64_t> srcDims(insertSrcType.getShape());
srcDims.back() =
isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
VectorType newCastSrcType =
@@ -765,8 +763,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
- auto dstShape = insertOp.getDestVectorType().getShape();
- SmallVector<int64_t> dstDims(dstShape.begin(), dstShape.end());
+ SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
dstDims.back() =
isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
VectorType newCastDstType =
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index f10feaf7654c53..eda6a5cc40d999 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -363,6 +363,18 @@ func.func @bubble_up_bitcast_in_insert_i8_i4(%val: vector<16xi8>, %src: vector<8
return %1 : vector<8x32xi4>
}
+// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_i32_f32(
+// CHECK-SAME: %[[VAL:.*]]: vector<16xi32>,
+// CHECK-SAME: %[[DST:.*]]: vector<8x16xi32>) -> vector<8x16xf32> {
+func.func @bubble_up_bitcast_in_insert_i32_f32(%val: vector<16xi32>, %src: vector<8x16xi32>) -> vector<8x16xf32> {
+// CHECK: %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi32> to vector<16xf32>
+// CHECK: %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi32> to vector<8x16xf32>
+// CHECK: vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xf32> into vector<8x16xf32>
+ %0 = vector.insert %val, %src[4] : vector<16xi32> into vector<8x16xi32>
+ %1 = vector.bitcast %0 : vector<8x16xi32> to vector<8x16xf32>
+ return %1 : vector<8x16xf32>
+}
+
// CHECK-LABEL: func.func @bubble_up_bitcast_in_insert_scalar(
func.func @bubble_up_bitcast_in_insert_scalar(%val: i8, %src: vector<8x16xi8>) -> vector<8x32xi4> {
// CHECK: vector.insert
More information about the Mlir-commits
mailing list