[Mlir-commits] [mlir] dc5d541 - [mlir][vector] Support scalable vectors when unrolling vector.bitcast (#94197)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 21 06:38:23 PDT 2024
Author: Benjamin Maxwell
Date: 2024-06-21T14:38:19+01:00
New Revision: dc5d541081381e3dca80b982097596546e0619fc
URL: https://github.com/llvm/llvm-project/commit/dc5d541081381e3dca80b982097596546e0619fc
DIFF: https://github.com/llvm/llvm-project/commit/dc5d541081381e3dca80b982097596546e0619fc.diff
LOG: [mlir][vector] Support scalable vectors when unrolling vector.bitcast (#94197)
Follow up to #94064.
Added:
Modified:
mlir/include/mlir/Dialect/Utils/IndexingUtils.h
mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 9892253df2bff..b774359552aa5 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -287,6 +287,8 @@ class TileOffsetRangeImpl {
return getDynamicTileOffsets(linearIndex);
}
+ size_t getRank() const { return tileShape.size(); }
+
private:
/// The sub-shape that divides the larger outer shape (which is provided to
/// the constructor).
@@ -388,6 +390,9 @@ class StaticTileOffsetRange {
/// Returns the total number of tiles that fit in the larger shape.
size_t size() const { return params.getMaxLinearIndex(); }
+ /// Returns rank of the iterator's shape.
+ size_t getRank() const { return params.getRank(); }
+
private:
const ParamsTy params;
IteratorTy beginValue;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
index 092ec927c92ae..e5f11d82f277f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -56,17 +56,12 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
if (!unrollIterator)
return failure();
- // TODO: Support the scalable vector cases. It is not supported because
- // the final rank could be values other than `targetRank`. It makes creating
- // the result type of new vector.bitcast ops much harder.
- if (resultType.isScalable()) {
- return rewriter.notifyMatchFailure(op,
- "unrolling vector.bitcast on scalable "
- "vectors is not yet implemented");
- }
-
- ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank);
- auto bitcastResType = VectorType::get(shape, resultType.getElementType());
+ auto unrollRank = unrollIterator->getRank();
+ ArrayRef<int64_t> shape = resultType.getShape().drop_front(unrollRank);
+ ArrayRef<bool> scalableDims =
+ resultType.getScalableDims().drop_front(unrollRank);
+ auto bitcastResType =
+ VectorType::get(shape, resultType.getElementType(), scalableDims);
Location loc = op.getLoc();
Value result = rewriter.create<arith::ConstantOp>(
diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
index 23fece208c561..346291019451c 100644
--- a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -38,7 +38,39 @@ func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) ->
return %0 : vector<1x2x[3]x8xi32>
}
// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
-// CHECK: vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
+// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<1x2x[3]x8xi32>
+// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0, 0] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64>
+// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[3]x4xi64> to vector<[3]x8xi32>
+// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0, 0] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32>
+// CHECK: %[[V2:.+]] = vector.extract %[[IN]][0, 1] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64>
+// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[3]x4xi64> to vector<[3]x8xi32>
+// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [0, 1] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32>
+// CHECK: return %[[R2]] : vector<1x2x[3]x8xi32>
+
+func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) -> vector<2x[4]xi32> {
+ %0 = vector.bitcast %arg0 : vector<2x[2]xi64> to vector<2x[4]xi32>
+ return %0 : vector<2x[4]xi32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_2d_trailing_scalable_dim
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
+// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<2x[4]xi32>
+// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<[2]xi64> from vector<2x[2]xi64>
+// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] : vector<[4]xi32> into vector<2x[4]xi32>
+// CHECK: %[[V2:.+]] = vector.extract %[[IN]][1] : vector<[2]xi64> from vector<2x[2]xi64>
+// CHECK: %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK: %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1] : vector<[4]xi32> into vector<2x[4]xi32>
+// CHECK: return %[[R2]] : vector<2x[4]xi32>
+
+func.func @negative_vector_bitcast_2d_leading_scalable_dim(%arg0: vector<[2]x2xi64>) -> vector<[2]x4xi32>
+{
+ %0 = vector.bitcast %arg0 : vector<[2]x2xi64> to vector<[2]x4xi32>
+ return %0 : vector<[2]x4xi32>
+}
+// CHECK-LABEL: func.func @negative_vector_bitcast_2d_leading_scalable_dim
+// CHECK-NOT: vector.extract
+// CHECK-NOT: vector.insert
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
More information about the Mlir-commits
mailing list