[Mlir-commits] [mlir] [mlir][Linalg] Add speculation for LinalgStructuredOps (PR #108032)
Kunwar Grover
llvmlistbot at llvm.org
Tue Sep 10 09:44:48 PDT 2024
================
@@ -1118,3 +1118,48 @@ func.func @hoist_from_scf_while(%arg0: i32, %arg1: i32) -> i32 {
}
return %0 : i32
}
+
+// -----
+
+#trait = {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @hoist_linalg_ops
+// CHECK: linalg.generic
+// CHECK: scf.for
+// CHECK-NOT: linalg.generic
+// CHECK: tensor.insert_slice
+// CHECK: scf.yield
+func.func @hoist_linalg_ops(%a : tensor<128x128xf32>,
+ %b : tensor<128x128xf32>,
+ %c: tensor<128x128xf32>,
+ %lb : index,
+ %ub : index,
+ %step : index,
+ %output : tensor<?x128xf32>) -> tensor<?x128xf32> {
+ %final =
+ scf.for %i = %lb to %ub step %step iter_args(%acc = %output)
+ -> tensor<?x128xf32> {
+ %compute = linalg.generic #trait
+ ins(%a, %b : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%c : tensor<128x128xf32>) {
+ ^bb0(%in : f32, %in2 : f32, %in3 : f32):
+ %mul = arith.mulf %in, %in2 : f32
+ %add = arith.addf %mul, %in3 : f32
+ linalg.yield %in3 : f32
+ } -> tensor<128x128xf32>
+
+ %newacc = tensor.insert_slice %compute into
+ %output[%i, 0][128, 128][1, 1]
+ : tensor<128x128xf32> into tensor<?x128xf32>
+ scf.yield %newacc : tensor<?x128xf32>
+ }
+
+ func.return %final : tensor<?x128xf32>
+}
----------------
Groverkss wrote:
added
https://github.com/llvm/llvm-project/pull/108032
More information about the Mlir-commits
mailing list