[Mlir-commits] [mlir] [mlir][scf][vector] Add `scf.parallel` vectorizer (PR #94168)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jun 2 18:40:35 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
Add `scf.parallel` vectorizer utilities and a test pass.
Add 2 functions:
* `getLoopVectorizeInfo` - collect `scf.parallel` loop vectorization info for the specific dimension, and target vector size. Returns number of ops which will be potentially vectorized, vectorization factor and if masked mode can be used.
* `vectorizeLoop` - unrolls specified `scf.parallel` dimension `factor` times and vectorizes ops if possible. Non-vectorizable ops will be replicated.
`scf.reduce` reductions are supported and will use vector reduction if possible.
Ops with nested regions beside `scf.reduce` are not supported yet.
Vectorizer has 2 modes:
* Masked - unroll loop to `ceildiv` number of iterations and use masked vector ops to handle out-of-bounds access.
* Non-masked - unroll to `floordiv` number of iterations and add a second loop to handle remaining items.
Upstreaming from `numba-mlir` project https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/SCFVectorize.cpp
Some initial upstreaming work by @<!-- -->makslevental
---
Patch is 51.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94168.diff
7 Files Affected:
- (added) mlir/include/mlir/Transforms/SCFVectorize.h (+70)
- (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Transforms/SCFVectorize.cpp (+648)
- (added) mlir/test/Transforms/test-scf-vectorize.mlir (+272)
- (modified) mlir/test/lib/Transforms/CMakeLists.txt (+1)
- (added) mlir/test/lib/Transforms/TestSCFVectorize.cpp (+110)
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+24-22)
``````````diff
diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h
new file mode 100644
index 0000000000000..d2a5e3085ae37
--- /dev/null
+++ b/mlir/include/mlir/Transforms/SCFVectorize.h
@@ -0,0 +1,70 @@
+//===- SCFVectorize.h - ------------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SCFVECTORIZE_H_
+#define MLIR_TRANSFORMS_SCFVECTORIZE_H_
+
+#include <optional>
+
+namespace mlir {
+class DataLayout;
+struct LogicalResult;
+namespace scf {
+class ParallelOp;
+}
+} // namespace mlir
+
+namespace mlir {
+
+/// Loop vectorization info
+struct SCFVectorizeInfo {
+ /// Loop dimension on which to vectorize.
+ unsigned dim = 0;
+
+ /// Biggest vector width, in elements.
+ unsigned factor = 0;
+
+ /// Number of ops, which will be vectorized.
+ unsigned count = 0;
+
+ /// Can use masked vector ops for our of bounds memory accesses.
+ bool masked = false;
+};
+
+/// Collect vectorization statistics on specified `scf.parallel` dimension.
+/// Return `SCFVectorizeInfo` or `std::nullopt` if loop cannot be vectorized on
+/// specified dimension.
+///
+/// `vectorBitwidth` - maximum vector size, in bits.
+std::optional<SCFVectorizeInfo>
+getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
+ unsigned vectorBitwidth, const DataLayout *DL = nullptr);
+
+/// Vectorization params
+struct SCFVectorizeParams {
+ /// Loop dimension on which to vectorize.
+ unsigned dim = 0;
+
+ /// Desired vector length, in elements
+ unsigned factor = 0;
+
+ /// Use masked vector ops for memory access outside loop bounds.
+ bool masked = false;
+};
+
+/// Vectorize loop on specified dimension with specified factor.
+///
+/// If `masked` is `true` and loop bound is not divisible by `factor`, instead
+/// of generating second loop to process remainig iterations, extend loop count
+/// and generate masked vector ops to handle out-of bounds memory accesses.
+mlir::LogicalResult vectorizeLoop(mlir::scf::ParallelOp loop,
+ const SCFVectorizeParams ¶ms,
+ const DataLayout *DL = nullptr);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46..ed71c73c938ed 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms
PrintIR.cpp
RemoveDeadValues.cpp
SCCP.cpp
+ SCFVectorize.cpp
SROA.cpp
StripDebugInfo.cpp
SymbolDCE.cpp
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
new file mode 100644
index 0000000000000..29e184e584a56
--- /dev/null
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -0,0 +1,648 @@
+//===- SCFVectorize.cpp - SCF vectorization utilities ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/SCFVectorize.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/IRMapping.h"
+
+using namespace mlir;
+
+static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); }
+
+/// Return type bitwidth for vectorization purposes or 0 if type cannot be
+/// vectorized.
+static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
+ if (!isSupportedVecElem(type))
+ return 0;
+
+ if (DL)
+ return DL->getTypeSizeInBits(type);
+
+ if (type.isIntOrFloat())
+ return type.getIntOrFloatBitWidth();
+
+ return 0;
+}
+
+static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) {
+ unsigned ret = 0;
+ for (auto arg : op.getOperands())
+ ret = std::max(ret, getTypeBitWidth(arg.getType(), DL));
+
+ for (auto res : op.getResults())
+ ret = std::max(ret, getTypeBitWidth(res.getType(), DL));
+
+ return ret;
+}
+
+static bool isSupportedVectorOp(Operation &op) {
+ return op.hasTrait<OpTrait::Vectorizable>();
+}
+
+/// Check if one `ValueRange` is permutation of another, i.e. contains same
+/// values, potentially in different order.
+static bool isRangePermutation(ValueRange val1, ValueRange val2) {
+ if (val1.size() != val2.size())
+ return false;
+
+ for (auto v1 : val1) {
+ auto it = llvm::find(val2, v1);
+ if (it == val2.end())
+ return false;
+ }
+ return true;
+}
+
+template <typename Op>
+static std::optional<unsigned>
+cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
+ const DataLayout *DL) {
+ auto loopIndexVars = loop.getInductionVars();
+ assert(dim < loopIndexVars.size());
+ auto memref = memOp.getMemRef();
+ auto type = cast<MemRefType>(memref.getType());
+ auto width = getTypeBitWidth(type.getElementType(), DL);
+ if (width == 0)
+ return std::nullopt;
+
+ if (!type.getLayout().isIdentity())
+ return std::nullopt;
+
+ if (!isRangePermutation(memOp.getIndices(), loopIndexVars))
+ return std::nullopt;
+
+ if (memOp.getIndices().back() != loopIndexVars[dim])
+ return std::nullopt;
+
+ DominanceInfo dom;
+ if (!dom.properlyDominates(memref, loop))
+ return std::nullopt;
+
+ return width;
+}
+
+/// Check if memref load/store can be converted into vectorized load/store
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
+static std::optional<unsigned>
+cavTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
+ const DataLayout *DL) {
+ assert(dim < loop.getInductionVars().size());
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
+
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL);
+
+ return std::nullopt;
+}
+
+template <typename Op>
+static std::optional<unsigned> canGatherScatterImpl(scf::ParallelOp loop, Op op,
+ const DataLayout *DL) {
+ auto memref = op.getMemRef();
+ auto memrefType = cast<MemRefType>(memref.getType());
+ auto width = getTypeBitWidth(memrefType.getElementType(), DL);
+ if (width == 0)
+ return std::nullopt;
+
+ DominanceInfo dom;
+ return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
+ memrefType.getLayout().isIdentity();
+}
+
+// Check if memref access can be converted into gather/scatter.
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
+static std::optional<unsigned>
+canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) {
+ if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+ return canGatherScatterImpl(loop, storeOp, DL);
+
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+ return canGatherScatterImpl(loop, loadOp, DL);
+
+ return std::nullopt;
+}
+
+static std::optional<unsigned> cenVectorizeMemrefOp(scf::ParallelOp loop,
+ unsigned dim, Operation &op,
+ const DataLayout *DL) {
+ if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op, DL))
+ return w;
+
+ return canGatherScatter(loop, op, DL);
+}
+
+/// Returns `vector.reduce` kind for specified `scf.parallel` reduce op ot
+/// `std::nullopt` if reduction cannot be handled by `vector.reduce`.
+static std::optional<vector::CombiningKind> getReductionKind(Block &body) {
+ if (!llvm::hasSingleElement(body.without_terminator()))
+ return std::nullopt;
+
+ // TODO: Move getCombinerOpKind to vector dialect.
+ return linalg::getCombinerOpKind(&body.front());
+}
+
+std::optional<SCFVectorizeInfo>
+mlir::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
+ unsigned vectorBitwidth, const DataLayout *DL) {
+ assert(dim < loop.getStep().size());
+ assert(vectorBitwidth > 0);
+ unsigned factor = vectorBitwidth / 8;
+ if (factor <= 1)
+ return std::nullopt;
+
+ /// Only step==1 is supported for now.
+ if (!isConstantIntValue(loop.getStep()[dim], 1))
+ return std::nullopt;
+
+ unsigned count = 0;
+ bool masked = true;
+
+ /// Check if `scf.reduce` can be handled by `vector.reduce`.
+ /// If not we still can vectorize the loop but we cannot use masked
+ /// vectorize.
+ auto reduce = cast<scf::ReduceOp>(loop.getBody()->getTerminator());
+ for (Region ® : reduce.getReductions()) {
+ if (!getReductionKind(reg.front()))
+ masked = false;
+
+ continue;
+ }
+
+ for (Operation &op : loop.getBody()->without_terminator()) {
+ /// Ops with nested regions are not supported yet.
+ if (op.getNumRegions() > 0)
+ return std::nullopt;
+
+ /// Check mem ops.
+ if (auto w = cenVectorizeMemrefOp(loop, dim, op, DL)) {
+ auto newFactor = vectorBitwidth / *w;
+ if (newFactor > 1) {
+ factor = std::min(factor, newFactor);
+ ++count;
+ }
+ continue;
+ }
+
+ /// If met the op which cannot be vectorized, we can replicate it and still
+ /// potentially vectorize other ops, but we cannot use masked vectorize.
+ if (!isSupportedVectorOp(op)) {
+ masked = false;
+ continue;
+ }
+
+ auto width = getArgsTypeWidth(op, DL);
+ if (width == 0)
+ return std::nullopt;
+
+ auto newFactor = vectorBitwidth / width;
+ if (newFactor <= 1)
+ continue;
+
+ factor = std::min(factor, newFactor);
+
+ ++count;
+ }
+
+ /// No ops to vectorize.
+ if (count == 0)
+ return std::nullopt;
+
+ return SCFVectorizeInfo{dim, factor, count, masked};
+}
+
+/// Get fastmath flags if ops support them or default (none).
+static arith::FastMathFlags getFMF(Operation &op) {
+ if (auto fmf = dyn_cast<arith::ArithFastMathInterface>(op))
+ return fmf.getFastMathFlagsAttr().getValue();
+
+ return arith::FastMathFlags::none;
+}
+
+LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
+ const SCFVectorizeParams ¶ms,
+ const DataLayout *DL) {
+ auto dim = params.dim;
+ auto factor = params.factor;
+ auto masked = params.masked;
+ assert(dim < loop.getStep().size());
+ assert(factor > 1);
+ assert(isConstantIntValue(loop.getStep()[dim], 1));
+
+ OpBuilder builder(loop);
+ auto lower = llvm::to_vector(loop.getLowerBound());
+ auto upper = llvm::to_vector(loop.getUpperBound());
+ auto step = llvm::to_vector(loop.getStep());
+
+ auto loc = loop.getLoc();
+
+ auto origIndexVar = loop.getInductionVars()[dim];
+
+ Value factorVal = builder.create<arith::ConstantIndexOp>(loc, factor);
+
+ auto origLower = lower[dim];
+ auto origUpper = upper[dim];
+ Value count = builder.createOrFold<arith::SubIOp>(loc, origUpper, origLower);
+ Value newCount;
+
+ // Compute new loop count, ceildiv if masked, floordiv otherwise.
+ if (masked) {
+ newCount = builder.createOrFold<arith::CeilDivSIOp>(loc, count, factorVal);
+ } else {
+ newCount = builder.createOrFold<arith::DivSIOp>(loc, count, factorVal);
+ }
+
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ lower[dim] = zero;
+ upper[dim] = newCount;
+
+ // Vectorized loop.
+ auto newLoop = builder.create<scf::ParallelOp>(loc, lower, upper, step,
+ loop.getInitVals());
+ auto newIndexVar = newLoop.getInductionVars()[dim];
+
+ auto toVectorType = [&](Type elemType) -> VectorType {
+ int64_t f = factor;
+ return VectorType::get(f, elemType);
+ };
+
+ IRMapping mapping;
+ IRMapping scalarMapping;
+
+ auto createPosionVec = [&](VectorType vecType) -> Value {
+ return builder.create<ub::PoisonOp>(loc, vecType, nullptr);
+ };
+
+ Value indexVarMult;
+ auto getrIndexVarMult = [&]() -> Value {
+ if (indexVarMult)
+ return indexVarMult;
+
+ indexVarMult =
+ builder.createOrFold<arith::MulIOp>(loc, newIndexVar, factorVal);
+ return indexVarMult;
+ };
+
+ // Get vector value in new loop for provided `orig` value in source loop.
+ auto getVecVal = [&](Value orig) -> Value {
+ // Use cached value if present.
+ if (auto mapped = mapping.lookupOrNull(orig))
+ return mapped;
+
+ // Vectorized loop index, loop index is divided by factor, so for factorN
+ // vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)`
+ if (orig == origIndexVar) {
+ auto vecType = toVectorType(builder.getIndexType());
+ llvm::SmallVector<Attribute> elems(factor);
+ for (auto i : llvm::seq(0u, factor))
+ elems[i] = builder.getIndexAttr(i);
+ auto attr = DenseElementsAttr::get(vecType, elems);
+ Value vec = builder.create<arith::ConstantOp>(loc, vecType, attr);
+
+ Value idx = getrIndexVarMult();
+ idx = builder.createOrFold<arith::AddIOp>(loc, idx, origLower);
+ idx = builder.create<vector::SplatOp>(loc, idx, vecType);
+ vec = builder.createOrFold<arith::AddIOp>(loc, idx, vec);
+ mapping.map(orig, vec);
+ return vec;
+ }
+ auto type = orig.getType();
+ assert(isSupportedVecElem(type));
+
+ Value val = orig;
+ auto origIndexVars = loop.getInductionVars();
+ auto it = llvm::find(origIndexVars, orig);
+
+ // If loop index, but not on vectorized dimension, just take new loop index
+ // and splat it.
+ if (it != origIndexVars.end())
+ val = newLoop.getInductionVars()[it - origIndexVars.begin()];
+
+ // Values which are defined inside loop body are preemptively added to the
+ // mapper and not handled here. Values defined outside body are just
+ // splatted.
+
+ auto vecType = toVectorType(type);
+ Value vec = builder.create<vector::SplatOp>(loc, val, vecType);
+ mapping.map(orig, vec);
+ return vec;
+ };
+
+ llvm::DenseMap<Value, llvm::SmallVector<Value>> unpackedVals;
+
+ // Get unpacked values for provided `orig` value in source loop.
+ // Values are returned as `ValueRange` and not as vector value.
+ auto getUnpackedVals = [&](Value val) -> ValueRange {
+ // Use cached values if present.
+ auto it = unpackedVals.find(val);
+ if (it != unpackedVals.end())
+ return it->second;
+
+ // Values which are defined inside loop body are preemptively added to the
+ // cache and not handled here.
+
+ auto &ret = unpackedVals[val];
+ assert(ret.empty());
+ if (!isSupportedVecElem(val.getType())) {
+ // Non vectorizable value, it must be a value defined outside the loop,
+ // just replicate it.
+ ret.resize(factor, val);
+ return ret;
+ }
+
+ // Get vector value and extract elements from it.
+ auto vecVal = getVecVal(val);
+ ret.resize(factor);
+ for (auto i : llvm::seq(0u, factor)) {
+ Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
+ ret[i] = builder.create<vector::ExtractElementOp>(loc, vecVal, idx);
+ }
+ return ret;
+ };
+
+ // Add unpacked values to the cache.
+ auto setUnpackedVals = [&](Value origVal, ValueRange newVals) {
+ assert(newVals.size() == factor);
+ assert(unpackedVals.count(origVal) == 0);
+ unpackedVals[origVal].append(newVals.begin(), newVals.end());
+
+ auto type = origVal.getType();
+ if (!isSupportedVecElem(type))
+ return;
+
+ // If type is vectorizabale construct a vector add it to vector cache as
+ // well.
+ auto vecType = toVectorType(type);
+
+ Value vec = createPosionVec(vecType);
+ for (auto i : llvm::seq(0u, factor)) {
+ Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
+ vec = builder.create<vector::InsertElementOp>(loc, newVals[i], vec, idx);
+ }
+ mapping.map(origVal, vec);
+ };
+
+ Value mask;
+
+ // Contruct mask value and cache it. If not a masked mode mask is always all
+ // 1s.
+ auto getMask = [&]() -> Value {
+ if (mask)
+ return mask;
+
+ OpFoldResult maskSize;
+ if (masked) {
+ Value size = getrIndexVarMult();
+ maskSize = builder.createOrFold<arith::SubIOp>(loc, count, size);
+ } else {
+ maskSize = builder.getIndexAttr(factor);
+ }
+ auto vecType = toVectorType(builder.getI1Type());
+ mask = builder.create<vector::CreateMaskOp>(loc, vecType, maskSize);
+
+ return mask;
+ };
+
+ auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
+ return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
+ };
+
+ auto canGatherScatter = [&](auto op) {
+ return !!::canGatherScatterImpl(loop, op, DL);
+ };
+
+ // Get idices for vectorized memref load/store.
+ auto getMemrefVecIndices = [&](ValueRange indices) {
+ scalarMapping.clear();
+ scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
+
+ llvm::SmallVector<Value> ret(indices.size());
+ for (auto &&[i, val] : llvm::enumerate(indices)) {
+ if (val == origIndexVar) {
+ Value idx = getrIndexVarMult();
+ idx = builder.createOrFold<arith::AddIOp>(loc, idx, origLower);
+ ret[i] = idx;
+ continue;
+ }
+ ret[i] = scalarMapping.lookup(val);
+ }
+
+ return ret;
+ };
+
+ // Create vectorized memref load for specified non-vectorized load.
+ auto genLoad = [&](auto loadOp) {
+ auto indices = getMemrefVecIndices(loadOp.getIndices());
+ auto resType = toVectorType(loadOp.getResult().getType());
+ auto memref = loadOp.getMemRef();
+ Value vecLoad;
+ if (masked) {
+ auto mask = getMask();
+ auto init = createPosionVec(resType);
+ vecLoad = builder.create<vector::MaskedLoadOp>(loc, resType, memref,
+ indices, mask, init);
+ } else {
+ vecLoad = builder.create<vector::LoadOp>(loc, resType, memref, indices);
+ }
+ mapping.map(loadOp.getResult(), vecLoad);
+ };
+
+ // Create vectorized memref store for specified non-vectorized store.
+ auto genStore = [&](auto storeOp) {
+ auto indices = getMemrefVecIndices(storeOp.getIndices());
+ auto value = getVecVal(storeOp.getValueToStore());
+ auto memref = storeOp.getMemRef();
+ if (masked) {
+ auto mask = getMask();
+ builder.create<vector::MaskedStoreOp>(loc, memref, indices, mask, value);
+ } else {
+ builder.create<vector::StoreOp>(loc, value, memref, indices);
+ }
+ };
+
+ llvm::SmallVector<Value> duplicatedArgs;
+ llvm::SmallVector<Value> duplicatedResults;
+
+ builder.setInsertionPointToStart(newLoop.getBody());
+ for (Operation &op : loop.getBody()->without_terminator()) {
+ loc = op.getLoc();
+ if (isSupportedVectorOp(op)) {
+ // If op can be vectorized, clone it with vectorized inputs and update
+ // resuls to vectorized types.
+ for (auto arg : op.getOperands())
+ getVecVal(arg); // init mapper for op args
+
+ auto newOp = builder.clone(op, mapping);
+ for (auto res : newOp->getResults())
+ res.setType(toVectorType(res.getType()));
+
+ continue;
+ }
+
+ // Vectorize memref load/store ops, vector load/store are preffered over
+ // gather/scatter.
+ if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
+ if (canTriviallyVectorizeMemOp(loadOp)) {
+ genLoad(loadOp);
+ continue;
+ }
+ if (canGatherScatter(loadOp)) {
+ auto resType = toVectorType(loadOp.getResult().getType());
+ auto memref = loadOp.getMemRef();
+ auto mask = getMask();
+ auto indexVec = getVecVal(loadOp.getIndices()[0]);
+ auto init = createPosionVec(resType);
+
+ auto gather = builder.create<vector::GatherOp>(
+ loc, resType, memref, zero, indexVec, mask, init);
+ mapping.map(loadOp.getResult(), gather.getResult());
+ continue;
+ }
+ }
+
+ if (auto stor...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/94168
More information about the Mlir-commits
mailing list