[Mlir-commits] [mlir] [mlir][llvm] add experimental.vector.interleave2 intrinsic (PR #79270)

Cullen Rhodes llvmlistbot at llvm.org
Mon Jan 29 04:12:21 PST 2024


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/79270

>From f1fe6a22df1ef9ac6d54e4d26afcdb03c2faa173 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 24 Jan 2024 10:15:14 +0000
Subject: [PATCH 1/5] [mlir][ArmSVE] add zip1 intrinsic

---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 4 ++++
 mlir/test/Target/LLVMIR/arm-sve.mlir          | 7 +++++++
 2 files changed, 11 insertions(+)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index e3f3d9e62e8fb39..754413a1ad491ec 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -410,4 +410,8 @@ def ConvertToSvboolIntrOp :
     /*overloadedResults=*/[]>,
     Arguments<(ins SVEPredicate:$mask)>;
 
+def Zip1IntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"zip1">,
+  Arguments<(ins AnyScalableVector, AnyScalableVector)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index b63d3f06515690a..002b1f9d804a7ce 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -314,3 +314,10 @@ llvm.func @arm_sve_convert_to_svbool(
     : (vector<[1]xi1>) -> vector<[16]xi1>
   llvm.return
 }
+
+// CHECK-LABEL: @arm_sve_zip1
+// CHECK-NEXT: call <vscale x 8 x half> @llvm.aarch64.sve.zip1.nxv8f16(<vscale x 8 x half> %{{.*}}, <vscale x 8 x half> {{.*}})
+llvm.func @arm_sve_zip1(%arg0 : vector<[8]xf16>) -> vector<[8]xf16> {
+  %0 = "arm_sve.intr.zip1"(%arg0, %arg0) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+  llvm.return %0 : vector<[8]xf16>
+}

>From 0686a46ea03ae0683d42204cce05ec2d4a10ae29 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 26 Jan 2024 13:51:13 +0000
Subject: [PATCH 2/5] replace with interleave2 intrinsic

---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td   |  4 ----
 .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td     | 14 ++++++++++++++
 mlir/test/Dialect/LLVMIR/invalid.mlir           | 17 +++++++++++++++++
 mlir/test/Dialect/LLVMIR/roundtrip.mlir         |  7 +++++++
 mlir/test/Target/LLVMIR/arm-sve.mlir            |  7 -------
 5 files changed, 38 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 754413a1ad491ec..e3f3d9e62e8fb39 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -410,8 +410,4 @@ def ConvertToSvboolIntrOp :
     /*overloadedResults=*/[]>,
     Arguments<(ins SVEPredicate:$mask)>;
 
-def Zip1IntrOp :
-  ArmSVE_IntrBinaryOverloadedOp<"zip1">,
-  Arguments<(ins AnyScalableVector, AnyScalableVector)>;
-
 #endif // ARMSVE_OPS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index a4f08fb92da9035..b0c9f0d2d030090 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -969,6 +969,20 @@ def LLVM_vector_extract
   }];
 }
 
+class HalfElementsVectorTypeConstraint<string lhs, string rhs>
+    : TypesMatchWith<
+        "'" # rhs # "'" # " has half elements of '" # lhs # "'",
+      lhs, rhs,
+      "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self))"
+        ".setDim(0, ::llvm::cast<VectorType>($_self).getShape()[0] / 2))">;
+
+def LLVM_experimental_vector_interleave2
+    : LLVM_OneResultIntrOp<"experimental.vector.interleave2",
+                           /*overloadedResults=*/[0], /*overloadedOperands=*/[],
+                           /*traits=*/[Pure, AllTypesMatch<["vec1", "vec2"]>,
+                                       HalfElementsVectorTypeConstraint<"res", "vec1">]>,
+      Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
+
 //
 // LLVM Vector Predication operations.
 //
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index d72ff8ca3c3aa7d..c8fad3c70594728 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1218,6 +1218,23 @@ func.func @extract_scalable_from_fixed_length_vector(%arg0 : vector<16xf32>) {
   %0 = llvm.intr.vector.extract %arg0[0] : vector<[8]xf32> from vector<16xf32>
 }
 
+
+// -----
+
+func.func @experimental_vector_interleave2_bad_type0(%vec1: vector<[2]xf16>, %vec2 : vector<[4]xf16>) {
+  // expected-error at +1 {{op failed to verify that all of {vec1, vec2} have same type}}
+  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[2]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+  return
+}
+
+// -----
+
+func.func @experimental_vector_interleave2_bad_type1(%vec1: vector<[2]xf16>, %vec2 : vector<[2]xf16>) {
+  // expected-error at +1 {{op failed to verify that 'vec1' has half elements of 'res'}}
+  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[2]xf16>, vector<[2]xf16>) -> vector<[8]xf16>
+  return
+}
+
 // -----
 
 func.func @invalid_bitcast_ptr_to_i64(%arg : !llvm.ptr) {
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 1958dd56bab7ad7..b157cf001418425 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -342,6 +342,13 @@ func.func @mixed_vect(%arg0: vector<8xf32>, %arg1: vector<4xf32>, %arg2: vector<
   return
 }
 
+// CHECK-LABEL: @experimental_vector_interleave2
+func.func @experimental_vector_interleave2(%vec1: vector<[4]xf16>, %vec2 : vector<[4]xf16>) {
+  // CHECK: = "llvm.intr.experimental.vector.interleave2"({{.*}}) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+  return
+}
+
 // CHECK-LABEL: @alloca
 func.func @alloca(%size : i64) {
   // CHECK: llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 002b1f9d804a7ce..b63d3f06515690a 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -314,10 +314,3 @@ llvm.func @arm_sve_convert_to_svbool(
     : (vector<[1]xi1>) -> vector<[16]xi1>
   llvm.return
 }
-
-// CHECK-LABEL: @arm_sve_zip1
-// CHECK-NEXT: call <vscale x 8 x half> @llvm.aarch64.sve.zip1.nxv8f16(<vscale x 8 x half> %{{.*}}, <vscale x 8 x half> {{.*}})
-llvm.func @arm_sve_zip1(%arg0 : vector<[8]xf16>) -> vector<[8]xf16> {
-  %0 = "arm_sve.intr.zip1"(%arg0, %arg0) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
-  llvm.return %0 : vector<[8]xf16>
-}

>From a8ee8ed19f329af037e0ba6018f5f533576920ae Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 26 Jan 2024 16:11:34 +0000
Subject: [PATCH 3/5] check result has even number of elements

---
 .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td   | 24 ++++++++++---------
 mlir/test/Dialect/LLVMIR/invalid.mlir         | 10 +++++++-
 2 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index b0c9f0d2d030090..6b55b66cafb606f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -969,19 +969,21 @@ def LLVM_vector_extract
   }];
 }
 
-class HalfElementsVectorTypeConstraint<string lhs, string rhs>
-    : TypesMatchWith<
-        "'" # rhs # "'" # " has half elements of '" # lhs # "'",
-      lhs, rhs,
-      "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self))"
-        ".setDim(0, ::llvm::cast<VectorType>($_self).getShape()[0] / 2))">;
-
 def LLVM_experimental_vector_interleave2
     : LLVM_OneResultIntrOp<"experimental.vector.interleave2",
-                           /*overloadedResults=*/[0], /*overloadedOperands=*/[],
-                           /*traits=*/[Pure, AllTypesMatch<["vec1", "vec2"]>,
-                                       HalfElementsVectorTypeConstraint<"res", "vec1">]>,
-      Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
+        /*overloadedResults=*/[0], /*overloadedOperands=*/[],
+        /*traits=*/[
+          Pure, AllTypesMatch<["vec1", "vec2"]>,
+          PredOpTrait<
+            "result has an even number of elements",
+            CPred<"::llvm::cast<::mlir::VectorType>($res.getType()).getNumElements() % 2 == 0">>,
+          TypesMatchWith<
+            "'vec1' has half as many elements as result",
+            "res", "vec1",
+            "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self))"
+              ".setDim(0, ::llvm::cast<VectorType>($_self).getShape()[0] / 2))">
+        ]>,
+        Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
 
 //
 // LLVM Vector Predication operations.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index c8fad3c70594728..950de64ce9bbf7e 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1230,13 +1230,21 @@ func.func @experimental_vector_interleave2_bad_type0(%vec1: vector<[2]xf16>, %ve
 // -----
 
 func.func @experimental_vector_interleave2_bad_type1(%vec1: vector<[2]xf16>, %vec2 : vector<[2]xf16>) {
-  // expected-error at +1 {{op failed to verify that 'vec1' has half elements of 'res'}}
+  // expected-error at +1 {{op failed to verify that 'vec1' has half as many elements as result}}
   %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[2]xf16>, vector<[2]xf16>) -> vector<[8]xf16>
   return
 }
 
 // -----
 
+func.func @experimental_vector_interleave2_bad_type2(%vec1: vector<[1]xf16>, %vec2 : vector<[1]xf16>) {
+  // expected-error at +1 {{op failed to verify that result has an even number of elements}}
+  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[1]xf16>, vector<[1]xf16>) -> vector<[3]xf16>
+  return
+}
+
+// -----
+
 func.func @invalid_bitcast_ptr_to_i64(%arg : !llvm.ptr) {
   // expected-error at +1 {{can only cast pointers from and to pointers}}
   %1 = llvm.bitcast %arg : !llvm.ptr to i64

>From 8f919715b7c4908012c4a82169a70eabe4c4a9bd Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 26 Jan 2024 17:44:12 +0000
Subject: [PATCH 4/5] simplify type constraint by reversing check

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 9 +++------
 mlir/test/Dialect/LLVMIR/invalid.mlir                | 9 +--------
 2 files changed, 4 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 6b55b66cafb606f..5234864c0dd9499 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -974,14 +974,11 @@ def LLVM_experimental_vector_interleave2
         /*overloadedResults=*/[0], /*overloadedOperands=*/[],
         /*traits=*/[
           Pure, AllTypesMatch<["vec1", "vec2"]>,
-          PredOpTrait<
-            "result has an even number of elements",
-            CPred<"::llvm::cast<::mlir::VectorType>($res.getType()).getNumElements() % 2 == 0">>,
           TypesMatchWith<
-            "'vec1' has half as many elements as result",
-            "res", "vec1",
+            "result has twice as many elements as 'vec1'",
+            "vec1", "res",
             "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self))"
-              ".setDim(0, ::llvm::cast<VectorType>($_self).getShape()[0] / 2))">
+              ".setDim(0, ::llvm::cast<VectorType>($_self).getShape()[0] * 2))">
         ]>,
         Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 950de64ce9bbf7e..f2873fb5ea71385 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1230,18 +1230,11 @@ func.func @experimental_vector_interleave2_bad_type0(%vec1: vector<[2]xf16>, %ve
 // -----
 
 func.func @experimental_vector_interleave2_bad_type1(%vec1: vector<[2]xf16>, %vec2 : vector<[2]xf16>) {
-  // expected-error at +1 {{op failed to verify that 'vec1' has half as many elements as result}}
+  // expected-error at +1 {{op failed to verify that result has twice as many elements as 'vec1'}}
   %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[2]xf16>, vector<[2]xf16>) -> vector<[8]xf16>
   return
 }
 
-// -----
-
-func.func @experimental_vector_interleave2_bad_type2(%vec1: vector<[1]xf16>, %vec2 : vector<[1]xf16>) {
-  // expected-error at +1 {{op failed to verify that result has an even number of elements}}
-  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[1]xf16>, vector<[1]xf16>) -> vector<[3]xf16>
-  return
-}
 
 // -----
 

>From e7633e9e4012bfcdf0c151bbfcbae8d28d6da7fd Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 29 Jan 2024 12:09:48 +0000
Subject: [PATCH 5/5] vector type cast in type constraint crashes on !llvm.vec
 types

use LLVM vector type APIs instead.
---
 .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td   |  9 +++++----
 mlir/test/Dialect/LLVMIR/invalid.mlir         | 20 +++++++++++++++++++
 2 files changed, 25 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 5234864c0dd9499..feb3578fe2d4966 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -974,11 +974,12 @@ def LLVM_experimental_vector_interleave2
         /*overloadedResults=*/[0], /*overloadedOperands=*/[],
         /*traits=*/[
           Pure, AllTypesMatch<["vec1", "vec2"]>,
-          TypesMatchWith<
+          PredOpTrait<
             "result has twice as many elements as 'vec1'",
-            "vec1", "res",
-            "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self))"
-              ".setDim(0, ::llvm::cast<VectorType>($_self).getShape()[0] * 2))">
+            And<[CPred<"getVectorNumElements($res.getType()) == "
+                       "getVectorNumElements($vec1.getType()) * 2">,
+                 CPred<"getVectorElementType($vec1.getType()) == "
+                       "getVectorElementType($res.getType())">]>>,
         ]>,
         Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index f2873fb5ea71385..de1ab9db8e8df8d 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1235,6 +1235,26 @@ func.func @experimental_vector_interleave2_bad_type1(%vec1: vector<[2]xf16>, %ve
   return
 }
 
+// -----
+
+/// result vector type is not scalable.
+
+func.func @experimental_vector_interleave2_bad_type2(%vec1: vector<[2]xf16>, %vec2 : vector<[2]xf16>) {
+  // expected-error at +1 {{op failed to verify that result has twice as many elements as 'vec1'}}
+  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[2]xf16>, vector<[2]xf16>) -> vector<4xf16>
+  return
+}
+
+// -----
+
+
+/// element type doesn't match.
+
+func.func @experimental_vector_interleave2_bad_type3(%vec1: vector<[2]xf16>, %vec2 : vector<[2]xf16>) {
+  // expected-error at +1 {{op failed to verify that result has twice as many elements as 'vec1'}}
+  %0 = "llvm.intr.experimental.vector.interleave2"(%vec1, %vec2) : (vector<[2]xf16>, vector<[2]xf16>) -> vector<[4]xf32>
+  return
+}
 
 // -----
 



More information about the Mlir-commits mailing list