[Mlir-commits] [mlir] 7ce2c3d - [mlir][Vector] Add 0-d vector support to 'vector.shape_cast`
Diego Caballero
llvmlistbot at llvm.org
Tue May 23 10:24:33 PDT 2023
Author: Diego Caballero
Date: 2023-05-23T17:22:55Z
New Revision: 7ce2c3d71b552310019f3dc4b0219528226a4778
URL: https://github.com/llvm/llvm-project/commit/7ce2c3d71b552310019f3dc4b0219528226a4778
DIFF: https://github.com/llvm/llvm-project/commit/7ce2c3d71b552310019f3dc4b0219528226a4778.diff
LOG: [mlir][Vector] Add 0-d vector support to 'vector.shape_cast`
This patch adds support to shape cast a vector<1x1x1...1xElemenType> to
a vector<ElementType> and the other way around.
Differential Revision: https://reviews.llvm.org/D151169
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 09c8d6a2b2831..d783a248b252a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2099,8 +2099,8 @@ def Vector_CompressStoreOp :
def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [Pure]>,
- Arguments<(ins AnyVector:$source)>,
- Results<(outs AnyVector:$result)> {
+ Arguments<(ins AnyVectorOfAnyRank:$source)>,
+ Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
let description = [{
The shape_cast operation casts between an n-D source vector shape and
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 64b64c6ae71de..aac677792e768 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4615,6 +4615,13 @@ static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
unsigned rankB = b.size();
assert(rankA < rankB);
+ auto isOne = [](int64_t v) { return v == 1; };
+
+ // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
+ // casted to a 0-d vector.
+ if (rankA == 0 && llvm::all_of(b, isOne))
+ return true;
+
unsigned i = 0;
unsigned j = 0;
while (i < rankA && j < rankB) {
@@ -4628,7 +4635,6 @@ static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
// Handle the case when trailing dimensions are of size 1.
// Include them into the contiguous sequence.
- auto isOne = [](int64_t v) { return v == 1; };
if (i < rankA && llvm::all_of(a.slice(i), isOne))
i = rankA;
if (j < rankB && llvm::all_of(b.slice(j), isOne))
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index aadbc14004da1..19488c5cbeda0 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -458,6 +458,18 @@ func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
return %0, %1, %2, %3 : vector<15x2xf32>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>
}
+// CHECK-LABEL: @shape_cast_0d
+func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
+
+ // CHECK: vector.shape_cast %{{.*}} : vector<1x1x1x1xf32> to vector<f32>
+ %0 = vector.shape_cast %arg0 : vector<1x1x1x1xf32> to vector<f32>
+
+ // CHECK: vector.shape_cast %{{.*}} : vector<f32> to vector<1x1x1x1xf32>
+ %1 = vector.shape_cast %0 : vector<f32> to vector<1x1x1x1xf32>
+
+ return %1 : vector<1x1x1x1xf32>
+}
+
// CHECK-LABEL: @bitcast
func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x1xi32>,
More information about the Mlir-commits
mailing list