[all-commits] [llvm/llvm-project] e66629: [mlir][ArmSME] Support lowering masked vector.oute...

Benjamin Maxwell via All-commits all-commits at lists.llvm.org
Tue Oct 31 02:06:34 PDT 2023


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: e66629501189b15f24810b7a65cd814f76a25e23
      https://github.com/llvm/llvm-project/commit/e66629501189b15f24810b7a65cd814f76a25e23
  Author: Benjamin Maxwell <benjamin.maxwell at arm.com>
  Date:   2023-10-31 (Tue, 31 Oct 2023)

  Changed paths:
    M mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
    M mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    M mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    M mlir/test/Dialect/ArmSME/invalid.mlir
    M mlir/test/Dialect/ArmSME/roundtrip.mlir
    M mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
    M mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
    M mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
    M mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir

  Log Message:
  -----------
  [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (#69604)

This patch adds support for lowering masked outer products to SME. This
is done in two stages. First, vector.outerproducts (both masked and
non-masked) are rewritten to arm_sme.outerproducts. The
arm_sme.outerproduct op is close to vector.outerproduct, but supports
masking on the operands rather than the result. It also limits the cases
it handles to things that could be (directly) lowered to SME.

This currently requires that the source of the mask is a
vector.create_mask op. E.g.:

```mlir
%mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
%result = vector.mask %mask {
             vector.outerproduct %vecA, %vecB
              : vector<[4]xf32>, vector<[4]xf32>
          } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
```
Is rewritten to:
```
%maskA = vector.create_mask %dimA : vector<[4]xi1>
%maskB = vector.create_mask %dimB : vector<[4]xi1>
%result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
              : vector<[4]xf32>, vector<[4]xf32>
```
(The same rewrite works for non-masked vector.outerproducts too)

The arm_sme.outerproduct can then be directly lowered to SME intrinsics.




More information about the All-commits mailing list