[Mlir-commits] [mlir] 0813700 - [mlir][NFC] Cleanup: Move helper functions to StaticValueUtils
Matthias Springer
llvmlistbot at llvm.org
Sat Jun 26 23:57:47 PDT 2021
Author: Matthias Springer
Date: 2021-06-27T15:56:48+09:00
New Revision: 0813700de1af72173ad18202fcbd3eafce90d184
URL: https://github.com/llvm/llvm-project/commit/0813700de1af72173ad18202fcbd3eafce90d184
DIFF: https://github.com/llvm/llvm-project/commit/0813700de1af72173ad18202fcbd3eafce90d184.diff
LOG: [mlir][NFC] Cleanup: Move helper functions to StaticValueUtils
Reduce code duplication: Move various helper functions, that are duplicated in TensorDialect, MemRefDialect, LinalgDialect, StandardDialect, into a new StaticValueUtils.cpp.
Differential Revision: https://reviews.llvm.org/D104687
Added:
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index df568d6795d43..ffd65f7138efc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -269,13 +269,13 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
// Return true if low padding is guaranteed to be 0.
bool hasZeroLowPad() {
return llvm::all_of(getMixedLowPad(), [](OpFoldResult ofr) {
- return mlir::isEqualConstantInt(ofr, 0);
+ return getConstantIntValue(ofr) == static_cast<int64_t>(0);
});
}
// Return true if high padding is guaranteed to be 0.
bool hasZeroHighPad() {
return llvm::all_of(getMixedHighPad(), [](OpFoldResult ofr) {
- return mlir::isEqualConstantInt(ofr, 0);
+ return getConstantIntValue(ofr) == static_cast<int64_t>(0);
});
}
}];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f6b78ae385d04..5f533df137419 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -13,6 +13,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/PatternMatch.h"
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index bff62c716dfe6..477474b41da46 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -114,21 +114,6 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
const APFloat &rhs);
-/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
-/// IntegerAttr, return the integer.
-llvm::Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
-
-/// Return true if ofr and value are the same integer.
-/// Ignore integer bitwidth and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType has no bitwidth.
-bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
-
-/// Return true if ofr1 and ofr2 are the same integer constant attribute values
-/// or the same SSA value.
-/// Ignore integer bitwitdh and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType have no bitwidth.
-bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
-
/// Returns the identity value attribute associated with an AtomicRMWKind op.
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc);
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
new file mode 100644
index 0000000000000..3284c022a7255
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -0,0 +1,58 @@
+//===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines utilities for dealing with static values, e.g.,
+// converting back and forth between Value and OpFoldResult. Such functionality
+// is used in multiple dialects.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
+#define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+
+/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
+/// it is a Value or into `staticVec` if it is an IntegerAttr.
+/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// `staticVec`. This is useful to extract mixed static and dynamic entries that
+/// come from an AttrSizedOperandSegments trait.
+void dispatchIndexOpFoldResult(OpFoldResult ofr,
+ SmallVectorImpl<Value> &dynamicVec,
+ SmallVectorImpl<int64_t> &staticVec,
+ int64_t sentinel);
+
+/// Helper function to dispatch multiple OpFoldResults into either the
+/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs).
+/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// `staticVec`. This is useful to extract mixed static and dynamic entries that
+/// come from an AttrSizedOperandSegments trait.
+void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
+ SmallVectorImpl<Value> &dynamicVec,
+ SmallVectorImpl<int64_t> &staticVec,
+ int64_t sentinel);
+
+/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
+SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
+
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
+Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
+
+/// Return true if ofr1 and ofr2 are the same integer constant attribute values
+/// or the same SSA value.
+/// Ignore integer bitwitdh and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType have no bitwidth.
+bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 8d58570148910..7df3d1e95bab4 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -13,6 +13,7 @@
#ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
#define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -30,8 +31,6 @@ struct Range {
class OffsetSizeAndStrideOpInterface;
-bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
-
namespace detail {
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index 62f24f2b97362..2ba9038cec775 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -444,7 +444,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) {
- return ::mlir::isEqualConstantInt(ofr, 1);
+ return ::mlir::getConstantIntValue(ofr) == static_cast<int64_t>(1);
});
}]
>,
@@ -456,7 +456,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) {
- return ::mlir::isEqualConstantInt(ofr, 0);
+ return ::mlir::getConstantIntValue(ofr) == static_cast<int64_t>(0);
});
}]
>,
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 8e808d75e205e..db5918e95f182 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -3388,14 +3389,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
}
};
-/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
- return llvm::to_vector<4>(
- llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
- }));
-}
-
/// Conversion pattern that transforms a subview op into:
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9a1ceebba97d5..109a1c60ddc39 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
@@ -116,24 +117,6 @@ static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
}));
}
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
-static void dispatchIndexOpFoldResult(OpFoldResult ofr,
- SmallVectorImpl<Value> &dynamicVec,
- SmallVectorImpl<int64_t> &staticVec,
- int64_t sentinel) {
- if (auto v = ofr.dyn_cast<Value>()) {
- dynamicVec.push_back(v);
- staticVec.push_back(sentinel);
- return;
- }
- APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
- staticVec.push_back(apInt.getSExtValue());
-}
-
/// This is a common class used for patterns of the form
/// ```
/// someop(memrefcast(%src)) -> someop(%src)
@@ -819,14 +802,6 @@ LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
// PadTensorOp
//===----------------------------------------------------------------------===//
-/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
- return llvm::to_vector<4>(
- llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
- }));
-}
-
static LogicalResult verify(PadTensorOp op) {
auto sourceType = op.source().getType().cast<RankedTensorType>();
auto resultType = op.result().getType().cast<RankedTensorType>();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index d02570af3622b..c951e70f18d83 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -110,6 +110,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 829b988dbad73..92382a6906835 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -814,8 +814,8 @@ struct GenericPadTensorOpVectorizationPattern
readInBounds.push_back(false);
// Write is out-of-bounds if low padding > 0.
writeInBounds.push_back(
- isEqualConstantIntOrValue(padOp.getMixedLowPad()[i],
- rewriter.getIndexAttr(0)));
+ getConstantIntValue(padOp.getMixedLowPad()[i]) ==
+ static_cast<int64_t>(0));
} else {
// Neither source nor result dim of padOp is static. Cannot vectorize
// the copy.
@@ -1098,9 +1098,9 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
if (!llvm::all_of(
- llvm::zip(insertOp.getMixedSizes(), expectedSizes),
- [](auto it) { return isEqualConstantInt(std::get<0>(it),
- std::get<1>(it)); }))
+ llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
+ return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
+ }))
return failure();
// Generate TransferReadOp: Read entire source tensor and add high padding.
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 6ac47b11996a3..6f9aeaa19cb22 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRMemRef
LINK_LIBS PUBLIC
MLIRDialect
+ MLIRDialectUtils
MLIRInferTypeOpInterface
MLIRIR
MLIRMemRefUtils
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 8d003577eb533..cc4e7a49363a5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -32,40 +33,6 @@ Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
return builder.create<mlir::ConstantOp>(loc, type, value);
}
-/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
- return llvm::to_vector<4>(
- llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
- }));
-}
-
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
-static void dispatchIndexOpFoldResult(OpFoldResult ofr,
- SmallVectorImpl<Value> &dynamicVec,
- SmallVectorImpl<int64_t> &staticVec,
- int64_t sentinel) {
- if (auto v = ofr.dyn_cast<Value>()) {
- dynamicVec.push_back(v);
- staticVec.push_back(sentinel);
- return;
- }
- APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
- staticVec.push_back(apInt.getSExtValue());
-}
-
-static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
- SmallVectorImpl<Value> &dynamicVec,
- SmallVectorImpl<int64_t> &staticVec,
- int64_t sentinel) {
- for (auto ofr : ofrs)
- dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
-}
-
//===----------------------------------------------------------------------===//
// Common canonicalization pattern support logic
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 73c0c4a607b63..837986fc03535 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -33,38 +33,6 @@
using namespace mlir;
-/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
-/// IntegerAttr, return the integer.
-llvm::Optional<int64_t> mlir::getConstantIntValue(OpFoldResult ofr) {
- Attribute attr = ofr.dyn_cast<Attribute>();
- // Note: isa+cast-like pattern allows writing the condition below as 1 line.
- if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
- attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
- if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
- return intAttr.getValue().getSExtValue();
- return llvm::None;
-}
-
-/// Return true if ofr and value are the same integer.
-/// Ignore integer bitwidth and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType has no bitwidth.
-bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) {
- auto ofrValue = getConstantIntValue(ofr);
- return ofrValue && *ofrValue == value;
-}
-
-/// Return true if ofr1 and ofr2 are the same integer constant attribute values
-/// or the same SSA value.
-/// Ignore integer bitwidth and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType has no bitwidth.
-bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
- auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
- if (cst1 && cst2 && *cst1 == *cst2)
- return true;
- auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
- return v1 && v2 && v1 == v2;
-}
-
//===----------------------------------------------------------------------===//
// StandardOpsDialect Interfaces
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index e1fad1b358f00..4b6886ef244d0 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRTensor
LINK_LIBS PUBLIC
MLIRCastInterfaces
+ MLIRDialectUtils
MLIRIR
MLIRSideEffectInterfaces
MLIRSupport
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8a4b212db0329..28a5f5df21cef 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
@@ -516,32 +517,6 @@ static LogicalResult verify(ReshapeOp op) {
// ExtractSliceOp
//===----------------------------------------------------------------------===//
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
-static void dispatchIndexOpFoldResult(OpFoldResult ofr,
- SmallVectorImpl<Value> &dynamicVec,
- SmallVectorImpl<int64_t> &staticVec,
- int64_t sentinel) {
- if (auto v = ofr.dyn_cast<Value>()) {
- dynamicVec.push_back(v);
- staticVec.push_back(sentinel);
- return;
- }
- APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
- staticVec.push_back(apInt.getSExtValue());
-}
-
-static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
- SmallVectorImpl<Value> &dynamicVec,
- SmallVectorImpl<int64_t> &staticVec,
- int64_t sentinel) {
- for (auto ofr : ofrs)
- dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
-}
-
/// An extract_slice op result type can be fully inferred from the source type
/// and the static representation of offsets, sizes and strides. Special
/// sentinels encode the dynamic case.
@@ -563,14 +538,6 @@ Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType,
sourceRankedTensorType.getElementType());
}
-/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
- return llvm::to_vector<4>(
- llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
- }));
-}
-
Type ExtractSliceOp::inferResultType(
RankedTensorType sourceRankedTensorType,
ArrayRef<OpFoldResult> leadingStaticOffsets,
@@ -890,17 +857,16 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
ShapedType shapedType) {
OpBuilder b(op.getContext());
for (OpFoldResult ofr : op.getMixedOffsets())
- if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0)))
+ if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
return failure();
// Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
// is appropriate.
auto shape = shapedType.getShape();
for (auto it : llvm::zip(op.getMixedSizes(), shape))
- if (!isEqualConstantIntOrValue(std::get<0>(it),
- b.getIndexAttr(std::get<1>(it))))
+ if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
return failure();
for (OpFoldResult ofr : op.getMixedStrides())
- if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1)))
+ if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
return failure();
return success();
}
diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt
index a640e3581b4a3..098b6b48b032f 100644
--- a/mlir/lib/Dialect/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_library(MLIRDialectUtils
StructuredOpsUtils.cpp
+ StaticValueUtils.cpp
LINK_LIBS PUBLIC
MLIRIR
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
new file mode 100644
index 0000000000000..bf7d662dbfcc9
--- /dev/null
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -0,0 +1,79 @@
+//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/APSInt.h"
+
+namespace mlir {
+
+/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
+/// it is a Value or into `staticVec` if it is an IntegerAttr.
+/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// `staticVec`. This is useful to extract mixed static and dynamic entries that
+/// come from an AttrSizedOperandSegments trait.
+void dispatchIndexOpFoldResult(OpFoldResult ofr,
+ SmallVectorImpl<Value> &dynamicVec,
+ SmallVectorImpl<int64_t> &staticVec,
+ int64_t sentinel) {
+ if (auto v = ofr.dyn_cast<Value>()) {
+ dynamicVec.push_back(v);
+ staticVec.push_back(sentinel);
+ return;
+ }
+ APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
+ staticVec.push_back(apInt.getSExtValue());
+}
+
+void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
+ SmallVectorImpl<Value> &dynamicVec,
+ SmallVectorImpl<int64_t> &staticVec,
+ int64_t sentinel) {
+ for (OpFoldResult ofr : ofrs)
+ dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
+}
+
+/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
+SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
+ return llvm::to_vector<4>(
+ llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
+ return a.cast<IntegerAttr>().getInt();
+ }));
+}
+
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
+Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
+ // Case 1: Check for Constant integer.
+ if (auto val = ofr.dyn_cast<Value>()) {
+ APSInt intVal;
+ if (matchPattern(val, m_ConstantInt(&intVal)))
+ return intVal.getSExtValue();
+ return llvm::None;
+ }
+ // Case 2: Check for IntegerAttr.
+ Attribute attr = ofr.dyn_cast<Attribute>();
+ if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+ return intAttr.getValue().getSExtValue();
+ return llvm::None;
+}
+
+/// Return true if ofr1 and ofr2 are the same integer constant attribute values
+/// or the same SSA value.
+/// Ignore integer bitwidth and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType has no bitwidth.
+bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
+ auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
+ if (cst1 && cst2 && *cst1 == *cst2)
+ return true;
+ auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
+ return v1 && v1 == v2;
+}
+
+} // namespace mlir
+
More information about the Mlir-commits
mailing list