[Mlir-commits] [mlir] [mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors (PR #117538)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 25 02:53:04 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-neon
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
- **[mlir][vector] Rename vector type TD definitions (nfc)**
- **[mlir][vector] Disable CompressStoreOp/ExpandLoadOp for scalable vectors**
---
Full diff: https://github.com/llvm/llvm-project/pull/117538.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td (+1-1)
- (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+27-27)
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+21-15)
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+27-17)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+16)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
index 9cc792093bf836..475b11f12c5f01 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td
@@ -35,7 +35,7 @@ def ArmNeon_Dialect : Dialect {
//===----------------------------------------------------------------------===//
class NeonVectorOfLength<int length, Type elementType> : ShapedContainerType<
- [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorTypePred]>,
+ [elementType], And<[IsVectorOfShape<[length]>, IsFixedVectorOfAnyRankTypePred]>,
"a vector with length " # length,
"::mlir::VectorType">;
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index d7e8b22fbd2d35..cdcf4d8752e874 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -100,11 +100,11 @@ class ScalableMaskedFOp<string mnemonic, string op_description,
op_description # [{ on active lanes. Inactive lanes will keep the value of
the first operand.}];
let arguments = (ins
- ScalableVectorOf<[I1]>:$mask,
- ScalableVectorOf<[AnyFloat]>:$src1,
- ScalableVectorOf<[AnyFloat]>:$src2
+ ScalableVectorOfAnyRank<[I1]>:$mask,
+ ScalableVectorOfAnyRank<[AnyFloat]>:$src1,
+ ScalableVectorOfAnyRank<[AnyFloat]>:$src2
);
- let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
+ let results = (outs ScalableVectorOfAnyRank<[AnyFloat]>:$res);
let assemblyFormat =
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
}
@@ -123,11 +123,11 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
op_description # [{ on active lanes. Inactive lanes will keep the value of
the first operand.}];
let arguments = (ins
- ScalableVectorOf<[I1]>:$mask,
- ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
- ScalableVectorOf<[I8, I16, I32, I64]>:$src2
+ ScalableVectorOfAnyRank<[I1]>:$mask,
+ ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src1,
+ ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$src2
);
- let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
+ let results = (outs ScalableVectorOfAnyRank<[I8, I16, I32, I64]>:$res);
let assemblyFormat =
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
}
@@ -511,55 +511,55 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def UdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"udot">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedAddIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"add">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedAddFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fadd">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedMulIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"mul">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedMulFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fmul">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedSubIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sub">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedSubFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fsub">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedSDivIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedUDivIIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"udiv">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ScalableMaskedDivFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
- Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
def ConvertFromSvboolIntrOp :
ArmSVE_IntrOp<"convert.from.svbool",
@@ -581,8 +581,8 @@ def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
/*overloadedOperands=*/[0],
/*overloadedResults=*/[],
/*numResults=*/2>,
- Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
- Arg<AnyScalableVector, "v2">:$v2)>;
+ Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+ Arg<AnyScalableVectorOfAnyRank, "v2">:$v2)>;
// Note: This multi-vector intrinsic requires SME2.
def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
@@ -590,10 +590,10 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
/*overloadedOperands=*/[0],
/*overloadedResults=*/[],
/*numResults=*/4>,
- Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
- Arg<AnyScalableVector, "v2">:$v2,
- Arg<AnyScalableVector, "v3">:$v3,
- Arg<AnyScalableVector, "v3">:$v4)>;
+ Arguments<(ins Arg<AnyScalableVectorOfAnyRank, "v1">:$v1,
+ Arg<AnyScalableVectorOfAnyRank, "v2">:$v2,
+ Arg<AnyScalableVectorOfAnyRank, "v3">:$v3,
+ Arg<AnyScalableVectorOfAnyRank, "v3">:$v4)>;
// Note: This intrinsic requires SME or SVE2.1.
def PselIntrOp : ArmSVE_IntrOp<"psel",
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cc4cafa869e63a..5911355abd5146 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -417,16 +417,18 @@ def Vector_BroadcastOp :
let hasVerifier = 1;
}
-def Vector_ShuffleOp :
- Vector_Op<"shuffle", [Pure,
- PredOpTrait<"first operand v1 and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>,
- PredOpTrait<"second operand v2 and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 1>>,
- InferTypeOpAdaptor]>,
- Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
- DenseI64ArrayAttr:$mask)>,
- Results<(outs AnyVector:$vector)> {
+def Vector_ShuffleOp
+ : Vector_Op<
+ "shuffle",
+ [Pure,
+ PredOpTrait<"first operand v1 and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 0>>,
+ PredOpTrait<"second operand v2 and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 1>>,
+ InferTypeOpAdaptor]>,
+ Arguments<(ins AnyFixedVectorOfAnyRank:$v1, AnyFixedVectorOfAnyRank:$v2,
+ DenseI64ArrayAttr:$mask)>,
+ Results<(outs AnyVector:$vector)> {
let summary = "shuffle operation";
let description = [{
The shuffle operation constructs a permutation (or duplication) of elements
@@ -2082,9 +2084,9 @@ def Vector_ExpandLoadOp :
Vector_Op<"expandload">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
- VectorOf<[I1]>:$mask,
- AnyVector:$pass_thru)>,
- Results<(outs AnyVector:$result)> {
+ FixedVectorOf<[I1]>:$mask,
+ AnyFixedVector:$pass_thru)>,
+ Results<(outs AnyFixedVector:$result)> {
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
@@ -2115,6 +2117,8 @@ def Vector_ExpandLoadOp :
correspond to those of the `llvm.masked.expandload`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
+ Note, at the moment this Op is only available for fixed-width vectors.
+
Examples:
```mlir
@@ -2149,8 +2153,8 @@ def Vector_CompressStoreOp :
Vector_Op<"compressstore">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOf<[I1]>:$mask,
- AnyVector:$valueToStore)> {
+ FixedVectorOf<[I1]>:$mask,
+ AnyFixedVector:$valueToStore)> {
let summary = "writes elements selectively from a vector as defined by a mask";
@@ -2181,6 +2185,8 @@ def Vector_CompressStoreOp :
correspond to those of the `llvm.masked.compressstore`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
+ Note that the index increment is done conditionally.
+
Examples:
```mlir
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 48e4c24f838652..ef1645117a7280 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,13 +24,16 @@ include "mlir/IR/DialectBase.td"
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
+def IsFixedVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
+ CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
// Whether a type is a fixed-length VectorType.
-def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
!::llvm::cast<VectorType>($_self).isScalable()}]>;
// Whether a type is a scalable VectorType.
@@ -432,17 +435,21 @@ class VectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
"::mlir::VectorType">;
+class FixedVectorOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
+ "fixed-length vector", "::mlir::VectorType">;
+
// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
class VectorOfAnyRankOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
"::mlir::VectorType">;
-class FixedVectorOf<list<Type> allowedTypes> :
- ShapedContainerType<allowedTypes, IsFixedVectorTypePred,
+class FixedVectorOfAnyRank<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsFixedVectorOfAnyRankTypePred,
"fixed-length vector", "::mlir::VectorType">;
-class ScalableVectorOf<list<Type> allowedTypes> :
+class ScalableVectorOfAnyRank<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypeWithAnyDimScalablePred,
"scalable vector", "::mlir::VectorType">;
@@ -467,7 +474,7 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
// Whether the number of elements of a fixed-length vector is from the given
// `allowedRanks` list
class IsFixedVectorOfRankPred<list<int> allowedRanks> :
- And<[IsFixedVectorTypePred,
+ And<[IsFixedVectorOfAnyRankTypePred,
Or<!foreach(allowedlength, allowedRanks,
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getRank()
== }]
@@ -509,8 +516,8 @@ class VectorOfRankAndType<list<int> allowedRanks,
// the type is from the given `allowedTypes` list
class FixedVectorOfRankAndType<list<int> allowedRanks,
list<Type> allowedTypes> : AllOfType<
- [FixedVectorOf<allowedTypes>, VectorOfRank<allowedRanks>],
- FixedVectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
+ [FixedVectorOfAnyRank<allowedTypes>, VectorOfRank<allowedRanks>],
+ FixedVectorOfAnyRank<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
"::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
@@ -525,7 +532,7 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
// Whether the number of elements of a fixed-length vector is from the given
// `allowedLengths` list
class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
- And<[IsFixedVectorTypePred,
+ And<[IsFixedVectorOfAnyRankTypePred,
Or<!foreach(allowedlength, allowedLengths,
CPred<[{::llvm::cast<::mlir::VectorType>($_self).getNumElements()
== }]
@@ -612,8 +619,8 @@ class VectorOfLengthAndType<list<int> allowedLengths,
// `allowedLengths` list and the type is from the given `allowedTypes` list
class FixedVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
- [FixedVectorOf<allowedTypes>, FixedVectorOfLength<allowedLengths>],
- FixedVectorOf<allowedTypes>.summary #
+ [FixedVectorOfAnyRank<allowedTypes>, FixedVectorOfLength<allowedLengths>],
+ FixedVectorOfAnyRank<allowedTypes>.summary #
FixedVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
@@ -621,8 +628,8 @@ class FixedVectorOfLengthAndType<list<int> allowedLengths,
// `allowedLengths` list and the type is from the given `allowedTypes` list
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
- [ScalableVectorOf<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
- ScalableVectorOf<allowedTypes>.summary #
+ [ScalableVectorOfAnyRank<allowedTypes>, ScalableVectorOfLength<allowedLengths>],
+ ScalableVectorOfAnyRank<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
@@ -632,10 +639,10 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
list<int> allowedLengths,
list<Type> allowedTypes> : AllOfType<
- [ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
+ [ScalableVectorOfRank<allowedRanks>, ScalableVectorOfAnyRank<allowedTypes>,
ScalableVectorOfLength<allowedLengths>],
ScalableVectorOfRank<allowedRanks>.summary #
- ScalableVectorOf<allowedTypes>.summary #
+ ScalableVectorOfAnyRank<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
@@ -657,13 +664,16 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
"::mlir::VectorType">;
+// Unlike the following definitions, this one excludes 0-D vectors
def AnyVector : VectorOf<[AnyType]>;
-// Temporary vector type clone that allows gradual transition to 0-D vectors.
-def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
def AnyFixedVector : FixedVectorOf<[AnyType]>;
-def AnyScalableVector : ScalableVectorOf<[AnyType]>;
+def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
+
+def AnyFixedVectorOfAnyRank : FixedVectorOfAnyRank<[AnyType]>;
+
+def AnyScalableVectorOfAnyRank : ScalableVectorOfAnyRank<[AnyType]>;
// Shaped types.
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 0c093b0ccff141..ae336d4f5ddb8d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1519,6 +1519,14 @@ func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>,
// -----
+func.func @expand_base_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %pass_thru: vector<[16]xf32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{'vector.expandload' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+ %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
+}
+
+// -----
+
func.func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error at +1 {{'vector.expandload' op expected result dim to match mask dim}}
@@ -1551,6 +1559,14 @@ func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1
// -----
+func.func @compress_scalable(%base: memref<?xf32>, %mask: vector<[16]xi1>, %value: vector<[16]xf32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{'vector.compressstore' op operand #2 must be fixed-length vector of 1-bit signless integer values, but got 'vector<[16]xi1>}}
+ vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
+}
+
+// -----
+
func.func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error at +1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}
``````````
</details>
https://github.com/llvm/llvm-project/pull/117538
More information about the Mlir-commits
mailing list