[Mlir-commits] [mlir] [mlir][Vector] Support vector.insert in bubbling bitcast patterns (PR #82843)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Feb 23 16:25:07 PST 2024


================
@@ -710,6 +710,83 @@ struct BubbleDownBitCastForStridedSliceExtract
   }
 };
 
+// Shuffles vector.bitcast op before vector.insert_strided_slice op.
+//
+// This transforms IR like:
+//   %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
+//   %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
+// Into:
+//   %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
+//   %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
+//   %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
+//
+struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType castSrcType = bitcastOp.getSourceVectorType();
+    VectorType castDstType = bitcastOp.getResultVectorType();
+    assert(castSrcType.getRank() == castDstType.getRank());
+    // Skip 0-D vector which will not from InsertStridedSliceOp.
+    if (castSrcType.getRank() == 0)
+      return failure();
+
+    int64_t castSrcLastDim = castSrcType.getShape().back();
+    int64_t castDstLastDim = castDstType.getShape().back();
+    bool isShrink = castSrcLastDim >= castDstLastDim;
+    int64_t ratio;
+    if (isShrink) {
+      assert(castSrcLastDim % castDstLastDim == 0);
+      ratio = castSrcLastDim / castDstLastDim;
+    } else {
+      assert(castDstLastDim % castSrcLastDim == 0);
+      ratio = castDstLastDim / castSrcLastDim;
+    }
+
+    auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
+    if (!insertOp)
+      return failure();
+
+    // Only vector sources are supported for now.
+    auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
+    if (!insertSrcType)
+      return failure();
+
+    // Requires that shape of insert op src is castable to dstType.
+    unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
+    unsigned destinationWidth =
+        castDstType.getElementType().getIntOrFloatBitWidth();
+    unsigned numElements = isShrink ? destinationWidth / sourceWidth
+                                    : sourceWidth / destinationWidth;
+    if (insertSrcType.getNumElements() % numElements != 0)
+      return failure();
----------------
MacDue wrote:

For a 'grow' the source number of elements would not have to be divisible by the ratio, right? 

i.e. `vector<1xi8> to vector<2xi4>` is fine?

https://github.com/llvm/llvm-project/pull/82843


More information about the Mlir-commits mailing list