[Mlir-commits] [mlir] [mlir][vector] add support for linearizing vector.bitcast in VectorLinearize (PR #123110)
Chao Chen
llvmlistbot at llvm.org
Mon Jan 27 07:25:48 PST 2025
https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/123110
>From d067a5ab18604996b9290e98701e3d8ac04efe7b Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 15 Jan 2025 19:10:38 +0000
Subject: [PATCH 1/5] add linearize pattern for bitcast
---
.../Vector/Transforms/VectorLinearize.cpp | 37 +++++++++++++++++--
mlir/test/Dialect/Vector/linearize.mlir | 21 ++++++++++-
2 files changed, 54 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 68535ae5a7a5c6..b450ea91fef651 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
auto resType =
getTypeConverter()->convertType<VectorType>(constOp.getType());
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
return rewriter.notifyMatchFailure(
loc,
"Cannot linearize a constant scalable vector that's not a splat");
- if (!resType)
- return rewriter.notifyMatchFailure(loc, "can't convert return type");
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
@@ -459,6 +460,35 @@ struct LinearizeVectorInsert final
private:
unsigned targetVectorBitWidth;
};
+
+struct LinearizeVectorBitCast final
+ : public OpConversionPattern<vector::BitCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorBitCast(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+ LogicalResult
+ matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = castOp.getLoc();
+ auto resType = getTypeConverter()->convertType(castOp.getType());
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type.");
+
+ if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ loc, "Can't flatten since targetBitWidth <= OpSize");
+
+ rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType, adaptor.getSource());
+ return mlir::success();
+ }
+private:
+ unsigned targetVectorBitWidth;
+};
+
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
@@ -486,6 +516,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
if ((isa<arith::ConstantOp>(op) ||
+ isa<vector::BitCastOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
? typeConverter.isLegal(op)
@@ -494,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- patterns.add<LinearizeConstant, LinearizeVectorizable>(
+ patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..0358c2637f72b2 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -179,7 +179,7 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
// ALL-LABEL: func.func @test_extract_strided_slice_1_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
// ALL-NOT: vector.shuffle
// ALL-NOT: vector.shape_cast
// ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
@@ -318,3 +318,22 @@ func.func @test_vector_extract_scalar() {
%0 = vector.extract %cst[0] : i32 from vector<4xi32>
return
}
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x1xf32>)
+func.func @test_vector_bitcast(%arg0: vector<4x1xf32>) -> vector<4x2xf16> {
+
+ // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
+ // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
+ // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+
+ // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
+ // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
+ // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+
+ // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4x2xf16>
+ %1 = vector.bitcast %arg0 : vector<4x1xf32> to vector<4x2xf16>
+ return %1 : vector<4x2xf16>
+}
>From 5f358831c7f25bc385a3d00f8a340766adbf1170 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 15 Jan 2025 19:36:22 +0000
Subject: [PATCH 2/5] code format
---
.../Dialect/Vector/Transforms/VectorLinearize.cpp | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b450ea91fef651..a89d2872e2434a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -482,9 +482,11 @@ struct LinearizeVectorBitCast final
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
- rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType, adaptor.getSource());
+ rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType,
+ adaptor.getSource());
return mlir::success();
}
+
private:
unsigned targetVectorBitWidth;
};
@@ -515,8 +517,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
- if ((isa<arith::ConstantOp>(op) ||
- isa<vector::BitCastOp>(op) ||
+ if ((isa<arith::ConstantOp>(op) || isa<vector::BitCastOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
? typeConverter.isLegal(op)
@@ -525,8 +526,9 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
- typeConverter, patterns.getContext(), targetBitWidth);
+ patterns
+ .add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
+ typeConverter, patterns.getContext(), targetBitWidth);
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
>From f0b3ae0d45ebaea3659fed486a8cc973673814f6 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 16 Jan 2025 15:18:09 +0000
Subject: [PATCH 3/5] update tests
---
mlir/test/Dialect/Vector/linearize.mlir | 37 +++++++++++++++++--------
1 file changed, 25 insertions(+), 12 deletions(-)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 0358c2637f72b2..bab5c6c15ed8f1 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -322,18 +322,31 @@ func.func @test_vector_extract_scalar() {
// -----
// ALL-LABEL: test_vector_bitcast
-// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x1xf32>)
-func.func @test_vector_bitcast(%arg0: vector<4x1xf32>) -> vector<4x2xf16> {
-
- // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
- // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
- // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x4xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
+ // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x4xf32> to vector<16xf32>
+ // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<16xf32> to vector<32xf16>
+ // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<32xf16> to vector<4x8xf16>
+
+ // BW-128: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16>
+ // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16>
+ %1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
+ return %1 : vector<4x8xf16>
+}
- // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
- // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
- // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+// -----
- // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4x2xf16>
- %1 = vector.bitcast %arg0 : vector<4x1xf32> to vector<4x2xf16>
- return %1 : vector<4x2xf16>
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x2xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
+ // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32>
+ // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<8xf32> to vector<16xf16>
+ // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<16xf16> to vector<4x4xf16>
+ // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32>
+ // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<8xf32> to vector<16xf16>
+ // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<16xf16> to vector<4x4xf16>
+
+ // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x2xf32> to vector<4x4xf16>
+ %1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16>
+ return %1 : vector<4x4xf16>
}
>From 25a8f39e7841d28c459a390175ea50cebf737b74 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 16 Jan 2025 19:34:32 +0000
Subject: [PATCH 4/5] add test for scalable vector
---
.../Vector/Transforms/VectorLinearize.cpp | 8 +++++
mlir/test/Dialect/Vector/linearize.mlir | 33 +++++++++++++++++++
2 files changed, 41 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index a89d2872e2434a..3ecd585c5a26d5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -461,6 +461,14 @@ struct LinearizeVectorInsert final
unsigned targetVectorBitWidth;
};
+/// This pattern converts the BitCastOp that works on nD (n > 1)
+/// vectors to a BitCastOp that works on linearized vectors.
+/// Following,
+/// vector.bitcast %v1: vector<4x2xf32> to vector<4x4xf16>
+/// is converted to :
+/// %v1_1d = vector.shape_cast %v1: vector<4x2xf32> to vector<8xf32>
+/// %out_1d = vector.bitcast %v1_1d: vector<8xf32> to vector<16xf16>
+/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
struct LinearizeVectorBitCast final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index bab5c6c15ed8f1..de757fb9e4c1a5 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -350,3 +350,36 @@ func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
%1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16>
return %1 : vector<4x4xf16>
}
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x[2]xf32>
+func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
+ // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32>
+ // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
+ // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<4x[4]xf16>
+ // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32>
+ // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
+ // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<4x[4]xf16>
+
+ // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<4x[4]xf16>
+ %1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16>
+ return %1 : vector<4x[4]xf16>
+}
+
+// -----
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: %[[ORIG_ARG:.*]]: vector<[4]x2xf32>
+func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
+ // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32>
+ // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
+ // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<[4]x4xf16>
+ // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32>
+ // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
+ // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<[4]x4xf16>
+
+ // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[4]x4xf16>
+ %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
+ return %1 : vector<[4]x4xf16>
+}
>From ea2e518149180e77c2652fff69b2c140b9cec95b Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 27 Jan 2025 15:14:26 +0000
Subject: [PATCH 5/5] fix naming in tests
---
mlir/test/Dialect/Vector/linearize.mlir | 58 ++++++++++++-------------
1 file changed, 29 insertions(+), 29 deletions(-)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index de757fb9e4c1a5..8279aac07245de 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -324,12 +324,12 @@ func.func @test_vector_extract_scalar() {
// ALL-LABEL: test_vector_bitcast
// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x4xf32>
func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
- // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x4xf32> to vector<16xf32>
- // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<16xf32> to vector<32xf16>
- // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<32xf16> to vector<4x8xf16>
+ // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x4xf32> to vector<16xf32>
+ // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<16xf32> to vector<32xf16>
+ // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<32xf16> to vector<4x8xf16>
- // BW-128: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16>
- // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16>
+ // BW-128: %[[UPCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16>
+ // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x4xf32> to vector<4x8xf16>
%1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
return %1 : vector<4x8xf16>
}
@@ -339,14 +339,14 @@ func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
// ALL-LABEL: test_vector_bitcast
// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x2xf32>
func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
- // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32>
- // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<8xf32> to vector<16xf16>
- // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<16xf16> to vector<4x4xf16>
- // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32>
- // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<8xf32> to vector<16xf16>
- // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<16xf16> to vector<4x4xf16>
-
- // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x2xf32> to vector<4x4xf16>
+ // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32>
+ // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
+ // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16>
+ // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x2xf32> to vector<8xf32>
+ // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
+ // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16>
+
+ // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x2xf32> to vector<4x4xf16>
%1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16>
return %1 : vector<4x4xf16>
}
@@ -356,14 +356,14 @@ func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
// ALL-LABEL: test_vector_bitcast
// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x[2]xf32>
func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
- // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32>
- // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
- // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<4x[4]xf16>
- // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32>
- // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
- // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<4x[4]xf16>
-
- // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<4x[4]xf16>
+ // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32>
+ // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+ // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16>
+ // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<[8]xf32>
+ // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+ // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16>
+
+ // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x[2]xf32> to vector<4x[4]xf16>
%1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16>
return %1 : vector<4x[4]xf16>
}
@@ -372,14 +372,14 @@ func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
// ALL-LABEL: test_vector_bitcast
// ALL-SAME: %[[ORIG_ARG:.*]]: vector<[4]x2xf32>
func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
- // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32>
- // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
- // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<[4]x4xf16>
- // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32>
- // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<[8]xf32> to vector<[16]xf16>
- // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<[16]xf16> to vector<[4]x4xf16>
-
- // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[4]x4xf16>
+ // DEFAULT: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32>
+ // DEFAULT: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+ // DEFAULT: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16>
+ // BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[8]xf32>
+ // BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
+ // BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16>
+
+ // BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<[4]x2xf32> to vector<[4]x4xf16>
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
return %1 : vector<[4]x4xf16>
}
More information about the Mlir-commits
mailing list