[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Oct 19 11:05:47 PDT 2023
================
@@ -507,4 +524,95 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
}];
}
+class HasMatchingMaskTypeConstraint<string operand, string maskGetter> :
+ TypesMatchWith<
+ "shape of `" # operand # "Mask` matches `" # operand # "`",
+ operand, operand # "Mask",
+ "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))",
+ "!" # maskGetter # "() || std::equal_to<>()">;
+
+def OuterProductOp :
+ ArmSME_Op<"outerproduct", [Pure,
+ AttrSizedOperandSegments,
+ AllElementTypesMatch<["lhs", "rhs", "result"]>,
+ HasMatchingMaskTypeConstraint<"lhs", "getLhsMask">,
+ HasMatchingMaskTypeConstraint<"rhs", "getRhsMask">,
+ PredOpTrait<
+ "both `lhsMask` and `rhsMask` should be provided or neither",
+ CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
+ PredOpTrait<
+ "result type is derived from `lhs` and `rhs`",
+ CPred<
+ "getResultType() == VectorType::get({"
+ "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
+ "getRhsType().getElementType(),"
+ "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
+ TypesMatchWith<"`result` and `acc` have the same type",
+ "result", "acc",
+ "::llvm::cast<mlir::Type>($_self)",
+ "!getAcc() || std::equal_to<>()">
----------------
MacDue wrote:
...and here
https://github.com/llvm/llvm-project/pull/69604
More information about the Mlir-commits
mailing list