[Mlir-commits] [mlir] [mlir][vector] Use `DenseI64ArrayAttr` in vector.multi_reduction (PR #102637)

Benjamin Maxwell llvmlistbot at llvm.org
Sat Aug 10 05:39:02 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/102637

>From 1714e7f2dec96b9a5bc2f43f3b08359ddfcf9e5c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 9 Aug 2024 15:45:40 +0000
Subject: [PATCH 1/2] [mlir][vector] Use `DenseI64ArrayAttr` in
 vector.multi_reduction

This prevents some unnecessary conversions to/from int64_t and
IntegerAttr.
---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td           | 6 +++---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp                   | 7 +++----
 .../Vector/Transforms/LowerVectorMultiReduction.cpp        | 5 +----
 3 files changed, 7 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 925eb80dbe71ec..b96f5c2651bce5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -286,7 +286,7 @@ def Vector_MultiDimReductionOp :
     Arguments<(ins Vector_CombiningKindAttr:$kind,
                    AnyVector:$source,
                    AnyType:$acc,
-                   I64ArrayAttr:$reduction_dims)>,
+                   DenseI64ArrayAttr:$reduction_dims)>,
     Results<(outs AnyType:$dest)> {
   let summary = "Multi-dimensional reduction operation";
   let description = [{
@@ -325,8 +325,8 @@ def Vector_MultiDimReductionOp :
 
     SmallVector<bool> getReductionMask() {
       SmallVector<bool> res(getSourceVectorType().getRank(), false);
-      for (auto ia : getReductionDims().getAsRange<IntegerAttr>())
-        res[ia.getInt()] = true;
+      for (int64_t dim : getReductionDims())
+        res[dim] = true;
       return res;
     }
     static SmallVector<bool> getReductionMask(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ab4485c37e5e7f..60b4f93a53ad43 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -445,8 +445,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder,
   for (const auto &en : llvm::enumerate(reductionMask))
     if (en.value())
       reductionDims.push_back(en.index());
-  build(builder, result, kind, source, acc,
-        builder.getI64ArrayAttr(reductionDims));
+  build(builder, result, kind, source, acc, reductionDims);
 }
 
 OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
@@ -467,8 +466,8 @@ LogicalResult MultiDimReductionOp::verify() {
   Type inferredReturnType;
   auto sourceScalableDims = getSourceVectorType().getScalableDims();
   for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
-    if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
-          return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
+    if (!llvm::any_of(getReductionDims(), [&](int64_t dim) {
+          return dim == static_cast<int64_t>(it.index());
         })) {
       targetShape.push_back(it.value());
       scalableDims.push_back(sourceScalableDims[it.index()]);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index ac576ed0b4f097..716da55ba09aec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -67,10 +67,7 @@ class InnerOuterDimReductionConversion
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
 
     // Separate reduction and parallel dims
-    auto reductionDimsRange =
-        multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
-    auto reductionDims = llvm::to_vector<4>(llvm::map_range(
-        reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
     llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
                                                   reductionDims.end());
     int64_t reductionSize = reductionDims.size();

>From 24780a2d410d1a61ecc8b47fefdd668d7b47d184 Mon Sep 17 00:00:00 2001
From: MacDue <macdue at dueutil.tech>
Date: Sat, 10 Aug 2024 13:37:24 +0100
Subject: [PATCH 2/2] Fix up

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 60b4f93a53ad43..44bd4aa76ffbd6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -465,12 +465,14 @@ LogicalResult MultiDimReductionOp::verify() {
   SmallVector<bool> scalableDims;
   Type inferredReturnType;
   auto sourceScalableDims = getSourceVectorType().getScalableDims();
-  for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
-    if (!llvm::any_of(getReductionDims(), [&](int64_t dim) {
-          return dim == static_cast<int64_t>(it.index());
-        })) {
-      targetShape.push_back(it.value());
-      scalableDims.push_back(sourceScalableDims[it.index()]);
+  for (auto [dimIdx, dimSize] :
+       llvm::enumerate(getSourceVectorType().getShape()))
+    if (!llvm::any_of(getReductionDims(),
+                      [dimIdx = dimIdx](int64_t reductionDimIdx) {
+                        return reductionDimIdx == static_cast<int64_t>(dimIdx);
+                      })) {
+      targetShape.push_back(dimSize);
+      scalableDims.push_back(sourceScalableDims[dimIdx]);
     }
   // TODO: update to also allow 0-d vectors when available.
   if (targetShape.empty())



More information about the Mlir-commits mailing list