[Mlir-commits] [mlir] d624c1b - [mlir][NFC] Move asOpFoldResult helper functions to StaticValueUtils

Matthias Springer llvmlistbot at llvm.org
Wed Jul 14 18:33:54 PDT 2021


Author: Matthias Springer
Date: 2021-07-15T10:28:57+09:00
New Revision: d624c1b50946b206b6274371fcc107f89d04a307

URL: https://github.com/llvm/llvm-project/commit/d624c1b50946b206b6274371fcc107f89d04a307
DIFF: https://github.com/llvm/llvm-project/commit/d624c1b50946b206b6274371fcc107f89d04a307.diff

LOG: [mlir][NFC] Move asOpFoldResult helper functions to StaticValueUtils

Differential Revision: https://reviews.llvm.org/D105602

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 3284c022a7255..5838f1d1fb241 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -44,6 +44,14 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
 SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
 
+/// Given a value, try to extract a constant Attribute. If this fails, return
+/// the original value.
+OpFoldResult getAsOpFoldResult(Value val);
+
+/// Given an array of values, try to extract a constant Attribute from each
+/// value. If this fails, return the original value.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
+
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 32efdc2fe19ee..12fbb8ce839bc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -89,20 +89,6 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
 template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
 
-/// Helper function to convert a Value into an OpFoldResult, if the Value is
-/// known to be a constant index value.
-static SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
-  return llvm::to_vector<4>(
-      llvm::map_range(values, [](Value v) -> OpFoldResult {
-        APInt intValue;
-        if (v.getType().isa<IndexType>() &&
-            matchPattern(v, m_ConstantInt(&intValue))) {
-          return IntegerAttr::get(v.getType(), intValue.getSExtValue());
-        }
-        return v;
-      }));
-}
-
 /// Helper function to convert a vector of `OpFoldResult`s into a vector of
 /// `Value`s.
 static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 59a858f57d731..4c3bb414bbcf3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/AffineExpr.h"
@@ -798,14 +799,6 @@ static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) {
   return builder.create<ConstantIndexOp>(loc, *intVal);
 }
 
-/// Given a value, try to extract a constant index-type integer as an Attribute.
-/// If this fails, return the original value.
-static OpFoldResult asOpFoldResult(OpBuilder &builder, Value val) {
-  if (auto constInt = getConstantIntValue(val))
-    return builder.getIndexAttr(*constInt);
-  return val;
-}
-
 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
     tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
   auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
@@ -895,7 +888,7 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
     // ExtractSliceOp length will be zero in that case. (Effectively reading no
     // data from the source.)
     Value newOffset = min(max(sub(offset, low), zero), srcSize);
-    newOffsets.push_back(asOpFoldResult(rewriter, newOffset));
+    newOffsets.push_back(getAsOpFoldResult(newOffset));
 
     // The original ExtractSliceOp was reading until position `offset + length`.
     // Therefore, the corresponding position within the source tensor is:
@@ -915,7 +908,7 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
     // The new ExtractSliceOp length is `endLoc - newOffset`.
     Value endLoc = min(max(add(sub(offset, low), length), zero), srcSize);
     Value newLength = sub(endLoc, newOffset);
-    newLengths.push_back(asOpFoldResult(rewriter, newLength));
+    newLengths.push_back(getAsOpFoldResult(newLength));
 
     // Check if newLength is zero. In that case, no SubTensorOp should be
     // executed.

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index bf7d662dbfcc9..b6cb4b7da5a6e 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -47,6 +47,22 @@ SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
       }));
 }
 
+/// Given a value, try to extract a constant Attribute. If this fails, return
+/// the original value.
+OpFoldResult getAsOpFoldResult(Value val) {
+  Attribute attr;
+  if (matchPattern(val, m_Constant(&attr)))
+    return attr;
+  return val;
+}
+
+/// Given an array of values, try to extract a constant Attribute from each
+/// value. If this fails, return the original value.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
+  return llvm::to_vector<4>(
+      llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
+}
+
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   // Case 1: Check for Constant integer.


        


More information about the Mlir-commits mailing list