[Mlir-commits] [mlir] [mlir][scf][vector] Add `scf.parallel` vectorizer (PR #94168)
Maksim Levental
llvmlistbot at llvm.org
Mon Jun 3 10:57:01 PDT 2024
================
@@ -0,0 +1,648 @@
+//===- SCFVectorize.cpp - SCF vectorization utilities ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/SCFVectorize.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/IRMapping.h"
+
+using namespace mlir;
+
+static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); }
+
+/// Return type bitwidth for vectorization purposes or 0 if type cannot be
+/// vectorized.
+static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
+ if (!isSupportedVecElem(type))
+ return 0;
+
+ if (DL)
+ return DL->getTypeSizeInBits(type);
+
+ if (type.isIntOrFloat())
+ return type.getIntOrFloatBitWidth();
+
+ return 0;
+}
+
+static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) {
+ unsigned ret = 0;
+ for (auto arg : op.getOperands())
+ ret = std::max(ret, getTypeBitWidth(arg.getType(), DL));
+
+ for (auto res : op.getResults())
+ ret = std::max(ret, getTypeBitWidth(res.getType(), DL));
+
+ return ret;
+}
+
+static bool isSupportedVectorOp(Operation &op) {
+ return op.hasTrait<OpTrait::Vectorizable>();
+}
+
+/// Check if one `ValueRange` is permutation of another, i.e. contains same
+/// values, potentially in different order.
+static bool isRangePermutation(ValueRange val1, ValueRange val2) {
+ if (val1.size() != val2.size())
+ return false;
+
+ for (auto v1 : val1) {
+ auto it = llvm::find(val2, v1);
+ if (it == val2.end())
+ return false;
+ }
+ return true;
+}
+
+template <typename Op>
+static std::optional<unsigned>
+canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
+ const DataLayout *DL) {
+ auto loopIndexVars = loop.getInductionVars();
+ assert(dim < loopIndexVars.size());
+ auto memref = memOp.getMemRef();
+ auto type = cast<MemRefType>(memref.getType());
+ auto width = getTypeBitWidth(type.getElementType(), DL);
----------------
makslevental wrote:
Sometimes I agree with you and sometimes I disagree with you (often clangd doesn't figure out the type fast enough or ever and I am stuck guessing). Either way it's pretty established convention :shrug:
https://github.com/llvm/llvm-project/pull/94168
More information about the Mlir-commits
mailing list