[Mlir-commits] [mlir] [mlir][linalg] Extend Linalg elemwise named ops semantics (PR #122753)

Andrzej Warzyński llvmlistbot at llvm.org
Wed Jan 15 11:12:41 PST 2025


================
@@ -551,6 +551,136 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for ElemwiseOp - with user-defined maps, computation type etc.
+//===----------------------------------------------------------------------===//
+
+def ElemwiseOp : LinalgStructuredBase_Op<"elemwise", [
+                   AttrSizedOperandSegments]> {
+  let summary = [{ Performs element-wise operation }];
+  let description = [{
+    Linalg op form which performs element-wise computation. The attribute
+    `func_type` describes the operation type (e.g. add, exp). The func_type
+    can be any valid unary, binary, or ternary operation.
+
+    Affine-maps for operands and result may be provided by the user. When
+    a user-defined indexing_map is not provided, identity map is inferred
+    for all operands. The default indexing maps are N identity-maps. ‘N’
+    depends on the arity of the elementwise op. The number of dims is
+    inferred from rank of the output type. In the case of default indexing
+    map, the input and output shapes must all match. Affine-map for operands
+    and result must be only projected permutations with no zero constants.
+
+    For element-wise iterator-type is always inferred as all ‘parallel’.
+    Iterator-type is needed for constructing this underlying structured op.
+    The number of dims of the iterator-type is inferred from the rank of
+    the result type.
+
+    Example:
+    Defining a unary linalg.elemwise with default indexing-map:
+
+      ```mlir
+      %exp = linalg.elemwise
+             func_type=#linalg.elemwise_fn<exp>
+             ins(%x : tensor<4x16x8xf32>)
+             outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
+      ```
+
+    Defining a binary linalg.elemwise with user-defined indexing-map:
+
+    ```mlir
+    %add = linalg.elemwise
+            func_type=#linalg.elemwise_fn<add>
+            indexing_maps = [#transpose, #broadcast, #identity]
+            ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
+            outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
+    ```
+  }];
+
+  let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      ElemwiseFnAttr:$func_type,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
+    );
+
+  let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+  let regions = (region AnyRegion:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildElemwiseOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, ElemwiseOp::getRegionBuilder());
+      }]>
+    ];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+
+      /// Get the nary category enum, e.g. `ElemwiseNAryCategory::Unary`,
+      /// corresponding to the given fn, e.g. `ElemwiseFn::exp`
+      static ElemwiseNAryCategory getNAryCategory(ElemwiseFn fn);
+
+      /// Elementwise is always `dynamic indexing maps` i.e. `user specified`
+      /// or `default`. Default is identity-maps.
+      static bool hasDynamicIndexingMaps() { return true; }
+
+      /// Implements the block region builder for the eemwiseOp. This is called
+      /// by the 'fillStructuredOpRegion'.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      /// Returns elementwise op kind e.g. `add` inferred from func_type attr.
+      ElemwiseFn getElemwiseFnVal() {
+        return getFuncType();
+      }
+
+      /// Infer dimensionality of the `iteration space` from the result type.
+      /// Useful when others means are not possible e.g. in case of absence of
+      /// user-provided indexing map.
+      unsigned getResultRank();
+
+      /// Elementwise op does not have to explicitly specify iterator type
+      /// as it is always 'parallel'. The number of 'parallel' loops is
+      /// inferred from other means (e.g. result tensor type).
----------------
banach-space wrote:

This part is already included the general Op description above:
>       /// Elementwise op does not have to explicitly specify iterator type
>     /// as it is always 'parallel'

Please only document _what_ `getIteratorTypesArray` does.

https://github.com/llvm/llvm-project/pull/122753


More information about the Mlir-commits mailing list