[Mlir-commits] [mlir] [mlir][vector] Update the internal representation of in_bounds (PR #100336)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Aug 13 09:46:19 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/100336

>From 490cf97ac3a7ac424a21579fe57b9ce99a33a94f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 13 Aug 2024 12:08:52 +0000
Subject: [PATCH 1/2] [mlir] Add the ability to override attribute
 parsing/printing in attr-dicts

This adds a `parseNamedAttrFn` callback to
`AsmParser::parseOptionalAttrDict()`.

If parseNamedAttrFn is provided the default parsing can be overridden
for a named attribute. parseNamedAttrFn is passed the name of an
attribute, if it can parse the attribute it returns the parsed
attribute, otherwise, it returns `failure()` which indicates that
generic parsing should be used. Note: Returning a null Attribute from
parseNamedAttrFn indicates a parser error.

It also adds `printNamedAttrFn` to
`AsmPrinter::printOptionalAttrDict()`.

If printNamedAttrFn is provided the default printing can be overridden
for a named attribute. printNamedAttrFn is passed a NamedAttribute, if
it prints the attribute it returns `success()`, otherwise, it returns
`failure()` which indicates that generic printing should be used.
---
 mlir/include/mlir/IR/OpImplementation.h | 23 ++++++++++---
 mlir/lib/AsmParser/AsmParserImpl.h      |  7 ++--
 mlir/lib/AsmParser/AttributeParser.cpp  | 16 +++++++--
 mlir/lib/AsmParser/Parser.h             |  4 ++-
 mlir/lib/IR/AsmPrinter.cpp              | 45 ++++++++++++++++---------
 5 files changed, 71 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index ae412c7227f8ea..5891cbffc9542d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
   /// If the specified operation has attributes, print out an attribute
   /// dictionary with their values.  elidedAttrs allows the client to ignore
   /// specific well known attributes, commonly used if the attribute value is
-  /// printed some other way (like as a fixed operand).
+  /// printed some other way (like as a fixed operand). If printNamedAttrFn is
+  /// provided the default printing can be overridden for a named attribute.
+  /// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
+  /// it returns `success()`, otherwise, it returns `failure()` which indicates
+  /// that generic printing should be used.
   virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                                     ArrayRef<StringRef> elidedAttrs = {}) = 0;
+                                     ArrayRef<StringRef> elidedAttrs = {},
+                                     function_ref<LogicalResult(NamedAttribute)>
+                                         printNamedAttrFn = nullptr) = 0;
 
   /// If the specified operation has attributes, print out an attribute
   /// dictionary prefixed with 'attributes'.
@@ -1116,8 +1122,17 @@ class AsmParser {
     return parseResult;
   }
 
-  /// Parse a named dictionary into 'result' if it is present.
-  virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
+  /// Parse a named dictionary into 'result' if it is present. If
+  /// parseNamedAttrFn is provided the default parsing can be overridden for a
+  /// named attribute. parseNamedAttrFn is passed the name of an attribute, if
+  /// it can parse the attribute it returns the parsed attribute, otherwise, it
+  /// returns `failure()` which indicates that generic parsing should be used.
+  /// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
+  /// error.
+  virtual ParseResult parseOptionalAttrDict(
+      NamedAttrList &result,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+          nullptr) = 0;
 
   /// Parse a named dictionary into 'result' if the `attributes` keyword is
   /// present.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index b12687833e3fde..808b2ca282f64b 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
   }
 
   /// Parse a named dictionary into 'result' if it is present.
-  ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
+  ParseResult parseOptionalAttrDict(
+      NamedAttrList &result,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+          nullptr) override {
     if (parser.getToken().isNot(Token::l_brace))
       return success();
-    return parser.parseAttributeDict(result);
+    return parser.parseAttributeDict(result, parseNamedAttrFn);
   }
 
   /// Parse a named dictionary into 'result' if the `attributes` keyword is
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e49abc33b..b687d822e7cb7d 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
 ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
 ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
 ///
-ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
+ParseResult Parser::parseAttributeDict(
+    NamedAttrList &attributes,
+    function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
   llvm::SmallDenseSet<StringAttr> seenKeys;
   auto parseElt = [&]() -> ParseResult {
     // The name of an attribute can either be a bare identifier, or a string.
@@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
       return success();
     }
 
-    auto attr = parseAttribute();
+    Attribute attr = nullptr;
+    FailureOr<Attribute> customParsedAttribute;
+    // Try to parse with `printNamedAttrFn` callback.
+    if (parseNamedAttrFn &&
+        succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
+      attr = *customParsedAttribute;
+    } else {
+      // Otherwise, use generic attribute parser.
+      attr = parseAttribute();
+    }
+
     if (!attr)
       return failure();
     attributes.push_back({*nameId, attr});
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..d5d90f391fd391 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -256,7 +256,9 @@ class Parser {
   }
 
   /// Parse an attribute dictionary.
-  ParseResult parseAttributeDict(NamedAttrList &attributes);
+  ParseResult parseAttributeDict(
+      NamedAttrList &attributes,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);
 
   /// Parse a distinct attribute.
   Attribute parseDistinctAttr(Type type);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..cd9f70c8868b83 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -452,10 +452,13 @@ class AsmPrinter::Impl {
   void printDimensionList(ArrayRef<int64_t> shape);
 
 protected:
-  void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {},
-                             bool withKeyword = false);
-  void printNamedAttribute(NamedAttribute attr);
+  void printOptionalAttrDict(
+      ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
+      bool withKeyword = false,
+      function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
+  void printNamedAttribute(
+      NamedAttribute attr,
+      function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
   void printTrailingLocation(Location loc, bool allowAlias = true);
   void printLocationInternal(LocationAttr loc, bool pretty = false,
                              bool isTopLevel = false);
@@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
   /// Print the given set of attributes with names not included within
   /// 'elidedAttrs'.
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {}) override {
-    if (attrs.empty())
-      return;
+                             ArrayRef<StringRef> elidedAttrs = {},
+                             function_ref<LogicalResult(NamedAttribute)>
+                                 printNamedAttrFn = nullptr) override {
+    (void)printNamedAttrFn;
     if (elidedAttrs.empty()) {
       for (const NamedAttribute &attr : attrs)
         printAttribute(attr.getValue());
@@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Default([&](Type type) { return printDialectType(type); });
 }
 
-void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                                             ArrayRef<StringRef> elidedAttrs,
-                                             bool withKeyword) {
+void AsmPrinter::Impl::printOptionalAttrDict(
+    ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
+    bool withKeyword,
+    function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
   // If there are no attributes, then there is nothing to be done.
   if (attrs.empty())
     return;
@@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
 
     // Otherwise, print them all out in braces.
     os << " {";
-    interleaveComma(filteredAttrs,
-                    [&](NamedAttribute attr) { printNamedAttribute(attr); });
+    interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+      printNamedAttribute(attr, printNamedAttrFn);
+    });
     os << '}';
   };
 
@@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
   if (!filteredAttrs.empty())
     printFilteredAttributesFn(filteredAttrs);
 }
-void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
+void AsmPrinter::Impl::printNamedAttribute(
+    NamedAttribute attr,
+    function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
   // Print the name without quotes if possible.
   ::printKeywordOrString(attr.getName().strref(), os);
 
@@ -2729,6 +2737,10 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
     return;
 
   os << " = ";
+  if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
+    /// If we print via the `printNamedAttrFn` callback skip printing.
+    return;
+  }
   printAttribute(attr.getValue());
 }
 
@@ -3149,8 +3161,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
 
   /// Print an optional attribute dictionary with a given set of elided values.
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {}) override {
-    Impl::printOptionalAttrDict(attrs, elidedAttrs);
+                             ArrayRef<StringRef> elidedAttrs = {},
+                             function_ref<LogicalResult(NamedAttribute)>
+                                 printNamedAttrFn = nullptr) override {
+    Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
+                                printNamedAttrFn);
   }
   void printOptionalAttrDictWithKeyword(
       ArrayRef<NamedAttribute> attrs,

>From 380f20259caf35bbe1f03cde646e9a5d5ab5af76 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 18 Jul 2024 11:24:06 +0100
Subject: [PATCH 2/2] Update the internal representation of `in_bounds`

This PR updates the internal representation of the `in_bounds` attribute
for `xfer_read`/`xfer_write` Ops. Currently we use `ArrayAttr` - that's
being updated to `DenseBoolArrayAttribute`.

Note that this means that the asm format of the `xfer_{read|_write}`
will change from:

```mlir
vector.transfer_read %arg0[%0, %1], %cst {in_bounds = [true], permutation_map = #map3} : memref<12x16xf32>, vector<8xf32>
```

to:
```mlir
vector.transfer_read %arg0[%0, %1], %cst {in_bounds = array<i1: true>, permutation_map = #map3} : memref<12x16xf32>, vector<8xf32>
```
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 10 ++--
 .../mlir/Interfaces/VectorInterfaces.td       |  4 +-
 .../Conversion/VectorToSCF/VectorToSCF.cpp    |  5 +-
 .../ArmSME/Transforms/VectorLegalization.cpp  | 12 ++--
 .../Linalg/Transforms/Vectorization.cpp       | 14 ++---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 55 ++++++++++++-------
 .../Vector/Transforms/LowerVectorTransfer.cpp | 23 ++++----
 .../Transforms/VectorDropLeadUnitDim.cpp      | 14 ++---
 .../Transforms/VectorTransferOpTransforms.cpp | 10 ++--
 .../VectorTransferSplitRewritePatterns.cpp    |  2 +-
 .../Vector/Transforms/VectorTransforms.cpp    | 10 ++--
 11 files changed, 86 insertions(+), 73 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b96f5c2651bce5..386da3d977e468 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1248,7 +1248,7 @@ def Vector_TransferReadOp :
                    AffineMapAttr:$permutation_map,
                    AnyType:$padding,
                    Optional<VectorOf<[I1]>>:$mask,
-                   BoolArrayAttr:$in_bounds)>,
+                   DenseBoolArrayAttr:$in_bounds)>,
     Results<(outs AnyVectorOfAnyRank:$vector)> {
 
   let summary = "Reads a supervector from memory into an SSA vector value.";
@@ -1443,7 +1443,7 @@ def Vector_TransferReadOp :
                    "Value":$source,
                    "ValueRange":$indices,
                    "AffineMapAttr":$permutationMapAttr,
-                   "ArrayAttr":$inBoundsAttr)>,
+                   "DenseBoolArrayAttr":$inBoundsAttr)>,
     /// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
     OpBuilder<(ins "VectorType":$vectorType,
                    "Value":$source,
@@ -1495,7 +1495,7 @@ def Vector_TransferWriteOp :
                    Variadic<Index>:$indices,
                    AffineMapAttr:$permutation_map,
                    Optional<VectorOf<[I1]>>:$mask,
-                   BoolArrayAttr:$in_bounds)>,
+                   DenseBoolArrayAttr:$in_bounds)>,
     Results<(outs Optional<AnyRankedTensor>:$result)> {
 
   let summary = "The vector.transfer_write op writes a supervector to memory.";
@@ -1606,13 +1606,13 @@ def Vector_TransferWriteOp :
                    "ValueRange":$indices,
                    "AffineMapAttr":$permutationMapAttr,
                    "Value":$mask,
-                   "ArrayAttr":$inBoundsAttr)>,
+                   "DenseBoolArrayAttr":$inBoundsAttr)>,
     /// 2. Builder with type inference that sets an empty mask (variant with attrs).
     OpBuilder<(ins "Value":$vector,
                    "Value":$dest,
                    "ValueRange":$indices,
                    "AffineMapAttr":$permutationMapAttr,
-                   "ArrayAttr":$inBoundsAttr)>,
+                   "DenseBoolArrayAttr":$inBoundsAttr)>,
     /// 3. Builder with type inference that sets an empty mask (variant without attrs).
     OpBuilder<(ins "Value":$vector,
                    "Value":$dest,
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 7ea62c2ae2ab13..b2a381b4510085 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -98,7 +98,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
         dimension whether it is in-bounds or not. (Broadcast dimensions are
         always in-bounds).
       }],
-      /*retTy=*/"::mlir::ArrayAttr",
+      /*retTy=*/"::mlir::ArrayRef<bool>",
       /*methodName=*/"getInBounds",
       /*args=*/(ins)
     >,
@@ -241,7 +241,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
       if ($_op.isBroadcastDim(dim))
         return true;
       auto inBounds = $_op.getInBounds();
-      return ::llvm::cast<::mlir::BoolAttr>(inBounds[dim]).getValue();
+      return inBounds[dim];
     }
 
     /// Helper function to account for the fact that `permutationMap` results
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..cba39905fcc571 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -264,10 +264,11 @@ static void generateInBoundsCheck(
 }
 
 /// Given an ArrayAttr, return a copy where the first element is dropped.
-static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
+static DenseBoolArrayAttr dropFirstElem(OpBuilder &b, DenseBoolArrayAttr attr) {
   if (!attr)
     return attr;
-  return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
+  return DenseBoolArrayAttr::get(b.getContext(),
+                                 attr.asArrayRef().drop_front());
 }
 
 /// Add the pass label to a vector transfer op if its rank is not the target
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..84305758e9fc01 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -497,8 +497,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
           loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
           AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
           sliceMask,
-          rewriter.getBoolArrayAttr(
-              ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
+          rewriter.getDenseBoolArrayAttr(writeOp.getInBounds().drop_front()));
     }
 
     rewriter.eraseOp(writeOp);
@@ -691,13 +690,12 @@ struct LiftIllegalVectorTransposeToMemory
         transposeOp.getPermutation(), getContext());
     auto transposedSubview = rewriter.create<memref::TransposeOp>(
         loc, readSubview, AffineMapAttr::get(transposeMap));
-    ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
+    DenseBoolArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
     // - The `in_bounds` attribute
     if (inBoundsAttr) {
-      SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
-                                            inBoundsAttr.end());
+      SmallVector<bool> inBoundsValues(inBoundsAttr.asArrayRef());
       applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
-      inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
+      inBoundsAttr = rewriter.getDenseBoolArrayAttr(inBoundsValues);
     }
 
     VectorType legalReadType = resultType.clone(readType.getElementType());
@@ -902,7 +900,7 @@ struct LowerIllegalTransposeStoreViaZA
           rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
       auto smeWrite = rewriter.create<vector::TransferWriteOp>(
           loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
-          transposeMap, subMask, writeOp.getInBounds());
+          transposeMap, subMask, writeOp.getInBoundsAttr());
 
       if (writeOp.hasPureTensorSemantics())
         destTensorOrMemref = smeWrite.getResult();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..4d1f68fe9c4ade 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -646,7 +646,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
   if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
     auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
     SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
-    maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
+    maskedWriteOp.setInBoundsAttr(rewriter.getDenseBoolArrayAttr(inBounds));
   }
 
   LDBG("vectorized op: " << *write << "\n");
@@ -1364,7 +1364,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
     if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
       SmallVector<bool> inBounds(readType.getRank(), true);
       cast<vector::TransferReadOp>(maskOp.getMaskableOp())
-          .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
+          .setInBoundsAttr(rewriter.getDenseBoolArrayAttr(inBounds));
     }
 
     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
@@ -2397,7 +2397,7 @@ struct PadOpVectorizationWithTransferReadPattern
     rewriter.modifyOpInPlace(xferOp, [&]() {
       SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
       xferOp->setAttr(xferOp.getInBoundsAttrName(),
-                      rewriter.getBoolArrayAttr(inBounds));
+                      rewriter.getDenseBoolArrayAttr(inBounds));
       xferOp.getSourceMutable().assign(padOp.getSource());
       xferOp.getPaddingMutable().assign(padValue);
     });
@@ -2476,7 +2476,7 @@ struct PadOpVectorizationWithTransferWritePattern
     auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         xferOp, padOp.getSource().getType(), xferOp.getVector(),
         padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
-        xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
+        xferOp.getMask(), rewriter.getDenseBoolArrayAttr(inBounds));
     rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
 
     return success();
@@ -2780,7 +2780,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
   Value res = rewriter.create<vector::TransferReadOp>(
       xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
       xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
-      rewriter.getBoolArrayAttr(
+      rewriter.getDenseBoolArrayAttr(
           SmallVector<bool>(vectorType.getRank(), false)));
 
   if (maybeFillOp)
@@ -2839,7 +2839,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
   rewriter.create<vector::TransferWriteOp>(
       xferOp.getLoc(), vector, out, xferOp.getIndices(),
       xferOp.getPermutationMapAttr(), xferOp.getMask(),
-      rewriter.getBoolArrayAttr(
+      rewriter.getDenseBoolArrayAttr(
           SmallVector<bool>(vector.getType().getRank(), false)));
 
   rewriter.eraseOp(copyOp);
@@ -3339,7 +3339,7 @@ struct Conv1DGenerator
       SmallVector<bool> inBounds(maskShape.size(), true);
       auto xferOp = cast<VectorTransferOpInterface>(opToMask);
       xferOp->setAttr(xferOp.getInBoundsAttrName(),
-                      rewriter.getBoolArrayAttr(inBounds));
+                      rewriter.getDenseBoolArrayAttr(inBounds));
 
       SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
           cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 44bd4aa76ffbd6..3228307370e2be 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3758,7 +3758,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
 void TransferReadOp::build(OpBuilder &builder, OperationState &result,
                            VectorType vectorType, Value source,
                            ValueRange indices, AffineMapAttr permutationMapAttr,
-                           /*optional*/ ArrayAttr inBoundsAttr) {
+                           /*optional*/ DenseBoolArrayAttr inBoundsAttr) {
   Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
   Value padding = builder.create<arith::ConstantOp>(
       result.location, elemType, builder.getZeroAttr(elemType));
@@ -3773,8 +3773,8 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
                            std::optional<ArrayRef<bool>> inBounds) {
   auto permutationMapAttr = AffineMapAttr::get(permutationMap);
   auto inBoundsAttr = (inBounds && !inBounds.value().empty())
-                          ? builder.getBoolArrayAttr(inBounds.value())
-                          : builder.getBoolArrayAttr(
+                          ? builder.getDenseBoolArrayAttr(inBounds.value())
+                          : builder.getDenseBoolArrayAttr(
                                 SmallVector<bool>(vectorType.getRank(), false));
   build(builder, result, vectorType, source, indices, permutationMapAttr,
         inBoundsAttr);
@@ -3789,8 +3789,8 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
       llvm::cast<ShapedType>(source.getType()), vectorType);
   auto permutationMapAttr = AffineMapAttr::get(permutationMap);
   auto inBoundsAttr = (inBounds && !inBounds.value().empty())
-                          ? builder.getBoolArrayAttr(inBounds.value())
-                          : builder.getBoolArrayAttr(
+                          ? builder.getDenseBoolArrayAttr(inBounds.value())
+                          : builder.getDenseBoolArrayAttr(
                                 SmallVector<bool>(vectorType.getRank(), false));
   build(builder, result, vectorType, source, indices, permutationMapAttr,
         padding,
@@ -3842,7 +3842,7 @@ static LogicalResult
 verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
                  VectorType vectorType, VectorType maskType,
                  VectorType inferredMaskType, AffineMap permutationMap,
-                 ArrayAttr inBounds) {
+                 ArrayRef<bool> inBounds) {
   if (op->hasAttr("masked")) {
     return op->emitOpError("masked attribute has been removed. "
                            "Use in_bounds instead.");
@@ -3915,8 +3915,7 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
            << AffineMapAttr::get(permutationMap)
            << " vs inBounds of size: " << inBounds.size();
   for (unsigned int i = 0, e = permutationMap.getNumResults(); i < e; ++i)
-    if (isa<AffineConstantExpr>(permutationMap.getResult(i)) &&
-        !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
+    if (isa<AffineConstantExpr>(permutationMap.getResult(i)) && !inBounds[i])
       return op->emitOpError("requires broadcast dimensions to be in-bounds");
 
   return success();
@@ -3930,7 +3929,25 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
   // Elide in_bounds attribute if all dims are out-of-bounds.
   if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
     elidedAttrs.push_back(op.getInBoundsAttrName());
-  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs,
+                          [&](NamedAttribute attr) -> LogicalResult {
+                            if (attr.getName() != op.getInBoundsAttrName())
+                              return failure();
+                            cast<DenseBoolArrayAttr>(attr.getValue()).print(p);
+                            return success();
+                          });
+}
+
+template <typename XferOp>
+static ParseResult parseTransferAttrs(OpAsmParser &parser,
+                                      OperationState &result) {
+  auto inBoundsAttrName = XferOp::getInBoundsAttrName(result.name);
+  return parser.parseOptionalAttrDict(
+      result.attributes, [&](StringRef name) -> FailureOr<Attribute> {
+        if (name != inBoundsAttrName)
+          return failure();
+        return DenseBoolArrayAttr::parse(parser, {});
+      });
 }
 
 void TransferReadOp::print(OpAsmPrinter &p) {
@@ -3972,7 +3989,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
     if (parser.parseOperand(maskInfo))
       return failure();
   }
-  if (parser.parseOptionalAttrDict(result.attributes) ||
+  if (parseTransferAttrs<TransferReadOp>(parser, result) ||
       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
     return failure();
   if (types.size() != 2)
@@ -3997,7 +4014,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
   Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
   if (!inBoundsAttr) {
     result.addAttribute(inBoundsAttrName,
-                        builder.getBoolArrayAttr(
+                        builder.getDenseBoolArrayAttr(
                             SmallVector<bool>(permMap.getNumResults(), false)));
   }
   if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
@@ -4125,7 +4142,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
     return failure();
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(op.getContext());
-  op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
+  op.setInBoundsAttr(b.getDenseBoolArrayAttr(newInBounds));
   return success();
 }
 
@@ -4295,7 +4312,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                             Value vector, Value dest, ValueRange indices,
                             AffineMapAttr permutationMapAttr,
                             /*optional*/ Value mask,
-                            /*optional*/ ArrayAttr inBoundsAttr) {
+                            /*optional*/ DenseBoolArrayAttr inBoundsAttr) {
   Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
   build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
         mask, inBoundsAttr);
@@ -4305,7 +4322,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
 void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
                             Value vector, Value dest, ValueRange indices,
                             AffineMapAttr permutationMapAttr,
-                            /*optional*/ ArrayAttr inBoundsAttr) {
+                            /*optional*/ DenseBoolArrayAttr inBoundsAttr) {
   build(builder, result, vector, dest, indices, permutationMapAttr,
         /*mask=*/Value(), inBoundsAttr);
 }
@@ -4319,8 +4336,8 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
   auto permutationMapAttr = AffineMapAttr::get(permutationMap);
   auto inBoundsAttr =
       (inBounds && !inBounds.value().empty())
-          ? builder.getBoolArrayAttr(inBounds.value())
-          : builder.getBoolArrayAttr(SmallVector<bool>(
+          ? builder.getDenseBoolArrayAttr(inBounds.value())
+          : builder.getDenseBoolArrayAttr(SmallVector<bool>(
                 llvm::cast<VectorType>(vector.getType()).getRank(), false));
   build(builder, result, vector, dest, indices, permutationMapAttr,
         /*mask=*/Value(), inBoundsAttr);
@@ -4352,7 +4369,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
   ParseResult hasMask = parser.parseOptionalComma();
   if (hasMask.succeeded() && parser.parseOperand(maskInfo))
     return failure();
-  if (parser.parseOptionalAttrDict(result.attributes) ||
+  if (parseTransferAttrs<TransferWriteOp>(parser, result) ||
       parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
     return failure();
   if (types.size() != 2)
@@ -4378,7 +4395,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
   Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
   if (!inBoundsAttr) {
     result.addAttribute(inBoundsAttrName,
-                        builder.getBoolArrayAttr(
+                        builder.getDenseBoolArrayAttr(
                             SmallVector<bool>(permMap.getNumResults(), false)));
   }
   if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
@@ -4731,7 +4748,7 @@ struct SwapExtractSliceOfTransferWrite
     auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
         transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
         transferOp.getIndices(), transferOp.getPermutationMapAttr(),
-        rewriter.getBoolArrayAttr(newInBounds));
+        rewriter.getDenseBoolArrayAttr(newInBounds));
     rewriter.modifyOpInPlace(insertOp, [&]() {
       insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
     });
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 344cfc0cbffb93..7ecfb766f0b799 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -22,15 +22,14 @@ using namespace mlir::vector;
 
 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
 /// permutation based on the given indices.
-static ArrayAttr
-inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
+static DenseBoolArrayAttr
+inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayRef<bool> inBounds,
                              const SmallVector<unsigned> &permutation) {
   SmallVector<bool> newInBoundsValues(permutation.size());
   size_t index = 0;
   for (unsigned pos : permutation)
-    newInBoundsValues[pos] =
-        cast<BoolAttr>(attr.getValue()[index++]).getValue();
-  return builder.getBoolArrayAttr(newInBoundsValues);
+    newInBoundsValues[pos] = inBounds[index++];
+  return builder.getDenseBoolArrayAttr(newInBoundsValues);
 }
 
 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit
@@ -132,7 +131,7 @@ struct TransferReadPermutationLowering
     }
 
     // Transpose in_bounds attribute.
-    ArrayAttr newInBoundsAttr =
+    DenseBoolArrayAttr newInBoundsAttr =
         inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
 
     // Generate new transfer_read operation.
@@ -205,7 +204,7 @@ struct TransferWritePermutationLowering
                     });
 
     // Transpose in_bounds attribute.
-    ArrayAttr newInBoundsAttr =
+    DenseBoolArrayAttr newInBoundsAttr =
         inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
 
     // Generate new transfer_write operation.
@@ -298,7 +297,8 @@ struct TransferWriteNonPermutationLowering
     for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
       newInBoundsValues.push_back(op.isDimInBounds(i));
     }
-    ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
+    DenseBoolArrayAttr newInBoundsAttr =
+        rewriter.getDenseBoolArrayAttr(newInBoundsValues);
     auto newWrite = rewriter.create<vector::TransferWriteOp>(
         op.getLoc(), newVec, op.getSource(), op.getIndices(),
         AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
@@ -386,11 +386,8 @@ struct TransferOpReduceRank
 
     VectorType newReadType = VectorType::get(
         newShape, originalVecType.getElementType(), newScalableDims);
-    ArrayAttr newInBoundsAttr =
-        op.getInBounds()
-            ? rewriter.getArrayAttr(
-                  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
-            : ArrayAttr();
+    DenseBoolArrayAttr newInBoundsAttr = rewriter.getDenseBoolArrayAttr(
+        op.getInBounds().take_back(reducedShapeRank));
     Value newRead = rewriter.create<vector::TransferReadOp>(
         op.getLoc(), newReadType, op.getSource(), op.getIndices(),
         AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 42ac717b44c4b9..e08285bdf772d0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -248,10 +248,9 @@ struct CastAwayTransferReadLeadingOneDim
         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
                        rewriter.getContext());
 
-    ArrayAttr inBoundsAttr;
-    if (read.getInBounds())
-      inBoundsAttr = rewriter.getArrayAttr(
-          read.getInBoundsAttr().getValue().take_back(newType.getRank()));
+    DenseBoolArrayAttr inBoundsAttr;
+    inBoundsAttr = rewriter.getDenseBoolArrayAttr(
+        read.getInBoundsAttr().asArrayRef().take_back(newType.getRank()));
 
     Value mask = Value();
     if (read.getMask()) {
@@ -302,10 +301,9 @@ struct CastAwayTransferWriteLeadingOneDim
         AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
                        rewriter.getContext());
 
-    ArrayAttr inBoundsAttr;
-    if (write.getInBounds())
-      inBoundsAttr = rewriter.getArrayAttr(
-          write.getInBoundsAttr().getValue().take_back(newType.getRank()));
+    DenseBoolArrayAttr inBoundsAttr;
+    inBoundsAttr = rewriter.getDenseBoolArrayAttr(
+        write.getInBoundsAttr().asArrayRef().take_back(newType.getRank()));
 
     auto newVector = rewriter.create<vector::ExtractOp>(
         write.getLoc(), write.getVector(), splatZero(dropDim));
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 4c93d3841bf878..52773b2570994c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -411,7 +411,7 @@ class TransferReadDropUnitDimsPattern
     auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
         loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
         transferReadOp.getPadding(), maskOp,
-        rewriter.getBoolArrayAttr(inBounds));
+        rewriter.getDenseBoolArrayAttr(inBounds));
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, vectorType, newTransferReadOp);
     rewriter.replaceOp(transferReadOp, shapeCast);
@@ -480,7 +480,7 @@ class TransferWriteDropUnitDimsPattern
         loc, reducedVectorType, vector);
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
-        identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
+        identityMap, maskOp, rewriter.getDenseBoolArrayAttr(inBounds));
 
     return success();
   }
@@ -640,7 +640,8 @@ class FlattenContiguousRowMajorTransferReadPattern
                                                 vectorType.getElementType());
     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
-    flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+    SmallVector<bool> inBounds(1, true);
+    flatRead.setInBoundsAttr(rewriter.getDenseBoolArrayAttr(inBounds));
 
     // 4. Replace the old transfer_read with the new one reading from the
     // collapsed shape
@@ -735,7 +736,8 @@ class FlattenContiguousRowMajorTransferWritePattern
     vector::TransferWriteOp flatWrite =
         rewriter.create<vector::TransferWriteOp>(
             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
-    flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+    SmallVector<bool> inBounds(1, true);
+    flatWrite.setInBoundsAttr(rewriter.getDenseBoolArrayAttr(inBounds));
 
     // 4. Replace the old transfer_write with the new one writing the
     // collapsed shape
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index ee622e886f6185..3c482413d761e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -523,7 +523,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
     return failure();
 
   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
-  auto inBoundsAttr = b.getBoolArrayAttr(bools);
+  auto inBoundsAttr = b.getDenseBoolArrayAttr(bools);
   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
     b.modifyOpInPlace(xferOp, [&]() {
       xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7f59a378e03512..ddeb6111545b65 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1122,7 +1122,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
 
     rewriter.modifyOpInPlace(xferOp, [&]() {
       xferOp.getMaskMutable().assign(mask);
-      xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+      xferOp.setInBoundsAttr(rewriter.getDenseBoolArrayAttr({true}));
     });
 
     return success();
@@ -1321,8 +1321,8 @@ class DropInnerMostUnitDimsTransferRead
         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
             srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
             strides));
-    ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
-        readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
+    DenseBoolArrayAttr inBoundsAttr = rewriter.getDenseBoolArrayAttr(
+        readOp.getInBounds().drop_back(dimsToDrop));
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
         loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
     auto permMap = getTransferMinorIdentityMap(
@@ -1412,8 +1412,8 @@ class DropInnerMostUnitDimsTransferWrite
         cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
             srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
             strides));
-    ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
-        writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
+    DenseBoolArrayAttr inBoundsAttr = rewriter.getDenseBoolArrayAttr(
+        writeOp.getInBounds().drop_back(dimsToDrop));
 
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
         loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);



More information about the Mlir-commits mailing list