[Mlir-commits] [mlir] [mlir][Vector] Support vector.insert in bubbling bitcast patterns (PR #82843)

Diego Caballero llvmlistbot at llvm.org
Tue Feb 27 15:24:59 PST 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/82843

>From f9779adacfd6a9007de7bf95f9bf832a6c7642e1 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 23 Feb 2024 23:17:21 +0000
Subject: [PATCH 1/4] [mlir][Vector] Support vector.insert in bubbling bitcast
 patterns

This PR is adds support for `vector.insert` to the patterns that bubble up
and down vector.bitcat ops across `vector.extract/extract_slice/insert_slice`
ops.
---
 .../Vector/Transforms/VectorTransforms.cpp    | 82 ++++++++++++++++++-
 .../Dialect/Vector/vector-transforms.mlir     | 33 ++++++++
 2 files changed, 113 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 74dd1dfaca0da8..278f02bb498291 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -710,6 +710,84 @@ struct BubbleDownBitCastForStridedSliceExtract
   }
 };
 
+// Shuffles vector.bitcast op before vector.insert_strided_slice op.
+//
+// This transforms IR like:
+//   %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
+//   %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
+// Into:
+//   %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
+//   %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
+//   %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
+//
+struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType castSrcType = bitcastOp.getSourceVectorType();
+    VectorType castDstType = bitcastOp.getResultVectorType();
+    assert(castSrcType.getRank() == castDstType.getRank());
+    // Skip 0-D vector which will not from InsertStridedSliceOp.
+    if (castSrcType.getRank() == 0)
+      return failure();
+
+    int64_t castSrcLastDim = castSrcType.getShape().back();
+    int64_t castDstLastDim = castDstType.getShape().back();
+    bool isShrink = castSrcLastDim >= castDstLastDim;
+    int64_t ratio;
+    if (isShrink) {
+      assert(castSrcLastDim % castDstLastDim == 0);
+      ratio = castSrcLastDim / castDstLastDim;
+    } else {
+      assert(castDstLastDim % castSrcLastDim == 0);
+      ratio = castDstLastDim / castSrcLastDim;
+    }
+
+    auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
+    if (!insertOp)
+      return failure();
+
+    // Only vector sources are supported for now.
+    auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
+    if (!insertSrcType)
+      return failure();
+
+    // Requires that shape of insert op src is castable to dstType.
+    unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
+    unsigned destinationWidth =
+        castDstType.getElementType().getIntOrFloatBitWidth();
+    unsigned numElements = isShrink ? destinationWidth / sourceWidth
+                                    : sourceWidth / destinationWidth;
+    if (insertSrcType.getNumElements() % numElements != 0)
+      return failure();
+
+    // Bitcast the source.
+    SmallVector<int64_t> srcDims =
+        llvm::to_vector<4>(insertSrcType.getShape());
+    srcDims.back() = isShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
+    VectorType newCastSrcType =
+        VectorType::get(srcDims, castDstType.getElementType());
+    auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
+        bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
+
+    SmallVector<int64_t> dstDims =
+        llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
+    dstDims.back() = isShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
+    VectorType newCastDstType =
+        VectorType::get(dstDims, castDstType.getElementType());
+
+    // Bitcast the destination.
+    auto newCastDstOp = rewriter.create<vector::BitCastOp>(
+        bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
+
+    // Generate new insert.
+    rewriter.replaceOpWithNewOp<vector::InsertOp>(
+        bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
+    return success();
+  }
+};
+
 // Shuffles vector.bitcast op before vector.insert_strided_slice op.
 //
 // This transforms IR like:
@@ -1782,8 +1860,8 @@ void mlir::vector::populateBubbleVectorBitCastOpPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<BubbleDownVectorBitCastForExtract,
                BubbleDownBitCastForStridedSliceExtract,
-               BubbleUpBitCastForStridedSliceInsert>(patterns.getContext(),
-                                                     benefit);
+               BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index ea10bd56390c78..f10feaf7654c53 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -339,6 +339,39 @@ func.func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4
   return %0: vector<3xf16>
 }
 
+// CHECK-LABEL:   func.func @bubble_up_bitcast_in_insert_i4_i8(
+// CHECK-SAME:                                                 %[[VAL:.*]]: vector<32xi4>,
+// CHECK-SAME:                                                 %[[DST:.*]]: vector<8x32xi4>) -> vector<8x16xi8> {
+func.func @bubble_up_bitcast_in_insert_i4_i8(%val: vector<32xi4>, %src: vector<8x32xi4>) -> vector<8x16xi8> {
+// CHECK:           %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<32xi4> to vector<16xi8>
+// CHECK:           %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xi8> into vector<8x16xi8>
+  %0 = vector.insert %val, %src[4] : vector<32xi4> into vector<8x32xi4>
+  %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
+  return %1 : vector<8x16xi8>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_bitcast_in_insert_i8_i4(
+// CHECK-SAME:                                                 %[[VAL:.*]]: vector<16xi8>,
+// CHECK-SAME:                                                 %[[DST:.*]]: vector<8x16xi8>) -> vector<8x32xi4> {
+func.func @bubble_up_bitcast_in_insert_i8_i4(%val: vector<16xi8>, %src: vector<8x16xi8>) -> vector<8x32xi4> {
+// CHECK:           %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi8> to vector<32xi4>
+// CHECK:           %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi8> to vector<8x32xi4>
+// CHECK:           vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<32xi4> into vector<8x32xi4>
+  %0 = vector.insert %val, %src[4] : vector<16xi8> into vector<8x16xi8>
+  %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4>
+  return %1 : vector<8x32xi4>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_bitcast_in_insert_scalar(
+func.func @bubble_up_bitcast_in_insert_scalar(%val: i8, %src: vector<8x16xi8>) -> vector<8x32xi4> {
+// CHECK:           vector.insert
+// CHECK-NEXT:      vector.bitcast
+  %0 = vector.insert %val, %src[4, 8] : i8 into vector<8x16xi8>
+  %1 = vector.bitcast %0 : vector<8x16xi8> to vector<8x32xi4>
+  return %1 : vector<8x32xi4>
+}
+
 // CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert
 //  CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>)
 func.func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> {

>From 2b69d35e245ebb96bd8160927d7721ff49d3b861 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 23 Feb 2024 23:32:16 +0000
Subject: [PATCH 2/4] Clang format

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 278f02bb498291..ad295928cb12b1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -763,8 +763,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
       return failure();
 
     // Bitcast the source.
-    SmallVector<int64_t> srcDims =
-        llvm::to_vector<4>(insertSrcType.getShape());
+    SmallVector<int64_t> srcDims = llvm::to_vector<4>(insertSrcType.getShape());
     srcDims.back() = isShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
     VectorType newCastSrcType =
         VectorType::get(srcDims, castDstType.getElementType());

>From 8c98b9bb65cd8dc710bc04b0f2946988cbb68611 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Sat, 24 Feb 2024 01:19:50 +0000
Subject: [PATCH 3/4] Feedback

---
 .../Vector/Transforms/VectorTransforms.cpp    | 32 ++++++++-----------
 1 file changed, 14 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ad295928cb12b1..b85ef050d6cfbd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -728,15 +728,17 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
     VectorType castSrcType = bitcastOp.getSourceVectorType();
     VectorType castDstType = bitcastOp.getResultVectorType();
     assert(castSrcType.getRank() == castDstType.getRank());
-    // Skip 0-D vector which will not from InsertStridedSliceOp.
-    if (castSrcType.getRank() == 0)
+
+    // 0-D and scalable vectors are not supported yet.
+    if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
+        castDstType.isScalable())
       return failure();
 
     int64_t castSrcLastDim = castSrcType.getShape().back();
     int64_t castDstLastDim = castDstType.getShape().back();
-    bool isShrink = castSrcLastDim >= castDstLastDim;
+    bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
     int64_t ratio;
-    if (isShrink) {
+    if (isNumElemsShrink) {
       assert(castSrcLastDim % castDstLastDim == 0);
       ratio = castSrcLastDim / castDstLastDim;
     } else {
@@ -753,26 +755,20 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
     if (!insertSrcType)
       return failure();
 
-    // Requires that shape of insert op src is castable to dstType.
-    unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
-    unsigned destinationWidth =
-        castDstType.getElementType().getIntOrFloatBitWidth();
-    unsigned numElements = isShrink ? destinationWidth / sourceWidth
-                                    : sourceWidth / destinationWidth;
-    if (insertSrcType.getNumElements() % numElements != 0)
-      return failure();
-
     // Bitcast the source.
-    SmallVector<int64_t> srcDims = llvm::to_vector<4>(insertSrcType.getShape());
-    srcDims.back() = isShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
+    auto insertSrcShape = insertSrcType.getShape();
+    SmallVector<int64_t> srcDims(insertSrcShape.begin(), insertSrcShape.end());
+    srcDims.back() =
+        isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
     VectorType newCastSrcType =
         VectorType::get(srcDims, castDstType.getElementType());
     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
 
-    SmallVector<int64_t> dstDims =
-        llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
-    dstDims.back() = isShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
+    auto dstShape = insertOp.getDestVectorType().getShape();
+    SmallVector<int64_t> dstDims(dstShape.begin(), dstShape.end());
+    dstDims.back() =
+        isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
     VectorType newCastDstType =
         VectorType::get(dstDims, castDstType.getElementType());
 

>From 36a4d4c6823a4ff2726e721687bed07ea5f27cba Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Tue, 27 Feb 2024 23:24:02 +0000
Subject: [PATCH 4/4] Addressed Hanhan's feedback

---
 .../Dialect/Vector/Transforms/VectorTransforms.cpp   |  7 ++-----
 mlir/test/Dialect/Vector/vector-transforms.mlir      | 12 ++++++++++++
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b85ef050d6cfbd..a2d4e216633181 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -727,7 +727,6 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
                                 PatternRewriter &rewriter) const override {
     VectorType castSrcType = bitcastOp.getSourceVectorType();
     VectorType castDstType = bitcastOp.getResultVectorType();
-    assert(castSrcType.getRank() == castDstType.getRank());
 
     // 0-D and scalable vectors are not supported yet.
     if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
@@ -756,8 +755,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
       return failure();
 
     // Bitcast the source.
-    auto insertSrcShape = insertSrcType.getShape();
-    SmallVector<int64_t> srcDims(insertSrcShape.begin(), insertSrcShape.end());
+    SmallVector<int64_t> srcDims(insertSrcType.getShape());
     srcDims.back() =
         isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
     VectorType newCastSrcType =
@@ -765,8 +763,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
     auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
         bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
 
-    auto dstShape = insertOp.getDestVectorType().getShape();
-    SmallVector<int64_t> dstDims(dstShape.begin(), dstShape.end());
+    SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
     dstDims.back() =
         isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
     VectorType newCastDstType =
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index f10feaf7654c53..eda6a5cc40d999 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -363,6 +363,18 @@ func.func @bubble_up_bitcast_in_insert_i8_i4(%val: vector<16xi8>, %src: vector<8
   return %1 : vector<8x32xi4>
 }
 
+// CHECK-LABEL:   func.func @bubble_up_bitcast_in_insert_i32_f32(
+// CHECK-SAME:                                                 %[[VAL:.*]]: vector<16xi32>,
+// CHECK-SAME:                                                 %[[DST:.*]]: vector<8x16xi32>) -> vector<8x16xf32> {
+func.func @bubble_up_bitcast_in_insert_i32_f32(%val: vector<16xi32>, %src: vector<8x16xi32>) -> vector<8x16xf32> {
+// CHECK:           %[[BC_VAL:.*]] = vector.bitcast %[[VAL]] : vector<16xi32> to vector<16xf32>
+// CHECK:           %[[BC_DST:.*]] = vector.bitcast %[[DST]] : vector<8x16xi32> to vector<8x16xf32>
+// CHECK:           vector.insert %[[BC_VAL]], %[[BC_DST]] [4] : vector<16xf32> into vector<8x16xf32>
+  %0 = vector.insert %val, %src[4] : vector<16xi32> into vector<8x16xi32>
+  %1 = vector.bitcast %0 : vector<8x16xi32> to vector<8x16xf32>
+  return %1 : vector<8x16xf32>
+}
+
 // CHECK-LABEL:   func.func @bubble_up_bitcast_in_insert_scalar(
 func.func @bubble_up_bitcast_in_insert_scalar(%val: i8, %src: vector<8x16xi8>) -> vector<8x32xi4> {
 // CHECK:           vector.insert



More information about the Mlir-commits mailing list