[Mlir-commits] [mlir] [mlr][vector] Add more patterns to Vector Linearize transformation (PR #136193)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 17 13:28:01 PDT 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff HEAD~1 HEAD --extensions cpp -- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 6de5d0c5a..0c3cd18ea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -282,22 +282,24 @@ private:
/// source vector using ExtractStridedSliceOp and inserting them into the
/// destination vector using InsertStridedSliceOp.
/// Following,
-/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into vector<4x4xf32>
+/// vector.insert_strided_slice %s, %d {offsets=[0, 0]}: vector<2x4xf32> into
+/// vector<4x4xf32>
/// is converted to :
-/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
-/// %1 = vector.insert_strided_slice %0, %d {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32>
-/// %2 = vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} : vector<4xf32> from vector<8xf32>
-/// %3 = vector.insert_strided_slice %2, %1 {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
+/// %0 = vector.extract_strided_slice %s {offsets=[0], sizes=[4], strides=[1]}
+/// : vector<4xf32> from vector<8xf32> %1 = vector.insert_strided_slice %0, %d
+/// {offsets=[0], strides=[1]} : vector<4xf32> into vector<16xf32> %2 =
+/// vector.extract_strided_slice %s {offsets=[4], sizes=[4], strides=[1]} :
+/// vector<4xf32> from vector<8xf32> %3 = vector.insert_strided_slice %2, %1
+/// {offsets=[4], strides=[1]} : vector<4xf32> into vector<16xf32>
struct LinearizeVectorInsertStridedSlice final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
- using OpConversionPattern<
- vector::InsertStridedSliceOp>::OpConversionPattern;
- LinearizeVectorInsertStridedSlice(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
+ LinearizeVectorInsertStridedSlice(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
@@ -345,8 +347,9 @@ struct LinearizeVectorInsertStridedSlice final
rewriter.replaceOp(op, dstValue);
return success();
}
- private:
- unsigned targetVectorBitWidth;
+
+private:
+ unsigned targetVectorBitWidth;
};
/// This pattern converts the ShuffleOp that works on nD (n > 1)
@@ -619,22 +622,22 @@ private:
/// is converted to :
/// %result = arith.constant dense<0.0> : vector<4x4xf32>
/// %slice_0 = vector.load %base[%indices] : vector<4xf32>
-/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32>
-/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
-/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32>
+/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into
+/// vector<4x4xf32> %slice_1 = vector.load %base[%indices + 1] : vector<4xf32>
+/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into
+/// vector<4x4xf32>
/// ...
/// This unrolls the 2D vector load into multiple 1D vector loads and inserts
/// them into the result vector. The pattern currently supports only 2D vectors
-struct LinearizeVectorLoad final
- : public OpConversionPattern<vector::LoadOp> {
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern<vector::LoadOp>::OpConversionPattern;
LinearizeVectorLoad(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
@@ -648,35 +651,33 @@ struct LinearizeVectorLoad final
}
auto unrollCount = shape[0];
auto vecSize = shape[1];
- auto newVecType =
- VectorType::get({vecSize}, vecType.getElementType());
+ auto newVecType = VectorType::get({vecSize}, vecType.getElementType());
llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
Value xBaseIndex = indices[0];
// Construct the 2D vector.
- Value resultVec = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(vecType));
+ Value resultVec =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecType));
// Emit unrolled loads for each 1D vector slice.
for (auto i = 0; i < unrollCount; i++) {
Value xIndex = xBaseIndex;
if (i) {
auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
- xIndex =
- rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
}
indices[0] = xIndex;
- auto vec = rewriter.create<vector::LoadOp>(
- loc, newVecType, adaptor.getBase(), indices);
- resultVec =
- rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
+ auto vec = rewriter.create<vector::LoadOp>(loc, newVecType,
+ adaptor.getBase(), indices);
+ resultVec = rewriter.create<vector::InsertOp>(loc, vec, resultVec, i);
}
rewriter.replaceOp(loadOp, resultVec);
return success();
}
- private:
- unsigned targetVectorBitWidth;
+
+private:
+ unsigned targetVectorBitWidth;
};
/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp
@@ -689,19 +690,19 @@ struct LinearizeVectorLoad final
/// %slice_1 = vector.extract %source[1] : vector<4xf32>
/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32>
/// ...
-/// This unrolls the 2D vector store into multiple 1D vector stores by extracting
-/// slices from the source vector and storing them into the destination.
-/// The pattern currently supports only 2D vectors
+/// This unrolls the 2D vector store into multiple 1D vector stores by
+/// extracting slices from the source vector and storing them into the
+/// destination. The pattern currently supports only 2D vectors
struct LinearizeVectorStore final
: public OpConversionPattern<vector::StoreOp> {
using OpConversionPattern<vector::StoreOp>::OpConversionPattern;
LinearizeVectorStore(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
@@ -718,26 +719,26 @@ struct LinearizeVectorStore final
llvm::SmallVector<Value, 4> indices = adaptor.getIndices();
Value xBaseIndex = indices[0];
- auto vec = rewriter.create<vector::ShapeCastOp>(
- loc, vecType, adaptor.getValueToStore());
+ auto vec = rewriter.create<vector::ShapeCastOp>(loc, vecType,
+ adaptor.getValueToStore());
for (auto i = 0; i < unrollCount; i++) {
auto vecSlice = rewriter.create<vector::ExtractOp>(loc, vec, i);
Value xIndex = xBaseIndex;
if (i) {
auto increment = rewriter.create<arith::ConstantIndexOp>(loc, i);
- xIndex =
- rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
+ xIndex = rewriter.create<arith::AddIOp>(loc, xBaseIndex, increment);
}
indices[0] = xIndex;
rewriter.create<vector::StoreOp>(loc, vecSlice, adaptor.getBase(),
- indices);
+ indices);
}
rewriter.eraseOp(storeOp);
return success();
}
- private:
- unsigned targetVectorBitWidth;
+
+private:
+ unsigned targetVectorBitWidth;
};
/// This pattern converts the SplatOp to work on a linearized vector.
@@ -754,11 +755,11 @@ struct LinearizeVectorSplat final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorSplat(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
@@ -766,12 +767,13 @@ struct LinearizeVectorSplat final
auto dstTy = getTypeConverter()->convertType(splatOp.getType());
if (!dstTy)
return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
- rewriter.replaceOpWithNewOp<vector::SplatOp>(
- splatOp, adaptor.getInput(), dstTy);
+ rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
+ dstTy);
return success();
}
- private:
- unsigned targetVectorBitWidth;
+
+private:
+ unsigned targetVectorBitWidth;
};
/// This pattern converts the CreateMaskOp to work on a
@@ -789,11 +791,11 @@ struct LinearizeVectorCreateMask final
using OpConversionPattern::OpConversionPattern;
LinearizeVectorCreateMask(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
@@ -816,8 +818,9 @@ struct LinearizeVectorCreateMask final
createMaskOp, dstTy, adaptor.getOperands().back());
return success();
}
- private:
- unsigned targetVectorBitWidth;
+
+private:
+ unsigned targetVectorBitWidth;
};
/// This pattern converts operations implementing the RegionBranchOpInterface
@@ -835,15 +838,14 @@ struct LinearizeRegionBranchOp final
RegionBranchOpInterface>::OpInterfaceConversionPattern;
LinearizeRegionBranchOp(
- const TypeConverter &typeConverter, MLIRContext *context,
- unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
- PatternBenefit benefit = 1)
- : OpInterfaceConversionPattern(typeConverter, context, benefit),
- targetVectorBitWidth(targetVectBitWidth) {}
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
- matchAndRewrite(RegionBranchOpInterface op,
- ArrayRef<Value> operands,
+ matchAndRewrite(RegionBranchOpInterface op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto converter = getTypeConverter();
@@ -907,8 +909,9 @@ struct LinearizeRegionBranchOp final
rewriter.finalizeOpModification(op);
return success();
}
- private:
- unsigned targetVectorBitWidth;
+
+private:
+ unsigned targetVectorBitWidth;
};
} // namespace
@@ -939,9 +942,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
target.addLegalOp<mlir::vector::ShapeCastOp>();
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<vector::BitCastOp, vector::LoadOp,
- vector::StoreOp, vector::CreateMaskOp,
- RegionBranchOpInterface, vector::SplatOp>(op) ||
+ if ((isa<vector::BitCastOp, vector::LoadOp, vector::StoreOp,
+ vector::CreateMaskOp, RegionBranchOpInterface,
+ vector::SplatOp>(op) ||
op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
@@ -951,12 +954,11 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorSplat,
- LinearizeVectorCreateMask, LinearizeRegionBranchOp
- >(typeConverter, patterns.getContext(),
- targetBitWidth);
+ patterns
+ .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+ LinearizeVectorLoad, LinearizeVectorStore, LinearizeVectorSplat,
+ LinearizeVectorCreateMask, LinearizeRegionBranchOp>(
+ typeConverter, patterns.getContext(), targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
@@ -972,16 +974,16 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
});
target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
- [=](vector::InsertStridedSliceOp op) -> bool {
- if(isLessThanTargetBitWidth(op, targetBitWidth)) {
- auto srcTy = op.getSourceVectorType();
- auto dstTy = op.getDestVectorType();
- if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
- srcTy.hasStaticShape() && dstTy.hasStaticShape())
- return false;
- }
- return true;
- });
+ [=](vector::InsertStridedSliceOp op) -> bool {
+ if (isLessThanTargetBitWidth(op, targetBitWidth)) {
+ auto srcTy = op.getSourceVectorType();
+ auto dstTy = op.getDestVectorType();
+ if (!op.hasNonUnitStrides() && srcTy.getRank() == 2 &&
+ srcTy.hasStaticShape() && dstTy.hasStaticShape())
+ return false;
+ }
+ return true;
+ });
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
``````````
</details>
https://github.com/llvm/llvm-project/pull/136193
More information about the Mlir-commits
mailing list