[Mlir-commits] [mlir] [MLIR] Merge AnyVector and AnyVectorOfAnyRank type constraints. (PR #112937)

Harrison Hao llvmlistbot at llvm.org
Fri Oct 18 10:12:54 PDT 2024


https://github.com/harrisonGPU created https://github.com/llvm/llvm-project/pull/112937

Merge AnyVector and AnyVectorOfAnyRank type constraints.

Closes https://github.com/llvm/llvm-project/issues/112913

>From 0b96a318b75300e7acc1216c9d8c27318b8afb25 Mon Sep 17 00:00:00 2001
From: Harrison Hao <tsworld1314 at gmail.com>
Date: Fri, 18 Oct 2024 16:56:20 +0000
Subject: [PATCH] [MLIR] Merge AnyVector and AnyVectorOfAnyRank type
 constraints.

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 48 +++++++++----------
 mlir/include/mlir/IR/CommonTypeConstraints.td |  4 +-
 2 files changed, 25 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c02b16ea931706..11d3adb3c43e93 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -224,7 +224,7 @@ def Vector_ReductionOp :
      DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
     ]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
-               AnyVectorOfAnyRank:$vector,
+               AnyVector:$vector,
                Optional<AnyType>:$acc,
                DefaultValuedAttr<
                  Arith_FastMathAttr,
@@ -349,7 +349,7 @@ def Vector_BroadcastOp :
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
     Arguments<(ins AnyType:$source)>,
-    Results<(outs AnyVectorOfAnyRank:$vector)> {
+    Results<(outs AnyVector:$vector)> {
   let summary = "broadcast operation";
   let description = [{
     Broadcasts the scalar or k-D vector value in the source operand
@@ -528,7 +528,7 @@ def Vector_InterleaveOp :
     ```
   }];
 
-  let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs);
+  let arguments = (ins AnyVector:$lhs, AnyVector:$rhs);
   let results = (outs AnyVector:$result);
 
   let assemblyFormat = [{
@@ -630,7 +630,7 @@ def Vector_ExtractElementOp :
      TypesMatchWith<"result type matches element type of vector operand",
                     "vector", "result",
                     "::llvm::cast<VectorType>($_self).getElementType()">]>,
-    Arguments<(ins AnyVectorOfAnyRank:$vector,
+    Arguments<(ins AnyVector:$vector,
                    Optional<AnySignlessIntegerOrIndex>:$position)>,
     Results<(outs AnyType:$result)> {
   let summary = "extractelement operation";
@@ -697,7 +697,7 @@ def Vector_ExtractOp :
   }];
 
   let arguments = (ins
-    AnyVectorOfAnyRank:$vector,
+    AnyVector:$vector,
     Variadic<Index>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
@@ -803,7 +803,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
   }];
 
   let arguments = (ins Variadic<AnyType>:$elements);
-  let results = (outs AnyVectorOfAnyRank:$result);
+  let results = (outs AnyVector:$result);
   let assemblyFormat = "$elements attr-dict `:` type($result)";
   let hasCanonicalizer = 1;
 }
@@ -814,9 +814,9 @@ def Vector_InsertElementOp :
                     "result", "source",
                     "::llvm::cast<VectorType>($_self).getElementType()">,
      AllTypesMatch<["dest", "result"]>]>,
-     Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
+     Arguments<(ins AnyType:$source, AnyVector:$dest,
                     Optional<AnySignlessIntegerOrIndex>:$position)>,
-     Results<(outs AnyVectorOfAnyRank:$result)> {
+     Results<(outs AnyVector:$result)> {
   let summary = "insertelement operation";
   let description = [{
     Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index
@@ -884,11 +884,11 @@ def Vector_InsertOp :
 
   let arguments = (ins
     AnyType:$source,
-    AnyVectorOfAnyRank:$dest,
+    AnyVector:$dest,
     Variadic<Index>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
-  let results = (outs AnyVectorOfAnyRank:$result);
+  let results = (outs AnyVector:$result);
 
   let builders = [
     OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
@@ -1250,7 +1250,7 @@ def Vector_TransferReadOp :
                    AnyType:$padding,
                    Optional<VectorOf<[I1]>>:$mask,
                    BoolArrayAttr:$in_bounds)>,
-    Results<(outs AnyVectorOfAnyRank:$vector)> {
+    Results<(outs AnyVector:$vector)> {
 
   let summary = "Reads a supervector from memory into an SSA vector value.";
 
@@ -1492,7 +1492,7 @@ def Vector_TransferWriteOp :
       AttrSizedOperandSegments,
       DestinationStyleOpInterface
   ]>,
-    Arguments<(ins AnyVectorOfAnyRank:$vector,
+    Arguments<(ins AnyVector:$vector,
                    AnyShaped:$source,
                    Variadic<Index>:$indices,
                    AffineMapAttr:$permutation_map,
@@ -1710,7 +1710,7 @@ def Vector_LoadOp : Vector_Op<"load"> {
       [MemRead]>:$base,
       Variadic<Index>:$indices,
       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
-  let results = (outs AnyVectorOfAnyRank:$result);
+  let results = (outs AnyVector:$result);
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -1791,7 +1791,7 @@ def Vector_StoreOp : Vector_Op<"store"> {
   }];
 
   let arguments = (ins
-      AnyVectorOfAnyRank:$valueToStore,
+      AnyVector:$valueToStore,
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
       Variadic<Index>:$indices,
@@ -2199,8 +2199,8 @@ def Vector_CompressStoreOp :
 
 def Vector_ShapeCastOp :
   Vector_Op<"shape_cast", [Pure]>,
-    Arguments<(ins AnyVectorOfAnyRank:$source)>,
-    Results<(outs AnyVectorOfAnyRank:$result)> {
+    Arguments<(ins AnyVector:$source)>,
+    Results<(outs AnyVector:$result)> {
   let summary = "shape_cast casts between vector shapes";
   let description = [{
     The shape_cast operation casts between an n-D source vector shape and
@@ -2251,8 +2251,8 @@ def Vector_ShapeCastOp :
 
 def Vector_BitCastOp :
   Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
-    Arguments<(ins AnyVectorOfAnyRank:$source)>,
-    Results<(outs AnyVectorOfAnyRank:$result)>{
+    Arguments<(ins AnyVector:$source)>,
+    Results<(outs AnyVector:$result)>{
   let summary = "bitcast casts between vectors";
   let description = [{
     The bitcast operation casts between vectors of the same rank, the minor 1-D
@@ -2561,9 +2561,9 @@ def Vector_TransposeOp :
     ```
   }];
 
-  let arguments = (ins AnyVectorOfAnyRank:$vector,
+  let arguments = (ins AnyVector:$vector,
                        DenseI64ArrayAttr:$permutation);
-  let results = (outs AnyVectorOfAnyRank:$result);
+  let results = (outs AnyVector:$result);
 
   let builders = [
     OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
@@ -2593,7 +2593,7 @@ def Vector_PrintOp :
     >,
   ]>,
   Arguments<(ins Optional<Type<Or<[
-    AnyVectorOfAnyRank.predicate,
+    AnyVector.predicate,
     AnyInteger.predicate, Index.predicate, AnyFloat.predicate
   ]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
                       "::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
@@ -2814,7 +2814,7 @@ def Vector_SplatOp : Vector_Op<"splat", [
 
   let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
                                  "integer/index/float type">:$input);
-  let results = (outs AnyVectorOfAnyRank:$aggregate);
+  let results = (outs AnyVector:$aggregate);
 
   let builders = [
     OpBuilder<(ins "Value":$element, "Type":$aggregateType),
@@ -2873,11 +2873,11 @@ def Vector_ScanOp :
     AllTypesMatch<["initial_value", "accumulated_value"]> ]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
                    AnyVector:$source,
-                   AnyVectorOfAnyRank:$initial_value,
+                   AnyVector:$initial_value,
                    I64Attr:$reduction_dim,
                    BoolAttr:$inclusive)>,
     Results<(outs AnyVector:$dest,
-                  AnyVectorOfAnyRank:$accumulated_value)> {
+                  AnyVector:$accumulated_value)> {
   let summary = "Scan operation";
   let description = [{
     Performs an inclusive/exclusive scan on an n-D vector along a single
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 48e4c24f838652..47de594bd51aae 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -657,9 +657,7 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
    ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
   "::mlir::VectorType">;
 
-def AnyVector : VectorOf<[AnyType]>;
-// Temporary vector type clone that allows gradual transition to 0-D vectors.
-def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
+def AnyVector : VectorOfAnyRankOf<[AnyType]>;
 
 def AnyFixedVector : FixedVectorOf<[AnyType]>;
 



More information about the Mlir-commits mailing list