[Mlir-commits] [mlir] [mlir][linalg] Extend elementwise (PR #124661)
Rolf Morel
llvmlistbot at llvm.org
Mon Feb 3 06:23:08 PST 2025
================
@@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
+
+// Join two I32EnumAttrCase lists. This joining takes care that the
+// 'int enum values' in the combined list do not overlap. It does this
+// by adding to each element of second list the offset '!size(a)'.
+class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
+ list<I32EnumAttrCase> b> {
+ int aSize = !size(a);
+ list<I32EnumAttrCase> result =
+ !foldl(a, b, acc, var,
+ acc # [I32EnumAttrCase<var.symbol,
+ !add(var.value, aSize)
+ >]);
+}
+
+// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
+// The flattening (via call to 'join') ensures no overlap in enum values.
+class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
+ list<I32EnumAttrCase> result =
+ !foldl([]<I32EnumAttrCase>, l, acc, var,
+ JoinTwoI32EnumAttrCaseList<acc, var>.result);
+}
+
+// Define a unified `enum class : i32` for all element-wise op functions.
+def ElementwiseFn :
+ I32EnumAttr<"ElementwiseFn",
+ "",
+ ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
+ BinaryFn.enumerants,
+ TernaryFn.enumerants]>.result
+ > {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
+// Define an `enum class : i32` that marks where each individual enum class
+// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseFn.
+def ElementwiseFnLimits : I32EnumAttr<"ElementwiseFnLimits", "", []> {
+ int last_unary = !size(UnaryFn.enumerants);
+ int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
+ int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));
+
+ let enumerants = [
+ I32EnumAttrCase<"LastUnary", last_unary>,
+ I32EnumAttrCase<"LastBinary", last_binary>,
+ I32EnumAttrCase<"LastTernary", last_ternary>];
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
+// Define an `enum class : i32` to categorise elementwise ops.
+def ElementwiseNAryCategory : I32EnumAttr<"ElementwiseNAryCategory", "", [
----------------
rolfmorel wrote:
Couldn't this just be `ElementwiseArity`? The concept of `NAryCategory` is exactly arity, right?
https://github.com/llvm/llvm-project/pull/124661
More information about the Mlir-commits
mailing list