[Mlir-commits] [mlir] [mlir][ArmSME] Update `OuterProductFusion` to account for recent changes (PR #102125)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Aug 6 04:35:21 PDT 2024


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/102125

- Use vector.interleave rather than the LLVM intrinsic
- Remove dependency on LLVM dialect
- Remove manual outerproduct erases (these are now trivially dead)
- Remove comment explaining issues with previous tile allocator
- Update pipeline in `multi-tile-matmul-mixed-types.mlir`

>From 44452f9e3aa9ce5d569d4f6df078db469b1238ba Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 6 Aug 2024 11:15:40 +0000
Subject: [PATCH] [mlir][ArmSME] Update `OuterProductFusion` to account for
 recent changes

- Use vector.interleave rather than the LLVM intrinsic
- Remove dependency on LLVM dialect
- Remove manual outerproduct erases (these are now trivially dead)
- Remove comment explaining issues with previous tile allocator
- Update pipeline in `multi-tile-matmul-mixed-types.mlir`
---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 20 ++++----
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  2 +-
 .../Dialect/ArmSME/Transforms/CMakeLists.txt  |  1 -
 .../ArmSME/Transforms/OuterProductFusion.cpp  | 49 ++-----------------
 .../Dialect/ArmSME/outer-product-fusion.mlir  | 32 ++++++------
 .../ArmSME/multi-tile-matmul-mixed-types.mlir |  6 +--
 6 files changed, 33 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 9178655f010c9..3f1776f57e4c7 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -910,11 +910,11 @@ def FMopa2WayOp
     The 2 outer products in the example above can be fused into a single outer
     product as follows:
 
-	```mlir
-    %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-    %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+    ```mlir
+    %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
+    %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
     %0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
-	```
+    ```
 
     This is implemented in the `-arm-sme-outer-product-fusion` pass.
 
@@ -1167,13 +1167,13 @@ def SMopa4WayOp
     product as follows:
 
     ```mlir
-    %lhs0 = "llvm.intr.experimental.vector.interleave2"(%a0, %a2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-    %lhs1 = "llvm.intr.experimental.vector.interleave2"(%a1, %a3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-    %lhs = "llvm.intr.experimental.vector.interleave2"(%lhs0, %lhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+    %lhs0 = vector.interleave %a0, %a2 : vector<[4]xi8> -> vector<[8]xi8>
+    %lhs1 = vector.interleave %a1, %a3 : vector<[4]xi8> -> vector<[8]xi8>
+    %lhs = vector.interleave %lhs0, %lhs1 : vector<[8]xi8> -> vector<[16]xi8>
 
-    %rhs0 = "llvm.intr.experimental.vector.interleave2"(%b0, %b2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-    %rhs1 = "llvm.intr.experimental.vector.interleave2"(%b1, %b3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-    %rhs = "llvm.intr.experimental.vector.interleave2"(%rhs0, %rhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+    %rhs0 = vector.interleave %b0, %b2 : vector<[4]xi8> -> vector<[8]xi8>
+    %rhs1 = vector.interleave %b1, %b3 : vector<[4]xi8> -> vector<[8]xi8>
+    %rhs = vector.interleave %rhs0, %rhs1 : vector<[8]xi8> -> vector<[16]xi8>
 
     %0 = arm_sme.smopa_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
     ```
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 921234daad1f1..45efabf5fe1b4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -180,7 +180,7 @@ def OuterProductFusion
     https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop
   }];
   let constructor = "mlir::arm_sme::createOuterProductFusionPass()";
-  let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect"];
+  let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
 }
 
 def VectorLegalization
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 8f9b5080e82db..a29624468ba2d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -14,7 +14,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRPass
   MLIRArmSMEDialect
   MLIRFuncDialect
-  MLIRLLVMCommonConversion
   MLIRVectorDialect
   MLIRIndexDialect
   MLIRSCFDialect
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 1e711678dc9ab..ee1e374b25b04 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -80,15 +79,6 @@ static LogicalResult isCompatible(PatternRewriter &rewriter,
   return success();
 }
 
-// Create 'llvm.experimental.vector.interleave2' intrinsic from `lhs` and `rhs`.
-static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
-                                        Value lhs, Value rhs) {
-  auto inputType = cast<VectorType>(lhs.getType());
-  VectorType inputTypeX2 =
-      VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
-  return rewriter.create<LLVM::vector_interleave2>(loc, inputTypeX2, lhs, rhs);
-}
-
 // Fuse two 'arm_sme.outerproduct' operations that are chained via the
 // accumulator into 2-way outer product operation.
 //
@@ -106,10 +96,8 @@ static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
 //
 // Becomes:
 //
-//  %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
-//    : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-//  %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
-//    : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+//  %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
+//  %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
 //  %0 = arm_sme.fmopa_2way %a_packed, %b_packed
 //    : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
 class OuterProductFusion2Way
@@ -135,28 +123,7 @@ class OuterProductFusion2Way
 
     if (!op1->hasOneUse()) {
       // If the first outer product has uses other than as the input to another
-      // outer product, it can't be erased after fusion. This is a problem when
-      // it also has an accumulator as this will be used as the root for tile
-      // allocation and since the widening outer product uses the same
-      // accumulator it will get assigned the same tile ID, resulting in 3
-      // outer products accumulating to the same tile and incorrect results.
-      //
-      // Example:
-      //
-      //  %acc = arith.constant dense<0.0> ; root for tile allocation
-      //  %0 = arm_sme.outerproduct %a0, %b0 acc(%acc)
-      //  vector.print %0                  ; intermediary use, can't erase %0
-      //  %1 = arm_sme.outerproduct %a1, %b1 acc(%0)
-      //
-      // After fusion and tile allocation
-      //
-      //  %0 = arm_sme.zero {tile_id = 0 : i32}
-      //  %1 = arm_sme.outerproduct %a0, %b0 acc(%0) {tile_id = 0 : i32}
-      //  vector.print %1
-      //  %2 = arm_sme.fmopa_2way %a, %b acc(%0) {tile_id = 0 : i32}
-      //
-      // No accumulator would be ok, but it's simpler to prevent this
-      // altogether, since it has no benefit.
+      // outer product, it can't be erased after fusion.
       return rewriter.notifyMatchFailure(op,
                                          kMatchFailureOuterProductNotSingleUse);
     }
@@ -169,7 +136,7 @@ class OuterProductFusion2Way
 
     auto loc = op.getLoc();
     auto packInputs = [&](Value lhs, Value rhs) {
-      return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
+      return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
     };
 
     auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -226,8 +193,6 @@ class OuterProductFusion2Way
       llvm_unreachable("unexpected arm_sme::CombiningKind!");
     }
 
-    rewriter.eraseOp(op1);
-
     return success();
   }
 
@@ -319,7 +284,7 @@ class OuterProductFusion4Way
 
     auto loc = op.getLoc();
     auto packInputs = [&](Value lhs, Value rhs) {
-      return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
+      return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs);
     };
 
     auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -400,10 +365,6 @@ class OuterProductFusion4Way
       llvm_unreachable("unexpected arm_sme::CombiningKind!");
     }
 
-    rewriter.eraseOp(op3);
-    rewriter.eraseOp(op2);
-    rewriter.eraseOp(op1);
-
     return success();
   }
 
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 4887d611643fb..9000551783576 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -4,10 +4,10 @@
 // CHECK-SAME:    %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
 // CHECK-SAME:    %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>
 // CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<[4]x[4]xf32>
-// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.vector.interleave2"(%[[A0]], %[[A1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.vector.interleave2"(%[[B0]], %[[B1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A0_MASK]], %[[A1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B0_MASK]], %[[B1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS:.*]] = vector.interleave %[[A0]], %[[A1]] : vector<[4]xf16> -> vector<[8]xf16>
+// CHECK-DAG: %[[RHS:.*]] = vector.interleave %[[B0]], %[[B1]] : vector<[4]xf16> -> vector<[8]xf16>
+// CHECK-DAG: %[[LHS_MASK:.*]] = vector.interleave %[[A0_MASK]], %[[A1_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = vector.interleave %[[B0_MASK]], %[[B1_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
 // CHECK-DAG: arm_sme.fmopa_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
 func.func @outerproduct_add_widening_2way_f16f16f32(
     %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
@@ -225,18 +225,18 @@ func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
 // CHECK-SAME:    %[[A2_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B2_MASK:[a-z0-9]+]]: vector<[4]xi1>,
 // CHECK-SAME:    %[[A3_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B3_MASK:[a-z0-9]+]]: vector<[4]xi1>
 // CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0> : vector<[4]x[4]xi32>
-// CHECK-DAG: %[[LHS0:.*]] = "llvm.intr.vector.interleave2"(%[[A0]], %[[A2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[LHS1:.*]] = "llvm.intr.vector.interleave2"(%[[A1]], %[[A3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[RHS0:.*]] = "llvm.intr.vector.interleave2"(%[[B0]], %[[B2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[RHS1:.*]] = "llvm.intr.vector.interleave2"(%[[B1]], %[[B3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
-// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.vector.interleave2"(%[[LHS0]], %[[LHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
-// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.vector.interleave2"(%[[RHS0]], %[[RHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
-// CHECK-DAG: %[[LHS0_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A0_MASK]], %[[A2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[LHS1_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[A1_MASK]], %[[A3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[RHS0_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B0_MASK]], %[[B2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[RHS1_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[B1_MASK]], %[[B3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[LHS0_MASK]], %[[LHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
-// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.vector.interleave2"(%[[RHS0_MASK]], %[[RHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
+// CHECK-DAG: %[[LHS0:.*]] = vector.interleave %[[A0]], %[[A2]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS1:.*]] = vector.interleave %[[A1]], %[[A3]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS0:.*]] = vector.interleave %[[B0]], %[[B2]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS1:.*]] = vector.interleave %[[B1]], %[[B3]] : vector<[4]xi8> -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS:.*]] = vector.interleave %[[LHS0]], %[[LHS1]] : vector<[8]xi8> -> vector<[16]xi8>
+// CHECK-DAG: %[[RHS:.*]] = vector.interleave %[[RHS0]], %[[RHS1]] : vector<[8]xi8> -> vector<[16]xi8>
+// CHECK-DAG: %[[LHS0_MASK:.*]] = vector.interleave %[[A0_MASK]], %[[A2_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS1_MASK:.*]] = vector.interleave %[[A1_MASK]], %[[A3_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS0_MASK:.*]] = vector.interleave %[[B0_MASK]], %[[B2_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS1_MASK:.*]] = vector.interleave %[[B1_MASK]], %[[B3_MASK]] : vector<[4]xi1> -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS_MASK:.*]] = vector.interleave %[[LHS0_MASK]], %[[LHS1_MASK]] : vector<[8]xi1> -> vector<[16]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = vector.interleave %[[RHS0_MASK]], %[[RHS1_MASK]] : vector<[8]xi1> -> vector<[16]xi1>
 // CHECK-DAG: arm_sme.smopa_4way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
 func.func @outerproduct_add_widening_4way_signed_i8i8i32(
     %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
index aabd9d2ce788e..5784ecbbe4014 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
@@ -1,11 +1,7 @@
 // RUN: mlir-opt %s \
 // RUN:   -transform-interpreter -test-transform-dialect-erase-schedule  \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
-// RUN:   -arm-sme-vector-legalization -canonicalize -cse \
-// RUN:   -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
-// RUN:   -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
-// RUN:   -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
+// RUN:   -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \
 // RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \



More information about the Mlir-commits mailing list