[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)

Rolf Morel llvmlistbot at llvm.org
Wed Oct 2 13:02:51 PDT 2024


================
@@ -0,0 +1,291 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_two_matmuls(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_two_matmuls
+// CHECK: %[[ACC:.+]] = linalg.matmul
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul ins({{.+}}) outs(%[[ACC]]
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_matmul_and_func_arg(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_matmul_and_func_arg
+// CHECK: %[[RES:.+]] = linalg.matmul
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_transposed_matmuls(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_transposed_matmuls
+// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul_transpose_b ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type
+  return %4 : !type
+}
+
+
+// CHECK-LABEL: func.func @expect_no_fold_as_operands_do_not_dominate_each_other
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul_transpose_b
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
+  %0 = arith.constant dense<1.111111e+00> : !type
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = tensor.empty() : !type
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+  %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+  %4 = linalg.sub ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+  %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+  return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_dominated_op_is_not_a_contraction
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.sub
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.linalg.fold_add_into_dest
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
----------------
rolfmorel wrote:

I have made the function names more specific and have now tried to group them as you suggest.

Note that previously I had them grouped by "interesting cases that ought to be working" and a group of "a test for each early exit condition". I guess its matter of personal preference which grouping is more helpful.

https://github.com/llvm/llvm-project/pull/110514


More information about the Mlir-commits mailing list