[Mlir-commits] [mlir] [mlir][ArmSME] Support widening outer products (PR #78975)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Jan 23 10:00:29 PST 2024


================
@@ -122,4 +122,43 @@ def TileAllocation
   let dependentDialects = ["func::FuncDialect"];
 }
 
+def OuterProductWidening
+    : Pass<"arm-sme-outer-product-widening", "mlir::func::FuncOp"> {
+  let summary = "Fold 'arm_sme.outerproduct' operations into widening variants";
+  let description = [{
+    This pass folds 'arm_sme.outerproduct' operations that are chained via the
+    accumulator into 2-way or 4-way ArmSME outer product operations.
+
+    For example:
+    ```mlir
+    %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+    %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+    %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+    %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+    %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+    %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+    ```
+
+    Becomes:
+
+    ```mlir
+    %undef = llvm.mlir.undef : vector<[8]xf16>
+    %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a1_ins = vector.scalable.ins %a1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %b0_ins = vector.scalable.ins %b0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b1_ins = vector.scalable.ins %b1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
----------------
MacDue wrote:

Here the example should be inserting the non-extended vectors (i.e. %b0, %a1, etc).

https://github.com/llvm/llvm-project/pull/78975


More information about the Mlir-commits mailing list