[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Oct 25 08:19:31 PDT 2023


================
@@ -427,12 +427,88 @@ struct TransposeOpToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.outerproduct.
+///
+/// If the vector.outerproduct is masked (and the mask from a
+/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
+/// operands.
+///
+/// Example:
+///
+///   %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+///   %result = vector.mask %mask {
+///                vector.outerproduct %vecA, %vecB
+///                 : vector<[4]xf32>, vector<[4]xf32>
+///             } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///
+/// is converted to:
+///
+///    %maskA = vector.create_mask %dimA : vector<[4]xi1>
+///    %maskB = vector.create_mask %dimB : vector<[4]xi1>
+///    %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+///                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+///
+struct VectorOuterProductToArmSMELowering
+    : public OpRewritePattern<vector::OuterProductOp> {
+
+  using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
+                                PatternRewriter &rewriter) const override {
+    // AXPY operation not suited for SME.
+    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+      return outerProductOp.emitError("AXPY operations not supported");
----------------
MacDue wrote:

If does support if you want it... You'd just have to mask out all but the first element of the LHS. We're choosing not to support it though. 

https://github.com/llvm/llvm-project/pull/69604


More information about the Mlir-commits mailing list