[Mlir-commits] [mlir] [mlir][ArmSME] Add initial SME vector legalization pass (PR #79152)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 23 07:15:04 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This adds a new pass (`-arm-sme-vector-legalization`) which legalizes vector operations so that they can be lowered to ArmSME. This initial patch adds decomposition for `vector.outerproduct`, `vector.transfer_read`, and `vector.transfer_write` when they operate on vector types larger than a single SME tile. For example, a [8]x[8]xf32 outer product would be decomposed into four [4]x[4]xf32 outer products, which could then be lowered to ArmSME. These three ops have been picked as supporting them alone allows lowering matmuls that use all ZA accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple of an SME tile size, but other than that any shape can be used. E.g. `vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>` can all be lowered to four `vector<[4]x[4]xf32>` operations.

In future, this pass will be extended with more SME-specific rewrites to legalize unrolling the reduction dimension of matmuls (which is not type-decomposition), which is why the pass has quite a general name.

---

Patch is 50.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79152.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+3) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+19) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+7) 
- (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+22) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (+4-1) 
- (added) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+308) 
- (added) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+268) 
- (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir (+109) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir (+171) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index aef2959265a7cd..9ba8c43551257b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,6 +32,9 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
 /// Pass that allocates tile IDs to ArmSME operations.
 std::unique_ptr<Pass> createTileAllocationPass();
 
+/// Pass that legalizes vectors so they can be lowered to ArmSME.
+std::unique_ptr<Pass> createVectorLegalizationPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8d1ba6ed34e805..9a6f5446de0094 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -122,4 +122,23 @@ def TileAllocation
   let dependentDialects = ["func::FuncDialect"];
 }
 
+def VectorLegalization
+  : Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> {
+  let summary = "Legalize vectors for ArmSME";
+  let description = [{
+    This pass legalizes vector operations so that they can be lowered to ArmSME.
+    This includes decomposing operations that operate on vector types larger
+    than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
+    tile-sized operations, as well as rewrites needed to get operations into
+    forms compatible with SME lowerings.
+  }];
+  let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
+  let dependentDialects = [
+    "func::FuncDialect",
+    "arm_sme::ArmSMEDialect",
+    "vector::VectorDialect",
+    "arith::ArithDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 41702008ee48fb..e89be8ed81e03f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -56,6 +56,13 @@ scf::ForOp createLoopOverTileSlices(
     PatternRewriter &rewriter, Location loc, Value initTile,
     std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
 
+/// Returns true if `vType` is a multiple of an SME tile size. Note returns
+/// false if the `vType` exactly matches the size of an SME tile.
+bool isMultipleOfSMETileVectorType(VectorType vType);
+
+/// Creates a vector type for the SME tile of `elementType`.
+VectorType getSMETileTypeForElement(Type elementType);
+
 } // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index d0391210555662..6a9e0221822267 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -94,4 +94,26 @@ scf::ForOp createLoopOverTileSlices(
   return forOp;
 }
 
+bool isMultipleOfSMETileVectorType(VectorType vType) {
+  if (vType.getRank() != 2 || !vType.allDimsScalable())
+    return false;
+
+  auto elementType = vType.getElementType();
+  if (!isValidSMETileElementType(elementType))
+    return false;
+
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+
+  int64_t vectorRows = vType.getDimSize(0);
+  int64_t vectorCols = vType.getDimSize(1);
+
+  return (vectorRows > minNumElts || vectorCols > minNumElts) &&
+         vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
+}
+
+VectorType getSMETileTypeForElement(Type elementType) {
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+  return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
+}
+
 } // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 96eb5844204384..3c32fc2645ce1b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
   EnableArmStreaming.cpp
   TileAllocation.cpp
+  VectorLegalization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -9,10 +10,12 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRArmSMETransformsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRPass
   MLIRArmSMEDialect
   MLIRFuncDialect
   MLIRLLVMCommonConversion
   MLIRVectorDialect
   MLIRSCFDialect
-  MLIRPass
+  MLIRSCFTransforms
+  MLIRFuncTransforms
   )
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
new file mode 100644
index 00000000000000..a801ebe27413d5
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -0,0 +1,308 @@
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+#define DEBUG_TYPE "arm-sme-vector-legalization"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_VECTORLEGALIZATION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+
+struct SMETile {
+  // Note: The units of (row, col) are vscale (as SME tiles are scalable).
+  int row{0};
+  int col{0};
+  VectorType type;
+};
+
+/// Adds a constant scalable offset to `indices`. i.e. for 2D:
+/// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
+SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
+                                                Location loc,
+                                                ValueRange indices,
+                                                ArrayRef<int> scalableOffset) {
+  auto vscale = builder.create<vector::VectorScaleOp>(loc);
+  return llvm::map_to_vector(
+      llvm::zip_equal(indices, scalableOffset), [&](auto pair) -> Value {
+        auto [index, base] = pair;
+        auto offset = builder.create<arith::MulIOp>(
+            loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
+        return builder.create<arith::AddIOp>(loc, index, offset);
+      });
+}
+
+/// Remaps indices (e.g. from a load/store) for a larger vector type to indices
+/// for one of the SME tiles it will decompose into.
+SmallVector<Value, 2> remapIndicesForSMETile(OpBuilder &builder, Location loc,
+                                             ValueRange indices,
+                                             SMETile tileTile) {
+  return addConstantScalableOffset(builder, loc, indices,
+                                   {tileTile.row, tileTile.col});
+}
+
+/// Returns true if `mask` is generated by an operation that can be decomposed
+/// for SME. Currently, that is just no mask, or vector.create_mask.
+bool isSupportedMaskOp(Value mask) {
+  return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
+}
+
+/// Extracts a mask for an SME tile from the mask of a larger vector type.
+Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
+                     SMETile tileTile) {
+  assert(isSupportedMaskOp(mask));
+  if (!mask)
+    return Value{};
+  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+  // The the operands of `vector.create_mask` (from a 2D perspective) are the
+  // coordinates where the mask ends. So we subtract where this tile starts,
+  // from the mask operands to get the parameters for this tile tile.
+  auto tileMaskDims = addConstantScalableOffset(
+      builder, loc, createMask.getOperands(), {-tileTile.row, -tileTile.col});
+  auto createTileMask = builder.create<vector::CreateMaskOp>(
+      loc, tileTile.type.clone(builder.getI1Type()), tileMaskDims);
+  return createTileMask.getResult();
+}
+
+/// Constructs an iterator that returns each SME tile (with coordinates)
+/// contained within a VectorType.
+auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
+                         VectorType smeTileType,
+                         bool transposeIndices = false) {
+  assert(isMultipleOfSMETileVectorType(type));
+  return llvm::map_range(
+      StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
+                                              smeTileType.getDimSize(1)}),
+      [=](auto indices) {
+        int row = int(indices[0]);
+        int col = int(indices[1]);
+        if (transposeIndices)
+          std::swap(row, col);
+        return SMETile{row, col, smeTileType};
+      });
+}
+
+/// Returns the number of SME tiles that fit into the a vector type.
+int getNumberOfSMETilesForVectorType(VectorType type) {
+  assert(isMultipleOfSMETileVectorType(type));
+  int64_t vectorRows = type.getDimSize(0);
+  int64_t vectorCols = type.getDimSize(1);
+  auto elementType = type.getElementType();
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+  return (vectorRows * vectorCols) / (minNumElts * minNumElts);
+}
+
+/// Legalize `vector.outerproduct` operations to fit within SME tiles.
+struct LegalizeVectorOuterProductOp
+    : public OneToNOpConversionPattern<vector::OuterProductOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = outerProductOp.getResultVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    Value mask;
+    Operation *rootOp = outerProductOp;
+    auto loc = outerProductOp.getLoc();
+    if (outerProductOp.isMasked()) {
+      auto maskOp = outerProductOp.getMaskingOp();
+      mask = maskOp.getMask();
+      rootOp = maskOp;
+    }
+
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    ValueRange accSMETiles = adaptor.getAcc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    VectorType sliceType = VectorType::Builder(tileType).dropDim(0);
+
+    SmallVector<Value> resultSMETiles;
+    for (auto [index, tileTile] :
+         llvm::enumerate(decomposeToSMETiles(rewriter, vectorType, tileType))) {
+
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto lhs = rewriter.create<vector::ScalableExtractOp>(
+          loc, sliceType, outerProductOp.getLhs(), tileTile.row);
+      auto rhs = rewriter.create<vector::ScalableExtractOp>(
+          loc, sliceType, outerProductOp.getRhs(), tileTile.col);
+      auto tileOuterProduct = rewriter.create<vector::OuterProductOp>(
+          loc, tileType, lhs, rhs,
+          !accSMETiles.empty() ? accSMETiles[index] : Value{},
+          outerProductOp.getKind());
+
+      auto maskedOuterProduct =
+          vector::maskOperation(rewriter, tileOuterProduct, tileMask);
+      resultSMETiles.push_back(maskedOuterProduct->getResult(0));
+    }
+
+    rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+    return success();
+  }
+};
+
+// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
+// get the help of the type conversion), but doing so results in the type
+// conversion adding target materializations in the `vector.mask` region
+// (invalid). This pattern matches on `vector.mask` then calls into the
+// `vector.outerproduct` pattern to work around this issue.
+struct LegalizeMaskedVectorOuterProductOp
+    : public OneToNOpConversionPattern<vector::MaskOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (auto outerProductOp =
+            llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
+      LegalizeVectorOuterProductOp pattern(*getTypeConverter(), getContext());
+      return static_cast<RewritePattern &>(pattern).matchAndRewrite(
+          outerProductOp, rewriter);
+    }
+    return failure();
+  }
+};
+
+/// Legalize `vector.transfer_read` operations to fit within SME tiles.
+struct LegalizeTransferReadOp
+    : public OneToNOpConversionPattern<vector::TransferReadOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = readOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    auto mask = readOp.getMask();
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    auto permutationMap = readOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return failure();
+
+    // Note: For 2D vector types the only non-identity permutation is a simple
+    // tranpose [1, 0].
+    bool transposed = !permutationMap.isIdentity();
+
+    auto loc = readOp.getLoc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+
+    SmallVector<Value> resultSMETiles;
+    for (SMETile tileTile :
+         decomposeToSMETiles(rewriter, vectorType, tileType, transposed)) {
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto transferRead = rewriter.create<vector::TransferReadOp>(
+          loc, tileType, readOp.getSource(),
+          remapIndicesForSMETile(rewriter, loc, readOp.getIndices(), tileTile),
+          readOp.getPermutationMapAttr(), readOp.getPadding(), tileMask,
+          readOp.getInBoundsAttr());
+      resultSMETiles.push_back(transferRead);
+    }
+
+    rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+    return success();
+  }
+};
+
+/// Legalize `vector.transfer_write` operations to fit within SME tiles.
+struct LegalizeTransferWriteOp
+    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = writeOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    auto mask = writeOp.getMask();
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return failure();
+
+    // Note: For 2D vector types the only non-identity permutation is a simple
+    // tranpose [1, 0].
+    bool transposed = !permutationMap.isIdentity();
+
+    auto loc = writeOp.getLoc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto inputSMETiles = adaptor.getVector();
+
+    Value destTensorOrMemref = writeOp.getSource();
+    for (auto [index, tileTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, vectorType, tileType, transposed))) {
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto tileWrite = rewriter.create<vector::TransferWriteOp>(
+          loc, inputSMETiles[index], destTensorOrMemref,
+          remapIndicesForSMETile(rewriter, loc, writeOp.getIndices(), tileTile),
+          writeOp.getPermutationMapAttr(), tileMask, writeOp.getInBoundsAttr());
+      if (writeOp.hasPureTensorSemantics())
+        destTensorOrMemref = tileWrite.getResult();
+    }
+
+    if (writeOp.hasPureTensorSemantics())
+      rewriter.replaceOp(writeOp, destTensorOrMemref);
+    else
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+};
+
+struct VectorLegalizationPass
+    : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
+  void runOnOperation() override {
+    auto *context = &getContext();
+    OneToNTypeConverter converter;
+    RewritePatternSet patterns(context);
+
+    converter.addConversion([](Type type) { return type; });
+    converter.addConversion(
+        [](VectorType vectorType,
+           SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
+          if (!isMultipleOfSMETileVectorType(vectorType))
+            return std::nullopt;
+          auto tileTileCount = getNumberOfSMETilesForVectorType(vectorType);
+          auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+          types = SmallVector<Type>(tileTileCount, tileType);
+          return success();
+        });
+
+    // Note: High benefit to ensure masked outer products are lowered first.
+    patterns.add<LegalizeMaskedVectorOuterProductOp>(converter, context, 1024);
+    patterns.add<LegalizeVectorOuterProductOp, LegalizeTransferReadOp,
+                 LegalizeTransferWriteOp>(converter, context);
+    populateFuncTypeConversionPatterns(converter, patterns);
+    scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
+
+    if (failed(applyPartialOneToNConversion(getOperation(), converter,
+                                            std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
+  return std::make_unique<VectorLegalizationPass>();
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
new file mode 100644
index 00000000000000..a20abeefedcfd4
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -0,0 +1,268 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @outerproduct_f32_scalable_8x8_no_acc(
+// CHECK-SAME:                                        %[[LHS:.*]]: vector<[8]xf32>,
+// CHECK-SAME:                                        %[[RHS:.*]]: vector<[8]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_8x8_no_acc(%lhs: vector<[8]xf32>, %rhs: vector<[8]xf32>) -> vector<[8]x[8]xf32>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf32>, vector<[8]xf32>
+  return %0 : vector<[8]x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_scalable_4x16_acc(
+// CHECK-SAME:                                      %[[LHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                                      %[[RHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                      %[[ACC_0:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_1:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_2:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_3:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_4x16_acc(%lhs: vector<[4]xf32>, %rhs: vector<[16]xf32>, %acc: ve...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/79152


More information about the Mlir-commits mailing list