[Mlir-commits] [mlir] [mlir][ArmSME] Support widening outer products (PR #78975)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Jan 23 05:35:30 PST 2024


================
@@ -814,6 +814,649 @@ let arguments = (ins
   }];
 }
 
+class OuterProductWideBase<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes,
+                           int numOuterProducts> :
+  ArmSME_Op<mnemonic, [
+    ArmSMETileOpInterface,
+    AttrSizedOperandSegments,
+    AllTypesMatch<["lhs", "rhs"]>,
+    HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+    HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+    >,
+    OptionalTypesMatchWith<"result and acc have the same type",
+                           "result", "acc", "::llvm::cast<Type>($_self)">,
+    // this trait ensures the input type match the correct output type for ops
+    // that takes multiple inputs and outputs (i.e., 4-way).
+    PredOpTrait<
+      "tile element size equals lhs element size * " # numOuterProducts,
+      CPred<"getTileType().getElementTypeBitWidth() == "
+            "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+    >,
+  ]> {
+
+  let arguments = (ins
+    AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+    Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+    Optional<AnyVector>:$acc);
+  let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs
+    oilist(
+        `acc` `` `(` $acc `)`
+      | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+    VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+    VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      // The outerproduct op allocates a new tile if no accumulator is passed.
+      if (!getAcc())
+        return arm_sme::getSMETileType(getResultType());
+      return std::nullopt;
+    }
+    VectorType getTileType() {
+      return getResultType();
+    }
+  }];
+}
+
+class OuterProductWide2Way<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes>
+  : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+                         allowedResultVectorTypes, /*numOuterProducts=*/2>;
+
+class OuterProductWide4Way<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes>
+  : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+                         allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def FMopaWide2WayOp
+  : OuterProductWide2Way<"fmopa_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+      [nxnxv4f32]> {
+  let summary = "Floating-point sum of 2 outer products and accumulate";
+
+  let description = [{
+    This operation represents a sum of 2 widened outer products. It takes 2 1-D
+    scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
----------------
banach-space wrote:

```suggestion
    This operation represents a sum of 2 widening outer products. It takes 2 x 1D
    scalable vectors as input and a 2D scalable vector (ZA tile) as output.
```

https://github.com/llvm/llvm-project/pull/78975


More information about the Mlir-commits mailing list