[Mlir-commits] [mlir] d3ddcfd - [mlir][DialectUtils] Generalize `extractFromI64ArrayAttr` helper

Matthias Springer llvmlistbot at llvm.org
Wed Jul 12 09:04:20 PDT 2023


Author: Matthias Springer
Date: 2023-07-12T17:59:40+02:00
New Revision: d3ddcfd448d08699e89b8e49e1775f9e30fcc53a

URL: https://github.com/llvm/llvm-project/commit/d3ddcfd448d08699e89b8e49e1775f9e30fcc53a
DIFF: https://github.com/llvm/llvm-project/commit/d3ddcfd448d08699e89b8e49e1775f9e30fcc53a.diff

LOG: [mlir][DialectUtils] Generalize `extractFromI64ArrayAttr` helper

Generalize `extractFromI64ArrayAttr` to `extractFromIntegerArrayAttr`, so that arbitrary integer/bool types can be extracted.

Differential Revision: https://reviews.llvm.org/D154974

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp
    mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 47910e2069761a..8c9b5e567f6699 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
 #define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
 
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SmallVector.h"
@@ -57,8 +58,14 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
                                 SmallVectorImpl<Value> &dynamicVec,
                                 SmallVectorImpl<int64_t> &staticVec);
 
-/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
+/// Extract integer values from the assumed ArrayAttr of IntegerAttr.
+template <typename IntTy>
+SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) {
+  return llvm::to_vector(
+      llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> IntTy {
+        return cast<IntegerAttr>(a).getInt();
+      }));
+}
 
 /// Given a value, try to extract a constant Attribute. If this fails, return
 /// the original value.

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4e0aa88464647e..31fdca7affbcc6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -416,9 +416,10 @@ DiagnosedSilenceableFailure
 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
                          mlir::transform::TransformResults &transformResults,
                          mlir::transform::TransformState &state) {
-  SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
+  SmallVector<int64_t> tileSizes =
+      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
   SmallVector<int64_t> tileInterchange =
-      extractFromI64ArrayAttr(getTileInterchange());
+      extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
 
   scf::SCFTilingOptions tilingOptions;
   tilingOptions.interchangeVector = tileInterchange;
@@ -471,7 +472,7 @@ void transform::FuseOp::print(OpAsmPrinter &p) {
 
 LogicalResult transform::FuseOp::verify() {
   SmallVector<int64_t> permutation =
-      extractFromI64ArrayAttr(getTileInterchange());
+      extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
   if (!std::is_permutation(sequence.begin(), sequence.end(),
                            permutation.begin(), permutation.end())) {
@@ -479,7 +480,8 @@ LogicalResult transform::FuseOp::verify() {
                          << getTileInterchange();
   }
 
-  SmallVector<int64_t> sizes = extractFromI64ArrayAttr(getTileSizes());
+  SmallVector<int64_t> sizes =
+      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
   size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
   if (numExpectedLoops != getNumResults() - 1)
     return emitOpError() << "expects " << numExpectedLoops << " loop results";
@@ -1571,7 +1573,8 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
 
     // Convert the integer packing flags to booleans.
     SmallVector<bool> packPaddings;
-    for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
+    for (int64_t packPadding :
+         extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
       packPaddings.push_back(static_cast<bool>(packPadding));
 
     // Convert the padding values to attributes.
@@ -1611,15 +1614,17 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
     // Extract the transpose vectors.
     SmallVector<SmallVector<int64_t>> transposePaddings;
     for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
-      transposePaddings.push_back(
-          extractFromI64ArrayAttr(cast<ArrayAttr>(transposeVector)));
+      transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
+          cast<ArrayAttr>(transposeVector)));
 
     LinalgOp paddedOp;
     LinalgPaddingOptions options;
-    options.paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions());
+    options.paddingDimensions =
+        extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
     SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
     if (getPadToMultipleOf().has_value())
-      padToMultipleOf = extractFromI64ArrayAttr(*getPadToMultipleOf());
+      padToMultipleOf =
+          extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
     options.padToMultipleOf = padToMultipleOf;
     options.paddingValues = paddingValues;
     options.packPaddings = packPaddings;
@@ -1650,7 +1655,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
 
 LogicalResult transform::PadOp::verify() {
   SmallVector<int64_t> packPaddings =
-      extractFromI64ArrayAttr(getPackPaddings());
+      extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
   if (any_of(packPaddings, [](int64_t packPadding) {
         return packPadding != 0 && packPadding != 1;
       })) {
@@ -1660,7 +1665,7 @@ LogicalResult transform::PadOp::verify() {
   }
 
   SmallVector<int64_t> paddingDimensions =
-      extractFromI64ArrayAttr(getPaddingDimensions());
+      extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
   if (any_of(paddingDimensions,
              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
     return emitOpError() << "expects padding_dimensions to contain positive "
@@ -1674,7 +1679,7 @@ LogicalResult transform::PadOp::verify() {
   }
   ArrayAttr transposes = getTransposePaddings();
   for (Attribute attr : transposes) {
-    SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
+    SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
     if (!std::is_permutation(sequence.begin(), sequence.end(),
                              transpose.begin(), transpose.end())) {
@@ -1791,7 +1796,7 @@ transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
   LinalgPromotionOptions promotionOptions;
   if (!getOperandsToPromote().empty())
     promotionOptions = promotionOptions.setOperandsToPromote(
-        extractFromI64ArrayAttr(getOperandsToPromote()));
+        extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
   if (getUseFullTilesByDefault())
     promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
         getUseFullTilesByDefault());

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 75d2dcec13643d..7db793b766a1b1 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -68,14 +68,6 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
 }
 
-/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
-  return llvm::to_vector<4>(
-      llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> int64_t {
-        return cast<IntegerAttr>(a).getInt();
-      }));
-}
-
 /// Given a value, try to extract a constant Attribute. If this fails, return
 /// the original value.
 OpFoldResult getAsOpFoldResult(Value val) {

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index eec84569a21a16..88e3a455340750 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -165,7 +165,7 @@ class NVVMDialectLLVMIRTranslationInterface
       if (!dyn_cast<ArrayAttr>(attribute.getValue()))
         return failure();
       SmallVector<int64_t> values =
-          extractFromI64ArrayAttr(attribute.getValue());
+          extractFromIntegerArrayAttr<int64_t>(attribute.getValue());
       generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
       if (values.size() > 1)
         generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
@@ -175,7 +175,7 @@ class NVVMDialectLLVMIRTranslationInterface
       if (!dyn_cast<ArrayAttr>(attribute.getValue()))
         return failure();
       SmallVector<int64_t> values =
-          extractFromI64ArrayAttr(attribute.getValue());
+          extractFromIntegerArrayAttr<int64_t>(attribute.getValue());
       generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
       if (values.size() > 1)
         generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());


        


More information about the Mlir-commits mailing list