[Mlir-commits] [mlir] [mlir][ArmSVE] Add `-arm-sve-legalize-vector-storage` pass (PR #68794)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Oct 25 00:59:39 PDT 2023


================
@@ -0,0 +1,327 @@
+//===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::arm_sve {
+#define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sve
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+// A tag to mark unrealized_conversions produced by this pass. This is used to
+// detect IR this pass failed to completely legalize, and report an error.
+// If everything was successfully legalized, no tagged ops will remain after
+// this pass.
+constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__");
+
+namespace {
+
+/// A (legal) SVE predicate mask that has a logical size, i.e. the number of
+/// bits match the number of lanes it masks (such as vector<[4]xi1>), but is too
+/// small to be stored to memory.
+bool isLogicalSVEMaskType(VectorType type) {
+  return type.getRank() > 0 && type.getElementType().isInteger(1) &&
+         type.getScalableDims().back() && type.getShape().back() < 16 &&
+         llvm::isPowerOf2_32(type.getShape().back()) &&
+         !llvm::is_contained(type.getScalableDims().drop_back(), true);
+}
+
+VectorType widenScalableMaskTypeToSvbool(VectorType type) {
+  assert(isLogicalSVEMaskType(type));
+  return VectorType::Builder(type).setDim(type.getRank() - 1, 16);
+}
+
+template <typename TOp, typename TLegalizerCallback>
+void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op,
+                              TLegalizerCallback callback) {
+  // Clone the previous op to preserve any properties/attributes.
+  auto newOp = op.clone();
+  rewriter.insert(newOp);
+  rewriter.replaceOp(op, callback(newOp));
+}
+
+template <typename TOp, typename TLegalizerCallback>
+void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
+                                       TLegalizerCallback callback) {
+  replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
+    // Mark our `unrealized_conversion_casts` with a pass label.
+    return rewriter.create<UnrealizedConversionCastOp>(
+        op.getLoc(), TypeRange{op.getResult().getType()},
+        ValueRange{callback(newOp)},
+        NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag),
+                       rewriter.getUnitAttr()));
+  });
+}
+
+/// Extracts the legal SVE memref value from the `unrealized_conversion_casts`
+/// added by this pass.
+static FailureOr<Value> getSVELegalizedMemref(Value illegalMemref) {
+  Operation *definingOp = illegalMemref.getDefiningOp();
+  if (!definingOp || !definingOp->hasAttr(kSVELegalizerTag))
+    return failure();
+  auto unrealizedConversion =
+      llvm::cast<UnrealizedConversionCastOp>(definingOp);
+  return unrealizedConversion.getOperand(0);
+}
+
+/// The default alignment of an alloca may request overaligned sizes for SVE
+/// types, which will fail during stack frame allocation. This rewrite
+/// explicitly adds a reasonable alignment to allocas of scalable types.
+struct RelaxScalableVectorAllocaAlignment
+    : public OpRewritePattern<memref::AllocaOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
+                                PatternRewriter &rewriter) const override {
+    auto memrefElementType = allocaOp.getType().getElementType();
+    auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
+    if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
+      return failure();
+
+    // Set alignment based on the defaults for SVE vectors and predicates.
+    unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
+    allocaOp.setAlignment(aligment);
+
+    return success();
+  }
+};
+
+/// Replaces allocations of SVE predicates smaller than an svbool_t (illegal)
+/// with a wider allocation of svbool_t (legal) followed by a tagged unrealized
+/// conversion to the original type.
+///
+/// svbool = vector<...x[16]xi1>, which maps to some multiple of full SVE
+/// predicate registers. A full predicate is the smallest quantity that can be
+/// loaded/stored.
----------------
banach-space wrote:

```suggestion
/// svbool_t = vector<...x[16]xi1>, which maps to some multiple of full SVE
/// predicate registers. A full predicate is the smallest quantity that can be
/// loaded/stored.
```

I would, perhaps, add this definition at the top of this file and just refer to in various places. For example:
```
// At the top:
// [1] svbool_t = (...)

// Further down:
// Replaces allocations of SVE predicates smaller than an svbool_t (_illegal_ to load/store)
/// with a wider allocation of svbool_t (_legal_ to load/store) followed by a tagged unrealized
/// conversion to the original type [1].
```

This way it is easy to re-use the same definition multiple times. You could also add a link to ACLE that defines `svbool_t`?

https://github.com/llvm/llvm-project/pull/68794


More information about the Mlir-commits mailing list