[Mlir-commits] [mlir] 6c9541d - Implement simple type polymorphism for linalg named ops.

Stella Laurenzo llvmlistbot at llvm.org
Sun Feb 21 14:33:42 PST 2021


Author: Stella Laurenzo
Date: 2021-02-21T14:30:31-08:00
New Revision: 6c9541d4ddfdab0dcb11436485b466a759c3126c

URL: https://github.com/llvm/llvm-project/commit/6c9541d4ddfdab0dcb11436485b466a759c3126c
DIFF: https://github.com/llvm/llvm-project/commit/6c9541d4ddfdab0dcb11436485b466a759c3126c.diff

LOG: Implement simple type polymorphism for linalg named ops.

* It was decided that this was the end of the line for the existing custom tc parser/generator, and this is the first step to replacing it with a declarative format that maps well to mathy source languages.
* One such source language is implemented here: https://github.com/stellaraccident/mlir-linalgpy/blob/main/samples/mm.py
  * In fact, this is the exact source of the declarative `polymorphic_matmul` in this change.
  * I am working separately to clean this python implementation up and add it to MLIR (probably as `mlir.tools.linalg_opgen` or equiv). The scope of the python side is greater than just generating named ops: the ops are callable and directly emit `linalg.generic` ops fully dynamically, and this is intended to be a feature for frontends like npcomp to define custom linear algebra ops at runtime.
* There is more work required to handle full type polymorphism, especially with respect to integer formulations, since they require more specificity wrt types.
* Followups to this change will bring the new generator to feature parity with the current one and delete the current. Roughly, this involves adding support for interface declarations and attribute symbol bindings.

Differential Revision: https://reviews.llvm.org/D97135

Added: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Modified: 
    mlir/docs/Dialects/Linalg.md
    mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
index dc9353c9748b..4efb0d1c9f00 100644
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -662,6 +662,18 @@ void batchmatmul::regionBuilder(ArrayRef<BlockArgument> args) {
 }
 ```
 
+### YAML Based Named Structured Ops
+
+Linalg provides a declarative generation tool (`mlir-linalg-ods-yaml-gen`) to
+automatically produce named ops from a YAML-based op description format
+intended to capture the structure of the named ops and be generated from a
+higher level "mathy" DSL syntax. This facility is currently in flight and is
+intended to subsume the above when ready. See the C++ class to YAML mapping
+traits in `mlir-mlinalg-ods-yaml-gen.cpp` as the source of truth for the schema.
+
+Most of the above documentation roughly applies to this path and will be ported
+as migration continues.
+
 ## Open Issues and Design Alternatives<a name="open_issues"></a>
 
 Multiple open issues and design alternatives are in flight and it is time to lay

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
index f1691718025c..ffe7619ce983 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
@@ -1,8 +1,8 @@
 # Declare a function to generate ODS with mlir-linalg-ods-gen
-function(add_linalg_ods_gen tc_filename output_file)
+function(add_linalg_ods_tc_gen tc_filename output_file)
   set(TC_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${tc_filename})
-  set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.td)
-  set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.cpp.inc)
+  set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.tcgen.td)
+  set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.tcgen.cpp.inc)
   set_source_files_properties(
     ${GEN_ODS_FILE}
     PROPERTIES GENERATED TRUE)
@@ -20,17 +20,52 @@ function(add_linalg_ods_gen tc_filename output_file)
     ${MLIR_LINALG_ODS_GEN_TARGET}
     VERBATIM)
   add_custom_target(
-    MLIR${output_file}IncGen
+    MLIR${output_file}TcIncGen
     DEPENDS
     ${MLIR_LINALG_ODS_GEN_EXE}
     ${MLIR_LINALG_ODS_GEN_TARGET}
     ${GEN_ODS_FILE} ${GEN_CPP_FILE})
 endfunction()
 
-add_linalg_ods_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps)
+# Declare a function to generate ODS with mlir-linalg-ods-yaml-gen
+function(add_linalg_ods_yaml_gen yaml_ast_file output_file)
+  set(YAML_AST_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${yaml_ast_file})
+  set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.yamlgen.td)
+  set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.yamlgen.cpp.inc)
+  set_source_files_properties(
+    ${GEN_ODS_FILE}
+    PROPERTIES GENERATED TRUE)
+  set_source_files_properties(
+    ${GEN_CPP_FILE}
+    PROPERTIES GENERATED TRUE)
+  add_custom_command(
+    OUTPUT ${GEN_ODS_FILE} ${GEN_CPP_FILE}
+    COMMAND ${MLIR_LINALG_ODS_YAML_GEN_EXE} ${YAML_AST_SOURCE} -o-ods-decl=${GEN_ODS_FILE} -o-impl=${GEN_CPP_FILE}
+    MAIN_DEPENDENCY
+    ${YAML_AST_SOURCE}
+    DEPENDS
+    ${MLIR_LINALG_ODS_YAML_GEN_EXE}
+    ${MLIR_LINALG_ODS_YAML_GEN_TARGET})
+  add_custom_target(
+    MLIR${output_file}YamlIncGen
+    DEPENDS
+    ${MLIR_LINALG_ODS_YAML_GEN_EXE}
+    ${MLIR_LINALG_ODS_YAML_GEN_TARGET}
+    ${GEN_ODS_FILE} ${GEN_CPP_FILE})
+endfunction()
+
+# TODO: Delete tc generation and replace with the YAML variant once all ops are
+# ported.
+add_linalg_ods_tc_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps)
+add_linalg_ods_yaml_gen(LinalgNamedStructuredOps.yaml LinalgNamedStructuredOps)
+
 # Provide a short name for all external dependency that needs to
 # include Linalg in ODS
-add_custom_target(LinalgOdsGen DEPENDS MLIRLinalgNamedStructuredOpsIncGen)
+add_custom_target(LinalgOdsGen
+  DEPENDS
+  MLIRLinalgNamedStructuredOpsTcIncGen
+  MLIRLinalgNamedStructuredOpsYamlIncGen
+)
 add_dependencies(mlir-headers LinalgOdsGen)
 
 add_mlir_dialect(LinalgOps linalg)

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
new file mode 100644
index 000000000000..02f08145da92
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -0,0 +1,50 @@
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: polymorphic_matmul
+  cpp_op_name: PolymorphicMatmulOp
+  doc: |-
+    Type polymorphic matrix multiplication.
+
+    This op is presently here to test a new path for generation and will replace
+    the existing 'matmul' op when ready. Do not use.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !<LinalgTensorDef>
+    name: A
+    usage: input
+    shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+  - !<LinalgTensorDef>
+    name: B
+    usage: input
+    shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+  - !<LinalgTensorDef>
+    name: C
+    usage: output
+    shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_arg: A
+            - !ScalarExpression
+              scalar_arg: B
+

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 05a6bb766dd0..cb370926e810 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -343,7 +343,7 @@ def ConvOp : PoolingBase_Op<"conv", []> {
       // parallelized across; i.e. [zs] in the TF notation above whose number
       // match `xs` (i.e. 1 window loop per "image" dimension).
       // This may evolve in the future.
-      // Conditionally check nPar is large enough for cases of ill-formed op: 
+      // Conditionally check nPar is large enough for cases of ill-formed op:
       // this avoids overflows before hitting the verifier.
       assert(nPar > getNumBatchDimensions() + getNumInputFeatureDimensions() &&
              "expected at least one window dimension (i.e. memref ranks greater "
@@ -806,6 +806,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
 //===----------------------------------------------------------------------===//
 
 // This file is auto-generated from a TC def specification.
-include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.td"
+include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.td"
 
 #endif // LINALG_STRUCTURED_OPS

diff  --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index 8522919bacb3..2be8b5dd624f 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalg
   LINK_LIBS PUBLIC
   MLIRAffine
   MLIRIR
+  MLIRParser
   MLIRSideEffectInterfaces
   MLIRViewLikeInterface
   MLIRStandard

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a5a9c76ff0f6..36e73bbabc37 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Parser.h"
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SetVector.h"
@@ -121,6 +122,81 @@ static LogicalResult foldMemRefCast(Operation *op) {
   return success(folded);
 }
 
+//===----------------------------------------------------------------------===//
+// Region builder helper.
+// TODO: Move this to a utility library.
+// The public methods on this class are referenced directly from generated code
+// and bind by name to math functions in the DSL as:
+//   `applyfn__{fnName}`
+// Examples:
+//   `applyfn__add`
+//   `applyfn__mul`
+// The naming convention is intentional in order to match snake-cased DSL names.
+// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
+//
+// Implementations of the math functions must be polymorphic over numeric types,
+// internally performing necessary casts. If the function application makes no
+// sense, then the only recourse is to assert and return nullptr. This can be
+// extended later if it becomes possible to fail construction of the region. The
+// invariant should be enforced at a higher level.
+//
+// TODO: These helpers are currently type polymorphic over the class of integer
+// and floating point types, but they will not internally cast within bit
+// widths of a class (mixed precision such as i8->i32) or across classes
+// (i.e. mixed float and integer). Many such combinations are ambiguous or need
+// to be handled with care and work is being considered to extend the op
+// language to make such cases explicit. In the mean-time, violating this will
+// fail verification, which is deemed acceptable.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class RegionBuilderHelper {
+public:
+  RegionBuilderHelper(Block &block) : block(block) {}
+
+  Value applyfn__add(Value lhs, Value rhs) {
+    OpBuilder builder = getBuilder(lhs);
+    if (isFloatingPoint(lhs))
+      return builder.create<AddFOp>(lhs.getLoc(), lhs, rhs);
+    else if (isInteger(lhs))
+      return builder.create<AddIOp>(lhs.getLoc(), lhs, rhs);
+    llvm_unreachable("unsupported non numeric type");
+  }
+
+  Value applyfn__mul(Value lhs, Value rhs) {
+    OpBuilder builder = getBuilder(lhs);
+    if (isFloatingPoint(lhs))
+      return builder.create<MulFOp>(lhs.getLoc(), lhs, rhs);
+    else if (isInteger(lhs))
+      return builder.create<MulIOp>(lhs.getLoc(), lhs, rhs);
+    llvm_unreachable("unsupported non numeric type");
+  }
+
+  void yieldOutputs(ValueRange values) {
+    assert(!values.empty() && "linalg ops must yield outputs");
+    if (values.empty())
+      return;
+    Value first = values.front();
+    OpBuilder builder = getBuilder(first);
+    builder.create<YieldOp>(first.getLoc(), values);
+  }
+
+private:
+  Block █
+
+  bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
+  bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
+
+  OpBuilder getBuilder(Value value) {
+    OpBuilder builder(value.getContext());
+    builder.setInsertionPointToEnd(&block);
+    return builder;
+  }
+};
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
@@ -1868,7 +1944,8 @@ struct EraseDeadLinalgOp;
 struct FoldTensorCastOp;
 } // namespace
 
-#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
@@ -2032,7 +2109,8 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   unsigned actual = body->getNumArguments();
   unsigned expected = NamedStructuredOpType::getNumRegionArgs();
   if (expected != actual) {
-    if (errorHandler) errorHandler(expected, actual);
+    if (errorHandler)
+      errorHandler(expected, actual);
     return;
   }
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
new file mode 100644
index 000000000000..186fb9627219
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
+
+func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
+                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+
+func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
+                          outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
+  return %0: tensor<16x32xi32>
+}
+
+// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i32)
+// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_ARG]], %[[B_ARG]] : i32
+// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
+// CHECK-NEXT: -> tensor<16x32xi32>

diff  --git a/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt
index c1a594389beb..e0c2a90fa8eb 100644
--- a/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt
+++ b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt
@@ -2,6 +2,13 @@ set(LLVM_LINK_COMPONENTS
   Core
   Support
   )
+
+set(LLVM_OPTIONAL_SOURCES
+  mlir-linalg-ods-gen.cpp
+  mlir-linalg-ods-yaml-gen.cpp
+)
+
+# Original mlir-linalg-ods-gen (to be replaced).
 add_llvm_tool(mlir-linalg-ods-gen
   mlir-linalg-ods-gen.cpp
 )
@@ -30,3 +37,35 @@ if(LLVM_USE_HOST_TOOLS)
     endif()
   endif()
 endif()
+
+
+# New mlir-linalg-ods-yaml-gen.
+add_llvm_tool(mlir-linalg-ods-yaml-gen
+  mlir-linalg-ods-yaml-gen.cpp
+)
+llvm_update_compile_flags(mlir-linalg-ods-yaml-gen)
+target_link_libraries(mlir-linalg-ods-yaml-gen PRIVATE
+  MLIRIR
+  MLIRSupport
+  MLIRParser
+  )
+
+set(MLIR_LINALG_ODS_YAML_GEN mlir-linalg-ods-yaml-gen CACHE
+  STRING "Native mlir-linalg-ods-yaml-gen executable. Saves building one when cross-compiling.")
+
+set(MLIR_LINALG_ODS_YAML_GEN_EXE ${MLIR_LINALG_ODS_YAML_GEN} PARENT_SCOPE)
+set(MLIR_LINALG_ODS_YAML_GEN_TARGET mlir-linalg-ods-yaml-gen PARENT_SCOPE)
+
+if(LLVM_USE_HOST_TOOLS)
+if ("${MLIR_LINALG_ODS_YAML_GEN_EXE}" STREQUAL mlir-linalg-ods-yaml-gen)
+  build_native_tool(mlir-linalg-ods-yaml-gen MLIR_LINALG_ODS_YAML_GEN_EXE DEPENDS mlir-linalg-ods-yaml-gen)
+  set(MLIR_LINALG_ODS_YAML_GEN_EXE ${MLIR_LINALG_ODS_YAML_GEN_EXE} PARENT_SCOPE)
+
+  add_custom_target(mlir-linalg-ods-yaml-gen-host DEPENDS ${MLIR_LINALG_ODS_YAML_GEN_EXE})
+  set(MLIR_LINALG_ODS_YAML_GEN_TARGET mlir-linalg-ods-yaml-gen-host DEPENDS PARENT_SCOPE)
+
+  if(NOT LLVM_BUILD_UTILS)
+    set_target_properties(mlir-linalg-ods-yaml-gen PROPERTIES EXCLUDE_FROM_ALL ON)
+  endif()
+endif()
+endif()

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
new file mode 100644
index 000000000000..582eecde4f24
--- /dev/null
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -0,0 +1,878 @@
+//===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml  ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements an ODS (and C++) generator from a YAML form
+// derived from the mathematical expression of linalg named ops. Typically a
+// math oriented DSL will be used to export the essential representation to
+// this form, and maintaining the SOT at the math level (versus recreating it
+// in MLIR) is deemed to have systemic value.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "llvm/Support/YAMLTraits.h"
+
+using namespace mlir;
+
+using llvm::yaml::Input;
+using llvm::yaml::IO;
+using llvm::yaml::MappingTraits;
+using llvm::yaml::ScalarEnumerationTraits;
+using llvm::yaml::ScalarTraits;
+
+#define DEBUG_TYPE "linalg-ods-gen"
+
+//===----------------------------------------------------------------------===//
+// Mapping structs (correspond to data types in the YAML description).
+// TODO: Since this is a schema/part of the contract, it should be moved to
+// a real header.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct LinalgYAMLContext {
+  MLIRContext *mlirContext;
+};
+
+struct LinalgOpMetadata {
+  std::string name;
+  std::string cppOpName;
+  Optional<std::string> doc;
+};
+
+struct SerializedAffineMap {
+  AffineMapAttr affineMapAttr;
+
+  AffineMap affineMap() { return affineMapAttr.getValue(); }
+};
+
+enum class LinalgTensorUsageDef {
+  input,
+  output,
+  temporary,
+};
+
+struct LinalgTensorDef {
+  std::string name;
+  LinalgTensorUsageDef usage;
+  SerializedAffineMap shape;
+};
+
+enum class LinalgIteratorTypeDef {
+  parallel,
+  reduction,
+};
+
+struct LinalgIndexingMapsConfig {
+  Optional<SmallVector<SerializedAffineMap>> staticIndexingMaps;
+};
+
+struct ScalarExpression;
+
+struct ScalarApply {
+  std::string fnName;
+  // NOTE: Must be pure heap allocated container (not SmallVector)
+  // due to recursive data type.
+  std::vector<ScalarExpression> operands;
+};
+
+struct ScalarExpression {
+  Optional<std::string> scalarArg;
+  Optional<ScalarApply> scalarApply;
+};
+
+struct ScalarAssign {
+  std::string arg;
+  ScalarExpression value;
+};
+
+struct LinalgStructuredOpConfig {
+  SmallVector<LinalgTensorDef> args;
+  LinalgIndexingMapsConfig indexingMaps;
+  SmallVector<LinalgIteratorTypeDef> iteratorTypes;
+  SmallVector<ScalarAssign> assignments;
+};
+
+struct LinalgOpConfig {
+  Optional<LinalgOpMetadata> metadata;
+  Optional<LinalgStructuredOpConfig> structuredOp;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Mapping traits.
+//===----------------------------------------------------------------------===//
+
+LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgTensorDef);
+LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap);
+LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef);
+LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign);
+LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression);
+LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig);
+
+namespace llvm {
+namespace yaml {
+
+/// Top-level type containing op metadata and one of a concrete op type.
+/// Currently, the only defined op type is `structured_op` (maps to
+/// `LinalgStructuredOpConfig`).
+template <>
+struct MappingTraits<LinalgOpConfig> {
+  static void mapping(IO &io, LinalgOpConfig &info) {
+    io.mapOptional("metadata", info.metadata);
+    io.mapOptional("structured_op", info.structuredOp);
+  }
+};
+
+/// A structured op models (at most) a single contraction by modeling
+///   - A list of named arguments (`LinalgTensorDef`), which can be inputs,
+///     outputs, or temporaries.
+///   - List of indexing maps (see `LinalgIndexingMaps`).
+///   - Iterator types (see `LinalgIteratorTypeDef`).
+///   - List of scalar level assignment (see `ScalarAssign`).
+template <>
+struct MappingTraits<LinalgStructuredOpConfig> {
+  static void mapping(IO &io, LinalgStructuredOpConfig &info) {
+    io.mapRequired("args", info.args);
+    io.mapRequired("indexing_maps", info.indexingMaps);
+    io.mapRequired("iterator_types", info.iteratorTypes);
+    io.mapRequired("assignments", info.assignments);
+  }
+};
+
+/// Maps a named tensor-argument to an operation, consisting of:
+///   - `name`: Must be unique within the operation.
+///   - `usage`: How the argument is used (input, output, etc).
+///   - `shape`: An AffineMap from all op symbols to the specific shape
+///     of this argument. Each shape must be normalized over the same list of
+///     symbols and have no dimension inputs.
+template <>
+struct MappingTraits<LinalgTensorDef> {
+  static void mapping(IO &io, LinalgTensorDef &info) {
+    io.mapRequired("name", info.name);
+    io.mapRequired("usage", info.usage);
+    io.mapRequired("shape", info.shape);
+  }
+};
+
+/// Usage enum for a named argument.
+template <>
+struct ScalarEnumerationTraits<LinalgTensorUsageDef> {
+  static void enumeration(IO &io, LinalgTensorUsageDef &value) {
+    io.enumCase(value, "input", LinalgTensorUsageDef::input);
+    io.enumCase(value, "output", LinalgTensorUsageDef::output);
+    io.enumCase(value, "temporary", LinalgTensorUsageDef::temporary);
+  }
+};
+
+/// Iterator type enum.
+template <>
+struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
+  static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
+    io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel);
+    io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction);
+  }
+};
+
+/// Metadata about the op (name, C++ name, and documentation).
+template <>
+struct MappingTraits<LinalgOpMetadata> {
+  static void mapping(IO &io, LinalgOpMetadata &info) {
+    io.mapRequired("name", info.name);
+    io.mapRequired("cpp_op_name", info.cppOpName);
+    io.mapOptional("doc", info.doc);
+  }
+};
+
+/// How the ops indexing maps are produced. Must be one of:
+///   - static_indexing_maps: A static list of AffineMaps, possibly with
+///     some symbols that bind to attributes of the op. Each indexing map must
+///     be normalized over the same list of dimensions, and its symbols must
+///     match the symbols for argument shapes.
+template <>
+struct MappingTraits<LinalgIndexingMapsConfig> {
+  static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
+    io.mapOptional("static_indexing_maps", info.staticIndexingMaps);
+  }
+};
+
+/// Models an assignment to a named output.
+///   - The `arg` name must match a named output or temporary.
+///   - The `value` is a scalar expression for computing the value to
+///     assign (see `ScalarExpression`).
+template <>
+struct MappingTraits<ScalarAssign> {
+  static void mapping(IO &io, ScalarAssign &info) {
+    io.mapRequired("arg", info.arg);
+    io.mapRequired("value", info.value);
+  }
+};
+
+/// A scalar expression (RHS of an assignment). Must be one of:
+///   - `scalar_arg`: Name of an argument to the op.
+///   - `scalar_apply`: Result of evaluating a named function (see
+///      `ScalarApply`).
+template <>
+struct MappingTraits<ScalarExpression> {
+  static void mapping(IO &io, ScalarExpression &info) {
+    io.mapOptional("scalar_arg", info.scalarArg);
+    io.mapOptional("scalar_apply", info.scalarApply);
+  }
+};
+
+/// A scalar expression that evaluates a named function.
+/// Functions are generally "math" level and type polymorphic. Builtin
+/// functions include:
+///   - `add(lhs, rhs)`
+///   - `mul(lhs, rhs)`
+template <>
+struct MappingTraits<ScalarApply> {
+  static void mapping(IO &io, ScalarApply &info) {
+    io.mapRequired("fn_name", info.fnName);
+    io.mapRequired("operands", info.operands);
+  }
+};
+
+/// Helper mapping which accesses an AffineMapAttr as a serialized string of
+/// the same.
+template <>
+struct ScalarTraits<SerializedAffineMap> {
+  static void output(const SerializedAffineMap &value, void *rawYamlContext,
+                     raw_ostream &out) {
+    assert(value.affineMapAttr);
+    value.affineMapAttr.print(out);
+  }
+  static StringRef input(StringRef scalar, void *rawYamlContext,
+                         SerializedAffineMap &value) {
+    assert(rawYamlContext);
+    auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
+    if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext)
+                        .dyn_cast_or_null<AffineMapAttr>())
+      value.affineMapAttr = attr;
+    else if (!value.affineMapAttr || !value.affineMapAttr.isa<AffineMapAttr>())
+      return "could not parse as an affine map attribute";
+    return StringRef();
+  }
+  static QuotingType mustQuote(StringRef) { return QuotingType::None; }
+};
+
+} // namespace yaml
+} // namespace llvm
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Generation utilities
+//===----------------------------------------------------------------------===//
+
+class GenerationContext {
+public:
+  GenerationContext(MLIRContext *context, raw_ostream *odsOut,
+                    raw_ostream *defnOut)
+      : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut),
+        defnOut(defnOut) {}
+
+  MLIRContext *getContext() { return context; }
+
+  void setLoc(Location loc) { this->loc = loc; }
+  Location getLoc() { return loc; }
+
+  bool shouldGenerateOds() { return odsOut; }
+  bool shouldGenerateDefns() { return defnOut; }
+
+  raw_ostream &odss() {
+    assert(odsOut && "ODS stream not defined");
+    return *odsOut;
+  }
+
+  raw_ostream &defns() {
+    assert(defnOut && "Definition stream not defined");
+    return *defnOut;
+  }
+
+private:
+  MLIRContext *context;
+  Location loc;
+  raw_ostream *odsOut;
+  raw_ostream *defnOut;
+};
+
+} // namespace
+
+static std::string generateCppExpression(SerializedAffineMap self,
+                                         StringRef contextName) {
+  std::string printedStr;
+  llvm::raw_string_ostream printedSs(printedStr);
+  self.affineMapAttr.print(printedSs);
+  printedSs.flush();
+
+  static const char exprFormat[] =
+      R"FMT(mlir::parseAttribute("{0}", {1}).cast<AffineMapAttr>().getValue())FMT";
+  return llvm::formatv(exprFormat, printedStr, contextName);
+}
+
+template <typename Container>
+static std::string interleaveToString(Container &container,
+                                      StringRef separator) {
+  std::string result;
+  llvm::raw_string_ostream ss(result);
+  llvm::interleave(container, ss, separator);
+  ss.flush();
+  return result;
+}
+
+static Optional<int>
+findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgTensorDef> &args) {
+  for (auto it : llvm::enumerate(args)) {
+    if (it.value().name == name)
+      return it.index();
+  }
+  return None;
+}
+
+static ScalarAssign *
+findAssignment(StringRef name, SmallVectorImpl<ScalarAssign> &assignments) {
+  for (auto &assign : assignments) {
+    if (assign.arg == name)
+      return &assign;
+  }
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Templates
+//===----------------------------------------------------------------------===//
+
+// A single line banner format. Parameters:
+// {0}: Single line comment
+static const char bannerFormat[] = R"FMT(
+//===----------------------------------------------------------------------===//
+// {0}
+//===----------------------------------------------------------------------===//
+)FMT";
+
+//===----------------------------------------------------------------------===//
+// Named generic op generation.
+// These ops map at most a single contraction that complies with the limitations
+// of a linalg.generic.
+//===----------------------------------------------------------------------===//
+
+// Template for Linalg named ops' ODS definitions. Parameters:
+// {0}: ODS/C++ op name
+// {1}: assembly op mnemonic
+// {2}: op interface list
+// {3}: documentation (summary + description)
+// {4}: op attribute list
+// {5}: the number of arguments for the op region
+// {6}: builder methods taking standalone attribute parameters
+// {7}: additional methods for attributes used by indexing maps
+static const char structuredOpOdsHeaderFormat[] = R"FMT(
+//===----------------------------------------------------------------------===//
+// Op definition for {0}
+//===----------------------------------------------------------------------===//
+
+def {0} : LinalgStructuredBase_Op<"{1}", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  SingleBlockImplicitTerminator<"YieldOp">
+  /*extraInterfaces=*/{2}]> {
+    {3}
+    let arguments = (ins
+      Variadic<AnyShaped>:$inputs,
+      Variadic<AnyShaped>:$outputs{4}
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilderDAG<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs),
+      [{{
+        $_state.addOperands(inputs);
+        $_state.addOperands(outputs);
+        $_state.addAttribute(
+          "operand_segment_sizes",
+          $_builder.getI32VectorAttr({{
+            static_cast<int32_t>(inputs.size()),
+            static_cast<int32_t>(outputs.size())}));
+        createAndFillStructuredOpRegion<{0}>(
+          $_builder,
+          $_state,
+          TypeRange(inputs),
+          TypeRange(outputs)/*, TODO: support captures*/);
+      }]>,
+      OpBuilderDAG<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs),
+      [{{
+        $_state.addOperands(inputs);
+        $_state.addOperands(outputs);
+        $_state.addTypes(resultTensorTypes);
+        $_state.addAttribute(
+          "operand_segment_sizes",
+          $_builder.getI32VectorAttr({{
+            static_cast<int32_t>(inputs.size()),
+            static_cast<int32_t>(outputs.size())}));
+        createAndFillStructuredOpRegion<{0}>(
+          $_builder,
+          $_state,
+          TypeRange(inputs),
+          TypeRange(outputs)/*, TODO: support captures*/);
+      }]>,
+      OpBuilderDAG<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
+      [{{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>
+      {6}
+    ];
+    let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
+    let parser = [{{
+      return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
+    }];
+    let hasFolder = 1;
+    let hasCanonicalizer = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{{
+      // Auto-generated.
+      ArrayAttr iterator_types();
+      ArrayAttr indexing_maps();
+      static void regionBuilder(Block &block, ValueRange captures);
+      static std::function<void(Block &, ValueRange)> getRegionBuilder() {{
+        return regionBuilder;
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      {7}
+    }];
+}
+)FMT";
+
+// The iterator_types() method implementation. Parameters:
+// {0}: Class name
+// {1}: Comma interleaved iterator type names.
+static const char structuredOpIteratorTypesFormat[] =
+    R"FMT(
+ArrayAttr {0}::iterator_types() {
+  return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef>{{ {1} });
+}
+)FMT";
+
+// Implementations of getCanonicalizationPatterns, fold and getEffects.
+// Parameters:
+// {0}: Class name
+const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT(
+void {0}::getCanonicalizationPatterns(
+    OwningRewritePatternList &results,
+    MLIRContext *context) {{
+  results.insert<EraseDeadLinalgOp>();
+  results.insert<FoldTensorCastOp>();
+}
+LogicalResult {0}::fold(ArrayRef<Attribute>,
+                        SmallVectorImpl<OpFoldResult> &) {{
+  return foldMemRefCast(*this);
+}
+void {0}::getEffects(SmallVectorImpl<
+    SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
+  getGenericEffectsImpl(effects,
+    getOperation()->getResults(), getInputBuffers(), getOutputBuffers());
+}
+)FMT";
+
+static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
+                                               GenerationContext &genContext) {
+  if (!genContext.shouldGenerateOds())
+    return success();
+
+  raw_ostream &os = genContext.odss();
+
+  std::string interfaceNameList;
+  std::string attrList;
+  std::string attrMethods;
+  std::string attrBuilder;
+
+  std::string doc;
+  if (opConfig.metadata->doc) {
+    const char *docFmt = R"FMT(
+      let summary = [{ {0} }];
+      let description = [{
+        {1}
+      }];
+    )FMT";
+    StringRef summary, description;
+    std::tie(summary, description) =
+        StringRef(*opConfig.metadata->doc).trim().split('\n');
+    doc = llvm::formatv(docFmt, summary.trim(), description.trim());
+  }
+
+  os << llvm::formatv(structuredOpOdsHeaderFormat, opConfig.metadata->cppOpName,
+                      opConfig.metadata->name, interfaceNameList, doc, attrList,
+                      opConfig.structuredOp->args.size(), attrBuilder,
+                      attrMethods);
+
+  return success();
+}
+
+static LogicalResult
+generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
+                            GenerationContext &genContext) {
+  if (!genContext.shouldGenerateDefns())
+    return success();
+
+  raw_ostream &os = genContext.defns();
+  StringRef className = opConfig.metadata->cppOpName;
+
+  // Implementation banner.
+  std::string bannerComment = llvm::formatv("Implementation of {0}", className);
+  os << llvm::formatv(bannerFormat, bannerComment);
+
+  // Reference iterators.
+  {
+    std::string iteratorsStr;
+    llvm::raw_string_ostream ss(iteratorsStr);
+    llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss,
+                          [&](LinalgIteratorTypeDef it) {
+                            switch (it) {
+                            case LinalgIteratorTypeDef::parallel:
+                              ss << "getParallelIteratorTypeName()";
+                              break;
+                            case LinalgIteratorTypeDef::reduction:
+                              ss << "getReductionIteratorTypeName()";
+                              break;
+                            }
+                          });
+    ss.flush();
+    os << llvm::formatv(structuredOpIteratorTypesFormat, className,
+                        iteratorsStr);
+  }
+
+  // Static indexing maps.
+  if (auto &staticMaps =
+          opConfig.structuredOp->indexingMaps.staticIndexingMaps) {
+    if (staticMaps->empty())
+      return emitError(genContext.getLoc()) << "op has no indexing maps";
+    AffineMap firstMap = staticMaps->front().affineMap();
+
+    // Symbol bindings.
+    {
+      // For each symbol, generate a declaration for it, either with an
+      // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
+      // an attribute).
+      // TODO: Implement attribute constants.
+      // TODO: Possibly lift into a top-level method.
+      static const char structuredOpSymbolBindingsFormat[] = R"FMT(
+static SmallVector<AffineExpr> getSymbolBindings({0} self) {
+  MLIRContext *context = self.getContext();
+  SmallVector<AffineExpr> exprs;
+{1}
+  return exprs;
+}
+)FMT";
+
+      unsigned symbolCount = firstMap.getNumSymbols();
+      SmallVector<std::string> symbolBindings;
+      for (unsigned i = 0; i < symbolCount; ++i) {
+        // TODO: Switch and emit constants for attribute bound symbols.
+        symbolBindings.push_back(llvm::formatv(
+            "  exprs.push_back(getAffineSymbolExpr({0}, context));", i));
+      }
+      std::string symbolBindingsStr;
+      llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
+      llvm::interleave(symbolBindings, symbolBindingsSs, "\n");
+      symbolBindingsSs.flush();
+
+      os << llvm::formatv(structuredOpSymbolBindingsFormat, className,
+                          symbolBindingsStr);
+    }
+
+    // Indexing maps.
+    {
+      // Parameters:
+      // {0}: Class name
+      // {1}: Comma-separated list of dimension variable names.
+      // {2}: Statements
+      static const char structuredOpIndexingMapsFormat[] = R"FMT(
+ArrayAttr {0}::indexing_maps() {
+  MLIRContext *context = getContext();
+  auto symbolBindings = getSymbolBindings(*this);
+  SmallVector<AffineMap> maps;
+  {2}
+  return Builder(context).getAffineMapArrayAttr(maps);
+}
+)FMT";
+
+      unsigned dimCount = firstMap.getNumDims();
+
+      // Generate a comma-separated list of dim identifiers to be passed to
+      // bindDims, ensuring tht AffineExpr identifiers are bound in the right
+      // order to the proper AffineDimExpr.
+      // This results in vars in scope like: d0, d1, d2...
+      SmallVector<unsigned> dimIndices;
+      for (unsigned i = 0; i < dimCount; ++i)
+        dimIndices.push_back(i);
+      std::string dimIdentsStr;
+      llvm::raw_string_ostream dimIdentsSs(dimIdentsStr);
+      llvm::interleaveComma(dimIndices, dimIdentsSs,
+                            [&](unsigned i) { dimIdentsSs << "d" << i; });
+      dimIdentsSs.flush();
+
+      // Statements to add and simplify each affine map.
+      SmallVector<std::string> stmts;
+      for (auto &indexingMap : *staticMaps) {
+        // TODO: Assert that dim and symbol count match the first.
+        stmts.push_back(
+            llvm::formatv("maps.push_back({0});",
+                          generateCppExpression(indexingMap, "context")));
+        stmts.push_back(llvm::formatv(
+            "maps.back() = "
+            "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
+            "symbolBindings, {0}, 0));",
+            dimCount));
+      }
+
+      // TODO: This needs to be memoized and/or converted to non-parser based
+      // C++ codegen prior to real use.
+      os << llvm::formatv(structuredOpIndexingMapsFormat, className,
+                          dimIdentsStr, interleaveToString(stmts, "\n  "));
+    }
+  } else {
+    return emitError(genContext.getLoc())
+           << "generating code for non static indexing maps not currently "
+              "supported";
+  }
+
+  // getNumRegionArgs()
+  {
+    // Generates a getNumRegionArgs() method. Parameters:
+    // {0}: Class name
+    // {1}: Number of region args
+    static const char structuredOpGetNumRegionArgsFormat[] = R"FMT(
+unsigned {0}::getNumRegionArgs() {{ return {1}; }
+)FMT";
+    os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className,
+                        opConfig.structuredOp->args.size());
+  }
+
+  // getLibraryCallName()
+  {
+    // Generates a getLibraryCallName method. Parameters:
+    // {0}: Class name
+    static const char structuredOpGetLibraryCallFormat[] = R"FMT(
+std::string {0}::getLibraryCallName() {{
+  return generateLibraryCallName(getOperation());
+}
+)FMT";
+    os << llvm::formatv(structuredOpGetLibraryCallFormat, className);
+  }
+
+  // regionBuilder()
+  {
+    // Generates a regionBuilder method. Parameters.
+    // {0}: Class name
+    // {1}: Statements
+    static const char structuredOpRegionBuilderFormat[] = R"FMT(
+void {0}::regionBuilder(Block &block, ValueRange captures) {{
+  RegionBuilderHelper helper(block);
+  SmallVector<Value> yields;
+  {1}
+  helper.yieldOutputs(yields);
+}
+)FMT";
+    auto &args = opConfig.structuredOp->args;
+    auto &assignments = opConfig.structuredOp->assignments;
+    size_t generatedAssignmentCount = 0;
+    int localCounter = 0;
+    SmallVector<std::string> stmts;
+    for (LinalgTensorDef &arg : args) {
+      if (arg.usage != LinalgTensorUsageDef::output &&
+          arg.usage != LinalgTensorUsageDef::temporary)
+        continue;
+
+      // Find the assignment that correlates with the argument.
+      ScalarAssign *assignment = findAssignment(arg.name, assignments);
+      if (!assignment)
+        return emitError(genContext.getLoc())
+               << "no assignment found for output argument " << arg.name;
+      ++generatedAssignmentCount;
+
+      // Recursively generate the expression.
+      std::function<Optional<std::string>(ScalarExpression &)>
+          generateExpression =
+              [&](ScalarExpression &expression) -> Optional<std::string> {
+        if (expression.scalarArg) {
+          Optional<int> argIndex =
+              findTensorDefArgIndex(*expression.scalarArg, args);
+          if (!argIndex) {
+            emitError(genContext.getLoc())
+                << "scalar argument not defined on the op: " << arg.name;
+            return None;
+          }
+          return std::string(
+              llvm::formatv("block.getArgument({0})", *argIndex));
+        } else if (expression.scalarApply) {
+          // Recursively generate operands.
+          SmallVector<std::string> operandCppValues;
+          for (ScalarExpression &operand : expression.scalarApply->operands) {
+            auto operandCppValue = generateExpression(operand);
+            if (!operandCppValue)
+              return None;
+            operandCppValues.push_back(*operandCppValue);
+          }
+          std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
+          stmts.push_back(
+              llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent,
+                            expression.scalarApply->fnName,
+                            interleaveToString(operandCppValues, ", ")));
+          return cppIdent;
+        } else {
+          emitError(genContext.getLoc()) << "unknown ScalarExpression type";
+          return None;
+        }
+      };
+      Optional<std::string> cppValue = generateExpression(assignment->value);
+      if (!cppValue)
+        return failure();
+      stmts.push_back(llvm::formatv("yields.push_back({0});", cppValue));
+    }
+
+    if (generatedAssignmentCount != assignments.size())
+      return emitError(genContext.getLoc())
+             << "mismatched number of assignments vs output arguments";
+
+    os << llvm::formatv(structuredOpRegionBuilderFormat, className,
+                        interleaveToString(stmts, "\n  "));
+  }
+
+  // Canonicalizers and folders.
+  os << llvm::formatv(structuredOpCanonicalizersAndFoldersFormat, className);
+
+  return success();
+}
+
+static LogicalResult generateOp(LinalgOpConfig &opConfig,
+                                GenerationContext &genContext) {
+  // Switch on op type being generated.
+  if (opConfig.structuredOp) {
+    return success(
+        succeeded(generateNamedGenericOpOds(opConfig, genContext)) &&
+        succeeded(generateNamedGenericOpDefns(opConfig, genContext)));
+  } else {
+    return emitError(genContext.getLoc()) << "unsupported operation type";
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Command line options and main
+//===----------------------------------------------------------------------===//
+
+static llvm::cl::opt<std::string>
+    inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
+                  llvm::cl::init("-"), llvm::cl::value_desc("YAML filename"));
+
+static llvm::cl::opt<std::string>
+    outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
+                          llvm::cl::value_desc("filename"), llvm::cl::init(""));
+
+static llvm::cl::opt<std::string>
+    outputCppImplFilename("o-impl",
+                          llvm::cl::desc("C++ implementation file name"),
+                          llvm::cl::value_desc("filename"), llvm::cl::init(""));
+
+int main(int argc, char **argv) {
+  llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML");
+
+  // Set up the input file.
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> file =
+      mlir::openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return 1;
+  }
+
+  MLIRContext mlirContext;
+  LinalgYAMLContext yamlContext{&mlirContext};
+
+  std::vector<LinalgOpConfig> opConfigs;
+
+  // Parse input.
+  Input yin(file->getBuffer(), &yamlContext);
+  yin >> opConfigs;
+
+  if (yin.error())
+    return 1;
+
+  // Open output files.
+  std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl;
+  if (!outputOdsDeclFilename.empty()) {
+    outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage);
+    if (!outputOdsDecl) {
+      llvm::errs() << errorMessage << "\n";
+      return 1;
+    }
+  }
+
+  std::unique_ptr<llvm::ToolOutputFile> outputCppImpl;
+  if (!outputCppImplFilename.empty()) {
+    outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage);
+    if (!outputCppImpl) {
+      llvm::errs() << errorMessage << "\n";
+      return 1;
+    }
+  }
+
+  if (!outputOdsDecl && !outputCppImpl) {
+    llvm::errs() << "error: No output files specified\n";
+    return 1;
+  }
+
+  // Generate.
+  GenerationContext genContext(&mlirContext,
+                               outputOdsDecl ? &outputOdsDecl->os() : nullptr,
+                               outputCppImpl ? &outputCppImpl->os() : nullptr);
+
+  for (auto &opConfig : opConfigs) {
+    if (!opConfig.metadata) {
+      emitError(genContext.getLoc())
+          << "missing operation metadata on subsequent op";
+      return 1;
+    }
+
+    genContext.setLoc(NameLoc::get(
+        Identifier::get(opConfig.metadata->cppOpName, &mlirContext),
+        &mlirContext));
+    if (failed(generateOp(opConfig, genContext))) {
+      return 1;
+    }
+  }
+
+  if (outputOdsDecl)
+    outputOdsDecl->keep();
+  if (outputCppImpl)
+    outputCppImpl->keep();
+
+  return 0;
+}


        


More information about the Mlir-commits mailing list