[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Oct 19 11:05:47 PDT 2023


================
@@ -507,4 +524,95 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
+class HasMatchingMaskTypeConstraint<string operand, string maskGetter> :
+  TypesMatchWith<
+    "shape of `" # operand #  "Mask` matches `" # operand # "`",
+    operand, operand # "Mask",
+    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))",
+    "!" # maskGetter # "() || std::equal_to<>()">;
+
+def OuterProductOp :
+  ArmSME_Op<"outerproduct", [Pure,
+    AttrSizedOperandSegments,
+    AllElementTypesMatch<["lhs", "rhs", "result"]>,
+    HasMatchingMaskTypeConstraint<"lhs", "getLhsMask">,
+    HasMatchingMaskTypeConstraint<"rhs", "getRhsMask">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
+    PredOpTrait<
+      "result type is derived from `lhs` and `rhs`",
+      CPred<
+        "getResultType() == VectorType::get({"
+          "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
+          "getRhsType().getElementType(),"
+          "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
+    TypesMatchWith<"`result` and `acc` have the same type",
+      "result", "acc",
+      "::llvm::cast<mlir::Type>($_self)",
+      "!getAcc() || std::equal_to<>()">
----------------
MacDue wrote:

...and here

https://github.com/llvm/llvm-project/pull/69604


More information about the Mlir-commits mailing list