[Mlir-commits] [mlir] [mlir][vector] Support scalable vectors when unrolling vector.bitcast (PR #94197)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Jun 21 06:37:57 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/94197

>From 6f3b57f57493202b2688d67cb93bce2d39e52e1d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 11 Jun 2024 10:21:59 +0000
Subject: [PATCH 1/3] [mlir][vector] Support scalable vectors when unrolling
 vector.bitcast

Follow up to #94064.
---
 .../mlir/Dialect/Utils/IndexingUtils.h        |  5 ++++
 .../Vector/Transforms/LowerVectorBitCast.cpp  | 17 +++++--------
 .../vector-bitcast-lowering-transforms.mlir   | 24 ++++++++++++++++++-
 3 files changed, 34 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 9892253df2bff..b774359552aa5 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -287,6 +287,8 @@ class TileOffsetRangeImpl {
       return getDynamicTileOffsets(linearIndex);
   }
 
+  size_t getRank() const { return tileShape.size(); }
+
 private:
   /// The sub-shape that divides the larger outer shape (which is provided to
   /// the constructor).
@@ -388,6 +390,9 @@ class StaticTileOffsetRange {
   /// Returns the total number of tiles that fit in the larger shape.
   size_t size() const { return params.getMaxLinearIndex(); }
 
+  /// Returns rank of the iterator's shape.
+  size_t getRank() const { return params.getRank(); }
+
 private:
   const ParamsTy params;
   IteratorTy beginValue;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
index 092ec927c92ae..e5f11d82f277f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -56,17 +56,12 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
     if (!unrollIterator)
       return failure();
 
-    // TODO: Support the scalable vector cases. It is not supported because
-    // the final rank could be values other than `targetRank`. It makes creating
-    // the result type of new vector.bitcast ops much harder.
-    if (resultType.isScalable()) {
-      return rewriter.notifyMatchFailure(op,
-                                         "unrolling vector.bitcast on scalable "
-                                         "vectors is not yet implemented");
-    }
-
-    ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank);
-    auto bitcastResType = VectorType::get(shape, resultType.getElementType());
+    auto unrollRank = unrollIterator->getRank();
+    ArrayRef<int64_t> shape = resultType.getShape().drop_front(unrollRank);
+    ArrayRef<bool> scalableDims =
+        resultType.getScalableDims().drop_front(unrollRank);
+    auto bitcastResType =
+        VectorType::get(shape, resultType.getElementType(), scalableDims);
 
     Location loc = op.getLoc();
     Value result = rewriter.create<arith::ConstantOp>(
diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
index 23fece208c561..91abe3e200fa2 100644
--- a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -38,8 +38,30 @@ func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) ->
   return %0 : vector<1x2x[3]x8xi32>
 }
 // CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
-// CHECK:         vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
+// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
+// CHECK:         %[[INIT:.+]] = arith.constant dense<0> : vector<1x2x[3]x8xi32>
+// CHECK:         %[[V1:.+]] = vector.extract %[[IN]][0, 0] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64>
+// CHECK:         %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[3]x4xi64> to vector<[3]x8xi32>
+// CHECK:         %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0, 0] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32>
+// CHECK:         %[[V2:.+]] = vector.extract %[[IN]][0, 1] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64>
+// CHECK:         %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[3]x4xi64> to vector<[3]x8xi32>
+// CHECK:         %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [0, 1] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32>
+// CHECK:         return %[[R2]] : vector<1x2x[3]x8xi32>
 
+func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) -> vector<2x[4]xi32> {
+  %0 = vector.bitcast %arg0 : vector<2x[2]xi64> to vector<2x[4]xi32>
+  return %0 : vector<2x[4]xi32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_2d_trailing_scalable_dim
+// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
+// CHECK:         %[[INIT:.+]] = arith.constant dense<0> : vector<2x[4]xi32>
+// CHECK:         %[[V1:.+]] = vector.extract %[[IN]][0] : vector<[2]xi64> from vector<2x[2]xi64>
+// CHECK:         %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK:         %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] : vector<[4]xi32> into vector<2x[4]xi32>
+// CHECK:         %[[V2:.+]] = vector.extract %[[IN]][1] : vector<[2]xi64> from vector<2x[2]xi64>
+// CHECK:         %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[2]xi64> to vector<[4]xi32>
+// CHECK:         %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1] : vector<[4]xi32> into vector<2x[4]xi32>
+// CHECK:         return %[[R2]] : vector<2x[4]xi32>
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

>From d6162ea9b6a4dc5ff5dfec650e4609dd3b3aed0b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 12 Jun 2024 10:58:59 +0000
Subject: [PATCH 2/3] Add leading scalable dim test

---
 .../Vector/vector-bitcast-lowering-transforms.mlir    | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
index 91abe3e200fa2..68acbd78a1918 100644
--- a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -62,6 +62,17 @@ func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) ->
 // CHECK:         %[[B2:.+]] = vector.bitcast %[[V2]] : vector<[2]xi64> to vector<[4]xi32>
 // CHECK:         %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1] : vector<[4]xi32> into vector<2x[4]xi32>
 // CHECK:         return %[[R2]] : vector<2x[4]xi32>
+
+func.func @vector_bitcast_2d_leading_scalable_dim(%arg0: vector<[2]x2xi64>) -> vector<[2]x4xi32>
+{
+  %0 = vector.bitcast %arg0 : vector<[2]x2xi64> to vector<[2]x4xi32>
+  return %0 : vector<[2]x4xi32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_2d_leading_scalable_dim
+// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
+// CHECK:         %[[RES:.+]] = vector.bitcast %[[IN]] : vector<[2]x2xi64> to vector<[2]x4xi32>
+// CHECK:         return %[[RES]] : vector<[2]x4xi32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

>From 25f27ede13bc27609bddc4255179e3e0e3d594d4 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 21 Jun 2024 13:29:57 +0000
Subject: [PATCH 3/3] Fixup

---
 .../Vector/vector-bitcast-lowering-transforms.mlir       | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
index 68acbd78a1918..346291019451c 100644
--- a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -63,15 +63,14 @@ func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) ->
 // CHECK:         %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1] : vector<[4]xi32> into vector<2x[4]xi32>
 // CHECK:         return %[[R2]] : vector<2x[4]xi32>
 
-func.func @vector_bitcast_2d_leading_scalable_dim(%arg0: vector<[2]x2xi64>) -> vector<[2]x4xi32>
+func.func @negative_vector_bitcast_2d_leading_scalable_dim(%arg0: vector<[2]x2xi64>) -> vector<[2]x4xi32>
 {
   %0 = vector.bitcast %arg0 : vector<[2]x2xi64> to vector<[2]x4xi32>
   return %0 : vector<[2]x4xi32>
 }
-// CHECK-LABEL: func.func @vector_bitcast_2d_leading_scalable_dim
-// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
-// CHECK:         %[[RES:.+]] = vector.bitcast %[[IN]] : vector<[2]x2xi64> to vector<[2]x4xi32>
-// CHECK:         return %[[RES]] : vector<[2]x4xi32>
+// CHECK-LABEL: func.func @negative_vector_bitcast_2d_leading_scalable_dim
+// CHECK-NOT: vector.extract
+// CHECK-NOT: vector.insert
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {



More information about the Mlir-commits mailing list