[Mlir-commits] [mlir] [mlir][ArmSME] Add initial SME vector legalization pass (PR #79152)

Andrzej Warzyński llvmlistbot at llvm.org
Fri Jan 26 03:56:32 PST 2024


================
@@ -0,0 +1,379 @@
+//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass legalizes vector operations so they can be lowered to ArmSME.
+// Currently, this only implements the decomposition of vector operations that
+// use vector sizes larger than an SME tile, into multiple SME-sized operations.
+//
+// Note: In the context of this pass 'tile' always refers to an SME tile.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+#define DEBUG_TYPE "arm-sme-vector-legalization"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_VECTORLEGALIZATION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+
+// Common match failure reasons.
+static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
+    "op vector size is not multiple of SME tiles");
+static constexpr StringLiteral MATCH_FAILURE_UNSUPPORTED_MASK_OP(
+    "op mask is unsupported for legalization/decomposition");
+static constexpr StringLiteral
+    MATCH_FAILURE_NON_PERMUTATION_MAP("op affine map is not a permutation");
+
+/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
+/// larger vector type. The (`row`, `col`) are the position of the tile in the
+/// original vector type. For example for an [8]x[8] tile would have four
+/// [4]x[4] sub-tiles.
+///
+///           8 x vscale
+/// ┌─────────────┬─────────────┐
+/// │(0,0)        │(0,4)        │
+/// │             │             │
+/// ├─────────────┼─────────────┤ 8 x vscale
+/// │(4,0)        │(4,4)        │
+/// │             │             │
+/// └─────────────┴─────────────┘
----------------
banach-space wrote:

Love this! 

> For example for an [8]x[8] tile would have four [4]x[4] sub-tiles.

Should this mention that it's an example specifically for 32-bit elements? I guess the intent is to keep this fairly abstract, so [8]x[8] could, in theory, be decomposed into 16 [2]x[2] tiles? I think that you can just replace:

> For example for an [8]x[8] tile would have four [4]x[4] sub-tiles.

with e.g.

> For example for an [8]x[8] tile with four [4]x[4] sub-tiles, we would have:

and then it should work for all cases.

https://github.com/llvm/llvm-project/pull/79152


More information about the Mlir-commits mailing list