[Mlir-commits] [mlir] [MLIR][Linalg] Pattern to fold AddOp to accumulation via contraction op's dest (PR #110514)
Rolf Morel
llvmlistbot at llvm.org
Wed Oct 2 13:02:51 PDT 2024
================
@@ -0,0 +1,291 @@
+// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_two_matmuls(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_two_matmuls
+// CHECK: %[[ACC:.+]] = linalg.matmul
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul ins({{.+}}) outs(%[[ACC]]
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_matmul_and_func_arg(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %arg1 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_matmul_and_func_arg
+// CHECK: %[[RES:.+]] = linalg.matmul
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @fold_add_on_transposed_matmuls(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul_transpose_a ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.matmul_transpose_b ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @fold_add_on_transposed_matmuls
+// CHECK: %[[ACC:.+]] = linalg.matmul_transpose_a
+// CHECK-NEXT: %[[RES:.+]] = linalg.matmul_transpose_b ins({{.+}}) outs(%[[ACC]]
+// CHECK-NOT: linalg.add
+// CHECK-NEXT: return %[[RES]]
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_operands_do_not_dominate_each_other(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul_transpose_b ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.add ins(%3, %3 : !type, !type) outs(%1 : !type) -> !type
+ return %4 : !type
+}
+
+
+// CHECK-LABEL: func.func @expect_no_fold_as_operands_do_not_dominate_each_other
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul_transpose_b
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_dominated_op_is_not_a_contraction(%arg0: !type, %arg1: !type) -> !type {
+ %0 = arith.constant dense<1.111111e+00> : !type
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = tensor.empty() : !type
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : !type) -> !type
+ %3 = linalg.matmul ins(%arg0, %0 : !type, !type) outs(%2 : !type) -> !type
+ %4 = linalg.sub ins(%arg1, %0 : !type, !type) outs(%2 : !type) -> !type
+ %5 = linalg.add ins(%3, %4 : !type, !type) outs(%1 : !type) -> !type
+ return %5 : !type
+}
+
+// CHECK-LABEL: func.func @expect_no_fold_as_dominated_op_is_not_a_contraction
+// CHECK: linalg.fill
+// CHECK-NEXT: linalg.matmul
+// CHECK-NEXT: linalg.sub
+// CHECK-NEXT: linalg.add
+// CHECK-NEXT: return
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.linalg.fold_add_into_dest
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!type = tensor<2048x2048xf32>
+func.func @expect_no_fold_as_orig_dest_not_additive_zero(%arg0: !type, %arg1: !type) -> !type {
----------------
rolfmorel wrote:
I have made the function names more specific and have now tried to group them as you suggest.
Note that previously I had them grouped by "interesting cases that ought to be working" and a group of "a test for each early exit condition". I guess its matter of personal preference which grouping is more helpful.
https://github.com/llvm/llvm-project/pull/110514
More information about the Mlir-commits
mailing list