[Mlir-commits] [mlir] d31ba52 - [mlir][Interface] Factor out common IndexingMapOpInterface behavior in a new generic interface (#145313)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 23 22:56:36 PDT 2025
Author: Nicolas Vasilache
Date: 2025-06-24T07:56:32+02:00
New Revision: d31ba5256327d30f264c2f671bf197877b242cde
URL: https://github.com/llvm/llvm-project/commit/d31ba5256327d30f264c2f671bf197877b242cde
DIFF: https://github.com/llvm/llvm-project/commit/d31ba5256327d30f264c2f671bf197877b242cde.diff
LOG: [mlir][Interface] Factor out common IndexingMapOpInterface behavior in a new generic interface (#145313)
Refactor the verifiers to make use of the common bits and make
`vector.contract` also use this interface.
In the process, the confusingly named getStaticShape has disappeared.
Note: the verifier for IndexingMapOpInterface is currently called
manually from other verifiers as it was unclear how to avoid it taking
precedence over more meaningful error messages
Added:
mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
mlir/lib/Interfaces/IndexingMapOpInterface.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/Interfaces/CMakeLists.txt
mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/IR/CMakeLists.txt
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Interfaces/CMakeLists.txt
mlir/test/Dialect/Linalg/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0f960fb5ad795..0ebbeea937554 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -20,6 +20,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/RawOstreamExtras.h"
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 74c4c0a8835f2..ca1cba8747bd8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -14,6 +14,7 @@
#define LINALG_IR_LINALGINTERFACES
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/IndexingMapOpInterface.td"
include "mlir/IR/OpBase.td"
// The 'LinalgContractionOpInterface' provides access to the
@@ -222,59 +223,11 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
];
}
-def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
- let description = [{
- Interface for operations that connect an iteration domain to operands via
- affine maps. Provides methods to access indexing maps between iteration
- domain and operand index spaces.
- }];
- let cppNamespace = "::mlir::linalg";
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Return the indexing maps attribute within the current operation.
- }],
- /*retTy=*/"ArrayAttr",
- /*methodName=*/"getIndexingMaps"
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the indexing maps within the current operation.
- }],
- /*retTy=*/"SmallVector<AffineMap>",
- /*methodName=*/"getIndexingMapsArray",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto range = $_op.getIndexingMaps()
- .template getAsValueRange<AffineMapAttr>();
- return {range.begin(), range.end()};
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the input or output indexing map for `opOperand`.
- }],
- /*retTy=*/"AffineMap",
- /*methodName=*/"getMatchingIndexingMap",
- /*args=*/(ins "OpOperand*":$opOperand),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(opOperand->getOwner() == this->getOperation());
- auto indexingMaps =
- $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
- return *(indexingMaps.begin() + opOperand->getOperandNumber());
- }]
- >,
- ];
-}
-
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
def LinalgStructuredInterface
- : OpInterface<"LinalgOp", [
- DestinationStyleOpInterface,
- IndexingMapOpInterface
- ]> {
+ : OpInterface<"LinalgOp",
+ [DestinationStyleOpInterface, IndexingMapOpInterface]
+ > {
let cppNamespace = "::mlir::linalg";
let methods = [
//===------------------------------------------------------------------===//
@@ -464,30 +417,6 @@ def LinalgStructuredInterface
return getBlock()->getArguments().take_back($_op.getNumDpsInits());
}]
>,
- InterfaceMethod<
- /*desc=*/[{
- Return the `opOperand` shape or an empty vector for scalars or vectors
- not wrapped within a tensor or a memref.
- }],
- /*retTy=*/"ArrayRef<int64_t>",
- /*methodName=*/"getShape",
- /*args=*/(ins "OpOperand*":$opOperand),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(opOperand->getOwner() == this->getOperation());
- Type t = opOperand->get().getType();
- // A VectorType is an elemental type, do not consider its rank for the operand.
- if (isa<VectorType>(t))
- return {};
- if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
- // Failsafe.
- assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
- "expected a ranked tensor or memref in LinalgInterface::getRank");
- return shapedType.getShape();
- }
- return {};
- }]
- >,
InterfaceMethod<
/*desc=*/[{
Return the block argument for an `opOperand`.
@@ -620,7 +549,12 @@ def LinalgStructuredInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return llvm::any_of(getStaticShape(), ShapedType::isDynamic);
+ for (OpOperand &opOperand : this->getOperation()->getOpOperands()) {
+ if (auto shapedType = dyn_cast<ShapedType>(opOperand.get().getType())) {
+ if (ShapedType::isDynamicShape(shapedType.getShape())) return true;
+ }
+ }
+ return false;
}]
>,
InterfaceMethod<
@@ -738,53 +672,6 @@ def LinalgStructuredInterface
//===------------------------------------------------------------------===//
// Linalg generalization hooks.
//===------------------------------------------------------------------===//
- InterfaceMethod<
- /*desc=*/[{
- Hook to provide a custom AffineMap used to compute all the operand
- subshapes given loop bounds. This is used to answer the question: "given
- an iteration space over the codomain, what are the subshapes of the
- operands involved in the computation".
- The default behavior is to just concatenate all the indexing maps.
- A custom AffineMap allows providing a map that can be used to
- compute subshapes even in cases where the concatenation of indexing maps
- (i.e. the data traversal order) is not a simple permutation of the loop
- traversal order. It is then possible to define ops with skewed data
- traversal order for which we can still easily compute hyperrectangular
- loop bounds and subviews.
- }],
- /*retTy=*/"AffineMap",
- /*methodName=*/"getLoopsToShapesMap",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto maps = $_op.getIndexingMapsArray();
- return concatAffineMaps(maps, $_op.getContext());
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Hook to provide a custom AffineMap used to construct the
- hyperrectangular loop iteration space given all the operand subshapes.
- This is used to answer the question:
- "Given a list of operand ranges, what is the subportion of the iteration
- space involved in the computation".
- This is the inverse problem of `getLoopsToShapesMap`.
- Return the empty AffineMap when such an AffineMap cannot be constructed.
- The default behavior is based on a very simple inference procedure that
- only works with permutation affine maps.
- A more advanced Tensor-Comprehension like inference is possible but has
- proven to be ambiguous in unfavorable case.
- A safer and more robust alternative is to allow each op to define
- its own AffineMap.
- }],
- /*retTy=*/"AffineMap",
- /*methodName=*/"getShapesToLoopsMap",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return inversePermutation(getLoopsToShapesMap());
- }]
- >,
InterfaceMethod<
/*desc=*/[{
Checks if the given operands can be dropped, and the remaining
@@ -798,39 +685,30 @@ def LinalgStructuredInterface
return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
}]
>,
+ //===------------------------------------------------------------------===//
+ // IndexingMapOpInterface interface methods implementation.
+ //===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
- Like `getShape`, but only returns statically-known information, without
- generating any new IR. For each shape dimension, returns >=0 if that
- dimension is statically known, or ShapedType::kDynamic otherwise.
- }],
- /*retTy=*/"SmallVector<int64_t>",
- /*methodName=*/"getStaticShape",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- SmallVector<int64_t> res;
- for (OpOperand &opOperand : this->getOperation()->getOpOperands())
- llvm::append_range(res, getShape(&opOperand));
- return res;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Returns the statically-known loop ranges. Composes
- `getShapesToLoopsMap()` with the result of `getStaticShape`.
- Returns ShapedType::kDynamic for non-statically-known loop ranges.
- This is expected to be called by a valid Linalg op
+ Return the `opOperand` shape or an empty vector for scalars or vectors
+ not wrapped within a tensor or a memref.
}],
- /*retTy=*/"SmallVector<int64_t, 4>",
- /*methodName=*/"getStaticLoopRanges",
- /*args=*/(ins),
+ /*retTy=*/"ArrayRef<int64_t>",
+ /*methodName=*/"getShape",
+ /*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- SmallVector<int64_t> viewSizes = getStaticShape();
- AffineMap invertedMap = getShapesToLoopsMap();
- assert(invertedMap && "expected a valid Linalg op to call the method");
- return invertedMap.compose(viewSizes);
+ Type t = opOperand->get().getType();
+ // A VectorType is an elemental type, do not consider its rank for the operand.
+ if (isa<VectorType>(t))
+ return {};
+ if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
+ // Failsafe.
+ assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
+ "expected a ranked tensor or memref in LinalgInterface::getRank");
+ return shapedType.getShape();
+ }
+ return {};
}]
>,
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 98fb6075cbf32..364c1728715e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -25,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 926a92eff2ebb..02e62930a742d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -21,6 +21,7 @@ include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/IndexingMapOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -33,6 +34,7 @@ include "mlir/IR/EnumAttr.td"
// than the current set: {*, +}.
def Vector_ContractionOp :
Vector_Op<"contract", [
+ IndexingMapOpInterface,
Pure,
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
PredOpTrait<"third operand acc and result have same element type",
@@ -207,6 +209,16 @@ def Vector_ContractionOp :
.template getAsValueRange<IteratorTypeAttr, IteratorType>();
return {range.begin(), range.end()};
}
+
+ //===------------------------------------------------------------------===//
+ // IndexingMapOpInterface interface methods implementation.
+ //===------------------------------------------------------------------===//
+ ArrayRef<int64_t> getShape(OpOperand * opOperand) {
+ Type t = opOperand->get().getType();
+ if (auto vt = dyn_cast<VectorType>(t))
+ return vt.getShape();
+ return {};
+ }
}];
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index d81298bb4daf0..067e0511e4e75 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_interface(CopyOpInterface)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
+add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
diff --git a/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
new file mode 100644
index 0000000000000..40252613a21f4
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
@@ -0,0 +1,27 @@
+//===- IndexingMapOpInterface.h ----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
+#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace detail {
+/// Verify that `op` conforms to the invariants of StructuredOpInterface
+LogicalResult verifyIndexingMapOpInterface(Operation *op);
+} // namespace detail
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/IndexingMapOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
new file mode 100644
index 0000000000000..fdcc183d99215
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
@@ -0,0 +1,153 @@
+//===- IndexingMapOpInterface.td - Interface Declaration -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the IndexingMapOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
+#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
+ let description = [{
+ Interface for operations that connect an iteration domain to operands via
+ affine maps. Provides methods to access indexing maps between iteration
+ domain and operand index spaces.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ }],
+ /*retTy=*/"ArrayAttr",
+ /*methodName=*/"getIndexingMaps"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps within the current operation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMapsArray",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto range = $_op.getIndexingMaps()
+ .template getAsValueRange<AffineMapAttr>();
+ return {range.begin(), range.end()};
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the input or output indexing map for `opOperand`.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getMatchingIndexingMap",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(opOperand->getOwner() == this->getOperation());
+ auto indexingMaps =
+ $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
+ return *(indexingMaps.begin() + opOperand->getOperandNumber());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Hook to provide a custom AffineMap used to compute all the operand
+ subshapes given loop bounds. This is used to answer the question: "given
+ an iteration space over the codomain, what are the subshapes of the
+ operands involved in the computation".
+ The default behavior is to just concatenate all the indexing maps.
+ A custom AffineMap allows providing a map that can be used to
+ compute subshapes even in cases where the concatenation of indexing maps
+ (i.e. the data traversal order) is not a simple permutation of the loop
+ traversal order. It is then possible to define ops with skewed data
+ traversal order for which we can still easily compute hyperrectangular
+ loop bounds and subviews.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getLoopsToShapesMap",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto maps = $_op.getIndexingMapsArray();
+ return concatAffineMaps(maps, $_op.getContext());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Hook to provide a custom AffineMap used to construct the
+ hyperrectangular loop iteration space given all the operand subshapes.
+ This is used to answer the question:
+ "Given a list of operand ranges, what is the subportion of the iteration
+ space involved in the computation".
+ This is the inverse problem of `getLoopsToShapesMap`.
+ Return the empty AffineMap when such an AffineMap cannot be constructed.
+ The default behavior is based on a very simple inference procedure that
+ only works with permutation affine maps.
+ A more advanced Tensor-Comprehension like inference is possible but has
+ proven to be ambiguous in unfavorable case.
+ A safer and more robust alternative is to allow each op to define
+ its own AffineMap.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getShapesToLoopsMap",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return inversePermutation($_op.getLoopsToShapesMap());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the static shape of the underlying operand (note this is
+ op-specific behavior).
+ Returns ShapedType::kDynamic for non-statically-known loop ranges.
+ }],
+ /*retTy=*/"SmallVector<int64_t>",
+ /*methodName=*/"getStaticOperandShape",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<int64_t> res;
+ llvm::append_range(res, $_op.getShape(opOperand));
+ return res;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns loop ranges by composing `getShapesToLoopsMap()` with the
+ flattened list of operand shapes.
+ Returns ShapedType::kDynamic for non-statically-known loop ranges.
+ }],
+ /*retTy=*/"SmallVector<int64_t>",
+ /*methodName=*/"getStaticLoopRanges",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<int64_t> allShapesSizes;
+ for (OpOperand &opOperand : this->getOperation()->getOpOperands())
+ llvm::append_range(allShapesSizes, $_op.getShape(&opOperand));
+ AffineMap invertedMap = $_op.getShapesToLoopsMap();
+ assert(invertedMap && "expected a valid op");
+ return invertedMap.compose(allShapesSizes);
+ }]
+ >
+ ];
+ let extraClassDeclaration = [{
+ // Verifier implementation for IndexingMapOpInterface.
+ // This must be called manually as part of other verifiers so that the
+ // verification order, and meaningful error messages, are not preempted.
+ LogicalResult verifyImpl();
+ }];
+}
+
+#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index b4aeb44ac8faf..ec433284e17ad 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRFunctionInterfaces
+ MLIRIndexingMapOpInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRParser
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 139e9901b0a29..ca7f31dd6b518 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1251,38 +1251,20 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
return failure();
- // All input/output operands must be indexed.
- if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
- linalgOp->getNumOperands())
- return op->emitOpError("expected the number of indexing_map (")
- << linalgOp.getIndexingMapsArray().size()
- << ") to be equal to the number of input/output operands ("
- << linalgOp->getNumOperands() << ")";
+ // Delayed calling of IndexingMapOpInterface::verifyImpl.
+ if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
+ return failure();
// Set this flag if this op has user defined maps. This is required to guard
// the below error condition which assume default indexing maps.
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
-
- // Symbols disallowed.
- if (indexingMap.getNumSymbols() != 0)
- return op->emitOpError("unexpected symbols in indexing_map #")
- << opOperand.getOperandNumber();
-
// Domain must be consistent.
unsigned numLoops = linalgOp.getNumLoops();
if (indexingMap.getNumDims() != numLoops)
return op->emitOpError("expected indexing_map #")
<< opOperand.getOperandNumber() << " to have " << numLoops
<< " dim(s) to match the number of loops";
-
- int64_t rank = linalgOp.getRank(&opOperand);
-
- if (indexingMap.getNumResults() != rank)
- return op->emitOpError("expected operand rank (")
- << rank << ") to match the result rank of indexing_map #"
- << opOperand.getOperandNumber() << " ("
- << indexingMap.getNumResults() << ")";
}
SmallVector<unsigned> redDims;
linalgOp.getReductionDims(redDims);
@@ -1290,67 +1272,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
if (!linalgOp.getShapesToLoopsMap())
return op->emitOpError("expected the shape-to-loops map to be non-null");
- // Check if given shapes match to inferred shapes.
- SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
- SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
- // Verify only static cases since we can't get exact dimension sizes and
- // loop ranges for dynamic cases in this stage.
- if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
- for (int64_t &range : endLoopRangeValues)
- range -= 1;
- for (OpOperand &opOperand : linalgOp->getOpOperands()) {
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
- SmallVector<int64_t, 4> startIndices =
- indexingMap.compose(startLoopRangeValues);
- SmallVector<int64_t, 4> endIndices =
- indexingMap.compose(endLoopRangeValues);
- ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
- for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
- // Ignore dynamic dimension or the case that the dimension size is 0
- if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
- continue;
-
- // The first index or last index should be the maximum or the minimum in
- // the inferred index ranges since the range is increasing or
- // decreasing. The size of dimensions of input/output operands and the
- // maximum value + 1 in the inferred range should be the same. But, for
- // now we check if the inferred ranges are in boundary of input/output
- // operands' size or not in case that Affine Expressions are complicated
- // such as d0 * 3
- // + d1 since it is not easy to handle the issues.
- // Found the case that this solution can't check, for example, (d0, d1)
- // -> (d1 - d0)
- int64_t inferredDimSize =
- std::max(startIndices[dim], endIndices[dim]) + 1;
- if (std::min(startIndices[dim], endIndices[dim]) < 0) {
- std::string mapStr;
- {
- llvm::raw_string_ostream os(mapStr);
- os << indexingMap;
- }
- return op->emitOpError(
- "unexpected result less than 0 at expression #")
- << dim << " in " << mapStr;
- }
- if (isa<AffineDimExpr>(indexingMap.getResult(dim))) {
- if (inferredDimSize != shape[dim]) {
- return op->emitOpError("inferred input/output operand #")
- << opOperand.getOperandNumber() << " has shape's dimension #"
- << dim << " to be " << inferredDimSize << ", but found "
- << shape[dim];
- }
- } else {
- if (inferredDimSize > shape[dim]) {
- return op->emitOpError("inferred input/output operand #")
- << opOperand.getOperandNumber() << " has shape's dimension #"
- << dim << " to be greater than or equal to "
- << inferredDimSize << ", but found " << shape[dim];
- }
- }
- }
- }
- }
-
// Check the region has exactly one block.
if (linalgOp->getNumRegions() != 1 ||
!llvm::hasSingleElement(linalgOp->getRegion(0)))
diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
index f56ef485069f8..8d6d9dc690b55 100644
--- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 5e6dde36d7f9f..c5d9a729a4136 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -398,7 +398,10 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
return rewriter.notifyMatchFailure(genericOp,
"invalid indexing maps for operation");
}
- SmallVector<int64_t> dims = genericOp.getStaticShape();
+
+ SmallVector<int64_t> allShapesSizes;
+ for (OpOperand &opOperand : genericOp->getOpOperands())
+ llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand));
// 1a. Get the allowed list of dimensions to drop from the `options`.
SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
@@ -411,7 +414,7 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
llvm::SmallDenseSet<unsigned> unitDims;
for (const auto &expr : enumerate(invertedMap.getResults())) {
if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
- if (dims[dimExpr.getPosition()] == 1 &&
+ if (allShapesSizes[dimExpr.getPosition()] == 1 &&
unitDimsFilter.count(expr.index()))
unitDims.insert(expr.index());
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ff28bd7c48342..ff8e0b8977ae8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -31,6 +31,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
@@ -2217,7 +2218,9 @@ static LogicalResult vectorizeLinalgOpPrecondition(
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
// tensor with dimension of 0 cannot be vectorized.
- if (llvm::is_contained(linalgOp.getStaticShape(), 0))
+ if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
+ return llvm::is_contained(linalgOp.getShape(&operand), 0);
+ }))
return failure();
// Check API contract for input vector sizes.
if (!inputVectorSizes.empty() &&
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 204462ffd047c..d464230c87723 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRVectorDialect
MLIRDataLayoutInterfaces
MLIRDestinationStyleOpInterface
MLIRDialectUtils
+ MLIRIndexingMapOpInterface
MLIRIR
MLIRMaskableOpInterface
MLIRMaskingOpInterface
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ee9ab61b670c4..5e0f36064be3b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1063,7 +1063,8 @@ LogicalResult ContractionOp::verify() {
if (!isSupportedCombiningKind(getKind(), elementType))
return emitOpError("unsupported contraction type");
- return success();
+ // Delayed calling of IndexingMapOpInterface::verifyImpl.
+ return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
}
// MaskableOpInterface methods.
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index a25694cfff5f2..af923d98c76ff 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
DestinationStyleOpInterface.cpp
FunctionImplementation.cpp
FunctionInterfaces.cpp
+ IndexingMapOpInterface.cpp
InferIntRangeInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
@@ -62,6 +63,7 @@ add_mlir_library(MLIRFunctionInterfaces
MLIRIR
)
+add_mlir_interface_library(IndexingMapOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
add_mlir_interface_library(InferTypeOpInterface)
diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
new file mode 100644
index 0000000000000..f3c12aed8df84
--- /dev/null
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -0,0 +1,125 @@
+//===- IndexingMapOpInterface.cpp -- IndexingMapOpInterface impl ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+#include "mlir/Interfaces/IndexingMapOpInterface.cpp.inc"
+} // namespace mlir
+
+LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
+ // All input/output operands must be indexed.
+ if (static_cast<int64_t>(getIndexingMapsArray().size()) !=
+ getOperation()->getNumOperands())
+ return this->emitOpError("expected the number of indexing_map (")
+ << getIndexingMapsArray().size()
+ << ") to be equal to the number of input/output operands ("
+ << getOperation()->getNumOperands() << ")";
+
+ AffineMap invertedMap = getShapesToLoopsMap();
+ if (!invertedMap) {
+ std::string str;
+ llvm::raw_string_ostream os(str);
+ getLoopsToShapesMap().print(os);
+ return this->emitOpError("invalid indexing maps are non-invertible: ")
+ << "(" << str << ")";
+ }
+
+ SmallVector<int64_t> endLoopRangeValues = getStaticLoopRanges();
+
+ // Set this flag if this op has user defined maps. This is required to guard
+ // the below error condition which assume default indexing maps.
+ for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+ AffineMap indexingMap = getMatchingIndexingMap(&opOperand);
+
+ // Symbols disallowed.
+ if (indexingMap.getNumSymbols() != 0)
+ return getOperation()->emitOpError("unexpected symbols in indexing_map #")
+ << opOperand.getOperandNumber();
+
+ // Domain must be consistent.
+ if (indexingMap.getNumDims() != endLoopRangeValues.size())
+ return getOperation()->emitOpError("expected indexing_map #")
+ << opOperand.getOperandNumber() << " to have "
+ << endLoopRangeValues.size()
+ << " dim(s) to match the number of loops";
+
+ SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
+ int64_t rank = shape.size();
+
+ if (indexingMap.getNumResults() != rank)
+ return getOperation()->emitOpError("expected operand rank (")
+ << rank << ") to match the result rank of indexing_map #"
+ << opOperand.getOperandNumber() << " ("
+ << indexingMap.getNumResults() << ")";
+ }
+
+ // Check if given shapes match to inferred shapes.
+ SmallVector<int64_t> startLoopRangeValues(endLoopRangeValues.size(), 0);
+ // Verify only static cases since we can't get exact dimension sizes and
+ // loop ranges for dynamic cases in this stage.
+ if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
+ // Exclusive end range.
+ for (int64_t &range : endLoopRangeValues)
+ range -= 1;
+ for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+ AffineMap indexingMap = getMatchingIndexingMap(&opOperand);
+ SmallVector<int64_t> startIndices =
+ indexingMap.compose(startLoopRangeValues);
+ SmallVector<int64_t> endIndices = indexingMap.compose(endLoopRangeValues);
+ SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
+ for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
+ // Ignore dynamic dimension or the case that the dimension size is 0
+ if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
+ continue;
+
+ // The first index or last index should be the maximum or the minimum in
+ // the inferred index ranges since the range is increasing or
+ // decreasing. The size of dimensions of input/output operands and the
+ // maximum value + 1 in the inferred range should be the same. But, for
+ // now we check if the inferred ranges are in boundary of input/output
+ // operands' size or not in case that Affine Expressions are complicated
+ // such as d0 * 3
+ // + d1 since it is not easy to handle the issues.
+ // Found the case that this solution can't check, for example, (d0, d1)
+ // -> (d1 - d0)
+ int64_t inferredDimSize =
+ std::max(startIndices[dim], endIndices[dim]) + 1;
+ if (std::min(startIndices[dim], endIndices[dim]) < 0) {
+ std::string mapStr;
+ {
+ llvm::raw_string_ostream os(mapStr);
+ os << indexingMap;
+ }
+ return this->emitOpError(
+ "unexpected result less than 0 at expression #")
+ << dim << " in " << mapStr;
+ }
+ if (isa<AffineDimExpr>(indexingMap.getResult(dim))) {
+ if (inferredDimSize != shape[dim]) {
+ return this->emitOpError("inferred input/output operand #")
+ << opOperand.getOperandNumber() << " has shape's dimension #"
+ << dim << " to be " << inferredDimSize << ", but found "
+ << shape[dim];
+ }
+ } else {
+ if (inferredDimSize > shape[dim]) {
+ return this->emitOpError("inferred input/output operand #")
+ << opOperand.getOperandNumber() << " has shape's dimension #"
+ << dim << " to be greater than or equal to "
+ << inferredDimSize << ", but found " << shape[dim];
+ }
+ }
+ }
+ }
+ }
+
+ return success();
+}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c0c5f785e856b..ca40301f04fa1 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -151,7 +151,7 @@ func.func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off
// -----
func.func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>, %arg1: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
- // expected-error @+1 {{expected the shape-to-loops map to be non-null}}
+ // expected-error @+1 {{invalid indexing maps are non-invertible: ((d0, d1) -> (d0 + d1, d0 + d1))}}
linalg.generic {
indexing_maps = [
affine_map<(i, j) -> (i + j)>,
More information about the Mlir-commits
mailing list