[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Oct 26 01:55:59 PDT 2023


================
@@ -515,4 +536,93 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
+class HasMatchingMaskTypeConstraint<string operand> :
+  OptionalTypesMatchWith<
+    "shape of `" # operand #  "Mask` matches `" # operand # "`",
+    operand, operand # "Mask",
+    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
+
+def OuterProductOp :
+  ArmSME_Op<"outerproduct", [Pure,
+    AttrSizedOperandSegments,
+    AllElementTypesMatch<["lhs", "rhs", "result"]>,
+    HasMatchingMaskTypeConstraint<"lhs">,
+    HasMatchingMaskTypeConstraint<"rhs">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
+    PredOpTrait<
+      "result type is derived from `lhs` and `rhs`",
+      CPred<
+        "getResultType() == VectorType::get({"
+          "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
+          "getRhsType().getElementType(),"
+          "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
+    OptionalTypesMatchWith<"`result` and `acc` have the same type",
+      "result", "acc",
+      "::llvm::cast<mlir::Type>($_self)">
+  ]>
+{
+  let summary = "Vector outerproduct with optional fused add";
+
+  let description = [{
+    This op is based on `vector.outerproduct` with the extra conditions that:
+
+    * AXPY operations are not supported
+    * The only combining functions are "add" and "sub"
+    * Masking is performed on the inputs (rather than the output)
+
+    This is meant as an intermediate op for lowering `vector.outerproduct` to
+    SME. Types are not restricted to SVE/SME vectors at this level.
+
+    Example 1: Unmasked outerproduct (without accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 2: Unmasked outerproduct (with accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 3: Masked outerproduct
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs masks($lhsMask, $rhsMask)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 4: Masked outerproduct (with accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator) masks($lhsMask, $rhsMask)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    VectorOfRank<[1]>:$lhs, VectorOfRank<[1]>:$rhs,
+    Optional<VectorOfRankAndType<[1],[I1]>>:$lhsMask,
+    Optional<VectorOfRankAndType<[1],[I1]>>:$rhsMask,
+    Optional<VectorOfRank<[2]>>: $acc,
----------------
banach-space wrote:

Why not restrict this to an SME tile type?

https://github.com/llvm/llvm-project/pull/69604


More information about the Mlir-commits mailing list