[Mlir-commits] [mlir] 875bbce - [mlir][Vector] Prevent AVX2 lowering for non-f32 transpose ops
Diego Caballero
llvmlistbot at llvm.org
Fri Feb 25 11:30:20 PST 2022
Author: Diego Caballero
Date: 2022-02-25T19:27:32Z
New Revision: 875bbce9f7206b97a74e656dd31df7b1c7dd897d
URL: https://github.com/llvm/llvm-project/commit/875bbce9f7206b97a74e656dd31df7b1c7dd897d
DIFF: https://github.com/llvm/llvm-project/commit/875bbce9f7206b97a74e656dd31df7b1c7dd897d.diff
LOG: [mlir][Vector] Prevent AVX2 lowering for non-f32 transpose ops
The AVX2 lowering for transpose operations is only applicable to f32 vector types.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D120427
Added:
Modified:
mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
index 27272d12ac4dd..065848d003d9c 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -250,8 +250,11 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
auto loc = op.getLoc();
// Check if the source vector type is supported. AVX2 patterns can only be
- // applied if the vector type has two dimensions greater than one.
+ // applied to f32 vector types with two dimensions greater than one.
VectorType srcType = op.getVectorType();
+ if (!srcType.getElementType().isF32())
+ return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
+
SmallVector<int64_t> srcGtOneDims;
for (auto &en : llvm::enumerate(srcType.getShape()))
if (en.value() > 1)
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 44a59a2299dad..651006b07d2b7 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -548,6 +548,15 @@ func @transpose210_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32> {
// -----
+func @do_not_lower_nonf32_to_avx2(%arg0: vector<4x8xi32>) -> vector<8x4xi32> {
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x8xi32> to vector<8x4xi32>
+ return %0 : vector<8x4xi32>
+}
+
+// AVX2-NOT: vector.shuffle
+
+// -----
+
// AVX2-LABEL: func @transpose021_8x1x8
func @transpose021_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<8x8x1xf32> {
%0 = vector.transpose %arg0, [0, 2, 1] : vector<8x1x8xf32> to vector<8x8x1xf32>
More information about the Mlir-commits
mailing list