[Mlir-commits] [mlir] 6de5d1e - [mlir][linalg] Extend elementwise (#124661)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 21 02:51:25 PST 2025


Author: Javed Absar
Date: 2025-02-21T10:51:21Z
New Revision: 6de5d1e46d1812de2bbbbe8d8d2c811e4d16acbe

URL: https://github.com/llvm/llvm-project/commit/6de5d1e46d1812de2bbbbe8d8d2c811e4d16acbe
DIFF: https://github.com/llvm/llvm-project/commit/6de5d1e46d1812de2bbbbe8d8d2c811e4d16acbe.diff

LOG: [mlir][linalg] Extend elementwise (#124661)

Implements Linalg elemwise named-op following the proposal and
discussions in RFC:
  https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1

Added: 
    mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir
    mlir/test/Dialect/Linalg/elementwise/invalid.mlir
    mlir/test/Dialect/Linalg/elementwise/round-trip.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 73f984dc072d3..33601c5d6dad9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
   }];
 }
 
+// Define the attribute enums matching elementwise op kind (e.g., add).
+def ElementwiseKindAttr : EnumAttr<Linalg_Dialect,
+                                   ElementwiseKind, "elementwise_kind"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 // Define the function attribute enums matching the OpDSL functions.
 def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
   let assemblyFormat = "`<` $value `>`";

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index e615876a95d05..ce68afe471fe8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::linalg";
 }
+
+// Join two I32EnumAttrCase lists. This joining takes care that the
+// 'int enum values' in the combined list do not overlap. It does this
+// by adding to each element of second list the offset '!size(a)'.
+class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
+                                  list<I32EnumAttrCase> b> {
+  int aSize = !size(a);
+  list<I32EnumAttrCase> result =
+             !foldl(a, b, acc, var,
+                    acc # [I32EnumAttrCase<var.symbol,
+                                           !add(var.value, aSize)
+                                           >]);
+}
+
+// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
+// The flattening (via call to 'join') ensures no overlap in enum values.
+class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
+  list<I32EnumAttrCase> result =
+             !foldl([]<I32EnumAttrCase>, l, acc, var,
+                    JoinTwoI32EnumAttrCaseList<acc, var>.result);
+}
+
+// Define a unified `enum class : i32` for all element-wise op functions.
+def ElementwiseKind :
+            I32EnumAttr<"ElementwiseKind",
+                        "",
+                        ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
+                                                  BinaryFn.enumerants,
+                                                  TernaryFn.enumerants]>.result
+                      > {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
+
+// Define an `enum class : i32` that marks where each individual enum class
+// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseKind.
+def ElementwiseCaseLimits : I32EnumAttr<"ElementwiseCaseLimits", "", []> {
+  int last_unary = !size(UnaryFn.enumerants);
+  int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
+  int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));
+
+  let enumerants =  [
+         I32EnumAttrCase<"LastUnary", last_unary>,
+         I32EnumAttrCase<"LastBinary", last_binary>,
+         I32EnumAttrCase<"LastTernary", last_ternary>];
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
+
+// Define an `enum class : i32` to categorise arity elementwise ops.
+def ElementwiseArityGroup : I32EnumAttr<"ElementwiseArityGroup", "", [
+  I32EnumAttrCase<"Unary", 1>,
+  I32EnumAttrCase<"Binary", 2>,
+  I32EnumAttrCase<"Ternary", 3>
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::linalg";
+}
+
 def TypeFn : I32EnumAttr<"TypeFn", "", [
   I32EnumAttrCase<"cast_signed", 0>,
   I32EnumAttrCase<"cast_unsigned", 1>

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index a5725d6f1507e..ce6e9e7bb28c4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for ElementwiseOp
+//===----------------------------------------------------------------------===//
+def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
+                   AttrSizedOperandSegments]> {
+  let summary = [{ Performs element-wise operation }];
+  let description = [{
+    The attribute `kind` describes arithmetic operation to perform. The
+    operation kind can be unary (e.g. max), binary (e.g. add) or ternary
+    (e.g. select).
+
+    By default, all indexing maps are identities. In the case of default
+    indexing map, all input and output shapes must match. The number of dims in
+    each of the identity maps is equal to the rank of the output type.
+
+    Affine-maps for operands and result are required to be provided by the user
+    when a transpose and/or broadcast is needed on any operand. When a map is not
+    provided, default identity maps are inferred for each operand.
+
+    Iterator-types are always all `parallel`.
+    Iterator-types are needed for constructing the underlying structured op.
+
+    The number of dims of the iterator-types are inferred from the rank of
+    the result type.
+
+    Example:
+
+    Defining a unary linalg.elemwise with default indexing-map:
+      ```mlir
+      %exp = linalg.elemwise
+             kind=#linalg.elemwise_kind<exp>
+             ins(%x : tensor<4x16x8xf32>)
+             outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
+      ```
+
+    Defining a binary linalg.elemwise with user-defined indexing-map:
+    ```mlir
+    %add = linalg.elemwise
+            kind=#linalg.elemwise_kind<add>
+            indexing_maps = [#transpose, #broadcast, #identity]
+            ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
+            outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+    ```
+  }];
+
+  let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      ElementwiseKindAttr:$kind,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+    );
+
+  let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+  let regions = (region AnyRegion:$region);
+  let skipDefaultBuilders = 1;
+
+  let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, ElementwiseOp::getRegionBuilder());
+      }]>
+    ];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+      /// Get the arity enum corresponding to the kind of op, e.g. if arg is
+      /// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
+      static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
+
+      /// Both user-specified and default indexing map will always depend on
+      /// the current Op instance.
+      static bool hasDynamicIndexingMaps() { return true; }
+
+      /// Implements the block region builder for the elementwiseOp. This is
+      /// called by the 'fillStructuredOpRegion'.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      /// Returns rank of the result tensor/memref. Useful for knowing
+      /// the dimensionality of the iteration space when others means
+      /// are not possible e.g. absence of user-provided indexing map.
+      unsigned getResultRank() {
+        Value output = getDpsInitOperand(0)->get();
+        ShapedType shapedType = llvm::cast<ShapedType>(output.getType());
+        return shapedType.getRank();
+      }
+
+      /// Returns N 'parallel' iterator types where N is rank of result.
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// The default indexing maps are identities.
+      /// There will be N+1 such maps, where N is the arity of the Op.
+      static SmallVector<AffineMap>
+      getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims,
+                             MLIRContext *context);
+
+      /// Destination passing style interface method.
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      std::string getLibraryCallName() {
+        return generateLibraryCallName(getOperation());
+      }
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Op definition for MatmulOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 42ea0e1197ef1..161c334c4c985 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4058,6 +4058,233 @@ Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ElementwiseOp
+//===----------------------------------------------------------------------===//
+//
+namespace {
+struct ArityGroupAndKind {
+  // The enum class {Unary, Binary, Ternary, ..}
+  ElementwiseArityGroup arityGroup;
+
+  // The kind (e.g. `exp` or `add`) belonging to the arity group.
+  union Kind {
+    UnaryFn unaryFn;
+    BinaryFn binaryFn;
+    TernaryFn ternaryFn;
+  } kind;
+};
+
+unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
+  return static_cast<unsigned>(arityGroup);
+}
+} // namespace
+
+static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
+  constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
+  constexpr int lastBinary =
+      static_cast<int>(ElementwiseCaseLimits::LastBinary);
+  constexpr int lastTernary =
+      static_cast<int>(ElementwiseCaseLimits::LastTernary);
+
+  int val = static_cast<int>(kind);
+  ArityGroupAndKind result;
+
+  if (val < lastUnary) {
+    result.arityGroup = ElementwiseArityGroup::Unary;
+    result.kind.unaryFn = static_cast<UnaryFn>(val);
+    return result;
+  }
+  if (val < lastBinary) {
+    result.arityGroup = ElementwiseArityGroup::Binary;
+    result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
+    return result;
+  }
+  if (val >= lastTernary) {
+    llvm_unreachable("unhandled ElementwiseFn");
+  }
+  result.arityGroup = ElementwiseArityGroup::Ternary;
+  result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
+  return result;
+}
+
+SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
+  auto rank = getResultRank();
+  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
+}
+
+SmallVector<AffineMap>
+ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
+                                      MLIRContext *context) {
+  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
+  return SmallVector<AffineMap>(numMaps, map);
+}
+
+ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Expect e.g. `kind = #linalg.elemwise_kind<add>`
+  Attribute attr;
+  mlir::linalg::ElementwiseKind elemwiseKindVal;
+  if (parser.parseKeyword("kind") || parser.parseEqual())
+    return failure();
+
+  if (succeeded(parser.parseAttribute(attr))) {
+    auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
+    if (!elemwiseKindAttr)
+      return parser.emitError(parser.getCurrentLocation(),
+                              "expected ElementwiseKind attribute");
+    elemwiseKindVal = elemwiseKindAttr.getValue();
+  } else {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected operation 'kind' attribute");
+  }
+  result.addAttribute(
+      "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
+
+  // Parse optional `indexing_maps`
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+    if (parser.parseLSquare())
+      return failure();
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr))
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      indexingMapsAttr.push_back(mapAttr);
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // At this stage of parsing the only way to infer number of region
+  // args is through op kind, as input output tensors are not parsed yet.
+  auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
+  int numRegionArgs =
+      getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
+  if (parseNamedStructuredOp(parser, result, numRegionArgs,
+                             ElementwiseOp::getRegionBuilder())) {
+    return parser.emitError(parser.getCurrentLocation(),
+                            "unable to parse elemwise op");
+  }
+
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    // We need to infer the numDims of the indexing maps from the output
+    // type which is already parsed by now.
+    auto resultType = result.operands[result.operands.size() - 1].getType();
+    auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
+    if (!shapedType)
+      return parser.emitError(parser.getCurrentLocation(),
+                              "return type needs to be shaped type");
+    auto numDims = shapedType.getRank();
+    indexingMapsAttr = llvm::map_to_vector(
+        ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
+                                              parser.getContext()),
+        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+  }
+
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+  return success();
+}
+
+void ElementwiseOp::print(OpAsmPrinter &p) {
+  p << " kind=";
+  p.printAttribute(getKindAttr());
+  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
+                                           "indexing_maps"};
+  unsigned arity =
+      getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
+  unsigned numDims = getResultRank();
+
+  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
+      ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
+                                            getContext()),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+
+  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
+    p << " indexing_maps = [";
+    llvm::interleaveComma(getIndexingMaps(), p,
+                          [&](Attribute attr) { p.printAttribute(attr); });
+    p << "]";
+  }
+
+  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
+                         elidedAttrs);
+}
+
+LogicalResult ElementwiseOp::verify() {
+  // All necessary checks are done either by
+  // - EnumAttr (e.g. unknown operation kind)
+  // - verifyStructuredOpInterface (incorrect map, sizes).
+  return success();
+}
+
+/// Implements the block region builder for the ElementwiseOp. This is called by
+/// 'fillStructuredOpRegion'.
+void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                                  ArrayRef<NamedAttribute> attrs) {
+  ElementwiseKind elemwiseKind;
+  for (auto attr : attrs) {
+    if (attr.getName() == b.getStringAttr("kind")) {
+      auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
+      assert(kindAttr && "op kind attribute incorrectly set");
+      elemwiseKind = kindAttr.getValue();
+      break;
+    }
+  }
+
+  ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
+  auto arityGroup = groupAndKind.arityGroup;
+  auto kind = groupAndKind.kind;
+  unsigned numBlockArgs = getArityGroupAsUInt(arityGroup) + 1 /*output*/;
+  assert(block.getNumArguments() == numBlockArgs &&
+         "Elementwise regionBuilder number of block args mismatch");
+
+  RegionBuilderHelper helper(b, block);
+  SmallVector<Value> yields;
+  Value result;
+
+  if (arityGroup == ElementwiseArityGroup::Unary) {
+    result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
+
+  } else if (arityGroup == ElementwiseArityGroup::Binary) {
+    result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
+                                  block.getArgument(1));
+
+  } else if (arityGroup == ElementwiseArityGroup::Ternary) {
+    result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
+                                   block.getArgument(1), block.getArgument(2));
+
+  } else
+    assert(false && "found unhandled category in elemwise");
+
+  yields.push_back(result);
+  helper.yieldOutputs(yields);
+}
+
+LogicalResult ElementwiseOp::fold(FoldAdaptor,
+                                  SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+void ElementwiseOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ElementwiseOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 //===----------------------------------------------------------------------===//
 // PackOp/UnPackOp Common
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir
new file mode 100644
index 0000000000000..94a46d97e6e86
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/generalize_named_ops.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops -split-input-file | FileCheck %s
+// CHECK: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//
+// CHECK: @unary_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME:  ins(%[[A]]
+// CHECK-SAME: outs(%[[B]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
+// CHECK:   %[[EXP:.+]] = math.exp %[[A_ARG]] : f32
+// CHECK:   linalg.yield %[[EXP]] : f32
+//
+func.func @unary_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+  %r = linalg.elementwise
+               kind=#linalg.elementwise_kind<exp>
+               ins(%A : tensor<8x16x32xf32>)
+               outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %r : tensor<8x16x32xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: @unary_transpose_broadcast_tanh(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME:  ins(%[[A]]
+// CHECK-SAME: outs(%[[B]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32)
+// CHECK:   %[[TANH:.+]] = math.tanh %[[A_ARG]] : f32
+// CHECK:   linalg.yield %[[TANH]] : f32
+//
+func.func @unary_transpose_broadcast_tanh(%A : tensor<32x16xf32>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+  %r = linalg.elementwise
+               kind=#linalg.elementwise_kind<tanh>
+               indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+               ins(%A : tensor<32x16xf32>)
+               outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %r : tensor<8x16x32xf32>
+}
+// -----
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//
+// CHECK: @binary_div_on_memrefs(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:  ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK:   %[[DIV:.+]] = arith.divf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK:   linalg.yield %[[DIV]] : f32
+//
+func.func @binary_div_on_memrefs(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+  linalg.elementwise
+               kind=#linalg.elementwise_kind<div>
+               ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>)
+               outs(%C: memref<16x8xf32>)
+  return
+}
+// -----
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//
+// CHECK: @binary_mul_on_tensors(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:  ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK:   %[[MUL:.+]] = arith.mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK:   linalg.yield %[[MUL]] : f32
+//
+func.func @binary_mul_on_tensors(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %r = linalg.elementwise
+               kind=#linalg.elementwise_kind<mul>
+               ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
+               outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %r : tensor<16x8xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//
+// CHECK: @binary_transpose_a(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:  ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK:   %[[SUB:.+]] = arith.subf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK:   linalg.yield %[[SUB]] : f32
+//
+func.func @binary_transpose_a(%A : tensor<8x16xf32>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %r = linalg.elementwise
+               kind=#linalg.elementwise_kind<sub>
+               indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
+                                affine_map<(d0, d1) -> (d0, d1)>,
+                                affine_map<(d0, d1) -> (d0, d1)>]
+               ins(%A, %B: tensor<8x16xf32>, tensor<16x8xf32>)
+               outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %r : tensor<16x8xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[TRANSPOSE:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-DAG: #[[BROADCAST:.+]] = affine_map<(d0, d1) -> (d0)>
+//
+// CHECK: @binary_transpose_a_broadcast_b(%[[A:.+]]: tensor<8x16xf32>, %[[B:.+]]: tensor<16xf32>, %[[C:.+]]: tensor<16x8xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[TRANSPOSE]], #[[BROADCAST]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:  ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK:   %[[ADD:.+]] = arith.addf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK:   linalg.yield %[[ADD]] : f32
+//
+func.func @binary_transpose_a_broadcast_b(%A : tensor<8x16xf32>, %B: tensor<16xf32>, %C: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %r = linalg.elementwise
+               kind=#linalg.elementwise_kind<add>
+               indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
+                                affine_map<(d0, d1) -> (d0)>,
+                                affine_map<(d0, d1) -> (d0, d1)>]
+               ins(%A, %B: tensor<8x16xf32>, tensor<16xf32>)
+               outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %r : tensor<16x8xf32>
+}
+// -----
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: @ternary(%[[A:.+]]: tensor<32x16xi1>, %[[B:.+]]: tensor<8x16x32xf32>, %[[C:.+]]: tensor<8x16x32xf32>, %[[D:.+]]: tensor<8x16x32xf32>)
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[PROJECTION]], #[[IDENTITY]], #[[IDENTITY]], #[[IDENTITY]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+//
+// CHECK-SAME:  ins(%[[A]], %[[B]], %[[C]]
+// CHECK-SAME: outs(%[[D]]
+//
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i1, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32, %[[D_ARG:.+]]: f32)
+// CHECK:   %[[SELECTED:.+]] = arith.select %[[A_ARG]], %[[B_ARG]], %[[C_ARG]] : f32
+// CHECK:   linalg.yield %[[SELECTED]] : f32
+//
+func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8x16x32xf32>, %D : tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+  %r = linalg.elementwise
+               kind=#linalg.elementwise_kind<select>
+               indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
+                                affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                                affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                                affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+               ins(%A, %B, %C : tensor<32x16xi1>, tensor<8x16x32xf32>, tensor<8x16x32xf32>)
+               outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %r : tensor<8x16x32xf32>
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Linalg/elementwise/invalid.mlir b/mlir/test/Dialect/Linalg/elementwise/invalid.mlir
new file mode 100644
index 0000000000000..3a47d231017b4
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/invalid.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+func.func @misspelt_op_div(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+  // expected-error at +3 {{expected ::mlir::linalg::ElementwiseKind to be one of: exp, log, abs, ceil, floor}}
+  // expected-error at +2 {{failed to parse ElementwiseKindAttr parameter}}
+  // expected-error at +1 {{custom op 'linalg.elementwise' expected operation 'kind' attribute}}
+  linalg.elementwise kind=#linalg.elementwise_kind<dive> ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+  return
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @missing_indexing_map(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+  // expected-error at +1 {{'linalg.elementwise' op expected the number of indexing_map (2) to be equal to the number of input/output operands (3)}}
+  linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map] ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+  return
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @identity_map_when_transpose_expected(%A : memref<8x16xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>) {
+  // expected-error at +1 {{'linalg.elementwise' op inferred input/output operand #1 has shape's dimension #0 to be 8, but found 16}}
+  linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map, #map] ins(%A, %B: memref<8x16xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
+  return
+}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+func.func @incorrect_result_rank(%A : memref<8x16xf32>, %B: memref<8x16xf32>, %C: memref<8xf32>) {
+  // expected-error at +1 {{'linalg.elementwise' op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
+  linalg.elementwise kind=#linalg.elementwise_kind<div> indexing_maps = [#map, #map, #map1] ins(%A, %B: memref<8x16xf32>, memref<8x16xf32>) outs(%C: memref<8xf32>)
+  return
+}
+
+// -----
+
+func.func @unary_too_many_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>, %C:  memref<8x16x32xf32>) {
+  // expected-error at +3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 2 args, got 3}}
+  // expected-error at +2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
+  linalg.elementwise kind=#linalg.elementwise_kind<exp> ins(%A, %B : memref<8x16x32xf32>,  memref<8x16x32xf32>) outs(%C: memref<8x16x32xf32>)
+  return 
+}
+
+// -----
+
+func.func @binary_too_few_args(%A : memref<8x16x32xf32>, %B: memref<8x16x32xf32>) {
+  // expected-error at +3 {{custom op 'linalg.elementwise' [parseNamedStructuredOpRegion] ods-gen generated region expects 3 args, got 2}}
+  // expected-error at +2 {{custom op 'linalg.elementwise' unable to parse elemwise op}}
+  linalg.elementwise kind=#linalg.elementwise_kind<add> ins(%A : memref<8x16x32xf32>) outs(%B: memref<8x16x32xf32>)
+  return 
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Linalg/elementwise/round-trip.mlir b/mlir/test/Dialect/Linalg/elementwise/round-trip.mlir
new file mode 100644
index 0000000000000..6ae2a77eb19f8
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/elementwise/round-trip.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+//
+// Note - the functions are named @{unary|binary}_{identity|transpose|broadcast|transpose_a|...}_{exp|mul|div|..}
+
+// CHECK: @unary_identity_exp(%[[A:.+]]: tensor<8x16x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>)
+// CHECK: %{{.*}} = linalg.elementwise kind=#linalg.elementwise_kind<exp>
+// CHECK-SAME         ins(%[[A:.+]] : tensor<8x16x32xf32>) outs(%[[B:.+]] : tensor<8x16x32xf32>)
+//
+func.func @unary_identity_exp(%A : tensor<8x16x32xf32>, %B: tensor<8x16x32xf32>) ->  tensor<8x16x32xf32> {
+  %r = linalg.elementwise
+         kind=#linalg.elementwise_kind<exp>
+         ins(%A : tensor<8x16x32xf32>)
+         outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
+  return %r : tensor<8x16x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[PROJECTION:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+//
+// CHECK: @unary_projection_tanh(%[[A:.+]]: tensor<?x16xf32>,
+// CHECK-SAME:                            %[[B:.+]]: tensor<8x16x?xf32>) ->  tensor<8x16x?xf32> {
+// CHECK: {{.*}} = linalg.elementwise kind=#linalg.elementwise_kind<tanh>
+// CHECK-SAME:       indexing_maps = [#[[PROJECTION]], #[[IDENTITY]]]
+// CHECK-SAME:       ins(%[[A]] : tensor<?x16xf32>) outs(%[[B]] : tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
+//
+func.func @unary_projection_tanh(%A: tensor<?x16xf32>,
+                                          %B: tensor<8x16x?xf32>) ->  tensor<8x16x?xf32> {
+  %r = linalg.elementwise
+         kind=#linalg.elementwise_kind<tanh>
+         indexing_maps = [affine_map<(d0, d1, d2) -> (d2, d1)>,
+                          affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+         ins(%A : tensor<?x16xf32>)
+         outs(%B: tensor<8x16x?xf32>) -> tensor<8x16x?xf32>
+  return %r : tensor<8x16x?xf32>
+}
+
+// -----
+
+// CHECK: @binary_identity_div(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>,
+// CHECK-SAME:        %[[C:.+]]: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<div>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
+// CHECK-SAME:       outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
+//
+func.func @binary_identity_div(%A: tensor<16x8xf32>, %B: tensor<16x8xf32>,
+                      %C: tensor<16x8xf32>) ->  tensor<16x8xf32> {
+  %r = linalg.elementwise
+         kind=#linalg.elementwise_kind<div>
+         ins(%A, %B: tensor<16x8xf32>, tensor<16x8xf32>)
+         outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
+  return %r : tensor<16x8xf32>
+}
+
+// -----
+
+// CHECK: @binary_identity_mul_5Di(%[[A]]: tensor<1x2x3x4x5xi32>,
+// CHECK-SAME:                     %[[B:.+]]: tensor<1x2x3x4x5xi32>,
+// CHECK-SAME:                     %[[C:.+]]: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> {
+// CHECK: {{.*}} = linalg.elementwise
+// CHECK-SAME:       kind=#linalg.elementwise_kind<mul>
+// CHECK-SAME:       ins(%[[A]], %[[B]] : tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
+// CHECK-SAME:       outs(%[[C]] : tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
+//
+func.func @binary_identity_mul_5Di(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
+                                   %C: tensor<1x2x3x4x5xi32>) ->  tensor<1x2x3x4x5xi32> {
+  %r = linalg.elementwise
+         kind=#linalg.elementwise_kind<mul>
+         ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
+         outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
+  return %r : tensor<1x2x3x4x5xi32>
+}
+
+// -----
+
+// CHECK: @redundant_maps
+// CHECK-NOT: indexing_maps
+//
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
+                          %C: tensor<1x2x3x4x5xi32>) ->  tensor<1x2x3x4x5xi32> {
+  %r = linalg.elementwise
+         kind=#linalg.elementwise_kind<mul>
+         indexing_maps = [#map, #map, #map]
+         ins(%A, %B: tensor<1x2x3x4x5xi32>, tensor<1x2x3x4x5xi32>)
+         outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
+  return %r : tensor<1x2x3x4x5xi32>
+}


        


More information about the Mlir-commits mailing list