[Mlir-commits] [mlir] b4444dc - [mlir][vector] Use `DenseI64ArrayAttr` for shuffle masks (#101163)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 30 07:00:17 PDT 2024


Author: Benjamin Maxwell
Date: 2024-07-30T15:00:14+01:00
New Revision: b4444dca47c41436aa781bfd38aac6eca856ef23

URL: https://github.com/llvm/llvm-project/commit/b4444dca47c41436aa781bfd38aac6eca856ef23
DIFF: https://github.com/llvm/llvm-project/commit/b4444dca47c41436aa781bfd38aac6eca856ef23.diff

LOG: [mlir][vector] Use `DenseI64ArrayAttr` for shuffle masks (#101163)

Follow on from #100997. This again removes from boilerplate conversions
to/from IntegerAttr and int64_t (otherwise, this is a NFC).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3cdbd21874567..434ff3956c250 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -421,7 +421,7 @@ def Vector_ShuffleOp :
                  TCresVTEtIsSameAsOpBase<0, 1>>,
      InferTypeOpAdaptor]>,
      Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
-                    I64ArrayAttr:$mask)>,
+                    DenseI64ArrayAttr:$mask)>,
      Results<(outs AnyVector:$vector)> {
   let summary = "shuffle operation";
   let description = [{
@@ -459,11 +459,7 @@ def Vector_ShuffleOp :
                : vector<f32>, vector<f32>           ; yields vector<2xf32>
     ```
   }];
-  let builders = [
-    OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef<int64_t>")>
-  ];
-  let hasFolder = 1;
-  let hasCanonicalizer = 1;
+
   let extraClassDeclaration = [{
     VectorType getV1VectorType() {
       return ::llvm::cast<VectorType>(getV1().getType());
@@ -475,7 +471,10 @@ def Vector_ShuffleOp :
       return ::llvm::cast<VectorType>(getVector().getType());
     }
   }];
+
   let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
+
+  let hasFolder = 1;
   let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f6b1c42dcd24c..53e18a2e9d299 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -994,7 +994,7 @@ class VectorShuffleOpConversion
     auto v2Type = shuffleOp.getV2VectorType();
     auto vectorType = shuffleOp.getResultVectorType();
     Type llvmType = typeConverter->convertType(vectorType);
-    auto maskArrayAttr = shuffleOp.getMask();
+    ArrayRef<int64_t> mask = shuffleOp.getMask();
 
     // Bail if result type cannot be lowered.
     if (!llvmType)
@@ -1015,7 +1015,7 @@ class VectorShuffleOpConversion
     if (rank <= 1 && v1Type == v2Type) {
       Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
           loc, adaptor.getV1(), adaptor.getV2(),
-          LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
+          llvm::to_vector_of<int32_t>(mask));
       rewriter.replaceOp(shuffleOp, llvmShuffleOp);
       return success();
     }
@@ -1029,8 +1029,7 @@ class VectorShuffleOpConversion
       eltType = cast<VectorType>(llvmType).getElementType();
     Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
     int64_t insPos = 0;
-    for (const auto &en : llvm::enumerate(maskArrayAttr)) {
-      int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
+    for (int64_t extPos : mask) {
       Value value = adaptor.getV1();
       if (extPos >= v1Dim) {
         extPos -= v1Dim;

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 890706bf1bb2e..21b8858989839 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -527,10 +527,7 @@ struct VectorShuffleOpConvert final
       return rewriter.notifyMatchFailure(shuffleOp,
                                          "unsupported result vector type");
 
-    SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
-        shuffleOp.getMask(), [](Attribute attr) -> int32_t {
-          return cast<IntegerAttr>(attr).getValue().getZExtValue();
-        });
+    auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
 
     VectorType oldV1Type = shuffleOp.getV1VectorType();
     VectorType oldV2Type = shuffleOp.getV2VectorType();

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 669ae586e5786..5047bd925d4c5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2464,11 +2464,6 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // ShuffleOp
 //===----------------------------------------------------------------------===//
 
-void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
-                      Value v2, ArrayRef<int64_t> mask) {
-  build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
-}
-
 LogicalResult ShuffleOp::verify() {
   VectorType resultType = getResultVectorType();
   VectorType v1Type = getV1VectorType();
@@ -2491,8 +2486,8 @@ LogicalResult ShuffleOp::verify() {
       return emitOpError("dimension mismatch");
   }
   // Verify mask length.
-  auto maskAttr = getMask().getValue();
-  int64_t maskLength = maskAttr.size();
+  ArrayRef<int64_t> mask = getMask();
+  int64_t maskLength = mask.size();
   if (maskLength <= 0)
     return emitOpError("invalid mask length");
   if (maskLength != resultType.getDimSize(0))
@@ -2500,10 +2495,9 @@ LogicalResult ShuffleOp::verify() {
   // Verify all indices.
   int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
                       (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
-  for (const auto &en : llvm::enumerate(maskAttr)) {
-    auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
-    if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
-      return emitOpError("mask index #") << (en.index() + 1) << " out of range";
+  for (auto [idx, maskPos] : llvm::enumerate(mask)) {
+    if (maskPos < 0 || maskPos >= indexSize)
+      return emitOpError("mask index #") << (idx + 1) << " out of range";
   }
   return success();
 }
@@ -2527,13 +2521,12 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
   return success();
 }
 
-static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
-  uint64_t expected = begin;
-  return idxArr.size() == width &&
-         llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
-                      [&expected](auto attr) {
-                        return attr.getZExtValue() == expected++;
-                      });
+template <typename T>
+static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
+  T expected = begin;
+  return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
+           return value == expected++;
+         });
 }
 
 OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
@@ -2568,8 +2561,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
   SmallVector<Attribute> results;
   auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
   auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
-  for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
-    int64_t i = index.getZExtValue();
+  for (int64_t i : this->getMask()) {
     if (i >= lhsSize) {
       results.push_back(rhsElements[i - lhsSize]);
     } else {
@@ -2590,13 +2582,13 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
   LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
                                 PatternRewriter &rewriter) const override {
     VectorType v1VectorType = shuffleOp.getV1VectorType();
-    ArrayAttr mask = shuffleOp.getMask();
+    ArrayRef<int64_t> mask = shuffleOp.getMask();
     if (v1VectorType.getRank() > 0)
       return failure();
     if (mask.size() != 1)
       return failure();
     VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
-    if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
+    if (mask[0] == 0)
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
                                                        shuffleOp.getV1());
     else
@@ -2651,11 +2643,11 @@ class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
           op, "ShuffleOp types don't match an interleave");
     }
 
-    ArrayAttr shuffleMask = op.getMask();
+    ArrayRef<int64_t> shuffleMask = op.getMask();
     int64_t resultVectorSize = resultType.getNumElements();
     for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
-      int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
-      int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
+      int64_t maskValueA = shuffleMask[i * 2];
+      int64_t maskValueB = shuffleMask[(i * 2) + 1];
       if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
         return rewriter.notifyMatchFailure(op,
                                            "ShuffleOp mask not interleaving");

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 37216cea7b615..ec2ef3fc7501c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -225,8 +225,7 @@ class Convert1DExtractStridedSliceIntoShuffle
          off += stride)
       offsets.push_back(off);
     rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
-                                           op.getVector(),
-                                           rewriter.getI64ArrayAttr(offsets));
+                                           op.getVector(), offsets);
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 4a3ae1b850517..868397f2daaae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -232,8 +232,7 @@ struct LinearizeVectorExtractStridedSlice final
     }
     // Perform a shuffle to extract the kD vector.
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        extractOp, dstType, srcVector, srcVector,
-        rewriter.getI64ArrayAttr(indices));
+        extractOp, dstType, srcVector, srcVector, indices);
     return success();
   }
 
@@ -298,20 +297,17 @@ struct LinearizeVectorShuffle final
     // that needs to be shuffled to the destination vector. If shuffleSliceLen >
     // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
     // elements) instead of scalars.
-    ArrayAttr mask = shuffleOp.getMask();
+    ArrayRef<int64_t> mask = shuffleOp.getMask();
     int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
     llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
-    for (auto [i, value] :
-         llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
-
-      int64_t v = value.getZExtValue();
+    for (auto [i, value] : llvm::enumerate(mask)) {
       std::iota(indices.begin() + shuffleSliceLen * i,
                 indices.begin() + shuffleSliceLen * (i + 1),
-                shuffleSliceLen * v);
+                shuffleSliceLen * value);
     }
 
-    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
+                                                   vec2, indices);
     return success();
   }
 
@@ -368,8 +364,7 @@ struct LinearizeVectorExtract final
     llvm::SmallVector<int64_t, 2> indices(size);
     std::iota(indices.begin(), indices.end(), linearizedOffset);
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
-        rewriter.getI64ArrayAttr(indices));
+        extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
 
     return success();
   }
@@ -452,8 +447,7 @@ struct LinearizeVectorInsert final
                                            // [offset+srcNumElements, end)
 
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
-        rewriter.getI64ArrayAttr(indices));
+        insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
 
     return success();
   }


        


More information about the Mlir-commits mailing list