[Mlir-commits] [mlir] aa9647e - [mlir][vector] Add vector.scalable.insert/extract ops
Javier Setoain
llvmlistbot at llvm.org
Tue Nov 8 00:52:20 PST 2022
Author: Javier Setoain
Date: 2022-11-08T08:51:15Z
New Revision: aa9647e2d0b29caa3b31154246dd4d9c6a4e0c2f
URL: https://github.com/llvm/llvm-project/commit/aa9647e2d0b29caa3b31154246dd4d9c6a4e0c2f
DIFF: https://github.com/llvm/llvm-project/commit/aa9647e2d0b29caa3b31154246dd4d9c6a4e0c2f.diff
LOG: [mlir][vector] Add vector.scalable.insert/extract ops
These new operations match the semantics of
llvm.experimental.vector.insert and llvm.experimental.vector.extract.
`vector.scalable.insert` and `vector.scalable.extract` allow,
respectively, insert vectors into scalable vectors, and extract vectors
from scalable vectors.
The discussion about the inclusion of these operations is here:
https://discourse.llvm.org/t/rfc-interfacing-between-fixed-length-and-scalable-vectors-for-vls-vector-code-on-scalable-vector-architectures
Differential Revision: https://reviews.llvm.org/D127875
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e952284046b54..758e7c11c9f80 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -727,6 +727,114 @@ def Vector_InsertOp :
let hasVerifier = 1;
}
+def Vector_ScalableInsertOp :
+ Vector_Op<"scalable.insert", [Pure,
+ AllElementTypesMatch<["source", "dest"]>,
+ AllTypesMatch<["dest", "res"]>,
+ PredOpTrait<"position is a multiple of the source length.",
+ CPred<
+ "(getPos() % getSourceVectorType().getNumElements()) == 0"
+ >>]>,
+ Arguments<(ins VectorOfRank<[1]>:$source,
+ ScalableVectorOfRank<[1]>:$dest,
+ I64Attr:$pos)>,
+ Results<(outs ScalableVectorOfRank<[1]>:$res)> {
+ let summary = "insert subvector into scalable vector operation";
+ // NOTE: This operation is designed to map to `llvm.vector.insert`, and its
+ // documentation should be kept aligned with LLVM IR:
+ // https://llvm.org/docs/LangRef.html#llvm-vector-insert-intrinsic
+ let description = [{
+ This operations takes a rank-1 fixed-length or scalable subvector and
+ inserts it within the destination scalable vector starting from the
+ position specificed by `pos`. If the source vector is scalable, the
+ insertion position will be scaled by the runtime scaling factor of the
+ source subvector.
+
+ The insertion position must be a multiple of the minimum size of the source
+ vector. For the operation to be well defined, the source vector must fit in
+ the destination vector from the specified position. Since the destination
+ vector is scalable and its runtime length is unknown, the validity of the
+ operation can't be verified nor guaranteed at compile time.
+
+ Example:
+
+ ```mlir
+ %2 = vector.scalable.insert %0, %1[8] : vector<4xf32> into vector<[16]xf32>
+ %5 = vector.scalable.insert %3, %4[0] : vector<8xf32> into vector<[4]xf32>
+ %8 = vector.scalable.insert %6, %7[0] : vector<[4]xf32> into vector<[8]xf32>
+ ```
+
+ Invalid example:
+ ```mlir
+ %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getSourceVectorType() {
+ return getSource().getType().cast<VectorType>();
+ }
+ VectorType getDestVectorType() {
+ return getDest().getType().cast<VectorType>();
+ }
+ }];
+}
+
+def Vector_ScalableExtractOp :
+ Vector_Op<"scalable.extract", [Pure,
+ AllElementTypesMatch<["source", "res"]>,
+ PredOpTrait<"position is a multiple of the result length.",
+ CPred<
+ "(getPos() % getResultVectorType().getNumElements()) == 0"
+ >>]>,
+ Arguments<(ins ScalableVectorOfRank<[1]>:$source,
+ I64Attr:$pos)>,
+ Results<(outs VectorOfRank<[1]>:$res)> {
+ let summary = "extract subvector from scalable vector operation";
+ // NOTE: This operation is designed to map to `llvm.vector.extract`, and its
+ // documentation should be kept aligned with LLVM IR:
+ // https://llvm.org/docs/LangRef.html#llvm-vector-extract-intrinsic
+ let description = [{
+ Takes rank-1 source vector and a position `pos` within the source
+ vector, and extracts a subvector starting from that position.
+
+ The extraction position must be a multiple of the minimum size of the result
+ vector. For the operation to be well defined, the destination vector must
+ fit within the source vector from the specified position. Since the source
+ vector is scalable and its runtime length is unknown, the validity of the
+ operation can't be verified nor guaranteed at compile time.
+
+ Example:
+
+ ```mlir
+ %1 = vector.scalable.extract %0[8] : vector<4xf32> from vector<[8]xf32>
+ %3 = vector.scalable.extract %2[0] : vector<[4]xf32> from vector<[8]xf32>
+ ```
+
+ Invalid example:
+ ```mlir
+ %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32>
+ ```
+ }];
+
+ let assemblyFormat = [{
+ $source `[` $pos `]` attr-dict `:` type($res) `from` type($source)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getSourceVectorType() {
+ return getSource().getType().cast<VectorType>();
+ }
+ VectorType getResultVectorType() {
+ return getRes().getType().cast<VectorType>();
+ }
+ }];
+}
+
def Vector_InsertStridedSliceOp :
Vector_Op<"insert_strided_slice", [Pure,
PredOpTrait<"operand #0 and result have same element type",
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 27e55ad07960e..da0e2de6bd2f0 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -579,11 +579,39 @@ class IsVectorOfRankPred<list<int> allowedRanks> :
== }]
# allowedlength>)>]>;
+// Whether the number of elements of a fixed-length vector is from the given
+// `allowedRanks` list
+class IsFixedVectorOfRankPred<list<int> allowedRanks> :
+ And<[IsFixedVectorTypePred,
+ Or<!foreach(allowedlength, allowedRanks,
+ CPred<[{$_self.cast<::mlir::VectorType>().getRank()
+ == }]
+ # allowedlength>)>]>;
+
+// Whether the number of elements of a scalable vector is from the given
+// `allowedRanks` list
+class IsScalableVectorOfRankPred<list<int> allowedRanks> :
+ And<[IsScalableVectorTypePred,
+ Or<!foreach(allowedlength, allowedRanks,
+ CPred<[{$_self.cast<::mlir::VectorType>().getRank()
+ == }]
+ # allowedlength>)>]>;
+
// Any vector where the rank is from the given `allowedRanks` list
class VectorOfRank<list<int> allowedRanks> : Type<
IsVectorOfRankPred<allowedRanks>,
" of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
+// Any fixed-length vector where the rank is from the given `allowedRanks` list
+class FixedVectorOfRank<list<int> allowedRanks> : Type<
+ IsFixedVectorOfRankPred<allowedRanks>,
+ " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
+
+// Any scalable vector where the rank is from the given `allowedRanks` list
+class ScalableVectorOfRank<list<int> allowedRanks> : Type<
+ IsScalableVectorOfRankPred<allowedRanks>,
+ " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
+
// Any vector where the rank is from the given `allowedRanks` list and the type
// is from the given `allowedTypes` list
class VectorOfRankAndType<list<int> allowedRanks,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f4191e2133def..fd3cee552138b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -857,6 +857,37 @@ class VectorInsertOpConversion
}
};
+/// Lower vector.scalable.insert ops to LLVM vector.insert
+struct VectorScalableInsertOpLowering
+ : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
+ using ConvertOpToLLVMPattern<
+ vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
+ insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos());
+ return success();
+ }
+};
+
+/// Lower vector.scalable.extract ops to LLVM vector.extract
+struct VectorScalableExtractOpLowering
+ : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
+ using ConvertOpToLLVMPattern<
+ vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
+ extOp, typeConverter->convertType(extOp.getResultVectorType()),
+ adaptor.getSource(), adaptor.getPos());
+ return success();
+ }
+};
+
/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
///
/// Example:
@@ -1329,7 +1360,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
vector::MaskedStoreOpAdaptor>,
VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
+ VectorSplatOpLowering, VectorSplatNdOpLowering,
+ VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>(
+ converter);
// Transfer ops with rank > 1 are handled by VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 0a4732aecf0fc..24697bc18acdf 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2140,3 +2140,25 @@ func.func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0, 0, 0, 0]
// CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32>
// CHECK-NEXT: return %[[SCALE]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: @vector_scalable_insert
+// CHECK-SAME: %[[SUB:.*]]: vector<4xf32>, %[[SV:.*]]: vector<[4]xf32>
+func.func @vector_scalable_insert(%sub: vector<4xf32>, %dsv: vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK-NEXT: %[[TMP:.*]] = llvm.intr.vector.insert %[[SUB]], %[[SV]][0] : vector<4xf32> into vector<[4]xf32>
+ %0 = vector.scalable.insert %sub, %dsv[0] : vector<4xf32> into vector<[4]xf32>
+ // CHECK-NEXT: llvm.intr.vector.insert %[[SUB]], %[[TMP]][4] : vector<4xf32> into vector<[4]xf32>
+ %1 = vector.scalable.insert %sub, %0[4] : vector<4xf32> into vector<[4]xf32>
+ return %1 : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_scalable_extract
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> {
+ // CHECK-NEXT: %{{.*}} = llvm.intr.vector.extract %[[VEC]][0] : vector<8xf32> from vector<[4]xf32>
+ %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32>
+ return %0 : vector<8xf32>
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 115cf41f5fd54..3366b1dd6ef02 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1632,3 +1632,16 @@ func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor<?xf3
return
}
+// -----
+
+func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
+ // expected-error at +1 {{op failed to verify that position is a multiple of the source length.}}
+ %0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
+}
+
+// -----
+
+func.func @vector_scalable_extract_unaligned(%vec: vector<[16]xf32>) {
+ // expected-error at +1 {{op failed to verify that position is a multiple of the result length.}}
+ %0 = vector.scalable.extract %vec[5] : vector<4xf32> from vector<[16]xf32>
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 7a6a2c3fdfe0b..88b1abbd209a7 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -853,3 +853,28 @@ func.func @vector_mask_tensor_return(%val: vector<16xf32>, %t0: tensor<?xf32>, %
return
}
+// CHECK-LABEL: func @vector_scalable_insert(
+// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
+// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>
+func.func @vector_scalable_insert(%sub0: vector<4xi32>, %sub1: vector<8xi32>,
+ %sub2: vector<[4]xi32>, %sv: vector<[8]xi32>) {
+ // CHECK-NEXT: vector.scalable.insert %[[SUB0]], %[[SV]][12] : vector<4xi32> into vector<[8]xi32>
+ %0 = vector.scalable.insert %sub0, %sv[12] : vector<4xi32> into vector<[8]xi32>
+ // CHECK-NEXT: vector.scalable.insert %[[SUB1]], %[[SV]][0] : vector<8xi32> into vector<[8]xi32>
+ %1 = vector.scalable.insert %sub1, %sv[0] : vector<8xi32> into vector<[8]xi32>
+ // CHECK-NEXT: vector.scalable.insert %[[SUB2]], %[[SV]][0] : vector<[4]xi32> into vector<[8]xi32>
+ %2 = vector.scalable.insert %sub2, %sv[0] : vector<[4]xi32> into vector<[8]xi32>
+ return
+ }
+
+// CHECK-LABEL: func @vector_scalable_extract(
+// CHECK-SAME: %[[SV:.*]]: vector<[8]xi32>
+func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
+ // CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<16xi32> from vector<[8]xi32>
+ %0 = vector.scalable.extract %sv[0] : vector<16xi32> from vector<[8]xi32>
+ // CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<[4]xi32> from vector<[8]xi32>
+ %1 = vector.scalable.extract %sv[0] : vector<[4]xi32> from vector<[8]xi32>
+ // CHECK-NEXT: vector.scalable.extract %[[SV]][4] : vector<4xi32> from vector<[8]xi32>
+ %2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
+ return
+ }
More information about the Mlir-commits
mailing list