[Mlir-commits] [mlir] [mlir][scf][vector] Add `scf.parallel` vectorizer (PR #94168)
Ivan Butygin
llvmlistbot at llvm.org
Sat Jun 8 15:15:19 PDT 2024
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/94168
>From db95496d83bebfc1db2cbc1ac6c1d04d706b6499 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 5 Dec 2023 16:44:04 -0600
Subject: [PATCH 01/10] [mlir][scf] upstream numba's scf vectorizer
---
mlir/include/mlir/Transforms/SCFVectorize.h | 49 ++
mlir/lib/Transforms/CMakeLists.txt | 1 +
mlir/lib/Transforms/SCFVectorize.cpp | 661 ++++++++++++++++++++
3 files changed, 711 insertions(+)
create mode 100644 mlir/include/mlir/Transforms/SCFVectorize.h
create mode 100644 mlir/lib/Transforms/SCFVectorize.cpp
diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h
new file mode 100644
index 0000000000000..d754b38d5bc23
--- /dev/null
+++ b/mlir/include/mlir/Transforms/SCFVectorize.h
@@ -0,0 +1,49 @@
+//===- SCFVectorize.h - ------------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SCFVECTORIZE_H_
+#define MLIR_TRANSFORMS_SCFVECTORIZE_H_
+
+#include <memory>
+#include <optional>
+
+namespace mlir {
+class OpBuilder;
+class Pass;
+struct LogicalResult;
+namespace scf {
+class ParallelOp;
+}
+} // namespace mlir
+
+namespace mlir {
+struct SCFVectorizeInfo {
+ unsigned dim = 0;
+ unsigned factor = 0;
+ unsigned count = 0;
+ bool masked = false;
+};
+
+std::optional<SCFVectorizeInfo> getLoopVectorizeInfo(mlir::scf::ParallelOp loop,
+ unsigned dim,
+ unsigned vectorBitWidth);
+
+struct SCFVectorizeParams {
+ unsigned dim = 0;
+ unsigned factor = 0;
+ bool masked = false;
+};
+
+mlir::LogicalResult vectorizeLoop(mlir::OpBuilder &builder,
+ mlir::scf::ParallelOp loop,
+ const SCFVectorizeParams ¶ms);
+
+std::unique_ptr<mlir::Pass> createSCFVectorizePass();
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
\ No newline at end of file
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46..ed71c73c938ed 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms
PrintIR.cpp
RemoveDeadValues.cpp
SCCP.cpp
+ SCFVectorize.cpp
SROA.cpp
StripDebugInfo.cpp
SymbolDCE.cpp
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
new file mode 100644
index 0000000000000..d7545ee30e29a
--- /dev/null
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -0,0 +1,661 @@
+//===- ControlFlowSink.cpp - Code to perform control-flow sinking ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/SCFVectorize.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+static unsigned getTypeBitWidth(mlir::Type type) {
+ if (mlir::isa<mlir::IndexType>(type))
+ return 64; // TODO: unhardcode
+
+ if (type.isIntOrFloat())
+ return type.getIntOrFloatBitWidth();
+
+ return 0;
+}
+
+static unsigned getArgsTypeWidth(mlir::Operation &op) {
+ unsigned ret = 0;
+ for (auto arg : op.getOperands())
+ ret = std::max(ret, getTypeBitWidth(arg.getType()));
+
+ for (auto res : op.getResults())
+ ret = std::max(ret, getTypeBitWidth(res.getType()));
+
+ return ret;
+}
+
+static bool isSupportedVectorOp(mlir::Operation &op) {
+ return op.hasTrait<mlir::OpTrait::Vectorizable>();
+}
+
+static bool isSupportedVecElem(mlir::Type type) {
+ return type.isIntOrIndexOrFloat();
+}
+
+static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) {
+ if (val1.size() != val2.size())
+ return false;
+
+ for (auto v1 : val1) {
+ auto it = llvm::find(val2, v1);
+ if (it == val2.end())
+ return false;
+ }
+ return true;
+}
+
+template <typename Op>
+static std::optional<unsigned>
+cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim,
+ Op memOp) {
+ auto loopIndexVars = loop.getInductionVars();
+ assert(dim < loopIndexVars.size());
+ auto memref = memOp.getMemRef();
+ auto type = mlir::cast<mlir::MemRefType>(memref.getType());
+ auto width = getTypeBitWidth(type.getElementType());
+ if (width == 0)
+ return std::nullopt;
+
+ if (!type.getLayout().isIdentity())
+ return std::nullopt;
+
+ if (!isRangePermutation(memOp.getIndices(), loopIndexVars))
+ return std::nullopt;
+
+ if (memOp.getIndices().back() != loopIndexVars[dim])
+ return std::nullopt;
+
+ mlir::DominanceInfo dom;
+ if (!dom.properlyDominates(memref, loop))
+ return std::nullopt;
+
+ return width;
+}
+
+static std::optional<unsigned>
+cavTriviallyVectorizeMemOp(mlir::scf::ParallelOp loop, unsigned dim,
+ mlir::Operation &op) {
+ assert(dim < loop.getInductionVars().size());
+ if (auto storeOp = mlir::dyn_cast<mlir::memref::StoreOp>(op))
+ return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp);
+
+ if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op))
+ return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp);
+
+ return std::nullopt;
+}
+
+template <typename T>
+static bool isOp(mlir::Operation &op) {
+ return mlir::isa<T>(op);
+}
+
+static std::optional<mlir::vector::CombiningKind>
+getReductionKind(mlir::scf::ReduceOp op) {
+ mlir::Block &body = op.getReductionOperator().front();
+ if (!llvm::hasSingleElement(body.without_terminator()))
+ return std::nullopt;
+
+ mlir::Operation &redOp = body.front();
+
+ using fptr_t = bool (*)(mlir::Operation &);
+ using CC = mlir::vector::CombiningKind;
+ const std::pair<fptr_t, CC> handlers[] = {
+ // clang-format off
+ {&isOp<mlir::arith::AddIOp>, CC::ADD},
+ {&isOp<mlir::arith::AddFOp>, CC::ADD},
+ {&isOp<mlir::arith::MulIOp>, CC::MUL},
+ {&isOp<mlir::arith::MulFOp>, CC::MUL},
+ // clang-format on
+ };
+
+ for (auto &&[handler, cc] : handlers) {
+ if (handler(redOp))
+ return cc;
+ }
+
+ return std::nullopt;
+}
+
+std::optional<mlir::SCFVectorizeInfo>
+mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
+ unsigned vectorBitwidth) {
+ assert(dim < loop.getStep().size());
+ assert(vectorBitwidth > 0);
+ unsigned factor = vectorBitwidth / 8;
+ if (factor <= 1)
+ return std::nullopt;
+
+ if (!mlir::isConstantIntValue(loop.getStep()[dim], 1))
+ return std::nullopt;
+
+ unsigned count = 0;
+ bool masked = true;
+
+ for (mlir::Operation &op : loop.getBody()->without_terminator()) {
+ if (auto reduce = mlir::dyn_cast<mlir::scf::ReduceOp>(op)) {
+ if (!getReductionKind(reduce))
+ masked = false;
+
+ continue;
+ }
+
+ if (op.getNumRegions() > 0)
+ return std::nullopt;
+
+ if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op)) {
+ auto newFactor = vectorBitwidth / *w;
+ if (newFactor > 1) {
+ factor = std::min(factor, newFactor);
+ ++count;
+ }
+ continue;
+ }
+
+ if (!isSupportedVectorOp(op)) {
+ masked = false;
+ continue;
+ }
+
+ auto width = getArgsTypeWidth(op);
+ if (width == 0)
+ return std::nullopt;
+
+ auto newFactor = vectorBitwidth / width;
+ if (newFactor <= 1)
+ continue;
+
+ factor = std::min(factor, newFactor);
+
+ ++count;
+ }
+
+ if (count == 0)
+ return std::nullopt;
+
+ return SCFVectorizeInfo{dim, factor, count, masked};
+}
+
+static mlir::arith::FastMathFlags getFMF(mlir::Operation &op) {
+ if (auto fmf = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(op))
+ return fmf.getFastMathFlagsAttr().getValue();
+
+ return mlir::arith::FastMathFlags::none;
+}
+
+mlir::LogicalResult
+mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
+ const mlir::SCFVectorizeParams ¶ms) {
+ auto dim = params.dim;
+ auto factor = params.factor;
+ auto masked = params.masked;
+ assert(dim < loop.getStep().size());
+ assert(factor > 1);
+ assert(mlir::isConstantIntValue(loop.getStep()[dim], 1));
+
+ mlir::OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPoint(loop);
+
+ auto lower = llvm::to_vector(loop.getLowerBound());
+ auto upper = llvm::to_vector(loop.getUpperBound());
+ auto step = llvm::to_vector(loop.getStep());
+
+ auto loc = loop.getLoc();
+
+ auto origIndexVar = loop.getInductionVars()[dim];
+
+ mlir::Value factorVal =
+ builder.create<mlir::arith::ConstantIndexOp>(loc, factor);
+
+ auto origLower = lower[dim];
+ auto origUpper = upper[dim];
+ mlir::Value count =
+ builder.create<mlir::arith::SubIOp>(loc, origUpper, origLower);
+ mlir::Value newCount;
+ if (masked) {
+ mlir::Value incCount =
+ builder.create<mlir::arith::AddIOp>(loc, count, factorVal);
+ mlir::Value one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
+ incCount = builder.create<mlir::arith::SubIOp>(loc, incCount, one);
+ newCount = builder.create<mlir::arith::DivSIOp>(loc, incCount, factorVal);
+ } else {
+ newCount = builder.create<mlir::arith::DivSIOp>(loc, count, factorVal);
+ }
+
+ mlir::Value zero = builder.create<mlir::arith::ConstantIndexOp>(loc, 0);
+ lower[dim] = zero;
+ upper[dim] = newCount;
+
+ auto newLoop = builder.create<mlir::scf::ParallelOp>(loc, lower, upper, step,
+ loop.getInitVals());
+ auto newIndexVar = newLoop.getInductionVars()[dim];
+
+ auto toVectorType = [&](mlir::Type elemType) -> mlir::VectorType {
+ int64_t f = factor;
+ return mlir::VectorType::get(f, elemType);
+ };
+
+ mlir::IRMapping mapping;
+ mlir::IRMapping scalarMapping;
+
+ auto createPosionVec = [&](mlir::VectorType vecType) -> mlir::Value {
+ return builder.create<mlir::ub::PoisonOp>(loc, vecType, nullptr);
+ };
+
+ auto getVecVal = [&](mlir::Value orig) -> mlir::Value {
+ if (auto mapped = mapping.lookupOrNull(orig))
+ return mapped;
+
+ if (orig == origIndexVar) {
+ auto vecType = toVectorType(builder.getIndexType());
+ llvm::SmallVector<mlir::Attribute> elems(factor);
+ for (auto i : llvm::seq(0u, factor))
+ elems[i] = builder.getIndexAttr(i);
+ auto attr = mlir::DenseElementsAttr::get(vecType, elems);
+ mlir::Value vec =
+ builder.create<mlir::arith::ConstantOp>(loc, vecType, attr);
+
+ mlir::Value idx =
+ builder.create<mlir::arith::MulIOp>(loc, newIndexVar, factorVal);
+ idx = builder.create<mlir::arith::AddIOp>(loc, idx, origLower);
+ idx = builder.create<mlir::vector::SplatOp>(loc, idx, vecType);
+ vec = builder.create<mlir::arith::AddIOp>(loc, idx, vec);
+ mapping.map(orig, vec);
+ return vec;
+ }
+ auto type = orig.getType();
+ assert(isSupportedVecElem(type));
+
+ mlir::Value val = orig;
+ auto origIndexVars = loop.getInductionVars();
+ auto it = llvm::find(origIndexVars, orig);
+ if (it != origIndexVars.end())
+ val = newLoop.getInductionVars()[it - origIndexVars.begin()];
+
+ auto vecType = toVectorType(type);
+ mlir::Value vec = builder.create<mlir::vector::SplatOp>(loc, val, vecType);
+ mapping.map(orig, vec);
+ return vec;
+ };
+
+ llvm::DenseMap<mlir::Value, llvm::SmallVector<mlir::Value>> unpackedVals;
+ auto getUnpackedVals = [&](mlir::Value val) -> mlir::ValueRange {
+ auto it = unpackedVals.find(val);
+ if (it != unpackedVals.end())
+ return it->second;
+
+ auto &ret = unpackedVals[val];
+ assert(ret.empty());
+ if (!isSupportedVecElem(val.getType())) {
+ ret.resize(factor, val);
+ return ret;
+ }
+
+ auto vecVal = getVecVal(val);
+ ret.resize(factor);
+ for (auto i : llvm::seq(0u, factor)) {
+ mlir::Value idx = builder.create<mlir::arith::ConstantIndexOp>(loc, i);
+ ret[i] = builder.create<mlir::vector::ExtractElementOp>(loc, vecVal, idx);
+ }
+ return ret;
+ };
+
+ auto setUnpackedVals = [&](mlir::Value origVal, mlir::ValueRange newVals) {
+ assert(newVals.size() == factor);
+ assert(unpackedVals.count(origVal) == 0);
+ unpackedVals[origVal].append(newVals.begin(), newVals.end());
+
+ auto type = origVal.getType();
+ if (!isSupportedVecElem(type))
+ return;
+
+ auto vecType = toVectorType(type);
+
+ mlir::Value vec = createPosionVec(vecType);
+ for (auto i : llvm::seq(0u, factor)) {
+ mlir::Value idx = builder.create<mlir::arith::ConstantIndexOp>(loc, i);
+ vec = builder.create<mlir::vector::InsertElementOp>(loc, newVals[i], vec,
+ idx);
+ }
+ mapping.map(origVal, vec);
+ };
+
+ mlir::Value mask;
+ auto getMask = [&]() -> mlir::Value {
+ if (mask)
+ return mask;
+
+ mlir::OpFoldResult maskSize;
+ if (masked) {
+ mlir::Value size =
+ builder.create<mlir::arith::MulIOp>(loc, factorVal, newIndexVar);
+ maskSize =
+ builder.create<mlir::arith::SubIOp>(loc, count, size).getResult();
+ } else {
+ maskSize = builder.getIndexAttr(factor);
+ }
+ auto vecType = toVectorType(builder.getI1Type());
+ mask = builder.create<mlir::vector::CreateMaskOp>(loc, vecType, maskSize);
+
+ return mask;
+ };
+
+ mlir::DominanceInfo dom;
+
+ auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
+ return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op);
+ };
+
+ auto getMemrefVecIndices = [&](mlir::ValueRange indices) {
+ scalarMapping.clear();
+ scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
+
+ llvm::SmallVector<mlir::Value> ret(indices.size());
+ for (auto &&[i, val] : llvm::enumerate(indices)) {
+ if (val == origIndexVar) {
+ mlir::Value idx =
+ builder.create<mlir::arith::MulIOp>(loc, newIndexVar, factorVal);
+ idx = builder.create<mlir::arith::AddIOp>(loc, idx, origLower);
+ ret[i] = idx;
+ continue;
+ }
+ ret[i] = scalarMapping.lookup(val);
+ }
+
+ return ret;
+ };
+
+ auto canGatherScatter = [&](auto op) {
+ auto memref = op.getMemRef();
+ auto memrefType = mlir::cast<mlir::MemRefType>(memref.getType());
+ if (!isSupportedVecElem(memrefType.getElementType()))
+ return false;
+
+ return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
+ memrefType.getLayout().isIdentity();
+ };
+
+ auto genLoad = [&](auto loadOp) {
+ auto indices = getMemrefVecIndices(loadOp.getIndices());
+ auto resType = toVectorType(loadOp.getResult().getType());
+ auto memref = loadOp.getMemRef();
+ mlir::Value vecLoad;
+ if (masked) {
+ auto mask = getMask();
+ auto init = createPosionVec(resType);
+ vecLoad = builder.create<mlir::vector::MaskedLoadOp>(loc, resType, memref,
+ indices, mask, init);
+ } else {
+ vecLoad =
+ builder.create<mlir::vector::LoadOp>(loc, resType, memref, indices);
+ }
+ mapping.map(loadOp.getResult(), vecLoad);
+ };
+
+ auto genStore = [&](auto storeOp) {
+ auto indices = getMemrefVecIndices(storeOp.getIndices());
+ auto value = getVecVal(storeOp.getValueToStore());
+ auto memref = storeOp.getMemRef();
+ if (masked) {
+ auto mask = getMask();
+ builder.create<mlir::vector::MaskedStoreOp>(loc, memref, indices, mask,
+ value);
+ } else {
+ builder.create<mlir::vector::StoreOp>(loc, value, memref, indices);
+ }
+ };
+
+ llvm::SmallVector<mlir::Value> duplicatedArgs;
+ llvm::SmallVector<mlir::Value> duplicatedResults;
+
+ builder.setInsertionPointToStart(newLoop.getBody());
+ for (mlir::Operation &op : loop.getBody()->without_terminator()) {
+ loc = op.getLoc();
+ if (isSupportedVectorOp(op)) {
+ for (auto arg : op.getOperands())
+ getVecVal(arg); // init mapper for op args
+
+ auto newOp = builder.clone(op, mapping);
+ for (auto res : newOp->getResults())
+ res.setType(toVectorType(res.getType()));
+
+ continue;
+ }
+
+ if (auto reduceOp = mlir::dyn_cast<mlir::scf::ReduceOp>(op)) {
+ scalarMapping.clear();
+ auto &reduceBody = reduceOp.getReductionOperator().front();
+ assert(reduceBody.getNumArguments() == 2);
+
+ mlir::Value reduceVal;
+ if (auto redKind = getReductionKind(reduceOp)) {
+ mlir::Value redArg = getVecVal(reduceOp.getOperand());
+ if (redArg) {
+ auto neutral = mlir::arith::getNeutralElement(&reduceBody.front());
+ assert(neutral);
+ mlir::Value neutralVal =
+ builder.create<mlir::arith::ConstantOp>(loc, *neutral);
+ mlir::Value neutralVec = builder.create<mlir::vector::SplatOp>(
+ loc, neutralVal, redArg.getType());
+ auto mask = getMask();
+ redArg = builder.create<mlir::arith::SelectOp>(loc, mask, redArg,
+ neutralVec);
+ }
+
+ auto fmf = getFMF(reduceBody.front());
+ reduceVal = builder.create<mlir::vector::ReductionOp>(loc, *redKind,
+ redArg, fmf);
+ } else {
+ if (masked)
+ return op.emitError("Cannot vectorize op in masked mode");
+
+ auto reduceTerm =
+ mlir::cast<mlir::scf::ReduceReturnOp>(reduceBody.getTerminator());
+ auto lhs = reduceBody.getArgument(0);
+ auto rhs = reduceBody.getArgument(1);
+ auto unpacked = getUnpackedVals(reduceOp.getOperand());
+ assert(unpacked.size() == factor);
+ reduceVal = unpacked.front();
+ for (auto i : llvm::seq(1u, factor)) {
+ mlir::Value val = unpacked[i];
+ scalarMapping.map(lhs, reduceVal);
+ scalarMapping.map(rhs, val);
+ for (auto &redOp : reduceBody.without_terminator())
+ builder.clone(redOp, scalarMapping);
+
+ reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult());
+ }
+ }
+ scalarMapping.clear();
+ scalarMapping.map(reduceOp.getOperand(), reduceVal);
+ builder.clone(op, scalarMapping);
+ continue;
+ }
+
+ if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op)) {
+ if (canTriviallyVectorizeMemOp(loadOp)) {
+ genLoad(loadOp);
+ continue;
+ }
+ if (canGatherScatter(loadOp)) {
+ auto resType = toVectorType(loadOp.getResult().getType());
+ auto memref = loadOp.getMemRef();
+ auto mask = getMask();
+ auto indexVec = getVecVal(loadOp.getIndices()[0]);
+ auto init = createPosionVec(resType);
+
+ auto gather = builder.create<mlir::vector::GatherOp>(
+ loc, resType, memref, zero, indexVec, mask, init);
+ mapping.map(loadOp.getResult(), gather.getResult());
+ continue;
+ }
+ }
+
+ if (auto storeOp = mlir::dyn_cast<mlir::memref::StoreOp>(op)) {
+ if (canTriviallyVectorizeMemOp(storeOp)) {
+ genStore(storeOp);
+ continue;
+ }
+ if (canGatherScatter(storeOp)) {
+ auto memref = storeOp.getMemRef();
+ auto value = getVecVal(storeOp.getValueToStore());
+ auto mask = getMask();
+ auto indexVec = getVecVal(storeOp.getIndices()[0]);
+
+ builder.create<mlir::vector::ScatterOp>(loc, memref, zero, indexVec,
+ mask, value);
+ }
+ }
+
+ // Fallback: Failed to vectorize op, just duplicate it `factor` times
+ if (masked)
+ return op.emitError("Cannot vectorize op in masked mode");
+
+ scalarMapping.clear();
+
+ auto numArgs = op.getNumOperands();
+ auto numResults = op.getNumResults();
+ duplicatedArgs.resize(numArgs * factor);
+ duplicatedResults.resize(numResults * factor);
+
+ for (auto &&[i, arg] : llvm::enumerate(op.getOperands())) {
+ auto unpacked = getUnpackedVals(arg);
+ assert(unpacked.size() == factor);
+ for (auto j : llvm::seq(0u, factor))
+ duplicatedArgs[j * numArgs + i] = unpacked[j];
+ }
+
+ for (auto i : llvm::seq(0u, factor)) {
+ auto args = mlir::ValueRange(duplicatedArgs)
+ .drop_front(numArgs * i)
+ .take_front(numArgs);
+ scalarMapping.map(op.getOperands(), args);
+ auto results = builder.clone(op, scalarMapping)->getResults();
+
+ for (auto j : llvm::seq(0u, numResults))
+ duplicatedResults[j * factor + i] = results[j];
+ }
+
+ for (auto i : llvm::seq(0u, numResults)) {
+ auto results = mlir::ValueRange(duplicatedResults)
+ .drop_front(factor * i)
+ .take_front(factor);
+ setUnpackedVals(op.getResult(i), results);
+ }
+ }
+
+ if (masked) {
+ loop->replaceAllUsesWith(newLoop.getResults());
+ loop->erase();
+ } else {
+ builder.setInsertionPoint(loop);
+ mlir::Value newLower =
+ builder.create<mlir::arith::MulIOp>(loc, newCount, factorVal);
+ newLower = builder.create<mlir::arith::AddIOp>(loc, origLower, newLower);
+
+ auto lowerCopy = llvm::to_vector(loop.getLowerBound());
+ lowerCopy[dim] = newLower;
+ loop.getLowerBoundMutable().assign(lowerCopy);
+ loop.getInitValsMutable().assign(newLoop.getResults());
+ }
+
+ return mlir::success();
+}
+
+llvm::StringRef getVectorLengthName() { return "numba.vector_length"; }
+
+static std::optional<unsigned> getVectorLength(mlir::Operation *op) {
+ auto func = op->getParentOfType<mlir::FunctionOpInterface>();
+ if (!func)
+ return std::nullopt;
+
+ auto attr = func->getAttrOfType<mlir::IntegerAttr>(getVectorLengthName());
+ if (!attr)
+ return std::nullopt;
+
+ auto val = attr.getInt();
+ if (val <= 0 || val > std::numeric_limits<unsigned>::max())
+ return std::nullopt;
+
+ return static_cast<unsigned>(val);
+}
+
+namespace {
+struct SCFVectorizePass
+ : public mlir::PassWrapper<SCFVectorizePass, mlir::OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFVectorizePass)
+
+ void getDependentDialects(mlir::DialectRegistry ®istry) const override {
+ registry.insert<mlir::arith::ArithDialect>();
+ registry.insert<mlir::scf::SCFDialect>();
+ registry.insert<mlir::ub::UBDialect>();
+ registry.insert<mlir::vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ llvm::SmallVector<
+ std::pair<mlir::scf::ParallelOp, mlir::SCFVectorizeParams>>
+ toVectorize;
+
+ auto getBenefit = [](const mlir::SCFVectorizeInfo &info) {
+ return info.factor * info.count * (int(info.masked) + 1);
+ };
+
+ getOperation()->walk([&](mlir::scf::ParallelOp loop) {
+ auto len = getVectorLength(loop);
+ if (!len)
+ return;
+
+ std::optional<mlir::SCFVectorizeInfo> best;
+ for (auto dim : llvm::seq(0u, loop.getNumLoops())) {
+ auto info = mlir::getLoopVectorizeInfo(loop, dim, *len);
+ if (!info)
+ continue;
+
+ if (!best) {
+ best = *info;
+ continue;
+ }
+
+ if (getBenefit(*info) > getBenefit(*best))
+ best = *info;
+ }
+
+ if (!best)
+ return;
+
+ toVectorize.emplace_back(
+ loop,
+ mlir::SCFVectorizeParams{best->dim, best->factor, best->masked});
+ });
+
+ if (toVectorize.empty())
+ return markAllAnalysesPreserved();
+
+ mlir::OpBuilder builder(&getContext());
+ for (auto &&[loop, params] : toVectorize) {
+ builder.setInsertionPoint(loop);
+ if (mlir::failed(mlir::vectorizeLoop(builder, loop, params)))
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> mlir::createSCFVectorizePass() {
+ return std::make_unique<SCFVectorizePass>();
+}
>From c241fe2eb24f4c2dc21de412a68a371566d9bb17 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 25 Apr 2024 10:57:05 -0500
Subject: [PATCH 02/10] get new stuff
---
mlir/include/mlir/Transforms/SCFVectorize.h | 27 ++-
mlir/lib/Transforms/SCFVectorize.cpp | 210 ++++++++++++++------
2 files changed, 172 insertions(+), 65 deletions(-)
diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h
index d754b38d5bc23..93a7864b976ec 100644
--- a/mlir/include/mlir/Transforms/SCFVectorize.h
+++ b/mlir/include/mlir/Transforms/SCFVectorize.h
@@ -22,23 +22,48 @@ class ParallelOp;
} // namespace mlir
namespace mlir {
+
+/// Loop vectorization info
struct SCFVectorizeInfo {
+ /// Loop dimension on which to vectorize.
unsigned dim = 0;
+
+ /// Biggest vector width, in elements.
unsigned factor = 0;
+
+ /// Number of ops, which will be vectorized.
unsigned count = 0;
+
+ /// Can use masked vector ops for our of bounds memory accesses.
bool masked = false;
};
+/// Collect vectorization statistics on specified `scf.parallel` dimension.
+/// Return `SCFVectorizeInfo` or `std::nullopt` if loop cannot be vectorized on
+/// specified dimension.
+///
+/// `vectorBitwidth` - maximum vector size, in bits.
std::optional<SCFVectorizeInfo> getLoopVectorizeInfo(mlir::scf::ParallelOp loop,
unsigned dim,
- unsigned vectorBitWidth);
+ unsigned vectorBitwidth);
+/// Vectorization params
struct SCFVectorizeParams {
+ /// Loop dimension on which to vectorize.
unsigned dim = 0;
+
+ /// Desired vector length, in elements
unsigned factor = 0;
+
+ /// Use masked vector ops for memory access outside loop bounds.
bool masked = false;
};
+/// Vectorize loop on specified dimension with specified factor.
+///
+/// If `masked` is `true` and loop bound is not divisible by `factor`, instead
+/// of generating second loop to process remainig iterations, extend loop count
+/// and generate masked vector ops to handle out-of bounds memory accesses.
mlir::LogicalResult vectorizeLoop(mlir::OpBuilder &builder,
mlir::scf::ParallelOp loop,
const SCFVectorizeParams ¶ms);
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
index d7545ee30e29a..13a9eca9cd2d3 100644
--- a/mlir/lib/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -16,7 +16,17 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
-
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/MemRef/IR/MemRef.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/Dialect/UB/IR/UBOps.h>
+#include <mlir/Dialect/Vector/IR/VectorOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/Interfaces/FunctionInterfaces.h>
+#include <mlir/Pass/Pass.h>
+
+/// Return type bitwidth for vectorization purposes or 0 if type cannot be
+/// vectorized.
static unsigned getTypeBitWidth(mlir::Type type) {
if (mlir::isa<mlir::IndexType>(type))
return 64; // TODO: unhardcode
@@ -46,6 +56,8 @@ static bool isSupportedVecElem(mlir::Type type) {
return type.isIntOrIndexOrFloat();
}
+/// Check if one `ValueRange` is permutation of another, i.e. contains same
+/// values, potentially in different order.
static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) {
if (val1.size() != val2.size())
return false;
@@ -86,6 +98,10 @@ cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim,
return width;
}
+/// Check if memref load/store can be converted into vectorized load/store
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
static std::optional<unsigned>
cavTriviallyVectorizeMemOp(mlir::scf::ParallelOp loop, unsigned dim,
mlir::Operation &op) {
@@ -104,9 +120,10 @@ static bool isOp(mlir::Operation &op) {
return mlir::isa<T>(op);
}
+/// Returns `vector.reduce` kind for specified `scf.parallel` reduce op ot
+/// `std::nullopt` if reduction cannot be handled by `vector.reduce`.
static std::optional<mlir::vector::CombiningKind>
-getReductionKind(mlir::scf::ReduceOp op) {
- mlir::Block &body = op.getReductionOperator().front();
+getReductionKind(mlir::Block &body) {
if (!llvm::hasSingleElement(body.without_terminator()))
return std::nullopt;
@@ -140,23 +157,31 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
if (factor <= 1)
return std::nullopt;
+ /// Only step==1 is supported for now.
if (!mlir::isConstantIntValue(loop.getStep()[dim], 1))
return std::nullopt;
unsigned count = 0;
bool masked = true;
- for (mlir::Operation &op : loop.getBody()->without_terminator()) {
- if (auto reduce = mlir::dyn_cast<mlir::scf::ReduceOp>(op)) {
- if (!getReductionKind(reduce))
- masked = false;
+ /// Check if `scf.reduce` can be handled by `vector.reduce`.
+ /// If not we still can vectorize the loop but we cannot use masked
+ /// vectorize.
+ auto reduce =
+ mlir::cast<mlir::scf::ReduceOp>(loop.getBody()->getTerminator());
+ for (mlir::Region ® : reduce.getReductions()) {
+ if (!getReductionKind(reg.front()))
+ masked = false;
- continue;
- }
+ continue;
+ }
+ for (mlir::Operation &op : loop.getBody()->without_terminator()) {
+ /// Ops with nested regions are not supported yet.
if (op.getNumRegions() > 0)
return std::nullopt;
+ /// Check mem ops.
if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op)) {
auto newFactor = vectorBitwidth / *w;
if (newFactor > 1) {
@@ -166,6 +191,8 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
continue;
}
+ /// If met the op which cannot be vectorized, we can replicate it and still
+ /// potentially vectorize other ops, but we cannot use masked vectorize.
if (!isSupportedVectorOp(op)) {
masked = false;
continue;
@@ -184,12 +211,14 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
++count;
}
+ /// No ops to vectorize.
if (count == 0)
return std::nullopt;
return SCFVectorizeInfo{dim, factor, count, masked};
}
+/// Get fastmath flags if ops support them or default (none).
static mlir::arith::FastMathFlags getFMF(mlir::Operation &op) {
if (auto fmf = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(op))
return fmf.getFastMathFlagsAttr().getValue();
@@ -226,6 +255,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
mlir::Value count =
builder.create<mlir::arith::SubIOp>(loc, origUpper, origLower);
mlir::Value newCount;
+
+ // Compute new loop count, ceildiv if masked, floordiv otherwise.
if (masked) {
mlir::Value incCount =
builder.create<mlir::arith::AddIOp>(loc, count, factorVal);
@@ -240,6 +271,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
lower[dim] = zero;
upper[dim] = newCount;
+ // Vectorized loop.
auto newLoop = builder.create<mlir::scf::ParallelOp>(loc, lower, upper, step,
loop.getInitVals());
auto newIndexVar = newLoop.getInductionVars()[dim];
@@ -256,10 +288,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
return builder.create<mlir::ub::PoisonOp>(loc, vecType, nullptr);
};
+ // Get vector value in new loop for provided `orig` value in source loop.
auto getVecVal = [&](mlir::Value orig) -> mlir::Value {
+ // Use cached value if present.
if (auto mapped = mapping.lookupOrNull(orig))
return mapped;
+ // Vectorized loop index, loop index is divided by factor, so for factorN
+ // vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)`
if (orig == origIndexVar) {
auto vecType = toVectorType(builder.getIndexType());
llvm::SmallVector<mlir::Attribute> elems(factor);
@@ -283,9 +319,16 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
mlir::Value val = orig;
auto origIndexVars = loop.getInductionVars();
auto it = llvm::find(origIndexVars, orig);
+
+ // If loop index, but not on vectorized dimension, just take new loop index
+ // and splat it.
if (it != origIndexVars.end())
val = newLoop.getInductionVars()[it - origIndexVars.begin()];
+ // Values which are defined inside loop body are preemptively added to the
+ // mapper and not handled here. Values defined outside body are just
+ // splatted.
+
auto vecType = toVectorType(type);
mlir::Value vec = builder.create<mlir::vector::SplatOp>(loc, val, vecType);
mapping.map(orig, vec);
@@ -293,18 +336,28 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
};
llvm::DenseMap<mlir::Value, llvm::SmallVector<mlir::Value>> unpackedVals;
+
+ // Get unpacked values for provided `orig` value in source loop.
+ // Values are returned as `ValueRange` and not as vector value.
auto getUnpackedVals = [&](mlir::Value val) -> mlir::ValueRange {
+ // Use cached values if present.
auto it = unpackedVals.find(val);
if (it != unpackedVals.end())
return it->second;
+ // Values which are defined inside loop body are preemptively added to the
+ // cache and not handled here.
+
auto &ret = unpackedVals[val];
assert(ret.empty());
if (!isSupportedVecElem(val.getType())) {
+ // Non vectorizable value, it must be a value defined outside the loop,
+ // just replicate it.
ret.resize(factor, val);
return ret;
}
+ // Get vector value and extract elements from it.
auto vecVal = getVecVal(val);
ret.resize(factor);
for (auto i : llvm::seq(0u, factor)) {
@@ -314,6 +367,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
return ret;
};
+ // Add unpacked values to the cache.
auto setUnpackedVals = [&](mlir::Value origVal, mlir::ValueRange newVals) {
assert(newVals.size() == factor);
assert(unpackedVals.count(origVal) == 0);
@@ -323,6 +377,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
if (!isSupportedVecElem(type))
return;
+ // If type is vectorizabale construct a vector add it to vector cache as
+ // well.
auto vecType = toVectorType(type);
mlir::Value vec = createPosionVec(vecType);
@@ -335,6 +391,9 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
};
mlir::Value mask;
+
+ // Contruct mask value and cache it. If not a masked mode mask is always all
+ // 1s.
auto getMask = [&]() -> mlir::Value {
if (mask)
return mask;
@@ -360,6 +419,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op);
};
+ // Get idices for vectorized memref load/store.
auto getMemrefVecIndices = [&](mlir::ValueRange indices) {
scalarMapping.clear();
scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
@@ -379,6 +439,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
return ret;
};
+ // Check if memref access can be converted into gather/scatter.
auto canGatherScatter = [&](auto op) {
auto memref = op.getMemRef();
auto memrefType = mlir::cast<mlir::MemRefType>(memref.getType());
@@ -389,6 +450,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
memrefType.getLayout().isIdentity();
};
+ // Create vectorized memref load for specified non-vectorized load.
auto genLoad = [&](auto loadOp) {
auto indices = getMemrefVecIndices(loadOp.getIndices());
auto resType = toVectorType(loadOp.getResult().getType());
@@ -406,6 +468,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
mapping.map(loadOp.getResult(), vecLoad);
};
+ // Create vectorized memref store for specified non-vectorized store.
auto genStore = [&](auto storeOp) {
auto indices = getMemrefVecIndices(storeOp.getIndices());
auto value = getVecVal(storeOp.getValueToStore());
@@ -426,6 +489,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
for (mlir::Operation &op : loop.getBody()->without_terminator()) {
loc = op.getLoc();
if (isSupportedVectorOp(op)) {
+ // If op can be vectorized, clone it with vectorized inputs and update
+ // resuls to vectorized types.
for (auto arg : op.getOperands())
getVecVal(arg); // init mapper for op args
@@ -436,56 +501,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
continue;
}
- if (auto reduceOp = mlir::dyn_cast<mlir::scf::ReduceOp>(op)) {
- scalarMapping.clear();
- auto &reduceBody = reduceOp.getReductionOperator().front();
- assert(reduceBody.getNumArguments() == 2);
-
- mlir::Value reduceVal;
- if (auto redKind = getReductionKind(reduceOp)) {
- mlir::Value redArg = getVecVal(reduceOp.getOperand());
- if (redArg) {
- auto neutral = mlir::arith::getNeutralElement(&reduceBody.front());
- assert(neutral);
- mlir::Value neutralVal =
- builder.create<mlir::arith::ConstantOp>(loc, *neutral);
- mlir::Value neutralVec = builder.create<mlir::vector::SplatOp>(
- loc, neutralVal, redArg.getType());
- auto mask = getMask();
- redArg = builder.create<mlir::arith::SelectOp>(loc, mask, redArg,
- neutralVec);
- }
-
- auto fmf = getFMF(reduceBody.front());
- reduceVal = builder.create<mlir::vector::ReductionOp>(loc, *redKind,
- redArg, fmf);
- } else {
- if (masked)
- return op.emitError("Cannot vectorize op in masked mode");
-
- auto reduceTerm =
- mlir::cast<mlir::scf::ReduceReturnOp>(reduceBody.getTerminator());
- auto lhs = reduceBody.getArgument(0);
- auto rhs = reduceBody.getArgument(1);
- auto unpacked = getUnpackedVals(reduceOp.getOperand());
- assert(unpacked.size() == factor);
- reduceVal = unpacked.front();
- for (auto i : llvm::seq(1u, factor)) {
- mlir::Value val = unpacked[i];
- scalarMapping.map(lhs, reduceVal);
- scalarMapping.map(rhs, val);
- for (auto &redOp : reduceBody.without_terminator())
- builder.clone(redOp, scalarMapping);
-
- reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult());
- }
- }
- scalarMapping.clear();
- scalarMapping.map(reduceOp.getOperand(), reduceVal);
- builder.clone(op, scalarMapping);
- continue;
- }
-
+ // Vectorize memref load/store ops, vector load/store are preffered over
+ // gather/scatter.
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op)) {
if (canTriviallyVectorizeMemOp(loadOp)) {
genLoad(loadOp);
@@ -558,6 +575,70 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
}
}
+ // Vectorize `scf.reduce` op.
+ auto reduceOp =
+ mlir::cast<mlir::scf::ReduceOp>(loop.getBody()->getTerminator());
+ llvm::SmallVector<mlir::Value> reduceVals;
+ reduceVals.reserve(reduceOp.getNumOperands());
+
+ for (auto &&[body, arg] :
+ llvm::zip(reduceOp.getReductions(), reduceOp.getOperands())) {
+ scalarMapping.clear();
+ mlir::Block &reduceBody = body.front();
+ assert(reduceBody.getNumArguments() == 2);
+
+ mlir::Value reduceVal;
+ if (auto redKind = getReductionKind(reduceBody)) {
+ // Generate `vector.reduce` if possible.
+ mlir::Value redArg = getVecVal(arg);
+ if (redArg) {
+ auto neutral = mlir::arith::getNeutralElement(&reduceBody.front());
+ assert(neutral);
+ mlir::Value neutralVal =
+ builder.create<mlir::arith::ConstantOp>(loc, *neutral);
+ mlir::Value neutralVec = builder.create<mlir::vector::SplatOp>(
+ loc, neutralVal, redArg.getType());
+ auto mask = getMask();
+ redArg = builder.create<mlir::arith::SelectOp>(loc, mask, redArg,
+ neutralVec);
+ }
+
+ auto fmf = getFMF(reduceBody.front());
+ reduceVal =
+ builder.create<mlir::vector::ReductionOp>(loc, *redKind, redArg, fmf);
+ } else {
+ if (masked)
+ return reduceOp.emitError("Cannot vectorize reduce op in masked mode");
+
+ // If `vector.reduce` cannot be used, unpack values and reduce them
+ // individually.
+
+ auto reduceTerm =
+ mlir::cast<mlir::scf::ReduceReturnOp>(reduceBody.getTerminator());
+ auto lhs = reduceBody.getArgument(0);
+ auto rhs = reduceBody.getArgument(1);
+ auto unpacked = getUnpackedVals(arg);
+ assert(unpacked.size() == factor);
+ reduceVal = unpacked.front();
+ for (auto i : llvm::seq(1u, factor)) {
+ mlir::Value val = unpacked[i];
+ scalarMapping.map(lhs, reduceVal);
+ scalarMapping.map(rhs, val);
+ for (auto &redOp : reduceBody.without_terminator())
+ builder.clone(redOp, scalarMapping);
+
+ reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult());
+ }
+ }
+ reduceVals.emplace_back(reduceVal);
+ }
+
+ // Clone `scf.reduce` op to reduce across loop iterations.
+ if (!reduceVals.empty())
+ builder.clone(*reduceOp)->setOperands(reduceVals);
+
+ // If in masked mode remove old loop, otherwise update loop bounds to
+ // repurpose it for handling remaining values.
if (masked) {
loop->replaceAllUsesWith(newLoop.getResults());
loop->erase();
@@ -576,14 +657,12 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
return mlir::success();
}
-llvm::StringRef getVectorLengthName() { return "numba.vector_length"; }
-
static std::optional<unsigned> getVectorLength(mlir::Operation *op) {
auto func = op->getParentOfType<mlir::FunctionOpInterface>();
if (!func)
return std::nullopt;
- auto attr = func->getAttrOfType<mlir::IntegerAttr>(getVectorLengthName());
+ auto attr = func->getAttrOfType<mlir::IntegerAttr>("mlir.vector_length");
if (!attr)
return std::nullopt;
@@ -599,7 +678,8 @@ struct SCFVectorizePass
: public mlir::PassWrapper<SCFVectorizePass, mlir::OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFVectorizePass)
- void getDependentDialects(mlir::DialectRegistry ®istry) const override {
+ virtual void
+ getDependentDialects(mlir::DialectRegistry ®istry) const override {
registry.insert<mlir::arith::ArithDialect>();
registry.insert<mlir::scf::SCFDialect>();
registry.insert<mlir::ub::UBDialect>();
@@ -611,6 +691,8 @@ struct SCFVectorizePass
std::pair<mlir::scf::ParallelOp, mlir::SCFVectorizeParams>>
toVectorize;
+ // Simple heuristic: total number of elements processed by vector ops, but
+ // prefer masked mode over non-masked.
auto getBenefit = [](const mlir::SCFVectorizeInfo &info) {
return info.factor * info.count * (int(info.masked) + 1);
};
@@ -658,4 +740,4 @@ struct SCFVectorizePass
std::unique_ptr<mlir::Pass> mlir::createSCFVectorizePass() {
return std::make_unique<SCFVectorizePass>();
-}
+}
\ No newline at end of file
>From c9a4bfe563013cfc64f0c44643c2c8e97b48757f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 2 Jun 2024 17:07:51 +0200
Subject: [PATCH 03/10] working on pass
---
mlir/include/mlir/Transforms/SCFVectorize.h | 20 +-
mlir/lib/Transforms/SCFVectorize.cpp | 449 +++++++-----------
mlir/test/Transforms/test-scf-vectorize.mlir | 272 +++++++++++
mlir/test/lib/Transforms/CMakeLists.txt | 1 +
mlir/test/lib/Transforms/TestSCFVectorize.cpp | 110 +++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
6 files changed, 570 insertions(+), 284 deletions(-)
create mode 100644 mlir/test/Transforms/test-scf-vectorize.mlir
create mode 100644 mlir/test/lib/Transforms/TestSCFVectorize.cpp
diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h
index 93a7864b976ec..d2a5e3085ae37 100644
--- a/mlir/include/mlir/Transforms/SCFVectorize.h
+++ b/mlir/include/mlir/Transforms/SCFVectorize.h
@@ -9,12 +9,10 @@
#ifndef MLIR_TRANSFORMS_SCFVECTORIZE_H_
#define MLIR_TRANSFORMS_SCFVECTORIZE_H_
-#include <memory>
#include <optional>
namespace mlir {
-class OpBuilder;
-class Pass;
+class DataLayout;
struct LogicalResult;
namespace scf {
class ParallelOp;
@@ -43,9 +41,9 @@ struct SCFVectorizeInfo {
/// specified dimension.
///
/// `vectorBitwidth` - maximum vector size, in bits.
-std::optional<SCFVectorizeInfo> getLoopVectorizeInfo(mlir::scf::ParallelOp loop,
- unsigned dim,
- unsigned vectorBitwidth);
+std::optional<SCFVectorizeInfo>
+getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
+ unsigned vectorBitwidth, const DataLayout *DL = nullptr);
/// Vectorization params
struct SCFVectorizeParams {
@@ -64,11 +62,9 @@ struct SCFVectorizeParams {
/// If `masked` is `true` and loop bound is not divisible by `factor`, instead
/// of generating second loop to process remainig iterations, extend loop count
/// and generate masked vector ops to handle out-of bounds memory accesses.
-mlir::LogicalResult vectorizeLoop(mlir::OpBuilder &builder,
- mlir::scf::ParallelOp loop,
- const SCFVectorizeParams ¶ms);
-
-std::unique_ptr<mlir::Pass> createSCFVectorizePass();
+mlir::LogicalResult vectorizeLoop(mlir::scf::ParallelOp loop,
+ const SCFVectorizeParams ¶ms,
+ const DataLayout *DL = nullptr);
} // namespace mlir
-#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
\ No newline at end of file
+#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
index 13a9eca9cd2d3..29e184e584a56 100644
--- a/mlir/lib/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -1,4 +1,4 @@
-//===- ControlFlowSink.cpp - Code to perform control-flow sinking ---------===//
+//===- SCFVectorize.cpp - SCF vectorization utilities ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -9,27 +9,25 @@
#include "mlir/Transforms/SCFVectorize.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/IRMapping.h"
-#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Pass/Pass.h"
-#include <mlir/Dialect/Arith/IR/Arith.h>
-#include <mlir/Dialect/MemRef/IR/MemRef.h>
-#include <mlir/Dialect/SCF/IR/SCF.h>
-#include <mlir/Dialect/UB/IR/UBOps.h>
-#include <mlir/Dialect/Vector/IR/VectorOps.h>
-#include <mlir/IR/IRMapping.h>
-#include <mlir/Interfaces/FunctionInterfaces.h>
-#include <mlir/Pass/Pass.h>
+
+using namespace mlir;
+
+static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); }
/// Return type bitwidth for vectorization purposes or 0 if type cannot be
/// vectorized.
-static unsigned getTypeBitWidth(mlir::Type type) {
- if (mlir::isa<mlir::IndexType>(type))
- return 64; // TODO: unhardcode
+static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
+ if (!isSupportedVecElem(type))
+ return 0;
+
+ if (DL)
+ return DL->getTypeSizeInBits(type);
if (type.isIntOrFloat())
return type.getIntOrFloatBitWidth();
@@ -37,28 +35,24 @@ static unsigned getTypeBitWidth(mlir::Type type) {
return 0;
}
-static unsigned getArgsTypeWidth(mlir::Operation &op) {
+static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) {
unsigned ret = 0;
for (auto arg : op.getOperands())
- ret = std::max(ret, getTypeBitWidth(arg.getType()));
+ ret = std::max(ret, getTypeBitWidth(arg.getType(), DL));
for (auto res : op.getResults())
- ret = std::max(ret, getTypeBitWidth(res.getType()));
+ ret = std::max(ret, getTypeBitWidth(res.getType(), DL));
return ret;
}
-static bool isSupportedVectorOp(mlir::Operation &op) {
- return op.hasTrait<mlir::OpTrait::Vectorizable>();
-}
-
-static bool isSupportedVecElem(mlir::Type type) {
- return type.isIntOrIndexOrFloat();
+static bool isSupportedVectorOp(Operation &op) {
+ return op.hasTrait<OpTrait::Vectorizable>();
}
/// Check if one `ValueRange` is permutation of another, i.e. contains same
/// values, potentially in different order.
-static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) {
+static bool isRangePermutation(ValueRange val1, ValueRange val2) {
if (val1.size() != val2.size())
return false;
@@ -72,13 +66,13 @@ static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) {
template <typename Op>
static std::optional<unsigned>
-cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim,
- Op memOp) {
+cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
+ const DataLayout *DL) {
auto loopIndexVars = loop.getInductionVars();
assert(dim < loopIndexVars.size());
auto memref = memOp.getMemRef();
- auto type = mlir::cast<mlir::MemRefType>(memref.getType());
- auto width = getTypeBitWidth(type.getElementType());
+ auto type = cast<MemRefType>(memref.getType());
+ auto width = getTypeBitWidth(type.getElementType(), DL);
if (width == 0)
return std::nullopt;
@@ -91,7 +85,7 @@ cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim,
if (memOp.getIndices().back() != loopIndexVars[dim])
return std::nullopt;
- mlir::DominanceInfo dom;
+ DominanceInfo dom;
if (!dom.properlyDominates(memref, loop))
return std::nullopt;
@@ -103,54 +97,69 @@ cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim,
/// Returns memref element bitwidth or `std::nullopt` if access cannot be
/// vectorized.
static std::optional<unsigned>
-cavTriviallyVectorizeMemOp(mlir::scf::ParallelOp loop, unsigned dim,
- mlir::Operation &op) {
+cavTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
+ const DataLayout *DL) {
assert(dim < loop.getInductionVars().size());
- if (auto storeOp = mlir::dyn_cast<mlir::memref::StoreOp>(op))
- return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp);
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
+
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL);
+
+ return std::nullopt;
+}
+
+template <typename Op>
+static std::optional<unsigned> canGatherScatterImpl(scf::ParallelOp loop, Op op,
+ const DataLayout *DL) {
+ auto memref = op.getMemRef();
+ auto memrefType = cast<MemRefType>(memref.getType());
+ auto width = getTypeBitWidth(memrefType.getElementType(), DL);
+ if (width == 0)
+ return std::nullopt;
+
+ DominanceInfo dom;
+ return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
+ memrefType.getLayout().isIdentity();
+}
- if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op))
- return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp);
+// Check if memref access can be converted into gather/scatter.
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
+static std::optional<unsigned>
+canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) {
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return canGatherScatterImpl(loop, storeOp, DL);
+
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return canGatherScatterImpl(loop, loadOp, DL);
return std::nullopt;
}
-template <typename T>
-static bool isOp(mlir::Operation &op) {
- return mlir::isa<T>(op);
+static std::optional<unsigned> cenVectorizeMemrefOp(scf::ParallelOp loop,
+ unsigned dim, Operation &op,
+ const DataLayout *DL) {
+ if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op, DL))
+ return w;
+
+ return canGatherScatter(loop, op, DL);
}
/// Returns `vector.reduce` kind for specified `scf.parallel` reduce op ot
/// `std::nullopt` if reduction cannot be handled by `vector.reduce`.
-static std::optional<mlir::vector::CombiningKind>
-getReductionKind(mlir::Block &body) {
+static std::optional<vector::CombiningKind> getReductionKind(Block &body) {
if (!llvm::hasSingleElement(body.without_terminator()))
return std::nullopt;
- mlir::Operation &redOp = body.front();
-
- using fptr_t = bool (*)(mlir::Operation &);
- using CC = mlir::vector::CombiningKind;
- const std::pair<fptr_t, CC> handlers[] = {
- // clang-format off
- {&isOp<mlir::arith::AddIOp>, CC::ADD},
- {&isOp<mlir::arith::AddFOp>, CC::ADD},
- {&isOp<mlir::arith::MulIOp>, CC::MUL},
- {&isOp<mlir::arith::MulFOp>, CC::MUL},
- // clang-format on
- };
-
- for (auto &&[handler, cc] : handlers) {
- if (handler(redOp))
- return cc;
- }
-
- return std::nullopt;
+ // TODO: Move getCombinerOpKind to vector dialect.
+ return linalg::getCombinerOpKind(&body.front());
}
-std::optional<mlir::SCFVectorizeInfo>
-mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
- unsigned vectorBitwidth) {
+std::optional<SCFVectorizeInfo>
+mlir::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
+ unsigned vectorBitwidth, const DataLayout *DL) {
assert(dim < loop.getStep().size());
assert(vectorBitwidth > 0);
unsigned factor = vectorBitwidth / 8;
@@ -158,7 +167,7 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
return std::nullopt;
/// Only step==1 is supported for now.
- if (!mlir::isConstantIntValue(loop.getStep()[dim], 1))
+ if (!isConstantIntValue(loop.getStep()[dim], 1))
return std::nullopt;
unsigned count = 0;
@@ -167,22 +176,21 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
/// Check if `scf.reduce` can be handled by `vector.reduce`.
/// If not we still can vectorize the loop but we cannot use masked
/// vectorize.
- auto reduce =
- mlir::cast<mlir::scf::ReduceOp>(loop.getBody()->getTerminator());
- for (mlir::Region ® : reduce.getReductions()) {
+ auto reduce = cast<scf::ReduceOp>(loop.getBody()->getTerminator());
+ for (Region ® : reduce.getReductions()) {
if (!getReductionKind(reg.front()))
masked = false;
continue;
}
- for (mlir::Operation &op : loop.getBody()->without_terminator()) {
+ for (Operation &op : loop.getBody()->without_terminator()) {
/// Ops with nested regions are not supported yet.
if (op.getNumRegions() > 0)
return std::nullopt;
/// Check mem ops.
- if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op)) {
+ if (auto w = cenVectorizeMemrefOp(loop, dim, op, DL)) {
auto newFactor = vectorBitwidth / *w;
if (newFactor > 1) {
factor = std::min(factor, newFactor);
@@ -198,7 +206,7 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
continue;
}
- auto width = getArgsTypeWidth(op);
+ auto width = getArgsTypeWidth(op, DL);
if (width == 0)
return std::nullopt;
@@ -219,26 +227,24 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
}
/// Get fastmath flags if ops support them or default (none).
-static mlir::arith::FastMathFlags getFMF(mlir::Operation &op) {
- if (auto fmf = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(op))
+static arith::FastMathFlags getFMF(Operation &op) {
+ if (auto fmf = dyn_cast<arith::ArithFastMathInterface>(op))
return fmf.getFastMathFlagsAttr().getValue();
- return mlir::arith::FastMathFlags::none;
+ return arith::FastMathFlags::none;
}
-mlir::LogicalResult
-mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
- const mlir::SCFVectorizeParams ¶ms) {
+LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
+ const SCFVectorizeParams ¶ms,
+ const DataLayout *DL) {
auto dim = params.dim;
auto factor = params.factor;
auto masked = params.masked;
assert(dim < loop.getStep().size());
assert(factor > 1);
- assert(mlir::isConstantIntValue(loop.getStep()[dim], 1));
-
- mlir::OpBuilder::InsertionGuard g(builder);
- builder.setInsertionPoint(loop);
+ assert(isConstantIntValue(loop.getStep()[dim], 1));
+ OpBuilder builder(loop);
auto lower = llvm::to_vector(loop.getLowerBound());
auto upper = llvm::to_vector(loop.getUpperBound());
auto step = llvm::to_vector(loop.getStep());
@@ -247,49 +253,53 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
auto origIndexVar = loop.getInductionVars()[dim];
- mlir::Value factorVal =
- builder.create<mlir::arith::ConstantIndexOp>(loc, factor);
+ Value factorVal = builder.create<arith::ConstantIndexOp>(loc, factor);
auto origLower = lower[dim];
auto origUpper = upper[dim];
- mlir::Value count =
- builder.create<mlir::arith::SubIOp>(loc, origUpper, origLower);
- mlir::Value newCount;
+ Value count = builder.createOrFold<arith::SubIOp>(loc, origUpper, origLower);
+ Value newCount;
// Compute new loop count, ceildiv if masked, floordiv otherwise.
if (masked) {
- mlir::Value incCount =
- builder.create<mlir::arith::AddIOp>(loc, count, factorVal);
- mlir::Value one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
- incCount = builder.create<mlir::arith::SubIOp>(loc, incCount, one);
- newCount = builder.create<mlir::arith::DivSIOp>(loc, incCount, factorVal);
+ newCount = builder.createOrFold<arith::CeilDivSIOp>(loc, count, factorVal);
} else {
- newCount = builder.create<mlir::arith::DivSIOp>(loc, count, factorVal);
+ newCount = builder.createOrFold<arith::DivSIOp>(loc, count, factorVal);
}
- mlir::Value zero = builder.create<mlir::arith::ConstantIndexOp>(loc, 0);
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
lower[dim] = zero;
upper[dim] = newCount;
// Vectorized loop.
- auto newLoop = builder.create<mlir::scf::ParallelOp>(loc, lower, upper, step,
- loop.getInitVals());
+ auto newLoop = builder.create<scf::ParallelOp>(loc, lower, upper, step,
+ loop.getInitVals());
auto newIndexVar = newLoop.getInductionVars()[dim];
- auto toVectorType = [&](mlir::Type elemType) -> mlir::VectorType {
+ auto toVectorType = [&](Type elemType) -> VectorType {
int64_t f = factor;
- return mlir::VectorType::get(f, elemType);
+ return VectorType::get(f, elemType);
+ };
+
+ IRMapping mapping;
+ IRMapping scalarMapping;
+
+ auto createPosionVec = [&](VectorType vecType) -> Value {
+ return builder.create<ub::PoisonOp>(loc, vecType, nullptr);
};
- mlir::IRMapping mapping;
- mlir::IRMapping scalarMapping;
+ Value indexVarMult;
+ auto getrIndexVarMult = [&]() -> Value {
+ if (indexVarMult)
+ return indexVarMult;
- auto createPosionVec = [&](mlir::VectorType vecType) -> mlir::Value {
- return builder.create<mlir::ub::PoisonOp>(loc, vecType, nullptr);
+ indexVarMult =
+ builder.createOrFold<arith::MulIOp>(loc, newIndexVar, factorVal);
+ return indexVarMult;
};
// Get vector value in new loop for provided `orig` value in source loop.
- auto getVecVal = [&](mlir::Value orig) -> mlir::Value {
+ auto getVecVal = [&](Value orig) -> Value {
// Use cached value if present.
if (auto mapped = mapping.lookupOrNull(orig))
return mapped;
@@ -298,25 +308,23 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
// vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)`
if (orig == origIndexVar) {
auto vecType = toVectorType(builder.getIndexType());
- llvm::SmallVector<mlir::Attribute> elems(factor);
+ llvm::SmallVector<Attribute> elems(factor);
for (auto i : llvm::seq(0u, factor))
elems[i] = builder.getIndexAttr(i);
- auto attr = mlir::DenseElementsAttr::get(vecType, elems);
- mlir::Value vec =
- builder.create<mlir::arith::ConstantOp>(loc, vecType, attr);
-
- mlir::Value idx =
- builder.create<mlir::arith::MulIOp>(loc, newIndexVar, factorVal);
- idx = builder.create<mlir::arith::AddIOp>(loc, idx, origLower);
- idx = builder.create<mlir::vector::SplatOp>(loc, idx, vecType);
- vec = builder.create<mlir::arith::AddIOp>(loc, idx, vec);
+ auto attr = DenseElementsAttr::get(vecType, elems);
+ Value vec = builder.create<arith::ConstantOp>(loc, vecType, attr);
+
+ Value idx = getrIndexVarMult();
+ idx = builder.createOrFold<arith::AddIOp>(loc, idx, origLower);
+ idx = builder.create<vector::SplatOp>(loc, idx, vecType);
+ vec = builder.createOrFold<arith::AddIOp>(loc, idx, vec);
mapping.map(orig, vec);
return vec;
}
auto type = orig.getType();
assert(isSupportedVecElem(type));
- mlir::Value val = orig;
+ Value val = orig;
auto origIndexVars = loop.getInductionVars();
auto it = llvm::find(origIndexVars, orig);
@@ -330,16 +338,16 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
// splatted.
auto vecType = toVectorType(type);
- mlir::Value vec = builder.create<mlir::vector::SplatOp>(loc, val, vecType);
+ Value vec = builder.create<vector::SplatOp>(loc, val, vecType);
mapping.map(orig, vec);
return vec;
};
- llvm::DenseMap<mlir::Value, llvm::SmallVector<mlir::Value>> unpackedVals;
+ llvm::DenseMap<Value, llvm::SmallVector<Value>> unpackedVals;
// Get unpacked values for provided `orig` value in source loop.
// Values are returned as `ValueRange` and not as vector value.
- auto getUnpackedVals = [&](mlir::Value val) -> mlir::ValueRange {
+ auto getUnpackedVals = [&](Value val) -> ValueRange {
// Use cached values if present.
auto it = unpackedVals.find(val);
if (it != unpackedVals.end())
@@ -361,14 +369,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
auto vecVal = getVecVal(val);
ret.resize(factor);
for (auto i : llvm::seq(0u, factor)) {
- mlir::Value idx = builder.create<mlir::arith::ConstantIndexOp>(loc, i);
- ret[i] = builder.create<mlir::vector::ExtractElementOp>(loc, vecVal, idx);
+ Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
+ ret[i] = builder.create<vector::ExtractElementOp>(loc, vecVal, idx);
}
return ret;
};
// Add unpacked values to the cache.
- auto setUnpackedVals = [&](mlir::Value origVal, mlir::ValueRange newVals) {
+ auto setUnpackedVals = [&](Value origVal, ValueRange newVals) {
assert(newVals.size() == factor);
assert(unpackedVals.count(origVal) == 0);
unpackedVals[origVal].append(newVals.begin(), newVals.end());
@@ -381,55 +389,53 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
// well.
auto vecType = toVectorType(type);
- mlir::Value vec = createPosionVec(vecType);
+ Value vec = createPosionVec(vecType);
for (auto i : llvm::seq(0u, factor)) {
- mlir::Value idx = builder.create<mlir::arith::ConstantIndexOp>(loc, i);
- vec = builder.create<mlir::vector::InsertElementOp>(loc, newVals[i], vec,
- idx);
+ Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
+ vec = builder.create<vector::InsertElementOp>(loc, newVals[i], vec, idx);
}
mapping.map(origVal, vec);
};
- mlir::Value mask;
+ Value mask;
// Contruct mask value and cache it. If not a masked mode mask is always all
// 1s.
- auto getMask = [&]() -> mlir::Value {
+ auto getMask = [&]() -> Value {
if (mask)
return mask;
- mlir::OpFoldResult maskSize;
+ OpFoldResult maskSize;
if (masked) {
- mlir::Value size =
- builder.create<mlir::arith::MulIOp>(loc, factorVal, newIndexVar);
- maskSize =
- builder.create<mlir::arith::SubIOp>(loc, count, size).getResult();
+ Value size = getrIndexVarMult();
+ maskSize = builder.createOrFold<arith::SubIOp>(loc, count, size);
} else {
maskSize = builder.getIndexAttr(factor);
}
auto vecType = toVectorType(builder.getI1Type());
- mask = builder.create<mlir::vector::CreateMaskOp>(loc, vecType, maskSize);
+ mask = builder.create<vector::CreateMaskOp>(loc, vecType, maskSize);
return mask;
};
- mlir::DominanceInfo dom;
-
auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
- return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op);
+ return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
+ };
+
+ auto canGatherScatter = [&](auto op) {
+ return !!::canGatherScatterImpl(loop, op, DL);
};
// Get idices for vectorized memref load/store.
- auto getMemrefVecIndices = [&](mlir::ValueRange indices) {
+ auto getMemrefVecIndices = [&](ValueRange indices) {
scalarMapping.clear();
scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
- llvm::SmallVector<mlir::Value> ret(indices.size());
+ llvm::SmallVector<Value> ret(indices.size());
for (auto &&[i, val] : llvm::enumerate(indices)) {
if (val == origIndexVar) {
- mlir::Value idx =
- builder.create<mlir::arith::MulIOp>(loc, newIndexVar, factorVal);
- idx = builder.create<mlir::arith::AddIOp>(loc, idx, origLower);
+ Value idx = getrIndexVarMult();
+ idx = builder.createOrFold<arith::AddIOp>(loc, idx, origLower);
ret[i] = idx;
continue;
}
@@ -439,31 +445,19 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
return ret;
};
- // Check if memref access can be converted into gather/scatter.
- auto canGatherScatter = [&](auto op) {
- auto memref = op.getMemRef();
- auto memrefType = mlir::cast<mlir::MemRefType>(memref.getType());
- if (!isSupportedVecElem(memrefType.getElementType()))
- return false;
-
- return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
- memrefType.getLayout().isIdentity();
- };
-
// Create vectorized memref load for specified non-vectorized load.
auto genLoad = [&](auto loadOp) {
auto indices = getMemrefVecIndices(loadOp.getIndices());
auto resType = toVectorType(loadOp.getResult().getType());
auto memref = loadOp.getMemRef();
- mlir::Value vecLoad;
+ Value vecLoad;
if (masked) {
auto mask = getMask();
auto init = createPosionVec(resType);
- vecLoad = builder.create<mlir::vector::MaskedLoadOp>(loc, resType, memref,
- indices, mask, init);
+ vecLoad = builder.create<vector::MaskedLoadOp>(loc, resType, memref,
+ indices, mask, init);
} else {
- vecLoad =
- builder.create<mlir::vector::LoadOp>(loc, resType, memref, indices);
+ vecLoad = builder.create<vector::LoadOp>(loc, resType, memref, indices);
}
mapping.map(loadOp.getResult(), vecLoad);
};
@@ -475,18 +469,17 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
auto memref = storeOp.getMemRef();
if (masked) {
auto mask = getMask();
- builder.create<mlir::vector::MaskedStoreOp>(loc, memref, indices, mask,
- value);
+ builder.create<vector::MaskedStoreOp>(loc, memref, indices, mask, value);
} else {
- builder.create<mlir::vector::StoreOp>(loc, value, memref, indices);
+ builder.create<vector::StoreOp>(loc, value, memref, indices);
}
};
- llvm::SmallVector<mlir::Value> duplicatedArgs;
- llvm::SmallVector<mlir::Value> duplicatedResults;
+ llvm::SmallVector<Value> duplicatedArgs;
+ llvm::SmallVector<Value> duplicatedResults;
builder.setInsertionPointToStart(newLoop.getBody());
- for (mlir::Operation &op : loop.getBody()->without_terminator()) {
+ for (Operation &op : loop.getBody()->without_terminator()) {
loc = op.getLoc();
if (isSupportedVectorOp(op)) {
// If op can be vectorized, clone it with vectorized inputs and update
@@ -503,7 +496,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
// Vectorize memref load/store ops, vector load/store are preffered over
// gather/scatter.
- if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op)) {
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
if (canTriviallyVectorizeMemOp(loadOp)) {
genLoad(loadOp);
continue;
@@ -515,14 +508,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
auto indexVec = getVecVal(loadOp.getIndices()[0]);
auto init = createPosionVec(resType);
- auto gather = builder.create<mlir::vector::GatherOp>(
+ auto gather = builder.create<vector::GatherOp>(
loc, resType, memref, zero, indexVec, mask, init);
mapping.map(loadOp.getResult(), gather.getResult());
continue;
}
}
- if (auto storeOp = mlir::dyn_cast<mlir::memref::StoreOp>(op)) {
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
if (canTriviallyVectorizeMemOp(storeOp)) {
genStore(storeOp);
continue;
@@ -533,8 +526,9 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
auto mask = getMask();
auto indexVec = getVecVal(storeOp.getIndices()[0]);
- builder.create<mlir::vector::ScatterOp>(loc, memref, zero, indexVec,
- mask, value);
+ builder.create<vector::ScatterOp>(loc, memref, zero, indexVec, mask,
+ value);
+ continue;
}
}
@@ -557,7 +551,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
}
for (auto i : llvm::seq(0u, factor)) {
- auto args = mlir::ValueRange(duplicatedArgs)
+ auto args = ValueRange(duplicatedArgs)
.drop_front(numArgs * i)
.take_front(numArgs);
scalarMapping.map(op.getOperands(), args);
@@ -568,7 +562,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
}
for (auto i : llvm::seq(0u, numResults)) {
- auto results = mlir::ValueRange(duplicatedResults)
+ auto results = ValueRange(duplicatedResults)
.drop_front(factor * i)
.take_front(factor);
setUnpackedVals(op.getResult(i), results);
@@ -576,36 +570,33 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
}
// Vectorize `scf.reduce` op.
- auto reduceOp =
- mlir::cast<mlir::scf::ReduceOp>(loop.getBody()->getTerminator());
- llvm::SmallVector<mlir::Value> reduceVals;
+ auto reduceOp = cast<scf::ReduceOp>(loop.getBody()->getTerminator());
+ llvm::SmallVector<Value> reduceVals;
reduceVals.reserve(reduceOp.getNumOperands());
for (auto &&[body, arg] :
llvm::zip(reduceOp.getReductions(), reduceOp.getOperands())) {
scalarMapping.clear();
- mlir::Block &reduceBody = body.front();
+ Block &reduceBody = body.front();
assert(reduceBody.getNumArguments() == 2);
- mlir::Value reduceVal;
+ Value reduceVal;
if (auto redKind = getReductionKind(reduceBody)) {
// Generate `vector.reduce` if possible.
- mlir::Value redArg = getVecVal(arg);
+ Value redArg = getVecVal(arg);
if (redArg) {
- auto neutral = mlir::arith::getNeutralElement(&reduceBody.front());
+ auto neutral = arith::getNeutralElement(&reduceBody.front());
assert(neutral);
- mlir::Value neutralVal =
- builder.create<mlir::arith::ConstantOp>(loc, *neutral);
- mlir::Value neutralVec = builder.create<mlir::vector::SplatOp>(
- loc, neutralVal, redArg.getType());
+ Value neutralVal = builder.create<arith::ConstantOp>(loc, *neutral);
+ Value neutralVec =
+ builder.create<vector::SplatOp>(loc, neutralVal, redArg.getType());
auto mask = getMask();
- redArg = builder.create<mlir::arith::SelectOp>(loc, mask, redArg,
- neutralVec);
+ redArg = builder.create<arith::SelectOp>(loc, mask, redArg, neutralVec);
}
auto fmf = getFMF(reduceBody.front());
reduceVal =
- builder.create<mlir::vector::ReductionOp>(loc, *redKind, redArg, fmf);
+ builder.create<vector::ReductionOp>(loc, *redKind, redArg, fmf);
} else {
if (masked)
return reduceOp.emitError("Cannot vectorize reduce op in masked mode");
@@ -613,15 +604,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
// If `vector.reduce` cannot be used, unpack values and reduce them
// individually.
- auto reduceTerm =
- mlir::cast<mlir::scf::ReduceReturnOp>(reduceBody.getTerminator());
+ auto reduceTerm = cast<scf::ReduceReturnOp>(reduceBody.getTerminator());
auto lhs = reduceBody.getArgument(0);
auto rhs = reduceBody.getArgument(1);
auto unpacked = getUnpackedVals(arg);
assert(unpacked.size() == factor);
reduceVal = unpacked.front();
for (auto i : llvm::seq(1u, factor)) {
- mlir::Value val = unpacked[i];
+ Value val = unpacked[i];
scalarMapping.map(lhs, reduceVal);
scalarMapping.map(rhs, val);
for (auto &redOp : reduceBody.without_terminator())
@@ -644,9 +634,9 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
loop->erase();
} else {
builder.setInsertionPoint(loop);
- mlir::Value newLower =
- builder.create<mlir::arith::MulIOp>(loc, newCount, factorVal);
- newLower = builder.create<mlir::arith::AddIOp>(loc, origLower, newLower);
+ Value newLower =
+ builder.createOrFold<arith::MulIOp>(loc, newCount, factorVal);
+ newLower = builder.createOrFold<arith::AddIOp>(loc, origLower, newLower);
auto lowerCopy = llvm::to_vector(loop.getLowerBound());
lowerCopy[dim] = newLower;
@@ -654,90 +644,5 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop,
loop.getInitValsMutable().assign(newLoop.getResults());
}
- return mlir::success();
-}
-
-static std::optional<unsigned> getVectorLength(mlir::Operation *op) {
- auto func = op->getParentOfType<mlir::FunctionOpInterface>();
- if (!func)
- return std::nullopt;
-
- auto attr = func->getAttrOfType<mlir::IntegerAttr>("mlir.vector_length");
- if (!attr)
- return std::nullopt;
-
- auto val = attr.getInt();
- if (val <= 0 || val > std::numeric_limits<unsigned>::max())
- return std::nullopt;
-
- return static_cast<unsigned>(val);
+ return success();
}
-
-namespace {
-struct SCFVectorizePass
- : public mlir::PassWrapper<SCFVectorizePass, mlir::OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFVectorizePass)
-
- virtual void
- getDependentDialects(mlir::DialectRegistry ®istry) const override {
- registry.insert<mlir::arith::ArithDialect>();
- registry.insert<mlir::scf::SCFDialect>();
- registry.insert<mlir::ub::UBDialect>();
- registry.insert<mlir::vector::VectorDialect>();
- }
-
- void runOnOperation() override {
- llvm::SmallVector<
- std::pair<mlir::scf::ParallelOp, mlir::SCFVectorizeParams>>
- toVectorize;
-
- // Simple heuristic: total number of elements processed by vector ops, but
- // prefer masked mode over non-masked.
- auto getBenefit = [](const mlir::SCFVectorizeInfo &info) {
- return info.factor * info.count * (int(info.masked) + 1);
- };
-
- getOperation()->walk([&](mlir::scf::ParallelOp loop) {
- auto len = getVectorLength(loop);
- if (!len)
- return;
-
- std::optional<mlir::SCFVectorizeInfo> best;
- for (auto dim : llvm::seq(0u, loop.getNumLoops())) {
- auto info = mlir::getLoopVectorizeInfo(loop, dim, *len);
- if (!info)
- continue;
-
- if (!best) {
- best = *info;
- continue;
- }
-
- if (getBenefit(*info) > getBenefit(*best))
- best = *info;
- }
-
- if (!best)
- return;
-
- toVectorize.emplace_back(
- loop,
- mlir::SCFVectorizeParams{best->dim, best->factor, best->masked});
- });
-
- if (toVectorize.empty())
- return markAllAnalysesPreserved();
-
- mlir::OpBuilder builder(&getContext());
- for (auto &&[loop, params] : toVectorize) {
- builder.setInsertionPoint(loop);
- if (mlir::failed(mlir::vectorizeLoop(builder, loop, params)))
- return signalPassFailure();
- }
- }
-};
-} // namespace
-
-std::unique_ptr<mlir::Pass> mlir::createSCFVectorizePass() {
- return std::make_unique<SCFVectorizePass>();
-}
\ No newline at end of file
diff --git a/mlir/test/Transforms/test-scf-vectorize.mlir b/mlir/test/Transforms/test-scf-vectorize.mlir
new file mode 100644
index 0000000000000..f4a817b44aa39
--- /dev/null
+++ b/mlir/test/Transforms/test-scf-vectorize.mlir
@@ -0,0 +1,272 @@
+// RUN: mlir-opt %s --test-scf-vectorize=vector-bitwidth=128 -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @test
+// CHECK-SAME: (%[[A:.*]]: memref<?xi32>, %[[B:.*]]: memref<?xi32>, %[[C:.*]]: memref<?xi32>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?xi32>
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index
+// CHECK: scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) {
+// CHECK: %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index
+// CHECK: %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1>
+// CHECK: %[[P:.*]] = ub.poison : vector<4xi32>
+// CHECK: %[[A_VAL:.*]] = vector.maskedload %[[A]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xi32>, vector<4xi1>, vector<4xi32> into vector<4xi32>
+// CHECK: %[[P:.*]] = ub.poison : vector<4xi32>
+// CHECK: %[[B_VAL:.*]] = vector.maskedload %[[B]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xi32>, vector<4xi1>, vector<4xi32> into vector<4xi32>
+// CHECK: %[[RES:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xi32>
+// CHECK: vector.maskedstore %[[C]][%1], %[[MASK]], %[[RES]] : memref<?xi32>, vector<4xi1>, vector<4xi32>
+// CHECK: scf.reduce
+func.func @test(%A: memref<?xi32>, %B: memref<?xi32>, %C: memref<?xi32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %count = memref.dim %A, %c0 : memref<?xi32>
+ scf.parallel (%i) = (%c0) to (%count) step (%c1) {
+ %1 = memref.load %A[%i] : memref<?xi32>
+ %2 = memref.load %B[%i] : memref<?xi32>
+ %3 = arith.addi %1, %2 : i32
+ memref.store %3, %C[%i] : memref<?xi32>
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test
+// CHECK-SAME: (%[[A:.*]]: memref<?xindex>, %[[B:.*]]: memref<?xindex>, %[[C:.*]]: memref<?xindex>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?xindex>
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index
+// CHECK: scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) {
+// CHECK: %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index
+// CHECK: %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1>
+// CHECK: %[[P:.*]] = ub.poison : vector<4xindex>
+// CHECK: %[[A_VAL:.*]] = vector.maskedload %[[A]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xindex>, vector<4xi1>, vector<4xindex> into vector<4xindex>
+// CHECK: %[[P:.*]] = ub.poison : vector<4xindex>
+// CHECK: %[[B_VAL:.*]] = vector.maskedload %[[B]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xindex>, vector<4xi1>, vector<4xindex> into vector<4xindex>
+// CHECK: %[[RES:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xindex>
+// CHECK: vector.maskedstore %[[C]][%1], %[[MASK]], %[[RES]] : memref<?xindex>, vector<4xi1>, vector<4xindex>
+// CHECK: scf.reduce
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
+func.func @test(%A: memref<?xindex>, %B: memref<?xindex>, %C: memref<?xindex>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %count = memref.dim %A, %c0 : memref<?xindex>
+ scf.parallel (%i) = (%c0) to (%count) step (%c1) {
+ %1 = memref.load %A[%i] : memref<?xindex>
+ %2 = memref.load %B[%i] : memref<?xindex>
+ %3 = arith.addi %1, %2 : index
+ memref.store %3, %C[%i] : memref<?xindex>
+ }
+ return
+}
+}
+
+// -----
+
+func.func private @non_vectorizable(i32) -> (i32)
+
+// CHECK-LABEL: @test
+// CHECK-SAME: (%[[A:.*]]: memref<?xi32>, %[[B:.*]]: memref<?xi32>, %[[C:.*]]: memref<?xi32>) {
+// CHECK: %[[C00:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C00]] : memref<?xi32>
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[COUNT:.*]] = arith.divsi %[[DIM]], %[[C4]] : index
+// CHECK: scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) {
+// CHECK: %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index
+// CHECK: %[[A_VAL:.*]] = vector.load %[[A]][%[[MULT]]] : memref<?xi32>, vector<4xi32>
+// CHECK: %[[B_VAL:.*]] = vector.load %[[B]][%[[MULT]]] : memref<?xi32>, vector<4xi32>
+// CHECK: %[[R1:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xi32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[E0:.*]] = vector.extractelement %[[R1]][%[[C0]] : index] : vector<4xi32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[E1:.*]] = vector.extractelement %[[R1]][%[[C1]] : index] : vector<4xi32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[E2:.*]] = vector.extractelement %[[R1]][%[[C2]] : index] : vector<4xi32>
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[E3:.*]] = vector.extractelement %[[R1]][%[[C3]] : index] : vector<4xi32>
+// CHECK: %[[R2:.*]] = func.call @non_vectorizable(%[[E0]]) : (i32) -> i32
+// CHECK: %[[R3:.*]] = func.call @non_vectorizable(%[[E1]]) : (i32) -> i32
+// CHECK: %[[R4:.*]] = func.call @non_vectorizable(%[[E2]]) : (i32) -> i32
+// CHECK: %[[R5:.*]] = func.call @non_vectorizable(%[[E3]]) : (i32) -> i32
+// CHECK: %[[RES1:.*]] = ub.poison : vector<4xi32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[RES2:.*]] = vector.insertelement %[[R2]], %[[RES1]][%[[C0]] : index] : vector<4xi32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[RES3:.*]] = vector.insertelement %[[R3]], %[[RES2]][%[[C1]] : index] : vector<4xi32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[RES4:.*]] = vector.insertelement %[[R4]], %[[RES3]][%[[C2]] : index] : vector<4xi32>
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[RES5:.*]] = vector.insertelement %[[R5]], %[[RES4]][%[[C3]] : index] : vector<4xi32>
+// CHECK: vector.store %[[RES5]], %[[C]][%[[MULT]]] : memref<?xi32>, vector<4xi32>
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: %[[UB1:.*]] = arith.muli %[[COUNT]], %[[C4]] : index
+// CHECK: %[[UB2:.*]] = arith.addi %[[UB1]], %[[C00]] : index
+// CHECK: scf.parallel (%[[I:.*]]) = (%[[UB2]]) to (%[[DIM]]) step (%{{.*}}) {
+// CHECK: %[[A_VAL:.*]] = memref.load %[[A]][%[[I]]] : memref<?xi32>
+// CHECK: %[[B_VAL:.*]] = memref.load %[[B]][%[[I]]] : memref<?xi32>
+// CHECK: %[[R1:.*]] = arith.addi %[[A_VAL:.*]], %[[B_VAL:.*]] : i32
+// CHECK: %[[R2:.*]] = func.call @non_vectorizable(%[[R1]]) : (i32) -> i32
+// CHECK: memref.store %[[R2]], %[[C]][%[[I]]] : memref<?xi32>
+// CHECK: scf.reduce
+// CHECK: }
+func.func @test(%A: memref<?xi32>, %B: memref<?xi32>, %C: memref<?xi32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %count = memref.dim %A, %c0 : memref<?xi32>
+ scf.parallel (%i) = (%c0) to (%count) step (%c1) {
+ %1 = memref.load %A[%i] : memref<?xi32>
+ %2 = memref.load %B[%i] : memref<?xi32>
+ %3 = arith.addi %1, %2 : i32
+ %4 = func.call @non_vectorizable(%3) : (i32) -> (i32)
+ memref.store %4, %C[%i] : memref<?xi32>
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test
+// CHECK-SAME: (%[[A:.*]]: memref<?xindex>, %[[B:.*]]: memref<?xindex>, %[[C:.*]]: memref<?xindex>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?xindex>
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index
+// CHECK: scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) {
+// CHECK: %[[OFFSETS:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index
+// CHECK: %[[O1:.*]] = vector.splat %[[MULT]] : vector<4xindex>
+// CHECK: %[[O2:.*]] = arith.addi %[[O1]], %[[OFFSETS]] : vector<4xindex>
+// CHECK: %[[O3:.*]] = vector.splat %[[C2]] : vector<4xindex>
+// CHECK: %[[O4:.*]] = arith.muli %[[O2]], %[[O3]] : vector<4xindex>
+// CHECK: %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1>
+// CHECK: %[[P:.*]] = ub.poison : vector<4xindex>
+// CHECK: %[[A_VAL:.*]] = vector.gather %arg0[%{{.*}}] [%[[O4]]], %[[MASK]], %[[P]] : memref<?xindex>, vector<4xindex>, vector<4xi1>, vector<4xindex> into vector<4xindex>
+// CHECK: %[[P:.*]] = ub.poison : vector<4xindex>
+// CHECK: %[[B_VAL:.*]] = vector.maskedload %[[B]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xindex>, vector<4xi1>, vector<4xindex> into vector<4xindex>
+// CHECK: %[[RES:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xindex>
+// CHECK: vector.scatter %[[C]][%{{.*}}] [%[[O4]]], %[[MASK]], %[[RES]] : memref<?xindex>, vector<4xindex>, vector<4xi1>, vector<4xindex>
+// CHECK: scf.reduce
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
+func.func @test(%A: memref<?xindex>, %B: memref<?xindex>, %C: memref<?xindex>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %count = memref.dim %A, %c0 : memref<?xindex>
+ scf.parallel (%i) = (%c0) to (%count) step (%c1) {
+ %0 = arith.muli %i, %c2 : index
+ %1 = memref.load %A[%0] : memref<?xindex>
+ %2 = memref.load %B[%i] : memref<?xindex>
+ %3 = arith.addi %1, %2 : index
+ memref.store %3, %C[%0] : memref<?xindex>
+ }
+ return
+}
+}
+
+// -----
+
+// CHECK-LABEL: @test
+// CHECK-SAME: (%[[A:.*]]: memref<?xf32>) -> f32 {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[INIT:.*]] = arith.constant 0.0{{.*}} : f32
+// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?xf32>
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index
+// CHECK: %[[RES:.*]] = scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) init (%[[INIT]]) -> f32 {
+// CHECK: %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index
+// CHECK: %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1>
+// CHECK: %[[P:.*]] = ub.poison : vector<4xf32>
+// CHECK: %[[A_VAL:.*]] = vector.maskedload %[[A]][%[[MULT]]], %[[MASK]], %[[P]] : memref<?xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+// CHECK: %[[N:.*]] = arith.constant 0.0{{.*}} : f32
+// CHECK: %[[N_SPLAT:.*]] = vector.splat %[[N]] : vector<4xf32>
+// CHECK: %[[RED1:.*]] = arith.select %[[MASK]], %[[A_VAL]], %[[N_SPLAT]] : vector<4xi1>, vector<4xf32>
+// CHECK: %[[RED2:.*]] = vector.reduction <add>, %[[RED1]] : vector<4xf32> into f32
+// CHECK: scf.reduce(%[[RED2]] : f32) {
+// CHECK: ^bb0(%[[R_ARG1:.*]]: f32, %[[R_ARG2:.*]]: f32):
+// CHECK: %[[R:.*]] = arith.addf %[[R_ARG1]], %[[R_ARG2]] : f32
+// CHECK: scf.reduce.return %[[R]] : f32
+// CHECK: }
+// CHECK: return %[[RES]] : f32
+func.func @test(%A: memref<?xf32>) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %init = arith.constant 0.0 : f32
+ %count = memref.dim %A, %c0 : memref<?xf32>
+ %res = scf.parallel (%i) = (%c0) to (%count) step (%c1) init (%init) -> f32 {
+ %1 = memref.load %A[%i] : memref<?xf32>
+ scf.reduce(%1 : f32) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %2 = arith.addf %lhs, %rhs : f32
+ scf.reduce.return %2 : f32
+ }
+ }
+ return %res : f32
+}
+
+
+// -----
+
+func.func private @combine(f32, f32) -> (f32)
+
+// CHECK-LABEL: @test
+// CHECK-SAME: (%[[A:.*]]: memref<?xf32>) -> f32 {
+// CHECK: %[[C00:.*]] = arith.constant 0 : index
+// CHECK: %[[INIT:.*]] = arith.constant 0.0{{.*}} : f32
+// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C00]] : memref<?xf32>
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[COUNT:.*]] = arith.divsi %[[DIM]], %[[C4]] : index
+// CHECK: %[[RES:.*]] = scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) init (%[[INIT]]) -> f32 {
+// CHECK: %[[MULT:.*]] = arith.muli %arg1, %c4 : index
+// CHECK: %[[A_VAL:.*]] = vector.load %[[A]][%[[MULT]]] : memref<?xf32>, vector<4xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[E0:.*]] = vector.extractelement %[[A_VAL]][%[[C0]] : index] : vector<4xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[E1:.*]] = vector.extractelement %[[A_VAL]][%[[C1]] : index] : vector<4xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[E2:.*]] = vector.extractelement %[[A_VAL]][%[[C2]] : index] : vector<4xf32>
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[E3:.*]] = vector.extractelement %[[A_VAL]][%[[C3]] : index] : vector<4xf32>
+// CHECK: %[[R0:.*]] = func.call @combine(%[[E0]], %[[E1]]) : (f32, f32) -> f32
+// CHECK: %[[R1:.*]] = func.call @combine(%[[R0]], %[[E2]]) : (f32, f32) -> f32
+// CHECK: %[[R2:.*]] = func.call @combine(%[[R1]], %[[E3]]) : (f32, f32) -> f32
+// CHECK: scf.reduce(%[[R2]] : f32) {
+// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK: %[[R:.*]] = func.call @combine(%[[LHS]], %[[RHS]]) : (f32, f32) -> f32
+// CHECK: scf.reduce.return %[[R]] : f32
+// CHECK: }
+// CHECK: }
+// CHECK: %[[UB1:.*]] = arith.muli %[[COUNT]], %[[C4]] : index
+// CHECK: %[[UB2:.*]] = arith.addi %[[UB1]], %[[C00]] : index
+// CHECK: %[[RES1:.*]] = scf.parallel (%[[I:.*]]) = (%[[UB2]]) to (%[[DIM]]) step (%{{.*}}) init (%[[RES]]) -> f32 {
+// CHECK: %[[A_VAL:.*]] = memref.load %[[A]][%[[I]]] : memref<?xf32>
+// CHECK: scf.reduce(%[[A_VAL]] : f32) {
+// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+// CHECK: %[[R:.*]] = func.call @combine(%[[LHS]], %[[RHS]]) : (f32, f32) -> f32
+// CHECK: scf.reduce.return %[[R]] : f32
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[RES1]] : f32
+func.func @test(%A: memref<?xf32>) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %init = arith.constant 0.0 : f32
+ %count = memref.dim %A, %c0 : memref<?xf32>
+ %res = scf.parallel (%i) = (%c0) to (%count) step (%c1) init (%init) -> f32 {
+ %1 = memref.load %A[%i] : memref<?xf32>
+ scf.reduce(%1 : f32) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %2 = func.call @combine(%lhs, %rhs) : (f32, f32) -> (f32)
+ scf.reduce.return %2 : f32
+ }
+ }
+ return %res : f32
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 975a41ac3d5fe..01c92199b6f3a 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_library(MLIRTestTransforms
TestInlining.cpp
TestIntRangeInference.cpp
TestMakeIsolatedFromAbove.cpp
+ TestSCFVectorize.cpp
${MLIRTestTransformsPDLSrc}
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Transforms/TestSCFVectorize.cpp b/mlir/test/lib/Transforms/TestSCFVectorize.cpp
new file mode 100644
index 0000000000000..84ea190a33de2
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestSCFVectorize.cpp
@@ -0,0 +1,110 @@
+//===- TestSCFVectorize.cpp - SCF vectorization test pass -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/SCFVectorize.h"
+
+#include "mlir/Analysis/DataLayoutAnalysis.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassRegistry.h"
+
+using namespace mlir;
+
+namespace {
+struct TestSCFVectorizePass
+ : public PassWrapper<TestSCFVectorizePass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFVectorizePass)
+
+ TestSCFVectorizePass() = default;
+ TestSCFVectorizePass(const TestSCFVectorizePass &pass) : PassWrapper(pass) {}
+
+ Option<unsigned> vectorBitwidth{*this, "vector-bitwidth",
+ llvm::cl::desc("Target vector bitwidth "),
+ llvm::cl::init(128)};
+
+ StringRef getArgument() const final { return "test-scf-vectorize"; }
+ StringRef getDescription() const final { return "Test SCF vectorization"; }
+
+ virtual void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect>();
+ registry.insert<scf::SCFDialect>();
+ registry.insert<ub::UBDialect>();
+ registry.insert<vector::VectorDialect>();
+ }
+
+ LogicalResult initializeOptions(
+ StringRef options,
+ function_ref<LogicalResult(const Twine &)> errorHandler) override {
+ if (failed(PassWrapper::initializeOptions(options, errorHandler)))
+ return failure();
+
+ if (vectorBitwidth <= 0)
+ return errorHandler("Invalid vector bitwidth: " +
+ llvm::Twine(vectorBitwidth));
+
+ return success();
+ }
+
+ void runOnOperation() override {
+ Operation *root = getOperation();
+ auto &DLAnalysis = getAnalysis<DataLayoutAnalysis>();
+
+ llvm::SmallVector<std::pair<scf::ParallelOp, SCFVectorizeParams>>
+ toVectorize;
+
+ // Simple heuristic: total number of elements processed by vector ops, but
+ // prefer masked mode over non-masked.
+ auto getBenefit = [](const SCFVectorizeInfo &info) {
+ return info.factor * info.count * (int(info.masked) + 1);
+ };
+
+ root->walk([&](scf::ParallelOp loop) {
+ const DataLayout &DL = DLAnalysis.getAbove(loop);
+ std::optional<SCFVectorizeInfo> best;
+ for (auto dim : llvm::seq(0u, loop.getNumLoops())) {
+ auto info = getLoopVectorizeInfo(loop, dim, vectorBitwidth, &DL);
+ if (!info)
+ continue;
+
+ if (!best) {
+ best = *info;
+ continue;
+ }
+
+ if (getBenefit(*info) > getBenefit(*best))
+ best = *info;
+ }
+
+ if (!best)
+ return;
+
+ toVectorize.emplace_back(
+ loop, SCFVectorizeParams{best->dim, best->factor, best->masked});
+ });
+
+ if (toVectorize.empty())
+ return markAllAnalysesPreserved();
+
+ for (auto &&[loop, params] : toVectorize) {
+ const DataLayout &DL = DLAnalysis.getAbove(loop);
+ if (failed(vectorizeLoop(loop, params, &DL)))
+ return signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTesSCFVectorize() { PassRegistration<TestSCFVectorizePass>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index d2ba3d06835fb..1ddf437233326 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -75,6 +75,7 @@ void registerInliner();
void registerMemRefBoundCheck();
void registerPatternsTestPass();
void registerSimpleParametricTilingPass();
+void registerTesSCFVectorize();
void registerTestAffineLoopParametricTilingPass();
void registerTestAliasAnalysisPass();
void registerTestArithEmulateWideIntPass();
@@ -204,6 +205,7 @@ void registerTestPasses() {
mlir::test::registerMemRefBoundCheck();
mlir::test::registerPatternsTestPass();
mlir::test::registerSimpleParametricTilingPass();
+ mlir::test::registerTesSCFVectorize();
mlir::test::registerTestAffineLoopParametricTilingPass();
mlir::test::registerTestAliasAnalysisPass();
mlir::test::registerTestArithEmulateWideIntPass();
>From e58d2924b0109a371d677d8b9aa6d6bc9cfff00c Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 12:19:00 +0200
Subject: [PATCH 04/10] fix typo
---
mlir/lib/Transforms/SCFVectorize.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
index 29e184e584a56..c74cfa4abf80d 100644
--- a/mlir/lib/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -66,7 +66,7 @@ static bool isRangePermutation(ValueRange val1, ValueRange val2) {
template <typename Op>
static std::optional<unsigned>
-cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
+canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
const DataLayout *DL) {
auto loopIndexVars = loop.getInductionVars();
assert(dim < loopIndexVars.size());
@@ -97,14 +97,14 @@ cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
/// Returns memref element bitwidth or `std::nullopt` if access cannot be
/// vectorized.
static std::optional<unsigned>
-cavTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
+canTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
const DataLayout *DL) {
assert(dim < loop.getInductionVars().size());
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
- return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
+ return canTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
- return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL);
+ return canTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL);
return std::nullopt;
}
@@ -141,7 +141,7 @@ canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) {
static std::optional<unsigned> cenVectorizeMemrefOp(scf::ParallelOp loop,
unsigned dim, Operation &op,
const DataLayout *DL) {
- if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op, DL))
+ if (auto w = canTriviallyVectorizeMemOp(loop, dim, op, DL))
return w;
return canGatherScatter(loop, op, DL);
@@ -419,7 +419,7 @@ LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
};
auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
- return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
+ return !!::canTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
};
auto canGatherScatter = [&](auto op) {
>From d5167004f47c47354847e276d75aa8918f582072 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 12:20:14 +0200
Subject: [PATCH 05/10] use has_value()
---
mlir/lib/Transforms/SCFVectorize.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
index c74cfa4abf80d..7b907b655976a 100644
--- a/mlir/lib/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -419,11 +419,11 @@ LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
};
auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
- return !!::canTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
+ return ::canTriviallyVectorizeMemOpImpl(loop, dim, op, DL).has_value();
};
auto canGatherScatter = [&](auto op) {
- return !!::canGatherScatterImpl(loop, op, DL);
+ return ::canGatherScatterImpl(loop, op, DL).has_value();
};
// Get idices for vectorized memref load/store.
>From c33fb1ad508bf326cfc73aafe2625cdde6c6e486 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 12:44:54 +0200
Subject: [PATCH 06/10] move files to scf dialect
---
.../{ => Dialect/SCF}/Transforms/SCFVectorize.h | 5 ++---
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt | 3 ++-
.../{ => Dialect/SCF}/Transforms/SCFVectorize.cpp | 14 +++++++-------
mlir/lib/Transforms/CMakeLists.txt | 1 -
.../SCF}/test-scf-vectorize.mlir | 0
mlir/test/lib/Dialect/SCF/CMakeLists.txt | 1 +
.../SCF}/TestSCFVectorize.cpp | 14 +++++++-------
mlir/test/lib/Transforms/CMakeLists.txt | 1 -
8 files changed, 19 insertions(+), 20 deletions(-)
rename mlir/include/mlir/{ => Dialect/SCF}/Transforms/SCFVectorize.h (98%)
rename mlir/lib/{ => Dialect/SCF}/Transforms/SCFVectorize.cpp (97%)
rename mlir/test/{Transforms => Dialect/SCF}/test-scf-vectorize.mlir (100%)
rename mlir/test/lib/{Transforms => Dialect/SCF}/TestSCFVectorize.cpp (87%)
diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Dialect/SCF/Transforms/SCFVectorize.h
similarity index 98%
rename from mlir/include/mlir/Transforms/SCFVectorize.h
rename to mlir/include/mlir/Dialect/SCF/Transforms/SCFVectorize.h
index d2a5e3085ae37..ebaa1edac531c 100644
--- a/mlir/include/mlir/Transforms/SCFVectorize.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/SCFVectorize.h
@@ -17,9 +17,7 @@ struct LogicalResult;
namespace scf {
class ParallelOp;
}
-} // namespace mlir
-
-namespace mlir {
+namespace scf {
/// Loop vectorization info
struct SCFVectorizeInfo {
@@ -65,6 +63,7 @@ struct SCFVectorizeParams {
mlir::LogicalResult vectorizeLoop(mlir::scf::ParallelOp loop,
const SCFVectorizeParams ¶ms,
const DataLayout *DL = nullptr);
+} // namespace scf
} // namespace mlir
#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index d363ffe941fce..898f20efa7078 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,10 +13,11 @@ add_mlir_dialect_library(MLIRSCFTransforms
ParallelLoopCollapsing.cpp
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
+ SCFVectorize.cpp
StructuralTypeConversions.cpp
TileUsingInterface.cpp
- WrapInZeroTripCheck.cpp
UpliftWhileToFor.cpp
+ WrapInZeroTripCheck.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
similarity index 97%
rename from mlir/lib/Transforms/SCFVectorize.cpp
rename to mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
index 7b907b655976a..7bc7fa544f286 100644
--- a/mlir/lib/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/SCFVectorize.h"
+#include "mlir/Dialect/SCF/Transforms/SCFVectorize.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind
@@ -157,9 +157,9 @@ static std::optional<vector::CombiningKind> getReductionKind(Block &body) {
return linalg::getCombinerOpKind(&body.front());
}
-std::optional<SCFVectorizeInfo>
-mlir::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
- unsigned vectorBitwidth, const DataLayout *DL) {
+std::optional<scf::SCFVectorizeInfo>
+mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
+ unsigned vectorBitwidth, const DataLayout *DL) {
assert(dim < loop.getStep().size());
assert(vectorBitwidth > 0);
unsigned factor = vectorBitwidth / 8;
@@ -234,9 +234,9 @@ static arith::FastMathFlags getFMF(Operation &op) {
return arith::FastMathFlags::none;
}
-LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
- const SCFVectorizeParams ¶ms,
- const DataLayout *DL) {
+LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
+ const scf::SCFVectorizeParams ¶ms,
+ const DataLayout *DL) {
auto dim = params.dim;
auto factor = params.factor;
auto masked = params.masked;
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index ed71c73c938ed..90c0298fb5e46 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,7 +14,6 @@ add_mlir_library(MLIRTransforms
PrintIR.cpp
RemoveDeadValues.cpp
SCCP.cpp
- SCFVectorize.cpp
SROA.cpp
StripDebugInfo.cpp
SymbolDCE.cpp
diff --git a/mlir/test/Transforms/test-scf-vectorize.mlir b/mlir/test/Dialect/SCF/test-scf-vectorize.mlir
similarity index 100%
rename from mlir/test/Transforms/test-scf-vectorize.mlir
rename to mlir/test/Dialect/SCF/test-scf-vectorize.mlir
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 792430cc84b65..9af1459d17df9 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRSCFTestPasses
TestLoopParametricTiling.cpp
TestLoopUnrolling.cpp
TestSCFUtils.cpp
+ TestSCFVectorize.cpp
TestSCFWrapInZeroTripCheck.cpp
TestUpliftWhileToFor.cpp
TestWhileOpBuilder.cpp
diff --git a/mlir/test/lib/Transforms/TestSCFVectorize.cpp b/mlir/test/lib/Dialect/SCF/TestSCFVectorize.cpp
similarity index 87%
rename from mlir/test/lib/Transforms/TestSCFVectorize.cpp
rename to mlir/test/lib/Dialect/SCF/TestSCFVectorize.cpp
index 84ea190a33de2..3f92dd03438bd 100644
--- a/mlir/test/lib/Transforms/TestSCFVectorize.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFVectorize.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/SCFVectorize.h"
+#include "mlir/Dialect/SCF/Transforms/SCFVectorize.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -58,20 +58,20 @@ struct TestSCFVectorizePass
Operation *root = getOperation();
auto &DLAnalysis = getAnalysis<DataLayoutAnalysis>();
- llvm::SmallVector<std::pair<scf::ParallelOp, SCFVectorizeParams>>
+ llvm::SmallVector<std::pair<scf::ParallelOp, scf::SCFVectorizeParams>>
toVectorize;
// Simple heuristic: total number of elements processed by vector ops, but
// prefer masked mode over non-masked.
- auto getBenefit = [](const SCFVectorizeInfo &info) {
+ auto getBenefit = [](const scf::SCFVectorizeInfo &info) {
return info.factor * info.count * (int(info.masked) + 1);
};
root->walk([&](scf::ParallelOp loop) {
const DataLayout &DL = DLAnalysis.getAbove(loop);
- std::optional<SCFVectorizeInfo> best;
+ std::optional<scf::SCFVectorizeInfo> best;
for (auto dim : llvm::seq(0u, loop.getNumLoops())) {
- auto info = getLoopVectorizeInfo(loop, dim, vectorBitwidth, &DL);
+ auto info = scf::getLoopVectorizeInfo(loop, dim, vectorBitwidth, &DL);
if (!info)
continue;
@@ -88,7 +88,7 @@ struct TestSCFVectorizePass
return;
toVectorize.emplace_back(
- loop, SCFVectorizeParams{best->dim, best->factor, best->masked});
+ loop, scf::SCFVectorizeParams{best->dim, best->factor, best->masked});
});
if (toVectorize.empty())
@@ -96,7 +96,7 @@ struct TestSCFVectorizePass
for (auto &&[loop, params] : toVectorize) {
const DataLayout &DL = DLAnalysis.getAbove(loop);
- if (failed(vectorizeLoop(loop, params, &DL)))
+ if (failed(scf::vectorizeLoop(loop, params, &DL)))
return signalPassFailure();
}
}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 01c92199b6f3a..975a41ac3d5fe 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -26,7 +26,6 @@ add_mlir_library(MLIRTestTransforms
TestInlining.cpp
TestIntRangeInference.cpp
TestMakeIsolatedFromAbove.cpp
- TestSCFVectorize.cpp
${MLIRTestTransformsPDLSrc}
EXCLUDE_FROM_LIBMLIR
>From d8427d7f672c8be8695c71b2e76c04dae4285286 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 19:10:39 +0200
Subject: [PATCH 07/10] getTypeBitWidth std::optional
---
.../Dialect/SCF/Transforms/SCFVectorize.cpp | 46 +++++++++++--------
1 file changed, 28 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
index 7bc7fa544f286..7e7add925dbb9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
@@ -20,11 +20,12 @@ using namespace mlir;
static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); }
-/// Return type bitwidth for vectorization purposes or 0 if type cannot be
+/// Return type bitwidth for vectorization purposes or empty if type cannot be
/// vectorized.
-static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
+static std::optional<unsigned> getTypeBitWidth(Type type,
+ const DataLayout *DL) {
if (!isSupportedVecElem(type))
- return 0;
+ return std::nullopt;
if (DL)
return DL->getTypeSizeInBits(type);
@@ -32,16 +33,21 @@ static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
if (type.isIntOrFloat())
return type.getIntOrFloatBitWidth();
- return 0;
+ return std::nullopt;
}
-static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) {
+static std::optional<unsigned> getArgsTypeWidth(Operation &op,
+ const DataLayout *DL) {
unsigned ret = 0;
- for (auto arg : op.getOperands())
- ret = std::max(ret, getTypeBitWidth(arg.getType(), DL));
+ for (auto r : {ValueRange(op.getOperands()), ValueRange(op.getResults())}) {
+ for (auto arg : op.getOperands()) {
+ std::optional<unsigned> w = getTypeBitWidth(arg.getType(), DL);
+ if (!w)
+ return std::nullopt;
- for (auto res : op.getResults())
- ret = std::max(ret, getTypeBitWidth(res.getType(), DL));
+ ret = std::max(ret, *w);
+ }
+ }
return ret;
}
@@ -72,8 +78,8 @@ canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
assert(dim < loopIndexVars.size());
auto memref = memOp.getMemRef();
auto type = cast<MemRefType>(memref.getType());
- auto width = getTypeBitWidth(type.getElementType(), DL);
- if (width == 0)
+ std::optional<unsigned> width = getTypeBitWidth(type.getElementType(), DL);
+ if (!width)
return std::nullopt;
if (!type.getLayout().isIdentity())
@@ -114,13 +120,17 @@ static std::optional<unsigned> canGatherScatterImpl(scf::ParallelOp loop, Op op,
const DataLayout *DL) {
auto memref = op.getMemRef();
auto memrefType = cast<MemRefType>(memref.getType());
- auto width = getTypeBitWidth(memrefType.getElementType(), DL);
- if (width == 0)
+ std::optional<unsigned> width =
+ getTypeBitWidth(memrefType.getElementType(), DL);
+ if (!width)
return std::nullopt;
DominanceInfo dom;
- return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
- memrefType.getLayout().isIdentity();
+ if (!dom.properlyDominates(memref, loop) || op.getIndices().size() != 1 ||
+ !memrefType.getLayout().isIdentity())
+ return std::nullopt;
+
+ return width;
}
// Check if memref access can be converted into gather/scatter.
@@ -206,11 +216,11 @@ mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
continue;
}
- auto width = getArgsTypeWidth(op, DL);
- if (width == 0)
+ std::optional<unsigned> width = getArgsTypeWidth(op, DL);
+ if (!width)
return std::nullopt;
- auto newFactor = vectorBitwidth / width;
+ auto newFactor = vectorBitwidth / *width;
if (newFactor <= 1)
continue;
>From f8da459b45d300ef080ac27dc7ba308ccda65c4b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 19:17:05 +0200
Subject: [PATCH 08/10] update assert messages
---
.../Dialect/SCF/Transforms/SCFVectorize.cpp | 30 +++++++++----------
1 file changed, 15 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
index 7e7add925dbb9..536efc72a0305 100644
--- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
@@ -75,7 +75,7 @@ static std::optional<unsigned>
canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
const DataLayout *DL) {
auto loopIndexVars = loop.getInductionVars();
- assert(dim < loopIndexVars.size());
+ assert(dim < loopIndexVars.size() && "Invalid loop dimension");
auto memref = memOp.getMemRef();
auto type = cast<MemRefType>(memref.getType());
std::optional<unsigned> width = getTypeBitWidth(type.getElementType(), DL);
@@ -105,7 +105,7 @@ canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
static std::optional<unsigned>
canTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
const DataLayout *DL) {
- assert(dim < loop.getInductionVars().size());
+ assert(dim < loop.getInductionVars().size() && "Invalid loop dimension");
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
return canTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
@@ -170,8 +170,8 @@ static std::optional<vector::CombiningKind> getReductionKind(Block &body) {
std::optional<scf::SCFVectorizeInfo>
mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
unsigned vectorBitwidth, const DataLayout *DL) {
- assert(dim < loop.getStep().size());
- assert(vectorBitwidth > 0);
+ assert(dim < loop.getStep().size() && "Invalid loop dimension");
+ assert(vectorBitwidth > 0 && "Invalid vector bitwidth");
unsigned factor = vectorBitwidth / 8;
if (factor <= 1)
return std::nullopt;
@@ -250,9 +250,9 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
auto dim = params.dim;
auto factor = params.factor;
auto masked = params.masked;
- assert(dim < loop.getStep().size());
- assert(factor > 1);
- assert(isConstantIntValue(loop.getStep()[dim], 1));
+ assert(dim < loop.getStep().size() && "Invalid loop dimension");
+ assert(factor > 1 && "Invalid vectorize factor");
+ assert(isConstantIntValue(loop.getStep()[dim], 1) && "Loop stepust be 1");
OpBuilder builder(loop);
auto lower = llvm::to_vector(loop.getLowerBound());
@@ -332,7 +332,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
return vec;
}
auto type = orig.getType();
- assert(isSupportedVecElem(type));
+ assert(isSupportedVecElem(type) && "Unsupported vector element type");
Value val = orig;
auto origIndexVars = loop.getInductionVars();
@@ -367,7 +367,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
// cache and not handled here.
auto &ret = unpackedVals[val];
- assert(ret.empty());
+ assert(ret.empty() && "Invalid unpackedVals state");
if (!isSupportedVecElem(val.getType())) {
// Non vectorizable value, it must be a value defined outside the loop,
// just replicate it.
@@ -387,8 +387,8 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
// Add unpacked values to the cache.
auto setUnpackedVals = [&](Value origVal, ValueRange newVals) {
- assert(newVals.size() == factor);
- assert(unpackedVals.count(origVal) == 0);
+ assert(newVals.size() == factor && "Invalid values count");
+ assert(unpackedVals.count(origVal) == 0 && "Invalid unpackedVals state");
unpackedVals[origVal].append(newVals.begin(), newVals.end());
auto type = origVal.getType();
@@ -555,7 +555,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
for (auto &&[i, arg] : llvm::enumerate(op.getOperands())) {
auto unpacked = getUnpackedVals(arg);
- assert(unpacked.size() == factor);
+ assert(unpacked.size() == factor && "Invalid unpacked size");
for (auto j : llvm::seq(0u, factor))
duplicatedArgs[j * numArgs + i] = unpacked[j];
}
@@ -588,7 +588,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
llvm::zip(reduceOp.getReductions(), reduceOp.getOperands())) {
scalarMapping.clear();
Block &reduceBody = body.front();
- assert(reduceBody.getNumArguments() == 2);
+ assert(reduceBody.getNumArguments() == 2 && "Malformed scf.reduce");
Value reduceVal;
if (auto redKind = getReductionKind(reduceBody)) {
@@ -596,7 +596,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
Value redArg = getVecVal(arg);
if (redArg) {
auto neutral = arith::getNeutralElement(&reduceBody.front());
- assert(neutral);
+ assert(neutral && "getNeutralElement has unepectedly failed");
Value neutralVal = builder.create<arith::ConstantOp>(loc, *neutral);
Value neutralVec =
builder.create<vector::SplatOp>(loc, neutralVal, redArg.getType());
@@ -618,7 +618,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
auto lhs = reduceBody.getArgument(0);
auto rhs = reduceBody.getArgument(1);
auto unpacked = getUnpackedVals(arg);
- assert(unpacked.size() == factor);
+ assert(unpacked.size() == factor && "Invalid unpacked size");
reduceVal = unpacked.front();
for (auto i : llvm::seq(1u, factor)) {
Value val = unpacked[i];
>From c4f9d1e7f55320a279594ffe83f7d2c07eca5860 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 19:42:01 +0200
Subject: [PATCH 09/10] remove auto
---
.../Dialect/SCF/Transforms/SCFVectorize.cpp | 139 +++++++++---------
1 file changed, 69 insertions(+), 70 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
index 536efc72a0305..b7d1281fb20ca 100644
--- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
@@ -40,7 +40,7 @@ static std::optional<unsigned> getArgsTypeWidth(Operation &op,
const DataLayout *DL) {
unsigned ret = 0;
for (auto r : {ValueRange(op.getOperands()), ValueRange(op.getResults())}) {
- for (auto arg : op.getOperands()) {
+ for (Value arg : op.getOperands()) {
std::optional<unsigned> w = getTypeBitWidth(arg.getType(), DL);
if (!w)
return std::nullopt;
@@ -62,7 +62,7 @@ static bool isRangePermutation(ValueRange val1, ValueRange val2) {
if (val1.size() != val2.size())
return false;
- for (auto v1 : val1) {
+ for (Value v1 : val1) {
auto it = llvm::find(val2, v1);
if (it == val2.end())
return false;
@@ -74,9 +74,9 @@ template <typename Op>
static std::optional<unsigned>
canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
const DataLayout *DL) {
- auto loopIndexVars = loop.getInductionVars();
+ ValueRange loopIndexVars = loop.getInductionVars();
assert(dim < loopIndexVars.size() && "Invalid loop dimension");
- auto memref = memOp.getMemRef();
+ Value memref = memOp.getMemRef();
auto type = cast<MemRefType>(memref.getType());
std::optional<unsigned> width = getTypeBitWidth(type.getElementType(), DL);
if (!width)
@@ -118,7 +118,7 @@ canTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
template <typename Op>
static std::optional<unsigned> canGatherScatterImpl(scf::ParallelOp loop, Op op,
const DataLayout *DL) {
- auto memref = op.getMemRef();
+ Value memref = op.getMemRef();
auto memrefType = cast<MemRefType>(memref.getType());
std::optional<unsigned> width =
getTypeBitWidth(memrefType.getElementType(), DL);
@@ -151,7 +151,7 @@ canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) {
static std::optional<unsigned> cenVectorizeMemrefOp(scf::ParallelOp loop,
unsigned dim, Operation &op,
const DataLayout *DL) {
- if (auto w = canTriviallyVectorizeMemOp(loop, dim, op, DL))
+ if (std::optional<unsigned> w = canTriviallyVectorizeMemOp(loop, dim, op, DL))
return w;
return canGatherScatter(loop, op, DL);
@@ -200,8 +200,8 @@ mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
return std::nullopt;
/// Check mem ops.
- if (auto w = cenVectorizeMemrefOp(loop, dim, op, DL)) {
- auto newFactor = vectorBitwidth / *w;
+ if (std::optional<unsigned> w = cenVectorizeMemrefOp(loop, dim, op, DL)) {
+ unsigned newFactor = vectorBitwidth / *w;
if (newFactor > 1) {
factor = std::min(factor, newFactor);
++count;
@@ -220,7 +220,7 @@ mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
if (!width)
return std::nullopt;
- auto newFactor = vectorBitwidth / *width;
+ unsigned newFactor = vectorBitwidth / *width;
if (newFactor <= 1)
continue;
@@ -247,26 +247,26 @@ static arith::FastMathFlags getFMF(Operation &op) {
LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
const scf::SCFVectorizeParams ¶ms,
const DataLayout *DL) {
- auto dim = params.dim;
- auto factor = params.factor;
- auto masked = params.masked;
+ unsigned dim = params.dim;
+ unsigned factor = params.factor;
+ bool masked = params.masked;
assert(dim < loop.getStep().size() && "Invalid loop dimension");
assert(factor > 1 && "Invalid vectorize factor");
assert(isConstantIntValue(loop.getStep()[dim], 1) && "Loop stepust be 1");
OpBuilder builder(loop);
- auto lower = llvm::to_vector(loop.getLowerBound());
- auto upper = llvm::to_vector(loop.getUpperBound());
- auto step = llvm::to_vector(loop.getStep());
+ SmallVector<Value> lower = llvm::to_vector(loop.getLowerBound());
+ SmallVector<Value> upper = llvm::to_vector(loop.getUpperBound());
+ SmallVector<Value> step = llvm::to_vector(loop.getStep());
- auto loc = loop.getLoc();
+ Location loc = loop.getLoc();
- auto origIndexVar = loop.getInductionVars()[dim];
+ Value origIndexVar = loop.getInductionVars()[dim];
Value factorVal = builder.create<arith::ConstantIndexOp>(loc, factor);
- auto origLower = lower[dim];
- auto origUpper = upper[dim];
+ Value origLower = lower[dim];
+ Value origUpper = upper[dim];
Value count = builder.createOrFold<arith::SubIOp>(loc, origUpper, origLower);
Value newCount;
@@ -284,10 +284,10 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
// Vectorized loop.
auto newLoop = builder.create<scf::ParallelOp>(loc, lower, upper, step,
loop.getInitVals());
- auto newIndexVar = newLoop.getInductionVars()[dim];
+ Value newIndexVar = newLoop.getInductionVars()[dim];
auto toVectorType = [&](Type elemType) -> VectorType {
- int64_t f = factor;
+ auto f = static_cast<int64_t>(factor);
return VectorType::get(f, elemType);
};
@@ -311,14 +311,14 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
// Get vector value in new loop for provided `orig` value in source loop.
auto getVecVal = [&](Value orig) -> Value {
// Use cached value if present.
- if (auto mapped = mapping.lookupOrNull(orig))
+ if (Value mapped = mapping.lookupOrNull(orig))
return mapped;
// Vectorized loop index, loop index is divided by factor, so for factorN
// vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)`
if (orig == origIndexVar) {
- auto vecType = toVectorType(builder.getIndexType());
- llvm::SmallVector<Attribute> elems(factor);
+ VectorType vecType = toVectorType(builder.getIndexType());
+ SmallVector<Attribute> elems(factor);
for (auto i : llvm::seq(0u, factor))
elems[i] = builder.getIndexAttr(i);
auto attr = DenseElementsAttr::get(vecType, elems);
@@ -331,11 +331,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
mapping.map(orig, vec);
return vec;
}
- auto type = orig.getType();
+ Type type = orig.getType();
assert(isSupportedVecElem(type) && "Unsupported vector element type");
Value val = orig;
- auto origIndexVars = loop.getInductionVars();
+ ValueRange origIndexVars = loop.getInductionVars();
auto it = llvm::find(origIndexVars, orig);
// If loop index, but not on vectorized dimension, just take new loop index
@@ -346,14 +346,12 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
// Values which are defined inside loop body are preemptively added to the
// mapper and not handled here. Values defined outside body are just
// splatted.
-
- auto vecType = toVectorType(type);
- Value vec = builder.create<vector::SplatOp>(loc, val, vecType);
+ Value vec = builder.create<vector::SplatOp>(loc, val, toVectorType(type));
mapping.map(orig, vec);
return vec;
};
- llvm::DenseMap<Value, llvm::SmallVector<Value>> unpackedVals;
+ llvm::DenseMap<Value, SmallVector<Value>> unpackedVals;
// Get unpacked values for provided `orig` value in source loop.
// Values are returned as `ValueRange` and not as vector value.
@@ -376,7 +374,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
}
// Get vector value and extract elements from it.
- auto vecVal = getVecVal(val);
+ Value vecVal = getVecVal(val);
ret.resize(factor);
for (auto i : llvm::seq(0u, factor)) {
Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
@@ -391,13 +389,13 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
assert(unpackedVals.count(origVal) == 0 && "Invalid unpackedVals state");
unpackedVals[origVal].append(newVals.begin(), newVals.end());
- auto type = origVal.getType();
+ Type type = origVal.getType();
if (!isSupportedVecElem(type))
return;
// If type is vectorizabale construct a vector add it to vector cache as
// well.
- auto vecType = toVectorType(type);
+ VectorType vecType = toVectorType(type);
Value vec = createPosionVec(vecType);
for (auto i : llvm::seq(0u, factor)) {
@@ -422,7 +420,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
} else {
maskSize = builder.getIndexAttr(factor);
}
- auto vecType = toVectorType(builder.getI1Type());
+ VectorType vecType = toVectorType(builder.getI1Type());
mask = builder.create<vector::CreateMaskOp>(loc, vecType, maskSize);
return mask;
@@ -437,11 +435,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
};
// Get idices for vectorized memref load/store.
- auto getMemrefVecIndices = [&](ValueRange indices) {
+ auto getMemrefVecIndices = [&](ValueRange indices) -> SmallVector<Value> {
scalarMapping.clear();
scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
- llvm::SmallVector<Value> ret(indices.size());
+ SmallVector<Value> ret(indices.size());
for (auto &&[i, val] : llvm::enumerate(indices)) {
if (val == origIndexVar) {
Value idx = getrIndexVarMult();
@@ -457,13 +455,13 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
// Create vectorized memref load for specified non-vectorized load.
auto genLoad = [&](auto loadOp) {
- auto indices = getMemrefVecIndices(loadOp.getIndices());
- auto resType = toVectorType(loadOp.getResult().getType());
- auto memref = loadOp.getMemRef();
+ SmallVector<Value> indices = getMemrefVecIndices(loadOp.getIndices());
+ VectorType resType = toVectorType(loadOp.getResult().getType());
+ Value memref = loadOp.getMemRef();
Value vecLoad;
if (masked) {
- auto mask = getMask();
- auto init = createPosionVec(resType);
+ Value mask = getMask();
+ Value init = createPosionVec(resType);
vecLoad = builder.create<vector::MaskedLoadOp>(loc, resType, memref,
indices, mask, init);
} else {
@@ -474,19 +472,19 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
// Create vectorized memref store for specified non-vectorized store.
auto genStore = [&](auto storeOp) {
- auto indices = getMemrefVecIndices(storeOp.getIndices());
- auto value = getVecVal(storeOp.getValueToStore());
- auto memref = storeOp.getMemRef();
+ SmallVector<Value> indices = getMemrefVecIndices(storeOp.getIndices());
+ Value value = getVecVal(storeOp.getValueToStore());
+ Value memref = storeOp.getMemRef();
if (masked) {
- auto mask = getMask();
+ Value mask = getMask();
builder.create<vector::MaskedStoreOp>(loc, memref, indices, mask, value);
} else {
builder.create<vector::StoreOp>(loc, value, memref, indices);
}
};
- llvm::SmallVector<Value> duplicatedArgs;
- llvm::SmallVector<Value> duplicatedResults;
+ SmallVector<Value> duplicatedArgs;
+ SmallVector<Value> duplicatedResults;
builder.setInsertionPointToStart(newLoop.getBody());
for (Operation &op : loop.getBody()->without_terminator()) {
@@ -494,11 +492,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
if (isSupportedVectorOp(op)) {
// If op can be vectorized, clone it with vectorized inputs and update
// resuls to vectorized types.
- for (auto arg : op.getOperands())
+ for (Value arg : op.getOperands())
getVecVal(arg); // init mapper for op args
- auto newOp = builder.clone(op, mapping);
- for (auto res : newOp->getResults())
+ Operation *newOp = builder.clone(op, mapping);
+ for (Value res : newOp->getResults())
res.setType(toVectorType(res.getType()));
continue;
@@ -512,11 +510,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
continue;
}
if (canGatherScatter(loadOp)) {
- auto resType = toVectorType(loadOp.getResult().getType());
- auto memref = loadOp.getMemRef();
- auto mask = getMask();
- auto indexVec = getVecVal(loadOp.getIndices()[0]);
- auto init = createPosionVec(resType);
+ VectorType resType = toVectorType(loadOp.getResult().getType());
+ Value memref = loadOp.getMemRef();
+ Value mask = getMask();
+ Value indexVec = getVecVal(loadOp.getIndices()[0]);
+ Value init = createPosionVec(resType);
auto gather = builder.create<vector::GatherOp>(
loc, resType, memref, zero, indexVec, mask, init);
@@ -531,10 +529,10 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
continue;
}
if (canGatherScatter(storeOp)) {
- auto memref = storeOp.getMemRef();
- auto value = getVecVal(storeOp.getValueToStore());
- auto mask = getMask();
- auto indexVec = getVecVal(storeOp.getIndices()[0]);
+ Value memref = storeOp.getMemRef();
+ Value value = getVecVal(storeOp.getValueToStore());
+ Value mask = getMask();
+ Value indexVec = getVecVal(storeOp.getIndices()[0]);
builder.create<vector::ScatterOp>(loc, memref, zero, indexVec, mask,
value);
@@ -554,7 +552,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
duplicatedResults.resize(numResults * factor);
for (auto &&[i, arg] : llvm::enumerate(op.getOperands())) {
- auto unpacked = getUnpackedVals(arg);
+ ValueRange unpacked = getUnpackedVals(arg);
assert(unpacked.size() == factor && "Invalid unpacked size");
for (auto j : llvm::seq(0u, factor))
duplicatedArgs[j * numArgs + i] = unpacked[j];
@@ -565,7 +563,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
.drop_front(numArgs * i)
.take_front(numArgs);
scalarMapping.map(op.getOperands(), args);
- auto results = builder.clone(op, scalarMapping)->getResults();
+ ValueRange results = builder.clone(op, scalarMapping)->getResults();
for (auto j : llvm::seq(0u, numResults))
duplicatedResults[j * factor + i] = results[j];
@@ -581,7 +579,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
// Vectorize `scf.reduce` op.
auto reduceOp = cast<scf::ReduceOp>(loop.getBody()->getTerminator());
- llvm::SmallVector<Value> reduceVals;
+ SmallVector<Value> reduceVals;
reduceVals.reserve(reduceOp.getNumOperands());
for (auto &&[body, arg] :
@@ -595,16 +593,17 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
// Generate `vector.reduce` if possible.
Value redArg = getVecVal(arg);
if (redArg) {
- auto neutral = arith::getNeutralElement(&reduceBody.front());
+ std::optional<TypedAttr> neutral =
+ arith::getNeutralElement(&reduceBody.front());
assert(neutral && "getNeutralElement has unepectedly failed");
Value neutralVal = builder.create<arith::ConstantOp>(loc, *neutral);
Value neutralVec =
builder.create<vector::SplatOp>(loc, neutralVal, redArg.getType());
- auto mask = getMask();
+ Value mask = getMask();
redArg = builder.create<arith::SelectOp>(loc, mask, redArg, neutralVec);
}
- auto fmf = getFMF(reduceBody.front());
+ arith::FastMathFlags fmf = getFMF(reduceBody.front());
reduceVal =
builder.create<vector::ReductionOp>(loc, *redKind, redArg, fmf);
} else {
@@ -615,16 +614,16 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
// individually.
auto reduceTerm = cast<scf::ReduceReturnOp>(reduceBody.getTerminator());
- auto lhs = reduceBody.getArgument(0);
- auto rhs = reduceBody.getArgument(1);
- auto unpacked = getUnpackedVals(arg);
+ Value lhs = reduceBody.getArgument(0);
+ Value rhs = reduceBody.getArgument(1);
+ ValueRange unpacked = getUnpackedVals(arg);
assert(unpacked.size() == factor && "Invalid unpacked size");
reduceVal = unpacked.front();
for (auto i : llvm::seq(1u, factor)) {
Value val = unpacked[i];
scalarMapping.map(lhs, reduceVal);
scalarMapping.map(rhs, val);
- for (auto &redOp : reduceBody.without_terminator())
+ for (Operation &redOp : reduceBody.without_terminator())
builder.clone(redOp, scalarMapping);
reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult());
@@ -648,7 +647,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
builder.createOrFold<arith::MulIOp>(loc, newCount, factorVal);
newLower = builder.createOrFold<arith::AddIOp>(loc, origLower, newLower);
- auto lowerCopy = llvm::to_vector(loop.getLowerBound());
+ SmallVector<Value> lowerCopy = llvm::to_vector(loop.getLowerBound());
lowerCopy[dim] = newLower;
loop.getLowerBoundMutable().assign(lowerCopy);
loop.getInitValsMutable().assign(newLoop.getResults());
>From a4ea5d981ac3a01e5ef5ceb21175388f02820d21 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 3 Jun 2024 19:48:27 +0200
Subject: [PATCH 10/10] remove tmp var
---
mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
index b7d1281fb20ca..d441dffc6c58a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp
@@ -287,8 +287,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop,
Value newIndexVar = newLoop.getInductionVars()[dim];
auto toVectorType = [&](Type elemType) -> VectorType {
- auto f = static_cast<int64_t>(factor);
- return VectorType::get(f, elemType);
+ return VectorType::get(factor, elemType);
};
IRMapping mapping;
More information about the Mlir-commits
mailing list