[Mlir-commits] [mlir] [mlir] add transform tutorial chapter for Halide conv mapping (PR #66386)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 14 07:53:29 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
This chapter demonstrates how one can replicate Halide DSL transformations using transform dialect operations transforming payload expressed using Linalg. This was a part of the live tutorial presented at EuroLLVM 2023.
--
Patch is 53.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66386.diff
3 Files Affected:
- (added) mlir/docs/Tutorials/transform/ChH.md (+690)
- (modified) mlir/docs/Tutorials/transform/_index.md (+1)
- (added) mlir/test/Examples/transform/ChH/full.mlir (+393)
<pre>
diff --git a/mlir/docs/Tutorials/transform/ChH.md b/mlir/docs/Tutorials/transform/ChH.md
new file mode 100644
index 000000000000000..aa40e2b24abe0d8
--- /dev/null
+++ b/mlir/docs/Tutorials/transform/ChH.md
@@ -0,0 +1,690 @@
+# Chapter H: Reproducing Halide Schedule
+
+This chapter demonstrates how a schedule from the [Halide
+DSL](http://halide-lang.org) can be implemented using transform dialect for
+structured ops.
+
+Note that the IR below is pseudo-code with types removed for brevity. It may
+also get out of sync with the current syntax. Always refer to the source code in
+[mlir/examples/transform/ChH](https://github.com/llvm/llvm-project/tree/main/mlir/test/Examples/transform/ChH)
+as the source of truth.
+
+## Channeled Convolution
+
+The Transform dialect provides a substrate for implementing “transformation
+directive” domain-specific languages (DSLs) in MLIR. Such a DSL, at least in its
+scheduling part, can target the operations in the Transform dialect that are
+later applied by the compiler. Sets of transform operations, or even new
+dialects leveraging the same interfaces and infrastructure, can be added to
+support a specific DSL for a particular scheduling model. In this chapter, we
+will revisit the Halide DSL that has (re)popularized separate specification of
+schedules originally for image processing programs.
+
+Two approaches Halide to the Transform dialect are possible:
+
+* Create a new dialect that corresponds to the computational part of Halide
+ DSL, and define a set of transformations wrapped into Transform dialect
+ operations, that correspond to the scheduling part of the DSL.
+* Map the Halide abstractions to the existing MLIR abstractions, for both
+ parts of the DSL.
+
+We will consider the latter approach as the computational part of the DSL easily
+maps to the structured ops in the Linalg dialect. This also gives us the
+opportunity to discuss how Linalg transformations on the so-called structured
+operations are similar to or different from the existing transformations.
+
+We will consider the 2D channeled convolution example extracted from Halide
+[application
+examples](https://github.com/halide/Halide/tree/294f80c49bf3bb8582446613c25fcce03b82bcd8/apps/conv_layer).
+
+```cpp
+// Sizes of the problem.
+const int N = 5, CI = 128, CO = 128, W = 100, H = 80;
+
+// Sized inputs. Note that the order of dimensions is
+// inverted in Halide with respect to C++, so the last dimension
+// in the list (N for input, CI for filter) is the least
+// frequently varying. The C++ equivalent is input[N][H+2][W+2][CI].
+Buffer<float, 4> input({CI, W+2, H+2, N}, "input");
+Buffer<float, 4> filter({CO, 3, 3, CI}, "filter");
+Buffer<float, 1> bias(std::vector<int>{CO}, "bias");
+
+// ... data initialization happens here ...
+
+// Declarations of "mathematical functions" for convolution and relu.
+Func conv("conv"), relu("relu");
+
+// Iterators/subscripts.
+Var x("x"), y("y"), c("c"), n("n");
+
+// 3D reduction domain (channels and 2 window dimensions),
+// dimensions are later referred to as r.x, r.y, r.z.
+RDom r(0, CI, 0, 3, 0, 3);
+
+// Core convolution with the result initialized to the bias value.
+// Note that the order of iterators is inverted in Halide DSL,
+// i.e. `n` corresponds to the lest frequently-varying (outermost) dimension
+// here and below.
+conv(c, x, y, n) = bias(c);
+conv(c, x, y, n) += filter(c, r.y, r.z, r.x) * input(r.x, x + r.y, y + r.z, n);
+
+// ReLU rectification, an elementwise operation.
+relu(c, x, y, n) = max(0, conv(c, x, y, n));
+```
+
+This can be almost directly converted to Linalg dialect operating on tensors,
+which is conceptually closer to the “mathematical function” abstraction and is
+where the majority of transformations are available.
+
+```mlir
+// Bias. Using a named Linalg operation for brevity.
+%bias_init = tensor.empty() : !toutput
+%biased = linalg.broadcast ins(%bias : !tbias)
+outs(%bias_init : !toutput) dimensions = [0, 1, 2]
+
+// Convolution proper. While Linalg has named operations for 2D convolutions,
+// the one in the Halide example has an uncommon order of filter dimensions
+// and is not supported. It also takes the fitler as first argument. This
+// code recreates it faithfully using the generic form.
+%convolved = linalg.generic {
+iterator_types = ["parallel", "parallel", "parallel", "parallel",
+ "reduction", "reduction", "reduction"],
+indexing_maps = [
+ affine_map<(n, y, x, c, rz, ry, rx) -> (rx, rz, ry, c)>,
+ affine_map<(n, y, x, c, rz, ry, rx) -> (n, y+rz, x+ry, rx)>,
+ affine_map<(n, y, x, c, rz, ry, rx) -> (n, y, x, c)>
+]
+} ins(%filter, %input: !tfilter, !tinput) outs(%biased : !toutput) {
+^bb0(%in: f32, %f: f32, %b: f32):
+// Note the fastmath attributes that allow operations to be recombined into
+// %0 = math.fma %in, %f, %b : f32
+// later on and to reorder reductions.
+%m1 = arith.mulf %in, %f {fastmath = #arith.fastmath<fast>} : f32
+%0 = arith.addf %b, %m1 {fastmath = #arith.fastmath<fast>} : f32
+linalg.yield %0 : f32
+} -> !toutput
+
+// ReLU is just a max(0, x).
+%c0 = arith.constant 0.0 : f32
+%relued = linalg.generic {
+iterator_types = ["parallel", "parallel", "parallel", "parallel"],
+indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> ()>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+]
+} ins(%c0, %convolved : f32, !toutput)
+outs(%output : !toutput) {
+^bb0(%cst: f32, %in: f32, %out: f32):
+%0 = llvm.intr.maxnum(%cst, %in) : (f32, f32) -> f32
+linalg.yield %0 : f32
+} -> !toutput
+```
+
+In Halide, a function such as `conv` may consist of two parts: a “functional”
+initialization computation and an in-place update for reductions. This is
+expressed as two C++ statements in the embedded DSL, but internally is
+represented in a single object. Linalg doesn’t have such a capability to the
+initialization and the update are represented as two distinct Linalg operations
+that are not connected to each other. Furthermore, the `x`, `y`, `c`, `n`
+variables in Halide DSL correspond to implicit loops iterating over the
+corresponding objects, which implies that functions sharing these variables in
+their definitions also share the corresponding loops. In other words, the loop
+equivalent of the Halide definition starts in a fully-fused form. The Linalg
+model is the opposite with each structured operation corresponding to its own
+loop nest, resulting in a fully-distributed form. This will affect how the
+schedule is constructed later on.
+
+The loop structure for Halide computation resembles the following (adapted from
+debug dump with `HL_DEBUG_CODEGEN=1`)
+
+```python
+for n
+ for y
+ for x
+ for c
+ conv[n, y, x, c] = bias[c]
+ for rz
+ for ry
+ for rx
+ conv[n, y, x, c] += filter[rx, rz, ry, c] * input[n, y+rz, x+ry, rx]
+ relu[n, y, x, c] = max(0, conv[n, y, x, c])
+```
+
+The loop structure for the Linalg computation is as follows (obtained by
+`mlir-opt --linalg-generalize-named-ops --empty-tensor-to-alloc-tensor
+--one-shot-bufferize --convert-linalg-to-loops`)
+
+```python
+for n
+ for y
+ for x
+ for c
+ init[n, y, x, c] = bias[c]
+for n
+ for y
+ for x
+ for c
+ for rz
+ for ry
+ for rx
+ conv[n, y, x, c] += filter[rx, rz, ry, c] * input[n, y+rz, x+ry, rx]
+for n
+ for y
+ for x
+ for c
+ relu[n, y, x, c] = max(0, conv[n, y, x, c])
+
+```
+
+## Mapping Halide Scheduling Primitives to Linalg Structured Transforms
+
+The complete Halide schedule listed in the example is as follows
+
+```cpp
+Var co, ci, xo, xi;
+relu.split(c, co, ci, vec * tile_w)
+ .split(x, xo, xi, tile_h)
+ .reorder(ci, xi, xo, y, n, co)
+ .vectorize(ci, vec)
+ .unroll(ci)
+ .unroll(xi)
+ .parallel(y)
+ .parallel(n)
+ .parallel(co);
+
+conv.compute_at(relu, xo)
+ .vectorize(c, vec)
+ .unroll(c)
+ .unroll(x)
+ .unroll(y)
+ .update()
+ .reorder(c, x, y, r.x, r.y, r.z, n)
+ .vectorize(c, vec)
+ .unroll(c)
+ .unroll(x)
+ .unroll(y)
+ .unroll(r.x, 2);
+```
+
+We will consider only the case without parallelization to avoid the difference
+in parallel runtimes generated by Halide and used by MLIR. This schedule
+corresponds to a sequence of loop manipulations, unrolling and vectorization.
+The following directives are present and can be mapped to transformations on
+Linalg as described below.
+
+* `split` decomposes a loop dimension into two immediately nested loops with
+ the inner loop having at most the given number of iterations. This can be
+ understood as loop _strip-mining_ or a degenerate case of tiling a single
+ dimension using any of `linalg.tile_` transform ops. We will be using
+ `transform.structured.tile_to_forall_op` as this kind of loop is best
+ supported by bufferization and can also be turned into a parallel loop later
+ on. Unlike Halide, this doesn’t add new dimensions to the original
+ operation, but rather creates a loop around it and rewrites the operation
+ itself to operate on a subset of the original data.
+* `reorder` rearranges the loops arbitrarily. In Linalg representation, loops
+ are implicit and are intended to remain so as long as possible to target
+ microkernels. The order of implicit loops in a `linalg.generic` operation
+ can be changed by using `transform.structured.interchange`, but this does
+ not apply to named operations that need to be “generalized” first by calling
+ `transform.structured.generalize`. However, this can only reorder implicit
+ dimensions and not the explicit loops materialized by tiling operations that
+ can no longer be “folded” into the original operation. Instead, we can
+ leverage this behavior by materializing loops directly in the desired order
+ by “tiling” to size 1.
+* `vectorize` indicates that the given dimension should be vectorized with the
+ given factor; if the loop extent is larger than the factor, the loop is
+ effectively split into two parts and the inner one is vectorized. On the
+ contrary, structured Linalg op vectorization applies as a global
+ transformation to all suitable operations at, e.g., a function scope via
+ `transform.structured.vectorize`. It relies on MLIR’s support for
+ multidimensional vectors to directly map multidimensional tensors, which are
+ later decomposed into operations on smaller hardware-compatible vectors
+ during lowering.
+* `unroll` performs loop unrolling, fully or up to the given factor. It is
+ equivalent to `transform.loop.unroll`.
+* `compute_at` indicates that the value of the function must be computed
+ within the given loop that will be produced for another function; depending
+ on the relation between loops surrounding functions, this corresponds to
+ either a loop distribution or a producer/consumer fusion. Given that the
+ Linalg representation starts in the fully distributed form, it can be
+ represented as a sequence of `transform.structured.fuse_into_containing_op`
+ that operates on `forall` loops materialized by tiling beforehand.
+
+
+## Recreating the Loop Structure
+
+The three first transformation directives for `relu` in the Halide schedule aim
+at producing the following loop structure.
+
+```python
+for co
+ for n
+ for y
+ for xo
+ for xi
+ for ci
+ relu[n, y, xo*tile_h + xi, co*tile_w*vec + ci] = ...
+```
+
+Note that the outer part of the `c` gets hoisted from all of the surrounding
+loops. The implicit loop order for the operation is `n, y, x, c`, so the `co`
+loop needs to be materialized first in order to achieve the desired reordering.
+The remaining dimensions can be materialized as loops in one transformation.
+
+```mlir
+ // [n y x c]
+ %co, %relu2 = transform.structured.tile_to_forall_op %relu
+ tile_sizes [0, 0, 0, 64]
+ %n_y_xo, %relu3 = transform.structured.tile_to_forall_op %relu2
+ tile_sizes [1, 1, 5, 0]
+```
+
+This will result in the following loops being created in the IR with the nested
+elementwise operation operating on a smaller subset of original data via
+implicit loops.
+
+```mlir
+scf.forall (%co) in (2) {
+ scf.forall (%n, %y, %xo) in (5, 80, 20) {
+ tensor.extract_slice
+ // Implicit dimensions [ni=0:1, y=0:1, xi=0:5, ci=0:64]
+ %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } // ...
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice // ...
+ }
+ }
+}
+```
+
+The following loop restructuring transformations are `compute_at` and `reorder`
+on the `conv` function that need to happen before loops are destroyed by
+unrolling and vectorization. They intend to produce the final desired loop
+structure.
+
+```python
+for co
+ for n
+ for y
+ for xo
+ for xi
+ for ci
+ conv[n, y, x*tile_h + xi, co*tile_w*vec + ci] = ...
+ for rz
+ for ry
+ for rx
+ for xi
+ for ci
+ conv[n, y, x*tile_h + xi, co*tile_w*vec + ci] += ...
+ for xi
+ for ci
+ relu[n, y, xo*tile_h + xi, co*tile_w*vec + ci] = ...
+```
+
+Practically, this corresponds to fusing the convolution initialization and
+update into the `co, n, y, xo` loops materialized by tiling earlier. Structured
+op transformation set supports fusing the producer of a value into its consumer,
+so fusion happens in two stages:
+
+* first the main convolution update is fused into ReLU that uses it and has
+ loops materialized;
+* then the bias initialization is fused into the convolution+relu loop nest.
+
+Each stage consists of two transformations fusing the computational operation
+into the outer loop, then the inner loop.
+
+```mlir
+%conv2, %co2 = transform.structured.fuse_into_containing_op %conv into %co
+%conv3, %n_y_xo2 = transform.structured.fuse_into_containing_op %conv2
+ into %n_y_xo
+
+%bias2, %co3 = transform.structured.fuse_into_containing_op %bias into %co2
+%bias3, %n_y_xo3 = transform.structured.fuse_into_containing_op %bias2
+ into %n_y_xo2
+```
+
+To complete the structure, we need to put the `rz, ry, rx` loops outside the
+“tile” loops `xi, ci`. This can be achieved materializing the corresponding
+loops from the convolution operation. However, these are reduction loops and it
+wouldn’t be valid to materialize them as intrinsically parallel “forall” loops.
+Instead, we use the dedicated “reduction tiling” transformation and produce
+sequential `scf.for` loops. (`scf.forall` loops can also express parallel
+reductions, but the corresponding transformation doesn’t handle reductions along
+more than one dimension at the moment of writing.)
+
+```mlir
+%rz_ry_rx, %red_fill, %conv4, %comb
+ = transform.structured.tile_reduction_using_scf %conv3
+// n y x c rz ry rx
+ by tile_sizes=[0, 0, 0, 0, 1, 1, 1]
+```
+
+This transformation materializes the desired loops around the convolution
+operation. It is also more capable than merely producing (reduction) loops: the
+transformed code performs `tile_size` partial reductions of `N / tile_size`
+elements, potentially in parallel by changing the dimension kind of the
+structured operation inside the loop, and then performs a final reduction of
+these partial results by producing a new “combiner” structured operation after
+the loops. In our case, `tile_size = 1` along all dimensions, so the reduction
+is entirely performed by the generated loops. The combiner structured operation
+is still produced and adds up the reduction result with the initial value. This
+changes the order of floating point operations (so would reduction tiling with
+non-unit size) and may affect the final result due to non-commutativity of these
+operations, but is explicitly allowed by `fastmath` flags. Halide also emits
+LLVM IR with full `fastmath` flags.
+
+Finally, we need to produce innermost loops `xi` and `ci` that are still not
+explicit. As our next step is going to be vectorization along `ci`, we need to
+take into account the way it operates on MLIR structured operations: rather than
+selecting a specific vector size and loop/dimension to vectorize, it directly
+substitutes multidimensional vector types for tensor types and updates the
+operations accordingly. Therefore, our tensor type should not become trivial,
+i.e. size-1, and retain a `vector_size` sized dimension along the desired axis,
+`ci`. This can be achieved by tiling with `vector_size` as tile size in that
+dimension:
+
+```mlir
+// n y xi ci
+%1, %c5 = transform.structured.tile_to_forall_op %conv4 tile_sizes [0, 0, 1, 16]
+%2, %b4 = transform.structured.tile_to_forall_op %bias3 tile_sizes [0, 0, 1, 16]
+%3, %r4 = transform.structured.tile_to_forall_op %relu3 tile_sizes [0, 0, 1, 16]
+%4, %c2 = transform.structured.tile_to_forall_op %comb tile_sizes [0, 0, 1, 16]
+```
+
+Note that the combiner operation produced by reduction tiling is also tiled here.
+
+
+## Explicit Loop Unrolling
+
+The remaining unhandled loop transformation is unrolling. Specifically,
+unrolling is requested for the innermost loops that form the 4x5 tile of
+16-element vector operations to ensure a contiguous sequence of `vfma`
+instructions using 20 512-bit vector registers as accumulators. Unrolling
+additional loops,, `unroll(y)` and `unroll(r.x, 2)`, is requested in the
+schedule but _has no practical effect_. That is, the code, and all intermediate
+representations, produced by Halide with these directives removed is _strictly
+identical_ to the code with the full schedule. Therefore, we will only unroll
+the corresponding loops corresponding to `xi` and `ci` dimensions that actually
+get unrolled by Halide.
+
+As tiling in the transform dialect produces handles to the loops materialized by
+tiling, unrolling those loops is just a matter of chaining the corresponding
+transformation. Note that the inner loop must be unrolled first as unrolling the
+outer loop will invalidate the handles to the inner loop.
+
+```mlir
+transform.loop.unroll %bias_ci {factor = 4}
+transform.loop.unroll %bias_xi {factor = 5}
+transform.loop.unroll %conv_ci {factor = 4}
+transform.loop.unroll %conv_xi {factor = 5}
+transform.loop.unroll %relu_ci {factor = 4}
+transform.loop.unroll %relu_xi {factor = 5}
+transform.loop.unroll %comb_ci {factor = 4}
+transform.loop.unroll %comb_xi {factor = 5}
+```
+
+## Vectorization
+
+These transformations produced the desired loop structure and we are now ready
+to vectorize. Before proceeding it is desirable to simplify the code as tiling
+and fusion may have produced a lot of operations computing tensor subsets and
+loop ranges, some of which may be duplicated or excessively complex.
+Simplification involving canonicalization, common subexpression elimination,
+loop invariant code motion and various rewrite patterns can be applied directly
+from the transform dialect. Furthermore, an arbitrary combination of rewrite
+patterns can be applied _in one sweep_ to a given scope, a functionality that
+_cannot be achieved with conventional compiler passes_ that apply each group of
+patterns separately (at least without creating a new pass for each combination
+of pattern groups).
+
+```mlir
+%f00 = transform.structured.match ops{["func.func"]...
<truncated>
</pre>
</details>
https://github.com/llvm/llvm-project/pull/66386
More information about the Mlir-commits
mailing list