[Mlir-commits] [mlir] [mlir][vector] Disable `BreakDownVectorBitCast` for scalable vectors (PR #122725)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 13 07:26:39 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
`BreakDownVectorBitCast` leverages
* `vector.extract_strided_slices` + `vector.insert_strided_slices`
As these Ops do not support extracting scalable sub-vectors (i.e.
extracting/inserting a fraction of a scalable dim), it's best to bail
out.
---
Full diff: https://github.com/llvm/llvm-project/pull/122725.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+7)
- (modified) mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir (+11)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 21ec718efd6a7a..c88d8daaf44c91 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -906,6 +906,13 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
VectorType castDstType = bitcastOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
+ // This transformation builds on top of
+ // vector.{extract|insert}_strided_slice, which do not support
+ // extracting/inserting "scallable sub-vectors". Bail out.
+ if (castSrcType.isScalable())
+ return rewriter.notifyMatchFailure(bitcastOp,
+ "Scalable vectors are not supported");
+
// Only support rank 1 case for now.
if (castSrcType.getRank() != 1)
return failure();
diff --git a/mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir b/mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir
index fbb2f7605e6497..173388f63ecda5 100644
--- a/mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir
+++ b/mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir
@@ -39,3 +39,14 @@ func.func @bitcast_i8_to_i32(%input: vector<16xi8>) -> vector<4xi32> {
// CHECK: %[[CAST3:.+]] = vector.bitcast %[[EXTRACT3]] : vector<4xi8> to vector<1xi32>
// CHECK: %[[INSERT3:.+]] = vector.insert_strided_slice %[[CAST3]], %[[INSERT2]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
// CHECK: return %[[INSERT3]]
+
+// -----
+
+// Scalable vectors are not supported!
+
+// CHECK-LABEL: func.func @bitcast_scalable_negative
+// CHECK: vector.bitcast
+func.func @bitcast_scalable_negative(%input: vector<[8]xf16>) -> vector<[4]xf32> {
+ %0 = vector.bitcast %input : vector<[8]xf16> to vector<[4]xf32>
+ return %0: vector<[4]xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/122725
More information about the Mlir-commits
mailing list