[Mlir-commits] [mlir] [mlir][Vector] Support vector.insert in bubbling bitcast patterns (PR #82843)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Feb 23 16:25:07 PST 2024
================
@@ -710,6 +710,83 @@ struct BubbleDownBitCastForStridedSliceExtract
}
};
+// Shuffles vector.bitcast op before vector.insert_strided_slice op.
+//
+// This transforms IR like:
+// %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
+// %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
+// Into:
+// %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
+// %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
+// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
+//
+struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
+ PatternRewriter &rewriter) const override {
+ VectorType castSrcType = bitcastOp.getSourceVectorType();
+ VectorType castDstType = bitcastOp.getResultVectorType();
+ assert(castSrcType.getRank() == castDstType.getRank());
+ // Skip 0-D vector which will not from InsertStridedSliceOp.
+ if (castSrcType.getRank() == 0)
+ return failure();
+
+ int64_t castSrcLastDim = castSrcType.getShape().back();
+ int64_t castDstLastDim = castDstType.getShape().back();
+ bool isShrink = castSrcLastDim >= castDstLastDim;
+ int64_t ratio;
+ if (isShrink) {
+ assert(castSrcLastDim % castDstLastDim == 0);
+ ratio = castSrcLastDim / castDstLastDim;
+ } else {
+ assert(castDstLastDim % castSrcLastDim == 0);
+ ratio = castDstLastDim / castSrcLastDim;
+ }
+
+ auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
+ if (!insertOp)
+ return failure();
+
+ // Only vector sources are supported for now.
+ auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
+ if (!insertSrcType)
+ return failure();
+
+ // Requires that shape of insert op src is castable to dstType.
+ unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
+ unsigned destinationWidth =
+ castDstType.getElementType().getIntOrFloatBitWidth();
+ unsigned numElements = isShrink ? destinationWidth / sourceWidth
+ : sourceWidth / destinationWidth;
+ if (insertSrcType.getNumElements() % numElements != 0)
+ return failure();
+
+ // Bitcast the source.
+ SmallVector<int64_t> srcDims = llvm::to_vector<4>(insertSrcType.getShape());
+ srcDims.back() = isShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
+ VectorType newCastSrcType =
+ VectorType::get(srcDims, castDstType.getElementType());
+ auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
+ bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
+
+ SmallVector<int64_t> dstDims =
+ llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
+ dstDims.back() = isShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
+ VectorType newCastDstType =
+ VectorType::get(dstDims, castDstType.getElementType());
----------------
MacDue wrote:
This rewrite does not check for scalability and drops the scalable dims here.
https://github.com/llvm/llvm-project/pull/82843
More information about the Mlir-commits
mailing list