[Mlir-commits] [mlir] a0c930d - [mlir][Vector] Support 0-D vectors in `CmpIOp`
Nicolas Vasilache
llvmlistbot at llvm.org
Sun Dec 12 05:29:17 PST 2021
Author: Michal Terepeta
Date: 2021-12-12T13:28:26Z
New Revision: a0c930d312848b43f94bc13060a687c3bbe96b78
URL: https://github.com/llvm/llvm-project/commit/a0c930d312848b43f94bc13060a687c3bbe96b78
DIFF: https://github.com/llvm/llvm-project/commit/a0c930d312848b43f94bc13060a687c3bbe96b78.diff
LOG: [mlir][Vector] Support 0-D vectors in `CmpIOp`
Following the example of `VectorOfAnyRankOf`, I've done a few changes in the
`.td` files to help with adding the support for the 0-D case gradually.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D115220
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/include/mlir/IR/OpBase.td
mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
mlir/test/Dialect/Arithmetic/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 57459b518bf5c..0aff766414b9e 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -116,6 +116,13 @@ class Arith_CompareOp<string mnemonic, list<OpTrait> traits = []> :
let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
}
+// Just like `Arith_CompareOp` but also admits 0-D vectors. Introduced
+// temporarily to allow gradual transition to 0-D vectors.
+class Arith_CompareOpOfAnyRank<string mnemonic, list<OpTrait> traits = []> :
+ Arith_CompareOp<mnemonic, traits> {
+ let results = (outs BoolLikeOfAnyRank:$result);
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -284,7 +291,7 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi"> {
def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> {
let summary = "unsigned ceil integer division operation";
let description = [{
- Unsigned integer division. Rounds towards positive infinity. Treats the
+ Unsigned integer division. Rounds towards positive infinity. Treats the
leading bit as the most significant, i.e. for `i16` given two's complement
representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
@@ -990,7 +997,7 @@ def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
// CmpIOp
//===----------------------------------------------------------------------===//
-def Arith_CmpIOp : Arith_CompareOp<"cmpi"> {
+def Arith_CmpIOp : Arith_CompareOpOfAnyRank<"cmpi"> {
let summary = "integer comparison operation";
let description = [{
The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -1057,8 +1064,8 @@ def Arith_CmpIOp : Arith_CompareOp<"cmpi"> {
}];
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
- SignlessIntegerLike:$lhs,
- SignlessIntegerLike:$rhs);
+ SignlessIntegerLikeOfAnyRank:$lhs,
+ SignlessIntegerLikeOfAnyRank:$rhs);
let builders = [
OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 65f8ed1028237..22f4045a0dc89 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -213,6 +213,7 @@ def IsVectorTypePred : And<[CPred<"$_self.isa<::mlir::VectorType>()">,
CPred<"$_self.cast<::mlir::VectorType>().getRank() > 0">]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
+// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
// Whether a type is a TensorType.
@@ -603,7 +604,9 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
class VectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
"::mlir::VectorType">;
+
// Temporary vector type clone that allows gradual transition to 0-D vectors.
+// TODO: Remove this when all ops support 0-D vectors.
class VectorOfAnyRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
"::mlir::VectorType">;
@@ -835,6 +838,14 @@ def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
TensorOf<[I1]>.predicate]>,
"bool-like">;
+// Temporary constraint to allow gradual transition to supporting 0-D vectors.
+// TODO: Remove this when all ops support 0-D vectors.
+def BoolLikeOfAnyRank : TypeConstraint<Or<[
+ I1.predicate,
+ VectorOfAnyRankOf<[I1]>.predicate,
+ TensorOf<[I1]>.predicate]>,
+ "bool-like">;
+
// Type constraint for signless-integer-like types: signless integers, indices,
// vectors of signless integers or indices, tensors of signless integers.
def SignlessIntegerLike : TypeConstraint<Or<[
@@ -843,6 +854,14 @@ def SignlessIntegerLike : TypeConstraint<Or<[
TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>,
"signless-integer-like">;
+// Temporary constraint to allow gradual transition to supporting 0-D vectors.
+// TODO: Remove this when all ops support 0-D vectors.
+def SignlessIntegerLikeOfAnyRank : TypeConstraint<Or<[
+ AnySignlessIntegerOrIndex.predicate,
+ VectorOfAnyRankOf<[AnySignlessIntegerOrIndex]>.predicate,
+ TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>,
+ "signless-integer-like">;
+
// Type constraint for float-like types: floats, vectors or tensors thereof.
def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
VectorOf<[AnyFloat]>.predicate, TensorOf<[AnyFloat]>.predicate]>,
diff --git a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
index c21db125a318d..ecc173dc44f0b 100644
--- a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
@@ -352,6 +352,17 @@ func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
// -----
+// CHECK-LABEL: func @cmpi_0dvector(
+func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) {
+ // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
+ // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
+ // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[ARG0]], %[[ARG1]] : vector<1xi32>
+ %0 = arith.cmpi ult, %arg0, %arg1 : vector<i32>
+ std.return
+}
+
+// -----
+
// CHECK-LABEL: func @cmpi_2dvector(
func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir
index 54a1014eb6e2a..6907d7b6fe535 100644
--- a/mlir/test/Dialect/Arithmetic/ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/ops.mlir
@@ -631,6 +631,12 @@ func @test_cmpi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8
return %0 : vector<8xi1>
}
+// CHECK-LABEL: test_cmpi_vector_0d
+func @test_cmpi_vector_0d(%arg0 : vector<i64>, %arg1 : vector<i64>) -> vector<i1> {
+ %0 = arith.cmpi ult, %arg0, %arg1 : vector<i64>
+ return %0 : vector<i1>
+}
+
// CHECK-LABEL: test_cmpf
func @test_cmpf(%arg0 : f64, %arg1 : f64) -> i1 {
%0 = arith.cmpf oeq, %arg0, %arg1 : f64
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index a0d4c3d82974c..67a4257fa35dc 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -67,7 +67,6 @@ func @bitcast_0d() {
return
}
-
func @constant_mask_0d() {
%1 = vector.constant_mask [0] : vector<i1>
// CHECK: ( 0 )
@@ -78,6 +77,22 @@ func @constant_mask_0d() {
return
}
+func @arith_cmpi_0d(%smaller : vector<i32>, %bigger : vector<i32>) {
+ %0 = arith.cmpi ult, %smaller, %bigger : vector<i32>
+ // CHECK: ( 1 )
+ vector.print %0: vector<i1>
+
+ %1 = arith.cmpi ugt, %smaller, %bigger : vector<i32>
+ // CHECK: ( 0 )
+ vector.print %1: vector<i1>
+
+ %2 = arith.cmpi eq, %smaller, %bigger : vector<i32>
+ // CHECK: ( 0 )
+ vector.print %2: vector<i1>
+
+ return
+}
+
func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
@@ -96,5 +111,9 @@ func @entry() {
call @bitcast_0d() : () -> ()
call @constant_mask_0d() : () -> ()
+ %smaller = arith.constant dense<42> : vector<i32>
+ %bigger = arith.constant dense<4242> : vector<i32>
+ call @arith_cmpi_0d(%smaller, %bigger) : (vector<i32>, vector<i32>) -> ()
+
return
}
More information about the Mlir-commits
mailing list