[Mlir-commits] [mlir] 5ec59f0 - [mlir][AVX512] Fix result type of vp2intersect
Matthias Springer
llvmlistbot at llvm.org
Sat Jan 30 19:06:29 PST 2021
Author: Matthias Springer
Date: 2021-01-31T12:03:46+09:00
New Revision: 5ec59f021ceb09cff32c0fb4c24310362d08ea63
URL: https://github.com/llvm/llvm-project/commit/5ec59f021ceb09cff32c0fb4c24310362d08ea63
DIFF: https://github.com/llvm/llvm-project/commit/5ec59f021ceb09cff32c0fb4c24310362d08ea63.diff
LOG: [mlir][AVX512] Fix result type of vp2intersect
The result values of vp2intersect are vectors of bits, i.e.,
vector<8xi1> or vector<16xi8> (instead of i8 or i16).
Differential Revision: https://reviews.llvm.org/D95678
Added:
Modified:
mlir/include/mlir/Dialect/AVX512/AVX512.td
mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
mlir/test/Dialect/AVX512/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td
index 95e4c5e886fd..7140b013967a 100644
--- a/mlir/include/mlir/Dialect/AVX512/AVX512.td
+++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td
@@ -100,14 +100,14 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
AllTypesMatch<["a", "b"]>,
TypesMatchWith<"k1 has the same number of bits as elements in a",
"a", "k1",
- "IntegerType::get($_self.getContext(), "
- "($_self.cast<VectorType>().getShape()[0]))">,
+ "VectorType::get({$_self.cast<VectorType>().getShape()[0]}, "
+ "IntegerType::get($_self.getContext(), 1))">,
TypesMatchWith<"k2 has the same number of bits as elements in b",
// Should use `b` instead of `a`, but that would require
// adding `type($b)` to assemblyFormat.
"a", "k2",
- "IntegerType::get($_self.getContext(), "
- "($_self.cast<VectorType>().getShape()[0]))">]> {
+ "VectorType::get({$_self.cast<VectorType>().getShape()[0]}, "
+ "IntegerType::get($_self.getContext(), 1))">]> {
let summary = "Vp2Intersect op";
let description = [{
The `vp2intersect` op is an AVX512 specific op that can lower to the proper
@@ -126,8 +126,8 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a,
VectorOfLengthAndType<[16, 8], [I32, I64]>:$b
);
- let results = (outs AnyTypeOf<[I16, I8]>:$k1,
- AnyTypeOf<[I16, I8]>:$k2
+ let results = (outs VectorOfLengthAndType<[16, 8], [I1, I1]>:$k1,
+ VectorOfLengthAndType<[16, 8], [I1, I1]>:$k2
);
let assemblyFormat =
"$a `,` $b attr-dict `:` type($a)";
diff --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
index 7862d5d46073..b6f7ad8e196e 100644
--- a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
@@ -18,11 +18,11 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i1
}
func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
- -> (i16, i16, i8, i8)
+ -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
{
// CHECK: llvm_avx512.vp2intersect.d.512
%0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: llvm_avx512.vp2intersect.q.512
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
- return %0, %1, %2, %3 : i16, i16, i8, i8
+ return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
}
diff --git a/mlir/test/Dialect/AVX512/roundtrip.mlir b/mlir/test/Dialect/AVX512/roundtrip.mlir
index 6423bbfdbe54..865f9185b821 100644
--- a/mlir/test/Dialect/AVX512/roundtrip.mlir
+++ b/mlir/test/Dialect/AVX512/roundtrip.mlir
@@ -21,11 +21,11 @@ func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16,
}
func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
- -> (i16, i16, i8, i8)
+ -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
{
// CHECK: avx512.vp2intersect {{.*}} : vector<16xi32>
%0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: avx512.vp2intersect {{.*}} : vector<8xi64>
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
- return %0, %1, %2, %3 : i16, i16, i8, i8
+ return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
}
More information about the Mlir-commits
mailing list