[Mlir-commits] [mlir] 6bb4fc9 - Fix a corner case in vector.shape_cast when the trailing dimensions are of size 1.
Wen-Heng Chung
llvmlistbot at llvm.org
Mon Jun 22 20:00:56 PDT 2020
Author: Wen-Heng (Jack) Chung
Date: 2020-06-22T22:00:45-05:00
New Revision: 6bb4fc93c2fd7f63c7ed430928d1b85bfd4b3d79
URL: https://github.com/llvm/llvm-project/commit/6bb4fc93c2fd7f63c7ed430928d1b85bfd4b3d79
DIFF: https://github.com/llvm/llvm-project/commit/6bb4fc93c2fd7f63c7ed430928d1b85bfd4b3d79.diff
LOG: Fix a corner case in vector.shape_cast when the trailing dimensions are of size 1.
Differential Revision: https://reviews.llvm.org/D82304
Added:
Modified:
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 019f5fd94621..5d3a916d02ea 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1633,6 +1633,14 @@ static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
if (dimA != dimB)
break;
++i;
+
+ // Handle the case when trailing dimensions are of size 1.
+ // Include them into the contiguous sequence.
+ auto isOne = [](int64_t v) { return v == 1; };
+ if (i < rankA && llvm::all_of(a.slice(i), isOne))
+ i = rankA;
+ if (j < rankB && llvm::all_of(b.slice(j), isOne))
+ j = rankB;
}
return i == rankA && j == rankB;
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 02ee4dd3883b..4ea72864ea94 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -266,8 +266,10 @@ func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
// CHECK-LABEL: @shape_cast
func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
- %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>)
- -> (vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>) {
+ %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>,
+ %arg2 : vector<8x1xf32>,
+ %arg3 : vector<16x1x1xf32>)
+ -> (vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>) {
// CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<15x2xf32>
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
@@ -276,7 +278,16 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
%1 = vector.shape_cast %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
tuple<vector<20x2xf32>, vector<12x2xf32>>
- return %0, %1 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x1xf32> to vector<8xf32>
+ %2 = vector.shape_cast %arg2 : vector<8x1xf32> to vector<8xf32>
+
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<16x1x1xf32> to vector<16xf32>
+ %3 = vector.shape_cast %arg3 : vector<16x1x1xf32> to vector<16xf32>
+
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<16x1x1xf32> to vector<16x1xf32>
+ %4 = vector.shape_cast %arg3 : vector<16x1x1xf32> to vector<16x1xf32>
+
+ return %0, %1, %2, %3, %4 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>
}
// CHECK-LABEL: @vector_fma
More information about the Mlir-commits
mailing list