[Mlir-commits] [mlir] b214ca8 - [mlir][vector] Rename vector type TD definitions (nfc) (#117150)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 26 06:59:43 PST 2024
Author: Andrzej WarzyĆski
Date: 2024-11-26T14:59:39Z
New Revision: b214ca82daeece1568268ebc0fbcc2eaa649425b
URL: https://github.com/llvm/llvm-project/commit/b214ca82daeece1568268ebc0fbcc2eaa649425b
DIFF: https://github.com/llvm/llvm-project/commit/b214ca82daeece1568268ebc0fbcc2eaa649425b.diff
LOG: [mlir][vector] Rename vector type TD definitions (nfc) (#117150)
Currently, the Vector dialect TD file includes the following "vector"
type definitions:
```mlir
def AnyVector : VectorOf<[AnyType]>;
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
def AnyFixedVector : FixedVectorOf<[AnyType]>;
def AnyScalableVector : ScalableVectorOf<[AnyType]>;
```
In short:
* `AnyVector` _excludes_ 0-D vectors.
* `AnyVectorOfAnyRank`, `AnyFixedVector`, and `AnyScalableVector`
_include_ 0-D vectors.
The naming for "groups" that include 0-D vectors is inconsistent and can
be misleading, and `AnyVector` implies that 0-D vectors are included,
which is not the case.
This patch renames these definitions for clarity:
```mlir
def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>;
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
```
Rationale:
* The updated names are more explicit about 0-D vector support.
* It becomes clearer that scalable vectors currently allow 0-D vectors -
this might warrant a revisit.
* The renaming paves the way for adding a new group for "fixed-width
vectors excluding 0-D vectors" (e.g., AnyFixedVector), which I plan to
introduce in a follow-up patch.
Added:
Modified:
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/IR/CommonTypeConstraints.td
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 76d97f106dcb88..56fbe9cdc2d21d 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -964,7 +964,7 @@ def AffineVectorLoadOp : AffineLoadOpBase<"vector_load"> {
(see [vector.transfer_read](../Vector/#vectortransfer_read-mlirvectortransferreadop)).
}];
- let results = (outs AnyVector:$result);
+ let results = (outs AnyVectorOfNonZeroRank:$result);
let builders = [
/// Builds an affine vector load op with the specified map and operands.
@@ -1031,7 +1031,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
(see [vector.transfer_write](../Vector/#vectortransfer_write-mlirvectortransferwriteop)).
}];
- let arguments = (ins AnyVector:$value,
+ let arguments = (ins AnyVectorOfNonZeroRank:$value,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref,
Variadic<Index>:$indices,
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 9cc792093bf836..475b11f12c5f01 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect {
//===----------------------------------------------------------------------===//
class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
- [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
+ [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
"a vector with length " # length,
"::mlir::VectorType">;
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 9a058ae4fe7647..6fd992afbf0436 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -371,7 +371,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
let arguments = (ins
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
- Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
+ Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SMETile:$result);
@@ -444,7 +444,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
}];
let arguments = (ins SMETile:$valueToStore,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
- Variadic<Index>:$indices, Optional<AnyVector>:$mask,
+ Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
let extraClassDeclaration = [{
@@ -799,9 +799,9 @@ class OuterProductWideningBase<string mnemonic,
]> {
let arguments = (ins
- AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
- Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
- Optional<AnyVector>:$acc);
+ AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVectorOfNonZeroRank:$rhs,
+ Optional<AnyVectorOfNonZeroRank>:$lhsMask, Optional<AnyVectorOfNonZeroRank>:$rhsMask,
+ Optional<AnyVectorOfNonZeroRank>:$acc);
let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d7e8b22fbd2d35..cdcf4d8752e874 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -100,11 +100,11 @@ class ScalableMaskedFOp<string mnemonic, string op_description,
op_description # [{ on active lanes. Inactive lanes will keep the value of
the first operand.}];
let arguments = (ins
- ScalableVectorOf<[I1]>:$mask,
- ScalableVectorOf<[AnyFloat]>:$src1,
- ScalableVectorOf<[AnyFloat]>:$src2
+ ScalableVectorOfAnyRank<[I1]>:$mask,
+ ScalableVectorOfAnyRank<[AnyFloat]>:$src1,
+ ScalableVectorOfAnyRank<[AnyFloat]>:$src2
);
- let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
+ let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res);
let assemblyFormat =
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
}
@@ -123,11 +123,11 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
op_description # [{ on active lanes. Inactive lanes will keep the value of
the first operand.}];
let arguments = (ins
- ScalableVectorOf<[I1]>:$mask,
- ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
- ScalableVectorOf<[I8, I16, I32, I64]>:$src2
+ ScalableVectorOfAnyRank<[I1]>:$mask,
+ ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1,
+ ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2
);
- let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
+ let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res);
let assemblyFormat =
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
}
@@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def UdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"udot">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedAddIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"add">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedAddFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fadd">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedMulIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"mul">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedMulFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fmul">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedSubIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sub">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedSubFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fsub">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedSDivIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedUDivIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"udiv">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedDivFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ConvertFromSvboolIntrOp :
ArmSVE_IntrOp<"convert.from.svbool",
@@ -581,8 +581,8 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
/*overloadedOperands=*/[0],
/*overloadedResults=*/[],
/*numResults=*/2>,
- Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
- Arg<AnyScalableVector, "v2">:$v2)>;
+ Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+ Arg<AnyScalableVectorOfAnyRank, "v2">:$v2)>;
// Note: This multi-vector intrinsic requires SME2.
def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
@@ -590,10 +590,10 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
/*overloadedOperands=*/[0],
/*overloadedResults=*/[],
/*numResults=*/4>,
- Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
- Arg<AnyScalableVector, "v2">:$v2,
- Arg<AnyScalableVector, "v3">:$v3,
- Arg<AnyScalableVector, "v3">:$v4)>;
+ Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+ Arg<AnyScalableVectorOfAnyRank, "v2">:$v2,
+ Arg<AnyScalableVectorOfAnyRank, "v3">:$v3,
+ Arg<AnyScalableVectorOfAnyRank, "v3">:$v4)>;
// Note: This intrinsic requires SME or SVE2.1.
def PselIntrOp : ArmSVE_IntrOp<"psel",
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 1f52f6b91617c1..b39f2ee594cd4a 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -255,7 +255,7 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", [
let arguments = (ins Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$srcMemref,
Variadic<Index>:$indices, BoolAttr:$transpose,
I32Attr:$numTiles);
- let results = (outs AnyVector:$res);
+ let results = (outs AnyVectorOfNonZeroRank:$res);
let assemblyFormat = [{
$srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res)
}];
@@ -301,13 +301,13 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
```
}];
- let arguments = (ins AnyVector:$matrixA,
- AnyVector:$matrixB,
- AnyVector:$matrixC,
+ let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
+ AnyVectorOfNonZeroRank:$matrixB,
+ AnyVectorOfNonZeroRank:$matrixC,
I64ArrayAttr:$mmaShape,
OptionalAttr<UnitAttr>:$tf32Enabled);
- let results = (outs AnyVector:$res);
+ let results = (outs AnyVectorOfNonZeroRank:$res);
let builders = [
OpBuilder<(ins "Value":$matrixA,
@@ -357,16 +357,16 @@ def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
```
}];
- let arguments = (ins AnyVector:$matrixA,
- AnyVector:$matrixB,
- AnyVector:$matrixC,
+ let arguments = (ins AnyVectorOfNonZeroRank:$matrixA,
+ AnyVectorOfNonZeroRank:$matrixB,
+ AnyVectorOfNonZeroRank:$matrixC,
NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
I64ArrayAttr:$mmaShape,
DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
OptionalAttr<UnitAttr>:$tf32Enabled
);
- let results = (outs AnyVector:$res);
+ let results = (outs AnyVectorOfNonZeroRank:$res);
let builders = [
OpBuilder<(ins "Value":$matrixA,
@@ -825,10 +825,10 @@ def NVGPU_RcpOp : NVGPU_Op<"rcp", [Pure,
The input and output must be of the same vector type and shape.
}];
- let arguments = (ins VectorOf<[F32]>:$in,
+ let arguments = (ins VectorOfNonZeroRankOf<[F32]>:$in,
DefaultValuedAttr<RcpRoundingModeAttr, "RcpRoundingMode::APPROX">:$rounding,
UnitAttr:$ftz);
- let results = (outs VectorOf<[F32]>:$out);
+ let results = (outs VectorOfNonZeroRankOf<[F32]>:$out);
let assemblyFormat = [{
$in `{` `rounding` `=` $rounding (`,` `ftz` $ftz^)? `}`
attr-dict `:` type($out)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index a4b43d656fe43e..a6d3163d4446fa 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -166,7 +166,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
AnyTypeOf<types>.predicate,
- VectorOf<types>.predicate,
+ VectorOfNonZeroRankOf<types>.predicate,
TosaTensorOf<types>.predicate]>,
description>;
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 41d7ce6610085c..88c1b94412241e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -40,7 +40,7 @@ def Vector_ContractionOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
- Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
+ Arguments<(ins AnyVectorOfNonZeroRank:$lhs, AnyVectorOfNonZeroRank:$rhs, AnyType:$acc,
ArrayAttr:$indexing_maps,
Vector_IteratorTypeArrayAttr:$iterator_types,
DefaultValuedAttr<Vector_CombiningKindAttr,
@@ -285,7 +285,7 @@ def Vector_MultiDimReductionOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
["getShapeForUnroll"]>]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
- AnyVector:$source,
+ AnyVectorOfNonZeroRank:$source,
AnyType:$acc,
DenseI64ArrayAttr:$reduction_dims)>,
Results<(outs AnyType:$dest)> {
@@ -417,16 +417,18 @@ def Vector_BroadcastOp :
let hasVerifier = 1;
}
-def Vector_ShuffleOp :
- Vector_Op<"shuffle", [Pure,
- PredOpTrait<"first operand v1 and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>,
- PredOpTrait<"second operand v2 and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 1>>,
- InferTypeOpAdaptor]>,
- Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
- DenseI64ArrayAttr:$mask)>,
- Results<(outs AnyVector:$vector)> {
+def Vector_ShuffleOp
+ : Vector_Op<
+ "shuffle",
+ [Pure,
+ PredOpTrait<"first operand v1 and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 0>>,
+ PredOpTrait<"second operand v2 and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 1>>,
+ InferTypeOpAdaptor]>,
+ Arguments<(ins AnyFixedVectorOfAnyRank:$v1, AnyFixedVectorOfAnyRank:$v2,
+ DenseI64ArrayAttr:$mask)>,
+ Results<(outs AnyVectorOfNonZeroRank:$vector)> {
let summary = "shuffle operation";
let description = [{
The shuffle operation constructs a permutation (or duplication) of elements
@@ -531,7 +533,7 @@ def Vector_InterleaveOp :
}];
let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs);
- let results = (outs AnyVector:$result);
+ let results = (outs AnyVectorOfNonZeroRank:$result);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)
@@ -610,8 +612,8 @@ def Vector_DeinterleaveOp :
```
}];
- let arguments = (ins AnyVector:$source);
- let results = (outs AnyVector:$res1, AnyVector:$res2);
+ let arguments = (ins AnyVectorOfNonZeroRank:$source);
+ let results = (outs AnyVectorOfNonZeroRank:$res1, AnyVectorOfNonZeroRank:$res2);
let assemblyFormat = [{
$source attr-dict `:` type($source) `->` type($res1)
@@ -1048,9 +1050,9 @@ def Vector_InsertStridedSliceOp :
PredOpTrait<"operand #0 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
- Arguments<(ins AnyVector:$source, AnyVector:$dest, I64ArrayAttr:$offsets,
+ Arguments<(ins AnyVectorOfNonZeroRank:$source, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
I64ArrayAttr:$strides)>,
- Results<(outs AnyVector:$res)> {
+ Results<(outs AnyVectorOfNonZeroRank:$res)> {
let summary = "strided_slice operation";
let description = [{
Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
@@ -1107,10 +1109,10 @@ def Vector_OuterProductOp :
PredOpTrait<"rhs operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
DeclareOpInterfaceMethods<MaskableOpInterface>]>,
- Arguments<(ins AnyVector:$lhs, AnyType:$rhs,
- Optional<AnyVector>:$acc,
+ Arguments<(ins AnyVectorOfNonZeroRank:$lhs, AnyType:$rhs,
+ Optional<AnyVectorOfNonZeroRank>:$acc,
DefaultValuedAttr<Vector_CombiningKindAttr, "CombiningKind::ADD">:$kind)>,
- Results<(outs AnyVector)> {
+ Results<(outs AnyVectorOfNonZeroRank)> {
let summary = "vector outerproduct with optional fused add";
let description = [{
Takes 2 1-D vectors and returns the 2-D vector containing the outer-product,
@@ -1190,9 +1192,9 @@ def Vector_ExtractStridedSliceOp :
Vector_Op<"extract_strided_slice", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVector:$vector, I64ArrayAttr:$offsets,
+ Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets,
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
- Results<(outs AnyVector)> {
+ Results<(outs AnyVectorOfNonZeroRank)> {
let summary = "extract_strided_slice operation";
let description = [{
Takes an n-D vector, k-D `offsets` integer array attribute, a k-sized
@@ -1254,7 +1256,7 @@ def Vector_TransferReadOp :
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
AnyType:$padding,
- Optional<VectorOf<[I1]>>:$mask,
+ Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
BoolArrayAttr:$in_bounds)>,
Results<(outs AnyVectorOfAnyRank:$vector)> {
@@ -1502,7 +1504,7 @@ def Vector_TransferWriteOp :
AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
- Optional<VectorOf<[I1]>>:$mask,
+ Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
BoolArrayAttr:$in_bounds)>,
Results<(outs Optional<AnyRankedTensor>:$result)> {
@@ -1825,9 +1827,9 @@ def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
- VectorOf<[I1]>:$mask,
- AnyVector:$pass_thru)>,
- Results<(outs AnyVector:$result)> {
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$pass_thru)>,
+ Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = "loads elements from memory into a vector as defined by a mask vector";
@@ -1888,8 +1890,8 @@ def Vector_MaskedStoreOp :
Vector_Op<"maskedstore">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOf<[I1]>:$mask,
- AnyVector:$valueToStore)> {
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$valueToStore)> {
let summary = "stores elements from a vector into memory as defined by a mask vector";
@@ -1951,10 +1953,10 @@ def Vector_GatherOp :
]>,
Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
Variadic<Index>:$indices,
- VectorOf<[AnyInteger, Index]>:$index_vec,
- VectorOf<[I1]>:$mask,
- AnyVector:$pass_thru)>,
- Results<(outs AnyVector:$result)> {
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$pass_thru)>,
+ Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = [{
gathers elements from memory or ranked tensor into a vector as defined by an
@@ -2082,9 +2084,9 @@ def Vector_ExpandLoadOp :
Vector_Op<"expandload">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
- VectorOf<[I1]>:$mask,
- AnyVector:$pass_thru)>,
- Results<(outs AnyVector:$result)> {
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$pass_thru)>,
+ Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
@@ -2149,8 +2151,8 @@ def Vector_CompressStoreOp :
Vector_Op<"compressstore">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOf<[I1]>:$mask,
- AnyVector:$valueToStore)> {
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$valueToStore)> {
let summary = "writes elements selectively from a vector as defined by a mask";
@@ -2508,7 +2510,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
}];
// TODO: Support multiple passthru values.
- let arguments = (ins VectorOf<[I1]>:$mask,
+ let arguments = (ins VectorOfNonZeroRankOf<[I1]>:$mask,
Optional<AnyType>:$passthru);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$maskRegion);
@@ -2891,11 +2893,11 @@ def Vector_ScanOp :
AllTypesMatch<["source", "dest"]>,
AllTypesMatch<["initial_value", "accumulated_value"]> ]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
- AnyVector:$source,
+ AnyVectorOfNonZeroRank:$source,
AnyVectorOfAnyRank:$initial_value,
I64Attr:$reduction_dim,
BoolAttr:$inclusive)>,
- Results<(outs AnyVector:$dest,
+ Results<(outs AnyVectorOfNonZeroRank:$dest,
AnyVectorOfAnyRank:$accumulated_value)> {
let summary = "Scan operation";
let description = [{
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 48e4c24f838652..fc4383d08422cb 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -22,15 +22,15 @@ include "mlir/IR/DialectBase.td"
// Whether a type is a VectorType.
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
-def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
- CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
+def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
// Whether a type is a fixed-length VectorType.
-def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
!::llvm::cast<VectorType>($_self).isScalable()}]>;
// Whether a type is a scalable VectorType.
@@ -53,7 +53,7 @@ def IsVectorTypeWithOnlyTrailingDimScalablePred : And<[
// Whether a type is a VectorType and all dimensions are scalable.
def IsVectorTypeWithAllDimsScalablePred : And<[
- IsVectorTypePred,
+ IsVectorOfNonZeroRankTypePred,
CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>
]>;
@@ -428,8 +428,8 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
// Vector types.
-class VectorOf<list<Type> allowedTypes> :
- ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
+class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
"::mlir::VectorType">;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
@@ -438,11 +438,11 @@ class VectorOfAnyRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
"::mlir::VectorType">;
-class FixedVectorOf<list<Type> allowedTypes> :
- ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
+class FixedVectorOfAnyRank<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
"fixed-length vector", "::mlir::VectorType">;
-class ScalableVectorOf<list<Type> allowedTypes> :
+class ScalableVectorOfAnyRank<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
"scalable vector", "::mlir::VectorType">;
@@ -458,7 +458,7 @@ class VectorWithTrailingDimScalableOf<list<Type> allowedTypes> :
// Whether the number of elements of a vector is from the given
// `allowedRanks` list
class IsVectorOfRankPred<list<int> allowedRanks> :
- And<[IsVectorTypePred,
+ And<[IsVectorOfNonZeroRankTypePred,
Or<!foreach(allowedlength, allowedRanks,
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
== }]
@@ -467,7 +467,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
// Whether the number of elements of a fixed-length vector is from the given
// `allowedRanks` list
class IsFixedVectorOfRankPred<list<int> allowedRanks> :
- And<[IsFixedVectorTypePred,
+ And<[IsFixedVectorOfAnyRankTypePred,
Or<!foreach(allowedlength, allowedRanks,
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
== }]
@@ -501,22 +501,22 @@ class ScalableVectorOfRank<list<int> allowedRanks> : Type<
// is from the given `allowedTypes` list
class VectorOfRankAndType<list<int> allowedRanks,
list<Type> allowedTypes> : AllOfType<
- [VectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
- VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
+ [VectorOfNonZeroRankOf<allowedTypes>, VectorOfRank<allowedRanks>],
+ VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
"::mlir::VectorType">;
// Fixed-width vector where the rank is from the given `allowedRanks` list and
// the type is from the given `allowedTypes` list
class FixedVectorOfRankAndType<list<int> allowedRanks,
list<Type> allowedTypes> : AllOfType<
- [FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
- FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
+ [FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>],
+ FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
"::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
// `allowedLengths` list
class IsVectorOfLengthPred<list<int> allowedLengths> :
- And<[IsVectorTypePred,
+ And<[IsVectorOfNonZeroRankTypePred,
Or<!foreach(allowedlength, allowedLengths,
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
== }]
@@ -525,7 +525,7 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
// Whether the number of elements of a fixed-length vector is from the given
// `allowedLengths` list
class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
- And<[IsFixedVectorTypePred,
+ And<[IsFixedVectorOfAnyRankTypePred,
Or<!foreach(allowedlength, allowedLengths,
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
== }]
@@ -604,16 +604,16 @@ class ScalableVectorOfLength<list<int> allowedLengths> : Type<
// list
class VectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
- [VectorOf<allowedTypes>, VectorOfLength<allowedLengths>],
- VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
+ [VectorOfNonZeroRankOf<allowedTypes>, VectorOfLength<allowedLengths>],
+ VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
// Any fixed-length vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class FixedVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
- [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
- FixedVectorOf<allowedTypes>.summary #
+ [FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
+ FixedVectorOfAnyRank<allowedTypes>.summary #
FixedVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
@@ -621,8 +621,8 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
// `allowedLengths` list and the type is from the given `allowedTypes` list
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
- [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
- ScalableVectorOf<allowedTypes>.summary #
+ [ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
+ ScalableVectorOfAnyRank<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
@@ -632,10 +632,10 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
- [ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
+ [ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>,
ScalableVectorOfLength<allowedLengths>],
ScalableVectorOfRank<allowedRanks>.summary #
- ScalableVectorOf<allowedTypes>.summary #
+ ScalableVectorOfAnyRank<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
@@ -657,13 +657,14 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
"::mlir::VectorType">;
-def AnyVector : VectorOf<[AnyType]>;
-// Temporary vector type clone that allows gradual transition to 0-D vectors.
+// Unlike the following definitions, this one excludes 0-D vectors
+def AnyVectorOfNonZeroRank : VectorOfNonZeroRankOf<[AnyType]>;
+
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
-def AnyFixedVector : FixedVectorOf<[AnyType]>;
+def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
-def AnyScalableVector : ScalableVectorOf<[AnyType]>;
+def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
// Shaped types.
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6752113cab8d41..239d5292180269 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2781,7 +2781,7 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
//===----------------------------------------------------------------------===//
// Test InferIntRangeInterface
//===----------------------------------------------------------------------===//
-def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOf<[AnyInteger, Index]>]>;
+def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOfNonZeroRankOf<[AnyInteger, Index]>]>;
def TestWithBoundsOp : TEST_Op<"with_bounds",
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
More information about the Mlir-commits
mailing list