[Mlir-commits] [mlir] d3ddcfd - [mlir][DialectUtils] Generalize `extractFromI64ArrayAttr` helper
Matthias Springer
llvmlistbot at llvm.org
Wed Jul 12 09:04:20 PDT 2023
Author: Matthias Springer
Date: 2023-07-12T17:59:40+02:00
New Revision: d3ddcfd448d08699e89b8e49e1775f9e30fcc53a
URL: https://github.com/llvm/llvm-project/commit/d3ddcfd448d08699e89b8e49e1775f9e30fcc53a
DIFF: https://github.com/llvm/llvm-project/commit/d3ddcfd448d08699e89b8e49e1775f9e30fcc53a.diff
LOG: [mlir][DialectUtils] Generalize `extractFromI64ArrayAttr` helper
Generalize `extractFromI64ArrayAttr` to `extractFromIntegerArrayAttr`, so that arbitrary integer/bool types can be extracted.
Differential Revision: https://reviews.llvm.org/D154974
Added:
Modified:
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 47910e2069761a..8c9b5e567f6699 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -15,6 +15,7 @@
#ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
#define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
@@ -57,8 +58,14 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec);
-/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
+/// Extract integer values from the assumed ArrayAttr of IntegerAttr.
+template <typename IntTy>
+SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) {
+ return llvm::to_vector(
+ llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> IntTy {
+ return cast<IntegerAttr>(a).getInt();
+ }));
+}
/// Given a value, try to extract a constant Attribute. If this fails, return
/// the original value.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4e0aa88464647e..31fdca7affbcc6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -416,9 +416,10 @@ DiagnosedSilenceableFailure
transform::FuseOp::apply(transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
- SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
+ SmallVector<int64_t> tileSizes =
+ extractFromIntegerArrayAttr<int64_t>(getTileSizes());
SmallVector<int64_t> tileInterchange =
- extractFromI64ArrayAttr(getTileInterchange());
+ extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
@@ -471,7 +472,7 @@ void transform::FuseOp::print(OpAsmPrinter &p) {
LogicalResult transform::FuseOp::verify() {
SmallVector<int64_t> permutation =
- extractFromI64ArrayAttr(getTileInterchange());
+ extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
if (!std::is_permutation(sequence.begin(), sequence.end(),
permutation.begin(), permutation.end())) {
@@ -479,7 +480,8 @@ LogicalResult transform::FuseOp::verify() {
<< getTileInterchange();
}
- SmallVector<int64_t> sizes = extractFromI64ArrayAttr(getTileSizes());
+ SmallVector<int64_t> sizes =
+ extractFromIntegerArrayAttr<int64_t>(getTileSizes());
size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
if (numExpectedLoops != getNumResults() - 1)
return emitOpError() << "expects " << numExpectedLoops << " loop results";
@@ -1571,7 +1573,8 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
// Convert the integer packing flags to booleans.
SmallVector<bool> packPaddings;
- for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
+ for (int64_t packPadding :
+ extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
packPaddings.push_back(static_cast<bool>(packPadding));
// Convert the padding values to attributes.
@@ -1611,15 +1614,17 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
// Extract the transpose vectors.
SmallVector<SmallVector<int64_t>> transposePaddings;
for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
- transposePaddings.push_back(
- extractFromI64ArrayAttr(cast<ArrayAttr>(transposeVector)));
+ transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
+ cast<ArrayAttr>(transposeVector)));
LinalgOp paddedOp;
LinalgPaddingOptions options;
- options.paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions());
+ options.paddingDimensions =
+ extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
if (getPadToMultipleOf().has_value())
- padToMultipleOf = extractFromI64ArrayAttr(*getPadToMultipleOf());
+ padToMultipleOf =
+ extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
options.padToMultipleOf = padToMultipleOf;
options.paddingValues = paddingValues;
options.packPaddings = packPaddings;
@@ -1650,7 +1655,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
LogicalResult transform::PadOp::verify() {
SmallVector<int64_t> packPaddings =
- extractFromI64ArrayAttr(getPackPaddings());
+ extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
if (any_of(packPaddings, [](int64_t packPadding) {
return packPadding != 0 && packPadding != 1;
})) {
@@ -1660,7 +1665,7 @@ LogicalResult transform::PadOp::verify() {
}
SmallVector<int64_t> paddingDimensions =
- extractFromI64ArrayAttr(getPaddingDimensions());
+ extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
if (any_of(paddingDimensions,
[](int64_t paddingDimension) { return paddingDimension < 0; })) {
return emitOpError() << "expects padding_dimensions to contain positive "
@@ -1674,7 +1679,7 @@ LogicalResult transform::PadOp::verify() {
}
ArrayAttr transposes = getTransposePaddings();
for (Attribute attr : transposes) {
- SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
+ SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(), transpose.end())) {
@@ -1791,7 +1796,7 @@ transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
LinalgPromotionOptions promotionOptions;
if (!getOperandsToPromote().empty())
promotionOptions = promotionOptions.setOperandsToPromote(
- extractFromI64ArrayAttr(getOperandsToPromote()));
+ extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
if (getUseFullTilesByDefault())
promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
getUseFullTilesByDefault());
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 75d2dcec13643d..7db793b766a1b1 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -68,14 +68,6 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
}
-/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
- return llvm::to_vector<4>(
- llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> int64_t {
- return cast<IntegerAttr>(a).getInt();
- }));
-}
-
/// Given a value, try to extract a constant Attribute. If this fails, return
/// the original value.
OpFoldResult getAsOpFoldResult(Value val) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index eec84569a21a16..88e3a455340750 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -165,7 +165,7 @@ class NVVMDialectLLVMIRTranslationInterface
if (!dyn_cast<ArrayAttr>(attribute.getValue()))
return failure();
SmallVector<int64_t> values =
- extractFromI64ArrayAttr(attribute.getValue());
+ extractFromIntegerArrayAttr<int64_t>(attribute.getValue());
generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
if (values.size() > 1)
generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
@@ -175,7 +175,7 @@ class NVVMDialectLLVMIRTranslationInterface
if (!dyn_cast<ArrayAttr>(attribute.getValue()))
return failure();
SmallVector<int64_t> values =
- extractFromI64ArrayAttr(attribute.getValue());
+ extractFromIntegerArrayAttr<int64_t>(attribute.getValue());
generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
if (values.size() > 1)
generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
More information about the Mlir-commits
mailing list