[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Oct 26 01:55:58 PDT 2023


================
@@ -427,12 +427,88 @@ struct TransposeOpToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.outerproduct.
+///
+/// If the vector.outerproduct is masked (and the mask from a
+/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
+/// operands.
+///
+/// Example:
+///
+///   %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+///   %result = vector.mask %mask {
+///                vector.outerproduct %vecA, %vecB
+///                 : vector<[4]xf32>, vector<[4]xf32>
+///             } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///
+/// is converted to:
+///
+///    %maskA = vector.create_mask %dimA : vector<[4]xi1>
+///    %maskB = vector.create_mask %dimB : vector<[4]xi1>
+///    %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+///                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+///
+struct VectorOuterProductToArmSMELowering
+    : public OpRewritePattern<vector::OuterProductOp> {
+
+  using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
+                                PatternRewriter &rewriter) const override {
+    // AXPY operation not suited for SME.
+    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+      return outerProductOp.emitError("AXPY operations not supported");
+
+    auto kind = outerProductOp.getKind();
+    if (kind != vector::CombiningKind::ADD)
+      return outerProductOp.emitError("unsupported kind");
+
+    Value lhsMask = {};
+    Value rhsMask = {};
+    Operation *rootOp = outerProductOp;
+    if (outerProductOp.isMasked()) {
+      auto maskingOp = outerProductOp.getMaskingOp();
+      rewriter.setInsertionPoint(maskingOp);
+      rootOp = maskingOp;
+
+      // Attempt to extract masks from vector.create_mask.
+      // TODO: Add support for other mask sources.
+      auto mask = maskingOp.getMask();
+      auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return failure();
+
+      auto maskType = createMaskOp.getVectorType();
+      if (maskType.getRank() != 2)
+        return failure();
+
+      auto loc = outerProductOp.getLoc();
+
+      Value lhsMaskDim = createMaskOp.getOperand(0);
+      Value rhsMaskDim = createMaskOp.getOperand(1);
+
+      VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
+      lhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
+                                                      lhsMaskDim);
+      rhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
+                                                      rhsMaskDim);
+    }
+
+    rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
+        rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
+        outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
+
+    return success();
----------------
banach-space wrote:

Try to avoid big blocks of indentation: https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code

```suggestion
    if (!outerProductOp.isMasked()) {
        rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
        rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
        outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());

        return success();
    }
     
    auto maskingOp = outerProductOp.getMaskingOp();
    rewriter.setInsertionPoint(maskingOp);
    rootOp = maskingOp;

    // Attempt to extract masks from vector.create_mask.
    // TODO: Add support for other mask sources.
    auto mask = maskingOp.getMask();
    auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
    if (!createMaskOp)
      return failure();

    auto maskType = createMaskOp.getVectorType();
    if (maskType.getRank() != 2
      return failure();

    auto loc = outerProductOp.getLoc();

    Value lhsMaskDim = createMaskOp.getOperand(0);
    Value rhsMaskDim = createMaskOp.getOperand(1);

    VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
    lhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
                                                    lhsMaskDim);
    rhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
                                                    rhsMaskDim);
  }
```

https://github.com/llvm/llvm-project/pull/69604


More information about the Mlir-commits mailing list