[Mlir-commits] [mlir] [mlir][linalg] Extend elementwise (PR #124661)

Rolf Morel llvmlistbot at llvm.org
Mon Feb 3 07:46:54 PST 2025

@@ -3611,5 +3620,260 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+// ElementwiseOp
+namespace {
+struct NAryCategoryAndFn {
+  // The enum category class {Unary, Binary, Ternary, ..}
+  ElementwiseNAryCategory category;
+  union NAryFn {
+    UnaryFn unaryFn;
+    BinaryFn binaryFn;
+    TernaryFn ternaryFn;
+  } fn;
+  ::llvm::StringRef stringifyCategory() {
+    return stringifyElementwiseNAryCategory(category);
+  }
+  ::llvm::StringRef stringifyFn() {
+    switch (category) {
+    case ElementwiseNAryCategory::Unary:
+      return stringifyUnaryFn(fn.unaryFn);
+    case ElementwiseNAryCategory::Binary:
+      return stringifyBinaryFn(fn.binaryFn);
+    case ElementwiseNAryCategory::Ternary:
+      return stringifyTernaryFn(fn.ternaryFn);
+    }
+    llvm_unreachable("unknown-fn");
+  }
+unsigned getArityFromCategory(ElementwiseNAryCategory category) {
+  switch (category) {
+  case ElementwiseNAryCategory::Unary:
+    return 1;
+  case ElementwiseNAryCategory::Binary:
+    return 2;
+  case ElementwiseNAryCategory::Ternary:
+    return 3;
+  }
+  llvm_unreachable("unhandled category");
+} // namespace
+static NAryCategoryAndFn getNAryCategoryAndFn(ElementwiseFn fn) {
+  constexpr int lastUnary = static_cast<int>(ElementwiseFnLimits::LastUnary);
+  constexpr int lastBinary = static_cast<int>(ElementwiseFnLimits::LastBinary);
+  constexpr int lastTernary =
+      static_cast<int>(ElementwiseFnLimits::LastTernary);
+  int val = static_cast<int>(fn);
+  NAryCategoryAndFn result;
+  if (val < lastUnary) {
+    result.category = ElementwiseNAryCategory::Unary;
+    result.fn.unaryFn = static_cast<UnaryFn>(val);
+    return result;
+  }
+  if (val < lastBinary) {
+    result.category = ElementwiseNAryCategory::Binary;
+    result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary);
+    return result;
+  }
+  if (val >= lastTernary) {
+    llvm_unreachable("unhandled ElementwiseFn");
+  }
+  result.category = ElementwiseNAryCategory::Ternary;
+  result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
+  return result;
+unsigned ElementwiseOp::getResultRank() {
+  auto output = getDpsInitOperand(0)->get();
+  auto shapedType = llvm::cast<ShapedType>(output.getType());
+  return shapedType.getRank();
+SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
+  auto rank = getResultRank();
+  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+                                      MLIRContext *context) {
+  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+  return SmallVector<AffineMap>(numMaps, map);
+ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Expect e.g. `kind = #linalg.elemwise_fn<add>`
+  Attribute attr;
+  mlir::linalg::ElementwiseFn elemwiseFnVal;
+  if (parser.parseKeyword("kind"))
+    return failure();
+  if (parser.parseEqual())
+    return failure();
+  if (succeeded(parser.parseAttribute(attr))) {
+    auto elemwiseFnAttr = dyn_cast<ElementwiseFnAttr>(attr);
+    if (!elemwiseFnAttr)
+      return parser.emitError(parser.getCurrentLocation(),
+                              "expected ElementwiseFn attribute");
+    elemwiseFnVal = elemwiseFnAttr.getValue();
+  } else {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected operation 'kind' attribute");
+  }
+  result.addAttribute(
+      "kind", ElementwiseFnAttr::get(parser.getContext(), elemwiseFnVal));
+  // Parse optional `indexing_maps`
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+    if (parser.parseLSquare())
+      return failure();
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr))
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      indexingMapsAttr.push_back(mapAttr);
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // At this stage of parsing the only way to infer number of region
+  // args is through op kind, as input output tensors are not parsed yet.
+  auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
+  auto arity = getArityFromCategory(arityAndCategory.category);
+  int numRegionArgs = arity + 1 /*output*/;
+  if (parseNamedStructuredOp(parser, result, numRegionArgs,
+                             ElementwiseOp::getRegionBuilder())) {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "unable to parse elemwise op");
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    // We need to infer the `number of indexing maps` needed from the result
+    // type which is already parsed by now.
+    auto resultType = result.operands[result.operands.size() - 1].getType();
+    auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+    if (!shapedType)
+      return parser.emitError(parser.getCurrentLocation(),
+                              "return type needs to be shaped type");
+    auto numDims = shapedType.getRank();
+    indexingMapsAttr = llvm::map_to_vector(
+        ElementwiseOp::getDefaultIndexingMaps(arity + 1, numDims,
+                                              parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+  return success();
+void ElementwiseOp::print(OpAsmPrinter &p) {
+  p << " kind=";
+  p.printAttribute(getKindAttr());
+  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
+                                           "indexing_maps"};
+  auto category = getNAryCategoryAndFn(getKind()).category;
+  auto arity = getArityFromCategory(category);
+  auto numDims = getResultRank();
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      ElementwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                         elidedAttrs);
+LogicalResult ElementwiseOp::verify() {
+  // All necessary checks are done either by
+  // - EnumAttr (e.g. unknown operation kind)
+  // - verifyStructuredOpInterface (incorrect map, sizes).
+  return success();
+/// Implements the block region builder for the ElementwiseOp. This is called by
+/// 'fillStructuredOpRegion'.
+void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                  ArrayRef<NamedAttribute> attrs) {
+  ElementwiseFn elemwiseFn;
+  for (auto attr : attrs) {
+    if (attr.getName() == b.getStringAttr("kind")) {
+      auto funcTypeAttr = dyn_cast<ElementwiseFnAttr>(attr.getValue());
+      assert(funcTypeAttr && "op kind attribute incorrectly set");
+      elemwiseFn = funcTypeAttr.getValue();
+      break;
+    }
+  }
+  NAryCategoryAndFn categoryAndFn = getNAryCategoryAndFn(elemwiseFn);
+  ElementwiseNAryCategory category = categoryAndFn.category;
+  unsigned numBlockArgs = getArityFromCategory(categoryAndFn.category) + 1;
+  assert(block.getNumArguments() == numBlockArgs &&
+         "Elementwise regionBuilder number of block args mismatch");
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+  Value result;
+  if (category == ElementwiseNAryCategory::Unary) {
+    result =
+        helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0));
rolfmorel wrote:

In case I understand correctly, `categoryAndFn.fn.unaryFn` and `...fn.binaryFn` and `...fn.ternaryFn` are only used in this function. Could we move the logic for retrieving the actual function here and simplify/get rid of `NAryCategoryAndFn`?


More information about the Mlir-commits mailing list