[Mlir-commits] [mlir] 5ec59f0 - [mlir][AVX512] Fix result type of vp2intersect

Matthias Springer llvmlistbot at llvm.org
Sat Jan 30 19:06:29 PST 2021


Author: Matthias Springer
Date: 2021-01-31T12:03:46+09:00
New Revision: 5ec59f021ceb09cff32c0fb4c24310362d08ea63

URL: https://github.com/llvm/llvm-project/commit/5ec59f021ceb09cff32c0fb4c24310362d08ea63
DIFF: https://github.com/llvm/llvm-project/commit/5ec59f021ceb09cff32c0fb4c24310362d08ea63.diff

LOG: [mlir][AVX512] Fix result type of vp2intersect

The result values of vp2intersect are vectors of bits, i.e.,
vector<8xi1> or vector<16xi8> (instead of i8 or i16).

Differential Revision: https://reviews.llvm.org/D95678

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/AVX512/AVX512.td
    mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
    mlir/test/Dialect/AVX512/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td
index 95e4c5e886fd..7140b013967a 100644
--- a/mlir/include/mlir/Dialect/AVX512/AVX512.td
+++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td
@@ -100,14 +100,14 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
   AllTypesMatch<["a", "b"]>,
   TypesMatchWith<"k1 has the same number of bits as elements in a",
                  "a", "k1",
-                 "IntegerType::get($_self.getContext(), "
-                 "($_self.cast<VectorType>().getShape()[0]))">,
+                 "VectorType::get({$_self.cast<VectorType>().getShape()[0]}, "
+                 "IntegerType::get($_self.getContext(), 1))">,
   TypesMatchWith<"k2 has the same number of bits as elements in b",
                  // Should use `b` instead of `a`, but that would require
                  // adding `type($b)` to assemblyFormat.
                  "a", "k2",
-                 "IntegerType::get($_self.getContext(), "
-                 "($_self.cast<VectorType>().getShape()[0]))">]> {
+                 "VectorType::get({$_self.cast<VectorType>().getShape()[0]}, "
+                 "IntegerType::get($_self.getContext(), 1))">]> {
   let summary = "Vp2Intersect op";
   let description = [{
     The `vp2intersect` op is an AVX512 specific op that can lower to the proper
@@ -126,8 +126,8 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
   let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a,
                    VectorOfLengthAndType<[16, 8], [I32, I64]>:$b
                    );
-  let results = (outs AnyTypeOf<[I16, I8]>:$k1,
-                 AnyTypeOf<[I16, I8]>:$k2
+  let results = (outs VectorOfLengthAndType<[16, 8], [I1, I1]>:$k1,
+                 VectorOfLengthAndType<[16, 8], [I1, I1]>:$k2
                  );
   let assemblyFormat =
     "$a `,` $b attr-dict `:` type($a)";

diff  --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
index 7862d5d46073..b6f7ad8e196e 100644
--- a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
@@ -18,11 +18,11 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i1
 }
 
 func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
-  -> (i16, i16, i8, i8)
+  -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
 {
   // CHECK: llvm_avx512.vp2intersect.d.512
   %0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
   // CHECK: llvm_avx512.vp2intersect.q.512
   %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
-  return %0, %1, %2, %3 : i16, i16, i8, i8
+  return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
 }

diff  --git a/mlir/test/Dialect/AVX512/roundtrip.mlir b/mlir/test/Dialect/AVX512/roundtrip.mlir
index 6423bbfdbe54..865f9185b821 100644
--- a/mlir/test/Dialect/AVX512/roundtrip.mlir
+++ b/mlir/test/Dialect/AVX512/roundtrip.mlir
@@ -21,11 +21,11 @@ func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16,
 }
 
 func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
-  -> (i16, i16, i8, i8)
+  -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
 {
   // CHECK: avx512.vp2intersect {{.*}} : vector<16xi32>
   %0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
   // CHECK: avx512.vp2intersect {{.*}} : vector<8xi64>
   %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
-  return %0, %1, %2, %3 : i16, i16, i8, i8
+  return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
 }


        


More information about the Mlir-commits mailing list