[llvm-branch-commits] [mlir] 8955e9f - [mlir][linalg] Fix bug in elementwise vectorization
Thomas Raoux via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Dec 14 10:51:49 PST 2020
Author: Thomas Raoux
Date: 2020-12-14T10:44:36-08:00
New Revision: 8955e9f6b75d436f92235531f003540401cd4b30
URL: https://github.com/llvm/llvm-project/commit/8955e9f6b75d436f92235531f003540401cd4b30
DIFF: https://github.com/llvm/llvm-project/commit/8955e9f6b75d436f92235531f003540401cd4b30.diff
LOG: [mlir][linalg] Fix bug in elementwise vectorization
Fix a bug causing to pick the wrong vector size to broadcast to when the source
vectors have different ranks.
Differential Revision: https://reviews.llvm.org/D93118
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a28b90b1d95c..2df1a9469eab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -216,6 +216,7 @@ class GenericVectorizer {
if (!vecType)
continue;
if (maxSize < vecType.getNumElements()) {
+ maxSize = vecType.getNumElements();
largestShape.assign(vecType.getShape().begin(),
vecType.getShape().end());
}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 1c3533275e49..6019dde49983 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -169,7 +169,7 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
%11 = mulf %arg5, %8 : f32
%12 = rsqrt %arg5 : f32
%13 = select %7, %arg5, %arg6 : f32
- %14 = subf %arg5, %arg6 : f32
+ %14 = subf %arg5, %arg4 : f32
%15 = tanh %arg5 : f32
linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
f32, f32, f32, f32, f32, f32, f32, f32
@@ -196,7 +196,8 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
-// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V1]] : vector<4x256xf32>
+// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
// CHECK: vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
// CHECK: vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
More information about the llvm-branch-commits
mailing list