[Mlir-commits] [mlir] a3655de - [mlir][OpDSL] Add support for basic rank polymorphism.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 11 00:28:33 PST 2022


Author: gysit
Date: 2022-02-11T08:27:49Z
New Revision: a3655de2c81fc959590c109d81a010fc8e09c48e

URL: https://github.com/llvm/llvm-project/commit/a3655de2c81fc959590c109d81a010fc8e09c48e
DIFF: https://github.com/llvm/llvm-project/commit/a3655de2c81fc959590c109d81a010fc8e09c48e.diff

LOG: [mlir][OpDSL] Add support for basic rank polymorphism.

Previously, OpDSL did not support rank polymorphism, which required a separate implementation of linalg.fill. This revision extends OpDSL to support rank polymorphism for a limited class of operations that access only scalars and tensors of rank zero. At operation instantiation time, it scales these scalar computations to multi-dimensional pointwise computations by replacing the empty indexing maps with identity index maps. The revision does not change the DSL itself, instead it adapts the Python emitter and the YAML generator to generate different indexing maps and and iterators depending on the rank of the first output.

Additionally, the revision introduces a `linalg.fill_tensor` operation that in a future revision shall replace the current handwritten `linalg.fill` operation. `linalg.fill_tensor` is thus only temporarily available and will be renamed to `linalg.fill`.

Reviewed By: nicolasvasilache, stellaraccident

Differential Revision: https://reviews.llvm.org/D119003

Added: 
    mlir/test/python/dialects/linalg/opdsl/emit_fill.py

Modified: 
    mlir/docs/Dialects/Linalg/OpDSL.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
    mlir/test/python/dialects/linalg/ops.py
    mlir/test/python/integration/dialects/linalg/opsrun.py
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index 79f22a247bb27..deec3eae0fd2b 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -102,7 +102,7 @@ bound to a `TensorDef` as demonstrated by the matmul example. All parameters
 appear in the parameter list of the operation:
 
 ```python
-fill(val, in_tensor, outs=[out_tensor])
+copy_and_scale(val, in_tensor, outs=[out_tensor])
 ```
 
 ## Attributes
@@ -251,3 +251,31 @@ The following examples illustrate the lowering of signed and unsigned functions:
 
 Not all functions are applicable for all numeric types, and on mismatch, op
 verification will fail.
+
+## Pointwise Computations
+
+Pointwise computations are expressible in a rank polymorphic form that supports
+arbitrary ranked operands - all of them need to have the same rank - with a
+single operation definition.
+
+An example for a rank polymorphic operation is `fill`:
+
+```python
+ at linalg_structured_op
+def fill(value=ScalarDef(T1),
+         O=TensorDef(U, output=True)):
+  O[None] = TypeFn.cast(U, value)
+```
+
+The operation sets the elements of the output tensor `O` to `value`. All
+operands are either scalars or rank zero tensors that are accessed using the
+index `None`. The operation thus performs a scalar computation that trivially
+extends to a multi-dimensional pointwise computation. As a result, we may use
+`fill` with arbitrary ranked output tensors:
+
+```python
+tensor_2d = linalg.InitTensorOp([4, 8], f32)
+tensor_3d = linalg.InitTensorOp([4, 8, 16], f32)
+fill(value, outs=[tensor_2d])
+fill(value, outs=[tensor_3d])
+```

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index dc5e5862e83c7..69a4cc407b9d7 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2522,6 +2522,42 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_arg: I
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: fill_tensor
+  cpp_class_name: FillTensorOp
+  doc: |-
+    Fills the output tensor with the given value.
+
+    Works for arbitrary ranked output tensors since the operation performs scalar
+    accesses only and is thus rank polymorphic. Numeric casting is performed on
+    the value operand, promoting it to the same data type as the output.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: value
+    usage: InputOperand
+    type_var: T1
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      type_fn:
+        fn_name: cast
+        type_var: U
+        operands:
+        - !ScalarExpression
+          scalar_arg: value
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: fill_rng_2d
   cpp_class_name: FillRng2DOp

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 22568c8b67487..643bcaa5c2f02 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -14,6 +14,7 @@
 
 from .scalar_expr import *
 from .config import *
+from .comprehension import *
 import numpy as np
 
 __all__ = [
@@ -132,6 +133,25 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
   indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
      prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
 
+  # An operation that accesses only scalars and scalar/rank zero tensors is
+  # rank polymorhpic. We implement rank polymorphism by generating 
diff erent
+  # indexing maps and iterators that match the rank of the first output tensor.
+  # An operation is rank polymorphic if the iteration domain has rank zero.
+  if not iterator_types_attr:
+    rank = ShapedType(outs[0].type).rank
+    iterator_types_attr = ArrayAttr.get([StringAttr.get("parallel")] * rank)
+    scalar_map = AffineMap.get(rank, 0, [])
+    tensor_map = AffineMap.get_identity(rank)
+    indexing_maps = []
+    for arg_def in all_arg_defs:
+      if arg_def.operand_def.kind == OperandKind.Scalar:
+        indexing_maps.append(scalar_map)
+      if (arg_def.operand_def.kind == OperandKind.InputTensor or
+          arg_def.operand_def.kind == OperandKind.OutputTensor):
+        indexing_maps.append(tensor_map)
+    indexing_maps_attr = ArrayAttr.get(
+        [AffineMapAttr.get(am) for am in indexing_maps])
+
   generic_op = linalg.GenericOp(
       result_tensors=result_types,
       inputs=ins,
@@ -172,19 +192,13 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
     raise NotImplementedError(
         f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")
 
+  # Set the index attributes used to compute the indexing maps.
   named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
-  linalg.fill_builtin_region(named_op.operation)
-  # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps
-  # attribute that the non-yaml path does not. The non-yaml path hardcodes the
-  # indexing_maps in C++ directly.
-  named_op.operation.attributes[
-      "linalg.memoized_indexing_maps"] = indexing_maps_attr
-  # iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
-
-  # Additionally set all named attributes.
   for name, value in index_attributes.items():
     named_op.operation.attributes[name] = value
 
+  linalg.fill_builtin_region(named_op.operation)
+
   if len(result_types) == 1:
     return named_op.result
   else:

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index d3651bd766fe7..80a8fb6ccf091 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -627,6 +627,17 @@ def pooling_ndhwc_min(
                D.ow * S.SW + D.kw * S.DW, D.c]))
 
 
+ at linalg_structured_op
+def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
+  """Fills the output tensor with the given value.
+
+  Works for arbitrary ranked output tensors since the operation performs scalar
+  accesses only and is thus rank polymorphic. Numeric casting is performed on
+  the value operand, promoting it to the same data type as the output.
+  """
+  O[None] = TypeFn.cast(U, value)
+
+
 @linalg_structured_op
 def fill_rng_2d(
     min=ScalarDef(F64),

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index b01191184b055..e5a7e74fc582d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -207,6 +207,35 @@ func @generalize_pooling_nhwc_sum_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
 
 // -----
 
+func @generalize_fill_0d(%value: f64, %O: tensor<f32>) -> tensor<f32> {
+  %0 = linalg.fill_tensor ins(%value: f64) outs(%O : tensor<f32>) -> tensor<f32>
+  return %0: tensor<f32>
+}
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
+
+// CHECK-LABEL: @generalize_fill_0d
+// CHECK:      linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
+// CHECK-SAME: iterator_types = []
+
+// -----
+
+func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) {
+  linalg.fill_tensor ins(%value: f64) outs(%O : memref<16x32xf32>)
+  return
+}
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @generalize_fill
+// CHECK:      linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+
+// -----
+
 func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
   %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
   return %0: tensor<16x32xf32>

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index 3634f4f83dd45..ee36510aaf004 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -175,3 +175,55 @@ structured_op: !LinalgStructuredOpConfig
 #  IMPL-NEXT:    assert(2 > 0 && block.getNumArguments() == 2 &&
 
 #       IMPL:   yields.push_back(block.getArgument(0));
+
+# @linalg_structured_op
+# def test3(value=ScalarDef(T1),
+#           O=TensorDef(U, output=True)):
+#   """Title.
+
+#   Detailed description.
+#   """
+#   O[None] = TypeFn.cast(U, value)
+
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: test3
+  cpp_class_name: Test3Op
+  doc: |-
+    Title.
+
+    Detailed description.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: value
+    usage: InputOperand
+    type_var: T1
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      type_fn:
+        fn_name: cast
+        type_var: U
+        operands:
+        - !ScalarExpression
+          scalar_arg: value
+
+#       IMPL:  Test3Op::iterator_types() {
+#  IMPL-NEXT:    int64_t rank = getRank(getOutputOperand(0));
+
+#       IMPL:  Test3Op::indexing_maps() {
+#  IMPL-NEXT:    MLIRContext *context = getContext();
+#  IMPL-NEXT:    AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
+#  IMPL-NEXT:    AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
new file mode 100644
index 0000000000000..75524691a4875
--- /dev/null
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py
@@ -0,0 +1,46 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import builtin
+from mlir.dialects import linalg
+from mlir.dialects import std
+
+from mlir.dialects.linalg.opdsl.lang import *
+
+T1 = TV.T1
+T2 = TV.T2
+
+
+ at linalg_structured_op
+def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
+  O[None] = TypeFn.cast(U, value)
+
+
+with Context() as ctx, Location.unknown():
+  module = Module.create()
+  f32 = F32Type.get()
+  with InsertionPoint(module.body):
+
+    # Fill indexing maps.
+    # CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
+    # CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+    # CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+    # CHECK-LABEL: @test_fill_0d
+    # CHECK: linalg.generic
+    # CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]
+    # CHECK-SAME: iterator_types = []
+    @builtin.FuncOp.from_py_func(f32, RankedTensorType.get([], f32))
+    def test_fill_0d(value, init_result):
+      return fill_poly(value, outs=[init_result])
+
+    # CHECK-LABEL: @test_fill_2d
+    # CHECK: linalg.generic
+    # CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]]]
+    # CHECK-SAME: iterator_types = ["parallel", "parallel"]
+    @builtin.FuncOp.from_py_func(f32, RankedTensorType.get([4, 16], f32))
+    def test_fill_2d(value, init_result):
+      return fill_poly(value, outs=[init_result])
+
+
+print(module)

diff  --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 4f9f138683b83..ba57a131f7f33 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -126,7 +126,7 @@ def named_form(lhs, rhs):
         # CHECK-NEXT:    arith.mulf{{.*}} (f32, f32) -> f32
         # CHECK-NEXT:    arith.addf{{.*}} (f32, f32) -> f32
         # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
-        # CHECK-NEXT:    {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
+        # CHECK-NEXT:    operand_segment_sizes = dense<[2, 1]> : vector<2xi32>
         # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
         return linalg.matmul(lhs, rhs, outs=[init_result.result])
 

diff  --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index b75de12085501..5be00f8df3334 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -42,13 +42,42 @@ def log(*args):
 """
 
 fill_boiler = """
+func @main() -> i32 attributes {llvm.emit_c_interface} {
+  %O0 = memref.alloc() : memref<i32>
+  %O1 = memref.alloc() : memref<16xi32>
+  %O2 = memref.alloc() : memref<4x16xi32>
+
+  %val0 = arith.constant 1.0 : f32
+  %val1 = arith.constant 2.0 : f32
+  %val2 = arith.constant 3.0 : f32
+
+  call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
+  call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
+  call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
+
+  %c0 = arith.constant 0 : index
+  %res0 = memref.load %O0[] : memref<i32>
+  %c8 = arith.constant 8 : index
+  %res1 = memref.load %O1[%c8] : memref<16xi32>
+  %c2 = arith.constant 2 : index
+  %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32>
+
+  %0 = arith.addi %res0, %res1 : i32
+  %1 = arith.addi %0, %res2 : i32
+
+  // TODO: FFI-based solution to allow testing and printing with python code.
+  return %1 : i32
+}
+"""
+
+fill_rng_boiler = """
 func @main() -> i32 attributes {llvm.emit_c_interface} {
   %O = memref.alloc() : memref<4x16xi32>
   %min = arith.constant -1000.0 : f64
   %max = arith.constant 1000.0 : f64
   %seed = arith.constant 42 : i32
 
-  call @fill_on_buffers(%min, %max, %seed, %O) :
+  call @fill_rng_on_buffers(%min, %max, %seed, %O) :
     (f64, f64, i32, memref<4x16xi32>) -> ()
 
   %c0 = arith.constant 0 : index
@@ -123,9 +152,9 @@ def transform(module, boilerplate):
 
   # TODO: Allow cloning functions from one module to another.
   # Atm we have to resort to string concatenation.
-  mod = Module.parse(
-      str(module.operation.regions[0].blocks[0].operations[0].operation) +
-      boilerplate)
+  ops = module.operation.regions[0].blocks[0].operations
+  mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
+
   pm = PassManager.parse(
       "builtin.func(convert-linalg-to-loops, lower-affine, " +
       "convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," +
@@ -192,6 +221,76 @@ def matmul_on_buffers(lhs, rhs, out):
 
 
 def test_fill_builtin():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+      def fill_0d_on_buffers(value, out):
+        linalg.fill_tensor(value, outs=[out])
+
+      @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+      def fill_1d_on_buffers(value, out):
+        linalg.fill_tensor(value, outs=[out])
+
+      @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+      def fill_2d_on_buffers(value, out):
+        linalg.fill_tensor(value, outs=[out])
+
+    execution_engine = ExecutionEngine(transform(module, fill_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: 6
+
+
+test_fill_builtin()
+
+
+def test_fill_generic():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f32 = F32Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(f32, MemRefType.get([], i32))
+      def fill_0d_on_buffers(value, out):
+        linalg.fill_tensor(value, outs=[out], emit_generic=True)
+
+      @builtin.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
+      def fill_1d_on_buffers(value, out):
+        linalg.fill_tensor(value, outs=[out], emit_generic=True)
+
+      @builtin.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
+      def fill_2d_on_buffers(value, out):
+        linalg.fill_tensor(value, outs=[out], emit_generic=True)
+
+    execution_engine = ExecutionEngine(transform(module, fill_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: 6
+
+
+test_fill_generic()
+
+
+def test_fill_rng_builtin():
   with Context() as ctx, Location.unknown():
     module = Module.create()
     f64 = F64Type.get()
@@ -199,10 +298,10 @@ def test_fill_builtin():
     with InsertionPoint(module.body):
 
       @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
-      def fill_on_buffers(min, max, seed, out):
+      def fill_rng_on_buffers(min, max, seed, out):
         linalg.fill_rng_2d(min, max, seed, outs=[out])
 
-    execution_engine = ExecutionEngine(transform(module, fill_boiler))
+    execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
 
     # TODO: FFI-based solution to allow testing and printing with python code.
     # Prepare arguments: one result i32.
@@ -215,10 +314,10 @@ def fill_on_buffers(min, max, seed, out):
     # CHECK: RESULT: -480
 
 
-test_fill_builtin()
+test_fill_rng_builtin()
 
 
-def test_fill_generic():
+def test_fill_rng_generic():
   with Context() as ctx, Location.unknown():
     module = Module.create()
     f64 = F64Type.get()
@@ -226,10 +325,10 @@ def test_fill_generic():
     with InsertionPoint(module.body):
 
       @builtin.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
-      def fill_on_buffers(min, max, seed, out):
+      def fill_rng_on_buffers(min, max, seed, out):
         linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
 
-    execution_engine = ExecutionEngine(transform(module, fill_boiler))
+    execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
 
     # TODO: FFI-based solution to allow testing and printing with python code.
     # Prepare arguments: one result i32.
@@ -242,7 +341,7 @@ def fill_on_buffers(min, max, seed, out):
     # CHECK: RESULT: -480
 
 
-test_fill_generic()
+test_fill_rng_generic()
 
 
 def test_max_pooling_builtin():

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 925b90848ac02..d5d8ba6e0db12 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -558,16 +558,63 @@ static const char structuredOpBuilderFormat[] = R"FMT(
   }]>
 )FMT";
 
-// The iterator_types() method implementation. Parameters:
+// The iterator_types() method for structured ops. Parameters:
 // {0}: Class name
 // {1}: Comma interleaved iterator type names.
 static const char structuredOpIteratorTypesFormat[] =
     R"FMT(
-ArrayAttr {0}::iterator_types() {
+ArrayAttr {0}::iterator_types() {{
   return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef>{{ {1} });
 }
 )FMT";
 
+// The iterator_types() method for rank polymorphic structured ops. Parameters:
+// {0}: Class name
+static const char rankPolyStructuredOpIteratorTypesFormat[] =
+    R"FMT(
+ArrayAttr {0}::iterator_types() {{
+  int64_t rank = getRank(getOutputOperand(0));
+  return Builder(getContext()).getStrArrayAttr(
+    SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
+}
+)FMT";
+
+// The indexing_maps() method for structured ops. Parameters:
+// {0}: Class name
+// {1}: Comma-separated list of dimension variable names.
+// {2}: Statements
+static const char structuredOpIndexingMapsFormat[] = R"FMT(
+ArrayAttr {0}::indexing_maps() {{
+  static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
+  ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
+  if (cached)
+    return cached;
+
+  MLIRContext *context = getContext();
+  auto symbolBindings = getSymbolBindings(*this);
+  SmallVector<AffineMap> maps;
+  {2}
+  cached = Builder(context).getAffineMapArrayAttr(maps);
+  getOperation()->setAttr(memoizeAttr, cached);
+  return cached;
+}
+)FMT";
+
+// The indexing_maps() method for rank polymorphic structured ops. Parameters:
+// {0}: Class name
+static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT(
+ArrayAttr {0}::indexing_maps() {{
+  MLIRContext *context = getContext();
+  AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
+  AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
+    getNumParallelLoops(), context);
+  SmallVector<AffineMap> indexingMaps;
+  for (OpOperand *opOperand : getInputAndOutputOperands())
+    indexingMaps.push_back(isScalar(opOperand) ? scalarMap : tensorMap);
+  return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
+}
+)FMT";
+
 // Implementations of fold and getEffects.
 // Parameters:
 // {0}: Class name
@@ -681,8 +728,14 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
         return arg.usage != LinalgOperandDefUsage::attribute;
       });
 
-  // Reference iterators.
-  {
+  // An operation that accesses only scalars and scalar/rank zero tensors is
+  // rank polymorhpic. We implement rank polymorphism by generating 
diff erent
+  // indexing maps and iterators that match the rank of the first output tensor.
+  // An operation is rank polymorphic if the iteration domain has rank zero.
+  bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty();
+
+  // Generate the iterator_types() method.
+  if (!isRankPolymorphic) {
     std::string iteratorsStr;
     llvm::raw_string_ostream ss(iteratorsStr);
     llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss,
@@ -699,22 +752,25 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
     ss.flush();
     os << llvm::formatv(structuredOpIteratorTypesFormat, className,
                         iteratorsStr);
+  } else {
+    os << llvm::formatv(rankPolyStructuredOpIteratorTypesFormat, className);
   }
 
-  // Static indexing maps.
+  // Generating the indexing_maps() method.
   if (auto &staticMaps =
           opConfig.structuredOp->indexingMaps.staticIndexingMaps) {
     if (staticMaps->empty())
       return emitError(genContext.getLoc()) << "op has no indexing maps";
-    AffineMap firstMap = staticMaps->front().affineMap();
-
-    // Symbol bindings.
-    {
-      // For each symbol, generate a declaration for it, either with an
-      // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
-      // an attribute).
-      // TODO: Possibly lift into a top-level method.
-      static const char structuredOpSymbolBindingsFormat[] = R"FMT(
+    if (!isRankPolymorphic) {
+      AffineMap firstMap = staticMaps->front().affineMap();
+
+      // Symbol bindings.
+      {
+        // For each symbol, generate a declaration for it, either with an
+        // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
+        // an attribute).
+        // TODO: Possibly lift into a top-level method.
+        static const char structuredOpSymbolBindingsFormat[] = R"FMT(
 static SmallVector<AffineExpr> getSymbolBindings({0} self) {
   MLIRContext *context = self.getContext();
   SmallVector<AffineExpr> exprs;
@@ -723,101 +779,83 @@ static SmallVector<AffineExpr> getSymbolBindings({0} self) {
 }
 )FMT";
 
-      unsigned symbolCount = firstMap.getNumSymbols();
-      SmallVector<std::string> symbolBindings;
-      for (unsigned i = 0; i < symbolCount; ++i) {
-        symbolBindings.push_back(llvm::formatv(
-            "  exprs.push_back(getAffineSymbolExpr({0}, context));", i));
-      }
+        unsigned symbolCount = firstMap.getNumSymbols();
+        SmallVector<std::string> symbolBindings;
+        for (unsigned i = 0; i < symbolCount; ++i) {
+          symbolBindings.push_back(llvm::formatv(
+              "  exprs.push_back(getAffineSymbolExpr({0}, context));", i));
+        }
 
-      // Access an index attribute. Parameters:
-      // {0}: Attribute name
-      // {1}: Symbol position
-      // {2}: Attribute index
-      static const char structuredOpAccessAttrFormat[] = R"FMT(
+        // Access an index attribute. Parameters:
+        // {0}: Attribute name
+        // {1}: Symbol position
+        // {2}: Attribute index
+        static const char structuredOpAccessAttrFormat[] = R"FMT(
 int64_t cst{1} = self.{0}().getValues<int64_t>()[{2}];
 exprs.push_back(getAffineConstantExpr(cst{1}, context));
 )FMT";
-      // Update all symbol bindings mapped to an attribute.
-      for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
-        if (arg.usage != LinalgOperandDefUsage::attribute)
-          continue;
-        assert(arg.attributeMap.hasValue());
-        for (auto &en :
-             llvm::enumerate(arg.attributeMap->affineMap().getResults())) {
-          if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
-            symbolBindings[symbol.getPosition()] =
-                llvm::formatv(structuredOpAccessAttrFormat, arg.name,
-                              symbol.getPosition(), en.index());
+        // Update all symbol bindings mapped to an attribute.
+        for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
+          if (arg.usage != LinalgOperandDefUsage::attribute)
+            continue;
+          assert(arg.attributeMap.hasValue());
+          for (auto &en :
+               llvm::enumerate(arg.attributeMap->affineMap().getResults())) {
+            if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
+              symbolBindings[symbol.getPosition()] =
+                  llvm::formatv(structuredOpAccessAttrFormat, arg.name,
+                                symbol.getPosition(), en.index());
+            }
           }
         }
-      }
 
-      std::string symbolBindingsStr;
-      llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
-      llvm::interleave(symbolBindings, symbolBindingsSs, "\n");
-      symbolBindingsSs.flush();
+        std::string symbolBindingsStr;
+        llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
+        llvm::interleave(symbolBindings, symbolBindingsSs, "\n");
+        symbolBindingsSs.flush();
 
-      os << llvm::formatv(structuredOpSymbolBindingsFormat, className,
-                          symbolBindingsStr);
-    }
-
-    // Indexing maps.
-    {
-      // Parameters:
-      // {0}: Class name
-      // {1}: Comma-separated list of dimension variable names.
-      // {2}: Statements
-      static const char structuredOpIndexingMapsFormat[] = R"FMT(
-ArrayAttr {0}::indexing_maps() {
-  static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
-  ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
-  if (cached)
-    return cached;
+        os << llvm::formatv(structuredOpSymbolBindingsFormat, className,
+                            symbolBindingsStr);
+      }
 
-  MLIRContext *context = getContext();
-  auto symbolBindings = getSymbolBindings(*this);
-  SmallVector<AffineMap> maps;
-  {2}
-  cached = Builder(context).getAffineMapArrayAttr(maps);
-  getOperation()->setAttr(memoizeAttr, cached);
-  return cached;
-}
-)FMT";
+      // Indexing maps.
+      {
+        unsigned dimCount = firstMap.getNumDims();
+
+        // Generate a comma-separated list of dim identifiers to be passed to
+        // bindDims, ensuring tht AffineExpr identifiers are bound in the right
+        // order to the proper AffineDimExpr.
+        // This results in vars in scope like: d0, d1, d2...
+        SmallVector<unsigned> dimIndices;
+        for (unsigned i = 0; i < dimCount; ++i)
+          dimIndices.push_back(i);
+        std::string dimIdentsStr;
+        llvm::raw_string_ostream dimIdentsSs(dimIdentsStr);
+        llvm::interleaveComma(dimIndices, dimIdentsSs,
+                              [&](unsigned i) { dimIdentsSs << "d" << i; });
+        dimIdentsSs.flush();
+
+        // Statements to add and simplify each affine map.
+        SmallVector<std::string> stmts;
+        for (auto &indexingMap : *staticMaps) {
+          // TODO: Assert that dim and symbol count match the first.
+          stmts.push_back(
+              llvm::formatv("maps.push_back({0});",
+                            generateCppExpression(indexingMap, "context")));
+          stmts.push_back(llvm::formatv(
+              "maps.back() = "
+              "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
+              "symbolBindings, {0}, 0));",
+              dimCount));
+        }
 
-      unsigned dimCount = firstMap.getNumDims();
-
-      // Generate a comma-separated list of dim identifiers to be passed to
-      // bindDims, ensuring tht AffineExpr identifiers are bound in the right
-      // order to the proper AffineDimExpr.
-      // This results in vars in scope like: d0, d1, d2...
-      SmallVector<unsigned> dimIndices;
-      for (unsigned i = 0; i < dimCount; ++i)
-        dimIndices.push_back(i);
-      std::string dimIdentsStr;
-      llvm::raw_string_ostream dimIdentsSs(dimIdentsStr);
-      llvm::interleaveComma(dimIndices, dimIdentsSs,
-                            [&](unsigned i) { dimIdentsSs << "d" << i; });
-      dimIdentsSs.flush();
-
-      // Statements to add and simplify each affine map.
-      SmallVector<std::string> stmts;
-      for (auto &indexingMap : *staticMaps) {
-        // TODO: Assert that dim and symbol count match the first.
-        stmts.push_back(
-            llvm::formatv("maps.push_back({0});",
-                          generateCppExpression(indexingMap, "context")));
-        stmts.push_back(llvm::formatv(
-            "maps.back() = "
-            "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
-            "symbolBindings, {0}, 0));",
-            dimCount));
+        // TODO: This needs to be memoized and/or converted to non-parser based
+        // C++ codegen prior to real use.
+        os << llvm::formatv(structuredOpIndexingMapsFormat, className,
+                            dimIdentsStr, interleaveToString(stmts, "\n  "));
       }
-
-      // TODO: This needs to be memoized and/or converted to non-parser based
-      // C++ codegen prior to real use.
-      os << llvm::formatv(structuredOpIndexingMapsFormat, className,
-                          dimIdentsStr, interleaveToString(stmts, "\n  "));
+    } else {
+      os << llvm::formatv(rankPolyStructuredOpIndexingMapsFormat, className);
     }
   } else {
     return emitError(genContext.getLoc())


        


More information about the Mlir-commits mailing list