[Mlir-commits] [mlir] 7ce2c3d - [mlir][Vector] Add 0-d vector support to 'vector.shape_cast`

Diego Caballero llvmlistbot at llvm.org
Tue May 23 10:24:33 PDT 2023


Author: Diego Caballero
Date: 2023-05-23T17:22:55Z
New Revision: 7ce2c3d71b552310019f3dc4b0219528226a4778

URL: https://github.com/llvm/llvm-project/commit/7ce2c3d71b552310019f3dc4b0219528226a4778
DIFF: https://github.com/llvm/llvm-project/commit/7ce2c3d71b552310019f3dc4b0219528226a4778.diff

LOG: [mlir][Vector] Add 0-d vector support to 'vector.shape_cast`

This patch adds support to shape cast a vector<1x1x1...1xElemenType> to
a vector<ElementType> and the other way around.

Differential Revision: https://reviews.llvm.org/D151169

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 09c8d6a2b2831..d783a248b252a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2099,8 +2099,8 @@ def Vector_CompressStoreOp :
 
 def Vector_ShapeCastOp :
   Vector_Op<"shape_cast", [Pure]>,
-    Arguments<(ins AnyVector:$source)>,
-    Results<(outs AnyVector:$result)> {
+    Arguments<(ins AnyVectorOfAnyRank:$source)>,
+    Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "shape_cast casts between vector shapes";
   let description = [{
     The shape_cast operation casts between an n-D source vector shape and

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 64b64c6ae71de..aac677792e768 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4615,6 +4615,13 @@ static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
   unsigned rankB = b.size();
   assert(rankA < rankB);
 
+  auto isOne = [](int64_t v) { return v == 1; };
+
+  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
+  // casted to a 0-d vector.
+  if (rankA == 0 && llvm::all_of(b, isOne))
+    return true;
+
   unsigned i = 0;
   unsigned j = 0;
   while (i < rankA && j < rankB) {
@@ -4628,7 +4635,6 @@ static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
 
     // Handle the case when trailing dimensions are of size 1.
     // Include them into the contiguous sequence.
-    auto isOne = [](int64_t v) { return v == 1; };
     if (i < rankA && llvm::all_of(a.slice(i), isOne))
       i = rankA;
     if (j < rankB && llvm::all_of(b.slice(j), isOne))

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index aadbc14004da1..19488c5cbeda0 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -458,6 +458,18 @@ func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
   return %0, %1, %2, %3 : vector<15x2xf32>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>
 }
 
+// CHECK-LABEL: @shape_cast_0d
+func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<1x1x1x1xf32> to vector<f32>
+  %0 = vector.shape_cast %arg0 : vector<1x1x1x1xf32> to vector<f32>
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<f32> to vector<1x1x1x1xf32>
+  %1 = vector.shape_cast %0 : vector<f32> to vector<1x1x1x1xf32>
+
+  return %1 : vector<1x1x1x1xf32>
+}
+
 // CHECK-LABEL: @bitcast
 func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : vector<8x1xi32>,


        


More information about the Mlir-commits mailing list