[Mlir-commits] [mlir] [mlir][vector] Update the internal representation of in_bounds (PR #100336)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Aug 13 09:46:19 PDT 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/100336
>From 490cf97ac3a7ac424a21579fe57b9ce99a33a94f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 13 Aug 2024 12:08:52 +0000
Subject: [PATCH 1/2] [mlir] Add the ability to override attribute
parsing/printing in attr-dicts
This adds a `parseNamedAttrFn` callback to
`AsmParser::parseOptionalAttrDict()`.
If parseNamedAttrFn is provided the default parsing can be overridden
for a named attribute. parseNamedAttrFn is passed the name of an
attribute, if it can parse the attribute it returns the parsed
attribute, otherwise, it returns `failure()` which indicates that
generic parsing should be used. Note: Returning a null Attribute from
parseNamedAttrFn indicates a parser error.
It also adds `printNamedAttrFn` to
`AsmPrinter::printOptionalAttrDict()`.
If printNamedAttrFn is provided the default printing can be overridden
for a named attribute. printNamedAttrFn is passed a NamedAttribute, if
it prints the attribute it returns `success()`, otherwise, it returns
`failure()` which indicates that generic printing should be used.
---
mlir/include/mlir/IR/OpImplementation.h | 23 ++++++++++---
mlir/lib/AsmParser/AsmParserImpl.h | 7 ++--
mlir/lib/AsmParser/AttributeParser.cpp | 16 +++++++--
mlir/lib/AsmParser/Parser.h | 4 ++-
mlir/lib/IR/AsmPrinter.cpp | 45 ++++++++++++++++---------
5 files changed, 71 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index ae412c7227f8ea..5891cbffc9542d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
/// If the specified operation has attributes, print out an attribute
/// dictionary with their values. elidedAttrs allows the client to ignore
/// specific well known attributes, commonly used if the attribute value is
- /// printed some other way (like as a fixed operand).
+ /// printed some other way (like as a fixed operand). If printNamedAttrFn is
+ /// provided the default printing can be overridden for a named attribute.
+ /// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
+ /// it returns `success()`, otherwise, it returns `failure()` which indicates
+ /// that generic printing should be used.
virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) = 0;
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) = 0;
/// If the specified operation has attributes, print out an attribute
/// dictionary prefixed with 'attributes'.
@@ -1116,8 +1122,17 @@ class AsmParser {
return parseResult;
}
- /// Parse a named dictionary into 'result' if it is present.
- virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
+ /// Parse a named dictionary into 'result' if it is present. If
+ /// parseNamedAttrFn is provided the default parsing can be overridden for a
+ /// named attribute. parseNamedAttrFn is passed the name of an attribute, if
+ /// it can parse the attribute it returns the parsed attribute, otherwise, it
+ /// returns `failure()` which indicates that generic parsing should be used.
+ /// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
+ /// error.
+ virtual ParseResult parseOptionalAttrDict(
+ NamedAttrList &result,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+ nullptr) = 0;
/// Parse a named dictionary into 'result' if the `attributes` keyword is
/// present.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index b12687833e3fde..808b2ca282f64b 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
}
/// Parse a named dictionary into 'result' if it is present.
- ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
+ ParseResult parseOptionalAttrDict(
+ NamedAttrList &result,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+ nullptr) override {
if (parser.getToken().isNot(Token::l_brace))
return success();
- return parser.parseAttributeDict(result);
+ return parser.parseAttributeDict(result, parseNamedAttrFn);
}
/// Parse a named dictionary into 'result' if the `attributes` keyword is
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e49abc33b..b687d822e7cb7d 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
/// | `{` attribute-entry (`,` attribute-entry)* `}`
/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
///
-ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
+ParseResult Parser::parseAttributeDict(
+ NamedAttrList &attributes,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
llvm::SmallDenseSet<StringAttr> seenKeys;
auto parseElt = [&]() -> ParseResult {
// The name of an attribute can either be a bare identifier, or a string.
@@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return success();
}
- auto attr = parseAttribute();
+ Attribute attr = nullptr;
+ FailureOr<Attribute> customParsedAttribute;
+ // Try to parse with `printNamedAttrFn` callback.
+ if (parseNamedAttrFn &&
+ succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
+ attr = *customParsedAttribute;
+ } else {
+ // Otherwise, use generic attribute parser.
+ attr = parseAttribute();
+ }
+
if (!attr)
return failure();
attributes.push_back({*nameId, attr});
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..d5d90f391fd391 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -256,7 +256,9 @@ class Parser {
}
/// Parse an attribute dictionary.
- ParseResult parseAttributeDict(NamedAttrList &attributes);
+ ParseResult parseAttributeDict(
+ NamedAttrList &attributes,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);
/// Parse a distinct attribute.
Attribute parseDistinctAttr(Type type);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..cd9f70c8868b83 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -452,10 +452,13 @@ class AsmPrinter::Impl {
void printDimensionList(ArrayRef<int64_t> shape);
protected:
- void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {},
- bool withKeyword = false);
- void printNamedAttribute(NamedAttribute attr);
+ void printOptionalAttrDict(
+ ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
+ bool withKeyword = false,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
+ void printNamedAttribute(
+ NamedAttribute attr,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
void printTrailingLocation(Location loc, bool allowAlias = true);
void printLocationInternal(LocationAttr loc, bool pretty = false,
bool isTopLevel = false);
@@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
/// Print the given set of attributes with names not included within
/// 'elidedAttrs'.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) override {
- if (attrs.empty())
- return;
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) override {
+ (void)printNamedAttrFn;
if (elidedAttrs.empty()) {
for (const NamedAttribute &attr : attrs)
printAttribute(attr.getValue());
@@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Default([&](Type type) { return printDialectType(type); });
}
-void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs,
- bool withKeyword) {
+void AsmPrinter::Impl::printOptionalAttrDict(
+ ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
+ bool withKeyword,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
@@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
// Otherwise, print them all out in braces.
os << " {";
- interleaveComma(filteredAttrs,
- [&](NamedAttribute attr) { printNamedAttribute(attr); });
+ interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+ printNamedAttribute(attr, printNamedAttrFn);
+ });
os << '}';
};
@@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs);
}
-void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
+void AsmPrinter::Impl::printNamedAttribute(
+ NamedAttribute attr,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// Print the name without quotes if possible.
::printKeywordOrString(attr.getName().strref(), os);
@@ -2729,6 +2737,10 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
return;
os << " = ";
+ if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
+ /// If we print via the `printNamedAttrFn` callback skip printing.
+ return;
+ }
printAttribute(attr.getValue());
}
@@ -3149,8 +3161,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
/// Print an optional attribute dictionary with a given set of elided values.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) override {
- Impl::printOptionalAttrDict(attrs, elidedAttrs);
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) override {
+ Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
+ printNamedAttrFn);
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
>From 380f20259caf35bbe1f03cde646e9a5d5ab5af76 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 18 Jul 2024 11:24:06 +0100
Subject: [PATCH 2/2] Update the internal representation of `in_bounds`
This PR updates the internal representation of the `in_bounds` attribute
for `xfer_read`/`xfer_write` Ops. Currently we use `ArrayAttr` - that's
being updated to `DenseBoolArrayAttribute`.
Note that this means that the asm format of the `xfer_{read|_write}`
will change from:
```mlir
vector.transfer_read %arg0[%0, %1], %cst {in_bounds = [true], permutation_map = #map3} : memref<12x16xf32>, vector<8xf32>
```
to:
```mlir
vector.transfer_read %arg0[%0, %1], %cst {in_bounds = array<i1: true>, permutation_map = #map3} : memref<12x16xf32>, vector<8xf32>
```
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 10 ++--
.../mlir/Interfaces/VectorInterfaces.td | 4 +-
.../Conversion/VectorToSCF/VectorToSCF.cpp | 5 +-
.../ArmSME/Transforms/VectorLegalization.cpp | 12 ++--
.../Linalg/Transforms/Vectorization.cpp | 14 ++---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 55 ++++++++++++-------
.../Vector/Transforms/LowerVectorTransfer.cpp | 23 ++++----
.../Transforms/VectorDropLeadUnitDim.cpp | 14 ++---
.../Transforms/VectorTransferOpTransforms.cpp | 10 ++--
.../VectorTransferSplitRewritePatterns.cpp | 2 +-
.../Vector/Transforms/VectorTransforms.cpp | 10 ++--
11 files changed, 86 insertions(+), 73 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b96f5c2651bce5..386da3d977e468 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1248,7 +1248,7 @@ def Vector_TransferReadOp :
AffineMapAttr:$permutation_map,
AnyType:$padding,
Optional<VectorOf<[I1]>>:$mask,
- BoolArrayAttr:$in_bounds)>,
+ DenseBoolArrayAttr:$in_bounds)>,
Results<(outs AnyVectorOfAnyRank:$vector)> {
let summary = "Reads a supervector from memory into an SSA vector value.";
@@ -1443,7 +1443,7 @@ def Vector_TransferReadOp :
"Value":$source,
"ValueRange":$indices,
"AffineMapAttr":$permutationMapAttr,
- "ArrayAttr":$inBoundsAttr)>,
+ "DenseBoolArrayAttr":$inBoundsAttr)>,
/// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
@@ -1495,7 +1495,7 @@ def Vector_TransferWriteOp :
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
Optional<VectorOf<[I1]>>:$mask,
- BoolArrayAttr:$in_bounds)>,
+ DenseBoolArrayAttr:$in_bounds)>,
Results<(outs Optional<AnyRankedTensor>:$result)> {
let summary = "The vector.transfer_write op writes a supervector to memory.";
@@ -1606,13 +1606,13 @@ def Vector_TransferWriteOp :
"ValueRange":$indices,
"AffineMapAttr":$permutationMapAttr,
"Value":$mask,
- "ArrayAttr":$inBoundsAttr)>,
+ "DenseBoolArrayAttr":$inBoundsAttr)>,
/// 2. Builder with type inference that sets an empty mask (variant with attrs).
OpBuilder<(ins "Value":$vector,
"Value":$dest,
"ValueRange":$indices,
"AffineMapAttr":$permutationMapAttr,
- "ArrayAttr":$inBoundsAttr)>,
+ "DenseBoolArrayAttr":$inBoundsAttr)>,
/// 3. Builder with type inference that sets an empty mask (variant without attrs).
OpBuilder<(ins "Value":$vector,
"Value":$dest,
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index 7ea62c2ae2ab13..b2a381b4510085 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -98,7 +98,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
dimension whether it is in-bounds or not. (Broadcast dimensions are
always in-bounds).
}],
- /*retTy=*/"::mlir::ArrayAttr",
+ /*retTy=*/"::mlir::ArrayRef<bool>",
/*methodName=*/"getInBounds",
/*args=*/(ins)
>,
@@ -241,7 +241,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
if ($_op.isBroadcastDim(dim))
return true;
auto inBounds = $_op.getInBounds();
- return ::llvm::cast<::mlir::BoolAttr>(inBounds[dim]).getValue();
+ return inBounds[dim];
}
/// Helper function to account for the fact that `permutationMap` results
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..cba39905fcc571 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -264,10 +264,11 @@ static void generateInBoundsCheck(
}
/// Given an ArrayAttr, return a copy where the first element is dropped.
-static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
+static DenseBoolArrayAttr dropFirstElem(OpBuilder &b, DenseBoolArrayAttr attr) {
if (!attr)
return attr;
- return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
+ return DenseBoolArrayAttr::get(b.getContext(),
+ attr.asArrayRef().drop_front());
}
/// Add the pass label to a vector transfer op if its rank is not the target
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..84305758e9fc01 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -497,8 +497,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
sliceMask,
- rewriter.getBoolArrayAttr(
- ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
+ rewriter.getDenseBoolArrayAttr(writeOp.getInBounds().drop_front()));
}
rewriter.eraseOp(writeOp);
@@ -691,13 +690,12 @@ struct LiftIllegalVectorTransposeToMemory
transposeOp.getPermutation(), getContext());
auto transposedSubview = rewriter.create<memref::TransposeOp>(
loc, readSubview, AffineMapAttr::get(transposeMap));
- ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
+ DenseBoolArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
// - The `in_bounds` attribute
if (inBoundsAttr) {
- SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
- inBoundsAttr.end());
+ SmallVector<bool> inBoundsValues(inBoundsAttr.asArrayRef());
applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
- inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
+ inBoundsAttr = rewriter.getDenseBoolArrayAttr(inBoundsValues);
}
VectorType legalReadType = resultType.clone(readType.getElementType());
@@ -902,7 +900,7 @@ struct LowerIllegalTransposeStoreViaZA
rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
auto smeWrite = rewriter.create<vector::TransferWriteOp>(
loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
- transposeMap, subMask, writeOp.getInBounds());
+ transposeMap, subMask, writeOp.getInBoundsAttr());
if (writeOp.hasPureTensorSemantics())
destTensorOrMemref = smeWrite.getResult();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..4d1f68fe9c4ade 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -646,7 +646,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
- maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
+ maskedWriteOp.setInBoundsAttr(rewriter.getDenseBoolArrayAttr(inBounds));
}
LDBG("vectorized op: " << *write << "\n");
@@ -1364,7 +1364,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
SmallVector<bool> inBounds(readType.getRank(), true);
cast<vector::TransferReadOp>(maskOp.getMaskableOp())
- .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
+ .setInBoundsAttr(rewriter.getDenseBoolArrayAttr(inBounds));
}
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
@@ -2397,7 +2397,7 @@ struct PadOpVectorizationWithTransferReadPattern
rewriter.modifyOpInPlace(xferOp, [&]() {
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
xferOp->setAttr(xferOp.getInBoundsAttrName(),
- rewriter.getBoolArrayAttr(inBounds));
+ rewriter.getDenseBoolArrayAttr(inBounds));
xferOp.getSourceMutable().assign(padOp.getSource());
xferOp.getPaddingMutable().assign(padValue);
});
@@ -2476,7 +2476,7 @@ struct PadOpVectorizationWithTransferWritePattern
auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
xferOp, padOp.getSource().getType(), xferOp.getVector(),
padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
- xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
+ xferOp.getMask(), rewriter.getDenseBoolArrayAttr(inBounds));
rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
return success();
@@ -2780,7 +2780,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
Value res = rewriter.create<vector::TransferReadOp>(
xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
- rewriter.getBoolArrayAttr(
+ rewriter.getDenseBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false)));
if (maybeFillOp)
@@ -2839,7 +2839,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
rewriter.create<vector::TransferWriteOp>(
xferOp.getLoc(), vector, out, xferOp.getIndices(),
xferOp.getPermutationMapAttr(), xferOp.getMask(),
- rewriter.getBoolArrayAttr(
+ rewriter.getDenseBoolArrayAttr(
SmallVector<bool>(vector.getType().getRank(), false)));
rewriter.eraseOp(copyOp);
@@ -3339,7 +3339,7 @@ struct Conv1DGenerator
SmallVector<bool> inBounds(maskShape.size(), true);
auto xferOp = cast<VectorTransferOpInterface>(opToMask);
xferOp->setAttr(xferOp.getInBoundsAttrName(),
- rewriter.getBoolArrayAttr(inBounds));
+ rewriter.getDenseBoolArrayAttr(inBounds));
SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 44bd4aa76ffbd6..3228307370e2be 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3758,7 +3758,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, AffineMapAttr permutationMapAttr,
- /*optional*/ ArrayAttr inBoundsAttr) {
+ /*optional*/ DenseBoolArrayAttr inBoundsAttr) {
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
Value padding = builder.create<arith::ConstantOp>(
result.location, elemType, builder.getZeroAttr(elemType));
@@ -3773,8 +3773,8 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
std::optional<ArrayRef<bool>> inBounds) {
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
- ? builder.getBoolArrayAttr(inBounds.value())
- : builder.getBoolArrayAttr(
+ ? builder.getDenseBoolArrayAttr(inBounds.value())
+ : builder.getDenseBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false));
build(builder, result, vectorType, source, indices, permutationMapAttr,
inBoundsAttr);
@@ -3789,8 +3789,8 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
llvm::cast<ShapedType>(source.getType()), vectorType);
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
- ? builder.getBoolArrayAttr(inBounds.value())
- : builder.getBoolArrayAttr(
+ ? builder.getDenseBoolArrayAttr(inBounds.value())
+ : builder.getDenseBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false));
build(builder, result, vectorType, source, indices, permutationMapAttr,
padding,
@@ -3842,7 +3842,7 @@ static LogicalResult
verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
VectorType vectorType, VectorType maskType,
VectorType inferredMaskType, AffineMap permutationMap,
- ArrayAttr inBounds) {
+ ArrayRef<bool> inBounds) {
if (op->hasAttr("masked")) {
return op->emitOpError("masked attribute has been removed. "
"Use in_bounds instead.");
@@ -3915,8 +3915,7 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
<< AffineMapAttr::get(permutationMap)
<< " vs inBounds of size: " << inBounds.size();
for (unsigned int i = 0, e = permutationMap.getNumResults(); i < e; ++i)
- if (isa<AffineConstantExpr>(permutationMap.getResult(i)) &&
- !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
+ if (isa<AffineConstantExpr>(permutationMap.getResult(i)) && !inBounds[i])
return op->emitOpError("requires broadcast dimensions to be in-bounds");
return success();
@@ -3930,7 +3929,25 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
// Elide in_bounds attribute if all dims are out-of-bounds.
if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
elidedAttrs.push_back(op.getInBoundsAttrName());
- p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+ p.printOptionalAttrDict(op->getAttrs(), elidedAttrs,
+ [&](NamedAttribute attr) -> LogicalResult {
+ if (attr.getName() != op.getInBoundsAttrName())
+ return failure();
+ cast<DenseBoolArrayAttr>(attr.getValue()).print(p);
+ return success();
+ });
+}
+
+template <typename XferOp>
+static ParseResult parseTransferAttrs(OpAsmParser &parser,
+ OperationState &result) {
+ auto inBoundsAttrName = XferOp::getInBoundsAttrName(result.name);
+ return parser.parseOptionalAttrDict(
+ result.attributes, [&](StringRef name) -> FailureOr<Attribute> {
+ if (name != inBoundsAttrName)
+ return failure();
+ return DenseBoolArrayAttr::parse(parser, {});
+ });
}
void TransferReadOp::print(OpAsmPrinter &p) {
@@ -3972,7 +3989,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseOperand(maskInfo))
return failure();
}
- if (parser.parseOptionalAttrDict(result.attributes) ||
+ if (parseTransferAttrs<TransferReadOp>(parser, result) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();
if (types.size() != 2)
@@ -3997,7 +4014,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
if (!inBoundsAttr) {
result.addAttribute(inBoundsAttrName,
- builder.getBoolArrayAttr(
+ builder.getDenseBoolArrayAttr(
SmallVector<bool>(permMap.getNumResults(), false)));
}
if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
@@ -4125,7 +4142,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
return failure();
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(op.getContext());
- op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
+ op.setInBoundsAttr(b.getDenseBoolArrayAttr(newInBounds));
return success();
}
@@ -4295,7 +4312,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value dest, ValueRange indices,
AffineMapAttr permutationMapAttr,
/*optional*/ Value mask,
- /*optional*/ ArrayAttr inBoundsAttr) {
+ /*optional*/ DenseBoolArrayAttr inBoundsAttr) {
Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
mask, inBoundsAttr);
@@ -4305,7 +4322,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value dest, ValueRange indices,
AffineMapAttr permutationMapAttr,
- /*optional*/ ArrayAttr inBoundsAttr) {
+ /*optional*/ DenseBoolArrayAttr inBoundsAttr) {
build(builder, result, vector, dest, indices, permutationMapAttr,
/*mask=*/Value(), inBoundsAttr);
}
@@ -4319,8 +4336,8 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
auto inBoundsAttr =
(inBounds && !inBounds.value().empty())
- ? builder.getBoolArrayAttr(inBounds.value())
- : builder.getBoolArrayAttr(SmallVector<bool>(
+ ? builder.getDenseBoolArrayAttr(inBounds.value())
+ : builder.getDenseBoolArrayAttr(SmallVector<bool>(
llvm::cast<VectorType>(vector.getType()).getRank(), false));
build(builder, result, vector, dest, indices, permutationMapAttr,
/*mask=*/Value(), inBoundsAttr);
@@ -4352,7 +4369,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
ParseResult hasMask = parser.parseOptionalComma();
if (hasMask.succeeded() && parser.parseOperand(maskInfo))
return failure();
- if (parser.parseOptionalAttrDict(result.attributes) ||
+ if (parseTransferAttrs<TransferWriteOp>(parser, result) ||
parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();
if (types.size() != 2)
@@ -4378,7 +4395,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
if (!inBoundsAttr) {
result.addAttribute(inBoundsAttrName,
- builder.getBoolArrayAttr(
+ builder.getDenseBoolArrayAttr(
SmallVector<bool>(permMap.getNumResults(), false)));
}
if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
@@ -4731,7 +4748,7 @@ struct SwapExtractSliceOfTransferWrite
auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
transferOp.getIndices(), transferOp.getPermutationMapAttr(),
- rewriter.getBoolArrayAttr(newInBounds));
+ rewriter.getDenseBoolArrayAttr(newInBounds));
rewriter.modifyOpInPlace(insertOp, [&]() {
insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
});
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 344cfc0cbffb93..7ecfb766f0b799 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -22,15 +22,14 @@ using namespace mlir::vector;
/// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
/// permutation based on the given indices.
-static ArrayAttr
-inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
+static DenseBoolArrayAttr
+inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayRef<bool> inBounds,
const SmallVector<unsigned> &permutation) {
SmallVector<bool> newInBoundsValues(permutation.size());
size_t index = 0;
for (unsigned pos : permutation)
- newInBoundsValues[pos] =
- cast<BoolAttr>(attr.getValue()[index++]).getValue();
- return builder.getBoolArrayAttr(newInBoundsValues);
+ newInBoundsValues[pos] = inBounds[index++];
+ return builder.getDenseBoolArrayAttr(newInBoundsValues);
}
/// Extend the rank of a vector Value by `addedRanks` by adding outer unit
@@ -132,7 +131,7 @@ struct TransferReadPermutationLowering
}
// Transpose in_bounds attribute.
- ArrayAttr newInBoundsAttr =
+ DenseBoolArrayAttr newInBoundsAttr =
inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
// Generate new transfer_read operation.
@@ -205,7 +204,7 @@ struct TransferWritePermutationLowering
});
// Transpose in_bounds attribute.
- ArrayAttr newInBoundsAttr =
+ DenseBoolArrayAttr newInBoundsAttr =
inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
// Generate new transfer_write operation.
@@ -298,7 +297,8 @@ struct TransferWriteNonPermutationLowering
for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
newInBoundsValues.push_back(op.isDimInBounds(i));
}
- ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
+ DenseBoolArrayAttr newInBoundsAttr =
+ rewriter.getDenseBoolArrayAttr(newInBoundsValues);
auto newWrite = rewriter.create<vector::TransferWriteOp>(
op.getLoc(), newVec, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
@@ -386,11 +386,8 @@ struct TransferOpReduceRank
VectorType newReadType = VectorType::get(
newShape, originalVecType.getElementType(), newScalableDims);
- ArrayAttr newInBoundsAttr =
- op.getInBounds()
- ? rewriter.getArrayAttr(
- op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
- : ArrayAttr();
+ DenseBoolArrayAttr newInBoundsAttr = rewriter.getDenseBoolArrayAttr(
+ op.getInBounds().take_back(reducedShapeRank));
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 42ac717b44c4b9..e08285bdf772d0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -248,10 +248,9 @@ struct CastAwayTransferReadLeadingOneDim
AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
rewriter.getContext());
- ArrayAttr inBoundsAttr;
- if (read.getInBounds())
- inBoundsAttr = rewriter.getArrayAttr(
- read.getInBoundsAttr().getValue().take_back(newType.getRank()));
+ DenseBoolArrayAttr inBoundsAttr;
+ inBoundsAttr = rewriter.getDenseBoolArrayAttr(
+ read.getInBoundsAttr().asArrayRef().take_back(newType.getRank()));
Value mask = Value();
if (read.getMask()) {
@@ -302,10 +301,9 @@ struct CastAwayTransferWriteLeadingOneDim
AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
rewriter.getContext());
- ArrayAttr inBoundsAttr;
- if (write.getInBounds())
- inBoundsAttr = rewriter.getArrayAttr(
- write.getInBoundsAttr().getValue().take_back(newType.getRank()));
+ DenseBoolArrayAttr inBoundsAttr;
+ inBoundsAttr = rewriter.getDenseBoolArrayAttr(
+ write.getInBoundsAttr().asArrayRef().take_back(newType.getRank()));
auto newVector = rewriter.create<vector::ExtractOp>(
write.getLoc(), write.getVector(), splatZero(dropDim));
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 4c93d3841bf878..52773b2570994c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -411,7 +411,7 @@ class TransferReadDropUnitDimsPattern
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
transferReadOp.getPadding(), maskOp,
- rewriter.getBoolArrayAttr(inBounds));
+ rewriter.getDenseBoolArrayAttr(inBounds));
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, vectorType, newTransferReadOp);
rewriter.replaceOp(transferReadOp, shapeCast);
@@ -480,7 +480,7 @@ class TransferWriteDropUnitDimsPattern
loc, reducedVectorType, vector);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
- identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
+ identityMap, maskOp, rewriter.getDenseBoolArrayAttr(inBounds));
return success();
}
@@ -640,7 +640,8 @@ class FlattenContiguousRowMajorTransferReadPattern
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
- flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+ SmallVector<bool> inBounds(1, true);
+ flatRead.setInBoundsAttr(rewriter.getDenseBoolArrayAttr(inBounds));
// 4. Replace the old transfer_read with the new one reading from the
// collapsed shape
@@ -735,7 +736,8 @@ class FlattenContiguousRowMajorTransferWritePattern
vector::TransferWriteOp flatWrite =
rewriter.create<vector::TransferWriteOp>(
loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
- flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+ SmallVector<bool> inBounds(1, true);
+ flatWrite.setInBoundsAttr(rewriter.getDenseBoolArrayAttr(inBounds));
// 4. Replace the old transfer_write with the new one writing the
// collapsed shape
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index ee622e886f6185..3c482413d761e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -523,7 +523,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
return failure();
SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
- auto inBoundsAttr = b.getBoolArrayAttr(bools);
+ auto inBoundsAttr = b.getDenseBoolArrayAttr(bools);
if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
b.modifyOpInPlace(xferOp, [&]() {
xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7f59a378e03512..ddeb6111545b65 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1122,7 +1122,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getMaskMutable().assign(mask);
- xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+ xferOp.setInBoundsAttr(rewriter.getDenseBoolArrayAttr({true}));
});
return success();
@@ -1321,8 +1321,8 @@ class DropInnerMostUnitDimsTransferRead
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
strides));
- ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
- readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
+ DenseBoolArrayAttr inBoundsAttr = rewriter.getDenseBoolArrayAttr(
+ readOp.getInBounds().drop_back(dimsToDrop));
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
auto permMap = getTransferMinorIdentityMap(
@@ -1412,8 +1412,8 @@ class DropInnerMostUnitDimsTransferWrite
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
strides));
- ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
- writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
+ DenseBoolArrayAttr inBoundsAttr = rewriter.getDenseBoolArrayAttr(
+ writeOp.getInBounds().drop_back(dimsToDrop));
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
More information about the Mlir-commits
mailing list