[Mlir-commits] [mlir] bd5d361 - [mlir][vector] add support for linearizing vector.bitcast in VectorLinearize (#123110)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 27 12:41:36 PST 2025
Author: Chao Chen
Date: 2025-01-27T14:41:33-06:00
New Revision: bd5d361c059814435bab24189e79e01d94c7039d
URL: https://github.com/llvm/llvm-project/commit/bd5d361c059814435bab24189e79e01d94c7039d
DIFF: https://github.com/llvm/llvm-project/commit/bd5d361c059814435bab24189e79e01d94c7039d.diff
LOG: [mlir][vector] add support for linearizing vector.bitcast in VectorLinearize (#123110)
This PR adds support for converting Vector::BitCastOp working on ND
(N >1) vectors into the same op working on linearized (1D) vectors.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
mlir/test/Dialect/Vector/linearize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 68535ae5a7a5c6..3ecd585c5a26d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
auto resType =
getTypeConverter()->convertType<VectorType>(constOp.getType());
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
return rewriter.notifyMatchFailure(
loc,
"Cannot linearize a constant scalable vector that's not a splat");
- if (!resType)
- return rewriter.notifyMatchFailure(loc, "can't convert return type");
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
@@ -459,6 +460,45 @@ struct LinearizeVectorInsert final
private:
unsigned targetVectorBitWidth;
};
+
+/// This pattern converts the BitCastOp that works on nD (n > 1)
+/// vectors to a BitCastOp that works on linearized vectors.
+/// Following,
+/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
+/// is converted to :
+/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
+/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
+/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
+struct LinearizeVectorBitCast final
+ : public OpConversionPattern<vector::BitCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorBitCast(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+ LogicalResult
+ matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = castOp.getLoc();
+ auto resType = getTypeConverter()->convertType(castOp.getType());
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type.");
+
+ if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ loc, "Can't flatten since targetBitWidth <= OpSize");
+
+ rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
+ adaptor.getSource());
+ return mlir::success();
+ }
+
+private:
+ unsigned targetVectorBitWidth;
+};
+
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
@@ -485,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<arith::ConstantOp>(op) ||
+ if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
? typeConverter.isLegal(op)
@@ -494,8 +534,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- patterns.add<LinearizeConstant, LinearizeVectorizable>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ patterns
+ .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
+ typeConverter, patterns.getContext(), targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..99b1bbab1eede9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -179,7 +179,7 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
// ALL-LABEL: func.func @test_extract_strided_slice_1_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
// ALL-NOT: vector.shuffle
// ALL-NOT: vector.shape_cast
// ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
@@ -318,3 +318,68 @@ func.func @test_vector_extract_scalar() {
%0 = vector.extract %cst[0] : i32 from vector<4xi32>
return
}
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
+ // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x4xf32> to vector<16xf32>
+ // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<16xf32> to vector<32xf16>
+ // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<32xf16> to vector<4x8xf16>
+
+ // BW-128: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16>
+ // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16>
+ %1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+ return %1 : vector<4x8xf16>
+}
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ARG_0:.*]]: vector<4x2xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
+ // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32>
+ // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
+ // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16>
+ // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32>
+ // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
+ // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16>
+
+ // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x2xf32> to vector<4x4xf16>
+ %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16>
+ return %1 : vector<4x4xf16>
+}
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ARG_0:.*]]: vector<4x[2]xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
+ // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32>
+ // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+ // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16>
+ // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32>
+ // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+ // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16>
+
+ // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x[2]xf32> to vector<4x[4]xf16>
+ %1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16>
+ return %1 : vector<4x[4]xf16>
+}
+
+// -----
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
+func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
+ // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32>
+ // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+ // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16>
+ // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32>
+ // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+ // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16>
+
+ // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<[4]x2xf32> to vector<[4]x4xf16>
+ %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
+ return %1 : vector<[4]x4xf16>
+}
More information about the Mlir-commits
mailing list