[Mlir-commits] [mlir] f39c2a1 - [mlir][llvm] Add vector insert/extract intrinsics
Javier Setoain
llvmlistbot at llvm.org
Mon Jun 27 06:15:49 PDT 2022
Author: Javier Setoain
Date: 2022-06-27T14:12:18+01:00
New Revision: f39c2a11428313c6643464f75645ad6cf2cf1ea7
URL: https://github.com/llvm/llvm-project/commit/f39c2a11428313c6643464f75645ad6cf2cf1ea7
DIFF: https://github.com/llvm/llvm-project/commit/f39c2a11428313c6643464f75645ad6cf2cf1ea7.diff
LOG: [mlir][llvm] Add vector insert/extract intrinsics
These intrinsics will be needed to convert between fixed-length vectors
and scalable vectors.
This operation will be needed for VLS (vector-length specific)
vectorization, when interfacing with vector functions or intrinsics that
take scalable vectors as operands in a context where the length of our
vectors is known or assumed at compile time, but we still want to
generate scalable vector instructions.
Differential Revision: https://reviews.llvm.org/D127100
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 4178cab6bcb57..8427eb18dd423 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -408,6 +408,75 @@ def LLVM_StepVectorOp
let assemblyFormat = "attr-dict `:` type($res)";
}
+/// Create a call to vector.insert intrinsic
+def LLVM_vector_insert
+ : LLVM_Op<"intr.vector.insert",
+ [NoSideEffect, AllTypesMatch<["dstvec", "res"]>,
+ PredOpTrait<"vectors are not bigger than 2^17 bits.", And<[
+ CPred<"getSrcVectorBitWidth() <= 131072">,
+ CPred<"getDstVectorBitWidth() <= 131072">
+ ]>>,
+ PredOpTrait<"it is not inserting scalable into fixed-length vectors.",
+ CPred<"!isScalableVectorType($srcvec.getType()) || "
+ "isScalableVectorType($dstvec.getType())">>]> {
+ let arguments = (ins LLVM_AnyVector:$srcvec, LLVM_AnyVector:$dstvec,
+ I64Attr:$pos);
+ let results = (outs LLVM_AnyVector:$res);
+ let builders = [LLVM_OneResultOpBuilder];
+ string llvmBuilder = [{
+ $res = builder.CreateInsertVector(
+ $_resultType, $dstvec, $srcvec, builder.getInt64($pos));
+ }];
+ let assemblyFormat = "$srcvec `,` $dstvec `[` $pos `]` attr-dict `:` "
+ "type($srcvec) `into` type($res)";
+ let extraClassDeclaration = [{
+ uint64_t getVectorBitWidth(Type vector) {
+ return getVectorNumElements(vector).getKnownMinValue() *
+ getVectorElementType(vector).getIntOrFloatBitWidth();
+ }
+ uint64_t getSrcVectorBitWidth() {
+ return getVectorBitWidth(getSrcvec().getType());
+ }
+ uint64_t getDstVectorBitWidth() {
+ return getVectorBitWidth(getDstvec().getType());
+ }
+ }];
+}
+
+/// Create a call to vector.extract intrinsic
+def LLVM_vector_extract
+ : LLVM_Op<"intr.vector.extract",
+ [NoSideEffect,
+ PredOpTrait<"vectors are not bigger than 2^17 bits.", And<[
+ CPred<"getSrcVectorBitWidth() <= 131072">,
+ CPred<"getResVectorBitWidth() <= 131072">
+ ]>>,
+ PredOpTrait<"it is not extracting scalable from fixed-length vectors.",
+ CPred<"!isScalableVectorType($res.getType()) || "
+ "isScalableVectorType($srcvec.getType())">>]> {
+ let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos);
+ let results = (outs LLVM_AnyVector:$res);
+ let builders = [LLVM_OneResultOpBuilder];
+ string llvmBuilder = [{
+ $res = builder.CreateExtractVector(
+ $_resultType, $srcvec, builder.getInt64($pos));
+ }];
+ let assemblyFormat = "$srcvec `[` $pos `]` attr-dict `:` "
+ "type($res) `from` type($srcvec)";
+ let extraClassDeclaration = [{
+ uint64_t getVectorBitWidth(Type vector) {
+ return getVectorNumElements(vector).getKnownMinValue() *
+ getVectorElementType(vector).getIntOrFloatBitWidth();
+ }
+ uint64_t getSrcVectorBitWidth() {
+ return getVectorBitWidth(getSrcvec().getType());
+ }
+ uint64_t getResVectorBitWidth() {
+ return getVectorBitWidth(getRes().getType());
+ }
+ }];
+}
+
//
// LLVM Vector Predication operations.
//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index c747cb13f2c2b..9ede9fd5931f3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -162,6 +162,16 @@ def LLVM_AnyNonAggregate : Type<And<[LLVM_Type.predicate,
def LLVM_AnyVector : Type<CPred<"::mlir::LLVM::isCompatibleVectorType($_self)">,
"LLVM dialect-compatible vector type">;
+// Type constraint accepting any LLVM fixed-length vector type.
+def LLVM_AnyFixedVector : Type<CPred<
+ "!::mlir::LLVM::isScalableVectorType($_self)">,
+ "LLVM dialect-compatible fixed-length vector type">;
+
+// Type constraint accepting any LLVM scalable vector type.
+def LLVM_AnyScalableVector : Type<CPred<
+ "::mlir::LLVM::isScalableVectorType($_self)">,
+ "LLVM dialect-compatible scalable vector type">;
+
// Type constraint accepting an LLVM vector type with an additional constraint
// on the vector element type.
class LLVM_VectorOf<Type element> : Type<
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 449ffb1bc9ab6..70b7de5164773 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1363,3 +1363,45 @@ func.func @invalid_res_struct_attr_value(%arg0 : !llvm.struct<(i32)>) -> (!llvm.
func.func @invalid_res_struct_attr_size(%arg0 : !llvm.struct<(i32)>) -> (!llvm.struct<(i32)> {llvm.struct_attrs = []}) {
return %arg0 : !llvm.struct<(i32)>
}
+
+// -----
+
+func.func @insert_vector_invalid_source_vector_size(%arg0 : vector<16385xi8>, %arg1 : vector<[16]xi8>) {
+ // expected-error at +1 {{op failed to verify that vectors are not bigger than 2^17 bits.}}
+ %0 = llvm.intr.vector.insert %arg0, %arg1[0] : vector<16385xi8> into vector<[16]xi8>
+}
+
+// -----
+
+func.func @insert_vector_invalid_dest_vector_size(%arg0 : vector<16xi8>, %arg1 : vector<[16385]xi8>) {
+ // expected-error at +1 {{op failed to verify that vectors are not bigger than 2^17 bits.}}
+ %0 = llvm.intr.vector.insert %arg0, %arg1[0] : vector<16xi8> into vector<[16385]xi8>
+}
+
+// -----
+
+func.func @insert_scalable_into_fixed_length_vector(%arg0 : vector<[8]xf32>, %arg1 : vector<16xf32>) {
+ // expected-error at +1 {{op failed to verify that it is not inserting scalable into fixed-length vectors.}}
+ %0 = llvm.intr.vector.insert %arg0, %arg1[0] : vector<[8]xf32> into vector<16xf32>
+}
+
+// -----
+
+func.func @extract_vector_invalid_source_vector_size(%arg0 : vector<[16385]xi8>) {
+ // expected-error at +1 {{op failed to verify that vectors are not bigger than 2^17 bits.}}
+ %0 = llvm.intr.vector.extract %arg0[0] : vector<16xi8> from vector<[16385]xi8>
+}
+
+// -----
+
+func.func @extract_vector_invalid_result_vector_size(%arg0 : vector<[16]xi8>) {
+ // expected-error at +1 {{op failed to verify that vectors are not bigger than 2^17 bits.}}
+ %0 = llvm.intr.vector.extract %arg0[0] : vector<16385xi8> from vector<[16]xi8>
+}
+
+// -----
+
+func.func @extract_scalable_from_fixed_length_vector(%arg0 : vector<16xf32>) {
+ // expected-error at +1 {{op failed to verify that it is not extracting scalable from fixed-length vectors.}}
+ %0 = llvm.intr.vector.extract %arg0[0] : vector<[8]xf32> from vector<16xf32>
+}
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 50a50af8eb7cf..9af27d4d1d39a 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -305,6 +305,23 @@ func.func @scalable_vect(%arg0: vector<[4]xf32>, %arg1: i32, %arg2: f32) {
return
}
+// CHECK-LABEL: @mixed_vect
+func.func @mixed_vect(%arg0: vector<8xf32>, %arg1: vector<4xf32>, %arg2: vector<[4]xf32>) {
+ // CHECK: = llvm.intr.vector.insert {{.*}} : vector<8xf32> into vector<[4]xf32>
+ %0 = llvm.intr.vector.insert %arg0, %arg2[0] : vector<8xf32> into vector<[4]xf32>
+ // CHECK: = llvm.intr.vector.insert {{.*}} : vector<4xf32> into vector<[4]xf32>
+ %1 = llvm.intr.vector.insert %arg1, %arg2[0] : vector<4xf32> into vector<[4]xf32>
+ // CHECK: = llvm.intr.vector.insert {{.*}} : vector<4xf32> into vector<[4]xf32>
+ %2 = llvm.intr.vector.insert %arg1, %1[4] : vector<4xf32> into vector<[4]xf32>
+ // CHECK: = llvm.intr.vector.insert {{.*}} : vector<4xf32> into vector<8xf32>
+ %3 = llvm.intr.vector.insert %arg1, %arg0[4] : vector<4xf32> into vector<8xf32>
+ // CHECK: = llvm.intr.vector.extract {{.*}} : vector<8xf32> from vector<[4]xf32>
+ %4 = llvm.intr.vector.extract %2[0] : vector<8xf32> from vector<[4]xf32>
+ // CHECK: = llvm.intr.vector.extract {{.*}} : vector<2xf32> from vector<8xf32>
+ %5 = llvm.intr.vector.extract %arg0[6] : vector<2xf32> from vector<8xf32>
+ return
+}
+
// CHECK-LABEL: @alloca
func.func @alloca(%size : i64) {
// CHECK: llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr<i32>
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index b0a07c1d6c59e..b9145be374662 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -680,6 +680,33 @@ llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>,
llvm.return
}
+// CHECK-LABEL: @vector_insert_extract
+llvm.func @vector_insert_extract(%f256: vector<8xi32>, %f128: vector<4xi32>,
+ %sv: vector<[4]xi32>) {
+ // CHECK: call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v8i32
+ %0 = llvm.intr.vector.insert %f256, %sv[0] :
+ vector<8xi32> into vector<[4]xi32>
+ // CHECK: call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32
+ %1 = llvm.intr.vector.insert %f128, %sv[0] :
+ vector<4xi32> into vector<[4]xi32>
+ // CHECK: call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32
+ %2 = llvm.intr.vector.insert %f128, %1[4] :
+ vector<4xi32> into vector<[4]xi32>
+ // CHECK: call <8 x i32> @llvm.vector.insert.v8i32.v4i32
+ %3 = llvm.intr.vector.insert %f128, %f256[4] :
+ vector<4xi32> into vector<8xi32>
+ // CHECK: call <8 x i32> @llvm.vector.extract.v8i32.nxv4i32
+ %4 = llvm.intr.vector.extract %2[0] :
+ vector<8xi32> from vector<[4]xi32>
+ // CHECK: call <4 x i32> @llvm.vector.extract.v4i32.nxv4i32
+ %5 = llvm.intr.vector.extract %2[0] :
+ vector<4xi32> from vector<[4]xi32>
+ // CHECK: call <2 x i32> @llvm.vector.extract.v2i32.v8i32
+ %6 = llvm.intr.vector.extract %f256[6] :
+ vector<2xi32> from vector<8xi32>
+ llvm.return
+}
+
// Check that intrinsics are declared with appropriate types.
// CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
// CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
@@ -781,3 +808,9 @@ llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>,
// CHECK-DAG: declare <8 x i64> @llvm.vp.fptosi.v8i64.v8f64(<8 x double>, <8 x i1>, i32) #2
// CHECK-DAG: declare <8 x i64> @llvm.vp.ptrtoint.v8i64.v8p0(<8 x ptr>, <8 x i1>, i32) #2
// CHECK-DAG: declare <8 x ptr> @llvm.vp.inttoptr.v8p0.v8i64(<8 x i64>, <8 x i1>, i32) #2
+// CHECK-DAG: declare <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v8i32(<vscale x 4 x i32>, <8 x i32>, i64 immarg) #2
+// CHECK-DAG: declare <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32>, <4 x i32>, i64 immarg) #2
+// CHECK-DAG: declare <8 x i32> @llvm.vector.insert.v8i32.v4i32(<8 x i32>, <4 x i32>, i64 immarg) #2
+// CHECK-DAG: declare <8 x i32> @llvm.vector.extract.v8i32.nxv4i32(<vscale x 4 x i32>, i64 immarg) #2
+// CHECK-DAG: declare <4 x i32> @llvm.vector.extract.v4i32.nxv4i32(<vscale x 4 x i32>, i64 immarg) #2
+// CHECK-DAG: declare <2 x i32> @llvm.vector.extract.v2i32.v8i32(<8 x i32>, i64 immarg) #2
More information about the Mlir-commits
mailing list