[Mlir-commits] [mlir] [mlir][ArmSME] Support widening outer products (PR #78975)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Jan 23 10:00:29 PST 2024
================
@@ -122,4 +122,43 @@ def TileAllocation
let dependentDialects = ["func::FuncDialect"];
}
+def OuterProductWidening
+ : Pass<"arm-sme-outer-product-widening", "mlir::func::FuncOp"> {
+ let summary = "Fold 'arm_sme.outerproduct' operations into widening variants";
+ let description = [{
+ This pass folds 'arm_sme.outerproduct' operations that are chained via the
+ accumulator into 2-way or 4-way ArmSME outer product operations.
+
+ For example:
+ ```mlir
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+ ```
+
+ Becomes:
+
+ ```mlir
+ %undef = llvm.mlir.undef : vector<[8]xf16>
+ %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a1_ins = vector.scalable.ins %a1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %b0_ins = vector.scalable.ins %b0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b1_ins = vector.scalable.ins %b1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
----------------
MacDue wrote:
Here the example should be inserting the non-extended vectors (i.e. %b0, %a1, etc).
https://github.com/llvm/llvm-project/pull/78975
More information about the Mlir-commits
mailing list