[Mlir-commits] [mlir] [mlir][Linalg] Add speculation for LinalgStructuredOps (PR #108032)

Kunwar Grover llvmlistbot at llvm.org
Tue Sep 10 09:44:48 PDT 2024


================
@@ -1118,3 +1118,48 @@ func.func @hoist_from_scf_while(%arg0: i32, %arg1: i32) -> i32 {
   }
   return %0 : i32
 }
+
+// -----
+
+#trait = {
+  indexing_maps = [
+    affine_map<(m, n, k) -> (m, k)>,
+    affine_map<(m, n, k) -> (k, n)>,
+    affine_map<(m, n, k) -> (m, n)>
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"] 
+}
+
+// CHECK-LABEL: func @hoist_linalg_ops
+// CHECK: linalg.generic
+// CHECK: scf.for
+// CHECK-NOT: linalg.generic
+// CHECK: tensor.insert_slice
+// CHECK: scf.yield
+func.func @hoist_linalg_ops(%a : tensor<128x128xf32>, 
+                            %b : tensor<128x128xf32>, 
+                            %c: tensor<128x128xf32>,
+                            %lb : index,
+                            %ub : index,
+                            %step : index,
+                            %output : tensor<?x128xf32>) -> tensor<?x128xf32> {
+  %final = 
+  scf.for %i = %lb to %ub step %step iter_args(%acc = %output) 
+                                            -> tensor<?x128xf32> {
+    %compute = linalg.generic #trait
+               ins(%a, %b : tensor<128x128xf32>, tensor<128x128xf32>) 
+               outs(%c : tensor<128x128xf32>) {
+    ^bb0(%in : f32, %in2 : f32, %in3 : f32):
+      %mul = arith.mulf %in, %in2 : f32
+      %add = arith.addf %mul, %in3 : f32
+      linalg.yield %in3 : f32
+    } -> tensor<128x128xf32>
+
+    %newacc = tensor.insert_slice %compute into 
+                                  %output[%i, 0][128, 128][1, 1] 
+                                  : tensor<128x128xf32> into tensor<?x128xf32>
+    scf.yield %newacc : tensor<?x128xf32>
+  }
+
+  func.return %final : tensor<?x128xf32>
+}
----------------
Groverkss wrote:

added

https://github.com/llvm/llvm-project/pull/108032


More information about the Mlir-commits mailing list