[Mlir-commits] [mlir] b214ca8 - [mlir][vector] Rename vector type TD definitions (nfc) (#117150)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 26 06:59:43 PST 2024


Author: Andrzej WarzyƄski
Date: 2024-11-26T14:59:39Z
New Revision: b214ca82daeece1568268ebc0fbcc2eaa649425b

URL: https://github.com/llvm/llvm-project/commit/b214ca82daeece1568268ebc0fbcc2eaa649425b
DIFF: https://github.com/llvm/llvm-project/commit/b214ca82daeece1568268ebc0fbcc2eaa649425b.diff

LOG: [mlir][vector] Rename vector type TD definitions (nfc) (#117150)

Currently, the Vector dialect TD file includes the following "vector"
type definitions:

```mlir
def AnyVector : VectorOf<[AnyType]>;
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
def AnyFixedVector : FixedVectorOf<[AnyType]>;
def AnyScalableVector : ScalableVectorOf<[AnyType]>;
```

In short:

  * `AnyVector` _excludes_ 0-D vectors.
  * `AnyVectorOfAnyRank`, `AnyFixedVector`, and `AnyScalableVector`
    _include_ 0-D vectors.

The naming for "groups" that include 0-D vectors is inconsistent and can
be misleading, and `AnyVector` implies that 0-D vectors are included,
which is not the case.

This patch renames these definitions for clarity:

```mlir
def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>;
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
```

Rationale:
* The updated names are more explicit about 0-D vector support.
* It becomes clearer that scalable vectors currently allow 0-D vectors -
  this might warrant a revisit.
* The renaming paves the way for adding a new group for "fixed-width
  vectors excluding 0-D vectors" (e.g., AnyFixedVector), which I plan to
  introduce in a follow-up patch.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
    mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
    mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
    mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/IR/CommonTypeConstraints.td
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 76d97f106dcb88..56fbe9cdc2d21d 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -964,7 +964,7 @@ def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> {
     (see [vector.transfer_read](../Vector/#vectortransfer_read-mlirvectortransferreadop)).
   }];
 
-  let results = (outs AnyVector:$result);
+  let results = (outs AnyVectorOfNonZeroRank:$result);
 
   let builders = [
     /// Builds an affine vector load op with the specified map and operands.
@@ -1031,7 +1031,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
     (see [vector.transfer_write](../Vector/#vectortransfer_write-mlirvectortransferwriteop)).
   }];
 
-  let arguments = (ins AnyVector:$value,
+  let arguments = (ins AnyVectorOfNonZeroRank:$value,
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$memref,
       Variadic<Index>:$indices,

diff  --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 9cc792093bf836..475b11f12c5f01 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect {
 //===----------------------------------------------------------------------===//
 
 class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
-  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
+  [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
   "a vector with length " # length,
   "::mlir::VectorType">;
 

diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 9a058ae4fe7647..6fd992afbf0436 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -371,7 +371,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
   let arguments = (ins
     Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
     Variadic<Index>:$indices,
-    Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
+    Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
   );
   let results = (outs SMETile:$result);
@@ -444,7 +444,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
   }];
   let arguments = (ins SMETile:$valueToStore,
     Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
-    Variadic<Index>:$indices, Optional<AnyVector>:$mask,
+    Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
   );
   let extraClassDeclaration = [{
@@ -799,9 +799,9 @@ class OuterProductWideningBase<string mnemonic,
   ]> {
 
   let arguments = (ins
-    AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
-    Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
-    Optional<AnyVector>:$acc);
+    AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVectorOfNonZeroRank:$rhs,
+    Optional<AnyVectorOfNonZeroRank>:$lhsMask, Optional<AnyVectorOfNonZeroRank>:$rhsMask,
+    Optional<AnyVectorOfNonZeroRank>:$acc);
   let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
 
   let assemblyFormat = [{

diff  --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d7e8b22fbd2d35..cdcf4d8752e874 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -100,11 +100,11 @@ class ScalableMaskedFOp<string mnemonic, string op_description,
     op_description # [{ on active lanes. Inactive lanes will keep the value of
     the first operand.}];
   let arguments = (ins
-          ScalableVectorOf<[I1]>:$mask,
-          ScalableVectorOf<[AnyFloat]>:$src1,
-          ScalableVectorOf<[AnyFloat]>:$src2
+          ScalableVectorOfAnyRank<[I1]>:$mask,
+          ScalableVectorOfAnyRank<[AnyFloat]>:$src1,
+          ScalableVectorOfAnyRank<[AnyFloat]>:$src2
   );
-  let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
+  let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res);
   let assemblyFormat =
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
@@ -123,11 +123,11 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
     op_description # [{ on active lanes. Inactive lanes will keep the value of
     the first operand.}];
   let arguments = (ins
-          ScalableVectorOf<[I1]>:$mask,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
-          ScalableVectorOf<[I8, I16, I32, I64]>:$src2
+          ScalableVectorOfAnyRank<[I1]>:$mask,
+          ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1,
+          ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2
   );
-  let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
+  let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res);
   let assemblyFormat =
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
@@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
 
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def SmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def UdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udot">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedAddIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"add">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedAddFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fadd">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedMulIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"mul">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedMulFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fmul">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSubIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sub">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSubFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fsub">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedSDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedUDivIIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"udiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ScalableMaskedDivFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
 def ConvertFromSvboolIntrOp :
   ArmSVE_IntrOp<"convert.from.svbool",
@@ -581,8 +581,8 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
     /*overloadedOperands=*/[0],
     /*overloadedResults=*/[],
     /*numResults=*/2>,
-    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
-                   Arg<AnyScalableVector, "v2">:$v2)>;
+    Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+                   Arg<AnyScalableVectorOfAnyRank, "v2">:$v2)>;
 
 // Note: This multi-vector intrinsic requires SME2.
 def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
@@ -590,10 +590,10 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
     /*overloadedOperands=*/[0],
     /*overloadedResults=*/[],
     /*numResults=*/4>,
-    Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
-                   Arg<AnyScalableVector, "v2">:$v2,
-                   Arg<AnyScalableVector, "v3">:$v3,
-                   Arg<AnyScalableVector, "v3">:$v4)>;
+    Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+                   Arg<AnyScalableVectorOfAnyRank, "v2">:$v2,
+                   Arg<AnyScalableVectorOfAnyRank, "v3">:$v3,
+                   Arg<AnyScalableVectorOfAnyRank, "v3">:$v4)>;
 
 // Note: This intrinsic requires SME or SVE2.1.
 def PselIntrOp : ArmSVE_IntrOp<"psel",

diff  --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 1f52f6b91617c1..b39f2ee594cd4a 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -255,7 +255,7 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", [
   let arguments = (ins Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$srcMemref,
                            Variadic<Index>:$indices, BoolAttr:$transpose,
                            I32Attr:$numTiles);
-  let results = (outs AnyVector:$res);
+  let results = (outs AnyVectorOfNonZeroRank:$res);
   let assemblyFormat = [{
     $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)
   }];
@@ -301,13 +301,13 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
         (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
     ```
   }];
-  let arguments = (ins AnyVector:$matrixA,
-                       AnyVector:$matrixB,
-                       AnyVector:$matrixC,
+  let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
+                       AnyVectorOfNonZeroRank:$matrixB,
+                       AnyVectorOfNonZeroRank:$matrixC,
                        I64ArrayAttr:$mmaShape,
                        OptionalAttr<UnitAttr>:$tf32Enabled);
 
-  let results = (outs AnyVector:$res);
+  let results = (outs AnyVectorOfNonZeroRank:$res);
 
   let builders = [
     OpBuilder<(ins "Value":$matrixA,
@@ -357,16 +357,16 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
   ```
   }];
 
-  let arguments = (ins AnyVector:$matrixA,
-                       AnyVector:$matrixB,
-                       AnyVector:$matrixC,
+  let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
+                       AnyVectorOfNonZeroRank:$matrixB,
+                       AnyVectorOfNonZeroRank:$matrixC,
                        NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
                        I64ArrayAttr:$mmaShape,
                        DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
                        OptionalAttr<UnitAttr>:$tf32Enabled
                        );
 
-  let results = (outs AnyVector:$res);
+  let results = (outs AnyVectorOfNonZeroRank:$res);
 
   let builders = [
     OpBuilder<(ins "Value":$matrixA,
@@ -825,10 +825,10 @@ def NVGPU_RcpOp : NVGPU_Op<"rcp", [Pure,
 
     The input and output must be of the same vector type and shape.
   }];
-  let arguments = (ins VectorOf<[F32]>:$in,
+  let arguments = (ins VectorOfNonZeroRankOf<[F32]>:$in,
                        DefaultValuedAttr<RcpRoundingModeAttr, "RcpRoundingMode::APPROX">:$rounding,
                        UnitAttr:$ftz);
-  let results = (outs VectorOf<[F32]>:$out);
+  let results = (outs VectorOfNonZeroRankOf<[F32]>:$out);
   let assemblyFormat = [{
     $in `{` `rounding` `=` $rounding (`,` `ftz` $ftz^)? `}` 
     attr-dict `:` type($out)

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index a4b43d656fe43e..a6d3163d4446fa 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -166,7 +166,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
 
 class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
      AnyTypeOf<types>.predicate,
-     VectorOf<types>.predicate,
+     VectorOfNonZeroRankOf<types>.predicate,
      TosaTensorOf<types>.predicate]>,
      description>;
 

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 41d7ce6610085c..88c1b94412241e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -40,7 +40,7 @@ def Vector_ContractionOp :
       DeclareOpInterfaceMethods<MaskableOpInterface>,
       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
-    Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
+    Arguments<(ins AnyVectorOfNonZeroRank:$lhs, AnyVectorOfNonZeroRank:$rhs, AnyType:$acc,
                ArrayAttr:$indexing_maps,
                Vector_IteratorTypeArrayAttr:$iterator_types,
                DefaultValuedAttr<Vector_CombiningKindAttr,
@@ -285,7 +285,7 @@ def Vector_MultiDimReductionOp :
      DeclareOpInterfaceMethods<VectorUnrollOpInterface,
                                ["getShapeForUnroll"]>]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
-                   AnyVector:$source,
+                   AnyVectorOfNonZeroRank:$source,
                    AnyType:$acc,
                    DenseI64ArrayAttr:$reduction_dims)>,
     Results<(outs AnyType:$dest)> {
@@ -417,16 +417,18 @@ def Vector_BroadcastOp :
   let hasVerifier = 1;
 }
 
-def Vector_ShuffleOp :
-  Vector_Op<"shuffle", [Pure,
-     PredOpTrait<"first operand v1 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>,
-     PredOpTrait<"second operand v2 and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 1>>,
-     InferTypeOpAdaptor]>,
-     Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
-                    DenseI64ArrayAttr:$mask)>,
-     Results<(outs AnyVector:$vector)> {
+def Vector_ShuffleOp
+    : Vector_Op<
+          "shuffle",
+          [Pure,
+           PredOpTrait<"first operand v1 and result have same element type",
+                       TCresVTEtIsSameAsOpBase<0, 0>>,
+           PredOpTrait<"second operand v2 and result have same element type",
+                       TCresVTEtIsSameAsOpBase<0, 1>>,
+           InferTypeOpAdaptor]>,
+      Arguments<(ins AnyFixedVectorOfAnyRank:$v1, AnyFixedVectorOfAnyRank:$v2,
+          DenseI64ArrayAttr:$mask)>,
+      Results<(outs AnyVectorOfNonZeroRank:$vector)> {
   let summary = "shuffle operation";
   let description = [{
     The shuffle operation constructs a permutation (or duplication) of elements
@@ -531,7 +533,7 @@ def Vector_InterleaveOp :
   }];
 
   let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs);
-  let results = (outs AnyVector:$result);
+  let results = (outs AnyVectorOfNonZeroRank:$result);
 
   let assemblyFormat = [{
     $lhs `,` $rhs  attr-dict `:` type($lhs) `->` type($result)
@@ -610,8 +612,8 @@ def Vector_DeinterleaveOp :
         ```
       }];
 
-      let arguments = (ins AnyVector:$source);
-      let results = (outs AnyVector:$res1, AnyVector:$res2);
+      let arguments = (ins AnyVectorOfNonZeroRank:$source);
+      let results = (outs AnyVectorOfNonZeroRank:$res1, AnyVectorOfNonZeroRank:$res2);
 
       let assemblyFormat = [{
         $source attr-dict `:` type($source) `->` type($res1)
@@ -1048,9 +1050,9 @@ def Vector_InsertStridedSliceOp :
     PredOpTrait<"operand #0 and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
     AllTypesMatch<["dest", "res"]>]>,
-    Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
+    Arguments<(ins AnyVectorOfNonZeroRank:$source, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
                I64ArrayAttr:$strides)>,
-    Results<(outs AnyVector:$res)> {
+    Results<(outs AnyVectorOfNonZeroRank:$res)> {
   let summary = "strided_slice operation";
   let description = [{
     Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
@@ -1107,10 +1109,10 @@ def Vector_OuterProductOp :
     PredOpTrait<"rhs operand and result have same element type",
                 TCresVTEtIsSameAsOpBase<0, 1>>,
     DeclareOpInterfaceMethods<MaskableOpInterface>]>,
-    Arguments<(ins AnyVector:$lhs, AnyType:$rhs,
-               Optional<AnyVector>:$acc,
+    Arguments<(ins AnyVectorOfNonZeroRank:$lhs, AnyType:$rhs,
+               Optional<AnyVectorOfNonZeroRank>:$acc,
                DefaultValuedAttr<Vector_CombiningKindAttr, "CombiningKind::ADD">:$kind)>,
-    Results<(outs AnyVector)> {
+    Results<(outs AnyVectorOfNonZeroRank)> {
   let summary = "vector outerproduct with optional fused add";
   let description = [{
     Takes 2 1-D vectors and returns the 2-D vector containing the outer-product,
@@ -1190,9 +1192,9 @@ def Vector_ExtractStridedSliceOp :
   Vector_Op<"extract_strided_slice", [Pure,
     PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
+    Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets,
                I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
-    Results<(outs AnyVector)> {
+    Results<(outs AnyVectorOfNonZeroRank)> {
   let summary = "extract_strided_slice operation";
   let description = [{
     Takes an n-D vector, k-D `offsets` integer array attribute, a k-sized
@@ -1254,7 +1256,7 @@ def Vector_TransferReadOp :
                    Variadic<Index>:$indices,
                    AffineMapAttr:$permutation_map,
                    AnyType:$padding,
-                   Optional<VectorOf<[I1]>>:$mask,
+                   Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
                    BoolArrayAttr:$in_bounds)>,
     Results<(outs AnyVectorOfAnyRank:$vector)> {
 
@@ -1502,7 +1504,7 @@ def Vector_TransferWriteOp :
                    AnyShaped:$source,
                    Variadic<Index>:$indices,
                    AffineMapAttr:$permutation_map,
-                   Optional<VectorOf<[I1]>>:$mask,
+                   Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
                    BoolArrayAttr:$in_bounds)>,
     Results<(outs Optional<AnyRankedTensor>:$result)> {
 
@@ -1825,9 +1827,9 @@ def Vector_MaskedLoadOp :
   Vector_Op<"maskedload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$pass_thru)>,
-    Results<(outs AnyVector:$result)> {
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$pass_thru)>,
+    Results<(outs AnyVectorOfNonZeroRank:$result)> {
 
   let summary = "loads elements from memory into a vector as defined by a mask vector";
 
@@ -1888,8 +1890,8 @@ def Vector_MaskedStoreOp :
   Vector_Op<"maskedstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$valueToStore)> {
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$valueToStore)> {
 
   let summary = "stores elements from a vector into memory as defined by a mask vector";
 
@@ -1951,10 +1953,10 @@ def Vector_GatherOp :
   ]>,
     Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[AnyInteger, Index]>:$index_vec,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$pass_thru)>,
-    Results<(outs AnyVector:$result)> {
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$pass_thru)>,
+    Results<(outs AnyVectorOfNonZeroRank:$result)> {
 
   let summary = [{
     gathers elements from memory or ranked tensor into a vector as defined by an
@@ -2082,9 +2084,9 @@ def Vector_ExpandLoadOp :
   Vector_Op<"expandload">,
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$pass_thru)>,
-    Results<(outs AnyVector:$result)> {
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$pass_thru)>,
+    Results<(outs AnyVectorOfNonZeroRank:$result)> {
 
   let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
 
@@ -2149,8 +2151,8 @@ def Vector_CompressStoreOp :
   Vector_Op<"compressstore">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
-               VectorOf<[I1]>:$mask,
-               AnyVector:$valueToStore)> {
+               VectorOfNonZeroRankOf<[I1]>:$mask,
+               AnyVectorOfNonZeroRank:$valueToStore)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
 
@@ -2508,7 +2510,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
   }];
 
   // TODO: Support multiple passthru values.
-  let arguments = (ins VectorOf<[I1]>:$mask,
+  let arguments = (ins VectorOfNonZeroRankOf<[I1]>:$mask,
                    Optional<AnyType>:$passthru);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$maskRegion);
@@ -2891,11 +2893,11 @@ def Vector_ScanOp :
     AllTypesMatch<["source", "dest"]>,
     AllTypesMatch<["initial_value", "accumulated_value"]> ]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
-                   AnyVector:$source,
+                   AnyVectorOfNonZeroRank:$source,
                    AnyVectorOfAnyRank:$initial_value,
                    I64Attr:$reduction_dim,
                    BoolAttr:$inclusive)>,
-    Results<(outs AnyVector:$dest,
+    Results<(outs AnyVectorOfNonZeroRank:$dest,
                   AnyVectorOfAnyRank:$accumulated_value)> {
   let summary = "Scan operation";
   let description = [{

diff  --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 48e4c24f838652..fc4383d08422cb 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -22,15 +22,15 @@ include "mlir/IR/DialectBase.td"
 
 // Whether a type is a VectorType.
 // Explicitly disallow 0-D vectors for now until we have good enough coverage.
-def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
-                            CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
+def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+                                         CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 
 // Whether a type is a fixed-length VectorType.
-def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                                   !::llvm::cast<VectorType>($_self).isScalable()}]>;
 
 // Whether a type is a scalable VectorType.
@@ -53,7 +53,7 @@ def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[
 
 // Whether a type is a VectorType and all dimensions are scalable.
 def IsVectorTypeWithAllDimsScalablePred : And<[
-  IsVectorTypePred,
+  IsVectorOfNonZeroRankTypePred,
   CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>
 ]>;
 
@@ -428,8 +428,8 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
 
 // Vector types.
 
-class VectorOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
+class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
                       "::mlir::VectorType">;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
@@ -438,11 +438,11 @@ class VectorOfAnyRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
                       "::mlir::VectorType">;
 
-class FixedVectorOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
+class FixedVectorOfAnyRank<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
           "fixed-length vector", "::mlir::VectorType">;
 
-class ScalableVectorOf<list<Type> allowedTypes> :
+class ScalableVectorOfAnyRank<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
           "scalable vector", "::mlir::VectorType">;
 
@@ -458,7 +458,7 @@ class VectorWithTrailingDimScalableOf<list<Type> allowedTypes> :
 // Whether the number of elements of a vector is from the given
 // `allowedRanks` list
 class IsVectorOfRankPred<list<int> allowedRanks> :
-  And<[IsVectorTypePred,
+  And<[IsVectorOfNonZeroRankTypePred,
        Or<!foreach(allowedlength, allowedRanks,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
                            == }]
@@ -467,7 +467,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedRanks` list
 class IsFixedVectorOfRankPred<list<int> allowedRanks> :
-  And<[IsFixedVectorTypePred,
+  And<[IsFixedVectorOfAnyRankTypePred,
        Or<!foreach(allowedlength, allowedRanks,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
                            == }]
@@ -501,22 +501,22 @@ class ScalableVectorOfRank<list<int> allowedRanks> : Type<
 // is from the given `allowedTypes` list
 class VectorOfRankAndType<list<int> allowedRanks,
                           list<Type> allowedTypes> : AllOfType<
-  [VectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
-  VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
+  [VectorOfNonZeroRankOf<allowedTypes>, VectorOfRank<allowedRanks>],
+   VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
   "::mlir::VectorType">;
 
 // Fixed-width vector where the rank is from the given `allowedRanks` list and
 // the type is from the given `allowedTypes` list
 class FixedVectorOfRankAndType<list<int> allowedRanks,
                           list<Type> allowedTypes> : AllOfType<
-  [FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
-  FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
+  [FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>],
+  FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
   "::mlir::VectorType">;
 
 // Whether the number of elements of a vector is from the given
 // `allowedLengths` list
 class IsVectorOfLengthPred<list<int> allowedLengths> :
-  And<[IsVectorTypePred,
+  And<[IsVectorOfNonZeroRankTypePred,
        Or<!foreach(allowedlength, allowedLengths,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
                            == }]
@@ -525,7 +525,7 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
 // Whether the number of elements of a fixed-length vector is from the given
 // `allowedLengths` list
 class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
-  And<[IsFixedVectorTypePred,
+  And<[IsFixedVectorOfAnyRankTypePred,
        Or<!foreach(allowedlength, allowedLengths,
                    CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
                            == }]
@@ -604,16 +604,16 @@ class ScalableVectorOfLength<list<int> allowedLengths> : Type<
 // list
 class VectorOfLengthAndType<list<int> allowedLengths,
                             list<Type> allowedTypes> : AllOfType<
-  [VectorOf<allowedTypes>, VectorOfLength<allowedLengths>],
-  VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
+  [VectorOfNonZeroRankOf<allowedTypes>, VectorOfLength<allowedLengths>],
+   VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
 // Any fixed-length vector where the number of elements is from the given
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class FixedVectorOfLengthAndType<list<int> allowedLengths,
                                  list<Type> allowedTypes> : AllOfType<
-  [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
-  FixedVectorOf<allowedTypes>.summary #
+  [FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
+  FixedVectorOfAnyRank<allowedTypes>.summary #
   FixedVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -621,8 +621,8 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
 // `allowedLengths` list and the type is from the given `allowedTypes` list
 class ScalableVectorOfLengthAndType<list<int> allowedLengths,
                                     list<Type> allowedTypes> : AllOfType<
-  [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
-  ScalableVectorOf<allowedTypes>.summary #
+  [ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
+  ScalableVectorOfAnyRank<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -632,10 +632,10 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
 class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
                                            list<int> allowedLengths,
                                            list<Type> allowedTypes> : AllOfType<
-  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
+  [ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>,
    ScalableVectorOfLength<allowedLengths>],
   ScalableVectorOfRank<allowedRanks>.summary #
-  ScalableVectorOf<allowedTypes>.summary #
+  ScalableVectorOfAnyRank<allowedTypes>.summary #
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
@@ -657,13 +657,14 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
    ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
   "::mlir::VectorType">;
 
-def AnyVector : VectorOf<[AnyType]>;
-// Temporary vector type clone that allows gradual transition to 0-D vectors.
+// Unlike the following definitions, this one excludes 0-D vectors
+def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>;
+
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
 
-def AnyFixedVector : FixedVectorOf<[AnyType]>;
+def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
 
-def AnyScalableVector : ScalableVectorOf<[AnyType]>;
+def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
 
 // Shaped types.
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6752113cab8d41..239d5292180269 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2781,7 +2781,7 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
 //===----------------------------------------------------------------------===//
 // Test InferIntRangeInterface
 //===----------------------------------------------------------------------===//
-def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOf<[AnyInteger, Index]>]>;
+def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOfNonZeroRankOf<[AnyInteger, Index]>]>;
 
 def TestWithBoundsOp : TEST_Op<"with_bounds",
                           [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,


        


More information about the Mlir-commits mailing list