[Mlir-commits] [mlir] [mlir][vector] Add more patterns to Vector Linearize transformation (PR #136193)
Ivan Butygin
llvmlistbot at llvm.org
Thu Apr 17 13:47:26 PDT 2025
================
@@ -531,12 +615,312 @@ struct LinearizeVectorBitCast final
unsigned targetVectorBitWidth;
};
+/// This pattern converts the LoadOp to a series of LoadOp & InsertOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.load %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %result = arith.constant dense<0.0> : vector<4x4xf32>
+/// %slice_0 = vector.load %base[%indices] : vector<4xf32>
+/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into
+/// vector<4x4xf32> %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into
+/// vector<4x4xf32>
+/// ...
+/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
+/// them into the result vector. The pattern currently supports only 2D vectors
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
+
+ LinearizeVectorLoad(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = loadOp->getLoc();
+ auto vecType = loadOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2) {
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+ }
+ auto unrollCount = shape[0];
+ auto vecSize = shape[1];
+ auto newVecType = VectorType::get({vecSize}, vecType.getElementType());
+
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ // Construct the 2D vector.
+ Value resultVec =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecType));
+ // Emit unrolled loads for each 1D vector slice.
+ for (auto i = 0; i < unrollCount; i++) {
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ auto vec = rewriter.create<vector::LoadOp>(loc, newVecType,
+ adaptor.getBase(), indices);
+ resultVec = rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+ }
+
+ rewriter.replaceOp(loadOp, resultVec);
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
+/// that works on a linearized vector.
+/// Following,
+/// vector.store %source, %base[%indices] : vector<4x4xf32>
+/// is converted to :
+/// %slice_0 = vector.extract %source[0] : vector<4xf32>
+/// vector.store %slice_0, %base[%indices] : vector<4xf32>
+/// %slice_1 = vector.extract %source[1] : vector<4xf32>
+/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
+/// ...
+/// This unrolls the 2D vector store into multiple 1D vector stores by
+/// extracting slices from the source vector and storing them into the
+/// destination. The pattern currently supports only 2D vectors
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
+
+ LinearizeVectorStore(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = storeOp->getLoc();
+ auto vecType = storeOp.getVectorType();
+ auto shape = vecType.getShape();
+
+ if (shape.size() != 2) {
+ return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors.");
+ }
+
+ auto unrollCount = shape[0];
+ llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
+ Value xBaseIndex = indices[0];
+
+ auto vec = rewriter.create<vector::ShapeCastOp>(loc, vecType,
+ adaptor.getValueToStore());
+
+ for (auto i = 0; i < unrollCount; i++) {
+ auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
+ Value xIndex = xBaseIndex;
+ if (i) {
+ auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ }
+ indices[0] = xIndex;
+ rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
+ indices);
+ }
+ rewriter.eraseOp(storeOp);
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the SplatOp to work on a linearized vector.
+/// Following,
+/// vector.splat %value : vector<4x4xf32>
+/// is converted to:
+/// %out_1d = vector.splat %value : vector<16xf32>
+/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
+/// It ensures that the operation is compatible with the target vector
+/// bit width and replaces the original operation with a new SplatOp
+/// that operates on the converted type.
+struct LinearizeVectorSplat final
+ : public OpConversionPattern<vector::SplatOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorSplat(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstTy = getTypeConverter()->convertType(splatOp.getType());
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
+ dstTy);
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. It ensures that the operation is compatible with the
+/// target vector bit width and replaces the original operation with a new
+/// CreateMaskOp that operates on the converted type. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+/// %out_1d = vector.create_mask %dims : vector<4xi1>
+/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+ createMaskOp, dstTy, adaptor.getOperands().back());
+ return success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts operations implementing the RegionBranchOpInterface
+/// to ensure compatibility with linearized vector types. It updates the
+/// operands, result types, and region types (block arguments and yields) to
+/// match the converted types. Additionally, it processes yields within each
+/// region to ensure that the types of yielded values are compatible with the
+/// target vector bit width. If the result types of the operation are updated,
+/// shape cast operations are inserted to maintain compatibility with the
+/// original types. This pattern ensures that operations with regions are
+/// properly linearized and remain valid after type conversion.
+struct LinearizeRegionBranchOp final
----------------
Hardcode84 wrote:
Instead of adding this pattern, can you test `populateSCFStructuralTypeConversionsAndLegality`?
https://github.com/llvm/llvm-project/pull/136193
More information about the Mlir-commits
mailing list