[Mlir-commits] [mlir] [mlir][linalg] Extend Linalg elemwise named ops semantics (PR #122753)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Jan 14 07:36:39 PST 2025


================
@@ -3611,5 +3621,283 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ElemwiseOp - with support for affine map, func_type and comp_type
+//===----------------------------------------------------------------------===//
+//
+namespace {
+struct NAryCategoryAndFn {
+  // The enum category class {Unary, Binary, Ternary, ..}
+  ElemwiseNAryCategory category;
+
+  union NAryFn {
+    UnaryFn unaryFn;
+    BinaryFn binaryFn;
+    TernaryFn ternaryFn;
+  } fn;
+
+  ::llvm::StringRef stringifyCategory() {
+    switch (category) {
+    case ElemwiseNAryCategory::Unary:
+      return "unary";
+    case ElemwiseNAryCategory::Binary:
+      return "binary";
+    case ElemwiseNAryCategory::Ternary:
+      return "ternary";
+    }
+    llvm_unreachable("unknown-category");
+  }
+
+  ::llvm::StringRef stringifyFn() {
+    switch (category) {
+    case ElemwiseNAryCategory::Unary:
+      return stringifyUnaryFn(fn.unaryFn);
+    case ElemwiseNAryCategory::Binary:
+      return stringifyBinaryFn(fn.binaryFn);
+    case ElemwiseNAryCategory::Ternary:
+      return stringifyTernaryFn(fn.ternaryFn);
+    }
+    llvm_unreachable("unknown-fn");
+  }
+};
+
+unsigned getArityFromCategory(ElemwiseNAryCategory category) {
+  switch (category) {
+  case ElemwiseNAryCategory::Unary:
+    return 1;
+  case ElemwiseNAryCategory::Binary:
+    return 2;
+  case ElemwiseNAryCategory::Ternary:
+    return 3;
+  }
+  llvm_unreachable("unhandled category");
+}
+} // namespace
+
+static NAryCategoryAndFn getNAryCategoryAndFn(ElemwiseFn fn) {
+  constexpr int lastUnary = static_cast<int>(ElemwiseFn::erf);
+  constexpr int lastBinary = static_cast<int>(ElemwiseFn::powf);
+  constexpr int lastTernary = static_cast<int>(ElemwiseFn::select);
+
+  int val = static_cast<int>(fn);
+  NAryCategoryAndFn result;
+  if (val <= lastUnary) {
+    result.category = ElemwiseNAryCategory::Unary;
+    result.fn.unaryFn = static_cast<UnaryFn>(val);
+    return result;
+  }
+  if (val <= lastBinary) {
+    result.category = ElemwiseNAryCategory::Binary;
+    result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary - 1);
+    return result;
+  }
+  if (val > lastTernary) {
+    llvm_unreachable("unhandled ElemwiseFn");
+  }
+  result.category = ElemwiseNAryCategory::Ternary;
+  result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary - 1);
+  return result;
+}
+
+unsigned ElemwiseOp::getResultRank() {
+  auto output = getDpsInitOperand(0)->get();
+  auto shapedType = llvm::cast<ShapedType>(output.getType());
+  return shapedType.getRank();
+}
+
+SmallVector<utils::IteratorType> ElemwiseOp::getIteratorTypesArray() {
+  auto rank = getResultRank();
+  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+}
+
+SmallVector<AffineMap>
+ElemwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+                                   MLIRContext *context) {
+  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+  return SmallVector<AffineMap>(numMaps, map);
+}
+
+bool ElemwiseOp::hasUserDefinedMaps() {
+  auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+  auto arity = getArityFromCategory(category);
+
+  auto numDims = getResultRank();
+  SmallVector<AffineMap, 3> defaultMaps =
----------------
adam-smnk wrote:

nit: it can be just `SmallVector<AffineMap>` as size changes with arity

https://github.com/llvm/llvm-project/pull/122753


More information about the Mlir-commits mailing list