[Mlir-commits] [mlir] 0d0c46a - [mlir] Improve documentation of shape dialect

Jacques Pienaar llvmlistbot at llvm.org
Wed Nov 17 14:07:19 PST 2021


Author: Jacques Pienaar
Date: 2021-11-17T14:07:06-08:00
New Revision: 0d0c46a35b3b30782107c63e833a91dbbe087feb

URL: https://github.com/llvm/llvm-project/commit/0d0c46a35b3b30782107c63e833a91dbbe087feb
DIFF: https://github.com/llvm/llvm-project/commit/0d0c46a35b3b30782107c63e833a91dbbe087feb.diff

LOG: [mlir] Improve documentation of shape dialect

Add small example of usage (brief which will be further refined).

Added: 
    mlir/docs/Dialects/Shape.md

Modified: 
    

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Shape.md b/mlir/docs/Dialects/Shape.md
new file mode 100644
index 0000000000000..147c66e04d032
--- /dev/null
+++ b/mlir/docs/Dialects/Shape.md
@@ -0,0 +1,201 @@
+# 'shape' Dialect
+
+Description of operations & types within the Shape dialect as well as their
+[usage](#
diff erent-stages-of-lowering-shape-dialect).
+
+[include "Dialects/ShapeDialect.md"]
+
+## Different stages of lowering Shape dialect
+
+In this section we shall give a brief overview of the 
diff erent uses of the 
+shape dialect and the lowering between these uses. Currently we have 3 worlds /
+stages of lowering of shape functions:
+
+1.  _Error monadic/error carrying/user specification_:
+    This "input" form carries both the shape and whether in error state as
+    value. Hence at this level all operations are pure operations producing and
+    consuming values where the values could represent an error.
+
+2.  _Constrained_:
+    This form uses a variant of explicit evidence passing to allow leveraging
+    existing compiler infrastructure to preserve safety information during
+    optimization.
+
+3.  _Side-effecting/asserting_:
+    This final lowered form is imperative form with side-effecting ops (e.g.,
+    assert) for final codegen.
+
+We are going to do a quick step through of the lowering using the example of
+a matmul.
+
+Starting from the shape function of matmul in the error monadic form
+below[^wip_form1]:
+
+```mlir
+shape.function_library @shplib {
+
+builtin.func @matmul(%lhs: !shape.value_shape, %rhs: !shape.value_shape) -> !shape.shape {
+  %c1 = shape.const_size 1
+  %c2 = shape.const_size 2
+  // We could also allow rank etc operations directly on value_shape too, that
+  // would make it nicer as "input" language, but keeping it explicit inside the
+  // IR instead and then we could have helper methods in front-end language.
+  %lhs_shape = shape.shape_of %lhs : !shape.value_shape -> !shape.shape
+  %rhs_shape = shape.shape_of %rhs : !shape.value_shape -> !shape.shape
+  %lhs_rank = shape.rank %lhs_shape : !shape.shape -> !shape.size
+  %rhs_rank = shape.rank %rhs_shape : !shape.shape -> !shape.size
+  // This is not minimal as one could ensure the ranks are the same below, also a
+  // variadic meet would make it more concise too.
+  %r = "shape.meet"(%lhs_rank, %rhs_rank) : (!shape.size, !shape.size) -> !shape.size
+  %rank = shape.meet %c2, %r, error="requires rank 2 operands" :
+    !shape.size, !shape.size -> !shape.size
+  %l0, %l1 = "shape.split_at"(%lhs_shape, %c1) :
+    (!shape.shape, !shape.size) -> (!shape.shape, !shape.shape)
+  %r0, %r1 = "shape.split_at"(%rhs_shape, %c1) :
+    (!shape.shape, !shape.size) -> (!shape.shape, !shape.shape)
+  %c = shape.meet %l1, %r0, error="inner dimensions required to match" :
+    !shape.shape, !shape.shape -> !shape.shape
+  %res = shape.concat %l0, %r1
+  // Should have `shape.return %res requires %c, %rank` to enable
+  return %res : !shape.shape
+}
+
+} mapping {
+  foo.matmul = @matmul
+}
+```
+
+*   We are using the default builtin func and return here. Preferably we'd use
+    ‘shape\_func’ as a special function op that allows passing multiple results
+    back that affect correct execution (e.g., serves as an error join)
+    *   This would also means one can't reify it inside a regular function
+        without handling the shape.return - that is a feature here as these are
+        more of a template.
+    *   Currently we also have not marked `meet` as having no side-effects to
+        avoid DCE until we have `shape.return`, at which point computing the
+        meet could be treated as purely computational returning error.
+*   Meet represents a constraint that should hold, so should not be used to see
+    *if* something is equal. E.g., this means `meet` can't be used to represent
+
+    ```
+       either(meet(x, y), meet(y,z))
+    ```
+
+*   This could have been written more concisely as something like
+
+    ```
+      concat(lhs[0], rhs[1]) if rank(lhs) == 2 &&
+        rank(rhs) == 2 && lhs[1] == rhs[0]
+    ```
+
+    but not focusing on front-end proper here.
+
+We are going to lower to "most" nested form directly (see
+[test](https://github.com/tensorflow/tensorflow/blob/64062b5c51e04e370df26551d247496787d3f5c2/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir#L3088)
+for an example reification along with legalization). In the above this was in a
+separate shape function library, while here we would normally reify it as part
+of lowering, but for simplicity will show as a standalone shape function.
+
+```mlir
+func @matmul_shape1(%lhs: tensor<*xf32>, %rhs: tensor<*xindex>) -> tensor<?xindex> {
+  %c1 = shape.const_size 1
+  %c2 = shape.const_size 2
+  // We allow `shape.shape_of` to return either a `!shape.shape` or
+  // `tensor<?xindex>` type, in the case where the input is a tensor the most
+  // refined type is a tensor of `index` but not required.
+  %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> !shape.shape
+  %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> !shape.shape
+  %lhs_rank = shape.rank %lhs_shape : !shape.shape -> !shape.size
+  %rhs_rank = shape.rank %rhs_shape : !shape.shape -> !shape.size
+  %w1 = shape.cstr_eq %lhs_rank, %rhs_rank : !shape.witness
+  %res = shape.assuming %w1 -> tensor<?xindex> {
+    %r1 = shape.any %lhs_rank, %rhs_rank : (!shape.size, !shape.size) -> !shape.size
+    // Error message needs an addition, currently only on cstr_require.
+    %w2 = shape.cstr_eq %c2, %r1, error="requires rank 2 operands"
+    %res_1 = shape.assuming %w2 -> tensor<?xindex> {
+      // Here the lowered
+      //   %rank = shape.any %c2, %r1 (!shape.size, !shape.size) -> !shape.size
+      // is dead and so elided further. But if `%rank` was actually consumed,
+      // then it could have been folded in `shape.any`.
+      %l0, %r0 = "shape.split_at"(%lhs_shape, %c1) :
+        (!shape.shape, !shape.size) -> !shape.shape
+      %l1, %r1 = "shape.split_at"(%lhs_shape, %c1) :
+        (!shape.shape, !shape.size) -> !shape.shape
+      %c = shape.meet %l1, %r0, error="inner dimensions required to match" :
+        !shape.size, !shape.size -> !shape.size
+      %res = concat(%l0, %r1)
+      shape.assuming_yield %res
+    }
+    shape.assuming_yield %res_1
+  }
+  return %res : tensor<?xindex>
+}
+```
+
+We can now hoist computations of constraint were possible (which in the case
+below is not too many as we need to verify the rank before we can split)
+
+```mlir
+func @matmul_shape2(%lhs: tensor<*xf32>, %lhs: tensor<*xf32>) -> tensor<?xindex> {
+  %c1 = shape.const_size 1
+  %c2 = shape.const_size 2
+  %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> tensor<?xindex>
+  %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> tensor<?xindex>
+  %lhs_rank = shape.rank %lhs_shape : tensor<?xindex> -> tensor<index>
+  %rhs_rank = shape.rank %rhs_shape : tensor<?xindex> -> tensor<index>
+  %w1 = shape.cstr_eq %c2, %lhs_rank, error="requires rank 2 operands"
+  %w2 = shape.cstr_eq %c2, %rhs_rank, error="requires rank 2 operands"
+  %w = shape.assuming_all %w1, %w2
+  %res = shape.assuming %w -> tensor<?xindex> {
+    %l0, %r0 = "shape.split_at"(%lhs_shape, %c1) :
+      (tensor<?xindex>, !shape.size) -> tensor<?xindex>
+    %l1, %r1 = "shape.split_at"(%lhs_shape, %c1) :
+      (tensor<?xindex>, !shape.size) -> tensor<?xindex>
+    %w3 = shape.cstr_eq %l1, %r0, error="inner dimensions required to match"
+    %res_2 = shape.assuming %w3 {
+      %res = concat(%l0, %r1)
+      shape.assuming_yield %res
+    }
+    shape.assuming_yield %res_1
+  }
+  return %res
+}
+```
+
+The above form can now be lowered to the fully imperative form (see
+[test](https://github.com/tensorflow/mlir-hlo/blob/af14e1ded33c3164d4418c5d234b5b346b6d017c/tests/rank-specialization.mlir#L22)
+for example).
+
+```mlir
+func @matmul_shape3(%lhs: tensor<*xf32>, %lhs: tensor<*xf32>) -> tensor<?xindex> {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> tensor<?xindex>
+  %rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> tensor<?xindex>
+  %lhs_rank = shape.rank %lhs_shape : tensor<?xindex> -> tensor<index>
+  %rhs_rank = shape.rank %rhs_shape : tensor<?xindex> -> tensor<index>
+  %w1 = shape.shape_eq %lhs_rank, %rhs_rank
+  %w2 = shape.shape_eq %c2, %lhs_rank
+  %w3 = and %w1, %w2
+  assert %w3, "requires rank 2 operands"
+  %l0, %l1 = shape.split_at(%lhs_shape, %c1) : tensor<?xindex>
+  %r0, %r1 = shape.split_at(%rhs_shape, %c1) : tensor<?xindex>
+  %w4 = shape.eq %l1, %r0
+  assert %w4, "inner dimensions required to match"
+  %res = concat(%l0, %r1)
+  return %res
+}
+```
+
+*   In this case form 3 is as easy and closer to form 1 (but only as no
+    reordering was required). So it is a good question if the frontend authoring
+    language could be more similar to the imperative form (under discussion).
+*   The above form presented here is an intermittent form during a lowering
+    pass. If used as input we would need to restrict the optimizations on it as
+    the `shape` dialect operations are no longer connected by producer-consumer
+    to enforce guard checking.
+
+The above could be further lowered by using `tensor.dim`, `tensor.from_elements`
+etc (or one could even lower these by way of, say, MHLO or TOSA dialect).
+
+[^wip_form1]: This form is least use inside the current workflows and needs more work. In particular in the example we use `shape_func` where in the code we instead use standard func as first form 1 isn't used explicitly.


        


More information about the Mlir-commits mailing list