[Mlir-commits] [mlir] [mlir][nfc] Update Linalg matmul -> Vector OP test (PR #81416)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 11 06:20:53 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Updates "transform-op-matmul-to-outerproduct.mlir". Summary:
  * refines TD sequence so it's easier to reason about the compilation
    pipeline,
  * new input dims to be able to distinguish parallel from reduction
    dims,
  * updates LIT variable names (makes the output easier to follow),
  * removes "noise" from the expected LIT output (e.g. types).

These Linalg -> Vector tests using Transform Dialect are great reference
points for constructing lowering pipelines. This simplification +
clean-up will hopefully make it easier to follow.


---
Full diff: https://github.com/llvm/llvm-project/pull/81416.diff


1 Files Affected:

- (modified) mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir (+40-27) 


``````````diff
diff --git a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
index ee66073a9a4193..a1a0c413da0c1c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir
@@ -1,38 +1,51 @@
 // RUN: mlir-opt %s -transform-interpreter | FileCheck %s
 
-func.func @outerproduct_matmul(%A: memref<3x3xf32>, %B: memref<3x3xf32>, %C: memref<3x3xf32>) {
-  linalg.matmul ins(%A, %B: memref<3x3xf32>, memref<3x3xf32>)
+func.func @matmul_to_outerproduct(%A: memref<3x4xf32>, %B: memref<4x3xf32>, %C: memref<3x3xf32>) {
+  linalg.matmul ins(%A, %B: memref<3x4xf32>, memref<4x3xf32>)
             outs(%C: memref<3x3xf32>)
   return
 }
 
-// CHECK-LABEL:   func.func @outerproduct_matmul(
-// CHECK-SAME:      %[[VAL_0:.*]]: memref<3x3xf32>, %[[VAL_1:.*]]: memref<3x3xf32>, %[[VAL_2:.*]]: memref<3x3xf32>) {
-// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x3xf32>, vector<3x3xf32>
-// CHECK:           %[[VAL_8:.*]] = vector.transpose %[[VAL_5]], [1, 0] : vector<3x3xf32> to vector<3x3xf32>
-// CHECK:           %[[VAL_9:.*]] = vector.extract %[[VAL_8]][0] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_11:.*]] = vector.outerproduct %[[VAL_9]], %[[VAL_10]], %[[VAL_7]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK:           %[[VAL_12:.*]] = vector.extract %[[VAL_8]][1] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_6]][1] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_14:.*]] = vector.outerproduct %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK:           %[[VAL_15:.*]] = vector.extract %[[VAL_8]][2] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_6]][2] : vector<3xf32> from vector<3x3xf32>
-// CHECK:           %[[VAL_17:.*]] = vector.outerproduct %[[VAL_15]], %[[VAL_16]], %[[VAL_14]] {kind = #vector.kind<add>} : vector<3xf32>, vector<3xf32>
-// CHECK:           vector.transfer_write %[[VAL_17]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x3xf32>, memref<3x3xf32>
-// CHECK:           return
-// CHECK:         }
+// CHECK-LABEL:   func.func @matmul_to_outerproduct(
+// CHECK-SAME:      %[[A:.*]]: memref<3x4xf32>,
+// CHECK-SAME:      %[[B:.*]]: memref<4x3xf32>,
+// CHECK-SAME:      %[[C:.*]]: memref<3x3xf32>) {
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]
+// CHECK:           %[[VEC_A_T:.*]] = vector.transpose %[[VEC_A]], [1, 0] : vector<3x4xf32> to vector<4x3xf32>
+// CHECK:           %[[A0:.*]] = vector.extract %[[VEC_A_T]][0] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B0:.*]] = vector.extract %[[VEC_B]][0] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[OP_0:.*]] = vector.outerproduct %[[A0]], %[[B0]], %[[VEC_C]]
+// CHECK:           %[[A1:.*]] = vector.extract %[[VEC_A_T]][1] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B1:.*]] = vector.extract %[[VEC_B]][1] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[OP_1:.*]] = vector.outerproduct %[[A1]], %[[B1]], %[[OP_0]]
+// CHECK:           %[[A_2:.*]] = vector.extract %[[VEC_A_T]][2] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B_2:.*]] = vector.extract %[[VEC_B]][2] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[OP_2:.*]] = vector.outerproduct %[[A_2]], %[[B_2]], %[[OP_1]]
+// CHECK:           %[[A_3:.*]] = vector.extract %[[VEC_A_T]][3] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[B_3:.*]] = vector.extract %[[VEC_B]][3] : vector<3xf32> from vector<4x3xf32>
+// CHECK:           %[[RES:.*]] = vector.outerproduct %[[A_3]], %[[B_3]], %[[OP_2]]
+// CHECK:           vector.transfer_write %[[RES]], %[[C]]{{.*}} : vector<3x3xf32>, memref<3x3xf32>
 
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
-    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %2 {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+
+    // Vectorize: linalg.matmul -> vector.multi_reduction
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %func : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %matmul : !transform.any_op
+
+    // vector.multi_reduction --> vector.contract
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transform vector.contract to be more
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.any_op
+
+    // vector.contract --> vector.outerproduct
+    transform.apply_patterns to %func {
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
     } : !transform.any_op
     transform.yield

``````````

</details>


https://github.com/llvm/llvm-project/pull/81416


More information about the Mlir-commits mailing list