[Mlir-commits] [mlir] f880bd2 - [mlir][ArmSVE] Add basic mask generation operations
Javier Setoain
llvmlistbot at llvm.org
Wed Jun 9 01:58:58 PDT 2021
Author: Javier Setoain
Date: 2021-06-09T09:56:53+01:00
New Revision: f880bd261f4e13d4d58a75886a2942f05783c7de
URL: https://github.com/llvm/llvm-project/commit/f880bd261f4e13d4d58a75886a2942f05783c7de
DIFF: https://github.com/llvm/llvm-project/commit/f880bd261f4e13d4d58a75886a2942f05783c7de.diff
LOG: [mlir][ArmSVE] Add basic mask generation operations
These `arm_sve.cmp` functions are needed to generate scalable vector
masks as long as scalable vectors are not part of the standard types.
Once in standard, these can be removed and `std.cmp` can be used
instead.
Differential Revision: https://reviews.llvm.org/D103473
Added:
Modified:
mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
mlir/test/Dialect/ArmSVE/roundtrip.mlir
mlir/test/Target/LLVMIR/arm-sve.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
index e34177bb50940..7114fdb9425a9 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td
@@ -15,6 +15,7 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td"
//===----------------------------------------------------------------------===//
@@ -398,6 +399,123 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
+//===----------------------------------------------------------------------===//
+// ScalableCmpFOp
+//===----------------------------------------------------------------------===//
+
+def ScalableCmpFOp : ArmSVE_Op<"cmpf", [NoSideEffect, SameTypeOperands,
+ TypesMatchWith<"result type has i1 element type and same shape as operands",
+ "lhs", "result", "getI1SameShape($_self)">]> {
+ let summary = "floating-point comparison operation for scalable vectors";
+ let description = [{
+ The `arm_sve.cmpf` operation compares two scalable vectors of floating point
+ elements according to the float comparison rules and the predicate specified
+ by the respective attribute. The predicate defines the type of comparison:
+ (un)orderedness, (in)equality and signed less/greater than (or equal to) as
+ well as predicates that are always true or false. The result is a scalable
+ vector of i1 elements. Unlike `arm_sve.cmpi`, the operands are always
+ treated as signed. The u prefix indicates *unordered* comparison, not
+ unsigned comparison, so "une" means unordered not equal. For the sake of
+ readability by humans, custom assembly form for the operation uses a
+ string-typed attribute for the predicate. The value of this attribute
+ corresponds to lower-cased name of the predicate constant, e.g., "one" means
+ "ordered not equal". The string representation of the attribute is merely a
+ syntactic sugar and is converted to an integer attribute by the parser.
+
+ Example:
+
+ ```mlir
+ %r = arm_sve.cmpf oeq, %0, %1 : !arm_sve.vector<4xf32>
+ ```
+ }];
+ let arguments = (ins
+ CmpFPredicateAttr:$predicate,
+ ScalableVectorOf<[AnyFloat]>:$lhs,
+ ScalableVectorOf<[AnyFloat]>:$rhs // TODO: This should support a simple scalar
+ );
+ let results = (outs ScalableVectorOf<[I1]>:$result);
+
+ let builders = [
+ OpBuilder<(ins "CmpFPredicate":$predicate, "Value":$lhs,
+ "Value":$rhs), [{
+ buildScalableCmpFOp($_builder, $_state, predicate, lhs, rhs);
+ }]>];
+
+ let extraClassDeclaration = [{
+ static StringRef getPredicateAttrName() { return "predicate"; }
+ static CmpFPredicate getPredicateByName(StringRef name);
+
+ CmpFPredicate getPredicate() {
+ return (CmpFPredicate)(*this)->getAttrOfType<IntegerAttr>(
+ getPredicateAttrName()).getInt();
+ }
+ }];
+
+ let verifier = [{ return success(); }];
+
+ let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
+}
+
+//===----------------------------------------------------------------------===//
+// ScalableCmpIOp
+//===----------------------------------------------------------------------===//
+
+def ScalableCmpIOp : ArmSVE_Op<"cmpi", [NoSideEffect, SameTypeOperands,
+ TypesMatchWith<"result type has i1 element type and same shape as operands",
+ "lhs", "result", "getI1SameShape($_self)">]> {
+ let summary = "integer comparison operation for scalable vectors";
+ let description = [{
+ The `arm_sve.cmpi` operation compares two scalable vectors of integer
+ elements according to the predicate specified by the respective attribute.
+
+ The predicate defines the type of comparison:
+
+ - equal (mnemonic: `"eq"`; integer value: `0`)
+ - not equal (mnemonic: `"ne"`; integer value: `1`)
+ - signed less than (mnemonic: `"slt"`; integer value: `2`)
+ - signed less than or equal (mnemonic: `"sle"`; integer value: `3`)
+ - signed greater than (mnemonic: `"sgt"`; integer value: `4`)
+ - signed greater than or equal (mnemonic: `"sge"`; integer value: `5`)
+ - unsigned less than (mnemonic: `"ult"`; integer value: `6`)
+ - unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`)
+ - unsigned greater than (mnemonic: `"ugt"`; integer value: `8`)
+ - unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`)
+
+ Example:
+
+ ```mlir
+ %r = arm_sve.cmpi uge, %0, %1 : !arm_sve.vector<4xi32>
+ ```
+ }];
+
+ let arguments = (ins
+ CmpIPredicateAttr:$predicate,
+ ScalableVectorOf<[I8, I16, I32, I64]>:$lhs,
+ ScalableVectorOf<[I8, I16, I32, I64]>:$rhs
+ );
+ let results = (outs ScalableVectorOf<[I1]>:$result);
+
+ let builders = [
+ OpBuilder<(ins "CmpIPredicate":$predicate, "Value":$lhs,
+ "Value":$rhs), [{
+ buildScalableCmpIOp($_builder, $_state, predicate, lhs, rhs);
+ }]>];
+
+ let extraClassDeclaration = [{
+ static StringRef getPredicateAttrName() { return "predicate"; }
+ static CmpIPredicate getPredicateByName(StringRef name);
+
+ CmpIPredicate getPredicate() {
+ return (CmpIPredicate)(*this)->getAttrOfType<IntegerAttr>(
+ getPredicateAttrName()).getInt();
+ }
+ }];
+
+ let verifier = [{ return success(); }];
+
+ let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)";
+}
+
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
index 7689eb4b13788..06fbf5a717ea5 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h
@@ -19,6 +19,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc"
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 8f863978456c9..cfcda1f24214d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -708,34 +708,6 @@ def FloorFOp : FloatUnaryOp<"floorf"> {
// CmpFOp
//===----------------------------------------------------------------------===//
-// The predicate indicates the type of the comparison to perform:
-// (un)orderedness, (in)equality and less/greater than (or equal to) as
-// well as predicates that are always true or false.
-def CMPF_P_FALSE : I64EnumAttrCase<"AlwaysFalse", 0, "false">;
-def CMPF_P_OEQ : I64EnumAttrCase<"OEQ", 1, "oeq">;
-def CMPF_P_OGT : I64EnumAttrCase<"OGT", 2, "ogt">;
-def CMPF_P_OGE : I64EnumAttrCase<"OGE", 3, "oge">;
-def CMPF_P_OLT : I64EnumAttrCase<"OLT", 4, "olt">;
-def CMPF_P_OLE : I64EnumAttrCase<"OLE", 5, "ole">;
-def CMPF_P_ONE : I64EnumAttrCase<"ONE", 6, "one">;
-def CMPF_P_ORD : I64EnumAttrCase<"ORD", 7, "ord">;
-def CMPF_P_UEQ : I64EnumAttrCase<"UEQ", 8, "ueq">;
-def CMPF_P_UGT : I64EnumAttrCase<"UGT", 9, "ugt">;
-def CMPF_P_UGE : I64EnumAttrCase<"UGE", 10, "uge">;
-def CMPF_P_ULT : I64EnumAttrCase<"ULT", 11, "ult">;
-def CMPF_P_ULE : I64EnumAttrCase<"ULE", 12, "ule">;
-def CMPF_P_UNE : I64EnumAttrCase<"UNE", 13, "une">;
-def CMPF_P_UNO : I64EnumAttrCase<"UNO", 14, "uno">;
-def CMPF_P_TRUE : I64EnumAttrCase<"AlwaysTrue", 15, "true">;
-
-def CmpFPredicateAttr : I64EnumAttr<
- "CmpFPredicate", "",
- [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE,
- CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT,
- CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> {
- let cppNamespace = "::mlir";
-}
-
def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
"result type has i1 element type and same shape as operands",
@@ -801,24 +773,6 @@ def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands,
// CmpIOp
//===----------------------------------------------------------------------===//
-def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>;
-def CMPI_P_NE : I64EnumAttrCase<"ne", 1>;
-def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>;
-def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>;
-def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>;
-def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>;
-def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>;
-def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>;
-def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>;
-def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>;
-
-def CmpIPredicateAttr : I64EnumAttr<
- "CmpIPredicate", "",
- [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT,
- CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> {
- let cppNamespace = "::mlir";
-}
-
def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
"result type has i1 element type and same shape as operands",
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
index 802a32fce370a..e18ae4c3be2db 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
@@ -36,4 +36,50 @@ def AtomicRMWKindAttr : I64EnumAttr<
let cppNamespace = "::mlir";
}
+// The predicate indicates the type of the comparison to perform:
+// (un)orderedness, (in)equality and less/greater than (or equal to) as
+// well as predicates that are always true or false.
+def CMPF_P_FALSE : I64EnumAttrCase<"AlwaysFalse", 0, "false">;
+def CMPF_P_OEQ : I64EnumAttrCase<"OEQ", 1, "oeq">;
+def CMPF_P_OGT : I64EnumAttrCase<"OGT", 2, "ogt">;
+def CMPF_P_OGE : I64EnumAttrCase<"OGE", 3, "oge">;
+def CMPF_P_OLT : I64EnumAttrCase<"OLT", 4, "olt">;
+def CMPF_P_OLE : I64EnumAttrCase<"OLE", 5, "ole">;
+def CMPF_P_ONE : I64EnumAttrCase<"ONE", 6, "one">;
+def CMPF_P_ORD : I64EnumAttrCase<"ORD", 7, "ord">;
+def CMPF_P_UEQ : I64EnumAttrCase<"UEQ", 8, "ueq">;
+def CMPF_P_UGT : I64EnumAttrCase<"UGT", 9, "ugt">;
+def CMPF_P_UGE : I64EnumAttrCase<"UGE", 10, "uge">;
+def CMPF_P_ULT : I64EnumAttrCase<"ULT", 11, "ult">;
+def CMPF_P_ULE : I64EnumAttrCase<"ULE", 12, "ule">;
+def CMPF_P_UNE : I64EnumAttrCase<"UNE", 13, "une">;
+def CMPF_P_UNO : I64EnumAttrCase<"UNO", 14, "uno">;
+def CMPF_P_TRUE : I64EnumAttrCase<"AlwaysTrue", 15, "true">;
+
+def CmpFPredicateAttr : I64EnumAttr<
+ "CmpFPredicate", "",
+ [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE,
+ CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT,
+ CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> {
+ let cppNamespace = "::mlir";
+}
+
+def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>;
+def CMPI_P_NE : I64EnumAttrCase<"ne", 1>;
+def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>;
+def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>;
+def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>;
+def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>;
+def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>;
+def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>;
+def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>;
+def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>;
+
+def CmpIPredicateAttr : I64EnumAttr<
+ "CmpIPredicate", "",
+ [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT,
+ CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> {
+ let cppNamespace = "::mlir";
+}
+
#endif // STANDARD_OPS_BASE
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index b86ba14303f89..5e5ce6ed63bc4 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -20,8 +20,13 @@
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
+using namespace arm_sve;
static Type getI1SameShape(Type type);
+static void buildScalableCmpIOp(OpBuilder &build, OperationState &result,
+ CmpIPredicate predicate, Value lhs, Value rhs);
+static void buildScalableCmpFOp(OpBuilder &build, OperationState &result,
+ CmpFPredicate predicate, Value lhs, Value rhs);
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
@@ -29,7 +34,7 @@ static Type getI1SameShape(Type type);
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
-void arm_sve::ArmSVEDialect::initialize() {
+void ArmSVEDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
@@ -44,7 +49,7 @@ void arm_sve::ArmSVEDialect::initialize() {
// ScalableVectorType
//===----------------------------------------------------------------------===//
-Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const {
+Type ArmSVEDialect::parseType(DialectAsmParser &parser) const {
llvm::SMLoc typeLoc = parser.getCurrentLocation();
{
Type genType;
@@ -57,7 +62,7 @@ Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const {
return Type();
}
-void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
+void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
if (failed(generatedTypePrinter(type, os)))
llvm_unreachable("unexpected 'arm_sve' type kind");
}
@@ -69,8 +74,28 @@ void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
// Return the scalable vector of the same shape and containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
- if (auto sVectorType = type.dyn_cast<arm_sve::ScalableVectorType>())
- return arm_sve::ScalableVectorType::get(type.getContext(),
- sVectorType.getShape(), i1Type);
+ if (auto sVectorType = type.dyn_cast<ScalableVectorType>())
+ return ScalableVectorType::get(type.getContext(), sVectorType.getShape(),
+ i1Type);
return nullptr;
}
+
+//===----------------------------------------------------------------------===//
+// CmpFOp
+//===----------------------------------------------------------------------===//
+
+static void buildScalableCmpFOp(OpBuilder &build, OperationState &result,
+ CmpFPredicate predicate, Value lhs, Value rhs) {
+ result.addOperands({lhs, rhs});
+ result.types.push_back(getI1SameShape(lhs.getType()));
+ result.addAttribute(ScalableCmpFOp::getPredicateAttrName(),
+ build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
+}
+
+static void buildScalableCmpIOp(OpBuilder &build, OperationState &result,
+ CmpIPredicate predicate, Value lhs, Value rhs) {
+ result.addOperands({lhs, rhs});
+ result.types.push_back(getI1SameShape(lhs.getType()));
+ result.addAttribute(ScalableCmpIOp::getPredicateAttrName(),
+ build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
+}
diff --git a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
index 9177b5889b948..4a2393e7ac3d9 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSVE
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMIR
+ MLIRStandard
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 845f407fba3fb..e43511a87a47b 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -143,6 +143,24 @@ configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
// clang-format on
}
+static void
+populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
+ OwningRewritePatternList &patterns) {
+ // clang-format off
+ patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
+ OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
+ >(converter);
+ // clang-format on
+}
+
+static void
+configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
+ // clang-format off
+ target.addIllegalOp<ScalableCmpFOp,
+ ScalableCmpIOp>();
+ // clang-format on
+}
+
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
@@ -175,6 +193,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ScalableMaskedDivFOpLowering>(converter);
// clang-format on
populateBasicSVEArithmeticExportPatterns(converter, patterns);
+ populateSVEMaskGenerationExportPatterns(converter, patterns);
}
void mlir::configureArmSVELegalizeForExportTarget(
@@ -225,4 +244,5 @@ void mlir::configureArmSVELegalizeForExportTarget(
!hasScalableVectorType(op->getResultTypes());
});
configureBasicSVEArithmeticLegalizations(target);
+ configureSVEMaskGenerationLegalizations(target);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a9c88afec9533..512a0ab898e2d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -134,11 +134,15 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
if (!isCompatibleType(type))
return parser.emitError(trailingTypeLoc,
"expected LLVM dialect-compatible type");
- if (LLVM::isCompatibleVectorType(type))
- resultType = LLVM::getFixedVectorType(
- resultType, LLVM::getVectorNumElements(type).getFixedValue());
- assert(!type.isa<LLVM::LLVMScalableVectorType>() &&
- "unhandled scalable vector");
+ if (LLVM::isCompatibleVectorType(type)) {
+ if (type.isa<LLVM::LLVMScalableVectorType>()) {
+ resultType = LLVM::LLVMScalableVectorType::get(
+ resultType, LLVM::getVectorNumElements(type).getKnownMinValue());
+ } else {
+ resultType = LLVM::getFixedVectorType(
+ resultType, LLVM::getVectorNumElements(type).getFixedValue());
+ }
+ }
result.addTypes({resultType});
return success();
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 2b2eda0bf32e7..7fd2a63e63df8 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -121,6 +121,46 @@ func @arm_sve_arithf_masked(%a: !arm_sve.vector<4xf32>,
return %3 : !arm_sve.vector<4xf32>
}
+func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>,
+ %b: !arm_sve.vector<4xf32>)
+ -> !arm_sve.vector<4xi1> {
+ // CHECK: llvm.fcmp "oeq" {{.*}}: !llvm.vec<? x 4 x f32>
+ %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32>
+ return %0 : !arm_sve.vector<4xi1>
+}
+
+func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>,
+ %b: !arm_sve.vector<4xi32>)
+ -> !arm_sve.vector<4xi1> {
+ // CHECK: llvm.icmp "uge" {{.*}}: !llvm.vec<? x 4 x i32>
+ %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32>
+ return %0 : !arm_sve.vector<4xi1>
+}
+
+func @arm_sve_abs_
diff (%a: !arm_sve.vector<4xi32>,
+ %b: !arm_sve.vector<4xi32>)
+ -> !arm_sve.vector<4xi32> {
+ // CHECK: llvm.sub {{.*}}: !llvm.vec<? x 4 x i32>
+ %z = arm_sve.subi %a, %a : !arm_sve.vector<4xi32>
+ // CHECK: llvm.icmp "sge" {{.*}}: !llvm.vec<? x 4 x i32>
+ %agb = arm_sve.cmpi sge, %a, %b : !arm_sve.vector<4xi32>
+ // CHECK: llvm.icmp "slt" {{.*}}: !llvm.vec<? x 4 x i32>
+ %bga = arm_sve.cmpi slt, %a, %b : !arm_sve.vector<4xi32>
+ // CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+ %0 = arm_sve.masked.subi %agb, %a, %b : !arm_sve.vector<4xi1>,
+ !arm_sve.vector<4xi32>
+ // CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+ %1 = arm_sve.masked.subi %bga, %b, %a : !arm_sve.vector<4xi1>,
+ !arm_sve.vector<4xi32>
+ // CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+ %2 = arm_sve.masked.addi %agb, %z, %0 : !arm_sve.vector<4xi1>,
+ !arm_sve.vector<4xi32>
+ // CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec<? x 4 x i1>, !llvm.vec<? x 4 x i32>, !llvm.vec<? x 4 x i32>) -> !llvm.vec<? x 4 x i32>
+ %3 = arm_sve.masked.addi %bga, %2, %1 : !arm_sve.vector<4xi1>,
+ !arm_sve.vector<4xi32>
+ return %3 : !arm_sve.vector<4xi32>
+}
+
func @get_vector_scale() -> index {
// CHECK: arm_sve.vscale
%0 = arm_sve.vector_scale : index
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index 4666d16f33f24..2dde6c32a665e 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -103,6 +103,22 @@ func @arm_sve_masked_arithf(%a: !arm_sve.vector<4xf32>,
return %3 : !arm_sve.vector<4xf32>
}
+func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>,
+ %b: !arm_sve.vector<4xf32>)
+ -> !arm_sve.vector<4xi1> {
+ // CHECK: arm_sve.cmpf oeq, {{.*}}: !arm_sve.vector<4xf32>
+ %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32>
+ return %0 : !arm_sve.vector<4xi1>
+}
+
+func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>,
+ %b: !arm_sve.vector<4xi32>)
+ -> !arm_sve.vector<4xi1> {
+ // CHECK: arm_sve.cmpi uge, {{.*}}: !arm_sve.vector<4xi32>
+ %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32>
+ return %0 : !arm_sve.vector<4xi1>
+}
+
func @get_vector_scale() -> index {
// CHECK: arm_sve.vector_scale : index
%0 = arm_sve.vector_scale : index
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index cf367904f8996..1e857275ec0f3 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -139,6 +139,57 @@ llvm.func @arm_sve_arithf_masked(%arg0: !llvm.vec<? x 4 x f32>,
llvm.return %3 : !llvm.vec<? x 4 x f32>
}
+// CHECK-LABEL: define <vscale x 4 x i1> @arm_sve_mask_genf
+llvm.func @arm_sve_mask_genf(%arg0: !llvm.vec<? x 4 x f32>,
+ %arg1: !llvm.vec<? x 4 x f32>)
+ -> !llvm.vec<? x 4 x i1> {
+ // CHECK: fcmp oeq <vscale x 4 x float>
+ %0 = llvm.fcmp "oeq" %arg0, %arg1 : !llvm.vec<? x 4 x f32>
+ llvm.return %0 : !llvm.vec<? x 4 x i1>
+}
+
+// CHECK-LABEL: define <vscale x 4 x i1> @arm_sve_mask_geni
+llvm.func @arm_sve_mask_geni(%arg0: !llvm.vec<? x 4 x i32>,
+ %arg1: !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i1> {
+ // CHECK: icmp uge <vscale x 4 x i32>
+ %0 = llvm.icmp "uge" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+ llvm.return %0 : !llvm.vec<? x 4 x i1>
+}
+
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_abs_
diff
+llvm.func @arm_sve_abs_
diff (%arg0: !llvm.vec<? x 4 x i32>,
+ %arg1: !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i32> {
+ // CHECK: sub <vscale x 4 x i32>
+ %0 = llvm.sub %arg0, %arg0 : !llvm.vec<? x 4 x i32>
+ // CHECK: icmp sge <vscale x 4 x i32>
+ %1 = llvm.icmp "sge" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+ // CHECK: icmp slt <vscale x 4 x i32>
+ %2 = llvm.icmp "slt" %arg0, %arg1 : !llvm.vec<? x 4 x i32>
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
+ %3 = "arm_sve.intr.sub"(%1, %arg0, %arg1) : (!llvm.vec<? x 4 x i1>,
+ !llvm.vec<? x 4 x i32>,
+ !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i32>
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sub.nxv4i32
+ %4 = "arm_sve.intr.sub"(%2, %arg1, %arg0) : (!llvm.vec<? x 4 x i1>,
+ !llvm.vec<? x 4 x i32>,
+ !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i32>
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
+ %5 = "arm_sve.intr.add"(%1, %0, %3) : (!llvm.vec<? x 4 x i1>,
+ !llvm.vec<? x 4 x i32>,
+ !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i32>
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.add.nxv4i32
+ %6 = "arm_sve.intr.add"(%2, %5, %4) : (!llvm.vec<? x 4 x i1>,
+ !llvm.vec<? x 4 x i32>,
+ !llvm.vec<? x 4 x i32>)
+ -> !llvm.vec<? x 4 x i32>
+ llvm.return %6 : !llvm.vec<? x 4 x i32>
+}
+
// CHECK-LABEL: define i64 @get_vector_scale()
llvm.func @get_vector_scale() -> i64 {
// CHECK: call i64 @llvm.vscale.i64()
More information about the Mlir-commits
mailing list