[Mlir-commits] [mlir] f39c2a1 - [mlir][llvm] Add vector insert/extract intrinsics

Javier Setoain llvmlistbot at llvm.org
Mon Jun 27 06:15:49 PDT 2022


Author: Javier Setoain
Date: 2022-06-27T14:12:18+01:00
New Revision: f39c2a11428313c6643464f75645ad6cf2cf1ea7

URL: https://github.com/llvm/llvm-project/commit/f39c2a11428313c6643464f75645ad6cf2cf1ea7
DIFF: https://github.com/llvm/llvm-project/commit/f39c2a11428313c6643464f75645ad6cf2cf1ea7.diff

LOG: [mlir][llvm] Add vector insert/extract intrinsics

These intrinsics will be needed to convert between fixed-length vectors
and scalable vectors.

This operation will be needed for VLS (vector-length specific)
vectorization, when interfacing with vector functions or intrinsics that
take scalable vectors as operands in a context where the length of our
vectors is known or assumed at compile time, but we still want to
generate scalable vector instructions.

Differential Revision: https://reviews.llvm.org/D127100

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 4178cab6bcb57..8427eb18dd423 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -408,6 +408,75 @@ def LLVM_StepVectorOp
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
+/// Create a call to vector.insert intrinsic
+def LLVM_vector_insert
+    : LLVM_Op<"intr.vector.insert",
+                 [NoSideEffect, AllTypesMatch<["dstvec", "res"]>,
+                  PredOpTrait<"vectors are not bigger than 2^17 bits.", And<[
+                    CPred<"getSrcVectorBitWidth() <= 131072">,
+                    CPred<"getDstVectorBitWidth() <= 131072">
+                  ]>>,
+                  PredOpTrait<"it is not inserting scalable into fixed-length vectors.",
+                    CPred<"!isScalableVectorType($srcvec.getType()) || "
+                          "isScalableVectorType($dstvec.getType())">>]> {
+  let arguments = (ins LLVM_AnyVector:$srcvec, LLVM_AnyVector:$dstvec,
+                       I64Attr:$pos);
+  let results = (outs LLVM_AnyVector:$res);
+  let builders = [LLVM_OneResultOpBuilder];
+  string llvmBuilder = [{
+    $res = builder.CreateInsertVector(
+        $_resultType, $dstvec, $srcvec, builder.getInt64($pos));
+  }];
+  let assemblyFormat = "$srcvec `,` $dstvec `[` $pos `]` attr-dict `:` "
+    "type($srcvec) `into` type($res)";
+  let extraClassDeclaration = [{
+    uint64_t getVectorBitWidth(Type vector) {
+      return getVectorNumElements(vector).getKnownMinValue() *
+             getVectorElementType(vector).getIntOrFloatBitWidth();
+    }
+    uint64_t getSrcVectorBitWidth() {
+      return getVectorBitWidth(getSrcvec().getType());
+    }
+    uint64_t getDstVectorBitWidth() {
+      return getVectorBitWidth(getDstvec().getType());
+    }
+  }];
+}
+
+/// Create a call to vector.extract intrinsic
+def LLVM_vector_extract
+    : LLVM_Op<"intr.vector.extract",
+                 [NoSideEffect,
+                  PredOpTrait<"vectors are not bigger than 2^17 bits.", And<[
+                    CPred<"getSrcVectorBitWidth() <= 131072">,
+                    CPred<"getResVectorBitWidth() <= 131072">
+                  ]>>,
+                  PredOpTrait<"it is not extracting scalable from fixed-length vectors.",
+                    CPred<"!isScalableVectorType($res.getType()) || "
+                          "isScalableVectorType($srcvec.getType())">>]> {
+  let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos);
+  let results = (outs LLVM_AnyVector:$res);
+  let builders = [LLVM_OneResultOpBuilder];
+  string llvmBuilder = [{
+    $res = builder.CreateExtractVector(
+        $_resultType, $srcvec, builder.getInt64($pos));
+  }];
+  let assemblyFormat = "$srcvec `[` $pos `]` attr-dict `:` "
+    "type($res) `from` type($srcvec)";
+  let extraClassDeclaration = [{
+    uint64_t getVectorBitWidth(Type vector) {
+      return getVectorNumElements(vector).getKnownMinValue() *
+             getVectorElementType(vector).getIntOrFloatBitWidth();
+    }
+    uint64_t getSrcVectorBitWidth() {
+      return getVectorBitWidth(getSrcvec().getType());
+    }
+    uint64_t getResVectorBitWidth() {
+      return getVectorBitWidth(getRes().getType());
+    }
+  }];
+}
+
 //
 // LLVM Vector Predication operations.
 //

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index c747cb13f2c2b..9ede9fd5931f3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -162,6 +162,16 @@ def LLVM_AnyNonAggregate : Type<And<[LLVM_Type.predicate,
 def LLVM_AnyVector : Type<CPred<"::mlir::LLVM::isCompatibleVectorType($_self)">,
                          "LLVM dialect-compatible vector type">;
 
+// Type constraint accepting any LLVM fixed-length vector type.
+def LLVM_AnyFixedVector : Type<CPred<
+                                "!::mlir::LLVM::isScalableVectorType($_self)">,
+                                "LLVM dialect-compatible fixed-length vector type">;
+
+// Type constraint accepting any LLVM scalable vector type.
+def LLVM_AnyScalableVector : Type<CPred<
+                                "::mlir::LLVM::isScalableVectorType($_self)">,
+                                "LLVM dialect-compatible scalable vector type">;
+
 // Type constraint accepting an LLVM vector type with an additional constraint
 // on the vector element type.
 class LLVM_VectorOf<Type element> : Type<

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 449ffb1bc9ab6..70b7de5164773 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1363,3 +1363,45 @@ func.func @invalid_res_struct_attr_value(%arg0 : !llvm.struct<(i32)>) -> (!llvm.
 func.func @invalid_res_struct_attr_size(%arg0 : !llvm.struct<(i32)>) -> (!llvm.struct<(i32)> {llvm.struct_attrs = []}) {
     return %arg0 : !llvm.struct<(i32)>
 }
+
+// -----
+
+func.func @insert_vector_invalid_source_vector_size(%arg0 : vector<16385xi8>, %arg1 : vector<[16]xi8>) {
+  // expected-error at +1 {{op failed to verify that vectors are not bigger than 2^17 bits.}}
+  %0 = llvm.intr.vector.insert %arg0, %arg1[0] : vector<16385xi8> into vector<[16]xi8>
+}
+
+// -----
+
+func.func @insert_vector_invalid_dest_vector_size(%arg0 : vector<16xi8>, %arg1 : vector<[16385]xi8>) {
+  // expected-error at +1 {{op failed to verify that vectors are not bigger than 2^17 bits.}}
+  %0 = llvm.intr.vector.insert %arg0, %arg1[0] : vector<16xi8> into vector<[16385]xi8>
+}
+
+// -----
+
+func.func @insert_scalable_into_fixed_length_vector(%arg0 : vector<[8]xf32>, %arg1 : vector<16xf32>) {
+  // expected-error at +1 {{op failed to verify that it is not inserting scalable into fixed-length vectors.}}
+  %0 = llvm.intr.vector.insert %arg0, %arg1[0] : vector<[8]xf32> into vector<16xf32>
+}
+
+// -----
+
+func.func @extract_vector_invalid_source_vector_size(%arg0 : vector<[16385]xi8>) {
+  // expected-error at +1 {{op failed to verify that vectors are not bigger than 2^17 bits.}}
+  %0 = llvm.intr.vector.extract %arg0[0] : vector<16xi8> from vector<[16385]xi8>
+}
+
+// -----
+
+func.func @extract_vector_invalid_result_vector_size(%arg0 : vector<[16]xi8>) {
+  // expected-error at +1 {{op failed to verify that vectors are not bigger than 2^17 bits.}}
+  %0 = llvm.intr.vector.extract %arg0[0] : vector<16385xi8> from vector<[16]xi8>
+}
+
+// -----
+
+func.func @extract_scalable_from_fixed_length_vector(%arg0 : vector<16xf32>) {
+  // expected-error at +1 {{op failed to verify that it is not extracting scalable from fixed-length vectors.}}
+  %0 = llvm.intr.vector.extract %arg0[0] : vector<[8]xf32> from vector<16xf32>
+}

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 50a50af8eb7cf..9af27d4d1d39a 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -305,6 +305,23 @@ func.func @scalable_vect(%arg0: vector<[4]xf32>, %arg1: i32, %arg2: f32) {
   return
 }
 
+// CHECK-LABEL: @mixed_vect
+func.func @mixed_vect(%arg0: vector<8xf32>, %arg1: vector<4xf32>, %arg2: vector<[4]xf32>) {
+  // CHECK: = llvm.intr.vector.insert {{.*}} : vector<8xf32> into vector<[4]xf32>
+  %0 = llvm.intr.vector.insert %arg0, %arg2[0] : vector<8xf32> into vector<[4]xf32>
+  // CHECK: = llvm.intr.vector.insert {{.*}} : vector<4xf32> into vector<[4]xf32>
+  %1 = llvm.intr.vector.insert %arg1, %arg2[0] : vector<4xf32> into vector<[4]xf32>
+  // CHECK: = llvm.intr.vector.insert {{.*}} : vector<4xf32> into vector<[4]xf32>
+  %2 = llvm.intr.vector.insert %arg1, %1[4] : vector<4xf32> into vector<[4]xf32>
+  // CHECK: = llvm.intr.vector.insert {{.*}} : vector<4xf32> into vector<8xf32>
+  %3 = llvm.intr.vector.insert %arg1, %arg0[4] : vector<4xf32> into vector<8xf32>
+  // CHECK: = llvm.intr.vector.extract {{.*}} : vector<8xf32> from vector<[4]xf32>
+  %4 = llvm.intr.vector.extract %2[0] : vector<8xf32> from vector<[4]xf32>
+  // CHECK: = llvm.intr.vector.extract {{.*}} : vector<2xf32> from vector<8xf32>
+  %5 = llvm.intr.vector.extract %arg0[6] : vector<2xf32> from vector<8xf32>
+  return
+}
+
 // CHECK-LABEL: @alloca
 func.func @alloca(%size : i64) {
   // CHECK: llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr<i32>

diff  --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index b0a07c1d6c59e..b9145be374662 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -680,6 +680,33 @@ llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>,
   llvm.return
 }
 
+// CHECK-LABEL: @vector_insert_extract
+llvm.func @vector_insert_extract(%f256: vector<8xi32>, %f128: vector<4xi32>,
+                                 %sv: vector<[4]xi32>) {
+  // CHECK: call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v8i32
+  %0 = llvm.intr.vector.insert %f256, %sv[0] :
+              vector<8xi32> into vector<[4]xi32>
+  // CHECK: call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32
+  %1 = llvm.intr.vector.insert %f128, %sv[0] :
+              vector<4xi32> into vector<[4]xi32>
+  // CHECK: call <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32
+  %2 = llvm.intr.vector.insert %f128, %1[4] :
+              vector<4xi32> into vector<[4]xi32>
+  // CHECK: call <8 x i32> @llvm.vector.insert.v8i32.v4i32
+  %3 = llvm.intr.vector.insert %f128, %f256[4] :
+              vector<4xi32> into vector<8xi32>
+  // CHECK: call <8 x i32> @llvm.vector.extract.v8i32.nxv4i32
+  %4 = llvm.intr.vector.extract %2[0] :
+              vector<8xi32> from vector<[4]xi32>
+  // CHECK: call <4 x i32> @llvm.vector.extract.v4i32.nxv4i32
+  %5 = llvm.intr.vector.extract %2[0] :
+              vector<4xi32> from vector<[4]xi32>
+  // CHECK: call <2 x i32> @llvm.vector.extract.v2i32.v8i32
+  %6 = llvm.intr.vector.extract %f256[6] :
+              vector<2xi32> from vector<8xi32>
+  llvm.return
+}
+
 // Check that intrinsics are declared with appropriate types.
 // CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
 // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
@@ -781,3 +808,9 @@ llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>,
 // CHECK-DAG: declare <8 x i64> @llvm.vp.fptosi.v8i64.v8f64(<8 x double>, <8 x i1>, i32) #2
 // CHECK-DAG: declare <8 x i64> @llvm.vp.ptrtoint.v8i64.v8p0(<8 x ptr>, <8 x i1>, i32) #2
 // CHECK-DAG: declare <8 x ptr> @llvm.vp.inttoptr.v8p0.v8i64(<8 x i64>, <8 x i1>, i32) #2
+// CHECK-DAG: declare <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v8i32(<vscale x 4 x i32>, <8 x i32>, i64 immarg) #2
+// CHECK-DAG: declare <vscale x 4 x i32> @llvm.vector.insert.nxv4i32.v4i32(<vscale x 4 x i32>, <4 x i32>, i64 immarg) #2
+// CHECK-DAG: declare <8 x i32> @llvm.vector.insert.v8i32.v4i32(<8 x i32>, <4 x i32>, i64 immarg) #2
+// CHECK-DAG: declare <8 x i32> @llvm.vector.extract.v8i32.nxv4i32(<vscale x 4 x i32>, i64 immarg) #2
+// CHECK-DAG: declare <4 x i32> @llvm.vector.extract.v4i32.nxv4i32(<vscale x 4 x i32>, i64 immarg) #2
+// CHECK-DAG: declare <2 x i32> @llvm.vector.extract.v2i32.v8i32(<8 x i32>, i64 immarg) #2


        


More information about the Mlir-commits mailing list