[Mlir-commits] [mlir] [mlir][ArmSVE] Add convert.from/to.svbool intrinsics (PR #68418)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Oct 6 06:54:56 PDT 2023


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/68418

These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not support this for any type smaller
than an svbool, which is vector<[16]xi1>).

Depends on #68399 



>From 6b24bac72e700df0439d75e70be43a187270e1f9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 5 Oct 2023 12:57:13 +0000
Subject: [PATCH 1/2] [mlir][ArmSVE] Restructure sources to match ArmSME
 dialect (NFC)

This rearranges the Arm SVE dialect to have the same structure of the
Arm SME dialect. So this just moves around some source files and adds a
ArmSVE_IntrOp base class for SVE intrinsics. This makes later changes
a little easier and more consistent other dialects.
---
 mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt  |  7 +------
 .../mlir/Dialect/ArmSVE/{ => IR}/ArmSVE.td       | 16 +++++++++++-----
 .../mlir/Dialect/ArmSVE/{ => IR}/ArmSVEDialect.h |  4 ++--
 .../mlir/Dialect/ArmSVE/IR/CMakeLists.txt        |  6 ++++++
 .../Dialect/ArmSVE/{ => Transforms}/Transforms.h |  0
 mlir/include/mlir/InitAllDialects.h              |  2 +-
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp     |  4 ++--
 mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp     | 10 +++++-----
 .../ArmSVE/Transforms/LegalizeForLLVMExport.cpp  |  4 ++--
 .../Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp |  4 ++--
 10 files changed, 32 insertions(+), 25 deletions(-)
 rename mlir/include/mlir/Dialect/ArmSVE/{ => IR}/ArmSVE.td (95%)
 rename mlir/include/mlir/Dialect/ArmSVE/{ => IR}/ArmSVEDialect.h (89%)
 create mode 100644 mlir/include/mlir/Dialect/ArmSVE/IR/CMakeLists.txt
 rename mlir/include/mlir/Dialect/ArmSVE/{ => Transforms}/Transforms.h (100%)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
index 06595b7088a1e1b..f33061b2d87cffc 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -1,6 +1 @@
-add_mlir_dialect(ArmSVE arm_sve ArmSVE)
-add_mlir_doc(ArmSVE ArmSVE Dialects/ -gen-dialect-doc -dialect=arm_sve)
-
-set(LLVM_TARGET_DEFINITIONS ArmSVE.td)
-mlir_tablegen(ArmSVEConversions.inc -gen-llvmir-conversions)
-add_public_tablegen_target(MLIRArmSVEConversionsIncGen)
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
similarity index 95%
rename from mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
rename to mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 5c86df1ef21f4b9..22f57a21fb9ae24 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -28,7 +28,6 @@ def ArmSVE_Dialect : Dialect {
     This dialect contains the definitions necessary to target specific Arm SVE
     scalable vector operations.
   }];
-
 }
 
 //===----------------------------------------------------------------------===//
@@ -38,16 +37,23 @@ def ArmSVE_Dialect : Dialect {
 class ArmSVE_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSVE_Dialect, mnemonic, traits> {}
 
-class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
-                                    list<Trait> traits = []> :
+class ArmSVE_IntrOp<string mnemonic,
+                                    list<Trait> traits = [],
+                                    list<int> overloadedOperands = [],
+                                    list<int> overloadedResults = []> :
   LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
                   /*string opName=*/"intr." # mnemonic,
                   /*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
-                  /*list<int> overloadedResults=*/[0],
-                  /*list<int> overloadedOperands=*/[], // defined by result overload
+                  /*list<int> overloadedResults=*/overloadedResults,
+                  /*list<int> overloadedOperands=*/overloadedOperands,
                   /*list<Trait> traits=*/traits,
                   /*int numResults=*/1>;
 
+class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
+                                    list<Trait> traits = []>:
+  ArmSVE_IntrOp<mnemonic, traits,
+    /*overloadedOperands=*/[], /*overloadedResults=*/[0]>;
+
 class ScalableMaskedFOp<string mnemonic, string op_description,
                         list<Trait> traits = []> :
   ArmSVE_Op<mnemonic, !listconcat(traits,
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h
similarity index 89%
rename from mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
rename to mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h
index 6ecee1e4552b95c..a424e109b06d3e8 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h
@@ -19,9 +19,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h.inc"
 
 #define GET_OP_CLASSES
-#include "mlir/Dialect/ArmSVE/ArmSVE.h.inc"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVE.h.inc"
 
 #endif // MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/IR/CMakeLists.txt
new file mode 100644
index 000000000000000..06595b7088a1e1b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_dialect(ArmSVE arm_sve ArmSVE)
+add_mlir_doc(ArmSVE ArmSVE Dialects/ -gen-dialect-doc -dialect=arm_sve)
+
+set(LLVM_TARGET_DEFINITIONS ArmSVE.td)
+mlir_tablegen(ArmSVEConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRArmSVEConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
similarity index 100%
rename from mlir/include/mlir/Dialect/ArmSVE/Transforms.h
rename to mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 5ec36a7f289e586..d04ed373ecf045a 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -24,7 +24,7 @@
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
-#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/Async/IR/Async.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 2929823bad32adf..b865a2671fff762 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -17,8 +17,8 @@
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
-#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
-#include "mlir/Dialect/ArmSVE/Transforms.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index 4af836a93c2a169..b7f1020deba1e40 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -10,7 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -38,17 +38,17 @@ static Type getI1SameShape(Type type) {
 // Tablegen Definitions
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/ArmSVE/ArmSVEDialect.cpp.inc"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.cpp.inc"
 
 #define GET_OP_CLASSES
-#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVE.cpp.inc"
 
 #define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVETypes.cpp.inc"
 
 void ArmSVEDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
-#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVE.cpp.inc"
       >();
 }
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index b6723a52e177f5f..abbb978304068e2 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -8,8 +8,8 @@
 
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
-#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
-#include "mlir/Dialect/ArmSVE/Transforms.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp
index bc1f0e934fa02f4..cd10811b68f0288 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp
@@ -11,7 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
-#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 
@@ -35,7 +35,7 @@ class ArmSVEDialectLLVMIRTranslationInterface
   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const final {
     Operation &opInst = *op;
-#include "mlir/Dialect/ArmSVE/ArmSVEConversions.inc"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEConversions.inc"
 
     return failure();
   }

>From fac5375a7a94afd99bcb6c7766bb4a45782149cf Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 6 Oct 2023 12:56:37 +0000
Subject: [PATCH 2/2] [mlir][ArmSVE] Add convert.from/to.svbool intrinsics

These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not suppor this for any type smaller
than an svbool, which is vector<[16]xi1>).
---
 mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td | 24 ++++++++++
 mlir/test/Target/LLVMIR/arm-sve.mlir          | 44 +++++++++++++++++++
 2 files changed, 68 insertions(+)

diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 22f57a21fb9ae24..2d8c7ca4fb109f1 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -30,6 +30,16 @@ def ArmSVE_Dialect : Dialect {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ArmSVE type definitions
+//===----------------------------------------------------------------------===//
+
+def SVBool : ScalableVectorOfRankAndLengthAndType<
+  [1], [16], [I1]>;
+
+def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
+  [1], [16, 8, 4, 2, 1], [I1]>;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -302,4 +312,18 @@ def ScalableMaskedDivFIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
   Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
 
+def ConvertFromSvboolIntrOp :
+  ArmSVE_IntrOp<"convert.from.svbool",
+    [TypeIs<"res", SVEPredicate>],
+    /*overloadedOperands=*/[],
+    /*overloadedResults=*/[0]>,
+  Arguments<(ins SVBool:$svbool)>;
+
+def ConvertToSvboolIntrOp :
+  ArmSVE_IntrOp<"convert.to.svbool",
+    [TypeIs<"res", SVBool>],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[]>,
+    Arguments<(ins SVEPredicate:$mask)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 999df8079e0727a..172a2f7d12d440e 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -272,3 +272,47 @@ llvm.func @get_vector_scale() -> i64 {
   %0 = "llvm.intr.vscale"() : () -> i64
   llvm.return %0 : i64
 }
+
+// CHECK-LABEL: @arm_sve_convert_from_svbool(
+// CHECK-SAME:                               <vscale x 16 x i1> %[[SVBOOL:[0-9]+]])
+llvm.func @arm_sve_convert_from_svbool(%nxv16i1 : vector<[16]xi1>) {
+  // CHECK: %[[RES0:.*]] = call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res0 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[8]xi1>
+  // CHECK: %[[RES1:.*]] = call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res1 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[4]xi1>
+  // CHECK: %[[RES2:.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res2 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[2]xi1>
+  // CHECK: %[[RES3:.*]] = call <vscale x 1 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv1i1(<vscale x 16 x i1> %[[SVBOOL]])
+  %res3 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+    : (vector<[16]xi1>) -> vector<[1]xi1>
+  llvm.return
+}
+
+// CHECK-LABEL: arm_sve_convert_to_svbool(
+// CHECK-SAME:                            <vscale x 8 x i1> %[[P8:[0-9]+]],
+// CHECK-SAME:                            <vscale x 4 x i1> %[[P4:[0-9]+]],
+// CHECK-SAME:                            <vscale x 2 x i1> %[[P2:[0-9]+]],
+// CHECK-SAME:                            <vscale x 1 x i1> %[[P1:[0-9]+]])
+llvm.func @arm_sve_convert_to_svbool(
+                                       %nxv8i1  : vector<[8]xi1>,
+                                       %nxv4i1  : vector<[4]xi1>,
+                                       %nxv2i1  : vector<[2]xi1>,
+                                       %nxv1i1  : vector<[1]xi1>
+) {
+  // CHECK-NEXT: %[[RES0:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %[[P8]])
+  %res0 = "arm_sve.intr.convert.to.svbool"(%nxv8i1)
+    : (vector<[8]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES1:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %[[P4]])
+  %res1 = "arm_sve.intr.convert.to.svbool"(%nxv4i1)
+    : (vector<[4]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES2:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %[[P2]])
+  %res2 = "arm_sve.intr.convert.to.svbool"(%nxv2i1)
+    : (vector<[2]xi1>) -> vector<[16]xi1>
+  // CHECK-NEXT: %[[RES3:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv1i1(<vscale x 1 x i1> %[[P1]])
+  %res3 = "arm_sve.intr.convert.to.svbool"(%nxv1i1)
+    : (vector<[1]xi1>) -> vector<[16]xi1>
+  llvm.return
+}



More information about the Mlir-commits mailing list