[Mlir-commits] [mlir] [mlir][scf][vector] Add `scf.parallel` vectorizer (PR #94168)

Maksim Levental llvmlistbot at llvm.org
Mon Jun 3 08:51:19 PDT 2024


================
@@ -0,0 +1,648 @@
+//===- SCFVectorize.cpp - SCF vectorization utilities ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/SCFVectorize.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/IRMapping.h"
+
+using namespace mlir;
+
+static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); }
+
+/// Return type bitwidth for vectorization purposes or 0 if type cannot be
+/// vectorized.
+static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
+  if (!isSupportedVecElem(type))
+    return 0;
+
+  if (DL)
+    return DL->getTypeSizeInBits(type);
+
+  if (type.isIntOrFloat())
+    return type.getIntOrFloatBitWidth();
+
+  return 0;
+}
+
+static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) {
+  unsigned ret = 0;
+  for (auto arg : op.getOperands())
+    ret = std::max(ret, getTypeBitWidth(arg.getType(), DL));
+
+  for (auto res : op.getResults())
+    ret = std::max(ret, getTypeBitWidth(res.getType(), DL));
+
+  return ret;
+}
+
+static bool isSupportedVectorOp(Operation &op) {
+  return op.hasTrait<OpTrait::Vectorizable>();
+}
+
+/// Check if one `ValueRange` is permutation of another, i.e. contains same
+/// values, potentially in different order.
+static bool isRangePermutation(ValueRange val1, ValueRange val2) {
+  if (val1.size() != val2.size())
+    return false;
+
+  for (auto v1 : val1) {
+    auto it = llvm::find(val2, v1);
+    if (it == val2.end())
+      return false;
+  }
+  return true;
+}
+
+template <typename Op>
+static std::optional<unsigned>
+canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
+                               const DataLayout *DL) {
+  auto loopIndexVars = loop.getInductionVars();
+  assert(dim < loopIndexVars.size());
+  auto memref = memOp.getMemRef();
+  auto type = cast<MemRefType>(memref.getType());
+  auto width = getTypeBitWidth(type.getElementType(), DL);
+  if (width == 0)
+    return std::nullopt;
+
+  if (!type.getLayout().isIdentity())
+    return std::nullopt;
+
+  if (!isRangePermutation(memOp.getIndices(), loopIndexVars))
+    return std::nullopt;
+
+  if (memOp.getIndices().back() != loopIndexVars[dim])
+    return std::nullopt;
+
+  DominanceInfo dom;
+  if (!dom.properlyDominates(memref, loop))
+    return std::nullopt;
+
+  return width;
+}
+
+/// Check if memref load/store can be converted into vectorized load/store
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
+static std::optional<unsigned>
+canTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
+                           const DataLayout *DL) {
+  assert(dim < loop.getInductionVars().size());
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return canTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
+
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return canTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL);
+
+  return std::nullopt;
+}
+
+template <typename Op>
+static std::optional<unsigned> canGatherScatterImpl(scf::ParallelOp loop, Op op,
+                                                    const DataLayout *DL) {
+  auto memref = op.getMemRef();
+  auto memrefType = cast<MemRefType>(memref.getType());
+  auto width = getTypeBitWidth(memrefType.getElementType(), DL);
+  if (width == 0)
+    return std::nullopt;
+
+  DominanceInfo dom;
+  return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
+         memrefType.getLayout().isIdentity();
+}
+
+// Check if memref access can be converted into gather/scatter.
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
+static std::optional<unsigned>
+canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) {
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return canGatherScatterImpl(loop, storeOp, DL);
+
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return canGatherScatterImpl(loop, loadOp, DL);
+
+  return std::nullopt;
+}
+
+static std::optional<unsigned> cenVectorizeMemrefOp(scf::ParallelOp loop,
+                                                    unsigned dim, Operation &op,
+                                                    const DataLayout *DL) {
+  if (auto w = canTriviallyVectorizeMemOp(loop, dim, op, DL))
+    return w;
+
+  return canGatherScatter(loop, op, DL);
+}
+
+/// Returns `vector.reduce` kind for specified `scf.parallel` reduce op ot
+/// `std::nullopt` if reduction cannot be handled by `vector.reduce`.
+static std::optional<vector::CombiningKind> getReductionKind(Block &body) {
+  if (!llvm::hasSingleElement(body.without_terminator()))
+    return std::nullopt;
+
+  // TODO: Move getCombinerOpKind to vector dialect.
+  return linalg::getCombinerOpKind(&body.front());
+}
+
+std::optional<scf::SCFVectorizeInfo>
+mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
+                                unsigned vectorBitwidth, const DataLayout *DL) {
+  assert(dim < loop.getStep().size());
+  assert(vectorBitwidth > 0);
----------------
makslevental wrote:

nit: some assert strings here please (I don't remember if they need to end in periods or not)

https://github.com/llvm/llvm-project/pull/94168


More information about the Mlir-commits mailing list