[Mlir-commits] [mlir] 8c88565 - [mlir][Interfaces] Add ValueBoundsOpInterface
Matthias Springer
llvmlistbot at llvm.org
Wed Apr 5 17:57:30 PDT 2023
Author: Matthias Springer
Date: 2023-04-06T02:57:14+02:00
New Revision: 8c885658edf599da277d6d8f2f66bf4cf6f2b934
URL: https://github.com/llvm/llvm-project/commit/8c885658edf599da277d6d8f2f66bf4cf6f2b934
DIFF: https://github.com/llvm/llvm-project/commit/8c885658edf599da277d6d8f2f66bf4cf6f2b934.diff
LOG: [mlir][Interfaces] Add ValueBoundsOpInterface
Ops can implement this interface to specify lower/upper bounds for their result values and block arguments. Bounds can be specified for:
* Index-type values
* Dimension sizes of shapes values
The bounds are added to a constraint set. Users can query this constraint set to compute bounds wrt. to a user-specified set of values. Only EQ bounds are supported at the moment.
This revision also contains interface implementations for various tensor dialect ops, which illustrates how to implement this interface.
Differential Revision: https://reviews.llvm.org/D145681
Added:
mlir/include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h
mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td
mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/test/Dialect/Affine/value-bounds-reification.mlir
mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
Modified:
mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/Interfaces/CMakeLists.txt
mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
mlir/lib/Interfaces/CMakeLists.txt
mlir/test/lib/Dialect/Affine/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index dbba3b7549ff1..2abd329b51620 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -14,12 +14,21 @@
#ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
+class AffineApplyOp;
+class Location;
+class OpBuilder;
+class OpFoldResult;
class RewritePatternSet;
class RewriterBase;
-class AffineApplyOp;
+class Value;
+
+namespace presburger {
+enum class BoundType;
+} // namespace presburger
/// Populate patterns that expand affine index operations into more fundamental
/// operations (not necessarily restricted to Affine dialect).
@@ -40,6 +49,32 @@ void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op);
/// maximally compose chains of AffineApplyOps.
FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);
+/// Reify a bound for the given index-typed value or shape dimension size in
+/// terms of the owning op's operands. `dim` must be `nullopt` if and only if
+/// `value` is index-typed.
+FailureOr<OpFoldResult> reifyValueBound(OpBuilder &b, Location loc,
+ presburger::BoundType type, Value value,
+ std::optional<int64_t> dim);
+
+/// Reify a bound for the given index-typed value or shape dimension size in
+/// terms of SSA values for which `stopCondition` is met. `dim` must be
+/// `nullopt` if and only if `value` is index-typed.
+///
+/// Example:
+/// %0 = arith.addi %a, %b : index
+/// %1 = arith.addi %0, %c : index
+///
+/// * If `stopCondition` evaluates to "true" for %0 and %c, "%0 + %c" is an EQ
+/// bound for %1.
+/// * If `stopCondition` evaluates to "true" for %a, %b and %c, "%a + %b + %c"
+/// is an EQ bound for %1.
+/// * Otherwise, if the owners of %a, %b or %c do not implement the
+/// ValueBoundsOpInterface, no bound can be computed.
+FailureOr<OpFoldResult>
+reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+ Value value, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition);
+
} // namespace mlir
#endif // MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h
new file mode 100644
index 0000000000000..a75ee9df3217d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TENSOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_TENSOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace tensor {
+void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 9090189a1bd18..998503a612016 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -65,6 +65,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
@@ -142,6 +143,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
+ tensor::registerValueBoundsOpInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
}
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index f24d34070fb62..2cb1d489c4bfd 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
add_mlir_interface(SideEffectInterfaces)
add_mlir_interface(TilingInterface)
+add_mlir_interface(ValueBoundsOpInterface)
add_mlir_interface(VectorInterfaces)
add_mlir_interface(ViewLikeInterface)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
new file mode 100644
index 0000000000000..97d27a04df893
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -0,0 +1,198 @@
+//===- ValueBoundsOpInterface.h - Value Bounds ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_
+#define MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_
+
+#include "mlir/Analysis/FlatLinearValueConstraints.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir {
+
+using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
+
+/// A helper class to be used with `ValueBoundsOpInterface`. This class stores a
+/// constraint system and mapping of constrained variables to index-typed
+/// values or dimension sizes of shaped values.
+///
+/// Interface implementations of `ValueBoundsOpInterface` use `addBounds` to
+/// insert constraints about their results and/or region block arguments into
+/// the constraint set in the form of an AffineExpr. When a bound should be
+/// expressed in terms of another value/dimension, `getExpr` can be used to
+/// retrieve an AffineExpr that represents the specified value/dimension.
+///
+/// When a value/dimension is retrieved for the first time through `getExpr`,
+/// it is added to an internal worklist. See `computeBound` for more details.
+///
+/// Note: Any modification of existing IR invalides the data stored in this
+/// class. Adding new operations is allowed.
+class ValueBoundsConstraintSet {
+protected:
+ /// Helper class that builds a bound for a shaped value dimension or
+ /// index-typed value.
+ class BoundBuilder {
+ public:
+ /// Specify a dimension, assuming that the underlying value is a shaped
+ /// value.
+ BoundBuilder &operator[](int64_t dim);
+
+ // These overloaded operators add lower/upper/equality bounds.
+ void operator<(AffineExpr expr);
+ void operator<=(AffineExpr expr);
+ void operator>(AffineExpr expr);
+ void operator>=(AffineExpr expr);
+ void operator==(AffineExpr expr);
+ void operator<(OpFoldResult ofr);
+ void operator<=(OpFoldResult ofr);
+ void operator>(OpFoldResult ofr);
+ void operator>=(OpFoldResult ofr);
+ void operator==(OpFoldResult ofr);
+ void operator<(int64_t i);
+ void operator<=(int64_t i);
+ void operator>(int64_t i);
+ void operator>=(int64_t i);
+ void operator==(int64_t i);
+
+ protected:
+ friend class ValueBoundsConstraintSet;
+ BoundBuilder(ValueBoundsConstraintSet &cstr, Value value)
+ : cstr(cstr), value(value) {}
+
+ private:
+ BoundBuilder(const BoundBuilder &) = delete;
+ BoundBuilder &operator=(const BoundBuilder &) = delete;
+ bool operator==(const BoundBuilder &) = delete;
+ bool operator!=(const BoundBuilder &) = delete;
+
+ ValueBoundsConstraintSet &cstr;
+ Value value;
+ std::optional<int64_t> dim = std::nullopt;
+ };
+
+public:
+ /// The stop condition when traversing the backward slice of a shaped value/
+ /// index-type value. The traversal continues until the stop condition
+ /// evaluates to "true" for a value.
+ using StopConditionFn = function_ref<bool(Value)>;
+
+ /// Compute a bound for the given index-typed value or shape dimension size.
+ /// The computed bound is stored in `resultMap`. The operands of the bound are
+ /// stored in `mapOperands`. An operand is either an index-type SSA value
+ /// or a shaped value and a dimension.
+ ///
+ /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
+ /// is computed in terms of values for which `stopCondition` evaluates to
+ /// "true". To that end, the backward slice (reverse use-def chain) of the
+ /// given value is visited in a worklist-driven manner and the constraint set
+ /// is populated according to `ValueBoundsOpInterface` for each visited value.
+ static LogicalResult computeBound(AffineMap &resultMap,
+ ValueDimList &mapOperands,
+ presburger::BoundType type, Value value,
+ std::optional<int64_t> dim,
+ StopConditionFn stopCondition);
+
+ /// Add a bound for the given index-typed value or shaped value. This function
+ /// returns a builder that adds the bound.
+ BoundBuilder bound(Value value) { return BoundBuilder(*this, value); }
+
+ /// Return an expression that represents the given index-typed value or shaped
+ /// value dimension. If this value/dimension was not used so far, it is added
+ /// to the worklist.
+ ///
+ /// `dim` must be `nullopt` if and only if the given value is of index type.
+ AffineExpr getExpr(Value value, std::optional<int64_t> dim = std::nullopt);
+
+ /// Return an expression that represents a constant or index-typed SSA value.
+ /// In case of a value, if this value was not used so far, it is added to the
+ /// worklist.
+ AffineExpr getExpr(OpFoldResult ofr);
+
+ /// Return an expression that represents a constant.
+ AffineExpr getExpr(int64_t constant);
+
+protected:
+ /// Dimension identifier to indicate a value is index-typed. This is used for
+ /// internal data structures/API only.
+ static constexpr int64_t kIndexValue = -1;
+
+ /// An index-typed value or the dimension of a shaped-type value.
+ using ValueDim = std::pair<Value, int64_t>;
+
+ ValueBoundsConstraintSet(Value value, std::optional<int64_t> dim);
+
+ /// Iteratively process all elements on the worklist until an index-typed
+ /// value or shaped value meets `stopCondition`. Such values are not processed
+ /// any further.
+ void processWorklist(StopConditionFn stopCondition);
+
+ /// Bound the given column in the underlying constraint set by the given
+ /// expression.
+ void addBound(presburger::BoundType type, int64_t pos, AffineExpr expr);
+
+ /// Return the column position of the given value/dimension. Asserts that the
+ /// value/dimension exists in the constraint set.
+ int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;
+
+ /// Insert a value/dimension into the constraint set. If `isSymbol` is set to
+ /// "false", a dimension is added.
+ ///
+ /// Note: There are certain affine restrictions wrt. dimensions. E.g., they
+ /// cannot be multiplied. Furthermore, bounds can only be queried for
+ /// dimensions but not for symbols.
+ int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);
+
+ /// Project out the given column in the constraint set.
+ void projectOut(int64_t pos);
+
+ /// Project out all columns for which the condition holds.
+ void projectOut(function_ref<bool(ValueDim)> condition);
+
+ /// Mapping of columns to values/shape dimensions.
+ SmallVector<ValueDim> positionToValueDim;
+ /// Reverse mapping of values/shape dimensions to columns.
+ DenseMap<ValueDim, int64_t> valueDimToPosition;
+
+ /// Worklist of values/shape dimensions that have not been processed yet.
+ SetVector<int64_t> worklist;
+
+ /// Constraint system of equalities and inequalities.
+ FlatLinearConstraints cstr;
+
+ /// Builder for constructing affine expressions.
+ Builder builder;
+};
+
+} // namespace mlir
+
+#include "mlir/Interfaces/ValueBoundsOpInterface.h.inc"
+
+namespace mlir {
+
+/// Default implementation for destination style ops: Tied OpResults and
+/// OpOperands have the same type.
+template <typename ConcreteOp>
+struct DstValueBoundsOpInterfaceExternalModel
+ : public ValueBoundsOpInterface::ExternalModel<
+ DstValueBoundsOpInterfaceExternalModel<ConcreteOp>, ConcreteOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto dstOp = cast<DestinationStyleOpInterface>(op);
+ assert(value.getDefiningOp() == dstOp);
+
+ Value tiedOperand = dstOp.getTiedOpOperand(value.cast<OpResult>())->get();
+ cstr.bound(value)[dim] == cstr.getExpr(tiedOperand, dim);
+ }
+};
+
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_VALUEBOUNDSOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td
new file mode 100644
index 0000000000000..a080209101a53
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.td
@@ -0,0 +1,63 @@
+//===-- ValueBoundsOpInterface.td - Value Bounds -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef VALUEBOUNDSOPINTERFACE
+#define VALUEBOUNDSOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def ValueBoundsOpInterface : OpInterface<"ValueBoundsOpInterface"> {
+ let description = [{
+ This interface allows operations with index-typed and/or shaped value-typed
+ results/block arguments to specify range bounds. These bounds are stored in
+ a constraint set. The constraint set can then be queried to compute bounds
+ in terms of other values that are stored in the constraint set.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate the constraint set with bounds for the given index-typed
+ value.
+
+ Note: If `value` is a block argument, it must belong to an entry block
+ of a region. Unstructured control flow graphs are not supported at the
+ moment.
+ }],
+ /*retType=*/"void",
+ /*methodName=*/"populateBoundsForIndexValue",
+ /*args=*/(ins "::mlir::Value":$value,
+ "::mlir::ValueBoundsConstraintSet &":$cstr),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ llvm_unreachable("not implemented");
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate the constraint set with bounds for the size of the specified
+ dimension of the given shaped value.
+
+ Note: If `value` is a block argument, it must belong to an entry block
+ of a region. Unstructured control flow graphs are not supported at the
+ moment.
+ }],
+ /*retType=*/"void",
+ /*methodName=*/"populateBoundsForShapedValueDim",
+ /*args=*/(ins "::mlir::Value":$value,
+ "int64_t":$dim,
+ "::mlir::ValueBoundsConstraintSet &":$cstr),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ llvm_unreachable("not implemented");
+ }]
+ >,
+ ];
+}
+
+#endif // VALUEBOUNDSOPINTERFACE
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 0535b8b5d8f52..60d14a373348d 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
LoopUnroll.cpp
LoopUnrollAndJam.cpp
PipelineDataTransfer.cpp
+ ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
@@ -33,7 +34,9 @@ add_mlir_dialect_library(MLIRAffineTransforms
MLIRPass
MLIRSCFUtils
MLIRSideEffectInterfaces
+ MLIRTensorDialect
MLIRTransformUtils
+ MLIRValueBoundsOpInterface
MLIRVectorDialect
MLIRVectorUtils
MLIRVectorToLLVM
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
new file mode 100644
index 0000000000000..80b4daa83b136
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -0,0 +1,87 @@
+//===- ReifyValueBounds.cpp --- Reify value bounds with affine ops ------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+
+FailureOr<OpFoldResult> mlir::reifyValueBound(OpBuilder &b, Location loc,
+ presburger::BoundType type,
+ Value value,
+ std::optional<int64_t> dim) {
+ // We are trying to reify a bound for `value`. Construct a stop condition that
+ // evaluates to "true" for any SSA value expect for `value`. I.e., the bound
+ // will be computed in terms of any SSA values expect for `value`. The first
+ // such values are operands of the owner of `value`.
+ auto stopCondition = [&](Value v) {
+ // Reify in terms of SSA values that are
diff erent from `value`.
+ return v != value;
+ };
+ return reifyValueBound(b, loc, type, value, dim, stopCondition);
+}
+
+FailureOr<OpFoldResult>
+mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+ Value value, std::optional<int64_t> dim,
+ function_ref<bool(Value)> stopCondition) {
+ // Compute bound.
+ AffineMap boundMap;
+ ValueDimList mapOperands;
+ if (failed(ValueBoundsConstraintSet::computeBound(boundMap, mapOperands, type,
+ value, dim, stopCondition)))
+ return failure();
+
+ // Materialize tensor.dim/memref.dim ops.
+ SmallVector<Value> operands;
+ for (auto valueDim : mapOperands) {
+ Value value = valueDim.first;
+ std::optional<int64_t> dim = valueDim.second;
+
+ if (!dim.has_value()) {
+ // This is an index-typed value.
+ assert(value.getType().isIndex() && "expected index type");
+ operands.push_back(value);
+ continue;
+ }
+
+ assert(cast<ShapedType>(value.getType()).isDynamicDim(*dim) &&
+ "expected dynamic dim");
+ if (isa<RankedTensorType>(value.getType())) {
+ // A tensor dimension is used: generate a tensor.dim.
+ operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
+ } else if (isa<MemRefType>(value.getType())) {
+ // A memref dimension is used: generate a memref.dim.
+ operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
+ } else {
+ llvm_unreachable("cannot generate DimOp for unsupported shaped type");
+ }
+ }
+
+ // Simplify and return bound.
+ mlir::canonicalizeMapAndOperands(&boundMap, &operands);
+ // Check for special cases where no affine.apply op is needed.
+ if (boundMap.isSingleConstant()) {
+ // Bound is a constant: return an IntegerAttr.
+ return static_cast<OpFoldResult>(
+ b.getIndexAttr(boundMap.getSingleConstantResult()));
+ }
+ // No affine.apply op is needed if the bound is a single SSA value.
+ if (auto expr = boundMap.getResult(0).dyn_cast<AffineDimExpr>())
+ return static_cast<OpFoldResult>(operands[expr.getPosition()]);
+ if (auto expr = boundMap.getResult(0).dyn_cast<AffineSymbolExpr>())
+ return static_cast<OpFoldResult>(
+ operands[expr.getPosition() + boundMap.getNumDims()]);
+ // General case: build affine.apply op.
+ return static_cast<OpFoldResult>(
+ b.create<AffineApplyOp>(loc, boundMap, operands).getResult());
+}
diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
index 141a9d1f2dc18..9cb2b0e68c521 100644
--- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt
@@ -3,11 +3,13 @@ set(LLVM_OPTIONAL_SOURCES
TensorInferTypeOpInterfaceImpl.cpp
TensorOps.cpp
TensorTilingInterfaceImpl.cpp
+ ValueBoundsOpInterfaceImpl.cpp
)
add_mlir_dialect_library(MLIRTensorDialect
TensorDialect.cpp
TensorOps.cpp
+ ValueBoundsOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/mlir/Dialect/Tensor
@@ -32,6 +34,7 @@ add_mlir_dialect_library(MLIRTensorDialect
MLIRShapedOpInterfaces
MLIRSideEffectInterfaces
MLIRSupport
+ MLIRValueBoundsOpInterface
MLIRViewLikeInterface
)
diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..8d6baa184af7b
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -0,0 +1,129 @@
+//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace tensor {
+namespace {
+
+struct CastOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto castOp = cast<CastOp>(op);
+ assert(value == castOp.getResult() && "invalid value");
+
+ if (castOp.getResult().getType().isa<RankedTensorType>() &&
+ castOp.getSource().getType().isa<RankedTensorType>()) {
+ cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
+ }
+ }
+};
+
+struct DimOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto dimOp = cast<DimOp>(op);
+ assert(value == dimOp.getResult() && "invalid value");
+
+ auto constIndex = dimOp.getConstantIndex();
+ if (!constIndex.has_value())
+ return;
+ cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
+ }
+};
+
+struct EmptyOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto emptyOp = cast<EmptyOp>(op);
+ assert(value == emptyOp.getResult() && "invalid value");
+
+ cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim];
+ }
+};
+
+struct ExtractSliceOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
+ ExtractSliceOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto extractSliceOp = cast<ExtractSliceOp>(op);
+ assert(value == extractSliceOp.getResult() && "invalid value");
+
+ llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
+ int64_t ctr = -1;
+ for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
+ // Skip over rank-reduced dimensions.
+ if (!dropped.test(i))
+ ++ctr;
+ if (ctr == dim) {
+ cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
+ return;
+ }
+ }
+ llvm_unreachable("could not find non-rank-reduced dim");
+ }
+};
+
+struct PadOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto padOp = cast<PadOp>(op);
+ assert(value == padOp.getResult() && "invalid value");
+
+ AffineExpr expr = cstr.getExpr(padOp.getSource(), dim) +
+ cstr.getExpr(padOp.getMixedLowPad()[dim]) +
+ cstr.getExpr(padOp.getMixedHighPad()[dim]);
+ cstr.bound(value)[dim] == expr;
+ }
+};
+
+struct RankOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto rankOp = cast<RankOp>(op);
+ assert(value == rankOp.getResult() && "invalid value");
+
+ auto tensorType = rankOp.getTensor().getType().dyn_cast<RankedTensorType>();
+ if (!tensorType)
+ return;
+ cstr.bound(value) == tensorType.getRank();
+ }
+};
+
+} // namespace
+} // namespace tensor
+} // namespace mlir
+
+void mlir::tensor::registerValueBoundsOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
+ tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
+ tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
+ tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
+ tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
+ *ctx);
+ tensor::InsertOp::attachInterface<
+ DstValueBoundsOpInterfaceExternalModel<tensor::InsertOp>>(*ctx);
+ tensor::InsertSliceOp::attachInterface<
+ DstValueBoundsOpInterfaceExternalModel<tensor::InsertSliceOp>>(*ctx);
+ tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
+ tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index 38ad0e4a2231c..20073e7030557 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -14,6 +14,7 @@ set(LLVM_OPTIONAL_SOURCES
ShapedOpInterfaces.cpp
SideEffectInterfaces.cpp
TilingInterface.cpp
+ ValueBoundsOpInterface.cpp
VectorInterfaces.cpp
ViewLikeInterface.cpp
)
@@ -52,4 +53,18 @@ add_mlir_interface_library(TilingInterface)
add_mlir_interface_library(VectorInterfaces)
add_mlir_interface_library(ViewLikeInterface)
+add_mlir_library(MLIRValueBoundsOpInterface
+ ValueBoundsOpInterface.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
+
+ DEPENDS
+ MLIRValueBoundsOpInterfaceIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAnalysis
+ MLIRIR
+ )
+
add_subdirectory(Utils)
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
new file mode 100644
index 0000000000000..73757b611520a
--- /dev/null
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -0,0 +1,397 @@
+//===- ValueBoundsOpInterface.cpp - Value Bounds -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/APSInt.h"
+
+using namespace mlir;
+using presburger::BoundType;
+using presburger::VarKind;
+
+namespace mlir {
+#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
+} // namespace mlir
+
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
+static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
+ // Case 1: Check for Constant integer.
+ if (auto val = ofr.dyn_cast<Value>()) {
+ APSInt intVal;
+ if (matchPattern(val, m_ConstantInt(&intVal)))
+ return intVal.getSExtValue();
+ return std::nullopt;
+ }
+ // Case 2: Check for IntegerAttr.
+ Attribute attr = ofr.dyn_cast<Attribute>();
+ if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+ return intAttr.getValue().getSExtValue();
+ return std::nullopt;
+}
+
+ValueBoundsConstraintSet::ValueBoundsConstraintSet(Value value,
+ std::optional<int64_t> dim)
+ : builder(value.getContext()) {
+ insert(value, dim, /*isSymbol=*/false);
+}
+
+#ifndef NDEBUG
+static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
+ if (value.getType().isIndex()) {
+ assert(!dim.has_value() && "invalid dim value");
+ } else if (auto shapedType = dyn_cast<ShapedType>(value.getType())) {
+ assert(*dim >= 0 && "invalid dim value");
+ if (shapedType.hasRank())
+ assert(*dim < shapedType.getRank() && "invalid dim value");
+ } else {
+ llvm_unreachable("unsupported type");
+ }
+}
+#endif // NDEBUG
+
+void ValueBoundsConstraintSet::addBound(BoundType type, int64_t pos,
+ AffineExpr expr) {
+ LogicalResult status = cstr.addBound(
+ type, pos,
+ AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr));
+ (void)status;
+ assert(succeeded(status) && "failed to add bound to constraint system");
+}
+
+AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
+ std::optional<int64_t> dim) {
+#ifndef NDEBUG
+ assertValidValueDim(value, dim);
+#endif // NDEBUG
+
+ auto shapedType = dyn_cast<ShapedType>(value.getType());
+ if (shapedType) {
+ // Static dimension: return constant directly.
+ if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
+ return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
+ } else {
+ // Constant index value: return directly.
+ if (auto constInt = getConstantIntValue(value))
+ return builder.getAffineConstantExpr(*constInt);
+ }
+
+ // Dynamic value: add to constraint set.
+ ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
+ if (valueDimToPosition.find(valueDim) == valueDimToPosition.end())
+ (void)insert(value, dim);
+ int64_t pos = getPos(value, dim);
+ return pos < cstr.getNumDimVars()
+ ? builder.getAffineDimExpr(pos)
+ : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
+}
+
+AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
+ if (Value value = ofr.dyn_cast<Value>())
+ return getExpr(value, /*dim=*/std::nullopt);
+ auto constInt = getConstantIntValue(ofr);
+ assert(constInt.has_value() && "expected Integer constant");
+ return builder.getAffineConstantExpr(*constInt);
+}
+
+AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {
+ return builder.getAffineConstantExpr(constant);
+}
+
+int64_t ValueBoundsConstraintSet::insert(Value value,
+ std::optional<int64_t> dim,
+ bool isSymbol) {
+#ifndef NDEBUG
+ assertValidValueDim(value, dim);
+#endif // NDEBUG
+
+ ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
+ assert((valueDimToPosition.find(valueDim) == valueDimToPosition.end()) &&
+ "already mapped");
+ int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
+ : cstr.appendVar(VarKind::SetDim);
+ positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
+ // Update reverse mapping.
+ for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
+ valueDimToPosition[positionToValueDim[i]] = i;
+
+ worklist.insert(pos);
+ return pos;
+}
+
+int64_t ValueBoundsConstraintSet::getPos(Value value,
+ std::optional<int64_t> dim) const {
+#ifndef NDEBUG
+ assertValidValueDim(value, dim);
+ assert((value.isa<OpResult>() ||
+ value.cast<BlockArgument>().getOwner()->isEntryBlock()) &&
+ "unstructured control flow is not supported");
+#endif // NDEBUG
+
+ auto it =
+ valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
+ assert(it != valueDimToPosition.end() && "expected mapped entry");
+ return it->second;
+}
+
+static Operation *getOwnerOfValue(Value value) {
+ if (auto bbArg = value.dyn_cast<BlockArgument>())
+ return bbArg.getOwner()->getParentOp();
+ return value.getDefiningOp();
+}
+
+void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
+ while (!worklist.empty()) {
+ int64_t pos = worklist.pop_back_val();
+ ValueDim valueDim = positionToValueDim[pos];
+ Value value = valueDim.first;
+ int64_t dim = valueDim.second;
+
+ // Check for static dim size.
+ if (dim != kIndexValue) {
+ auto shapedType = cast<ShapedType>(value.getType());
+ if (shapedType.hasRank() && !shapedType.isDynamicDim(dim)) {
+ bound(value)[dim] == getExpr(shapedType.getDimSize(dim));
+ continue;
+ }
+ }
+
+ // Do not process any further if the stop condition is met.
+ if (stopCondition(value))
+ continue;
+
+ // Query `ValueBoundsOpInterface` for constraints. New items may be added to
+ // the worklist.
+ auto valueBoundsOp =
+ dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
+ if (!valueBoundsOp)
+ continue;
+ if (dim == kIndexValue) {
+ valueBoundsOp.populateBoundsForIndexValue(value, *this);
+ } else {
+ valueBoundsOp.populateBoundsForShapedValueDim(value, dim, *this);
+ }
+ }
+}
+
+void ValueBoundsConstraintSet::projectOut(int64_t pos) {
+ assert(pos >= 0 && pos < static_cast<int64_t>(positionToValueDim.size()) &&
+ "invalid position");
+ cstr.projectOut(pos);
+ bool erased = valueDimToPosition.erase(positionToValueDim[pos]);
+ (void)erased;
+ assert(erased && "inconsistent reverse mapping");
+ positionToValueDim.erase(positionToValueDim.begin() + pos);
+ // Update reverse mapping.
+ for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
+ valueDimToPosition[positionToValueDim[i]] = i;
+}
+
+void ValueBoundsConstraintSet::projectOut(
+ function_ref<bool(ValueDim)> condition) {
+ int64_t nextPos = 0;
+ while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
+ if (condition(positionToValueDim[nextPos])) {
+ projectOut(nextPos);
+ // The column was projected out so another column is now at that position.
+ // Do not increase the counter.
+ } else {
+ ++nextPos;
+ }
+ }
+}
+
+LogicalResult ValueBoundsConstraintSet::computeBound(
+ AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
+ Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
+#ifndef NDEBUG
+ assertValidValueDim(value, dim);
+#endif // NDEBUG
+
+ // Only EQ bounds are supported at the moment.
+ assert(type == BoundType::EQ && "unsupported bound type");
+
+ Builder b(value.getContext());
+ mapOperands.clear();
+
+ if (stopCondition(value)) {
+ // Special case: If the stop condition is satisfied for the input
+ // value/dimension, directly return it.
+ mapOperands.push_back(std::make_pair(value, dim));
+ resultMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
+ b.getAffineDimExpr(0));
+ return success();
+ }
+
+ // Process the backward slice of `value` (i.e., reverse use-def chain) until
+ // `stopCondition` is met.
+ ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
+ ValueBoundsConstraintSet cstr(value, dim);
+ cstr.processWorklist(stopCondition);
+
+ // Project out all variables (apart from `valueDim`) that do not match the
+ // stop condition.
+ cstr.projectOut([&](ValueDim p) {
+ // Do not project out `valueDim`.
+ if (valueDim == p)
+ return false;
+ return !stopCondition(p.first);
+ });
+
+ // Compute lower and upper bounds for `valueDim`.
+ int64_t pos = cstr.getPos(value, dim);
+ SmallVector<AffineMap> lb(1), ub(1);
+ cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub,
+ /*getClosedUB=*/true);
+ // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
+ // case, no lower/upper bound can be computed at the moment.
+ if (lb.empty() || !lb[0] || ub.empty() || !ub[0] ||
+ lb[0].getNumResults() != 1 || ub[0].getNumResults() != 1)
+ return failure();
+
+ // Look for same lower and upper bound: EQ bound.
+ if (ub[0] != lb[0])
+ return failure();
+
+ // Gather all SSA values that are used in the computed bound.
+ assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() &&
+ "inconsistent mapping state");
+ SmallVector<AffineExpr> replacementDims, replacementSymbols;
+ int64_t numDims = 0, numSymbols = 0;
+ for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) {
+ // Skip `value`.
+ if (i == pos)
+ continue;
+ // Check if the position `i` is used in the generated bound. If so, it must
+ // be included in the generated affine.apply op.
+ bool used = false;
+ bool isDim = i < cstr.cstr.getNumDimVars();
+ if (isDim) {
+ if (lb[0].isFunctionOfDim(i))
+ used = true;
+ } else {
+ if (lb[0].isFunctionOfSymbol(i - cstr.cstr.getNumDimVars()))
+ used = true;
+ }
+
+ if (!used) {
+ // Not used: Remove dim/symbol from the result.
+ if (isDim) {
+ replacementDims.push_back(b.getAffineConstantExpr(0));
+ } else {
+ replacementSymbols.push_back(b.getAffineConstantExpr(0));
+ }
+ continue;
+ }
+
+ if (isDim) {
+ replacementDims.push_back(b.getAffineDimExpr(numDims++));
+ } else {
+ replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++));
+ }
+
+ ValueBoundsConstraintSet::ValueDim valueDim = cstr.positionToValueDim[i];
+ Value value = valueDim.first;
+ int64_t dim = valueDim.second;
+ if (dim == ValueBoundsConstraintSet::kIndexValue) {
+ // An index-type value is used: can be used directly in the affine.apply
+ // op.
+ assert(value.getType().isIndex() && "expected index type");
+ mapOperands.push_back(std::make_pair(value, std::nullopt));
+ continue;
+ }
+
+ assert(cast<ShapedType>(value.getType()).isDynamicDim(dim) &&
+ "expected dynamic dim");
+ mapOperands.push_back(std::make_pair(value, dim));
+ }
+
+ resultMap = lb[0].replaceDimsAndSymbols(replacementDims, replacementSymbols,
+ numDims, numSymbols);
+ return success();
+}
+
+ValueBoundsConstraintSet::BoundBuilder &
+ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
+ assert(!this->dim.has_value() && "dim was already set");
+ this->dim = dim;
+#ifndef NDEBUG
+ assertValidValueDim(value, this->dim);
+#endif // NDEBUG
+ return *this;
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator<(AffineExpr expr) {
+#ifndef NDEBUG
+ assertValidValueDim(value, this->dim);
+#endif // NDEBUG
+ cstr.addBound(BoundType::UB, cstr.getPos(value, this->dim), expr);
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator<=(AffineExpr expr) {
+ operator<(expr + 1);
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator>(AffineExpr expr) {
+ operator>=(expr + 1);
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator>=(AffineExpr expr) {
+#ifndef NDEBUG
+ assertValidValueDim(value, this->dim);
+#endif // NDEBUG
+ cstr.addBound(BoundType::LB, cstr.getPos(value, this->dim), expr);
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator==(AffineExpr expr) {
+#ifndef NDEBUG
+ assertValidValueDim(value, this->dim);
+#endif // NDEBUG
+ cstr.addBound(BoundType::EQ, cstr.getPos(value, this->dim), expr);
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator<(OpFoldResult ofr) {
+ operator<(cstr.getExpr(ofr));
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator<=(OpFoldResult ofr) {
+ operator<=(cstr.getExpr(ofr));
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator>(OpFoldResult ofr) {
+ operator>(cstr.getExpr(ofr));
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator>=(OpFoldResult ofr) {
+ operator>=(cstr.getExpr(ofr));
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator==(OpFoldResult ofr) {
+ operator==(cstr.getExpr(ofr));
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator<(int64_t i) {
+ operator<(cstr.getExpr(i));
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator<=(int64_t i) {
+ operator<=(cstr.getExpr(i));
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator>(int64_t i) {
+ operator>(cstr.getExpr(i));
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator>=(int64_t i) {
+ operator>=(cstr.getExpr(i));
+}
+
+void ValueBoundsConstraintSet::BoundBuilder::operator==(int64_t i) {
+ operator==(cstr.getExpr(i));
+}
diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
new file mode 100644
index 0000000000000..c376af1089aec
--- /dev/null
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \
+// RUN: -verify-diagnostics -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @reify_through_chain(
+// CHECK-SAME: %[[sz0:.*]]: index, %[[sz2:.*]]: index
+// CHECK: %[[c10:.*]] = arith.constant 10 : index
+// CHECK: return %[[sz0]], %[[c10]], %[[sz2]]
+func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index) {
+ %c2 = arith.constant 2 : index
+ %0 = tensor.empty(%sz0, %sz2) : tensor<?x10x?xf32>
+ %1 = tensor.cast %0 : tensor<?x10x?xf32> to tensor<?x?x?xf32>
+ %pos = arith.constant 0 : index
+ %f = arith.constant 0.0 : f32
+ %2 = tensor.insert %f into %1[%pos, %pos, %pos] : tensor<?x?x?xf32>
+ %3 = tensor.dim %2, %c2 : tensor<?x?x?xf32>
+
+ %4 = "test.reify_bound"(%2) {dim = 0} : (tensor<?x?x?xf32>) -> (index)
+ %5 = "test.reify_bound"(%2) {dim = 1} : (tensor<?x?x?xf32>) -> (index)
+ %6 = "test.reify_bound"(%3) : (index) -> (index)
+
+ return %4, %5, %6 : index, index, index
+}
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
new file mode 100644
index 0000000000000..576759e4f21ca
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN: -split-input-file | FileCheck %s
+
+func.func @unknown_op() -> index {
+ %0 = "test.foo"() : () -> (tensor<?x?xf32>)
+ // expected-error @below{{could not reify bound}}
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?x?xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cast(
+// CHECK: %[[c10:.*]] = arith.constant 10 : index
+// CHECK: return %[[c10]]
+func.func @cast(%t: tensor<10xf32>) -> index {
+ %0 = tensor.cast %t : tensor<10xf32> to tensor<?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
+func.func @cast_unranked(%t: tensor<*xf32>) -> index {
+ %0 = tensor.cast %t : tensor<*xf32> to tensor<?xf32>
+ // expected-error @below{{could not reify bound}}
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @dim(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
+// CHECK: %[[dim:.*]] = tensor.dim %[[t]]
+// CHECK: %[[dim:.*]] = tensor.dim %[[t]]
+// CHECK: return %[[dim]]
+func.func @dim(%t: tensor<?xf32>) -> index {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.dim %t, %c0 : tensor<?xf32>
+ %1 = "test.reify_bound"(%0) : (index) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @empty(
+// CHECK-SAME: %[[sz:.*]]: index
+// CHECK: %[[c6:.*]] = arith.constant 6 : index
+// CHECK: return %[[c6]], %[[sz]]
+func.func @empty(%sz: index) -> (index, index) {
+ %0 = tensor.empty(%sz) : tensor<6x?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<6x?xf32>) -> (index)
+ %2 = "test.reify_bound"(%0) {dim = 1} : (tensor<6x?xf32>) -> (index)
+ return %1, %2 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_slice_dynamic(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>, %[[sz:.*]]: index
+// CHECK: return %[[sz]]
+func.func @extract_slice_dynamic(%t: tensor<?xf32>, %sz: index) -> index {
+ %0 = tensor.extract_slice %t[2][%sz][1] : tensor<?xf32> to tensor<?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_slice_static(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
+// CHECK: %[[c5:.*]] = arith.constant 5 : index
+// CHECK: return %[[c5]]
+func.func @extract_slice_static(%t: tensor<?xf32>) -> index {
+ %0 = tensor.extract_slice %t[2][5][1] : tensor<?xf32> to tensor<5xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<5xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_slice_rank_reduce(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[sz:.*]]: index
+// CHECK: return %[[sz]]
+func.func @extract_slice_rank_reduce(%t: tensor<?x?xf32>, %sz: index) -> index {
+ %0 = tensor.extract_slice %t[0, 2][1, %sz][1, 1] : tensor<?x?xf32> to tensor<?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @insert(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]]
+// CHECK: return %[[dim]]
+func.func @insert(%t: tensor<?xf32>, %f: f32, %pos: index) -> index {
+ %0 = tensor.insert %f into %t[%pos] : tensor<?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
+ return %1 : index
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
+// CHECK: #[[$map1:.*]] = affine_map<()[s0] -> (s0 + 12)>
+// CHECK-LABEL: func @pad(
+// CHECK-SAME: %[[t:.*]]: tensor<?x7xf32>, %[[a:.*]]: index, %[[b:.*]]: index
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[dim0:.*]] = tensor.dim %[[t]], %[[c0]]
+// CHECK: %[[bound0:.*]] = affine.apply #[[$map]]()[%[[dim0]], %[[a]]]
+// CHECK: %[[bound1:.*]] = affine.apply #[[$map1]]()[%[[b]]]
+// CHECK: return %[[bound0]], %[[bound1]]
+func.func @pad(%t: tensor<?x7xf32>, %a: index, %b: index) -> (index, index) {
+ %pad = arith.constant 0.0 : f32
+ %0 = tensor.pad %t low[%a, 5] high[%a, %b] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %pad : f32
+ } : tensor<?x7xf32> to tensor<?x?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?x?xf32>) -> (index)
+ %2 = "test.reify_bound"(%0) {dim = 1} : (tensor<?x?xf32>) -> (index)
+ return %1, %2 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @rank(
+// CHECK-SAME: %[[t:.*]]: tensor<5xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: return %[[c1]]
+func.func @rank(%t: tensor<5xf32>) -> index {
+ %0 = tensor.rank %t : tensor<5xf32>
+ %1 = "test.reify_bound"(%0) : (index) -> (index)
+ return %1 : index
+}
diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
index 8fe9113f69848..d76884d5a32d7 100644
--- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_library(MLIRAffineTransformsTestPasses
TestAffineLoopUnswitching.cpp
TestAffineLoopParametricTiling.cpp
TestDecomposeAffineOps.cpp
+ TestReifyValueBounds.cpp
TestLoopFusion.cpp
TestLoopMapping.cpp
TestLoopPermutation.cpp
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
new file mode 100644
index 0000000000000..eccf55136eddf
--- /dev/null
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -0,0 +1,120 @@
+//===- TestReifyValueBounds.cpp - Test value bounds reification -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Pass/Pass.h"
+
+#define PASS_NAME "test-affine-reify-value-bounds"
+
+using namespace mlir;
+
+namespace {
+
+/// This pass applies the permutation on the first maximal perfect nest.
+struct TestReifyValueBounds
+ : public PassWrapper<TestReifyValueBounds, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds)
+
+ StringRef getArgument() const final { return PASS_NAME; }
+ StringRef getDescription() const final {
+ return "Tests ValueBoundsOpInterface with affine dialect reification";
+ }
+ TestReifyValueBounds() = default;
+ TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};
+
+ void runOnOperation() override;
+
+private:
+ Option<bool> reifyToFuncArgs{
+ *this, "reify-to-func-args",
+ llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)};
+};
+
+} // namespace
+
+/// Look for "test.reify_bound" ops in the input and replace their results with
+/// the reified values.
+static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
+ bool reifyToFuncArgs) {
+ IRRewriter rewriter(funcOp.getContext());
+ WalkResult result = funcOp.walk([&](Operation *op) {
+ // Look for test.reify_bound ops.
+ if (op->getName().getStringRef() == "test.reify_bound") {
+ if (op->getNumOperands() != 1 || op->getNumResults() != 1 ||
+ !op->getResultTypes()[0].isIndex()) {
+ op->emitOpError("invalid op");
+ return WalkResult::skip();
+ }
+ Value value = op->getOperand(0);
+ if (value.getType().isa<IndexType>() !=
+ !op->hasAttrOfType<IntegerAttr>("dim")) {
+ // Op should have "dim" attribute if and only if the operand is an
+ // index-typed value.
+ op->emitOpError("invalid op");
+ return WalkResult::skip();
+ }
+
+ auto dim = value.getType().isIndex()
+ ? std::nullopt
+ : std::make_optional<int64_t>(
+ op->getAttrOfType<IntegerAttr>("dim").getInt());
+
+ // Reify value bound.
+ rewriter.setInsertionPointAfter(op);
+ FailureOr<OpFoldResult> reified;
+ if (!reifyToFuncArgs) {
+ // Reify in terms of the op's operands.
+ reified = reifyValueBound(rewriter, op->getLoc(),
+ presburger::BoundType::EQ, value, dim);
+ } else {
+ // Reify in terms of function block arguments.
+ auto stopCondition = [](Value v) {
+ auto bbArg = v.dyn_cast<BlockArgument>();
+ if (!bbArg)
+ return false;
+ return isa<FunctionOpInterface>(
+ bbArg.getParentBlock()->getParentOp());
+ };
+ reified =
+ reifyValueBound(rewriter, op->getLoc(), presburger::BoundType::EQ,
+ value, dim, stopCondition);
+ }
+ if (failed(reified)) {
+ op->emitOpError("could not reify bound");
+ return WalkResult::interrupt();
+ }
+
+ // Replace the op with the reified bound.
+ if (auto val = reified->dyn_cast<Value>()) {
+ rewriter.replaceOp(op, val);
+ return WalkResult::skip();
+ }
+ Value constOp = rewriter.create<arith::ConstantIndexOp>(
+ op->getLoc(), reified->get<Attribute>().cast<IntegerAttr>().getInt());
+ rewriter.replaceOp(op, constOp);
+ return WalkResult::skip();
+ }
+ return WalkResult::advance();
+ });
+ return failure(result.wasInterrupted());
+}
+
+void TestReifyValueBounds::runOnOperation() {
+ if (failed(testReifyValueBounds(getOperation(), reifyToFuncArgs)))
+ signalPassFailure();
+}
+
+namespace mlir {
+void registerTestAffineReifyValueBoundsPass() {
+ PassRegistration<TestReifyValueBounds>();
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2ddd83f36f52d..12eee4ba924bd 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -40,6 +40,7 @@ void registerSliceAnalysisTestPass();
void registerSymbolTestPasses();
void registerRegionTestPasses();
void registerTestAffineDataCopyPass();
+void registerTestAffineReifyValueBoundsPass();
void registerTestDecomposeAffineOpPass();
void registerTestAffineLoopUnswitchingPass();
void registerTestAllReduceLoweringPass();
@@ -151,6 +152,7 @@ void registerTestPasses() {
registerSymbolTestPasses();
registerRegionTestPasses();
registerTestAffineDataCopyPass();
+ registerTestAffineReifyValueBoundsPass();
registerTestDecomposeAffineOpPass();
registerTestAffineLoopUnswitchingPass();
registerTestAllReduceLoweringPass();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 711a51d2bc929..6932d051f4df6 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2836,7 +2836,9 @@ cc_library(
":SCFDialect",
":SCFUtils",
":Support",
+ ":TensorDialect",
":Transforms",
+ ":ValueBoundsOpInterface",
":VectorDialect",
":VectorUtils",
"//llvm:Support",
@@ -5653,8 +5655,12 @@ cc_library(
"include/mlir/Transforms/InliningUtils.h",
"lib/Dialect/Tensor/IR/TensorDialect.cpp",
"lib/Dialect/Tensor/IR/TensorOps.cpp",
+ "lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp",
+ ],
+ hdrs = [
+ "include/mlir/Dialect/Tensor/IR/Tensor.h",
+ "include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h",
],
- hdrs = ["include/mlir/Dialect/Tensor/IR/Tensor.h"],
includes = ["include"],
deps = [
":AffineDialect",
@@ -5673,6 +5679,7 @@ cc_library(
":Support",
":TensorOpsIncGen",
":TilingInterface",
+ ":ValueBoundsOpInterface",
":ViewLikeInterface",
"//llvm:Support",
],
@@ -8808,6 +8815,52 @@ cc_library(
],
)
+td_library(
+ name = "ValueBoundsOpInterfaceTdFiles",
+ srcs = [
+ "include/mlir/Interfaces/ValueBoundsOpInterface.td",
+ ],
+ includes = ["include"],
+ deps = [
+ ":OpBaseTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "ValueBoundsOpInterfaceIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "include/mlir/Interfaces/ValueBoundsOpInterface.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "include/mlir/Interfaces/ValueBoundsOpInterface.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Interfaces/ValueBoundsOpInterface.td",
+ deps = [
+ ":ValueBoundsOpInterfaceTdFiles",
+ ],
+)
+
+cc_library(
+ name = "ValueBoundsOpInterface",
+ srcs = ["lib/Interfaces/ValueBoundsOpInterface.cpp"],
+ hdrs = ["include/mlir/Interfaces/ValueBoundsOpInterface.h"],
+ includes = ["include"],
+ deps = [
+ ":Analysis",
+ ":DestinationStyleOpInterface",
+ ":IR",
+ ":Support",
+ ":ValueBoundsOpInterfaceIncGen",
+ "//llvm:Support",
+ ],
+)
+
cc_library(
name = "TilingInterface",
srcs = ["lib/Interfaces/TilingInterface.cpp"],
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 23d076de07e9d..28a1a1943f7c9 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -559,6 +559,7 @@ cc_library(
"//mlir:SCFDialect",
"//mlir:Support",
"//mlir:Transforms",
+ "//mlir:ValueBoundsOpInterface",
"//mlir:VectorDialect",
"//mlir:VectorUtils",
],
More information about the Mlir-commits
mailing list