[Mlir-commits] [mlir] [mlir][ArmSVE] Add `-arm-sve-legalize-vector-storage` pass (PR #68794)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Oct 24 10:18:27 PDT 2023
================
@@ -0,0 +1,315 @@
+//===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::arm_sve {
+#define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sve
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+constexpr StringLiteral kPassLabel("__arm_sve_legalize_vector_storage__");
+
+namespace {
+
+/// A (legal) SVE predicate mask that has a logical size, i.e. the number of
+/// bits match the number of lanes it masks (such as vector<[4]xi1>), but is too
+/// small to be stored to memory.
+bool isLogicalSVEPredicateType(VectorType type) {
+ return type.getRank() > 0 && type.getElementType().isInteger(1) &&
+ type.getScalableDims().back() && type.getShape().back() < 16 &&
+ llvm::isPowerOf2_32(type.getShape().back()) &&
+ !llvm::is_contained(type.getScalableDims().drop_back(), true);
+}
+
+VectorType widenScalableMaskTypeToSvbool(VectorType type) {
+ assert(isLogicalSVEPredicateType(type));
+ return VectorType::Builder(type).setDim(type.getRank() - 1, 16);
+}
+
+template <typename TOp, typename TLegalizerCallback>
+void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op,
+ TLegalizerCallback callback) {
+ // Clone the previous op to preserve any properties/attributes.
+ auto newOp = op.clone();
+ rewriter.insert(newOp);
+ rewriter.replaceOp(op, callback(newOp));
+}
+
+template <typename TOp, typename TLegalizerCallback>
+void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
+ TLegalizerCallback callback) {
+ replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
+ // Mark our `unrealized_conversion_casts` with a pass label.
+ return rewriter.create<UnrealizedConversionCastOp>(
+ op.getLoc(), TypeRange{op.getResult().getType()},
+ ValueRange{callback(newOp)},
+ NamedAttribute(rewriter.getStringAttr(kPassLabel),
+ rewriter.getUnitAttr()));
+ });
+}
+
+/// Extracts the legal memref value from the `unrealized_conversion_casts` added
+/// by this pass.
+static FailureOr<Value> getLegalMemRef(Value illegalMemref) {
+ Operation *definingOp = illegalMemref.getDefiningOp();
+ if (!definingOp || !definingOp->hasAttr(kPassLabel))
+ return failure();
+ auto unrealizedConversion =
+ llvm::cast<UnrealizedConversionCastOp>(definingOp);
+ return unrealizedConversion.getOperand(0);
+}
+
+/// The default alignment of an alloca may request overaligned sizes for SVE
+/// types, which will fail during stack frame allocation. This rewrite
+/// explicitly adds a reasonable alignment to allocas of scalable types.
+struct RelaxScalableVectorAllocaAlignment
+ : public OpRewritePattern<memref::AllocaOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
+ PatternRewriter &rewriter) const override {
+ auto memrefElementType = allocaOp.getType().getElementType();
+ auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
+ if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
+ return failure();
+
+ // Set alignment based on the defaults for SVE vectors and predicates.
+ unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
+ allocaOp.setAlignment(aligment);
+
+ return success();
+ }
+};
+
+/// Replaces allocations of SVE predicates smaller than an svbool with a wider
+/// allocation and a tagged unrealized conversion.
----------------
MacDue wrote:
> 1. Could you provide a definition of svbool_t somewhere?
I've added a little definition here (it's now defined in a few places, ArmSVE op docs, the pass docs, and in some comments) :)
> 2. Could you document what tag you are referring here to?
The tag is shown in the example
https://github.com/llvm/llvm-project/pull/68794
More information about the Mlir-commits
mailing list