[Mlir-commits] [mlir] 25bb616 - [mlir][linalg][python] Add attribute support to the YAML codegen.

Tobias Gysi llvmlistbot at llvm.org
Thu Jun 24 05:45:26 PDT 2021


Author: Tobias Gysi
Date: 2021-06-24T12:33:48Z
New Revision: 25bb61649085c0a6e66630bbffe7faa54cd67829

URL: https://github.com/llvm/llvm-project/commit/25bb61649085c0a6e66630bbffe7faa54cd67829
DIFF: https://github.com/llvm/llvm-project/commit/25bb61649085c0a6e66630bbffe7faa54cd67829.diff

LOG: [mlir][linalg][python] Add attribute support to the YAML codegen.

Extend the yaml code generation to support the index attributes that https://reviews.llvm.org/D104711 added to the OpDSL.

Differential Revision: https://reviews.llvm.org/D104712

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/test/python/dialects/linalg/opdsl/arguments.py
    mlir/test/python/dialects/linalg/opsrun.py
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 9f5bf5d62755d..58872da9b1dab 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -13,19 +13,19 @@ structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
     name: A
-    usage: input
-    shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+    usage: InputOperand
     type_var: T1
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
   - !LinalgOperandDefConfig
     name: B
-    usage: input
-    shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+    usage: InputOperand
     type_var: T2
+    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
   - !LinalgOperandDefConfig
     name: C
-    usage: output
-    shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+    usage: OutputOperand
     type_var: U
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
@@ -75,19 +75,19 @@ structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
     name: A
-    usage: input
-    shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+    usage: InputOperand
     type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
   - !LinalgOperandDefConfig
     name: B
-    usage: input
-    shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
+    usage: InputOperand
     type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
   - !LinalgOperandDefConfig
     name: C
-    usage: output
-    shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+    usage: OutputOperand
     type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
@@ -138,19 +138,19 @@ structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
     name: A
-    usage: input
-    shape: affine_map<()[s0, s1] -> (s0, s1)>
+    usage: InputOperand
     type_var: T1
+    shape_map: affine_map<()[s0, s1] -> (s0, s1)>
   - !LinalgOperandDefConfig
     name: y
-    usage: input
-    shape: affine_map<()[s0, s1] -> (s1)>
+    usage: InputOperand
     type_var: T2
+    shape_map: affine_map<()[s0, s1] -> (s1)>
   - !LinalgOperandDefConfig
     name: x
-    usage: output
-    shape: affine_map<()[s0, s1] -> (s0)>
+    usage: OutputOperand
     type_var: U
+    shape_map: affine_map<()[s0, s1] -> (s0)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -199,19 +199,19 @@ structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
     name: y
-    usage: input
-    shape: affine_map<()[s0, s1] -> (s1)>
+    usage: InputOperand
     type_var: T1
+    shape_map: affine_map<()[s0, s1] -> (s1)>
   - !LinalgOperandDefConfig
     name: A
-    usage: input
-    shape: affine_map<()[s0, s1] -> (s1, s0)>
+    usage: InputOperand
     type_var: T2
+    shape_map: affine_map<()[s0, s1] -> (s1, s0)>
   - !LinalgOperandDefConfig
     name: x
-    usage: output
-    shape: affine_map<()[s0, s1] -> (s0)>
+    usage: OutputOperand
     type_var: U
+    shape_map: affine_map<()[s0, s1] -> (s0)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1)[s0, s1] -> (d1)>
@@ -260,19 +260,19 @@ structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
     name: A
-    usage: input
-    shape: affine_map<()[s0] -> (s0)>
+    usage: InputOperand
     type_var: T1
+    shape_map: affine_map<()[s0] -> (s0)>
   - !LinalgOperandDefConfig
     name: B
-    usage: input
-    shape: affine_map<()[s0] -> (s0)>
+    usage: InputOperand
     type_var: T2
+    shape_map: affine_map<()[s0] -> (s0)>
   - !LinalgOperandDefConfig
     name: C
-    usage: output
-    shape: affine_map<()[s0] -> ()>
+    usage: OutputOperand
     type_var: U
+    shape_map: affine_map<()[s0] -> ()>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0)[s0] -> (d0)>
@@ -306,6 +306,83 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: B
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: depthwise_conv_2d_input_nhwc_filter_hwc_poly
+  cpp_class_name: DepthwiseConv2DInputNhwcFilterHwcPolyOp
+  doc: A depth-wise 2-D convolution operation.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s6, s7, s3)>
+  - !LinalgOperandDefConfig
+    name: K
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s4, s5, s3)>
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s2, s3)>
+  - !LinalgOperandDefConfig
+    name: strides
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s8, s9)>
+  - !LinalgOperandDefConfig
+    name: dilations
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s10, s11)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d0, d1 * s8 + d4 * s10, d2 * s9 + d5 * s11, d3)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d4, d5, d3)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d0, d1, d2, d3)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: I
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: K
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: fill_rng_2d
   cpp_class_name: FillRng2DOp
@@ -323,21 +400,21 @@ structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
     name: min
-    usage: input
+    usage: InputOperand
     type_var: F64
   - !LinalgOperandDefConfig
     name: max
-    usage: input
+    usage: InputOperand
     type_var: F64
   - !LinalgOperandDefConfig
     name: seed
-    usage: input
+    usage: InputOperand
     type_var: I32
   - !LinalgOperandDefConfig
     name: O
-    usage: output
-    shape: affine_map<()[s0, s1] -> (s0, s1)>
+    usage: OutputOperand
     type_var: T
+    shape_map: affine_map<()[s0, s1] -> (s0, s1)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1)[s0, s1] -> ()>

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 35fe9fe69a5a7..b40ab139c3e73 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -30,6 +30,36 @@ func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>,
 
 // -----
 
+func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32(%input : tensor<1x4x16x1xf32>, %filter: tensor<2x2x1xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
+  %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+    ins(%input, %filter : tensor<1x4x16x1xf32>, tensor<2x2x1xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
+  return %0: tensor<1x2x4x1xf32>
+}
+
+// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32
+// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[FILTER_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
+// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[IN_ARG]], %[[FILTER_ARG]] : f32
+// CHECK-NEXT:   %[[ADD:.+]] = addf %[[OUT_ARG]], %[[MUL]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<1x2x4x1xf32>
+
+// -----
+
+func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tensor<1x4x16x1xi32>, %filter: tensor<2x2x1xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
+  %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+    ins(%input, %filter : tensor<1x4x16x1xi32>, tensor<2x2x1xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
+  return %0: tensor<1x2x4x1xi32>
+}
+
+// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32
+// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[FILTER_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
+// CHECK-NEXT:   %[[MUL:.+]] = muli %[[IN_ARG]], %[[FILTER_ARG]] : i32
+// CHECK-NEXT:   %[[ADD:.+]] = addi %[[OUT_ARG]], %[[MUL]] : i32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
+// CHECK-NEXT: -> tensor<1x2x4x1xi32>
+
+// -----
+
 func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
   %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
   return %0: tensor<16x32xf32>

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 00a6528ddcd17..471890e5f4a45 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -21,9 +21,9 @@ structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
     name: O
-    usage: output
-    shape: affine_map<()[s0, s1] -> (s0, s1)>
+    usage: OutputOperand
     type_var: T
+    shape_map: affine_map<()[s0, s1] -> (s0, s1)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
@@ -86,12 +86,13 @@ structured_op: !LinalgStructuredOpConfig
 
 # @linalg_structured_op
 # def test2(I=TensorDef(T, S.M, S.N),
-#           O=TensorDef(T, S.M, S.N, output=True)):
+#           O=TensorDef(T, S.M, S.N, output=True),
+#           strides=AttributeDef(S.SM, S.SN)):
 #   """Title.
 
 #   Detailed description.
 #   """
-#   O[D.m, D.n] = I[D.n, D.m]
+#   O[D.m, D.n] = I[D.n * S.SM, D.m * S.SN]
 
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
@@ -103,25 +104,25 @@ metadata: !LinalgOpMetadata
     Detailed description.
 structured_op: !LinalgStructuredOpConfig
   args:
-  - !LinalgOperandDefConfig
-    name: value
-    usage: input
-    type_var: T
   - !LinalgOperandDefConfig
     name: I
-    usage: input
-    shape: affine_map<()[s0, s1] -> (s1, s0)>
+    usage: InputOperand
     type_var: T
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
   - !LinalgOperandDefConfig
     name: O
-    usage: output
-    shape: affine_map<()[s0, s1] -> (s0, s1)>
+    usage: OutputOperand
     type_var: T
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
+  - !LinalgOperandDefConfig
+    name: strides
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
-    - affine_map<(d0, d1)[s0, s1] -> ()>
-    - affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
-    - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
+    - affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>
+    - affine_map<(d0, d1)[s0, s1, s2, s3] -> (d0, d1)>
   iterator_types:
   - parallel
   - parallel
@@ -129,23 +130,41 @@ structured_op: !LinalgStructuredOpConfig
   - !ScalarAssign
     arg: O
     value: !ScalarExpression
-      scalar_apply:
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: value
-        - !ScalarExpression
-          scalar_arg: I
+      scalar_arg: I
 
-# IMPL-LABEL:  Test2Op::iterator_types()
-#  IMPL-NEXT:  { getParallelIteratorTypeName(), getParallelIteratorTypeName() }
+# ODS-LABEL:  def Test2Op : LinalgStructuredBase_Op<"test2"
+
+#       ODS:  let arguments =
+#  ODS-NEXT:    Variadic<AnyType>:$inputs,
+#  ODS-NEXT:    Variadic<AnyShaped>:$outputs,
+#  ODS-NEXT:    RankedI64ElementsAttr<[2]>:$strides
+
+#       ODS:  "Attribute":$strides
+#       ODS:  $_state.addAttribute("strides", strides);
+
+#       ODS:  bool hasDynamicIndexingMaps();
+#  ODS-NEXT:  LogicalResult verifyIndexingMapRequiredAttributes();
+
+#       IMPL:  getSymbolBindings(Test2Op self)
+#       IMPL:  cst2 = self.strides().getValue<int64_t>({ 0 });
+#  IMPL-NEXT:  getAffineConstantExpr(cst2, context)
+#       IMPL:  cst3 = self.strides().getValue<int64_t>({ 1 });
+#  IMPL-NEXT:  getAffineConstantExpr(cst3, context)
 
 #       IMPL:  Test2Op::indexing_maps()
-#       IMPL:  "affine_map<(d0, d1)[s0, s1] -> ()>"
-#       IMPL:  "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>"
-#       IMPL:  "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>"
+#       IMPL:  = getSymbolBindings(*this);
+#       IMPL:  "affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>"
+#       IMPL:  "affine_map<(d0, d1)[s0, s1, s2, s3] -> (d0, d1)>"
+
+#       IMPL:  Test2Op::getNumRegionArgs() { return 2; }
+
+#       IMPL:  Test2Op::hasDynamicIndexingMaps() { return true; }
+#       IMPL:  Test2Op::verifyIndexingMapRequiredAttributes()
+#       IMPL:  auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
+#       IMPL:  "missing indexing map required attribute 'strides'"
 
 #       IMPL:  void Test2Op::regionBuilder(
-#       IMPL:    ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
+#  IMPL-NEXT:    ImplicitLocOpBuilder &b, Block &block, ValueRange captures)
+#  IMPL-NEXT:    assert(2 > 0 && block.getNumArguments() == 2 &&
 
-#       IMPL:   = helper.applyfn__add(block.getArgument(0), block.getArgument(1));
+#       IMPL:   yields.push_back(block.getArgument(0));

diff  --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index 6c94bec316293..a70e3cdeca99b 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -55,7 +55,7 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
 # CHECK:     attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
 @linalg_structured_op
 def strided_copy(
-    I=TensorDef(T, S.W, S.H),
+    I=TensorDef(T, S.IH, S.IW),
     O=TensorDef(T, S.OH, S.OW, output=True),
-    strides=AttributeDef(S.S0, S.S1)):
-  O[D.oh, D.ow] = I[D.h * S.S0, D.w * S.S1]
+    strides=AttributeDef(S.SH, S.SW)):
+  O[D.oh, D.ow] = I[D.h * S.SH, D.w * S.SW]

diff  --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py
index 14217014fcd98..e315a5fe9889e 100644
--- a/mlir/test/python/dialects/linalg/opsrun.py
+++ b/mlir/test/python/dialects/linalg/opsrun.py
@@ -210,6 +210,36 @@ def fill_on_buffers(min, max, seed, out):
 test_fill_generic()
 
 
+def test_conv_builtin():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f64 = F64Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2, 1), f64),
+          MemRefType.get((1, 2, 4, 1), i32))
+      def conv_on_buffers(input, filter, output):
+        linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly(
+            input, filter, outs=[output], strides=[2, 4], dilations=[1, 2])
+
+    execution_engine = ExecutionEngine(transform(module, conv_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: 8
+
+
+test_conv_builtin()
+
+
 def test_conv_generic():
   with Context() as ctx, Location.unknown():
     module = Module.create()

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index e7660cbd6286e..00c4096d095cf 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -62,13 +62,14 @@ struct SerializedAffineMap {
   AffineMap affineMap() { return affineMapAttr.getValue(); }
 };
 
-enum class LinalgOperandDefUsage { input, output };
+enum class LinalgOperandDefUsage { input, output, attribute };
 
 struct LinalgOperandDef {
   std::string name;
   LinalgOperandDefUsage usage;
-  Optional<SerializedAffineMap> shape;
   std::string typeVar;
+  Optional<SerializedAffineMap> shapeMap;
+  Optional<SerializedAffineMap> attributeMap;
 };
 
 enum class LinalgIteratorTypeDef {
@@ -149,8 +150,8 @@ struct MappingTraits<LinalgOpConfig> {
 };
 
 /// A structured op models (at most) a single contraction by modeling
-///   - A list of named arguments (`LinalgOperandDef`), which can be inputs or
-///     outputs.
+///   - A list of named arguments (`LinalgOperandDef`), which can be inputs,
+///     outputs, or index attributes.
 ///   - List of indexing maps (see `LinalgIndexingMaps`).
 ///   - Iterator types (see `LinalgIteratorTypeDef`).
 ///   - List of scalar level assignment (see `ScalarAssign`).
@@ -164,21 +165,28 @@ struct MappingTraits<LinalgStructuredOpConfig> {
   }
 };
 
-/// Maps a named tensor- or scalar-argument to an operation, consisting of:
+/// Maps a named tensor, scalar or attribute argument to an operation,
+/// consisting of:
 ///   - `name`: Must be unique within the operation.
-///   - `usage`: How the argument is used (input, output, etc).
-///   - `shape`: An optional AffineMap from all op symbols to the shape of the
-///     argument. Only tensor-arguments have a shape. Each shape must be
-///     normalized over the same list of symbols and have no dimension inputs.
+///   - `usage`: How the argument is used (input, output, attribute, etc).
 ///   - `type_var`: The symbolic type variable that binds to the element or self
-///     type of the tensor- or scalar-argument, respectively.
+///     type of the tensor or scalar argument, respectively.
+///   - `shape_map`: An optional AffineMap from all op symbols to the shape of
+///     the argument. Only tensor arguments have a `shape_map`. Each shape must
+///     be normalized over the same list of symbols and have no dimension
+///     inputs.
+///   - `attribute_map`: An optional AffineMap from all op symbols to the
+///     attribute symbols. During op creation these symbols are replaced by the
+///     corresponding `name` attribute values. Only attribute arguments have
+///     an `attribute_map`.
 template <>
 struct MappingTraits<LinalgOperandDef> {
   static void mapping(IO &io, LinalgOperandDef &info) {
     io.mapRequired("name", info.name);
     io.mapRequired("usage", info.usage);
-    io.mapOptional("shape", info.shape);
     io.mapRequired("type_var", info.typeVar);
+    io.mapOptional("shape_map", info.shapeMap);
+    io.mapOptional("attribute_map", info.attributeMap);
   }
 };
 
@@ -186,8 +194,9 @@ struct MappingTraits<LinalgOperandDef> {
 template <>
 struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
   static void enumeration(IO &io, LinalgOperandDefUsage &value) {
-    io.enumCase(value, "input", LinalgOperandDefUsage::input);
-    io.enumCase(value, "output", LinalgOperandDefUsage::output);
+    io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input);
+    io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output);
+    io.enumCase(value, "IndexAttribute", LinalgOperandDefUsage::attribute);
   }
 };
 
@@ -425,9 +434,8 @@ static const char bannerFormat[] = R"FMT(
 // {2}: op interface list
 // {3}: documentation (summary + description)
 // {4}: op attribute list
-// {5}: the number of arguments for the op region
-// {6}: builder methods taking standalone attribute parameters
-// {7}: additional methods for attributes used by indexing maps
+// {5}: builder methods taking standalone attribute parameters
+// {6}: additional methods for attributes used by indexing maps
 static const char structuredOpOdsHeaderFormat[] = R"FMT(
 //===----------------------------------------------------------------------===//
 // Op definition for {0}
@@ -491,7 +499,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
         $_state.addTypes(resultTensorTypes);
         (void)$_state.addRegion();
       }]>
-      {6}
+      {5}
     ];
     let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
     let parser = [{{
@@ -514,11 +522,37 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
       // Generic methods.
       static unsigned getNumRegionArgs();
       std::string getLibraryCallName();
-      {7}
+      {6}
     }];
 }
 )FMT";
 
+// Builder method taking attribute parameters. Parameters:
+// {0}: Class name
+// {1}: Comma interleaved attribute parameters
+// {2}: Attribute initialization
+static const char structuredOpBuilderFormat[] = R"FMT(
+  , OpBuilder<
+  (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs, {1}),
+  [{{
+    $_state.addOperands(inputs);
+    $_state.addOperands(outputs);
+    $_state.addTypes(resultTensorTypes);
+    $_state.addAttribute(
+      "operand_segment_sizes",
+      $_builder.getI32VectorAttr({{
+        static_cast<int32_t>(inputs.size()),
+        static_cast<int32_t>(outputs.size())}));
+    createAndFillStructuredOpRegion<{0}>(
+      $_builder,
+      $_state,
+      TypeRange(inputs),
+      TypeRange(outputs));
+    {2}
+  }]>
+)FMT";
+
 // The iterator_types() method implementation. Parameters:
 // {0}: Class name
 // {1}: Comma interleaved iterator type names.
@@ -560,24 +594,53 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
 
   std::string doc;
   if (opConfig.metadata->doc) {
-    const char *docFmt = R"FMT(
-      let summary = [{ {0} }];
-      let description = [{
-        {1}
-      }];
-    )FMT";
+    static const char structuredOpDocFmt[] = R"FMT(
+  let summary = [{ {0} }];
+  let description = [{
+    {1}
+  }];
+)FMT";
     StringRef summary, description;
     std::tie(summary, description) =
         StringRef(*opConfig.metadata->doc).trim().split('\n');
-    doc = llvm::formatv(docFmt, summary.trim(), description.trim());
+    doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim());
   }
 
   interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
 
-  os << llvm::formatv(
-      structuredOpOdsHeaderFormat, opConfig.metadata->cppClassName,
-      opConfig.metadata->name, interfaceNameList, doc, attrList,
-      opConfig.structuredOp->args.size(), attrBuilder, attrMethods);
+  // Assemble the attribute specific logic required for the op definition.
+  if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
+        return arg.usage == LinalgOperandDefUsage::attribute;
+      })) {
+    SmallVector<std::string> attrDefs;
+    SmallVector<std::string> attrParams;
+    SmallVector<std::string> attrStmts;
+    for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
+      if (arg.usage != LinalgOperandDefUsage::attribute)
+        continue;
+      assert(arg.attributeMap.hasValue() && arg.typeVar == "I64");
+      static const char defFmt[] = "RankedI64ElementsAttr<[{0}]>:${1}";
+      static const char paramFmt[] = "\"Attribute\":${0}";
+      static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
+      attrDefs.push_back(llvm::formatv(
+          defFmt, arg.attributeMap->affineMap().getNumResults(), arg.name));
+      attrParams.push_back(llvm::formatv(paramFmt, arg.name));
+      attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
+    }
+    attrList = ",\n" + llvm::join(attrDefs, ",\n");
+    attrMethods = R"(
+      bool hasDynamicIndexingMaps();
+      LogicalResult verifyIndexingMapRequiredAttributes();
+    )";
+    attrBuilder = llvm::formatv(
+        structuredOpBuilderFormat, opConfig.metadata->cppClassName,
+        llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n"));
+  }
+
+  os << llvm::formatv(structuredOpOdsHeaderFormat,
+                      opConfig.metadata->cppClassName, opConfig.metadata->name,
+                      interfaceNameList, doc, attrList, attrBuilder,
+                      attrMethods);
 
   return success();
 }
@@ -595,6 +658,12 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
   std::string bannerComment = llvm::formatv("Implementation of {0}", className);
   os << llvm::formatv(bannerFormat, bannerComment);
 
+  // Compute the number of scalar and tensor arguments.
+  int64_t numOfArgs =
+      llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
+        return arg.usage != LinalgOperandDefUsage::attribute;
+      });
+
   // Reference iterators.
   {
     std::string iteratorsStr;
@@ -627,7 +696,6 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
       // For each symbol, generate a declaration for it, either with an
       // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
       // an attribute).
-      // TODO: Implement attribute constants.
       // TODO: Possibly lift into a top-level method.
       static const char structuredOpSymbolBindingsFormat[] = R"FMT(
 static SmallVector<AffineExpr> getSymbolBindings({0} self) {
@@ -641,10 +709,33 @@ static SmallVector<AffineExpr> getSymbolBindings({0} self) {
       unsigned symbolCount = firstMap.getNumSymbols();
       SmallVector<std::string> symbolBindings;
       for (unsigned i = 0; i < symbolCount; ++i) {
-        // TODO: Switch and emit constants for attribute bound symbols.
         symbolBindings.push_back(llvm::formatv(
             "  exprs.push_back(getAffineSymbolExpr({0}, context));", i));
       }
+
+      // Access an index attribute. Parameters:
+      // {0}: Attribute name
+      // {1}: Symbol position
+      // {2}: Attribute index
+      static const char structuredOpAccessAttrFormat[] = R"FMT(
+int64_t cst{1} = self.{0}().getValue<int64_t>({ {2} });
+exprs.push_back(getAffineConstantExpr(cst{1}, context));
+)FMT";
+      // Update all symbol bindings mapped to an attribute.
+      for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
+        if (arg.usage != LinalgOperandDefUsage::attribute)
+          continue;
+        assert(arg.attributeMap.hasValue());
+        for (auto &en :
+             llvm::enumerate(arg.attributeMap->affineMap().getResults())) {
+          if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
+            symbolBindings[symbol.getPosition()] =
+                llvm::formatv(structuredOpAccessAttrFormat, arg.name,
+                              symbol.getPosition(), en.index());
+          }
+        }
+      }
+
       std::string symbolBindingsStr;
       llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
       llvm::interleave(symbolBindings, symbolBindingsSs, "\n");
@@ -726,7 +817,7 @@ ArrayAttr {0}::indexing_maps() {
 unsigned {0}::getNumRegionArgs() {{ return {1}; }
 )FMT";
     os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className,
-                        opConfig.structuredOp->args.size());
+                        numOfArgs);
   }
 
   // getLibraryCallName()
@@ -741,6 +832,50 @@ std::string {0}::getLibraryCallName() {{
     os << llvm::formatv(structuredOpGetLibraryCallFormat, className);
   }
 
+  // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
+  if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
+        return arg.usage == LinalgOperandDefUsage::attribute;
+      })) {
+    std::vector<std::string> attrVerifications;
+    for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
+      if (arg.usage != LinalgOperandDefUsage::attribute)
+        continue;
+      assert(arg.attributeMap.hasValue() && arg.typeVar == "I64");
+      // Verify index attribute. Paramters:
+      // {0}: Attribute name
+      // {1}: Attribute size
+      static const char attrFmt[] = R"FMT(
+if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
+  if (!attr.getType().getElementType().isInteger(64))
+    return op->emitError(
+      "incorrect element type for indexing map required attribute '{0}'");
+  if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
+    return op->emitError(
+      "incorrect shape for indexing map required attribute '{0}'");
+} else {
+  return op->emitError(
+    "missing indexing map required attribute '{0}'");
+}
+)FMT";
+      attrVerifications.push_back(llvm::formatv(
+          attrFmt, arg.name, arg.attributeMap->affineMap().getNumResults()));
+    }
+
+    // Generates the verifyIndexingMapRequiredAttributes method. Parameters:
+    // {0}: Class name
+    // {1}: Attribute verification
+    static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT(
+bool {0}::hasDynamicIndexingMaps() {{ return true; }
+LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
+  Operation *op = getOperation();
+  {1}
+  return success();
+}
+)FMT";
+    os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes,
+                        className, llvm::join(attrVerifications, "\n"));
+  }
+
   // regionBuilder()
   {
     // Generates a regionBuilder method. Parameters.
@@ -861,7 +996,6 @@ void {0}::regionBuilder(
       return emitError(genContext.getLoc())
              << "mismatched number of assignments vs output arguments";
 
-    int64_t numOfArgs = args.size();
     os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
                         interleaveToString(stmts, "\n  "));
   }


        


More information about the Mlir-commits mailing list