[Mlir-commits] [mlir] [mlir][ArmSME] Add arith-to-arm-sme conversion pass (PR #78197)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Jan 17 08:08:40 PST 2024


================
@@ -0,0 +1,127 @@
+//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "arith-to-arm-sme"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion helpers
+//===----------------------------------------------------------------------===//
+
+/// Returns true if 'val' is a splat of zero, false otherwise.
+static bool isSplatZero(Type elemType, DenseElementsAttr val) {
+  if (llvm::isa<FloatType>(elemType))
+    return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
+  if (llvm::isa<IntegerType>(elemType))
+    return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
+  return false;
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConstantOp
+//===----------------------------------------------------------------------===//
+
+/// Conversion pattern for dense arith.constant.
+struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
+  using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
+                                PatternRewriter &rewriter) const final {
+    auto tileType = dyn_cast<VectorType>(constantOp.getType());
+    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+      return failure();
+
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+    if (!denseAttr || !denseAttr.isSplat())
+      return failure();
+
+    auto tileElementType = tileType.getElementType();
+
+    // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
+    if (isSplatZero(tileElementType, denseAttr)) {
+      rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
+      return success();
+    }
+
+    // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
+    // ops that broadcast the constant to each tile slice.
+    auto loc = constantOp.getLoc();
+
+    // To fill a tile with a constant, we create a 1-D splat of the constant,
+    // then move that into each tile slice (the largest unit we can set at once,
+    // outside of operations like the outerproduct).
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+    auto denseAttr1D = DenseElementsAttr::get(
+        tileSliceType, denseAttr.getSplatValue<Attribute>());
+    auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
+
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+    arm_sme::LoopBodyBuilder loopBody = [&](OpBuilder &b, Location loc,
----------------
MacDue wrote:

e.g.:
```suggestion
    auto loopBody = [&](OpBuilder &b, Location loc,
```

https://github.com/llvm/llvm-project/pull/78197


More information about the Mlir-commits mailing list