[Mlir-commits] [mlir] [mlir][vector] Use `DenseI64ArrayAttr` for constant_mask dim sizes (PR #100997)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Jul 29 04:04:00 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/100997

>From 10b7a1250b593152151143f67b08330c93129b48 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 29 Jul 2024 10:55:54 +0000
Subject: [PATCH] [mlir][vector] Use `DenseI64ArrayAttr` for constant_mask dim
 sizes

This prevents a bunch of boilerplate conversions to/from IntegerAttrs
and int64_ts. Other than that this is a NFC.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 44 +++++++------------
 .../Vector/Transforms/LowerVectorMask.cpp     |  6 +--
 .../Transforms/VectorDropLeadUnitDim.cpp      |  6 +--
 .../Transforms/VectorEmulateNarrowType.cpp    | 17 +++----
 5 files changed, 30 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 39ad03c801140..3cdbd21874567 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2443,7 +2443,7 @@ def Vector_TypeCastOp :
 
 def Vector_ConstantMaskOp :
   Vector_Op<"constant_mask", [Pure]>,
-    Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
+    Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
     Results<(outs VectorOfAnyRankOf<[I1]>)> {
   let summary = "creates a constant vector mask";
   let description = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d297c40760cd8..669ae586e5786 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
     // Inspect constant mask index. If the index exceeds the
     // dimension size, all bits are set. If the index is zero
     // or less, no bits are set.
-    ArrayAttr masks = m.getMaskDimSizes();
+    ArrayRef<int64_t> masks = m.getMaskDimSizes();
     auto shape = m.getType().getShape();
     bool allTrue = true;
     bool allFalse = true;
     for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
-      int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
-      if (i < dimSize)
+      if (maskIdx < dimSize)
         allTrue = false;
-      if (i > 0)
+      if (maskIdx > 0)
         allFalse = false;
     }
     if (allTrue)
@@ -3593,8 +3592,7 @@ class StridedSliceConstantMaskFolder final
     if (extractStridedSliceOp.hasNonUnitStrides())
       return failure();
     // Gather constant mask dimension sizes.
-    SmallVector<int64_t, 4> maskDimSizes;
-    populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
+    ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     // Gather strided slice offsets and sizes.
     SmallVector<int64_t, 4> sliceOffsets;
     populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
@@ -3625,7 +3623,7 @@ class StridedSliceConstantMaskFolder final
     // region.
     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
         extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
-        vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
+        sliceMaskDimSizes);
     return success();
   }
 };
@@ -5410,21 +5408,19 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
     }
 
     if (constantMaskOp) {
-      auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+      auto maskDimSizes = constantMaskOp.getMaskDimSizes();
       auto numMaskOperands = maskDimSizes.size();
 
       // Check every mask dim size to see whether it can be dropped
       for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
            --i) {
-        if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
+        if (maskDimSizes[i] != 1)
           return failure();
       }
 
       auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
-      ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
-
       rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
-                                                          newMaskOperandsAttr);
+                                                          newMaskOperands);
       return success();
     }
 
@@ -5804,12 +5800,10 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
 
     // ConstantMaskOp case.
     auto maskDimSizes = constantMaskOp.getMaskDimSizes();
-    SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
-    applyPermutationToVector(newMaskDimSizes, permutation);
+    auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
 
     rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
-        transpOp, transpOp.getResultVectorType(),
-        ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
+        transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
     return success();
   }
 };
@@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
   if (resultType.getRank() == 0) {
     if (getMaskDimSizes().size() != 1)
       return emitError("array attr must have length 1 for 0-D vectors");
-    auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
+    auto dim = getMaskDimSizes()[0];
     if (dim != 0 && dim != 1)
       return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
     return success();
@@ -5846,9 +5840,8 @@ LogicalResult ConstantMaskOp::verify() {
   // result dimension size.
   auto resultShape = resultType.getShape();
   auto resultScalableDims = resultType.getScalableDims();
-  SmallVector<int64_t, 4> maskDimSizes;
-  for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
-    int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+  ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
+  for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
     if (maskDimSize < 0 || maskDimSize > resultShape[index])
       return emitOpError(
           "array attr of size out of bounds of vector result dimension size");
@@ -5856,7 +5849,6 @@ LogicalResult ConstantMaskOp::verify() {
         maskDimSize != resultShape[index])
       return emitOpError(
           "only supports 'none set' or 'all set' scalable dimensions");
-    maskDimSizes.push_back(maskDimSize);
   }
   // Verify that if one mask dim size is zero, they all should be zero (because
   // the mask region is a conjunction of each mask dimension interval).
@@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
   // Check the corner case of 0-D vectors first.
   if (resultType.getRank() == 0) {
     assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
-    return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
+    return getMaskDimSizes()[0] == 1;
   }
-  for (const auto [resultSize, intAttr] :
+  for (const auto [resultSize, maskDimSize] :
        llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
-    int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
     if (maskDimSize < resultSize)
       return false;
   }
@@ -6007,9 +5998,8 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
     }
 
     // Replace 'createMaskOp' with ConstantMaskOp.
-    rewriter.replaceOpWithNewOp<ConstantMaskOp>(
-        createMaskOp, retTy,
-        vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
+    rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
+                                                maskDimSizes);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index dfeb7bc53adad..bfc05c71f5340 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -111,7 +111,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
     if (rank == 0) {
       assert(dimSizes.size() == 1 &&
              "Expected exactly one dim size for a 0-D vector");
-      bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
+      bool value = dimSizes.front() == 1;
       rewriter.replaceOpWithNewOp<arith::ConstantOp>(
           op, dstType,
           DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
@@ -119,7 +119,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
       return success();
     }
 
-    int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
+    int64_t trueDimSize = dimSizes.front();
 
     if (rank == 1) {
       if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
@@ -147,7 +147,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
 
     VectorType lowType = VectorType::Builder(dstType).dropDim(0);
     Value trueVal = rewriter.create<vector::ConstantMaskOp>(
-        loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
+        loc, lowType, dimSizes.drop_front());
     Value result = rewriter.create<arith::ConstantOp>(
         loc, dstType, rewriter.getZeroAttr(dstType));
     for (int64_t d = 0; d < trueDimSize; d++)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 7ed3dea42b771..42ac717b44c4b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -550,9 +550,7 @@ struct CastAwayConstantMaskLeadingOneDim
       return failure();
 
     int64_t dropDim = oldType.getRank() - newType.getRank();
-    SmallVector<int64_t> dimSizes;
-    for (auto attr : mask.getMaskDimSizes())
-      dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+    ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
 
     // If any of the dropped unit dims has a size of `0`, the entire mask is a
     // zero mask, else the unit dim has no effect on the mask.
@@ -563,7 +561,7 @@ struct CastAwayConstantMaskLeadingOneDim
     newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
 
     auto newMask = rewriter.create<vector::ConstantMaskOp>(
-        mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
+        mask.getLoc(), newType, newDimSizes);
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
     return success();
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ac2a4d3abcc68..d3296ee38c249 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -83,17 +83,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
     newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
                                                     newMaskOperands);
   } else if (constantMaskOp) {
-    ArrayRef<Attribute> maskDimSizes =
-        constantMaskOp.getMaskDimSizes().getValue();
+    ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     size_t numMaskOperands = maskDimSizes.size();
-    auto origIndex =
-        cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
-    IntegerAttr maskIndexAttr =
-        rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
-    SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
-    newMaskDimSizes.push_back(maskIndexAttr);
-    newMask = rewriter.create<vector::ConstantMaskOp>(
-        loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
+    int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+    int64_t maskIndex = (origIndex + scale - 1) / scale;
+    SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+    newMaskDimSizes.push_back(maskIndex);
+    newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+                                                      newMaskDimSizes);
   }
 
   while (!extractOps.empty()) {



More information about the Mlir-commits mailing list