[Mlir-commits] [mlir] 6de5d1e - [mlir][linalg] Extend elementwise (#124661)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 21 02:51:25 PST 2025
Author: Javed Absar
Date: 2025-02-21T10:51:21Z
New Revision: 6de5d1e46d1812de2bbbbe8d8d2c811e4d16acbe
URL: https://github.com/llvm/llvm-project/commit/6de5d1e46d1812de2bbbbe8d8d2c811e4d16acbe
DIFF: https://github.com/llvm/llvm-project/commit/6de5d1e46d1812de2bbbbe8d8d2c811e4d16acbe.diff
LOG: [mlir][linalg] Extend elementwise (#124661)
Implements Linalg elemwise named-op following the proposal and
discussions in RFC:
https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1
Added:
mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir
mlir/test/Dialect/Linalg/elementwise/invalid.mlir
mlir/test/Dialect/Linalg/elementwise/round-trip.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d3..33601c5d6dad9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
}];
}
+// Define the attribute enums matching elementwise op kind (e.g., add).
+def ElementwiseKindAttr : EnumAttr<Linalg_Dialect,
+ ElementwiseKind, "elementwise_kind"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
// Define the function attribute enums matching the OpDSL functions.
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
let assemblyFormat = "`<` $value `>`";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d05..ce68afe471fe8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
+
+// Join two I32EnumAttrCase lists. This joining takes care that the
+// 'int enum values' in the combined list do not overlap. It does this
+// by adding to each element of second list the offset '!size(a)'.
+class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
+ list<I32EnumAttrCase> b> {
+ int aSize = !size(a);
+ list<I32EnumAttrCase> result =
+ !foldl(a, b, acc, var,
+ acc # [I32EnumAttrCase<var.symbol,
+ !add(var.value, aSize)
+ >]);
+}
+
+// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
+// The flattening (via call to 'join') ensures no overlap in enum values.
+class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
+ list<I32EnumAttrCase> result =
+ !foldl([]<I32EnumAttrCase>, l, acc, var,
+ JoinTwoI32EnumAttrCaseList<acc, var>.result);
+}
+
+// Define a unified `enum class : i32` for all element-wise op functions.
+def ElementwiseKind :
+ I32EnumAttr<"ElementwiseKind",
+ "",
+ ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
+ BinaryFn.enumerants,
+ TernaryFn.enumerants]>.result
+ > {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
+// Define an `enum class : i32` that marks where each individual enum class
+// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseKind.
+def ElementwiseCaseLimits : I32EnumAttr<"ElementwiseCaseLimits", "", []> {
+ int last_unary = !size(UnaryFn.enumerants);
+ int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
+ int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));
+
+ let enumerants = [
+ I32EnumAttrCase<"LastUnary", last_unary>,
+ I32EnumAttrCase<"LastBinary", last_binary>,
+ I32EnumAttrCase<"LastTernary", last_ternary>];
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
+// Define an `enum class : i32` to categorise arity elementwise ops.
+def ElementwiseArityGroup : I32EnumAttr<"ElementwiseArityGroup", "", [
+ I32EnumAttrCase<"Unary", 1>,
+ I32EnumAttrCase<"Binary", 2>,
+ I32EnumAttrCase<"Ternary", 3>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::linalg";
+}
+
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast_signed", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index a5725d6f1507e..ce6e9e7bb28c4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// Op definition for ElementwiseOp
+//===----------------------------------------------------------------------===//
+def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
+ AttrSizedOperandSegments]> {
+ let summary = [{ Performs element-wise operation }];
+ let description = [{
+ The attribute `kind` describes arithmetic operation to perform. The
+ operation kind can be unary (e.g. max), binary (e.g. add) or ternary
+ (e.g. select).
+
+ By default, all indexing maps are identities. In the case of default
+ indexing map, all input and output shapes must match. The number of dims in
+ each of the identity maps is equal to the rank of the output type.
+
+ Affine-maps for operands and result are required to be provided by the user
+ when a transpose and/or broadcast is needed on any operand. When a map is not
+ provided, default identity maps are inferred for each operand.
+
+ Iterator-types are always all `parallel`.
+ Iterator-types are needed for constructing the underlying structured op.
+
+ The number of dims of the iterator-types are inferred from the rank of
+ the result type.
+
+ Example:
+
+ Defining a unary linalg.elemwise with default indexing-map:
+ ```mlir
+ %exp = linalg.elemwise
+ kind=#linalg.elemwise_kind<exp>
+ ins(%x : tensor<4x16x8xf32>)
+ outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
+ ```
+
+ Defining a binary linalg.elemwise with user-defined indexing-map:
+ ```mlir
+ %add = linalg.elemwise
+ kind=#linalg.elemwise_kind<add>
+ indexing_maps = [#transpose, #broadcast, #identity]
+ ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
+ outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ ElementwiseKindAttr:$kind,
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+ );
+
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let regions = (region AnyRegion:$region);
+ let skipDefaultBuilders = 1;
+
+ let builders = [
+ OpBuilder<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+ [{
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+ attributes, ElementwiseOp::getRegionBuilder());
+ }]>
+ ];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
+ /// Get the arity enum corresponding to the kind of op, e.g. if arg is
+ /// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
+ static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
+
+ /// Both user-specified and default indexing map will always depend on
+ /// the current Op instance.
+ static bool hasDynamicIndexingMaps() { return true; }
+
+ /// Implements the block region builder for the elementwiseOp. This is
+ /// called by the 'fillStructuredOpRegion'.
+ static void regionBuilder(ImplicitLocOpBuilder &b,
+ Block &block, ArrayRef<NamedAttribute> attrs);
+
+ static std::function<void(ImplicitLocOpBuilder &,
+ Block &, ArrayRef<NamedAttribute>)>
+ getRegionBuilder() {
+ return regionBuilder;
+ }
+
+ /// Returns rank of the result tensor/memref. Useful for knowing
+ /// the dimensionality of the iteration space when others means
+ /// are not possible e.g. absence of user-provided indexing map.
+ unsigned getResultRank() {
+ Value output = getDpsInitOperand(0)->get();
+ ShapedType shapedType = llvm::cast<ShapedType>(output.getType());
+ return shapedType.getRank();
+ }
+
+ /// Returns N 'parallel' iterator types where N is rank of result.
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+ /// The default indexing maps are identities.
+ /// There will be N+1 such maps, where N is the arity of the Op.
+ static SmallVector<AffineMap>
+ getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims,
+ MLIRContext *context);
+
+ /// Destination passing style interface method.
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
+ return getOutputsMutable();
+ }
+
+ // Generic methods.
+ std::string getLibraryCallName() {
+ return generateLibraryCallName(getOperation());
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Op definition for MatmulOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 42ea0e1197ef1..161c334c4c985 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4058,6 +4058,233 @@ Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+//===----------------------------------------------------------------------===//
+// ElementwiseOp
+//===----------------------------------------------------------------------===//
+//
+namespace {
+struct ArityGroupAndKind {
+ // The enum class {Unary, Binary, Ternary, ..}
+ ElementwiseArityGroup arityGroup;
+
+ // The kind (e.g. `exp` or `add`) belonging to the arity group.
+ union Kind {
+ UnaryFn unaryFn;
+ BinaryFn binaryFn;
+ TernaryFn ternaryFn;
+ } kind;
+};
+
+unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
+ return static_cast<unsigned>(arityGroup);
+}
+} // namespace
+
+static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
+ constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
+ constexpr int lastBinary =
+ static_cast<int>(ElementwiseCaseLimits::LastBinary);
+ constexpr int lastTernary =
+ static_cast<int>(ElementwiseCaseLimits::LastTernary);
+
+ int val = static_cast<int>(kind);
+ ArityGroupAndKind result;
+
+ if (val < lastUnary) {
+ result.arityGroup = ElementwiseArityGroup::Unary;
+ result.kind.unaryFn = static_cast<UnaryFn>(val);
+ return result;
+ }
+ if (val < lastBinary) {
+ result.arityGroup = ElementwiseArityGroup::Binary;
+ result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
+ return result;
+ }
+ if (val >= lastTernary) {
+ llvm_unreachable("unhandled ElementwiseFn");
+ }
+ result.arityGroup = ElementwiseArityGroup::Ternary;
+ result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
+ return result;
+}
+
+SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
+ auto rank = getResultRank();
+ return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+}
+
+SmallVector<AffineMap>
+ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+ MLIRContext *context) {
+ auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+ return SmallVector<AffineMap>(numMaps, map);
+}
+
+ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Expect e.g. `kind = #linalg.elemwise_kind<add>`
+ Attribute attr;
+ mlir::linalg::ElementwiseKind elemwiseKindVal;
+ if (parser.parseKeyword("kind") || parser.parseEqual())
+ return failure();
+
+ if (succeeded(parser.parseAttribute(attr))) {
+ auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
+ if (!elemwiseKindAttr)
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ElementwiseKind attribute");
+ elemwiseKindVal = elemwiseKindAttr.getValue();
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected operation 'kind' attribute");
+ }
+ result.addAttribute(
+ "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
+
+ // Parse optional `indexing_maps`
+ SmallVector<Attribute, 3> indexingMapsAttr;
+ Attribute mapAttr;
+ if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+ if (parser.parseEqual())
+ return failure();
+ if (parser.parseLSquare())
+ return failure();
+ do {
+ if (parser.parseAttribute(mapAttr))
+ return failure();
+ if (!isa<AffineMapAttr>(mapAttr))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected affine map attribute");
+ indexingMapsAttr.push_back(mapAttr);
+ if (parser.parseOptionalComma())
+ break;
+ } while (true);
+ if (parser.parseRSquare())
+ return failure();
+ }
+ // At this stage of parsing the only way to infer number of region
+ // args is through op kind, as input output tensors are not parsed yet.
+ auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
+ int numRegionArgs =
+ getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
+ if (parseNamedStructuredOp(parser, result, numRegionArgs,
+ ElementwiseOp::getRegionBuilder())) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "unable to parse elemwise op");
+ }
+
+ // Initialize indexingMaps, if not supplied explicitly.
+ if (indexingMapsAttr.empty()) {
+ // We need to infer the numDims of the indexing maps from the output
+ // type which is already parsed by now.
+ auto resultType = result.operands[result.operands.size() - 1].getType();
+ auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+ if (!shapedType)
+ return parser.emitError(parser.getCurrentLocation(),
+ "return type needs to be shaped type");
+ auto numDims = shapedType.getRank();
+ indexingMapsAttr = llvm::map_to_vector(
+ ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
+ parser.getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ }
+
+ result.addAttribute("indexing_maps",
+ parser.getBuilder().getArrayAttr(indexingMapsAttr));
+ return success();
+}
+
+void ElementwiseOp::print(OpAsmPrinter &p) {
+ p << " kind=";
+ p.printAttribute(getKindAttr());
+ SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
+ "indexing_maps"};
+ unsigned arity =
+ getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
+ unsigned numDims = getResultRank();
+
+ SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+ ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
+ getContext()),
+ [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+
+ if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+ p << " indexing_maps = [";
+ llvm::interleaveComma(getIndexingMaps(), p,
+ [&](Attribute attr) { p.printAttribute(attr); });
+ p << "]";
+ }
+
+ printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+ elidedAttrs);
+}
+
+LogicalResult ElementwiseOp::verify() {
+ // All necessary checks are done either by
+ // - EnumAttr (e.g. unknown operation kind)
+ // - verifyStructuredOpInterface (incorrect map, sizes).
+ return success();
+}
+
+/// Implements the block region builder for the ElementwiseOp. This is called by
+/// 'fillStructuredOpRegion'.
+void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+ ArrayRef<NamedAttribute> attrs) {
+ ElementwiseKind elemwiseKind;
+ for (auto attr : attrs) {
+ if (attr.getName() == b.getStringAttr("kind")) {
+ auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
+ assert(kindAttr && "op kind attribute incorrectly set");
+ elemwiseKind = kindAttr.getValue();
+ break;
+ }
+ }
+
+ ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
+ auto arityGroup = groupAndKind.arityGroup;
+ auto kind = groupAndKind.kind;
+ unsigned numBlockArgs = getArityGroupAsUInt(arityGroup) + 1 /*output*/;
+ assert(block.getNumArguments() == numBlockArgs &&
+ "Elementwise regionBuilder number of block args mismatch");
+
+ RegionBuilderHelper helper(b, block);
+ SmallVector<Value> yields;
+ Value result;
+
+ if (arityGroup == ElementwiseArityGroup::Unary) {
+ result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
+
+ } else if (arityGroup == ElementwiseArityGroup::Binary) {
+ result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
+ block.getArgument(1));
+
+ } else if (arityGroup == ElementwiseArityGroup::Ternary) {
+ result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
+ block.getArgument(1), block.getArgument(2));
+
+ } else
+ assert(false && "found unhandled category in elemwise");
+
+ yields.push_back(result);
+ helper.yieldOutputs(yields);
+}
+
+LogicalResult ElementwiseOp::fold(FoldAdaptor,
+ SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+
+void ElementwiseOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ if (hasPureTensorSemantics())
+ return;
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ElementwiseOp::getSpeculatability() {
+ return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir
new file mode 100644
index 0000000000000..94a46d97e6e86
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops -split-input-file | FileCheck %s
+// CHECK: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//
+// CHECK: @unary_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]]
+// CHECK-SAME: outs(%[[B]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
+// CHECK: %[[EXP:.+]] = math.exp %[[A_ARG]] : f32
+// CHECK: linalg.yield %[[EXP]] : f32
+//
+func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<exp>
+ ins(%A : tensor<8x16x32xf32>)
+ outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %r : tensor<8x16x32xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: @unary_transpose_broadcast_tanh(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]]
+// CHECK-SAME: outs(%[[B]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
+// CHECK: %[[TANH:.+]] = math.tanh %[[A_ARG]] : f32
+// CHECK: linalg.yield %[[TANH]] : f32
+//
+func.func @unary_transpose_broadcast_tanh(%A : tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<tanh>
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+ ins(%A : tensor<32x16xf32>)
+ outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %r : tensor<8x16x32xf32>
+}
+// -----
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//
+// CHECK: @binary_div_on_memrefs(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[DIV:.+]] = arith.divf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK: linalg.yield %[[DIV]] : f32
+//
+func.func @binary_div_on_memrefs(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+ linalg.elementwise
+ kind=#linalg.elementwise_kind<div>
+ ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>)
+ outs(%C: memref<16x8xf32>)
+ return
+}
+// -----
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//
+// CHECK: @binary_mul_on_tensors(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK: linalg.yield %[[MUL]] : f32
+//
+func.func @binary_mul_on_tensors(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<mul>
+ ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//
+// CHECK: @binary_transpose_a(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[SUB:.+]] = arith.subf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK: linalg.yield %[[SUB]] : f32
+//
+func.func @binary_transpose_a(%A : tensor<8x16xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<sub>
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>]
+ ins(%A, %B: tensor<8x16xf32>, tensor<16x8xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-DAG: #[[BROADCAST:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK: @binary_transpose_a_broadcast_b(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[BROADCAST]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[ADD:.+]] = arith.addf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+//
+func.func @binary_transpose_a_broadcast_b(%A : tensor<8x16xf32>, %B: tensor<16xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<add>
+ indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
+ affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>]
+ ins(%A, %B: tensor<8x16xf32>, tensor<16xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: @ternary(%[[A:.+]]: tensor<32x16xi1>, %[[B:.+]]: tensor<8x16x32xf32>, %[[C:.+]]: tensor<8x16x32xf32>, %[[D:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]], #[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+//
+// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]]
+// CHECK-SAME: outs(%[[D]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i1, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32, %[[D_ARG:.+]]: f32)
+// CHECK: %[[SELECTED:.+]] = arith.select %[[A_ARG]], %[[B_ARG]], %[[C_ARG]] : f32
+// CHECK: linalg.yield %[[SELECTED]] : f32
+//
+func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8x16x32xf32>, %D : tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<select>
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+ ins(%A, %B, %C : tensor<32x16xi1>, tensor<8x16x32xf32>, tensor<8x16x32xf32>)
+ outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %r : tensor<8x16x32xf32>
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/elementwise/invalid.mlir b/mlir/test/Dialect/Linalg/elementwise/invalid.mlir
new file mode 100644
index 0000000000000..3a47d231017b4
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/invalid.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+func.func @misspelt_op_div(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+ // expected-error at +3 {{expected ::mlir::linalg::ElementwiseKind to be one of: exp, log, abs, ceil, floor}}
+ // expected-error at +2 {{failed to parse ElementwiseKindAttr parameter}}
+ // expected-error at +1 {{custom op 'linalg.elementwise' expected operation 'kind' attribute}}
+ linalg.elementwise kind=#linalg.elementwise_kind<dive> ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+ return
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @missing_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+ // expected-error at +1 {{'linalg.elementwise' op expected the number of indexing_map (2) to be equal to the number of input/output operands (3)}}
+ linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map] ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+ return
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @identity_map_when_transpose_expected(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+ // expected-error at +1 {{'linalg.elementwise' op inferred input/output operand #1 has shape's dimension #0 to be 8, but found 16}}
+ linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map, #map] ins(%A, %B: memref<8x16xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+ return
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+func.func @incorrect_result_rank(%A : memref<8x16xf32>, %B: memref<8x16xf32>, %C: memref<8xf32>) {
+ // expected-error at +1 {{'linalg.elementwise' op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
+ linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map, #map1] ins(%A, %B: memref<8x16xf32>, memref<8x16xf32>) outs(%C: memref<8xf32>)
+ return
+}
+
+// -----
+
+func.func @unary_too_many_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>, %C: memref<8x16x32xf32>) {
+ // expected-error at +3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 2 args, got 3}}
+ // expected-error at +2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
+ linalg.elementwise kind=#linalg.elementwise_kind<exp> ins(%A, %B : memref<8x16x32xf32>, memref<8x16x32xf32>) outs(%C: memref<8x16x32xf32>)
+ return
+}
+
+// -----
+
+func.func @binary_too_few_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>) {
+ // expected-error at +3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 3 args, got 2}}
+ // expected-error at +2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
+ linalg.elementwise kind=#linalg.elementwise_kind<add> ins(%A : memref<8x16x32xf32>) outs(%B: memref<8x16x32xf32>)
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/elementwise/round-trip.mlir b/mlir/test/Dialect/Linalg/elementwise/round-trip.mlir
new file mode 100644
index 0000000000000..6ae2a77eb19f8
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/round-trip.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+//
+// Note - the functions are named @{unary|binary}_{identity|transpose|broadcast|transpose_a|...}_{exp|mul|div|..}
+
+// CHECK: @unary_identity_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: %{{.*}} = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME ins(%[[A:.+]] : tensor<8x16x32xf32>) outs(%[[B:.+]] : tensor<8x16x32xf32>)
+//
+func.func @unary_identity_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<exp>
+ ins(%A : tensor<8x16x32xf32>)
+ outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+ return %r : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: @unary_projection_tanh(%[[A:.+]]: tensor<?x16xf32>,
+// CHECK-SAME: %[[B:.+]]: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
+// CHECK: {{.*}} = linalg.elementwise kind=#linalg.elementwise_kind<tanh>
+// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
+// CHECK-SAME: ins(%[[A]] : tensor<?x16xf32>) outs(%[[B]] : tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
+//
+func.func @unary_projection_tanh(%A: tensor<?x16xf32>,
+ %B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<tanh>
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+ ins(%A : tensor<?x16xf32>)
+ outs(%B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
+ return %r : tensor<8x16x?xf32>
+}
+
+// -----
+
+// CHECK: @binary_identity_div(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>,
+// CHECK-SAME: %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<div>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @binary_identity_div(%A: tensor<16x8xf32>, %B: tensor<16x8xf32>,
+ %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<div>
+ ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
+ outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+ return %r : tensor<16x8xf32>
+}
+
+// -----
+
+// CHECK: @binary_identity_mul_5Di(%[[A]]: tensor<1x2x3x4x5xi32>,
+// CHECK-SAME: %[[B:.+]]: tensor<1x2x3x4x5xi32>,
+// CHECK-SAME: %[[C:.+]]: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME: kind=#linalg.elementwise_kind<mul>
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
+// CHECK-SAME: outs(%[[C]] : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
+//
+func.func @binary_identity_mul_5Di(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
+ %C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<mul>
+ ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
+ outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
+ return %r : tensor<1x2x3x4x5xi32>
+}
+
+// -----
+
+// CHECK: @redundant_maps
+// CHECK-NOT: indexing_maps
+//
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
+ %C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
+ %r = linalg.elementwise
+ kind=#linalg.elementwise_kind<mul>
+ indexing_maps = [#map, #map, #map]
+ ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
+ outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
+ return %r : tensor<1x2x3x4x5xi32>
+}
More information about the Mlir-commits
mailing list