[Mlir-commits] [mlir] d624c1b - [mlir][NFC] Move asOpFoldResult helper functions to StaticValueUtils
Matthias Springer
llvmlistbot at llvm.org
Wed Jul 14 18:33:54 PDT 2021
Author: Matthias Springer
Date: 2021-07-15T10:28:57+09:00
New Revision: d624c1b50946b206b6274371fcc107f89d04a307
URL: https://github.com/llvm/llvm-project/commit/d624c1b50946b206b6274371fcc107f89d04a307
DIFF: https://github.com/llvm/llvm-project/commit/d624c1b50946b206b6274371fcc107f89d04a307.diff
LOG: [mlir][NFC] Move asOpFoldResult helper functions to StaticValueUtils
Differential Revision: https://reviews.llvm.org/D105602
Added:
Modified:
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 3284c022a7255..5838f1d1fb241 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -44,6 +44,14 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
+/// Given a value, try to extract a constant Attribute. If this fails, return
+/// the original value.
+OpFoldResult getAsOpFoldResult(Value val);
+
+/// Given an array of values, try to extract a constant Attribute from each
+/// value. If this fails, return the original value.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
+
/// If ofr is a constant integer or an IntegerAttr, return the integer.
Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 32efdc2fe19ee..12fbb8ce839bc 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -89,20 +89,6 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
-/// Helper function to convert a Value into an OpFoldResult, if the Value is
-/// known to be a constant index value.
-static SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
- return llvm::to_vector<4>(
- llvm::map_range(values, [](Value v) -> OpFoldResult {
- APInt intValue;
- if (v.getType().isa<IndexType>() &&
- matchPattern(v, m_ConstantInt(&intValue))) {
- return IntegerAttr::get(v.getType(), intValue.getSExtValue());
- }
- return v;
- }));
-}
-
/// Helper function to convert a vector of `OpFoldResult`s into a vector of
/// `Value`s.
static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 59a858f57d731..4c3bb414bbcf3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
@@ -798,14 +799,6 @@ static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) {
return builder.create<ConstantIndexOp>(loc, *intVal);
}
-/// Given a value, try to extract a constant index-type integer as an Attribute.
-/// If this fails, return the original value.
-static OpFoldResult asOpFoldResult(OpBuilder &builder, Value val) {
- if (auto constInt = getConstantIntValue(val))
- return builder.getIndexAttr(*constInt);
- return val;
-}
-
LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
auto padOp = sliceOp.source().getDefiningOp<PadTensorOp>();
@@ -895,7 +888,7 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
// ExtractSliceOp length will be zero in that case. (Effectively reading no
// data from the source.)
Value newOffset = min(max(sub(offset, low), zero), srcSize);
- newOffsets.push_back(asOpFoldResult(rewriter, newOffset));
+ newOffsets.push_back(getAsOpFoldResult(newOffset));
// The original ExtractSliceOp was reading until position `offset + length`.
// Therefore, the corresponding position within the source tensor is:
@@ -915,7 +908,7 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
// The new ExtractSliceOp length is `endLoc - newOffset`.
Value endLoc = min(max(add(sub(offset, low), length), zero), srcSize);
Value newLength = sub(endLoc, newOffset);
- newLengths.push_back(asOpFoldResult(rewriter, newLength));
+ newLengths.push_back(getAsOpFoldResult(newLength));
// Check if newLength is zero. In that case, no SubTensorOp should be
// executed.
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index bf7d662dbfcc9..b6cb4b7da5a6e 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -47,6 +47,22 @@ SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
}));
}
+/// Given a value, try to extract a constant Attribute. If this fails, return
+/// the original value.
+OpFoldResult getAsOpFoldResult(Value val) {
+ Attribute attr;
+ if (matchPattern(val, m_Constant(&attr)))
+ return attr;
+ return val;
+}
+
+/// Given an array of values, try to extract a constant Attribute from each
+/// value. If this fails, return the original value.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
+ return llvm::to_vector<4>(
+ llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
+}
+
/// If ofr is a constant integer or an IntegerAttr, return the integer.
Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
More information about the Mlir-commits
mailing list