[Mlir-commits] [mlir] [mlir][linalg] Extend Linalg elemwise named ops semantics (PR #122753)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 13 09:43:38 PST 2025
llvmbot wrote:
Author: Javed Absar (javedabsar1)
Implements Linalg elemwise named-op following the
proposal and discussions in RFC:
Discussions are on-going on RFC especially about
`comp_type` and so that part is left open/unimplemented in this diff.
Patch is 33.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122753.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td (+5)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+51)
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+130)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+287)
- (added) mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir (+170)
- (added) mlir/test/Dialect/Linalg/elemwise/invalid.mlir (+39)
- (added) mlir/test/Dialect/Linalg/elemwise/round-trip.mlir (+72)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d31..115eaebc6aff54 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,6 +61,11 @@ def Linalg_Dialect : Dialect {
+// Define the enum-type Elemwise func attribute.
+def ElemwiseFnAttr : EnumAttr<Linalg_Dialect, ElemwiseFn, "elemwise_fn"> {
+ let assemblyFormat = "`<` $value `>`";
// Define the function attribute enums matching the OpDSL functions.
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d057..5135e9cd4386ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -15,6 +15,57 @@
include "mlir/IR/EnumAttr.td"
+// Define an `enum class : i32` to categorise element-wise op.
+def ElemwiseNAryCategory : I32EnumAttr<"ElemwiseNAryCategory", "", [
+ I32EnumAttrCase<"Unary", 0>,
+ I32EnumAttrCase<"Binary", 1>,
+ I32EnumAttrCase<"Ternary", 2>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+// Define a unified `enum class : i32` for all element-wise options.
+// Note: The order of individual fn (e.g. 'exp', 'log') within each
+// category (Unary, Binary etc.) must match the ordering of same fn
+// defined in UnaryFn, BinaryFn. This is to enable correct mapping
+// from this unified enum class to different category enums.
+def ElemwiseFn : I32EnumAttr<"ElemwiseFn", "", [
+ // Unary
+ I32EnumAttrCase<"exp", 0>,
+ I32EnumAttrCase<"log", 1>,
+ I32EnumAttrCase<"abs", 2>,
+ I32EnumAttrCase<"ceil", 3>,
+ I32EnumAttrCase<"floor", 4>,
+ I32EnumAttrCase<"negf", 5>,
+ I32EnumAttrCase<"reciprocal", 6>,
+ I32EnumAttrCase<"round", 7>,
+ I32EnumAttrCase<"sqrt", 8>,
+ I32EnumAttrCase<"rsqrt", 9>,
+ I32EnumAttrCase<"square", 10>,
+ I32EnumAttrCase<"tanh", 11>,
+ I32EnumAttrCase<"erf", 12>,
+ // Binary
+ I32EnumAttrCase<"add", 13>,
+ I32EnumAttrCase<"sub", 14>,
+ I32EnumAttrCase<"mul", 15>,
+ I32EnumAttrCase<"div", 16>,
+ I32EnumAttrCase<"div_unsigned", 17>,
+ I32EnumAttrCase<"max_signed", 18>,
+ I32EnumAttrCase<"min_signed", 19>,
+ I32EnumAttrCase<"max_unsigned", 20>,
+ I32EnumAttrCase<"min_unsigned", 21>,
+ I32EnumAttrCase<"powf", 22>,
+ // Ternary
+ I32EnumAttrCase<"select", 23>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
// Define the function attribute enums matching the OpDSL functions.
def UnaryFn : I32EnumAttr<"UnaryFn", "", [
I32EnumAttrCase<"exp", 0>,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..6d6ff7a5c7872a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
+// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
+def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [
+ AttrSizedOperandSegments]> {
+ let summary = [{ Performs element-wise operation }];
+ let description = [{
+ Linalg op form which performs element-wise computation. The attribute
+ `func_type` describes the operation type (e.g. add, exp). The func_type
+ can be any valid unary, binary, or ternary operation.
+ Affine-maps for operands and result may be provided by the user. When
+ a user-defined indexing_map is not provided, identity map is inferred
+ for all operands. The default indexing maps are N identity-maps. ‘N’
+ depends on the arity of the elementwise op. The number of dims is
+ inferred from rank of the output type. In the case of default indexing
+ map, the input and output shapes must all match. Affine-map for operands
+ and result must be only projected permutations with no zero constants.
+ For element-wise iterator-type is always inferred as all ‘parallel’.
+ Iterator-type is needed for constructing this underlying structured op.
+ The number of dims of the iterator-type is inferred from the rank of
+ the result type.
+ Example:
+ Defining a unary linalg.elemwise with default indexing-map:
+ ```mlir
+ %exp = linalg.elemwise
+ func_type=#linalg.elemwise_fn<exp>
+ ins(%x : tensor<4x16x8xf32>)
+ outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
+ ```
+ Defining a binary linalg.elemwise with user-defined indexing-map:
+ ```mlir
+ %add = linalg.elemwise
+ func_type=#linalg.elemwise_fn<add>
+ indexing_maps = [#transpose, #broadcast, #identity]
+ ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
+ outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ ```
+ }];
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ ElemwiseFnAttr:$func_type,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildElemwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, ElemwiseOp::getRegionBuilder());
+ }]>
+ ];
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ /// Get the nary category enum, e.g. `ElemwiseNAryCategory::Unary`,
+ /// corresponding to the given fn, e.g. `ElemwiseFn::exp`
+ static ElemwiseNAryCategory getNAryCategory(ElemwiseFn fn);
+ /// Elementwise is always `dynamic indexing maps` i.e. `user specified`
+ /// or `default`. Default is identity-maps.
+ static bool hasDynamicIndexingMaps() { return true; }
+ /// Implements the block region builder for the eemwiseOp. This is called
+ /// by the 'fillStructuredOpRegion'.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+ /// Returns elementwise op kind e.g. `add` inferred from func_type attr.
+ ElemwiseFn getElemwiseFnVal() {
+ return getFuncType();
+ }
+ /// Infer dimensionality of the `iteration space` from the result type.
+ /// Useful when others means are not possible e.g. in case of absence of
+ /// user-provided indexing map.
+ unsigned getResultRank();
+ /// Elementwise op does not have to explicitly specify iterator type
+ /// as it is always 'parallel'. The number of 'parallel' loops is
+ /// inferred from other means (e.g. result tensor type).
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+ /// The default indexing maps are N identity-maps. 'N' depends on the
+ /// arity of the elementwise op. The default case is when all input
+ /// output tensors are same rank and no transpose/broadcast is needed.
+ static SmallVector<AffineMap>
+ getDefaultIndexingMaps(unsigned N, unsigned numDims,
+ MLIRContext *context);
+ /// Returns true if the user defined indexing maps are not equal to
+ /// the default (identity) map.
+ bool hasUserDefinedMaps();
+ /// destination passing style interface method.
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+ // Generic methods.
+ std::string getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+ }
+ }];
// Op definition for MatmulOp
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..c84220f5b4f2ce 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,15 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
+static void buildElemwiseOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder) {
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3566,6 +3575,7 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
void MatmulOp::print(OpAsmPrinter &p) {
SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3611,5 +3621,282 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+// ElemwiseOp - with support for affine map, func_type and comp_type
+namespace {
+struct NAryCategoryAndFn {
+ // The enum category class {Unary, Binary, Ternary, ..}
+ ElemwiseNAryCategory category;
+ union NAryFn {
+ UnaryFn unaryFn;
+ BinaryFn binaryFn;
+ TernaryFn ternaryFn;
+ } fn;
+ ::llvm::StringRef stringifyCategory() {
+ switch (category) {
+ case ElemwiseNAryCategory::Unary:
+ return "unary";
+ case ElemwiseNAryCategory::Binary:
+ return "binary";
+ case ElemwiseNAryCategory::Ternary:
+ return "ternary";
+ }
+ llvm_unreachable("unknown-category");
+ }
+ ::llvm::StringRef stringifyFn() {
+ switch (category) {
+ case ElemwiseNAryCategory::Unary:
+ return stringifyUnaryFn(fn.unaryFn);
+ case ElemwiseNAryCategory::Binary:
+ return stringifyBinaryFn(fn.binaryFn);
+ case ElemwiseNAryCategory::Ternary:
+ return stringifyTernaryFn(fn.ternaryFn);
+ }
+ llvm_unreachable("unknown-fn");
+ }
+unsigned getArityFromCategory(ElemwiseNAryCategory category) {
+ switch (category) {
+ case ElemwiseNAryCategory::Unary:
+ return 1;
+ case ElemwiseNAryCategory::Binary:
+ return 2;
+ case ElemwiseNAryCategory::Ternary:
+ return 3;
+ }
+ llvm_unreachable("unhandled category");
+} // namespace
+static NAryCategoryAndFn getNAryCategoryAndFn(ElemwiseFn fn) {
+ constexpr int lastUnary = static_cast<int>(ElemwiseFn::erf);
+ constexpr int lastBinary = static_cast<int>(ElemwiseFn::powf);
+ constexpr int lastTernary = static_cast<int>(ElemwiseFn::select);
+ int val = static_cast<int>(fn);
+ NAryCategoryAndFn result;
+ if (val <= lastUnary) {
+ result.category = ElemwiseNAryCategory::Unary;
+ result.fn.unaryFn = static_cast<UnaryFn>(val);
+ return result;
+ }
+ if (val <= lastBinary) {
+ result.category = ElemwiseNAryCategory::Binary;
+ result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary - 1);
+ return result;
+ }
+ if (val > lastTernary) {
+ llvm_unreachable("unhandled ElemwiseFn");
+ }
+ result.category = ElemwiseNAryCategory::Ternary;
+ result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary - 1);
+ return result;
+unsigned ElemwiseOp::getResultRank() {
+ auto output = getDpsInitOperand(0)->get();
+ auto shapedType = llvm::cast<ShapedType>(output.getType());
+ return shapedType.getRank();
+SmallVector<utils::IteratorType> ElemwiseOp::getIteratorTypesArray() {
+ auto rank = getResultRank();
+ return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+ElemwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+ MLIRContext *context) {
+ auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+ return SmallVector<AffineMap>(numMaps, map);
+bool ElemwiseOp::hasUserDefinedMaps() {
+ auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+ auto arity = getArityFromCategory(category);
+ auto numDims = getResultRank();
+ SmallVector<AffineMap, 3> defaultMaps =
+ getDefaultIndexingMaps(arity + 1, numDims, this->getContext());
+ SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
+ return defaultMaps != explicitMaps;
+ParseResult ElemwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Expect e.g. `func_type = #linalg.elemwise_fn<add>`
+ Attribute attr;
+ mlir::linalg::ElemwiseFn elemwiseFnVal;
+ if (parser.parseKeyword("func_type"))
+ return failure();
+ if (parser.parseEqual())
+ return failure();
+ if (succeeded(parser.parseAttribute(attr))) {
+ auto elemwiseFnAttr = dyn_cast<ElemwiseFnAttr>(attr);
+ if (!elemwiseFnAttr)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ElemwiseFn attribute");
+ elemwiseFnVal = elemwiseFnAttr.getValue();
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected 'func_type' attribute");
+ }
+ result.addAttribute("func_type",
+ ElemwiseFnAttr::get(parser.getContext(), elemwiseFnVal));
+ // Parse optional `indexing_maps`
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+ if (parser.parseLSquare())
+ return failure();
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ indexingMapsAttr.push_back(mapAttr);
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // At this stage of parsing the only way to infer number of region
+ // args is through op kind, as input output tensors are not parsed yet.
+ auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
+ auto arity = getArityFromCategory(arityAndCategory.category);
+ int numRegionArgs = arity + 1 /*output*/;
+ if (parseNamedStructuredOp(parser, result, numRegionArgs,
+ ElemwiseOp::getRegionBuilder())) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "unable to parse elemwise op");
+ }
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ // We need to infer the `number of indexing maps` needed from the result
+ // type which is already parsed by now.
+ auto resultType = result.operands[result.operands.size() - 1].getType();
+ auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+ if (!shapedType)
+ return parser.emitError(parser.getCurrentLocation(),
+ "return type needs to be shaped type");
+ auto numDims = shapedType.getRank();
+ indexingMapsAttr = llvm::map_to_vector(
+ ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims,
+ parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+ return success();
+void ElemwiseOp::print(OpAsmPrinter &p) {
+ p << " func_type=";
+ p.printAttribute(getFuncTypeAttr());
+ SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "func_type",
+ "indexing_maps"};
+ auto category = getNAryCategoryAndFn(getElemwiseFnVal()).category;
+ auto arity = getArityFromCategory(category);
+ auto numDims = getResultRank();
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ ElemwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+LogicalResult ElemwiseOp::verify() {
+ // All necessary checks are done either by
+ // - EnumAttr (e.g. unknown func_type)
+ // - verifyStructuredOpInterface (incorrect map, sizes).
+ return success();
+/// Implements the block region builder for the ElemwiseOp. This is called by
+/// 'fillStructuredOpRegion'.
+void ElemwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ ElemwiseFn elemwiseFn;
+ for (auto attr : attrs) {
+ if (attr.getName() == b.getStringAttr("func_type")) {
+ auto funcTypeAttr = dyn_cast<ElemwiseFnAttr>(attr.getValue());
+ assert(funcTypeAttr && "func_type attribute incorrectly set");
+ elemwiseFn = funcTypeAttr.getValue();
+ break;
+ }
+ }
+ NAryCategoryAndFn categoryAndFn = getNAryCategoryAndFn(elemwiseFn);
+ ElemwiseNAryCategory category = categoryAndFn.category;
+ unsigned numBlockArgs = getArityFromCategory(categoryAndFn.category) + 1;
+ assert(block.getNumArguments() == numBlockArgs &&
+ "Elemwise regionBuilder number of block args mismatch");
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+ Value result;
+ if (category == ElemwiseNAryCategory::Unary) {
+ result =
+ helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0));
+ } else if (category == ElemwiseNAryCategory::Binary) {
+ result = helper.buildBinaryFn(categoryAndFn.fn.binaryFn,
+ block.getArgument(0), block.getArgument(1));
+ } else if (category == ElemwiseNAryCategory::Ternary) {
+ result = helper.buildTernaryFn(categoryAndFn.fn.ternaryFn,
+ block.getArgument(0), block.getArgument(1), block.getArgument(2));
+ } else
+ assert(false && "found unhandled category in elemwise print");
+ yields.push_back(result);
+ helper.yieldOutputs(yields);
+LogicalResult ElemwiseOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+void ElemwiseOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+Speculation::Speculatability ElemwiseOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elemwise/generalize-named-ops.mlir
new file mode 100644
index 00...
More information about the Mlir-commits
mailing list