[Mlir-commits] [mlir] Users/nico/indexing map op interface (PR #145332)
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jun 23 06:49:12 PDT 2025
https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/145332
None
>From 91f7ea29a737e1ecd959ec64820264c4ecaedddc Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Mon, 23 Jun 2025 13:41:33 +0200
Subject: [PATCH 1/2] [mlir][vector] NFC - Add more structured interface
support to vector.contract
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 33 +++++++++++++++++++
1 file changed, 33 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 926a92eff2ebb..12362be4d7d30 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -207,6 +207,39 @@ def Vector_ContractionOp :
.template getAsValueRange<IteratorTypeAttr, IteratorType>();
return {range.begin(), range.end()};
}
+
+ //===------------------------------------------------------------------===//
+ // The code below is shared with LinalgStructuredInterface.
+ // vector.contract is really a linalg.generic on vectors without region.
+ // TODO: factor out in a common interface to inherit from ince identified.
+ //===------------------------------------------------------------------===//
+ ArrayRef<int64_t> getShape(OpOperand * opOperand) {
+ assert(opOperand->getOwner() == this->getOperation());
+ Type t = opOperand->get().getType();
+ return cast<VectorType>(t).getShape();
+ }
+
+ AffineMap getLoopsToShapesMap() {
+ auto maps = getIndexingMapsArray();
+ return concatAffineMaps(maps, getContext());
+ }
+
+ AffineMap getShapesToLoopsMap() {
+ return inversePermutation(getLoopsToShapesMap());
+ }
+
+ SmallVector<int64_t> getStaticShape(){
+ SmallVector<int64_t> res;
+ for (OpOperand &opOperand : this->getOperation()->getOpOperands())
+ llvm::append_range(res, getShape(&opOperand));
+ return res;
+ }
+
+ SmallVector<int64_t> getStaticLoopRanges() {
+ SmallVector<int64_t> viewSizes = getStaticShape();
+ AffineMap invertedMap = getShapesToLoopsMap();
+ return invertedMap.compose(viewSizes);
+ }
}];
let hasCanonicalizer = 1;
>From 888b4441917d2f9f7920127dea9f1c0eb3201a7d Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Mon, 23 Jun 2025 15:48:11 +0200
Subject: [PATCH 2/2] WIP - [mlir][Interface] Factor out common
IndexingMapOpInterface behavior in a new generic interface
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 1 +
.../Dialect/Linalg/IR/LinalgInterfaces.td | 179 +++---------------
.../mlir/Dialect/Vector/IR/VectorOps.td | 26 +--
mlir/include/mlir/Interfaces/CMakeLists.txt | 1 +
.../mlir/Interfaces/IndexingMapOpInterface.h | 20 ++
.../mlir/Interfaces/IndexingMapOpInterface.td | 162 ++++++++++++++++
mlir/lib/Dialect/Linalg/IR/CMakeLists.txt | 1 +
.../Linalg/IR/ValueBoundsOpInterfaceImpl.cpp | 1 +
mlir/lib/Interfaces/CMakeLists.txt | 2 +
.../DestinationStyleOpInterface.cpp | 53 +-----
.../lib/Interfaces/IndexingMapOpInterface.cpp | 62 ++++++
11 files changed, 283 insertions(+), 225 deletions(-)
create mode 100644 mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
create mode 100644 mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
create mode 100644 mlir/lib/Interfaces/IndexingMapOpInterface.cpp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index df32cafd2d024..3b2557cb9c165 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -20,6 +20,7 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/RawOstreamExtras.h"
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 74c4c0a8835f2..f5fe3a5cd3394 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -14,6 +14,7 @@
#define LINALG_IR_LINALGINTERFACES
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/IndexingMapOpInterface.td"
include "mlir/IR/OpBase.td"
// The 'LinalgContractionOpInterface' provides access to the
@@ -222,58 +223,11 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
];
}
-def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
- let description = [{
- Interface for operations that connect an iteration domain to operands via
- affine maps. Provides methods to access indexing maps between iteration
- domain and operand index spaces.
- }];
- let cppNamespace = "::mlir::linalg";
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
- Return the indexing maps attribute within the current operation.
- }],
- /*retTy=*/"ArrayAttr",
- /*methodName=*/"getIndexingMaps"
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the indexing maps within the current operation.
- }],
- /*retTy=*/"SmallVector<AffineMap>",
- /*methodName=*/"getIndexingMapsArray",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto range = $_op.getIndexingMaps()
- .template getAsValueRange<AffineMapAttr>();
- return {range.begin(), range.end()};
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the input or output indexing map for `opOperand`.
- }],
- /*retTy=*/"AffineMap",
- /*methodName=*/"getMatchingIndexingMap",
- /*args=*/(ins "OpOperand*":$opOperand),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(opOperand->getOwner() == this->getOperation());
- auto indexingMaps =
- $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
- return *(indexingMaps.begin() + opOperand->getOperandNumber());
- }]
- >,
- ];
-}
-
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
def LinalgStructuredInterface
: OpInterface<"LinalgOp", [
DestinationStyleOpInterface,
- IndexingMapOpInterface
+ DeclareOpInterfaceMethods<IndexingMapOpInterface, ["getShape"]>
]> {
let cppNamespace = "::mlir::linalg";
let methods = [
@@ -464,30 +418,6 @@ def LinalgStructuredInterface
return getBlock()->getArguments().take_back($_op.getNumDpsInits());
}]
>,
- InterfaceMethod<
- /*desc=*/[{
- Return the `opOperand` shape or an empty vector for scalars or vectors
- not wrapped within a tensor or a memref.
- }],
- /*retTy=*/"ArrayRef<int64_t>",
- /*methodName=*/"getShape",
- /*args=*/(ins "OpOperand*":$opOperand),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(opOperand->getOwner() == this->getOperation());
- Type t = opOperand->get().getType();
- // A VectorType is an elemental type, do not consider its rank for the operand.
- if (isa<VectorType>(t))
- return {};
- if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
- // Failsafe.
- assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
- "expected a ranked tensor or memref in LinalgInterface::getRank");
- return shapedType.getShape();
- }
- return {};
- }]
- >,
InterfaceMethod<
/*desc=*/[{
Return the block argument for an `opOperand`.
@@ -620,7 +550,12 @@ def LinalgStructuredInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return llvm::any_of(getStaticShape(), ShapedType::isDynamic);
+ for (OpOperand &opOperand : this->getOperation()->getOpOperands()) {
+ if (auto shapedType = dyn_cast<ShapedType>(opOperand.get().getType())) {
+ if (ShapedType::isDynamicShape(shapedType.getShape())) return true;
+ }
+ }
+ return false;
}]
>,
InterfaceMethod<
@@ -738,53 +673,6 @@ def LinalgStructuredInterface
//===------------------------------------------------------------------===//
// Linalg generalization hooks.
//===------------------------------------------------------------------===//
- InterfaceMethod<
- /*desc=*/[{
- Hook to provide a custom AffineMap used to compute all the operand
- subshapes given loop bounds. This is used to answer the question: "given
- an iteration space over the codomain, what are the subshapes of the
- operands involved in the computation".
- The default behavior is to just concatenate all the indexing maps.
- A custom AffineMap allows providing a map that can be used to
- compute subshapes even in cases where the concatenation of indexing maps
- (i.e. the data traversal order) is not a simple permutation of the loop
- traversal order. It is then possible to define ops with skewed data
- traversal order for which we can still easily compute hyperrectangular
- loop bounds and subviews.
- }],
- /*retTy=*/"AffineMap",
- /*methodName=*/"getLoopsToShapesMap",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto maps = $_op.getIndexingMapsArray();
- return concatAffineMaps(maps, $_op.getContext());
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Hook to provide a custom AffineMap used to construct the
- hyperrectangular loop iteration space given all the operand subshapes.
- This is used to answer the question:
- "Given a list of operand ranges, what is the subportion of the iteration
- space involved in the computation".
- This is the inverse problem of `getLoopsToShapesMap`.
- Return the empty AffineMap when such an AffineMap cannot be constructed.
- The default behavior is based on a very simple inference procedure that
- only works with permutation affine maps.
- A more advanced Tensor-Comprehension like inference is possible but has
- proven to be ambiguous in unfavorable case.
- A safer and more robust alternative is to allow each op to define
- its own AffineMap.
- }],
- /*retTy=*/"AffineMap",
- /*methodName=*/"getShapesToLoopsMap",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return inversePermutation(getLoopsToShapesMap());
- }]
- >,
InterfaceMethod<
/*desc=*/[{
Checks if the given operands can be dropped, and the remaining
@@ -798,39 +686,30 @@ def LinalgStructuredInterface
return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
}]
>,
+ //===------------------------------------------------------------------===//
+ // IndexingMapOpInterface interface methods implementation.
+ //===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
- Like `getShape`, but only returns statically-known information, without
- generating any new IR. For each shape dimension, returns >=0 if that
- dimension is statically known, or ShapedType::kDynamic otherwise.
- }],
- /*retTy=*/"SmallVector<int64_t>",
- /*methodName=*/"getStaticShape",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- SmallVector<int64_t> res;
- for (OpOperand &opOperand : this->getOperation()->getOpOperands())
- llvm::append_range(res, getShape(&opOperand));
- return res;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Returns the statically-known loop ranges. Composes
- `getShapesToLoopsMap()` with the result of `getStaticShape`.
- Returns ShapedType::kDynamic for non-statically-known loop ranges.
- This is expected to be called by a valid Linalg op
+ Return the `opOperand` shape or an empty vector for scalars or vectors
+ not wrapped within a tensor or a memref.
}],
- /*retTy=*/"SmallVector<int64_t, 4>",
- /*methodName=*/"getStaticLoopRanges",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- SmallVector<int64_t> viewSizes = getStaticShape();
- AffineMap invertedMap = getShapesToLoopsMap();
- assert(invertedMap && "expected a valid Linalg op to call the method");
- return invertedMap.compose(viewSizes);
+ /*retTy=*/"ArrayRef<int64_t>",
+ /*methodName=*/"getShape",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/[{
+ assert(opOperand->getOwner() == this->getOperation());
+ Type t = opOperand->get().getType();
+ // A VectorType is an elemental type, do not consider its rank for the operand.
+ if (isa<VectorType>(t))
+ return {};
+ if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
+ // Failsafe.
+ assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
+ "expected a ranked tensor or memref in LinalgInterface::getRank");
+ return shapedType.getShape();
+ }
+ return {};
}]
>,
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 12362be4d7d30..91c3fc528029e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -209,37 +209,13 @@ def Vector_ContractionOp :
}
//===------------------------------------------------------------------===//
- // The code below is shared with LinalgStructuredInterface.
- // vector.contract is really a linalg.generic on vectors without region.
- // TODO: factor out in a common interface to inherit from ince identified.
+ // IndexingMapOpInterface interface methods implementation.
//===------------------------------------------------------------------===//
ArrayRef<int64_t> getShape(OpOperand * opOperand) {
assert(opOperand->getOwner() == this->getOperation());
Type t = opOperand->get().getType();
return cast<VectorType>(t).getShape();
}
-
- AffineMap getLoopsToShapesMap() {
- auto maps = getIndexingMapsArray();
- return concatAffineMaps(maps, getContext());
- }
-
- AffineMap getShapesToLoopsMap() {
- return inversePermutation(getLoopsToShapesMap());
- }
-
- SmallVector<int64_t> getStaticShape(){
- SmallVector<int64_t> res;
- for (OpOperand &opOperand : this->getOperation()->getOpOperands())
- llvm::append_range(res, getShape(&opOperand));
- return res;
- }
-
- SmallVector<int64_t> getStaticLoopRanges() {
- SmallVector<int64_t> viewSizes = getStaticShape();
- AffineMap invertedMap = getShapesToLoopsMap();
- return invertedMap.compose(viewSizes);
- }
}];
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index d81298bb4daf0..067e0511e4e75 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_interface(CopyOpInterface)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
+add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
diff --git a/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
new file mode 100644
index 0000000000000..d5978b6fe9b78
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
@@ -0,0 +1,20 @@
+//===- IndexingMapOpInterface.h ----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
+#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/IndexingMapOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
new file mode 100644
index 0000000000000..82dc420c5437a
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
@@ -0,0 +1,162 @@
+//===- IndexingMapOpInterface.td - Interface Declaration -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the IndexingMapOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
+#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
+
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/IR/OpBase.td"
+
+def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
+ let description = [{
+ Interface for operations that connect an iteration domain to operands via
+ affine maps. Provides methods to access indexing maps between iteration
+ domain and operand index spaces.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ }],
+ /*retTy=*/"ArrayAttr",
+ /*methodName=*/"getIndexingMaps"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps within the current operation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMapsArray",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto range = $_op.getIndexingMaps()
+ .template getAsValueRange<AffineMapAttr>();
+ return {range.begin(), range.end()};
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the input or output indexing map for `opOperand`.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getMatchingIndexingMap",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(opOperand->getOwner() == this->getOperation());
+ auto indexingMaps =
+ $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
+ return *(indexingMaps.begin() + opOperand->getOperandNumber());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Hook to provide a custom AffineMap used to compute all the operand
+ subshapes given loop bounds. This is used to answer the question: "given
+ an iteration space over the codomain, what are the subshapes of the
+ operands involved in the computation".
+ The default behavior is to just concatenate all the indexing maps.
+ A custom AffineMap allows providing a map that can be used to
+ compute subshapes even in cases where the concatenation of indexing maps
+ (i.e. the data traversal order) is not a simple permutation of the loop
+ traversal order. It is then possible to define ops with skewed data
+ traversal order for which we can still easily compute hyperrectangular
+ loop bounds and subviews.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getLoopsToShapesMap",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto maps = $_op.getIndexingMapsArray();
+ return concatAffineMaps(maps, $_op.getContext());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the `opOperand` shape or an empty vector if not shaped.
+ Must be implemented by derived interfaces.
+ }],
+ /*retTy=*/"ArrayRef<int64_t>",
+ /*methodName=*/"getShape",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ assert(false && "Need to override");
+ return {};
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Like `getShape`, but only returns statically-known information, without
+ generating any new IR. For each shape dimension, returns >=0 if that
+ dimension is statically known, or ShapedType::kDynamic otherwise.
+ }],
+ /*retTy=*/"SmallVector<int64_t>",
+ /*methodName=*/"getStaticShape",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<int64_t> res;
+ for (OpOperand &opOperand : this->getOperation()->getOpOperands())
+ llvm::append_range(res, getShape(&opOperand));
+ return res;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Hook to provide a custom AffineMap used to construct the
+ hyperrectangular loop iteration space given all the operand subshapes.
+ This is used to answer the question:
+ "Given a list of operand ranges, what is the subportion of the iteration
+ space involved in the computation".
+ This is the inverse problem of `getLoopsToShapesMap`.
+ Return the empty AffineMap when such an AffineMap cannot be constructed.
+ The default behavior is based on a very simple inference procedure that
+ only works with permutation affine maps.
+ A more advanced Tensor-Comprehension like inference is possible but has
+ proven to be ambiguous in unfavorable case.
+ A safer and more robust alternative is to allow each op to define
+ its own AffineMap.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getShapesToLoopsMap",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return inversePermutation(getLoopsToShapesMap());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the statically-known loop ranges. Composes
+ `getShapesToLoopsMap()` with the result of `getStaticShape`.
+ Returns ShapedType::kDynamic for non-statically-known loop ranges.
+ This is expected to be called by a valid Linalg op
+ }],
+ /*retTy=*/"SmallVector<int64_t, 4>",
+ /*methodName=*/"getStaticLoopRanges",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<int64_t> viewSizes = getStaticShape();
+ AffineMap invertedMap = getShapesToLoopsMap();
+ assert(invertedMap && "expected a valid Linalg op to call the method");
+ return invertedMap.compose(viewSizes);
+ }]
+ >,
+ ];
+}
+
+#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index b4aeb44ac8faf..ec433284e17ad 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRFunctionInterfaces
+ MLIRIndexingMapOpInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRParser
diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
index f56ef485069f8..8d6d9dc690b55 100644
--- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index a25694cfff5f2..af923d98c76ff 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
DestinationStyleOpInterface.cpp
FunctionImplementation.cpp
FunctionInterfaces.cpp
+ IndexingMapOpInterface.cpp
InferIntRangeInterface.cpp
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
@@ -62,6 +63,7 @@ add_mlir_library(MLIRFunctionInterfaces
MLIRIR
)
+add_mlir_interface_library(IndexingMapOpInterface)
add_mlir_interface_library(InferIntRangeInterface)
add_mlir_interface_library(InferTypeOpInterface)
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
index 496238fcaa3ff..cb0885344d906 100644
--- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
+++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
@@ -1,4 +1,4 @@
-//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
+//===- IndexingMapOpInterface.cpp -- Destination style ops -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,57 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
using namespace mlir;
namespace mlir {
-#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
+#include "mlir/Interfaces/IndexingMapOpInterface.cpp.inc"
} // namespace mlir
-
-namespace {
-size_t getNumTensorResults(Operation *op) {
- size_t numTensorResults = 0;
- for (auto t : op->getResultTypes()) {
- if (isa<TensorType>(t)) {
- ++numTensorResults;
- }
- }
- return numTensorResults;
-}
-} // namespace
-
-LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
- DestinationStyleOpInterface dstStyleOp =
- cast<DestinationStyleOpInterface>(op);
-
- SmallVector<OpOperand *> outputTensorOperands;
- for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
- Type type = operand.get().getType();
- if (isa<TensorType>(type)) {
- outputTensorOperands.push_back(&operand);
- } else if (!isa<BaseMemRefType>(type)) {
- return op->emitOpError("expected that operand #")
- << operand.getOperandNumber() << " is a tensor or a memref";
- }
- }
-
- // Verify the number of tensor results matches the number of output tensors.
- if (getNumTensorResults(op) != outputTensorOperands.size())
- return op->emitOpError("expected the number of tensor results (")
- << getNumTensorResults(op)
- << ") to be equal to the number of output tensors ("
- << outputTensorOperands.size() << ")";
-
- for (OpOperand *opOperand : outputTensorOperands) {
- OpResult result = dstStyleOp.getTiedOpResult(opOperand);
- if (result.getType() != opOperand->get().getType())
- return op->emitOpError("expected type of operand #")
- << opOperand->getOperandNumber() << " ("
- << opOperand->get().getType() << ")"
- << " to match type of corresponding result (" << result.getType()
- << ")";
- }
-
- return success();
-}
diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
new file mode 100644
index 0000000000000..496238fcaa3ff
--- /dev/null
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -0,0 +1,62 @@
+//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
+} // namespace mlir
+
+namespace {
+size_t getNumTensorResults(Operation *op) {
+ size_t numTensorResults = 0;
+ for (auto t : op->getResultTypes()) {
+ if (isa<TensorType>(t)) {
+ ++numTensorResults;
+ }
+ }
+ return numTensorResults;
+}
+} // namespace
+
+LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
+ DestinationStyleOpInterface dstStyleOp =
+ cast<DestinationStyleOpInterface>(op);
+
+ SmallVector<OpOperand *> outputTensorOperands;
+ for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
+ Type type = operand.get().getType();
+ if (isa<TensorType>(type)) {
+ outputTensorOperands.push_back(&operand);
+ } else if (!isa<BaseMemRefType>(type)) {
+ return op->emitOpError("expected that operand #")
+ << operand.getOperandNumber() << " is a tensor or a memref";
+ }
+ }
+
+ // Verify the number of tensor results matches the number of output tensors.
+ if (getNumTensorResults(op) != outputTensorOperands.size())
+ return op->emitOpError("expected the number of tensor results (")
+ << getNumTensorResults(op)
+ << ") to be equal to the number of output tensors ("
+ << outputTensorOperands.size() << ")";
+
+ for (OpOperand *opOperand : outputTensorOperands) {
+ OpResult result = dstStyleOp.getTiedOpResult(opOperand);
+ if (result.getType() != opOperand->get().getType())
+ return op->emitOpError("expected type of operand #")
+ << opOperand->getOperandNumber() << " ("
+ << opOperand->get().getType() << ")"
+ << " to match type of corresponding result (" << result.getType()
+ << ")";
+ }
+
+ return success();
+}
More information about the Mlir-commits
mailing list