[Mlir-commits] [mlir] ff2ef4d - [mlir][linalg] Adapt yaml codegen to support scalar parameters.

Tobias Gysi llvmlistbot at llvm.org
Tue Jun 15 08:22:43 PDT 2021


Author: Tobias Gysi
Date: 2021-06-15T15:20:48Z
New Revision: ff2ef4d684821c373e989105ac51eeeca9c2027e

URL: https://github.com/llvm/llvm-project/commit/ff2ef4d684821c373e989105ac51eeeca9c2027e
DIFF: https://github.com/llvm/llvm-project/commit/ff2ef4d684821c373e989105ac51eeeca9c2027e.diff

LOG: [mlir][linalg] Adapt yaml codegen to support scalar parameters.

The patch updates the C++ yaml code generation to support scalar operands as added in https://reviews.llvm.org/D104220.

Differential Revision: https://reviews.llvm.org/D104224

Added: 
    

Modified: 
    mlir/docs/Tools/LinalgOpDsl.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/test/python/dialects/linalg/opsrun.py
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tools/LinalgOpDsl.md b/mlir/docs/Tools/LinalgOpDsl.md
index d944fc83829d5..3ae9c9b7f45fd 100644
--- a/mlir/docs/Tools/LinalgOpDsl.md
+++ b/mlir/docs/Tools/LinalgOpDsl.md
@@ -19,7 +19,7 @@ package, if available, to avoid building.
 
 ```shell
 # Dump the `core_named_ops.py` module as YAML.
-python -m python -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops
+python -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops
 ```
 
 The tool is meant for use during both development and runtime, but not as

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 7e8d560e9bca6..9f5bf5d62755d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -11,21 +11,21 @@ metadata: !LinalgOpMetadata
   - LinalgContractionOpInterface
 structured_op: !LinalgStructuredOpConfig
   args:
-  - !<LinalgTensorDef>
+  - !LinalgOperandDefConfig
     name: A
     usage: input
     shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
-    element_type_var: T1
-  - !<LinalgTensorDef>
+    type_var: T1
+  - !LinalgOperandDefConfig
     name: B
     usage: input
     shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
-    element_type_var: T2
-  - !<LinalgTensorDef>
+    type_var: T2
+  - !LinalgOperandDefConfig
     name: C
     usage: output
     shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
-    element_type_var: U
+    type_var: U
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
@@ -73,21 +73,21 @@ metadata: !LinalgOpMetadata
   - LinalgContractionOpInterface
 structured_op: !LinalgStructuredOpConfig
   args:
-  - !<LinalgTensorDef>
+  - !LinalgOperandDefConfig
     name: A
     usage: input
     shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
-    element_type_var: T1
-  - !<LinalgTensorDef>
+    type_var: T1
+  - !LinalgOperandDefConfig
     name: B
     usage: input
     shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
-    element_type_var: T2
-  - !<LinalgTensorDef>
+    type_var: T2
+  - !LinalgOperandDefConfig
     name: C
     usage: output
     shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
-    element_type_var: U
+    type_var: U
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
@@ -136,21 +136,21 @@ metadata: !LinalgOpMetadata
   - LinalgContractionOpInterface
 structured_op: !LinalgStructuredOpConfig
   args:
-  - !<LinalgTensorDef>
+  - !LinalgOperandDefConfig
     name: A
     usage: input
     shape: affine_map<()[s0, s1] -> (s0, s1)>
-    element_type_var: T1
-  - !<LinalgTensorDef>
+    type_var: T1
+  - !LinalgOperandDefConfig
     name: y
     usage: input
     shape: affine_map<()[s0, s1] -> (s1)>
-    element_type_var: T2
-  - !<LinalgTensorDef>
+    type_var: T2
+  - !LinalgOperandDefConfig
     name: x
     usage: output
     shape: affine_map<()[s0, s1] -> (s0)>
-    element_type_var: U
+    type_var: U
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -197,21 +197,21 @@ metadata: !LinalgOpMetadata
   - LinalgContractionOpInterface
 structured_op: !LinalgStructuredOpConfig
   args:
-  - !<LinalgTensorDef>
+  - !LinalgOperandDefConfig
     name: y
     usage: input
     shape: affine_map<()[s0, s1] -> (s1)>
-    element_type_var: T1
-  - !<LinalgTensorDef>
+    type_var: T1
+  - !LinalgOperandDefConfig
     name: A
     usage: input
     shape: affine_map<()[s0, s1] -> (s1, s0)>
-    element_type_var: T2
-  - !<LinalgTensorDef>
+    type_var: T2
+  - !LinalgOperandDefConfig
     name: x
     usage: output
     shape: affine_map<()[s0, s1] -> (s0)>
-    element_type_var: U
+    type_var: U
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1)[s0, s1] -> (d1)>
@@ -258,21 +258,21 @@ metadata: !LinalgOpMetadata
   - LinalgContractionOpInterface
 structured_op: !LinalgStructuredOpConfig
   args:
-  - !<LinalgTensorDef>
+  - !LinalgOperandDefConfig
     name: A
     usage: input
     shape: affine_map<()[s0] -> (s0)>
-    element_type_var: T1
-  - !<LinalgTensorDef>
+    type_var: T1
+  - !LinalgOperandDefConfig
     name: B
     usage: input
     shape: affine_map<()[s0] -> (s0)>
-    element_type_var: T2
-  - !<LinalgTensorDef>
+    type_var: T2
+  - !LinalgOperandDefConfig
     name: C
     usage: output
     shape: affine_map<()[s0] -> ()>
-    element_type_var: U
+    type_var: U
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0)[s0] -> (d0)>
@@ -319,18 +319,30 @@ metadata: !LinalgOpMetadata
     and runs them in parallel. The seed operand and the indices of the data
     element seed the random number generation. The min and max operands limit
     the range of the generated random numbers.
-
-    Note: The captures are hard-coded till there is capture support on the C++
-    side.
 structured_op: !LinalgStructuredOpConfig
   args:
-  - !<LinalgTensorDef>
+  - !LinalgOperandDefConfig
+    name: min
+    usage: input
+    type_var: F64
+  - !LinalgOperandDefConfig
+    name: max
+    usage: input
+    type_var: F64
+  - !LinalgOperandDefConfig
+    name: seed
+    usage: input
+    type_var: I32
+  - !LinalgOperandDefConfig
     name: O
     usage: output
     shape: affine_map<()[s0, s1] -> (s0, s1)>
-    element_type_var: T
+    type_var: T
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
+    - affine_map<(d0, d1)[s0, s1] -> ()>
+    - affine_map<(d0, d1)[s0, s1] -> ()>
+    - affine_map<(d0, d1)[s0, s1] -> ()>
     - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
   iterator_types:
   - parallel
@@ -401,11 +413,7 @@ structured_op: !LinalgStructuredOpConfig
                                                     - !ScalarExpression
                                                       scalar_index: 0
                                                 - !ScalarExpression
-                                                  symbolic_cast:
-                                                    type_var: I32
-                                                    operands:
-                                                    - !ScalarExpression
-                                                      scalar_const: '42 : i64'
+                                                  scalar_arg: seed
                                             - !ScalarExpression
                                               symbolic_cast:
                                                 type_var: I32
@@ -439,17 +447,9 @@ structured_op: !LinalgStructuredOpConfig
                         fn_name: sub
                         operands:
                         - !ScalarExpression
-                          symbolic_cast:
-                            type_var: F64
-                            operands:
-                            - !ScalarExpression
-                              scalar_const: '1000 : i64'
+                          scalar_arg: max
                         - !ScalarExpression
-                          symbolic_cast:
-                            type_var: F64
-                            operands:
-                            - !ScalarExpression
-                              scalar_const: '-1000 : i64'
+                          scalar_arg: min
                     - !ScalarExpression
                       symbolic_cast:
                         type_var: F64
@@ -457,8 +457,4 @@ structured_op: !LinalgStructuredOpConfig
                         - !ScalarExpression
                           scalar_const: '2.3283063999999999E-10 : f64'
             - !ScalarExpression
-              symbolic_cast:
-                type_var: F64
-                operands:
-                - !ScalarExpression
-                  scalar_const: '-1000 : i64'
+              scalar_arg: min

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 4a431bd1a54a6..35fe9fe69a5a7 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -30,16 +30,13 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
 
 // -----
 
-func @generalize_fill_rng_2d_f32(%O: tensor<16x32xf32>) -> tensor<16x32xf32> {
-  %0 = linalg.fill_rng_2d outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
+func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
   return %0: tensor<16x32xf32>
 }
 
 // CHECK-LABEL: @generalize_fill_rng_2d_f32
-// CHECK-SAME: (%[[O:.+]]: tensor<16x32xf32>)
-// CHECK-DAG:    %[[MIN:.+]] = constant -1000 : i64
-// CHECK-DAG:    %[[MAX:.+]] = constant 1000 : i64
-// CHECK-DAG:    %[[SEED:.+]] = constant 42 : i32
+// CHECK-DAG:  ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %[[O:.+]]: f32
 // CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
 // CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
 // CHECK-DAG:    %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
@@ -50,27 +47,24 @@ func @generalize_fill_rng_2d_f32(%O: tensor<16x32xf32>) -> tensor<16x32xf32> {
 // CHECK-DAG:    %[[VAL1:.+]] = muli %[[VAL0]], %[[CST0]] : i32
 // CHECK-DAG:    %[[VAL2:.+]] = addi %[[VAL1]], %[[CST1]] : i32
 // Skip random number computation for the second index.
-// CHECK-DAG:    %[[MIN_CAST1:.+]] = sitofp %[[MIN]] : i64 to f64
-// CHECK-DAG:    %[[MAX_CAST:.+]] = sitofp %[[MAX]] : i64 to f64
-// CHECK-DAG:    %[[DIFF:.+]] = subf %[[MAX_CAST]], %[[MIN_CAST1]] : f64
+// CHECK-DAG:    %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
 // CHECK-DAG:    %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64
 // CHECK-DAG:    %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64
 // CHECK-DAG:    %[[VAL4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
-// CHECK-DAG:    %[[MIN_CAST2:.+]] = sitofp %[[MIN]] : i64 to f64
-// CHECK-DAG:    %[[VAL5:.+]] = addf %[[VAL4]], %[[MIN_CAST2]] : f64
+// CHECK-DAG:    %[[VAL5:.+]] = addf %[[VAL4]], %[[MIN]] : f64
 // CHECK-DAG:    %[[VAL6:.+]] = fptrunc %[[VAL5]] : f64 to f32
 // CHECK-NEXT:   linalg.yield %[[VAL6]] : f32
 // CHECK-NEXT: -> tensor<16x32xf32>
 
 // -----
 
-func @generalize_fill_rng_2d_i32(%O: tensor<16x32xi32>) -> tensor<16x32xi32> {
-  %0 = linalg.fill_rng_2d outs(%O : tensor<16x32xi32>) -> tensor<16x32xi32>
+func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xi32>) -> tensor<16x32xi32>
   return %0: tensor<16x32xi32>
 }
 
 // CHECK-LABEL: @generalize_fill_rng_2d_i32
-// CHECK-SAME: (%[[O:.+]]: tensor<16x32xi32>)
+// CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %[[O:.+]]: i32
 // Verifies floating point to integer cast.
 // CHECK:        %[[VAL6:.+]] = fptosi %{{.+}} : f64 to i32
 // CHECK-NEXT:   linalg.yield %[[VAL6]] : i32

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 5f107c3071616..00a6528ddcd17 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -19,11 +19,11 @@ metadata: !LinalgOpMetadata
     Detailed description.
 structured_op: !LinalgStructuredOpConfig
   args:
-  - !<LinalgTensorDef>
+  - !LinalgOperandDefConfig
     name: O
     usage: output
     shape: affine_map<()[s0, s1] -> (s0, s1)>
-    element_type_var: T
+    type_var: T
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -58,7 +58,7 @@ structured_op: !LinalgStructuredOpConfig
 #  ODS-NEXT:  }];
 
 #       ODS:  let arguments =
-#  ODS-NEXT:    Variadic<AnyShaped>:$inputs,
+#  ODS-NEXT:    Variadic<AnyType>:$inputs,
 #  ODS-NEXT:    Variadic<AnyShaped>:$outputs
 
 #       ODS:  let builders =
@@ -103,18 +103,23 @@ metadata: !LinalgOpMetadata
     Detailed description.
 structured_op: !LinalgStructuredOpConfig
   args:
-  - !<LinalgTensorDef>
+  - !LinalgOperandDefConfig
+    name: value
+    usage: input
+    type_var: T
+  - !LinalgOperandDefConfig
     name: I
     usage: input
-    shape: affine_map<()[s0, s1] -> (s0, s1)>
-    element_type_var: T
-  - !<LinalgTensorDef>
+    shape: affine_map<()[s0, s1] -> (s1, s0)>
+    type_var: T
+  - !LinalgOperandDefConfig
     name: O
     usage: output
     shape: affine_map<()[s0, s1] -> (s0, s1)>
-    element_type_var: T
+    type_var: T
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
+    - affine_map<(d0, d1)[s0, s1] -> ()>
     - affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
     - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
   iterator_types:
@@ -124,15 +129,23 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      scalar_arg: I
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: value
+        - !ScalarExpression
+          scalar_arg: I
 
 # IMPL-LABEL:  Test2Op::iterator_types()
 #  IMPL-NEXT:  { getParallelIteratorTypeName(), getParallelIteratorTypeName() }
 
 #       IMPL:  Test2Op::indexing_maps()
+#       IMPL:  "affine_map<(d0, d1)[s0, s1] -> ()>"
 #       IMPL:  "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>"
 #       IMPL:  "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>"
 
 #       IMPL:  void Test2Op::regionBuilder(
 #       IMPL:    ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
-#       IMPL:  yields.push_back(block.getArgument(0));
+
+#       IMPL:   = helper.applyfn__add(block.getArgument(0), block.getArgument(1));

diff  --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py
index 2b58f38f36319..ae9b7d318d66c 100644
--- a/mlir/test/python/dialects/linalg/opsrun.py
+++ b/mlir/test/python/dialects/linalg/opsrun.py
@@ -131,6 +131,33 @@ def matmul_on_buffers(lhs, rhs, out):
 test_matmul_generic()
 
 
+def test_fill_builtin():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f64 = F64Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
+      def fill_on_buffers(min, max, seed, out):
+        linalg.fill_rng_2d(min, max, seed, outs=[out])
+
+    execution_engine = ExecutionEngine(transform(module, fill_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: -480
+
+
+test_fill_builtin()
+
+
 def test_fill_generic():
   with Context() as ctx, Location.unknown():
     module = Module.create()

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index c907caef649de..e7660cbd6286e 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -62,17 +62,13 @@ struct SerializedAffineMap {
   AffineMap affineMap() { return affineMapAttr.getValue(); }
 };
 
-enum class LinalgTensorUsageDef {
-  input,
-  output,
-  temporary,
-};
+enum class LinalgOperandDefUsage { input, output };
 
-struct LinalgTensorDef {
+struct LinalgOperandDef {
   std::string name;
-  LinalgTensorUsageDef usage;
-  SerializedAffineMap shape;
-  std::string elementTypeVar;
+  LinalgOperandDefUsage usage;
+  Optional<SerializedAffineMap> shape;
+  std::string typeVar;
 };
 
 enum class LinalgIteratorTypeDef {
@@ -114,10 +110,10 @@ struct ScalarAssign {
 };
 
 struct LinalgStructuredOpConfig {
-  SmallVector<LinalgTensorDef> args;
+  SmallVector<LinalgOperandDef> args;
   LinalgIndexingMapsConfig indexingMaps;
   SmallVector<LinalgIteratorTypeDef> iteratorTypes;
-  SmallVector<ScalarAssign, 2> assignments;
+  std::vector<ScalarAssign> assignments;
 };
 
 struct LinalgOpConfig {
@@ -131,7 +127,7 @@ struct LinalgOpConfig {
 // Mapping traits.
 //===----------------------------------------------------------------------===//
 
-LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgTensorDef)
+LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef)
 LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap)
 LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef)
 LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign)
@@ -153,8 +149,8 @@ struct MappingTraits<LinalgOpConfig> {
 };
 
 /// A structured op models (at most) a single contraction by modeling
-///   - A list of named arguments (`LinalgTensorDef`), which can be inputs,
-///     outputs, or temporaries.
+///   - A list of named arguments (`LinalgOperandDef`), which can be inputs or
+///     outputs.
 ///   - List of indexing maps (see `LinalgIndexingMaps`).
 ///   - Iterator types (see `LinalgIteratorTypeDef`).
 ///   - List of scalar level assignment (see `ScalarAssign`).
@@ -168,31 +164,30 @@ struct MappingTraits<LinalgStructuredOpConfig> {
   }
 };
 
-/// Maps a named tensor-argument to an operation, consisting of:
+/// Maps a named tensor- or scalar-argument to an operation, consisting of:
 ///   - `name`: Must be unique within the operation.
 ///   - `usage`: How the argument is used (input, output, etc).
-///   - `shape`: An AffineMap from all op symbols to the specific shape
-///     of this argument. Each shape must be normalized over the same list of
-///     symbols and have no dimension inputs.
-///   - `element_type_var`: The symbolic type variable that binds to the scalar
-///     element type of this TensorDef.
+///   - `shape`: An optional AffineMap from all op symbols to the shape of the
+///     argument. Only tensor-arguments have a shape. Each shape must be
+///     normalized over the same list of symbols and have no dimension inputs.
+///   - `type_var`: The symbolic type variable that binds to the element or self
+///     type of the tensor- or scalar-argument, respectively.
 template <>
-struct MappingTraits<LinalgTensorDef> {
-  static void mapping(IO &io, LinalgTensorDef &info) {
+struct MappingTraits<LinalgOperandDef> {
+  static void mapping(IO &io, LinalgOperandDef &info) {
     io.mapRequired("name", info.name);
     io.mapRequired("usage", info.usage);
-    io.mapRequired("shape", info.shape);
-    io.mapRequired("element_type_var", info.elementTypeVar);
+    io.mapOptional("shape", info.shape);
+    io.mapRequired("type_var", info.typeVar);
   }
 };
 
 /// Usage enum for a named argument.
 template <>
-struct ScalarEnumerationTraits<LinalgTensorUsageDef> {
-  static void enumeration(IO &io, LinalgTensorUsageDef &value) {
-    io.enumCase(value, "input", LinalgTensorUsageDef::input);
-    io.enumCase(value, "output", LinalgTensorUsageDef::output);
-    io.enumCase(value, "temporary", LinalgTensorUsageDef::temporary);
+struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
+  static void enumeration(IO &io, LinalgOperandDefUsage &value) {
+    io.enumCase(value, "input", LinalgOperandDefUsage::input);
+    io.enumCase(value, "output", LinalgOperandDefUsage::output);
   }
 };
 
@@ -229,7 +224,7 @@ struct MappingTraits<LinalgIndexingMapsConfig> {
 };
 
 /// Models an assignment to a named output.
-///   - The `arg` name must match a named output or temporary.
+///   - The `arg` name must match a named output.
 ///   - The `value` is a scalar expression for computing the value to
 ///     assign (see `ScalarExpression`).
 template <>
@@ -366,7 +361,7 @@ static std::string interleaveToString(Container &container,
 }
 
 static Optional<int>
-findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgTensorDef> &args) {
+findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) {
   for (auto it : llvm::enumerate(args)) {
     if (it.value().name == name)
       return it.index();
@@ -376,7 +371,7 @@ findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgTensorDef> &args) {
 
 // Try to map the TypeVar to a predefined or an argument type.
 static Optional<std::string>
-findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
+findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
   // Handle all predefined types.
   if (typeVar == "I32")
     return std::string("helper.getIntegerType(32)");
@@ -389,7 +384,7 @@ findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
 
   // Search all argument types.
   for (auto it : llvm::enumerate(args)) {
-    if (it.value().elementTypeVar == typeVar)
+    if (it.value().typeVar == typeVar)
       return llvm::formatv("block.getArgument({0}).getType()", it.index())
           .str();
   }
@@ -397,8 +392,8 @@ findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgTensorDef> &args) {
   return None;
 }
 
-static ScalarAssign *
-findAssignment(StringRef name, SmallVectorImpl<ScalarAssign> &assignments) {
+static ScalarAssign *findAssignment(StringRef name,
+                                    std::vector<ScalarAssign> &assignments) {
   for (auto &assign : assignments) {
     if (assign.arg == name)
       return &assign;
@@ -445,7 +440,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
   /*extraInterfaces=*/[{2}])> {
     {3}
     let arguments = (ins
-      Variadic<AnyShaped>:$inputs,
+      Variadic<AnyType>:$inputs,
       Variadic<AnyShaped>:$outputs{4}
     );
     let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
@@ -467,7 +462,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
           $_builder,
           $_state,
           TypeRange(inputs),
-          TypeRange(outputs)/*, TODO: support captures*/);
+          TypeRange(outputs));
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@@ -485,7 +480,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
           $_builder,
           $_state,
           TypeRange(inputs),
-          TypeRange(outputs)/*, TODO: support captures*/);
+          TypeRange(outputs));
       }]>,
       OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@@ -500,7 +495,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
     ];
     let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
     let parser = [{{
-      return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
+      return ::parseNamedStructuredOp<{0}>(parser, result);
     }];
     let hasFolder = 1;
 
@@ -768,9 +763,8 @@ void {0}::regionBuilder(
     size_t generatedAssignmentCount = 0;
     int localCounter = 0;
     SmallVector<std::string> stmts;
-    for (LinalgTensorDef &arg : args) {
-      if (arg.usage != LinalgTensorUsageDef::output &&
-          arg.usage != LinalgTensorUsageDef::temporary)
+    for (LinalgOperandDef &arg : args) {
+      if (arg.usage != LinalgOperandDefUsage::output)
         continue;
 
       // Find the assignment that correlates with the argument.


        


More information about the Mlir-commits mailing list