[Mlir-commits] [mlir] [mlir][linalg] Extend elementwise (PR #124661)
Rolf Morel
llvmlistbot at llvm.org
Mon Feb 3 07:22:22 PST 2025
================
@@ -3611,5 +3620,260 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// ElementwiseOp
+//===----------------------------------------------------------------------===//
+//
+namespace {
+
+struct NAryCategoryAndFn {
+ // The enum category class {Unary, Binary, Ternary, ..}
+ ElementwiseNAryCategory category;
+
+ union NAryFn {
+ UnaryFn unaryFn;
+ BinaryFn binaryFn;
+ TernaryFn ternaryFn;
+ } fn;
+
+ ::llvm::StringRef stringifyCategory() {
+ return stringifyElementwiseNAryCategory(category);
+ }
+
+ ::llvm::StringRef stringifyFn() {
+ switch (category) {
+ case ElementwiseNAryCategory::Unary:
+ return stringifyUnaryFn(fn.unaryFn);
+ case ElementwiseNAryCategory::Binary:
+ return stringifyBinaryFn(fn.binaryFn);
+ case ElementwiseNAryCategory::Ternary:
+ return stringifyTernaryFn(fn.ternaryFn);
+ }
+ llvm_unreachable("unknown-fn");
+ }
+};
+
+unsigned getArityFromCategory(ElementwiseNAryCategory category) {
+ switch (category) {
+ case ElementwiseNAryCategory::Unary:
+ return 1;
+ case ElementwiseNAryCategory::Binary:
+ return 2;
+ case ElementwiseNAryCategory::Ternary:
+ return 3;
+ }
+ llvm_unreachable("unhandled category");
+}
+} // namespace
+
+static NAryCategoryAndFn getNAryCategoryAndFn(ElementwiseFn fn) {
+ constexpr int lastUnary = static_cast<int>(ElementwiseFnLimits::LastUnary);
+ constexpr int lastBinary = static_cast<int>(ElementwiseFnLimits::LastBinary);
+ constexpr int lastTernary =
+ static_cast<int>(ElementwiseFnLimits::LastTernary);
+
+ int val = static_cast<int>(fn);
+ NAryCategoryAndFn result;
+
+ if (val < lastUnary) {
+ result.category = ElementwiseNAryCategory::Unary;
+ result.fn.unaryFn = static_cast<UnaryFn>(val);
+ return result;
+ }
+ if (val < lastBinary) {
+ result.category = ElementwiseNAryCategory::Binary;
+ result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary);
+ return result;
+ }
+ if (val >= lastTernary) {
+ llvm_unreachable("unhandled ElementwiseFn");
+ }
+ result.category = ElementwiseNAryCategory::Ternary;
+ result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
+ return result;
+}
+
+unsigned ElementwiseOp::getResultRank() {
+ auto output = getDpsInitOperand(0)->get();
+ auto shapedType = llvm::cast<ShapedType>(output.getType());
+ return shapedType.getRank();
+}
+
+SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
+ auto rank = getResultRank();
+ return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+}
+
+SmallVector<AffineMap>
+ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+ MLIRContext *context) {
+ auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+ return SmallVector<AffineMap>(numMaps, map);
+}
+
+ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Expect e.g. `kind = #linalg.elemwise_fn<add>`
+ Attribute attr;
+ mlir::linalg::ElementwiseFn elemwiseFnVal;
+ if (parser.parseKeyword("kind"))
+ return failure();
+ if (parser.parseEqual())
+ return failure();
+ if (succeeded(parser.parseAttribute(attr))) {
+ auto elemwiseFnAttr = dyn_cast<ElementwiseFnAttr>(attr);
+ if (!elemwiseFnAttr)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ElementwiseFn attribute");
+ elemwiseFnVal = elemwiseFnAttr.getValue();
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected operation 'kind' attribute");
+ }
+ result.addAttribute(
+ "kind", ElementwiseFnAttr::get(parser.getContext(), elemwiseFnVal));
+
+ // Parse optional `indexing_maps`
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+ if (parser.parseLSquare())
+ return failure();
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ indexingMapsAttr.push_back(mapAttr);
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+ if (parser.parseRSquare())
+ return failure();
+ }
+
+ // At this stage of parsing the only way to infer number of region
+ // args is through op kind, as input output tensors are not parsed yet.
+ auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
+ auto arity = getArityFromCategory(arityAndCategory.category);
+ int numRegionArgs = arity + 1 /*output*/;
+ if (parseNamedStructuredOp(parser, result, numRegionArgs,
+ ElementwiseOp::getRegionBuilder())) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "unable to parse elemwise op");
+ }
+
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ // We need to infer the `number of indexing maps` needed from the result
----------------
rolfmorel wrote:
nit: you (need to) infer the _rank_ of the affine maps from the output type, not the number of maps.
https://github.com/llvm/llvm-project/pull/124661
More information about the Mlir-commits
mailing list