[Mlir-commits] [mlir] [mlir] Allow unroll & jam on SCF loops with results (PR #98887)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Mon Aug 12 22:47:16 PDT 2024


================
@@ -336,6 +336,119 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @loop_unroll_and_jam_loop_with_results
+func.func @loop_unroll_and_jam_loop_with_results() -> index {
+  // CHECK:           %[[C0:.*]] = arith.constant 0
+  // CHECK:           %[[UB:.*]] = arith.constant 40
+  // CHECK:           %[[STEP:.*]] = arith.constant 8
+  %c0 = arith.constant 0 : index
+  %c40 = arith.constant 40 : index
+  %c2 = arith.constant 2 : index
+  // CHECK:           %[[RES:.*]]:4 = scf.for %[[I:.*]] = %[[C0]] to %[[UB]] step %[[STEP]]
+  // CHECK-SAME:       iter_args(%[[ARG0:.*]] = %[[C0]], %[[ARG1:.*]] = %[[C0]],
+  // CHECK-SAME                  %[[ARG2:.*]] = %[[C0]], %[[ARG3:.*]] = %[[C0]])
+  %sum = scf.for %i = %c0 to %c40 step %c2 iter_args(%does_not_alias_aggregated = %c0) -> (index) {
+    %sum = arith.addi %i, %i : index
+    // CHECK:         scf.yield %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : index, index, index, index
+    scf.yield %sum : index
+  }
+  return %sum : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @unroll_jam_tiled_loops
+func.func @unroll_jam_tiled_loops(%A : tensor<8x16x4x8xbf16>, %B : tensor<16x8x8x4xbf16>) -> tensor<16x16x4x4xf32> {
+  // CHECK:      %[[C2:.*]] = arith.constant 2 : index
+  // CHECK:      %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:      %[[C1:.*]] = arith.constant 1 : index
+  // CHECK:      %[[C8:.*]] = arith.constant 8 : index
+  // CHECK:      %[[C16:.*]] = arith.constant 16 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %buf = memref.alloc() : memref<16x16x4x4xf32>
+  %ten = bufferization.to_tensor %buf restrict writable : memref<16x16x4x4xf32>
+  // CHECK:      %[[AT:.*]] = linalg.fill {{.*}} -> tensor<16x16x4x4xf32>
+  %acc = linalg.fill ins(%c0_f32 : f32) outs(%ten : tensor<16x16x4x4xf32>) -> tensor<16x16x4x4xf32>
+  // CHECK:      %[[L0R:.*]]:2 = scf.for %{{.*}} = %[[C0]] to %[[C16]] step %[[C2]]
+  // CHECK-SAME:     iter_args(%[[L0IA0:.*]] = %[[AT]], %[[L0IA1:.*]] = %[[AT]])
+  %l0r = scf.for %i = %c0 to %c16 step %c1 iter_args(%acl0l1 = %acc) -> (tensor<16x16x4x4xf32>) {
+    // CHECK:        %[[L1R:.*]]:4 = scf.for %{{.*}} = %[[C0]] to %[[C16]] step %[[C2]]
+    // CHECK-SAME:       iter_args(%[[L1IA0:.*]] = %[[L0IA0]], %[[L1IA1:.*]] = %[[L0IA0]],
+    // CHECK-SAME:                 %[[L1IA2:.*]] = %[[L0IA1]], %[[L1IA3:.*]] = %[[L0IA1]])
+    %l1r = scf.for %j = %c0 to %c16 step %c1 iter_args(%acl1l2 = %acl0l1) -> (tensor<16x16x4x4xf32>) {
+      // CHECK:          %[[L2R:.*]]:4 = scf.for %{{.*}} = %[[C0]] to %[[C8]] step %[[C1]]
+      // CHECK-SAME:         iter_args(%[[L2IA0:.*]] = %[[L1IA0]], %[[L2IA1:.*]] = %[[L1IA1]],
+      // CHECK-SAME:                   %[[L2IA2:.*]] = %[[L1IA2]], %[[L2IA3:.*]] = %[[L1IA3]])
+      %l2r = scf.for %k = %c0 to %c8 step %c1 iter_args(%C = %acl1l2) -> (tensor<16x16x4x4xf32>) {
+        %ta = tensor.extract_slice %A[%k, %i, 0, 0] [1, 1, 4, 8] [1, 1, 1, 1] : tensor<8x16x4x8xbf16> to tensor<1x1x4x8xbf16>
+        %tb = tensor.extract_slice %B[%j, %k, 0, 0] [1, 1, 8, 4] [1, 1, 1, 1] : tensor<16x8x8x4xbf16> to tensor<1x1x8x4xbf16>
+        %tc = tensor.extract_slice %C[%j, %i, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : tensor<16x16x4x4xf32> to tensor<1x1x4x4xf32>
+        %rr = linalg.generic {
+                    indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>,
+                                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>,
+                                     affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>],
+                    iterator_types = ["parallel", "parallel", "reduction",
+                                      "parallel", "parallel", "reduction"]}
+                ins(%ta, %tb : tensor<1x1x4x8xbf16>, tensor<1x1x8x4xbf16>)
+                outs(%tc : tensor<1x1x4x4xf32>) {
+              ^bb0(%ia: bf16, %ib: bf16, %out: f32):
+                %0 = arith.extf %ia : bf16 to f32
+                %1 = arith.extf %ib : bf16 to f32
+                %2 = arith.mulf %0, %1 : f32
+                %3 = arith.addf %out, %2 : f32
+                linalg.yield %3 : f32
+        } -> tensor<1x1x4x4xf32>
+        %is = tensor.insert_slice %rr into %C[%j, %i, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : tensor<1x1x4x4xf32> into tensor<16x16x4x4xf32>
+        // CHECK:            scf.yield %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
+        // CHECK-SAME:           tensor<16x16x4x4xf32>, tensor<16x16x4x4xf32>,
+        // CHECK-SAME:           tensor<16x16x4x4xf32>, tensor<16x16x4x4xf32>
+        scf.yield %is : tensor<16x16x4x4xf32>
+      }
+      // CHECK:          scf.yield %[[L2R]]#0, %[[L2R]]#1, %[[L2R]]#2, %[[L2R]]#3 :
+      // CHECK-SAME:         tensor<16x16x4x4xf32>, tensor<16x16x4x4xf32>,
+      // CHECK-SAME:         tensor<16x16x4x4xf32>, tensor<16x16x4x4xf32>
+      scf.yield %l2r : tensor<16x16x4x4xf32>
+    }
+    // CHECK:        scf.yield %[[L1R]]#0, %[[L1R]]#2 :
+    // CHECK-SAME:       tensor<16x16x4x4xf32>, tensor<16x16x4x4xf32>
+    scf.yield %l1r : tensor<16x16x4x4xf32>
+  }
+  // CHECK:      return %[[L0R]]#0 : tensor<16x16x4x4xf32>
+  return %l0r : tensor<16x16x4x4xf32>
+}
----------------
ftynse wrote:

Could we minimize this test? I understand it comes from a real use case, but it looks obnoxiously long and therfore brittle. Using a unary operation (negate or cast) on a 3d tensor will already be an improvement.

https://github.com/llvm/llvm-project/pull/98887


More information about the Mlir-commits mailing list