[Mlir-commits] [mlir] a0c930d - [mlir][Vector] Support 0-D vectors in `CmpIOp`

Nicolas Vasilache llvmlistbot at llvm.org
Sun Dec 12 05:29:17 PST 2021


Author: Michal Terepeta
Date: 2021-12-12T13:28:26Z
New Revision: a0c930d312848b43f94bc13060a687c3bbe96b78

URL: https://github.com/llvm/llvm-project/commit/a0c930d312848b43f94bc13060a687c3bbe96b78
DIFF: https://github.com/llvm/llvm-project/commit/a0c930d312848b43f94bc13060a687c3bbe96b78.diff

LOG: [mlir][Vector] Support 0-D vectors in `CmpIOp`

Following the example of `VectorOfAnyRankOf`, I've done a few changes in the
`.td` files to help with adding the support for the 0-D case gradually.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D115220

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
    mlir/test/Dialect/Arithmetic/ops.mlir
    mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 57459b518bf5c..0aff766414b9e 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -116,6 +116,13 @@ class Arith_CompareOp<string mnemonic, list<OpTrait> traits = []> :
   let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
 }
 
+// Just like `Arith_CompareOp` but also admits 0-D vectors. Introduced
+// temporarily to allow gradual transition to 0-D vectors.
+class Arith_CompareOpOfAnyRank<string mnemonic, list<OpTrait> traits = []> :
+    Arith_CompareOp<mnemonic, traits> {
+  let results = (outs BoolLikeOfAnyRank:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -284,7 +291,7 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi"> {
 def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> {
   let summary = "unsigned ceil integer division operation";
   let description = [{
-    Unsigned integer division. Rounds towards positive infinity. Treats the 
+    Unsigned integer division. Rounds towards positive infinity. Treats the
     leading bit as the most significant, i.e. for `i16` given two's complement
     representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
 
@@ -990,7 +997,7 @@ def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_CmpIOp : Arith_CompareOp<"cmpi"> {
+def Arith_CmpIOp : Arith_CompareOpOfAnyRank<"cmpi"> {
   let summary = "integer comparison operation";
   let description = [{
     The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -1057,8 +1064,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi"> {
   }];
 
   let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
-                       SignlessIntegerLike:$lhs,
-                       SignlessIntegerLike:$rhs);
+                       SignlessIntegerLikeOfAnyRank:$lhs,
+                       SignlessIntegerLikeOfAnyRank:$rhs);
 
   let builders = [
     OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 65f8ed1028237..22f4045a0dc89 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -213,6 +213,7 @@ def IsVectorTypePred : And<[CPred<"$_self.isa<::mlir::VectorType>()">,
                             CPred<"$_self.cast<::mlir::VectorType>().getRank() > 0">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
+// TODO: Remove this when all ops support 0-D vectors.
 def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
 
 // Whether a type is a TensorType.
@@ -603,7 +604,9 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
 class VectorOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
                       "::mlir::VectorType">;
+
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
+// TODO: Remove this when all ops support 0-D vectors.
 class VectorOfAnyRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
                       "::mlir::VectorType">;
@@ -835,6 +838,14 @@ def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
                                   TensorOf<[I1]>.predicate]>,
     "bool-like">;
 
+// Temporary constraint to allow gradual transition to supporting 0-D vectors.
+// TODO: Remove this when all ops support 0-D vectors.
+def BoolLikeOfAnyRank : TypeConstraint<Or<[
+        I1.predicate,
+        VectorOfAnyRankOf<[I1]>.predicate,
+        TensorOf<[I1]>.predicate]>,
+    "bool-like">;
+
 // Type constraint for signless-integer-like types: signless integers, indices,
 // vectors of signless integers or indices, tensors of signless integers.
 def SignlessIntegerLike : TypeConstraint<Or<[
@@ -843,6 +854,14 @@ def SignlessIntegerLike : TypeConstraint<Or<[
         TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>,
     "signless-integer-like">;
 
+// Temporary constraint to allow gradual transition to supporting 0-D vectors.
+// TODO: Remove this when all ops support 0-D vectors.
+def SignlessIntegerLikeOfAnyRank : TypeConstraint<Or<[
+        AnySignlessIntegerOrIndex.predicate,
+        VectorOfAnyRankOf<[AnySignlessIntegerOrIndex]>.predicate,
+        TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>,
+    "signless-integer-like">;
+
 // Type constraint for float-like types: floats, vectors or tensors thereof.
 def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
         VectorOf<[AnyFloat]>.predicate, TensorOf<[AnyFloat]>.predicate]>,

diff  --git a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
index c21db125a318d..ecc173dc44f0b 100644
--- a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
@@ -352,6 +352,17 @@ func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @cmpi_0dvector(
+func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) {
+  // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
+  // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
+  // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[ARG0]], %[[ARG1]] : vector<1xi32>
+  %0 = arith.cmpi ult, %arg0, %arg1 : vector<i32>
+  std.return
+}
+
+// -----
+
 // CHECK-LABEL: func @cmpi_2dvector(
 func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
   // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast

diff  --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir
index 54a1014eb6e2a..6907d7b6fe535 100644
--- a/mlir/test/Dialect/Arithmetic/ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/ops.mlir
@@ -631,6 +631,12 @@ func @test_cmpi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
   return %0 : vector<8xi1>
 }
 
+// CHECK-LABEL: test_cmpi_vector_0d
+func @test_cmpi_vector_0d(%arg0 : vector<i64>, %arg1 : vector<i64>) -> vector<i1> {
+  %0 = arith.cmpi ult, %arg0, %arg1 : vector<i64>
+  return %0 : vector<i1>
+}
+
 // CHECK-LABEL: test_cmpf
 func @test_cmpf(%arg0 : f64, %arg1 : f64) -> i1 {
   %0 = arith.cmpf oeq, %arg0, %arg1 : f64

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index a0d4c3d82974c..67a4257fa35dc 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -67,7 +67,6 @@ func @bitcast_0d() {
   return
 }
 
-
 func @constant_mask_0d() {
   %1 = vector.constant_mask [0] : vector<i1>
   // CHECK: ( 0 )
@@ -78,6 +77,22 @@ func @constant_mask_0d() {
   return
 }
 
+func @arith_cmpi_0d(%smaller : vector<i32>, %bigger : vector<i32>) {
+  %0 = arith.cmpi ult, %smaller, %bigger : vector<i32>
+  // CHECK: ( 1 )
+  vector.print %0: vector<i1>
+
+  %1 = arith.cmpi ugt, %smaller, %bigger : vector<i32>
+  // CHECK: ( 0 )
+  vector.print %1: vector<i1>
+
+  %2 = arith.cmpi eq, %smaller, %bigger : vector<i32>
+  // CHECK: ( 0 )
+  vector.print %2: vector<i1>
+
+  return
+}
+
 func @entry() {
   %0 = arith.constant 42.0 : f32
   %1 = arith.constant dense<0.0> : vector<f32>
@@ -96,5 +111,9 @@ func @entry() {
   call  @bitcast_0d() : () -> ()
   call  @constant_mask_0d() : () -> ()
 
+  %smaller = arith.constant dense<42> : vector<i32>
+  %bigger = arith.constant dense<4242> : vector<i32>
+  call  @arith_cmpi_0d(%smaller, %bigger) : (vector<i32>, vector<i32>) -> ()
+
   return
 }


        


More information about the Mlir-commits mailing list