[Mlir-commits] [mlir] [mlir][linalg] Extend elementwise (PR #124661)
Javed Absar
llvmlistbot at llvm.org
Thu Feb 20 13:36:11 PST 2025
https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/124661
>From 0b37550ef93b0028771ed2e731e215dd02611f4c Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Fri, 24 Jan 2025 08:30:07 -0500
Subject: [PATCH 1/3] [mlir][linalg] Extend elementwise
---
.../mlir/Dialect/Linalg/IR/LinalgBase.td | 6 +
.../mlir/Dialect/Linalg/IR/LinalgEnums.td | 44 +++
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 116 ++++++++
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 263 ++++++++++++++++++
.../element_wise/generalize_named_ops.mlir | 157 +++++++++++
.../Dialect/Linalg/element_wise/invalid.mlir | 54 ++++
.../Linalg/element_wise/round-trip.mlir | 88 ++++++
7 files changed, 728 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir
create mode 100644 mlir/test/Dialect/Linalg/element_wise/invalid.mlir
create mode 100644 mlir/test/Dialect/Linalg/element_wise/round-trip.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d3..00e3633610ccb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
}];
}
+// Define the attribute enums matching elementwise op function (e.g., add).
+def ElementwiseFnAttr : EnumAttr<Linalg_Dialect,
+ ElementwiseFn, "elementwise_fn"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
// Define the function attribute enums matching the OpDSL functions.
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d05..c41b541835751 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -55,6 +55,50 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
+
+// Join two I32EnumAttrCase lists. This joining takes care that the
+// 'int enum values' in the combined list do not overlap. It does this
+// by adding to each element of second list the offset '!size(a)'.
+class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
+ list<I32EnumAttrCase> b> {
+ int aSize = !size(a);
+ list<I32EnumAttrCase> result =
+ !foldl(a, b, acc, var,
+ acc # [I32EnumAttrCase<var.symbol,
+ !add(var.value, aSize)
+ >]);
+}
+
+// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
+// The flattening (via call to 'join') ensures no overlap in enum values.
+class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
+ list<I32EnumAttrCase> result =
+ !foldl([]<I32EnumAttrCase>, l, acc, var,
+ JoinTwoI32EnumAttrCaseList<acc, var>.result);
+}
+
+// Define a unified `enum class : i32` for all element-wise op functions.
+def ElementwiseFn :
+ I32EnumAttr<"ElementwiseFn",
+ "",
+ ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
+ BinaryFn.enumerants,
+ TernaryFn.enumerants]>.result
+ > {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
+// Define an `enum class : i32` to categorise elementwise ops.
+def ElementwiseNAryCategory : I32EnumAttr<"ElementwiseNAryCategory", "", [
+ I32EnumAttrCase<"Unary", 0>,
+ I32EnumAttrCase<"Binary", 1>,
+ I32EnumAttrCase<"Ternary", 2>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast_signed", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e..2d82eef41c2f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -551,6 +551,122 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// Op definition for ElementwiseOp
+//===----------------------------------------------------------------------===//
+def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
+ AttrSizedOperandSegments]> {
+ let summary = [{ Performs element-wise operation }];
+ let description = [{
+ Linalg op form which performs element-wise computation.
+
+ The attribute `kind` describes the operation (e.g. add, exp). The operation
+ kind can be any elementwise nary (e.g. unary, binary) operation.
+
+ Affine-maps for operands and result are reuired to be provided by the user
+ when transpose and/or broadcast is needed on any operand. When a map is not
+ provided, default identity maps are inferred for each operand. The number
+ of dims in each of the identity maps is equal to the rank of the output type.
+ In the case of default indexing map, all input and output shapes must match.
+ User-defined Affine-map for operands and result must only be projected
+ permutations with no zero constants.
+
+ For elementwise, iterator-types are always 'all parallel’.
+ Iterator-types are needed for constructing the underlying structured op.
+ The number of dims of the iterator-types are inferred from the rank of
+ the result type.
+
+ Example:
+
+ Defining a unary linalg.elemwise with default indexing-map:
+ ```mlir
+ %exp = linalg.elemwise
+ kind=#linalg.elemwise_fn<exp>
+ ins(%x : tensor<4x16x8xf32>)
+ outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
+ ```
+
+ Defining a binary linalg.elemwise with user-defined indexing-map:
+ ```mlir
+ %add = linalg.elemwise
+ kind=#linalg.elemwise_fn<add>
+ indexing_maps = [#transpose, #broadcast, #identity]
+ ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
+ outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ ElementwiseFnAttr:$kind,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ );
+
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+ let skipDefaultBuilders = 1;
+
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildElementwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, ElementwiseOp::getRegionBuilder());
+ }]>
+ ];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ /// Get the nary category enum, e.g. `ElementwiseNAryCategory::Unary`,
+ /// corresponding to the given fn, e.g. `ElementwiseFn::exp`
+ static ElementwiseNAryCategory getNAryCategory(ElementwiseFn fn);
+
+ /// Both user-specified and default indexing map will always depend on
+ /// the current Op instance.
+ static bool hasDynamicIndexingMaps() { return true; }
+
+ /// Implements the block region builder for the elementwiseOp. This is
+ /// called by the 'fillStructuredOpRegion'.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ /// Returns rank of the result tensor/memref. Useful for knowing
+ /// the dimensionality of the iteration space when others means
+ /// are not possible e.g. absence of user-provided indexing map.
+ unsigned getResultRank();
+
+ /// Returns N 'parallel' iterator types where N is rank of result.
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// The default indexing maps are identities.
+ /// There will be N such maps, where N is the arity of the Op.
+ static SmallVector<AffineMap>
+ getDefaultIndexingMaps(unsigned N, unsigned numDims,
+ MLIRContext *context);
+
+ /// Destination passing style interface method.
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ std::string getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b..e1947a864d4d0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,6 +203,15 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
+static void buildElementwiseOp(OpBuilder &b, OperationState &state,
+ std::optional<TypeRange> resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes,
+ RegionBuilderFn regionBuilder) {
+ return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
+ attributes, regionBuilder);
+}
+
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3611,5 +3620,259 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// ElementwiseOp
+//===----------------------------------------------------------------------===//
+//
+namespace {
+
+struct NAryCategoryAndFn {
+ // The enum category class {Unary, Binary, Ternary, ..}
+ ElementwiseNAryCategory category;
+
+ union NAryFn {
+ UnaryFn unaryFn;
+ BinaryFn binaryFn;
+ TernaryFn ternaryFn;
+ } fn;
+
+ ::llvm::StringRef stringifyCategory() {
+ return stringifyElementwiseNAryCategory(category);
+ }
+
+ ::llvm::StringRef stringifyFn() {
+ switch (category) {
+ case ElementwiseNAryCategory::Unary:
+ return stringifyUnaryFn(fn.unaryFn);
+ case ElementwiseNAryCategory::Binary:
+ return stringifyBinaryFn(fn.binaryFn);
+ case ElementwiseNAryCategory::Ternary:
+ return stringifyTernaryFn(fn.ternaryFn);
+ }
+ llvm_unreachable("unknown-fn");
+ }
+};
+
+unsigned getArityFromCategory(ElementwiseNAryCategory category) {
+ switch (category) {
+ case ElementwiseNAryCategory::Unary:
+ return 1;
+ case ElementwiseNAryCategory::Binary:
+ return 2;
+ case ElementwiseNAryCategory::Ternary:
+ return 3;
+ }
+ llvm_unreachable("unhandled category");
+}
+} // namespace
+
+static NAryCategoryAndFn getNAryCategoryAndFn(ElementwiseFn fn) {
+ constexpr int lastUnary = static_cast<int>(ElementwiseFn::erf);
+ constexpr int lastBinary = static_cast<int>(ElementwiseFn::powf);
+ constexpr int lastTernary = static_cast<int>(ElementwiseFn::select);
+
+ int val = static_cast<int>(fn);
+ NAryCategoryAndFn result;
+
+ if (val <= lastUnary) {
+ result.category = ElementwiseNAryCategory::Unary;
+ result.fn.unaryFn = static_cast<UnaryFn>(val);
+ return result;
+ }
+ if (val <= lastBinary) {
+ result.category = ElementwiseNAryCategory::Binary;
+ result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary - 1);
+ return result;
+ }
+ if (val > lastTernary) {
+ llvm_unreachable("unhandled ElementwiseFn");
+ }
+ result.category = ElementwiseNAryCategory::Ternary;
+ result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary - 1);
+ return result;
+}
+
+unsigned ElementwiseOp::getResultRank() {
+ auto output = getDpsInitOperand(0)->get();
+ auto shapedType = llvm::cast<ShapedType>(output.getType());
+ return shapedType.getRank();
+}
+
+SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
+ auto rank = getResultRank();
+ return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+}
+
+SmallVector<AffineMap>
+ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+ MLIRContext *context) {
+ auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+ return SmallVector<AffineMap>(numMaps, map);
+}
+
+ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Expect e.g. `kind = #linalg.elemwise_fn<add>`
+ Attribute attr;
+ mlir::linalg::ElementwiseFn elemwiseFnVal;
+ if (parser.parseKeyword("kind"))
+ return failure();
+ if (parser.parseEqual())
+ return failure();
+ if (succeeded(parser.parseAttribute(attr))) {
+ auto elemwiseFnAttr = dyn_cast<ElementwiseFnAttr>(attr);
+ if (!elemwiseFnAttr)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ElementwiseFn attribute");
+ elemwiseFnVal = elemwiseFnAttr.getValue();
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected operation 'kind' attribute");
+ }
+ result.addAttribute(
+ "kind", ElementwiseFnAttr::get(parser.getContext(), elemwiseFnVal));
+
+ // Parse optional `indexing_maps`
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+ if (parser.parseLSquare())
+ return failure();
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ indexingMapsAttr.push_back(mapAttr);
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+ if (parser.parseRSquare())
+ return failure();
+ }
+
+ // At this stage of parsing the only way to infer number of region
+ // args is through op kind, as input output tensors are not parsed yet.
+ auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
+ auto arity = getArityFromCategory(arityAndCategory.category);
+ int numRegionArgs = arity + 1 /*output*/;
+ if (parseNamedStructuredOp(parser, result, numRegionArgs,
+ ElementwiseOp::getRegionBuilder())) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "unable to parse elemwise op");
+ }
+
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ // We need to infer the `number of indexing maps` needed from the result
+ // type which is already parsed by now.
+ auto resultType = result.operands[result.operands.size() - 1].getType();
+ auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+ if (!shapedType)
+ return parser.emitError(parser.getCurrentLocation(),
+ "return type needs to be shaped type");
+ auto numDims = shapedType.getRank();
+ indexingMapsAttr = llvm::map_to_vector(
+ ElementwiseOp::getDefaultIndexingMaps(arity + 1, numDims,
+ parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+ return success();
+}
+
+void ElementwiseOp::print(OpAsmPrinter &p) {
+ p << " kind=";
+ p.printAttribute(getKindAttr());
+ SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
+ "indexing_maps"};
+ auto category = getNAryCategoryAndFn(getKind()).category;
+ auto arity = getArityFromCategory(category);
+ auto numDims = getResultRank();
+
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ ElementwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+}
+
+LogicalResult ElementwiseOp::verify() {
+ // All necessary checks are done either by
+ // - EnumAttr (e.g. unknown operation kind)
+ // - verifyStructuredOpInterface (incorrect map, sizes).
+ return success();
+}
+
+/// Implements the block region builder for the ElementwiseOp. This is called by
+/// 'fillStructuredOpRegion'.
+void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ ElementwiseFn elemwiseFn;
+ for (auto attr : attrs) {
+ if (attr.getName() == b.getStringAttr("kind")) {
+ auto funcTypeAttr = dyn_cast<ElementwiseFnAttr>(attr.getValue());
+ assert(funcTypeAttr && "op kind attribute incorrectly set");
+ elemwiseFn = funcTypeAttr.getValue();
+ break;
+ }
+ }
+
+ NAryCategoryAndFn categoryAndFn = getNAryCategoryAndFn(elemwiseFn);
+ ElementwiseNAryCategory category = categoryAndFn.category;
+ unsigned numBlockArgs = getArityFromCategory(categoryAndFn.category) + 1;
+ assert(block.getNumArguments() == numBlockArgs &&
+ "Elementwise regionBuilder number of block args mismatch");
+
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+ Value result;
+
+ if (category == ElementwiseNAryCategory::Unary) {
+ result =
+ helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0));
+ } else if (category == ElementwiseNAryCategory::Binary) {
+ result = helper.buildBinaryFn(categoryAndFn.fn.binaryFn,
+ block.getArgument(0), block.getArgument(1));
+ } else if (category == ElementwiseNAryCategory::Ternary) {
+ result =
+ helper.buildTernaryFn(categoryAndFn.fn.ternaryFn, block.getArgument(0),
+ block.getArgument(1), block.getArgument(2));
+ } else
+ assert(false && "found unhandled category in elemwise print");
+
+ yields.push_back(result);
+ helper.yieldOutputs(yields);
+}
+
+LogicalResult ElementwiseOp::fold(FoldAdaptor,
+ SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+
+void ElementwiseOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ElementwiseOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir b/mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir
new file mode 100644
index 0000000000000..2466a77acc236
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops -split-input-file | FileCheck %s
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//
+// CHECK: @unary_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]]
+// CHECK-SAME: outs(%[[B]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
+// CHECK: %[[EXP:.+]] = math.exp %[[A_ARG]] : f32
+// CHECK: linalg.yield %[[EXP]] : f32
+//
+func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_fn<exp>
+ ins(%A : tensor<8x16x32xf32>)
+ outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %r : tensor<8x16x32xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: @unary_transpose_broadcast_tanh(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]]
+// CHECK-SAME: outs(%[[B]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
+// CHECK: %[[TANH:.+]] = math.tanh %[[A_ARG]] : f32
+// CHECK: linalg.yield %[[TANH]] : f32
+//
+func.func @unary_transpose_broadcast_tanh(%A : tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_fn<tanh>
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+ ins(%A : tensor<32x16xf32>)
+ outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %r : tensor<8x16x32xf32>
+}
+// -----
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//
+// CHECK: @binary_div_on_memrefs(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[DIV:.+]] = arith.divf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK: linalg.yield %[[DIV]] : f32
+//
+func.func @binary_div_on_memrefs(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+ linalg.elementwise
+ kind=#linalg.elementwise_fn<div>
+ ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>)
+ outs(%C: memref<16x8xf32>)
+ return
+}
+// -----
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//
+// CHECK: @binary_mul_on_tensors(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK: linalg.yield %[[MUL]] : f32
+//
+func.func @binary_mul_on_tensors(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_fn<mul>
+ ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//
+// CHECK: @binary_transpose_a(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[SUB:.+]] = arith.subf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK: linalg.yield %[[SUB]] : f32
+//
+func.func @binary_transpose_a(%A : tensor<8x16xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_fn<sub>
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>]
+ ins(%A, %B: tensor<8x16xf32>, tensor<16x8xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-DAG: #[[BROADCAST:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK: @binary_transpose_a_broadcast_b(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[BROADCAST]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[ADD:.+]] = arith.addf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+//
+func.func @binary_transpose_a_broadcast_b(%A : tensor<8x16xf32>, %B: tensor<16xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_fn<add>
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>]
+ ins(%A, %B: tensor<8x16xf32>, tensor<16xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: @ternary(%[[A:.+]]: tensor<32x16xi1>, %[[B:.+]]: tensor<8x16x32xf32>, %[[C:.+]]: tensor<8x16x32xf32>, %[[D:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]], #[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+//
+// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]]
+// CHECK-SAME: outs(%[[D]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i1, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32, %[[D_ARG:.+]]: f32)
+// CHECK: %[[SELECTED:.+]] = arith.select %[[A_ARG]], %[[B_ARG]], %[[C_ARG]] : f32
+// CHECK: linalg.yield %[[SELECTED]] : f32
+//
+func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8x16x32xf32>, %D : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_fn<select>
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+ ins(%A, %B, %C : tensor<32x16xi1>, tensor<8x16x32xf32>, tensor<8x16x32xf32>)
+ outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %r : tensor<8x16x32xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/element_wise/invalid.mlir b/mlir/test/Dialect/Linalg/element_wise/invalid.mlir
new file mode 100644
index 0000000000000..519183e580538
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/element_wise/invalid.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+func.func @misspelt_op_div(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+ // expected-error at +3 {{expected ::mlir::linalg::ElementwiseFn to be one of: exp, log, abs, ceil, floor}}
+ // expected-error at +2 {{failed to parse ElementwiseFnAttr parameter}}
+ // expected-error at +1 {{custom op 'linalg.elementwise' expected operation 'kind' attribute}}
+ linalg.elementwise kind=#linalg.elementwise_fn<dive> ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+ return
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @missing_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+ // expected-error at +1 {{'linalg.elementwise' op expected the number of indexing_map (2) to be equal to the number of input/output operands (3)}}
+ linalg.elementwise kind=#linalg.elementwise_fn<div> indexing_maps = [#map, #map] ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+ return
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @identity_map_when_transpose_expected(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+ // expected-error at +1 {{'linalg.elementwise' op inferred input/output operand #1 has shape's dimension #0 to be 8, but found 16}}
+ linalg.elementwise kind=#linalg.elementwise_fn<div> indexing_maps = [#map, #map, #map] ins(%A, %B: memref<8x16xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+ return
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+func.func @incorrect_result_rank(%A : memref<8x16xf32>, %B: memref<8x16xf32>, %C: memref<8xf32>) {
+ // expected-error at +1 {{'linalg.elementwise' op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
+ linalg.elementwise kind=#linalg.elementwise_fn<div> indexing_maps = [#map, #map, #map1] ins(%A, %B: memref<8x16xf32>, memref<8x16xf32>) outs(%C: memref<8xf32>)
+ return
+}
+
+// -----
+
+func.func @unary_too_many_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>, %C: memref<8x16x32xf32>) {
+ // expected-error at +3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 2 args, got 3}}
+ // expected-error at +2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
+ linalg.elementwise kind=#linalg.elementwise_fn<exp> ins(%A, %B : memref<8x16x32xf32>, memref<8x16x32xf32>) outs(%C: memref<8x16x32xf32>)
+ return
+}
+
+// -----
+
+func.func @binary_too_few_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>) {
+ // expected-error at +3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 3 args, got 2}}
+ // expected-error at +2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
+ linalg.elementwise kind=#linalg.elementwise_fn<add> ins(%A : memref<8x16x32xf32>) outs(%B: memref<8x16x32xf32>)
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/element_wise/round-trip.mlir b/mlir/test/Dialect/Linalg/element_wise/round-trip.mlir
new file mode 100644
index 0000000000000..f4659f89785e4
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/element_wise/round-trip.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+//
+// Note - the functions are named @{unary|binary}_{identity|transpose|broadcast|transpose_a|...}_{exp|mul|div|..}
+
+// CHECK: @unary_identity_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: %{{.*}} = linalg.elementwise kind=#linalg.elementwise_fn<exp>
+// CHECK-SAME ins(%[[A:.+]] : tensor<8x16x32xf32>) outs(%[[B:.+]] : tensor<8x16x32xf32>)
+//
+func.func @unary_identity_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_fn<exp>
+ ins(%A : tensor<8x16x32xf32>)
+ outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %r : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: @unary_projection_tanh(%[[A:.+]]: tensor<?x16xf32>,
+// CHECK-SAME: %[[B:.+]]: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
+// CHECK: {{.*}} = linalg.elementwise kind=#linalg.elementwise_fn<tanh>
+// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]] : tensor<?x16xf32>) outs(%[[B]] : tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
+//
+func.func @unary_projection_tanh(%A: tensor<?x16xf32>,
+ %B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_fn<tanh>
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+ ins(%A : tensor<?x16xf32>)
+ outs(%B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
+ return %r : tensor<8x16x?xf32>
+}
+
+// -----
+
+// CHECK: @binary_identity_div(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>,
+// CHECK-SAME: %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_fn<div>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @binary_identity_div(%A: tensor<16x8xf32>, %B: tensor<16x8xf32>,
+ %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_fn<div>
+ ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
+
+// -----
+
+// CHECK: @binary_identity_mul_5Di(%[[A]]: tensor<1x2x3x4x5xi32>,
+// CHECK-SAME: %[[B:.+]]: tensor<1x2x3x4x5xi32>,
+// CHECK-SAME: %[[C:.+]]: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_fn<mul>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
+// CHECK-SAME: outs(%[[C]] : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
+//
+func.func @binary_identity_mul_5Di(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
+ %C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_fn<mul>
+ ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
+ outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
+ return %r : tensor<1x2x3x4x5xi32>
+}
+
+// -----
+
+// CHECK: @redundant_maps
+// CHECK-NOT: indexing_maps
+//
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
+ %C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
+ %r = linalg.elementwise kind=#linalg.elementwise_fn<mul>
+ indexing_maps = [#map, #map, #map]
+ ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
+ outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
+ return %r : tensor<1x2x3x4x5xi32>
+}
>From 1a68c5f937d0001a86974c56f54c8b36935b648b Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 1 Feb 2025 16:40:52 -0500
Subject: [PATCH 2/3] [mlir][linalg] change based on review comments
---
.../mlir/Dialect/Linalg/IR/LinalgEnums.td | 15 +++++++++++++++
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 8 ++++----
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 17 +++++++++--------
3 files changed, 28 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index c41b541835751..36e8edc11cc97 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -89,6 +89,21 @@ def ElementwiseFn :
let cppNamespace = "::mlir::linalg";
}
+// Define an `enum class : i32` that marks where each individual enum class
+// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseFn.
+def ElementwiseFnLimits : I32EnumAttr<"ElementwiseFnLimits", "", []> {
+ int last_unary = !size(UnaryFn.enumerants);
+ int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
+ int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));
+
+ let enumerants = [
+ I32EnumAttrCase<"LastUnary", last_unary>,
+ I32EnumAttrCase<"LastBinary", last_binary>,
+ I32EnumAttrCase<"LastTernary", last_ternary>];
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
// Define an `enum class : i32` to categorise elementwise ops.
def ElementwiseNAryCategory : I32EnumAttr<"ElementwiseNAryCategory", "", [
I32EnumAttrCase<"Unary", 0>,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 2d82eef41c2f2..1a67174db89fb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -563,15 +563,15 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
The attribute `kind` describes the operation (e.g. add, exp). The operation
kind can be any elementwise nary (e.g. unary, binary) operation.
- Affine-maps for operands and result are reuired to be provided by the user
+ Affine-maps for operands and result are required to be provided by the user
when transpose and/or broadcast is needed on any operand. When a map is not
provided, default identity maps are inferred for each operand. The number
of dims in each of the identity maps is equal to the rank of the output type.
In the case of default indexing map, all input and output shapes must match.
- User-defined Affine-map for operands and result must only be projected
+ User-defined affine-map for operands and result must only be projected
permutations with no zero constants.
- For elementwise, iterator-types are always 'all parallel’.
+ For elementwise, iterator-types are always `all parallel`.
Iterator-types are needed for constructing the underlying structured op.
The number of dims of the iterator-types are inferred from the rank of
the result type.
@@ -597,7 +597,7 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
}];
let arguments = (ins
- Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
ElementwiseFnAttr:$kind,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e1947a864d4d0..0db2bf57bd23c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3667,28 +3667,29 @@ unsigned getArityFromCategory(ElementwiseNAryCategory category) {
} // namespace
static NAryCategoryAndFn getNAryCategoryAndFn(ElementwiseFn fn) {
- constexpr int lastUnary = static_cast<int>(ElementwiseFn::erf);
- constexpr int lastBinary = static_cast<int>(ElementwiseFn::powf);
- constexpr int lastTernary = static_cast<int>(ElementwiseFn::select);
+ constexpr int lastUnary = static_cast<int>(ElementwiseFnLimits::LastUnary);
+ constexpr int lastBinary = static_cast<int>(ElementwiseFnLimits::LastBinary);
+ constexpr int lastTernary =
+ static_cast<int>(ElementwiseFnLimits::LastTernary);
int val = static_cast<int>(fn);
NAryCategoryAndFn result;
- if (val <= lastUnary) {
+ if (val < lastUnary) {
result.category = ElementwiseNAryCategory::Unary;
result.fn.unaryFn = static_cast<UnaryFn>(val);
return result;
}
- if (val <= lastBinary) {
+ if (val < lastBinary) {
result.category = ElementwiseNAryCategory::Binary;
- result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary - 1);
+ result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary);
return result;
}
- if (val > lastTernary) {
+ if (val >= lastTernary) {
llvm_unreachable("unhandled ElementwiseFn");
}
result.category = ElementwiseNAryCategory::Ternary;
- result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary - 1);
+ result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
return result;
}
>From 42325d10ff7bf273044515532bcf6d91e84cdbe4 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Mon, 17 Feb 2025 13:36:14 -0500
Subject: [PATCH 3/3] [mlir][linalg] change based on even more review comments.
---
.../mlir/Dialect/Linalg/IR/LinalgBase.td | 6 +-
.../mlir/Dialect/Linalg/IR/LinalgEnums.td | 18 +-
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 49 +++---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 155 +++++++-----------
.../element_wise/generalize_named_ops.mlir | 30 ++--
.../Dialect/Linalg/element_wise/invalid.mlir | 16 +-
.../Linalg/element_wise/round-trip.mlir | 22 +--
7 files changed, 136 insertions(+), 160 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 00e3633610ccb..4452189fde64f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,9 +61,9 @@ def Linalg_Dialect : Dialect {
}];
}
-// Define the attribute enums matching elementwise op function (e.g., add).
-def ElementwiseFnAttr : EnumAttr<Linalg_Dialect,
- ElementwiseFn, "elementwise_fn"> {
+// Define the attribute enums matching elementwise op kind (e.g., add).
+def ElementwiseKindAttr : EnumAttr<Linalg_Dialect,
+ ElementwiseKind, "elementwise_kind"> {
let assemblyFormat = "`<` $value `>`";
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 36e8edc11cc97..ce68afe471fe8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -78,8 +78,8 @@ class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
}
// Define a unified `enum class : i32` for all element-wise op functions.
-def ElementwiseFn :
- I32EnumAttr<"ElementwiseFn",
+def ElementwiseKind :
+ I32EnumAttr<"ElementwiseKind",
"",
ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
BinaryFn.enumerants,
@@ -90,8 +90,8 @@ def ElementwiseFn :
}
// Define an `enum class : i32` that marks where each individual enum class
-// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseFn.
-def ElementwiseFnLimits : I32EnumAttr<"ElementwiseFnLimits", "", []> {
+// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseKind.
+def ElementwiseCaseLimits : I32EnumAttr<"ElementwiseCaseLimits", "", []> {
int last_unary = !size(UnaryFn.enumerants);
int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));
@@ -104,11 +104,11 @@ def ElementwiseFnLimits : I32EnumAttr<"ElementwiseFnLimits", "", []> {
let cppNamespace = "::mlir::linalg";
}
-// Define an `enum class : i32` to categorise elementwise ops.
-def ElementwiseNAryCategory : I32EnumAttr<"ElementwiseNAryCategory", "", [
- I32EnumAttrCase<"Unary", 0>,
- I32EnumAttrCase<"Binary", 1>,
- I32EnumAttrCase<"Ternary", 2>
+// Define an `enum class : i32` to categorise arity elementwise ops.
+def ElementwiseArityGroup : I32EnumAttr<"ElementwiseArityGroup", "", [
+ I32EnumAttrCase<"Unary", 1>,
+ I32EnumAttrCase<"Binary", 2>,
+ I32EnumAttrCase<"Ternary", 3>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 1a67174db89fb..59b92d8130a47 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -558,20 +558,19 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
AttrSizedOperandSegments]> {
let summary = [{ Performs element-wise operation }];
let description = [{
- Linalg op form which performs element-wise computation.
+ The attribute `kind` describes arithmetic operation to perform. The
+ operation kind can be unary (e.g. max), binary (e.g. add) or ternary
+ (e.g. select).
- The attribute `kind` describes the operation (e.g. add, exp). The operation
- kind can be any elementwise nary (e.g. unary, binary) operation.
+ By default, all indexing maps are identities. In the case of default
+ indexing map, all input and output shapes must match. The number of dims in
+ each of the identity maps is equal to the rank of the output type.
Affine-maps for operands and result are required to be provided by the user
- when transpose and/or broadcast is needed on any operand. When a map is not
- provided, default identity maps are inferred for each operand. The number
- of dims in each of the identity maps is equal to the rank of the output type.
- In the case of default indexing map, all input and output shapes must match.
- User-defined affine-map for operands and result must only be projected
- permutations with no zero constants.
-
- For elementwise, iterator-types are always `all parallel`.
+ when a transpose and/or broadcast is needed on any operand. When a map is not
+ provided, default identity maps are inferred for each operand.
+
+ Iterator-types are always all `parallel`.
Iterator-types are needed for constructing the underlying structured op.
The number of dims of the iterator-types are inferred from the rank of
the result type.
@@ -581,7 +580,7 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
Defining a unary linalg.elemwise with default indexing-map:
```mlir
%exp = linalg.elemwise
- kind=#linalg.elemwise_fn<exp>
+ kind=#linalg.elemwise_kind<exp>
ins(%x : tensor<4x16x8xf32>)
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
```
@@ -589,7 +588,7 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
Defining a binary linalg.elemwise with user-defined indexing-map:
```mlir
%add = linalg.elemwise
- kind=#linalg.elemwise_fn<add>
+ kind=#linalg.elemwise_kind<add>
indexing_maps = [#transpose, #broadcast, #identity]
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
@@ -597,9 +596,9 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
}];
let arguments = (ins
- Variadic<AnyShaped>:$inputs,
+ Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
- ElementwiseFnAttr:$kind,
+ ElementwiseKindAttr:$kind,
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
);
@@ -612,7 +611,7 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
- buildElementwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
attributes, ElementwiseOp::getRegionBuilder());
}]>
];
@@ -622,12 +621,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
let hasVerifier = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{
- /// Get the nary category enum, e.g. `ElementwiseNAryCategory::Unary`,
- /// corresponding to the given fn, e.g. `ElementwiseFn::exp`
- static ElementwiseNAryCategory getNAryCategory(ElementwiseFn fn);
+ /// Get the arity enum corresponding to the kind of op, e.g. if arg is
+ /// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
+ static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
/// Both user-specified and default indexing map will always depend on
- /// the current Op instance.
+ /// the current Op instance.
static bool hasDynamicIndexingMaps() { return true; }
/// Implements the block region builder for the elementwiseOp. This is
@@ -644,15 +643,19 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
/// Returns rank of the result tensor/memref. Useful for knowing
/// the dimensionality of the iteration space when others means
/// are not possible e.g. absence of user-provided indexing map.
- unsigned getResultRank();
+ unsigned getResultRank() {
+ Value output = getDpsInitOperand(0)->get();
+ ShapedType shapedType = llvm::cast<ShapedType>(output.getType());
+ return shapedType.getRank();
+ }
/// Returns N 'parallel' iterator types where N is rank of result.
SmallVector<utils::IteratorType> getIteratorTypesArray();
/// The default indexing maps are identities.
- /// There will be N such maps, where N is the arity of the Op.
+ /// There will be N+1 such maps, where N is the arity of the Op.
static SmallVector<AffineMap>
- getDefaultIndexingMaps(unsigned N, unsigned numDims,
+ getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims,
MLIRContext *context);
/// Destination passing style interface method.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 0db2bf57bd23c..e3374a577846e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -203,15 +203,6 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
attributes, regionBuilder);
}
-static void buildElementwiseOp(OpBuilder &b, OperationState &state,
- std::optional<TypeRange> resultTensorTypes,
- ValueRange inputs, ValueRange outputs,
- ArrayRef<NamedAttribute> attributes,
- RegionBuilderFn regionBuilder) {
- return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
- attributes, regionBuilder);
-}
-
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
@@ -3625,80 +3616,51 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
//===----------------------------------------------------------------------===//
//
namespace {
+struct ArityGroupAndKind {
+ // The enum class {Unary, Binary, Ternary, ..}
+ ElementwiseArityGroup arityGroup;
-struct NAryCategoryAndFn {
- // The enum category class {Unary, Binary, Ternary, ..}
- ElementwiseNAryCategory category;
-
- union NAryFn {
+ // The kind (e.g. `exp` or `add`) belonging to the arity group.
+ union Kind {
UnaryFn unaryFn;
BinaryFn binaryFn;
TernaryFn ternaryFn;
- } fn;
-
- ::llvm::StringRef stringifyCategory() {
- return stringifyElementwiseNAryCategory(category);
- }
-
- ::llvm::StringRef stringifyFn() {
- switch (category) {
- case ElementwiseNAryCategory::Unary:
- return stringifyUnaryFn(fn.unaryFn);
- case ElementwiseNAryCategory::Binary:
- return stringifyBinaryFn(fn.binaryFn);
- case ElementwiseNAryCategory::Ternary:
- return stringifyTernaryFn(fn.ternaryFn);
- }
- llvm_unreachable("unknown-fn");
- }
+ } kind;
};
-unsigned getArityFromCategory(ElementwiseNAryCategory category) {
- switch (category) {
- case ElementwiseNAryCategory::Unary:
- return 1;
- case ElementwiseNAryCategory::Binary:
- return 2;
- case ElementwiseNAryCategory::Ternary:
- return 3;
- }
- llvm_unreachable("unhandled category");
+unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
+ return static_cast<unsigned>(arityGroup);
}
} // namespace
-static NAryCategoryAndFn getNAryCategoryAndFn(ElementwiseFn fn) {
- constexpr int lastUnary = static_cast<int>(ElementwiseFnLimits::LastUnary);
- constexpr int lastBinary = static_cast<int>(ElementwiseFnLimits::LastBinary);
+static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
+ constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
+ constexpr int lastBinary =
+ static_cast<int>(ElementwiseCaseLimits::LastBinary);
constexpr int lastTernary =
- static_cast<int>(ElementwiseFnLimits::LastTernary);
+ static_cast<int>(ElementwiseCaseLimits::LastTernary);
- int val = static_cast<int>(fn);
- NAryCategoryAndFn result;
+ int val = static_cast<int>(kind);
+ ArityGroupAndKind result;
if (val < lastUnary) {
- result.category = ElementwiseNAryCategory::Unary;
- result.fn.unaryFn = static_cast<UnaryFn>(val);
+ result.arityGroup = ElementwiseArityGroup::Unary;
+ result.kind.unaryFn = static_cast<UnaryFn>(val);
return result;
}
if (val < lastBinary) {
- result.category = ElementwiseNAryCategory::Binary;
- result.fn.binaryFn = static_cast<BinaryFn>(val - lastUnary);
+ result.arityGroup = ElementwiseArityGroup::Binary;
+ result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
return result;
}
if (val >= lastTernary) {
llvm_unreachable("unhandled ElementwiseFn");
}
- result.category = ElementwiseNAryCategory::Ternary;
- result.fn.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
+ result.arityGroup = ElementwiseArityGroup::Ternary;
+ result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
return result;
}
-unsigned ElementwiseOp::getResultRank() {
- auto output = getDpsInitOperand(0)->get();
- auto shapedType = llvm::cast<ShapedType>(output.getType());
- return shapedType.getRank();
-}
-
SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
auto rank = getResultRank();
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
@@ -3712,25 +3674,24 @@ ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
}
ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
- // Expect e.g. `kind = #linalg.elemwise_fn<add>`
+ // Expect e.g. `kind = #linalg.elemwise_kind<add>`
Attribute attr;
- mlir::linalg::ElementwiseFn elemwiseFnVal;
- if (parser.parseKeyword("kind"))
- return failure();
- if (parser.parseEqual())
+ mlir::linalg::ElementwiseKind elemwiseKindVal;
+ if (parser.parseKeyword("kind") || parser.parseEqual())
return failure();
+
if (succeeded(parser.parseAttribute(attr))) {
- auto elemwiseFnAttr = dyn_cast<ElementwiseFnAttr>(attr);
- if (!elemwiseFnAttr)
+ auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
+ if (!elemwiseKindAttr)
return parser.emitError(parser.getCurrentLocation(),
- "expected ElementwiseFn attribute");
- elemwiseFnVal = elemwiseFnAttr.getValue();
+ "expected ElementwiseKind attribute");
+ elemwiseKindVal = elemwiseKindAttr.getValue();
} else {
return parser.emitError(parser.getCurrentLocation(),
"expected operation 'kind' attribute");
}
result.addAttribute(
- "kind", ElementwiseFnAttr::get(parser.getContext(), elemwiseFnVal));
+ "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
// Parse optional `indexing_maps`
SmallVector<Attribute, 3> indexingMapsAttr;
@@ -3756,9 +3717,9 @@ ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
// At this stage of parsing the only way to infer number of region
// args is through op kind, as input output tensors are not parsed yet.
- auto arityAndCategory = getNAryCategoryAndFn(elemwiseFnVal);
- auto arity = getArityFromCategory(arityAndCategory.category);
- int numRegionArgs = arity + 1 /*output*/;
+ auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
+ int numRegionArgs =
+ getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
if (parseNamedStructuredOp(parser, result, numRegionArgs,
ElementwiseOp::getRegionBuilder())) {
return parser.emitError(parser.getCurrentLocation(),
@@ -3767,7 +3728,7 @@ ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
// Initialize indexingMaps, if not supplied explicitly.
if (indexingMapsAttr.empty()) {
- // We need to infer the `number of indexing maps` needed from the result
+ // We need to infer the numDims of the indexing maps from the output
// type which is already parsed by now.
auto resultType = result.operands[result.operands.size() - 1].getType();
auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
@@ -3776,7 +3737,7 @@ ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
"return type needs to be shaped type");
auto numDims = shapedType.getRank();
indexingMapsAttr = llvm::map_to_vector(
- ElementwiseOp::getDefaultIndexingMaps(arity + 1, numDims,
+ ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
parser.getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
}
@@ -3791,12 +3752,12 @@ void ElementwiseOp::print(OpAsmPrinter &p) {
p.printAttribute(getKindAttr());
SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
"indexing_maps"};
- auto category = getNAryCategoryAndFn(getKind()).category;
- auto arity = getArityFromCategory(category);
- auto numDims = getResultRank();
+ unsigned arity = getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
+ unsigned numDims = getResultRank();
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
- ElementwiseOp::getDefaultIndexingMaps(arity + 1, numDims, getContext()),
+ ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
+ getContext()),
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
@@ -3821,19 +3782,20 @@ LogicalResult ElementwiseOp::verify() {
/// 'fillStructuredOpRegion'.
void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
ArrayRef<NamedAttribute> attrs) {
- ElementwiseFn elemwiseFn;
+ ElementwiseKind elemwiseKind;
for (auto attr : attrs) {
if (attr.getName() == b.getStringAttr("kind")) {
- auto funcTypeAttr = dyn_cast<ElementwiseFnAttr>(attr.getValue());
- assert(funcTypeAttr && "op kind attribute incorrectly set");
- elemwiseFn = funcTypeAttr.getValue();
+ auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
+ assert(kindAttr && "op kind attribute incorrectly set");
+ elemwiseKind = kindAttr.getValue();
break;
}
}
- NAryCategoryAndFn categoryAndFn = getNAryCategoryAndFn(elemwiseFn);
- ElementwiseNAryCategory category = categoryAndFn.category;
- unsigned numBlockArgs = getArityFromCategory(categoryAndFn.category) + 1;
+ ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
+ auto arityGroup = groupAndKind.arityGroup;
+ auto kind = groupAndKind.kind;
+ unsigned numBlockArgs = getArityGroupAsUInt(arityGroup) + 1 /*output*/;
assert(block.getNumArguments() == numBlockArgs &&
"Elementwise regionBuilder number of block args mismatch");
@@ -3841,18 +3803,19 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
SmallVector<Value> yields;
Value result;
- if (category == ElementwiseNAryCategory::Unary) {
- result =
- helper.buildUnaryFn(categoryAndFn.fn.unaryFn, block.getArgument(0));
- } else if (category == ElementwiseNAryCategory::Binary) {
- result = helper.buildBinaryFn(categoryAndFn.fn.binaryFn,
- block.getArgument(0), block.getArgument(1));
- } else if (category == ElementwiseNAryCategory::Ternary) {
- result =
- helper.buildTernaryFn(categoryAndFn.fn.ternaryFn, block.getArgument(0),
- block.getArgument(1), block.getArgument(2));
+ if (arityGroup == ElementwiseArityGroup::Unary) {
+ result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
+
+ } else if (arityGroup == ElementwiseArityGroup::Binary) {
+ result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
+ block.getArgument(1));
+
+ } else if (arityGroup == ElementwiseArityGroup::Ternary) {
+ result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
+ block.getArgument(1), block.getArgument(2));
+
} else
- assert(false && "found unhandled category in elemwise print");
+ assert(false && "found unhandled category in elemwise");
yields.push_back(result);
helper.yieldOutputs(yields);
diff --git a/mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir b/mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir
index 2466a77acc236..1b1725692a023 100644
--- a/mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir
+++ b/mlir/test/Dialect/Linalg/element_wise/generalize_named_ops.mlir
@@ -14,7 +14,7 @@
//
func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%r = linalg.elementwise
- kind=#linalg.elementwise_fn<exp>
+ kind=#linalg.elementwise_kind<exp>
ins(%A : tensor<8x16x32xf32>)
outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
@@ -36,8 +36,9 @@ func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tens
//
func.func @unary_transpose_broadcast_tanh(%A : tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%r = linalg.elementwise
- kind=#linalg.elementwise_fn<tanh>
- indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+ kind=#linalg.elementwise_kind<tanh>
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
ins(%A : tensor<32x16xf32>)
outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
@@ -58,7 +59,7 @@ func.func @unary_transpose_broadcast_tanh(%A : tensor<32x16xf32>, %B: tensor<8x1
//
func.func @binary_div_on_memrefs(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
linalg.elementwise
- kind=#linalg.elementwise_fn<div>
+ kind=#linalg.elementwise_kind<div>
ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>)
outs(%C: memref<16x8xf32>)
return
@@ -79,7 +80,7 @@ func.func @binary_div_on_memrefs(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C
//
func.func @binary_mul_on_tensors(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
%r = linalg.elementwise
- kind=#linalg.elementwise_fn<mul>
+ kind=#linalg.elementwise_kind<mul>
ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
return %r : tensor<16x8xf32>
@@ -101,8 +102,10 @@ func.func @binary_mul_on_tensors(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C
//
func.func @binary_transpose_a(%A : tensor<8x16xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
%r = linalg.elementwise
- kind=#linalg.elementwise_fn<sub>
- indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>]
+ kind=#linalg.elementwise_kind<sub>
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>]
ins(%A, %B: tensor<8x16xf32>, tensor<16x8xf32>)
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
return %r : tensor<16x8xf32>
@@ -125,8 +128,10 @@ func.func @binary_transpose_a(%A : tensor<8x16xf32>, %B: tensor<16x8xf32>, %C: t
//
func.func @binary_transpose_a_broadcast_b(%A : tensor<8x16xf32>, %B: tensor<16xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
%r = linalg.elementwise
- kind=#linalg.elementwise_fn<add>
- indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>]
+ kind=#linalg.elementwise_kind<add>
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>]
ins(%A, %B: tensor<8x16xf32>, tensor<16xf32>)
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
return %r : tensor<16x8xf32>
@@ -149,8 +154,11 @@ func.func @binary_transpose_a_broadcast_b(%A : tensor<8x16xf32>, %B: tensor<16xf
//
func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8x16x32xf32>, %D : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%r = linalg.elementwise
- kind=#linalg.elementwise_fn<select>
- indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+ kind=#linalg.elementwise_kind<select>
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
ins(%A, %B, %C : tensor<32x16xi1>, tensor<8x16x32xf32>, tensor<8x16x32xf32>)
outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
diff --git a/mlir/test/Dialect/Linalg/element_wise/invalid.mlir b/mlir/test/Dialect/Linalg/element_wise/invalid.mlir
index 519183e580538..4567befe502a0 100644
--- a/mlir/test/Dialect/Linalg/element_wise/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/element_wise/invalid.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @misspelt_op_div(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
- // expected-error at +3 {{expected ::mlir::linalg::ElementwiseFn to be one of: exp, log, abs, ceil, floor}}
- // expected-error at +2 {{failed to parse ElementwiseFnAttr parameter}}
+ // expected-error at +3 {{expected ::mlir::linalg::ElementwiseKind to be one of: exp, log, abs, ceil, floor}}
+ // expected-error at +2 {{failed to parse ElementwiseKindAttr parameter}}
// expected-error at +1 {{custom op 'linalg.elementwise' expected operation 'kind' attribute}}
- linalg.elementwise kind=#linalg.elementwise_fn<dive> ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<dive> ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
return
}
@@ -12,7 +12,7 @@ func.func @misspelt_op_div(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memr
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @missing_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
// expected-error at +1 {{'linalg.elementwise' op expected the number of indexing_map (2) to be equal to the number of input/output operands (3)}}
- linalg.elementwise kind=#linalg.elementwise_fn<div> indexing_maps = [#map, #map] ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map] ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
return
}
@@ -21,7 +21,7 @@ func.func @missing_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C:
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @identity_map_when_transpose_expected(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
// expected-error at +1 {{'linalg.elementwise' op inferred input/output operand #1 has shape's dimension #0 to be 8, but found 16}}
- linalg.elementwise kind=#linalg.elementwise_fn<div> indexing_maps = [#map, #map, #map] ins(%A, %B: memref<8x16xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map, #map] ins(%A, %B: memref<8x16xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
return
}
@@ -31,7 +31,7 @@ func.func @identity_map_when_transpose_expected(%A : memref<8x16xf32>, %B: memre
#map1 = affine_map<(d0, d1) -> (d0)>
func.func @incorrect_result_rank(%A : memref<8x16xf32>, %B: memref<8x16xf32>, %C: memref<8xf32>) {
// expected-error at +1 {{'linalg.elementwise' op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
- linalg.elementwise kind=#linalg.elementwise_fn<div> indexing_maps = [#map, #map, #map1] ins(%A, %B: memref<8x16xf32>, memref<8x16xf32>) outs(%C: memref<8xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map, #map1] ins(%A, %B: memref<8x16xf32>, memref<8x16xf32>) outs(%C: memref<8xf32>)
return
}
@@ -40,7 +40,7 @@ func.func @incorrect_result_rank(%A : memref<8x16xf32>, %B: memref<8x16xf32>, %C
func.func @unary_too_many_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>, %C: memref<8x16x32xf32>) {
// expected-error at +3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 2 args, got 3}}
// expected-error at +2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
- linalg.elementwise kind=#linalg.elementwise_fn<exp> ins(%A, %B : memref<8x16x32xf32>, memref<8x16x32xf32>) outs(%C: memref<8x16x32xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<exp> ins(%A, %B : memref<8x16x32xf32>, memref<8x16x32xf32>) outs(%C: memref<8x16x32xf32>)
return
}
@@ -49,6 +49,6 @@ func.func @unary_too_many_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>
func.func @binary_too_few_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>) {
// expected-error at +3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 3 args, got 2}}
// expected-error at +2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
- linalg.elementwise kind=#linalg.elementwise_fn<add> ins(%A : memref<8x16x32xf32>) outs(%B: memref<8x16x32xf32>)
+ linalg.elementwise kind=#linalg.elementwise_kind<add> ins(%A : memref<8x16x32xf32>) outs(%B: memref<8x16x32xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/element_wise/round-trip.mlir b/mlir/test/Dialect/Linalg/element_wise/round-trip.mlir
index f4659f89785e4..6ae2a77eb19f8 100644
--- a/mlir/test/Dialect/Linalg/element_wise/round-trip.mlir
+++ b/mlir/test/Dialect/Linalg/element_wise/round-trip.mlir
@@ -3,12 +3,12 @@
// Note - the functions are named @{unary|binary}_{identity|transpose|broadcast|transpose_a|...}_{exp|mul|div|..}
// CHECK: @unary_identity_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
-// CHECK: %{{.*}} = linalg.elementwise kind=#linalg.elementwise_fn<exp>
+// CHECK: %{{.*}} = linalg.elementwise kind=#linalg.elementwise_kind<exp>
// CHECK-SAME ins(%[[A:.+]] : tensor<8x16x32xf32>) outs(%[[B:.+]] : tensor<8x16x32xf32>)
//
func.func @unary_identity_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%r = linalg.elementwise
- kind=#linalg.elementwise_fn<exp>
+ kind=#linalg.elementwise_kind<exp>
ins(%A : tensor<8x16x32xf32>)
outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %r : tensor<8x16x32xf32>
@@ -21,15 +21,16 @@ func.func @unary_identity_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>)
//
// CHECK: @unary_projection_tanh(%[[A:.+]]: tensor<?x16xf32>,
// CHECK-SAME: %[[B:.+]]: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
-// CHECK: {{.*}} = linalg.elementwise kind=#linalg.elementwise_fn<tanh>
+// CHECK: {{.*}} = linalg.elementwise kind=#linalg.elementwise_kind<tanh>
// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]] : tensor<?x16xf32>) outs(%[[B]] : tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
//
func.func @unary_projection_tanh(%A: tensor<?x16xf32>,
%B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
%r = linalg.elementwise
- kind=#linalg.elementwise_fn<tanh>
- indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+ kind=#linalg.elementwise_kind<tanh>
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
ins(%A : tensor<?x16xf32>)
outs(%B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
return %r : tensor<8x16x?xf32>
@@ -40,14 +41,14 @@ func.func @unary_projection_tanh(%A: tensor<?x16xf32>,
// CHECK: @binary_identity_div(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>,
// CHECK-SAME: %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
// CHECK: {{.*}} = linalg.elementwise
-// CHECK-SAME: kind=#linalg.elementwise_fn<div>
+// CHECK-SAME: kind=#linalg.elementwise_kind<div>
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
//
func.func @binary_identity_div(%A: tensor<16x8xf32>, %B: tensor<16x8xf32>,
%C: tensor<16x8xf32>) -> tensor<16x8xf32> {
%r = linalg.elementwise
- kind=#linalg.elementwise_fn<div>
+ kind=#linalg.elementwise_kind<div>
ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
return %r : tensor<16x8xf32>
@@ -59,14 +60,14 @@ func.func @binary_identity_div(%A: tensor<16x8xf32>, %B: tensor<16x8xf32>,
// CHECK-SAME: %[[B:.+]]: tensor<1x2x3x4x5xi32>,
// CHECK-SAME: %[[C:.+]]: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
// CHECK: {{.*}} = linalg.elementwise
-// CHECK-SAME: kind=#linalg.elementwise_fn<mul>
+// CHECK-SAME: kind=#linalg.elementwise_kind<mul>
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
// CHECK-SAME: outs(%[[C]] : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
//
func.func @binary_identity_mul_5Di(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
%r = linalg.elementwise
- kind=#linalg.elementwise_fn<mul>
+ kind=#linalg.elementwise_kind<mul>
ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
return %r : tensor<1x2x3x4x5xi32>
@@ -80,7 +81,8 @@ func.func @binary_identity_mul_5Di(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
- %r = linalg.elementwise kind=#linalg.elementwise_fn<mul>
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<mul>
indexing_maps = [#map, #map, #map]
ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
More information about the Mlir-commits
mailing list