[Mlir-commits] [mlir] [MLIR][Linalg] Use Top-Down traversal to safely optimize multi-use producer fusion (PR #172216)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 18 14:45:16 PST 2025


================
@@ -32,3 +32,74 @@ func.func @multi_use_producer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 // CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
 //      CHECK:   %[[RESULT:.+]]:3 = linalg.generic
 //      CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2
+
+func.func @multi_use_producer_2(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {llvm.emit_c_interface} {
+  %0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
+  %1 = llvm.mlir.constant(31 : index) : i64
+  %2 = tensor.empty() : tensor<1x32x32x8xf32>
+  %3 = tensor.empty() : tensor<1x32x32x8xindex>
+  %4:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0 : tensor<1x32x32x8xf32>) 
+  outs(%2, %3 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) {
+    ^bb0(%in: f32, %out: f32, %out_0: index):
+      %9 = linalg.index 1 : index
+      linalg.yield %0, %9 : f32, index
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
+
+  %5 = tensor.empty() : tensor<1x32x32x8xi64>
+  %6:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0, %4#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) 
+  outs(%2, %5 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+    ^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
+      %9 = builtin.unrealized_conversion_cast %in_0 : index to i64
+      linalg.yield %0, %9 : f32, i64
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+
+  %7 = tensor.empty() : tensor<1x32x32x8xi64>
+  %8:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+    ], 
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } 
+  ins(%arg0, %4#1, %6#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>, tensor<1x32x32x8xi64>) 
+  outs(%2, %7 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
+    ^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
+      %9 = llvm.sub %1, %in_1 : i64
+      linalg.yield %0, %9 : f32, i64
+  } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+  return %8#0 : tensor<1x32x32x8xf32>
+}
+// CHECK-LABEL: func @multi_use_producer_2(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x32x32x8xf32>)
+// CHECK-SAME: -> tensor<1x32x32x8xf32>
+// CHECK: %[[C31:.+]] = llvm.mlir.constant(31 : index) : i64
+// CHECK: %[[R0:.+]]:2 = linalg.generic {
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]], %[[ARG0]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>)
+// CHECK-SAME: outs(%[[INIT:.+]], %[[INIT_1:.+]] : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[IN_2:.+]]: f32, %[[IN_3:.+]]: f32, %[[IN_4:.+]]: f32, %[[OUT:.+]]: f32, %[[OUT_I:.+]]: i64):
+// CHECK: %[[IDX_9:.+]] = linalg.index 1 : index
+// CHECK: %[[C_9:.+]] = builtin.unrealized_conversion_cast %[[IDX_9]] : index to i64
+// CHECK: %[[C_SUB:.+]] = llvm.sub %[[C31]], %[[C_9]] : i64
+// CHECK: linalg.yield %[[C0:.+]], %[[C_SUB]] : f32, i64
+// CHECK: } -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
+// CHECK: return %[[R0]]#0 : tensor<1x32x32x8xf32>
----------------
MaheshRavishankar wrote:

Nit: Add new line at EoF.

https://github.com/llvm/llvm-project/pull/172216


More information about the Mlir-commits mailing list