[Mlir-commits] [mlir] [mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA (PR #65621)

Diego Caballero llvmlistbot at llvm.org
Sun Sep 10 23:14:56 PDT 2023


================
@@ -361,6 +361,111 @@ struct MoveVectorToTileSliceToArmSMELowering
   }
 };
 
+/// Lower `vector.outerproduct` to SME MOPA intrinsics.
+///
+/// Example:
+///
+///   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
+///     : vector<[4]xf32>, vector<[4]xf32>
+///
+/// is converted to:
+///
+///   "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
+///     : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
+///        vector<[4]xf32>) -> ()
+///
+/// Currently only supports FMOPA and BFMOPA (non-widening).
+struct VectorOuterProductToArmSMELowering
+    : public ConvertOpToLLVMPattern<vector::OuterProductOp> {
+  using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::OuterProductOp outerProductOp,
+                  vector::OuterProductOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto isSupportedType = [](VectorType vectorType) {
+      // TODO: the FP outer product instruction variants are predicated on
+      // different features:
+      //
+      // * FMOPA (non-widening)
+      //   * half-precision   - +sme2p1,+sme-f16f16
+      //   * single-precision - +sme
+      //   * double-precision - +sme-f64f64
+      // * BFMOPA
+      //   * half-precision   - +sme2p1,+b16b16
+      //
+      // It should be possible to control lowering based on target features.
+      if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
+        return false;
+
+      auto elementType = vectorType.getElementType();
+
+      if (!elementType.isF16() && !elementType.isBF16() &&
+          !elementType.isF32() && !elementType.isF64())
+        return false;
+
+      unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
+                            vectorType.getElementTypeBitWidth();
+      if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
+        return false;
+
+      return true;
+    };
+
+    auto resultVectorType = outerProductOp.getResultVectorType();
+    if (!isSupportedType(resultVectorType))
+      return outerProductOp.emitError("unsupported type");
+
+    vector::CombiningKind kind = outerProductOp.getKind();
+    if (kind != vector::CombiningKind::ADD)
+      // TODO: support subtract.
+      return outerProductOp.emitError("unsupported kind");
+
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
+    if (maskableOp.isMasked())
+      // TODO: support masking.
+      return outerProductOp.emitError("masking is currently unsupported");
+
+    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+      // AXPY operation not suited for SME.
+      return failure();
+
+    auto loc = outerProductOp.getLoc();
+
+    Value acc = outerProductOp.getAcc();
+    if (!acc)
+      // Initalize accumulator with zero.
+      acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+
+    unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
+    auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
+        loc, rewriter.getIntegerType(elementWidth), acc);
+
+    // Create all active predicate mask.
+    auto one = rewriter.create<arith::ConstantOp>(
----------------
dcaballe wrote:

We should use a ConstantMaskOp here

https://github.com/llvm/llvm-project/pull/65621


More information about the Mlir-commits mailing list