[Mlir-commits] [mlir] [MLIR] Merge AnyVector and AnyVectorOfAnyRank type constraints. (PR #112937)
Harrison Hao
llvmlistbot at llvm.org
Fri Oct 18 10:12:54 PDT 2024
https://github.com/harrisonGPU created https://github.com/llvm/llvm-project/pull/112937
Merge AnyVector and AnyVectorOfAnyRank type constraints.
Closes https://github.com/llvm/llvm-project/issues/112913
>From 0b96a318b75300e7acc1216c9d8c27318b8afb25 Mon Sep 17 00:00:00 2001
From: Harrison Hao <tsworld1314 at gmail.com>
Date: Fri, 18 Oct 2024 16:56:20 +0000
Subject: [PATCH] [MLIR] Merge AnyVector and AnyVectorOfAnyRank type
constraints.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 48 +++++++++----------
mlir/include/mlir/IR/CommonTypeConstraints.td | 4 +-
2 files changed, 25 insertions(+), 27 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c02b16ea931706..11d3adb3c43e93 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -224,7 +224,7 @@ def Vector_ReductionOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
- AnyVectorOfAnyRank:$vector,
+ AnyVector:$vector,
Optional<AnyType>:$acc,
DefaultValuedAttr<
Arith_FastMathAttr,
@@ -349,7 +349,7 @@ def Vector_BroadcastOp :
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyType:$source)>,
- Results<(outs AnyVectorOfAnyRank:$vector)> {
+ Results<(outs AnyVector:$vector)> {
let summary = "broadcast operation";
let description = [{
Broadcasts the scalar or k-D vector value in the source operand
@@ -528,7 +528,7 @@ def Vector_InterleaveOp :
```
}];
- let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs);
+ let arguments = (ins AnyVector:$lhs, AnyVector:$rhs);
let results = (outs AnyVector:$result);
let assemblyFormat = [{
@@ -630,7 +630,7 @@ def Vector_ExtractElementOp :
TypesMatchWith<"result type matches element type of vector operand",
"vector", "result",
"::llvm::cast<VectorType>($_self).getElementType()">]>,
- Arguments<(ins AnyVectorOfAnyRank:$vector,
+ Arguments<(ins AnyVector:$vector,
Optional<AnySignlessIntegerOrIndex>:$position)>,
Results<(outs AnyType:$result)> {
let summary = "extractelement operation";
@@ -697,7 +697,7 @@ def Vector_ExtractOp :
}];
let arguments = (ins
- AnyVectorOfAnyRank:$vector,
+ AnyVector:$vector,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
@@ -803,7 +803,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
}];
let arguments = (ins Variadic<AnyType>:$elements);
- let results = (outs AnyVectorOfAnyRank:$result);
+ let results = (outs AnyVector:$result);
let assemblyFormat = "$elements attr-dict `:` type($result)";
let hasCanonicalizer = 1;
}
@@ -814,9 +814,9 @@ def Vector_InsertElementOp :
"result", "source",
"::llvm::cast<VectorType>($_self).getElementType()">,
AllTypesMatch<["dest", "result"]>]>,
- Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
+ Arguments<(ins AnyType:$source, AnyVector:$dest,
Optional<AnySignlessIntegerOrIndex>:$position)>,
- Results<(outs AnyVectorOfAnyRank:$result)> {
+ Results<(outs AnyVector:$result)> {
let summary = "insertelement operation";
let description = [{
Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index
@@ -884,11 +884,11 @@ def Vector_InsertOp :
let arguments = (ins
AnyType:$source,
- AnyVectorOfAnyRank:$dest,
+ AnyVector:$dest,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
- let results = (outs AnyVectorOfAnyRank:$result);
+ let results = (outs AnyVector:$result);
let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
@@ -1250,7 +1250,7 @@ def Vector_TransferReadOp :
AnyType:$padding,
Optional<VectorOf<[I1]>>:$mask,
BoolArrayAttr:$in_bounds)>,
- Results<(outs AnyVectorOfAnyRank:$vector)> {
+ Results<(outs AnyVector:$vector)> {
let summary = "Reads a supervector from memory into an SSA vector value.";
@@ -1492,7 +1492,7 @@ def Vector_TransferWriteOp :
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
- Arguments<(ins AnyVectorOfAnyRank:$vector,
+ Arguments<(ins AnyVector:$vector,
AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
@@ -1710,7 +1710,7 @@ def Vector_LoadOp : Vector_Op<"load"> {
[MemRead]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
- let results = (outs AnyVectorOfAnyRank:$result);
+ let results = (outs AnyVector:$result);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
@@ -1791,7 +1791,7 @@ def Vector_StoreOp : Vector_Op<"store"> {
}];
let arguments = (ins
- AnyVectorOfAnyRank:$valueToStore,
+ AnyVector:$valueToStore,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$base,
Variadic<Index>:$indices,
@@ -2199,8 +2199,8 @@ def Vector_CompressStoreOp :
def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [Pure]>,
- Arguments<(ins AnyVectorOfAnyRank:$source)>,
- Results<(outs AnyVectorOfAnyRank:$result)> {
+ Arguments<(ins AnyVector:$source)>,
+ Results<(outs AnyVector:$result)> {
let summary = "shape_cast casts between vector shapes";
let description = [{
The shape_cast operation casts between an n-D source vector shape and
@@ -2251,8 +2251,8 @@ def Vector_ShapeCastOp :
def Vector_BitCastOp :
Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
- Arguments<(ins AnyVectorOfAnyRank:$source)>,
- Results<(outs AnyVectorOfAnyRank:$result)>{
+ Arguments<(ins AnyVector:$source)>,
+ Results<(outs AnyVector:$result)>{
let summary = "bitcast casts between vectors";
let description = [{
The bitcast operation casts between vectors of the same rank, the minor 1-D
@@ -2561,9 +2561,9 @@ def Vector_TransposeOp :
```
}];
- let arguments = (ins AnyVectorOfAnyRank:$vector,
+ let arguments = (ins AnyVector:$vector,
DenseI64ArrayAttr:$permutation);
- let results = (outs AnyVectorOfAnyRank:$result);
+ let results = (outs AnyVector:$result);
let builders = [
OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
@@ -2593,7 +2593,7 @@ def Vector_PrintOp :
>,
]>,
Arguments<(ins Optional<Type<Or<[
- AnyVectorOfAnyRank.predicate,
+ AnyVector.predicate,
AnyInteger.predicate, Index.predicate, AnyFloat.predicate
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
@@ -2814,7 +2814,7 @@ def Vector_SplatOp : Vector_Op<"splat", [
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
"integer/index/float type">:$input);
- let results = (outs AnyVectorOfAnyRank:$aggregate);
+ let results = (outs AnyVector:$aggregate);
let builders = [
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
@@ -2873,11 +2873,11 @@ def Vector_ScanOp :
AllTypesMatch<["initial_value", "accumulated_value"]> ]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
AnyVector:$source,
- AnyVectorOfAnyRank:$initial_value,
+ AnyVector:$initial_value,
I64Attr:$reduction_dim,
BoolAttr:$inclusive)>,
Results<(outs AnyVector:$dest,
- AnyVectorOfAnyRank:$accumulated_value)> {
+ AnyVector:$accumulated_value)> {
let summary = "Scan operation";
let description = [{
Performs an inclusive/exclusive scan on an n-D vector along a single
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 48e4c24f838652..47de594bd51aae 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -657,9 +657,7 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
"::mlir::VectorType">;
-def AnyVector : VectorOf<[AnyType]>;
-// Temporary vector type clone that allows gradual transition to 0-D vectors.
-def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
+def AnyVector : VectorOfAnyRankOf<[AnyType]>;
def AnyFixedVector : FixedVectorOf<[AnyType]>;
More information about the Mlir-commits
mailing list