[Mlir-commits] [mlir] 7dcca62 - [mlir][ArmSVE] Add `arm_sve.zip.x2` and `arm_sve.zip.x4` ops (#81278)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 16 03:34:37 PST 2024
Author: Benjamin Maxwell
Date: 2024-02-16T11:34:34Z
New Revision: 7dcca6213258bf2df2dd8a7d555c9a12c1484759
URL: https://github.com/llvm/llvm-project/commit/7dcca6213258bf2df2dd8a7d555c9a12c1484759
DIFF: https://github.com/llvm/llvm-project/commit/7dcca6213258bf2df2dd8a7d555c9a12c1484759.diff
LOG: [mlir][ArmSVE] Add `arm_sve.zip.x2` and `arm_sve.zip.x4` ops (#81278)
This adds ops for the two and four-way SME 2 multi-vector zips.
See:
-
https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--two-registers---Interleave-elements-from-two-vectors-?lang=en
-
https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--four-registers---Interleave-elements-from-four-vectors-?lang=en
Added:
Modified:
mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSVE/invalid.mlir
mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
mlir/test/Dialect/ArmSVE/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index f237f232487e50..f2d330c98e7d6d 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -49,6 +49,12 @@ def SVBoolMask : VectorWithTrailingDimScalableOfSizeAndType<
def SVEPredicateMask : VectorWithTrailingDimScalableOfSizeAndType<
[16, 8, 4, 2, 1], [I1]>;
+// A constraint for a 1-D scalable vector of `length`.
+class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedContainerType<
+ elementTypes, And<[IsVectorOfShape<[length]>, IsVectorTypeWithAnyDimScalablePred]>,
+ "a 1-D scalable vector with length " # length,
+ "::mlir::VectorType">;
+
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -321,6 +327,121 @@ def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
let assemblyFormat = "$source attr-dict `:` type($source)";
}
+// Inputs valid for the multi-vector zips (not including the 128-bit element zipqs)
+def ZipInputVectorType : AnyTypeOf<[
+ Scalable1DVectorOfLength<2, [I64, F64]>,
+ Scalable1DVectorOfLength<4, [I32, F32]>,
+ Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+ Scalable1DVectorOfLength<16, [I8]>],
+ "an SVE vector with element size <= 64-bit">;
+
+def ZipX2Op : ArmSVE_Op<"zip.x2", [
+ Pure,
+ AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
+> {
+ let summary = "Multi-vector two-way zip op";
+
+ let description = [{
+ This operation interleaves elements from two input SVE vectors, returning
+ two new SVE vectors (`resultV1` and `resultV2`), which contain the low and
+ high halves of the result respectively.
+
+ Example:
+ ```mlir
+ // sourceV1 = [ A1, A2, A3, ... An ]
+ // sourceV2 = [ B1, B2, B3, ... Bn ]
+ // (resultV1, resultV2) = [ A1, B1, A2, B2, A3, B3, ... An, Bn ]
+ %resultV1, %resultV2 = arm_sve.zip.x2 %sourceV1, %sourceV2 : vector<[16]xi8>
+ ```
+
+ Note: This requires SME 2 (`+sme2` in LLVM target features)
+
+ [Source](https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--two-registers---Interleave-elements-from-two-vectors-?lang=en)
+ }];
+
+ let arguments = (ins ZipInputVectorType:$sourceV1,
+ ZipInputVectorType:$sourceV2);
+
+ let results = (outs ZipInputVectorType:$resultV1,
+ ZipInputVectorType:$resultV2);
+
+ let builders = [
+ OpBuilder<(ins "Value":$v1, "Value":$v2), [{
+ build($_builder, $_state, v1.getType(), v1.getType(), v1, v2);
+ }]>];
+
+ let assemblyFormat = "$sourceV1 `,` $sourceV2 attr-dict `:` type($sourceV1)";
+
+ let extraClassDeclaration = [{
+ VectorType getVectorType() {
+ return ::llvm::cast<VectorType>(getSourceV1().getType());
+ }
+ }];
+}
+
+def ZipX4Op : ArmSVE_Op<"zip.x4", [
+ Pure,
+ AllTypesMatch<[
+ "sourceV1", "sourceV2", "sourceV3", "sourceV4",
+ "resultV1", "resultV2", "resultV3", "resultV4"]>]
+> {
+ let summary = "Multi-vector four-way zip op";
+
+ let description = [{
+ This operation interleaves elements from four input SVE vectors, returning
+ four new SVE vectors, each of which contain a quarter of the result. The
+ first quarter will be in `resultV1`, second in `resultV2`, third in
+ `resultV3`, and fourth in `resultV4`.
+
+ ```mlir
+ // sourceV1 = [ A1, A2, ... An ]
+ // sourceV2 = [ B1, B2, ... Bn ]
+ // sourceV3 = [ C1, C2, ... Cn ]
+ // sourceV4 = [ D1, D2, ... Dn ]
+ // (resultV1, resultV2, resultV3, resultV4)
+ // = [ A1, B1, C1, D1, A2, B2, C2, D2, ... An, Bn, Cn, Dn ]
+ %resultV1, %resultV2, %resultV3, %resultV4 = arm_sve.zip.x4
+ %sourceV1, %sourceV2, %sourceV3, %sourceV4 : vector<[16]xi8>
+ ```
+
+ **Warning:** The result of this op is undefined for 64-bit elements on
+ hardware with less than 256-bit vectors!
+
+ Note: This requires SME 2 (`+sme2` in LLVM target features)
+
+ [Source](https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--four-registers---Interleave-elements-from-four-vectors-?lang=en)
+ }];
+
+ let arguments = (ins ZipInputVectorType:$sourceV1,
+ ZipInputVectorType:$sourceV2,
+ ZipInputVectorType:$sourceV3,
+ ZipInputVectorType:$sourceV4);
+
+ let results = (outs ZipInputVectorType:$resultV1,
+ ZipInputVectorType:$resultV2,
+ ZipInputVectorType:$resultV3,
+ ZipInputVectorType:$resultV4);
+
+ let builders = [
+ OpBuilder<(ins "Value":$v1, "Value":$v2, "Value":$v3, "Value":$v4), [{
+ build($_builder, $_state,
+ v1.getType(), v1.getType(),
+ v1.getType(), v1.getType(),
+ v1, v2, v3, v4);
+ }]>];
+
+ let assemblyFormat = [{
+ $sourceV1 `,` $sourceV2 `,` $sourceV3 `,` $sourceV4 attr-dict
+ `:` type($sourceV1)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getVectorType() {
+ return ::llvm::cast<VectorType>(getSourceV1().getType());
+ }
+ }];
+}
+
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
[Commutative]>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 32c87c1b824074..387937e811ced8 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -137,6 +137,9 @@ using ConvertToSvboolOpLowering =
using ConvertFromSvboolOpLowering =
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
+using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
+using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
+
} // namespace
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -163,7 +166,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ScalableMaskedUDivIOpLowering,
ScalableMaskedDivFOpLowering,
ConvertToSvboolOpLowering,
- ConvertFromSvboolOpLowering>(converter);
+ ConvertFromSvboolOpLowering,
+ ZipX2OpLowering,
+ ZipX4OpLowering>(converter);
// clang-format on
}
@@ -184,7 +189,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedUDivIIntrOp,
ScalableMaskedDivFIntrOp,
ConvertToSvboolIntrOp,
- ConvertFromSvboolIntrOp>();
+ ConvertFromSvboolIntrOp,
+ ZipX2IntrOp,
+ ZipX4IntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
@@ -199,6 +206,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedUDivIOp,
ScalableMaskedDivFOp,
ConvertToSvboolOp,
- ConvertFromSvboolOp>();
+ ConvertFromSvboolOp,
+ ZipX2Op,
+ ZipX4Op>();
// clang-format on
}
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
index a1fa0d0292b7b7..1258d3532c049c 100644
--- a/mlir/test/Dialect/ArmSVE/invalid.mlir
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -49,3 +49,18 @@ func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8
}
+// -----
+
+func.func @arm_sve_zip_x2_bad_vector_type(%a : vector<[7]xi8>) {
+ // expected-error at +1 {{op operand #0 must be an SVE vector with element size <= 64-bit, but got 'vector<[7]xi8>'}}
+ arm_sve.zip.x2 %a, %a : vector<[7]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
+ // expected-error at +1 {{op operand #0 must be an SVE vector with element size <= 64-bit, but got 'vector<[5]xf64>'}}
+ arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8e76fb7119b844..8d11c2bcaa8d51 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -187,3 +187,27 @@ func.func @convert_2d_mask_from_svbool(%svbool: vector<3x[16]xi1>) -> vector<3x[
// CHECK-NEXT: llvm.return %[[MASK]] : !llvm.array<3 x vector<[1]xi1>>
return %mask : vector<3x[1]xi1>
}
+
+// -----
+
+func.func @arm_sve_zip_x2(%a: vector<[8]xi16>, %b: vector<[8]xi16>)
+ -> (vector<[8]xi16>, vector<[8]xi16>)
+{
+ // CHECK: arm_sve.intr.zip.x2
+ %0, %1 = arm_sve.zip.x2 %a, %b : vector<[8]xi16>
+ return %0, %1 : vector<[8]xi16>, vector<[8]xi16>
+}
+
+// -----
+
+func.func @arm_sve_zip_x4(
+ %a: vector<[16]xi8>,
+ %b: vector<[16]xi8>,
+ %c: vector<[16]xi8>,
+ %d: vector<[16]xi8>
+) -> (vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>)
+{
+ // CHECK: arm_sve.intr.zip.x4
+ %0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
+ return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index c9a0b6db8fa803..f7b79aa2f275c4 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -163,3 +163,65 @@ func.func @arm_sve_convert_from_svbool(%a: vector<[16]xi1>,
return
}
+
+// -----
+
+func.func @arm_sve_zip_x2(
+ %v1: vector<[2]xi64>,
+ %v2: vector<[2]xf64>,
+ %v3: vector<[4]xi32>,
+ %v4: vector<[4]xf32>,
+ %v5: vector<[8]xi16>,
+ %v6: vector<[8]xf16>,
+ %v7: vector<[8]xbf16>,
+ %v8: vector<[16]xi8>
+) {
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[2]xi64>
+ %a1, %b1 = arm_sve.zip.x2 %v1, %v1 : vector<[2]xi64>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[2]xf64>
+ %a2, %b2 = arm_sve.zip.x2 %v2, %v2 : vector<[2]xf64>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[4]xi32>
+ %a3, %b3 = arm_sve.zip.x2 %v3, %v3 : vector<[4]xi32>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[4]xf32>
+ %a4, %b4 = arm_sve.zip.x2 %v4, %v4 : vector<[4]xf32>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xi16>
+ %a5, %b5 = arm_sve.zip.x2 %v5, %v5 : vector<[8]xi16>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xf16>
+ %a6, %b6 = arm_sve.zip.x2 %v6, %v6 : vector<[8]xf16>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xbf16>
+ %a7, %b7 = arm_sve.zip.x2 %v7, %v7 : vector<[8]xbf16>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[16]xi8>
+ %a8, %b8 = arm_sve.zip.x2 %v8, %v8 : vector<[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sve_zip_x4(
+ %v1: vector<[2]xi64>,
+ %v2: vector<[2]xf64>,
+ %v3: vector<[4]xi32>,
+ %v4: vector<[4]xf32>,
+ %v5: vector<[8]xi16>,
+ %v6: vector<[8]xf16>,
+ %v7: vector<[8]xbf16>,
+ %v8: vector<[16]xi8>
+) {
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[2]xi64>
+ %a1, %b1, %c1, %d1 = arm_sve.zip.x4 %v1, %v1, %v1, %v1 : vector<[2]xi64>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[2]xf64>
+ %a2, %b2, %c2, %d2 = arm_sve.zip.x4 %v2, %v2, %v2, %v2 : vector<[2]xf64>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[4]xi32>
+ %a3, %b3, %c3, %d3 = arm_sve.zip.x4 %v3, %v3, %v3, %v3 : vector<[4]xi32>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[4]xf32>
+ %a4, %b4, %c4, %d4 = arm_sve.zip.x4 %v4, %v4, %v4, %v4 : vector<[4]xf32>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xi16>
+ %a5, %b5, %c5, %d5 = arm_sve.zip.x4 %v5, %v5, %v5, %v5 : vector<[8]xi16>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xf16>
+ %a6, %b6, %c6, %d6 = arm_sve.zip.x4 %v6, %v6, %v6, %v6 : vector<[8]xf16>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xbf16>
+ %a7, %b7, %c7, %d7 = arm_sve.zip.x4 %v7, %v7, %v7, %v7 : vector<[8]xbf16>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[16]xi8>
+ %a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8>
+ return
+}
More information about the Mlir-commits
mailing list