[Mlir-commits] [mlir] f880bd2 - [mlir][ArmSVE] Add basic mask generation operations

Javier Setoain llvmlistbot at llvm.org
Wed Jun 9 01:58:58 PDT 2021


Author: Javier Setoain
Date: 2021-06-09T09:56:53+01:00
New Revision: f880bd261f4e13d4d58a75886a2942f05783c7de

URL: https://github.com/llvm/llvm-project/commit/f880bd261f4e13d4d58a75886a2942f05783c7de
DIFF: https://github.com/llvm/llvm-project/commit/f880bd261f4e13d4d58a75886a2942f05783c7de.diff

LOG: [mlir][ArmSVE] Add basic mask generation operations

These `arm_sve.cmp` functions are needed to generate scalable vector
masks as long as scalable vectors are not part of the standard types.
Once in standard, these can be removed and `std.cmp` can be used
instead.

Differential Revision: https://reviews.llvm.org/D103473

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
    mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
    mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
    mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
    mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
    mlir/test/Dialect/ArmSVE/roundtrip.mlir
    mlir/test/Target/LLVMIR/arm-sve.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index e34177bb50940..7114fdb9425a9 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -15,6 +15,7 @@
 
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
 include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td"
 
 //===----------------------------------------------------------------------===//
@@ -398,6 +399,123 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
 
 def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
 
+//===----------------------------------------------------------------------===//
+// ScalableCmpFOp
+//===----------------------------------------------------------------------===//
+
+def ScalableCmpFOp : ArmSVE_Op<"cmpf", [NoSideEffect, SameTypeOperands,
+    TypesMatchWith<"result type has i1 element type and same shape as operands",
+    "lhs", "result", "getI1SameShape($_self)">]> {
+  let summary = "floating-point comparison operation for scalable vectors";
+  let description = [{
+    The `arm_sve.cmpf` operation compares two scalable vectors of floating point
+    elements according to the float comparison rules and the predicate specified
+    by the respective attribute. The predicate defines the type of comparison:
+    (un)orderedness, (in)equality and signed less/greater than (or equal to) as
+    well as predicates that are always true or false. The result is a scalable
+    vector of i1 elements. Unlike `arm_sve.cmpi`, the operands are always
+    treated as signed. The u prefix indicates *unordered* comparison, not
+    unsigned comparison, so "une" means unordered not equal. For the sake of
+    readability by humans, custom assembly form for the operation uses a
+    string-typed attribute for the predicate.  The value of this attribute
+    corresponds to lower-cased name of the predicate constant, e.g., "one" means
+    "ordered not equal".  The string representation of the attribute is merely a
+    syntactic sugar and is converted to an integer attribute by the parser.
+
+    Example:
+
+    ```mlir
+    %r = arm_sve.cmpf oeq, %0, %1 : !arm_sve.vector<4xf32>
+    ```
+  }];
+  let arguments = (ins
+    CmpFPredicateAttr:$predicate,
+    ScalableVectorOf<[AnyFloat]>:$lhs,
+    ScalableVectorOf<[AnyFloat]>:$rhs // TODO: This should support a simple scalar
+  );
+  let results = (outs ScalableVectorOf<[I1]>:$result);
+
+  let builders = [
+    OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs,
+                  "Value":$rhs), [{
+      buildScalableCmpFOp($_builder, $_state, predicate, lhs, rhs);
+    }]>];
+
+  let extraClassDeclaration = [{
+    static StringRef getPredicateAttrName() { return "predicate"; }
+    static CmpFPredicate getPredicateByName(StringRef name);
+
+    CmpFPredicate getPredicate() {
+      return (CmpFPredicate)(*this)->getAttrOfType<IntegerAttr>(
+          getPredicateAttrName()).getInt();
+    }
+  }];
+
+  let verifier = [{ return success(); }];
+
+  let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
+}
+
+//===----------------------------------------------------------------------===//
+// ScalableCmpIOp
+//===----------------------------------------------------------------------===//
+
+def ScalableCmpIOp : ArmSVE_Op<"cmpi", [NoSideEffect, SameTypeOperands,
+    TypesMatchWith<"result type has i1 element type and same shape as operands",
+    "lhs", "result", "getI1SameShape($_self)">]> {
+  let summary = "integer comparison operation for scalable vectors";
+  let description = [{
+    The `arm_sve.cmpi` operation compares two scalable vectors of integer
+    elements according to the predicate specified by the respective attribute.
+
+    The predicate defines the type of comparison:
+
+    -   equal (mnemonic: `"eq"`; integer value: `0`)
+    -   not equal (mnemonic: `"ne"`; integer value: `1`)
+    -   signed less than (mnemonic: `"slt"`; integer value: `2`)
+    -   signed less than or equal (mnemonic: `"sle"`; integer value: `3`)
+    -   signed greater than (mnemonic: `"sgt"`; integer value: `4`)
+    -   signed greater than or equal (mnemonic: `"sge"`; integer value: `5`)
+    -   unsigned less than (mnemonic: `"ult"`; integer value: `6`)
+    -   unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`)
+    -   unsigned greater than (mnemonic: `"ugt"`; integer value: `8`)
+    -   unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`)
+
+    Example:
+
+    ```mlir
+    %r = arm_sve.cmpi uge, %0, %1 : !arm_sve.vector<4xi32>
+    ```
+  }];
+
+  let arguments = (ins
+      CmpIPredicateAttr:$predicate,
+      ScalableVectorOf<[I8, I16, I32, I64]>:$lhs,
+      ScalableVectorOf<[I8, I16, I32, I64]>:$rhs
+  );
+  let results = (outs ScalableVectorOf<[I1]>:$result);
+
+  let builders = [
+    OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs,
+                 "Value":$rhs), [{
+      buildScalableCmpIOp($_builder, $_state, predicate, lhs, rhs);
+    }]>];
+
+  let extraClassDeclaration = [{
+    static StringRef getPredicateAttrName() { return "predicate"; }
+    static CmpIPredicate getPredicateByName(StringRef name);
+
+    CmpIPredicate getPredicate() {
+      return (CmpIPredicate)(*this)->getAttrOfType<IntegerAttr>(
+          getPredicateAttrName()).getInt();
+    }
+  }];
+
+  let verifier = [{ return success(); }];
+
+  let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
+}
+
 def UmmlaIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"ummla">,
   Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,

diff  --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
index 7689eb4b13788..06fbf5a717ea5 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
@@ -19,6 +19,7 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc"

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 8f863978456c9..cfcda1f24214d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -708,34 +708,6 @@ def FloorFOp : FloatUnaryOp<"floorf"> {
 // CmpFOp
 //===----------------------------------------------------------------------===//
 
-// The predicate indicates the type of the comparison to perform:
-// (un)orderedness, (in)equality and less/greater than (or equal to) as
-// well as predicates that are always true or false.
-def CMPF_P_FALSE   : I64EnumAttrCase<"AlwaysFalse", 0, "false">;
-def CMPF_P_OEQ     : I64EnumAttrCase<"OEQ", 1, "oeq">;
-def CMPF_P_OGT     : I64EnumAttrCase<"OGT", 2, "ogt">;
-def CMPF_P_OGE     : I64EnumAttrCase<"OGE", 3, "oge">;
-def CMPF_P_OLT     : I64EnumAttrCase<"OLT", 4, "olt">;
-def CMPF_P_OLE     : I64EnumAttrCase<"OLE", 5, "ole">;
-def CMPF_P_ONE     : I64EnumAttrCase<"ONE", 6, "one">;
-def CMPF_P_ORD     : I64EnumAttrCase<"ORD", 7, "ord">;
-def CMPF_P_UEQ     : I64EnumAttrCase<"UEQ", 8, "ueq">;
-def CMPF_P_UGT     : I64EnumAttrCase<"UGT", 9, "ugt">;
-def CMPF_P_UGE     : I64EnumAttrCase<"UGE", 10, "uge">;
-def CMPF_P_ULT     : I64EnumAttrCase<"ULT", 11, "ult">;
-def CMPF_P_ULE     : I64EnumAttrCase<"ULE", 12, "ule">;
-def CMPF_P_UNE     : I64EnumAttrCase<"UNE", 13, "une">;
-def CMPF_P_UNO     : I64EnumAttrCase<"UNO", 14, "uno">;
-def CMPF_P_TRUE    : I64EnumAttrCase<"AlwaysTrue", 15, "true">;
-
-def CmpFPredicateAttr : I64EnumAttr<
-    "CmpFPredicate", "",
-    [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE,
-     CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT,
-     CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> {
-  let cppNamespace = "::mlir";
-}
-
 def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
     "result type has i1 element type and same shape as operands",
@@ -801,24 +773,6 @@ def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands,
 // CmpIOp
 //===----------------------------------------------------------------------===//
 
-def CMPI_P_EQ  : I64EnumAttrCase<"eq", 0>;
-def CMPI_P_NE  : I64EnumAttrCase<"ne", 1>;
-def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>;
-def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>;
-def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>;
-def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>;
-def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>;
-def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>;
-def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>;
-def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>;
-
-def CmpIPredicateAttr : I64EnumAttr<
-    "CmpIPredicate", "",
-    [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT,
-     CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> {
-  let cppNamespace = "::mlir";
-}
-
 def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
     "result type has i1 element type and same shape as operands",

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
index 802a32fce370a..e18ae4c3be2db 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
@@ -36,4 +36,50 @@ def AtomicRMWKindAttr : I64EnumAttr<
   let cppNamespace = "::mlir";
 }
 
+// The predicate indicates the type of the comparison to perform:
+// (un)orderedness, (in)equality and less/greater than (or equal to) as
+// well as predicates that are always true or false.
+def CMPF_P_FALSE   : I64EnumAttrCase<"AlwaysFalse", 0, "false">;
+def CMPF_P_OEQ     : I64EnumAttrCase<"OEQ", 1, "oeq">;
+def CMPF_P_OGT     : I64EnumAttrCase<"OGT", 2, "ogt">;
+def CMPF_P_OGE     : I64EnumAttrCase<"OGE", 3, "oge">;
+def CMPF_P_OLT     : I64EnumAttrCase<"OLT", 4, "olt">;
+def CMPF_P_OLE     : I64EnumAttrCase<"OLE", 5, "ole">;
+def CMPF_P_ONE     : I64EnumAttrCase<"ONE", 6, "one">;
+def CMPF_P_ORD     : I64EnumAttrCase<"ORD", 7, "ord">;
+def CMPF_P_UEQ     : I64EnumAttrCase<"UEQ", 8, "ueq">;
+def CMPF_P_UGT     : I64EnumAttrCase<"UGT", 9, "ugt">;
+def CMPF_P_UGE     : I64EnumAttrCase<"UGE", 10, "uge">;
+def CMPF_P_ULT     : I64EnumAttrCase<"ULT", 11, "ult">;
+def CMPF_P_ULE     : I64EnumAttrCase<"ULE", 12, "ule">;
+def CMPF_P_UNE     : I64EnumAttrCase<"UNE", 13, "une">;
+def CMPF_P_UNO     : I64EnumAttrCase<"UNO", 14, "uno">;
+def CMPF_P_TRUE    : I64EnumAttrCase<"AlwaysTrue", 15, "true">;
+
+def CmpFPredicateAttr : I64EnumAttr<
+    "CmpFPredicate", "",
+    [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE,
+     CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT,
+     CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> {
+  let cppNamespace = "::mlir";
+}
+
+def CMPI_P_EQ  : I64EnumAttrCase<"eq", 0>;
+def CMPI_P_NE  : I64EnumAttrCase<"ne", 1>;
+def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>;
+def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>;
+def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>;
+def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>;
+def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>;
+def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>;
+def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>;
+def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>;
+
+def CmpIPredicateAttr : I64EnumAttr<
+    "CmpIPredicate", "",
+    [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT,
+     CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> {
+  let cppNamespace = "::mlir";
+}
+
 #endif // STANDARD_OPS_BASE

diff  --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index b86ba14303f89..5e5ce6ed63bc4 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -20,8 +20,13 @@
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
+using namespace arm_sve;
 
 static Type getI1SameShape(Type type);
+static void buildScalableCmpIOp(OpBuilder &build, OperationState &result,
+                                CmpIPredicate predicate, Value lhs, Value rhs);
+static void buildScalableCmpFOp(OpBuilder &build, OperationState &result,
+                                CmpFPredicate predicate, Value lhs, Value rhs);
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
@@ -29,7 +34,7 @@ static Type getI1SameShape(Type type);
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
 
-void arm_sve::ArmSVEDialect::initialize() {
+void ArmSVEDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
@@ -44,7 +49,7 @@ void arm_sve::ArmSVEDialect::initialize() {
 // ScalableVectorType
 //===----------------------------------------------------------------------===//
 
-Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const {
+Type ArmSVEDialect::parseType(DialectAsmParser &parser) const {
   llvm::SMLoc typeLoc = parser.getCurrentLocation();
   {
     Type genType;
@@ -57,7 +62,7 @@ Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const {
   return Type();
 }
 
-void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
+void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
   if (failed(generatedTypePrinter(type, os)))
     llvm_unreachable("unexpected 'arm_sve' type kind");
 }
@@ -69,8 +74,28 @@ void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
 // Return the scalable vector of the same shape and containing i1.
 static Type getI1SameShape(Type type) {
   auto i1Type = IntegerType::get(type.getContext(), 1);
-  if (auto sVectorType = type.dyn_cast<arm_sve::ScalableVectorType>())
-    return arm_sve::ScalableVectorType::get(type.getContext(),
-                                            sVectorType.getShape(), i1Type);
+  if (auto sVectorType = type.dyn_cast<ScalableVectorType>())
+    return ScalableVectorType::get(type.getContext(), sVectorType.getShape(),
+                                   i1Type);
   return nullptr;
 }
+
+//===----------------------------------------------------------------------===//
+// CmpFOp
+//===----------------------------------------------------------------------===//
+
+static void buildScalableCmpFOp(OpBuilder &build, OperationState &result,
+                                CmpFPredicate predicate, Value lhs, Value rhs) {
+  result.addOperands({lhs, rhs});
+  result.types.push_back(getI1SameShape(lhs.getType()));
+  result.addAttribute(ScalableCmpFOp::getPredicateAttrName(),
+                      build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
+}
+
+static void buildScalableCmpIOp(OpBuilder &build, OperationState &result,
+                                CmpIPredicate predicate, Value lhs, Value rhs) {
+  result.addOperands({lhs, rhs});
+  result.types.push_back(getI1SameShape(lhs.getType()));
+  result.addAttribute(ScalableCmpIOp::getPredicateAttrName(),
+                      build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
+}

diff  --git a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
index 9177b5889b948..4a2393e7ac3d9 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVE
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMIR
+  MLIRStandard
   MLIRSideEffectInterfaces
   )

diff  --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 845f407fba3fb..e43511a87a47b 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -143,6 +143,24 @@ configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
   // clang-format on
 }
 
+static void
+populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
+                                        OwningRewritePatternList &patterns) {
+  // clang-format off
+  patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
+               OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
+              >(converter);
+  // clang-format on
+}
+
+static void
+configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
+  // clang-format off
+  target.addIllegalOp<ScalableCmpFOp,
+                      ScalableCmpIOp>();
+  // clang-format on
+}
+
 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
@@ -175,6 +193,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ScalableMaskedDivFOpLowering>(converter);
   // clang-format on
   populateBasicSVEArithmeticExportPatterns(converter, patterns);
+  populateSVEMaskGenerationExportPatterns(converter, patterns);
 }
 
 void mlir::configureArmSVELegalizeForExportTarget(
@@ -225,4 +244,5 @@ void mlir::configureArmSVELegalizeForExportTarget(
                !hasScalableVectorType(op->getResultTypes());
       });
   configureBasicSVEArithmeticLegalizations(target);
+  configureSVEMaskGenerationLegalizations(target);
 }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a9c88afec9533..512a0ab898e2d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -134,11 +134,15 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
   if (!isCompatibleType(type))
     return parser.emitError(trailingTypeLoc,
                             "expected LLVM dialect-compatible type");
-  if (LLVM::isCompatibleVectorType(type))
-    resultType = LLVM::getFixedVectorType(
-        resultType, LLVM::getVectorNumElements(type).getFixedValue());
-  assert(!type.isa<LLVM::LLVMScalableVectorType>() &&
-         "unhandled scalable vector");
+  if (LLVM::isCompatibleVectorType(type)) {
+    if (type.isa<LLVM::LLVMScalableVectorType>()) {
+      resultType = LLVM::LLVMScalableVectorType::get(
+          resultType, LLVM::getVectorNumElements(type).getKnownMinValue());
+    } else {
+      resultType = LLVM::getFixedVectorType(
+          resultType, LLVM::getVectorNumElements(type).getFixedValue());
+    }
+  }
 
   result.addTypes({resultType});
   return success();

diff  --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 2b2eda0bf32e7..7fd2a63e63df8 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -121,6 +121,46 @@ func @arm_sve_arithf_masked(%a: !arm_sve.vector<4xf32>,
   return %3 : !arm_sve.vector<4xf32>
 }
 
+func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>,
+                        %b: !arm_sve.vector<4xf32>)
+                        -> !arm_sve.vector<4xi1> {
+  // CHECK: llvm.fcmp "oeq" {{.*}}: !llvm.vec<? x 4 x f32>
+  %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32>
+  return %0 : !arm_sve.vector<4xi1>
+}
+
+func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>,
+                        %b: !arm_sve.vector<4xi32>)
+                        -> !arm_sve.vector<4xi1> {
+  // CHECK: llvm.icmp "uge" {{.*}}: !llvm.vec<? x 4 x i32>
+  %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32>
+  return %0 : !arm_sve.vector<4xi1>
+}
+
+func @arm_sve_abs_
diff (%a: !arm_sve.vector<4xi32>,
+                       %b: !arm_sve.vector<4xi32>)
+                       -> !arm_sve.vector<4xi32> {
+  // CHECK: llvm.sub {{.*}}: !llvm.vec<? x 4 x i32>
+  %z = arm_sve.subi %a, %a : !arm_sve.vector<4xi32>
+  // CHECK: llvm.icmp "sge" {{.*}}: !llvm.vec<? x 4 x i32>
+  %agb = arm_sve.cmpi sge, %a, %b : !arm_sve.vector<4xi32>
+  // CHECK: llvm.icmp "slt" {{.*}}: !llvm.vec<? x 4 x i32>
+  %bga = arm_sve.cmpi slt, %a, %b : !arm_sve.vector<4xi32>
+  // CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+  %0 = arm_sve.masked.subi %agb, %a, %b : !arm_sve.vector<4xi1>,
+                                          !arm_sve.vector<4xi32>
+  // CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+  %1 = arm_sve.masked.subi %bga, %b, %a : !arm_sve.vector<4xi1>,
+                                          !arm_sve.vector<4xi32>
+  // CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+  %2 = arm_sve.masked.addi %agb, %z, %0 : !arm_sve.vector<4xi1>,
+                                          !arm_sve.vector<4xi32>
+  // CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+  %3 = arm_sve.masked.addi %bga, %2, %1 : !arm_sve.vector<4xi1>,
+                                          !arm_sve.vector<4xi32>
+  return %3 : !arm_sve.vector<4xi32>
+}
+
 func @get_vector_scale() -> index {
   // CHECK: arm_sve.vscale
   %0 = arm_sve.vector_scale : index

diff  --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 4666d16f33f24..2dde6c32a665e 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -103,6 +103,22 @@ func @arm_sve_masked_arithf(%a: !arm_sve.vector<4xf32>,
   return %3 : !arm_sve.vector<4xf32>
 }
 
+func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>,
+                        %b: !arm_sve.vector<4xf32>)
+                        -> !arm_sve.vector<4xi1> {
+  // CHECK: arm_sve.cmpf oeq, {{.*}}: !arm_sve.vector<4xf32>
+  %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32>
+  return %0 : !arm_sve.vector<4xi1>
+}
+
+func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>,
+                        %b: !arm_sve.vector<4xi32>)
+                        -> !arm_sve.vector<4xi1> {
+  // CHECK: arm_sve.cmpi uge, {{.*}}: !arm_sve.vector<4xi32>
+  %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32>
+  return %0 : !arm_sve.vector<4xi1>
+}
+
 func @get_vector_scale() -> index {
   // CHECK: arm_sve.vector_scale : index
   %0 = arm_sve.vector_scale : index

diff  --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index cf367904f8996..1e857275ec0f3 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -139,6 +139,57 @@ llvm.func @arm_sve_arithf_masked(%arg0: !llvm.vec<? x 4 x f32>,
   llvm.return %3 : !llvm.vec<? x 4 x f32>
 }
 
+// CHECK-LABEL: define <vscale x 4 x i1> @arm_sve_mask_genf
+llvm.func @arm_sve_mask_genf(%arg0: !llvm.vec<? x 4 x f32>,
+                             %arg1: !llvm.vec<? x 4 x f32>)
+                             -> !llvm.vec<? x 4 x i1> {
+  // CHECK: fcmp oeq <vscale x 4 x float>
+  %0 = llvm.fcmp "oeq" %arg0, %arg1 : !llvm.vec<? x 4 x f32>
+  llvm.return %0 : !llvm.vec<? x 4 x i1>
+}
+
+// CHECK-LABEL: define <vscale x 4 x i1> @arm_sve_mask_geni
+llvm.func @arm_sve_mask_geni(%arg0: !llvm.vec<? x 4 x i32>,
+                             %arg1: !llvm.vec<? x 4 x i32>)
+                             -> !llvm.vec<? x 4 x i1> {
+  // CHECK: icmp uge <vscale x 4 x i32>
+  %0 = llvm.icmp "uge" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+  llvm.return %0 : !llvm.vec<? x 4 x i1>
+}
+
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_abs_
diff 
+llvm.func @arm_sve_abs_
diff (%arg0: !llvm.vec<? x 4 x i32>,
+                            %arg1: !llvm.vec<? x 4 x i32>)
+                            -> !llvm.vec<? x 4 x i32> {
+  // CHECK: sub <vscale x 4 x i32>
+  %0 = llvm.sub %arg0, %arg0  : !llvm.vec<? x 4 x i32>
+  // CHECK: icmp sge <vscale x 4 x i32>
+  %1 = llvm.icmp "sge" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+  // CHECK: icmp slt <vscale x 4 x i32>
+  %2 = llvm.icmp "slt" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
+  %3 = "arm_sve.intr.sub"(%1, %arg0, %arg1) : (!llvm.vec<? x 4 x i1>,
+                                               !llvm.vec<? x 4 x i32>,
+                                               !llvm.vec<? x 4 x i32>)
+                                               -> !llvm.vec<? x 4 x i32>
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
+  %4 = "arm_sve.intr.sub"(%2, %arg1, %arg0) : (!llvm.vec<? x 4 x i1>,
+                                               !llvm.vec<? x 4 x i32>,
+                                               !llvm.vec<? x 4 x i32>)
+                                               -> !llvm.vec<? x 4 x i32>
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
+  %5 = "arm_sve.intr.add"(%1, %0, %3) : (!llvm.vec<? x 4 x i1>,
+                                         !llvm.vec<? x 4 x i32>,
+                                         !llvm.vec<? x 4 x i32>)
+                                         -> !llvm.vec<? x 4 x i32>
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
+  %6 = "arm_sve.intr.add"(%2, %5, %4) : (!llvm.vec<? x 4 x i1>,
+                                         !llvm.vec<? x 4 x i32>,
+                                         !llvm.vec<? x 4 x i32>)
+                                         -> !llvm.vec<? x 4 x i32>
+  llvm.return %6 : !llvm.vec<? x 4 x i32>
+}
+
 // CHECK-LABEL: define i64 @get_vector_scale()
 llvm.func @get_vector_scale() -> i64 {
   // CHECK: call i64 @llvm.vscale.i64()


        


More information about the Mlir-commits mailing list