[Mlir-commits] [mlir] 5f26497 - [mlir][vector] Use `DenseI64ArrayAttr` in vector.multi_reduction (#102637)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Aug 10 06:10:27 PDT 2024
Author: Benjamin Maxwell
Date: 2024-08-10T14:10:24+01:00
New Revision: 5f26497da7de10c4eeec33b5a5cfcb47e96836cc
URL: https://github.com/llvm/llvm-project/commit/5f26497da7de10c4eeec33b5a5cfcb47e96836cc
DIFF: https://github.com/llvm/llvm-project/commit/5f26497da7de10c4eeec33b5a5cfcb47e96836cc.diff
LOG: [mlir][vector] Use `DenseI64ArrayAttr` in vector.multi_reduction (#102637)
This prevents some unnecessary conversions to/from int64_t and
IntegerAttr.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 925eb80dbe71ec..b96f5c2651bce5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -286,7 +286,7 @@ def Vector_MultiDimReductionOp :
Arguments<(ins Vector_CombiningKindAttr:$kind,
AnyVector:$source,
AnyType:$acc,
- I64ArrayAttr:$reduction_dims)>,
+ DenseI64ArrayAttr:$reduction_dims)>,
Results<(outs AnyType:$dest)> {
let summary = "Multi-dimensional reduction operation";
let description = [{
@@ -325,8 +325,8 @@ def Vector_MultiDimReductionOp :
SmallVector<bool> getReductionMask() {
SmallVector<bool> res(getSourceVectorType().getRank(), false);
- for (auto ia : getReductionDims().getAsRange<IntegerAttr>())
- res[ia.getInt()] = true;
+ for (int64_t dim : getReductionDims())
+ res[dim] = true;
return res;
}
static SmallVector<bool> getReductionMask(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ab4485c37e5e7f..44bd4aa76ffbd6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -445,8 +445,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
for (const auto &en : llvm::enumerate(reductionMask))
if (en.value())
reductionDims.push_back(en.index());
- build(builder, result, kind, source, acc,
- builder.getI64ArrayAttr(reductionDims));
+ build(builder, result, kind, source, acc, reductionDims);
}
OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
@@ -466,12 +465,14 @@ LogicalResult MultiDimReductionOp::verify() {
SmallVector<bool> scalableDims;
Type inferredReturnType;
auto sourceScalableDims = getSourceVectorType().getScalableDims();
- for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
- if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
- return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
- })) {
- targetShape.push_back(it.value());
- scalableDims.push_back(sourceScalableDims[it.index()]);
+ for (auto [dimIdx, dimSize] :
+ llvm::enumerate(getSourceVectorType().getShape()))
+ if (!llvm::any_of(getReductionDims(),
+ [dimIdx = dimIdx](int64_t reductionDimIdx) {
+ return reductionDimIdx == static_cast<int64_t>(dimIdx);
+ })) {
+ targetShape.push_back(dimSize);
+ scalableDims.push_back(sourceScalableDims[dimIdx]);
}
// TODO: update to also allow 0-d vectors when available.
if (targetShape.empty())
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index ac576ed0b4f097..716da55ba09aec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -67,10 +67,7 @@ class InnerOuterDimReductionConversion
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Separate reduction and parallel dims
- auto reductionDimsRange =
- multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
- auto reductionDims = llvm::to_vector<4>(llvm::map_range(
- reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
+ ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
reductionDims.end());
int64_t reductionSize = reductionDims.size();
More information about the Mlir-commits
mailing list