[Mlir-commits] [mlir] Users/nico/indexing map op interface (PR #145332)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jun 23 06:49:12 PDT 2025


https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/145332

None

>From 91f7ea29a737e1ecd959ec64820264c4ecaedddc Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Mon, 23 Jun 2025 13:41:33 +0200
Subject: [PATCH 1/2] [mlir][vector] NFC - Add more structured interface
 support to vector.contract

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 33 +++++++++++++++++++
 1 file changed, 33 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 926a92eff2ebb..12362be4d7d30 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -207,6 +207,39 @@ def Vector_ContractionOp :
               .template getAsValueRange<IteratorTypeAttr, IteratorType>();
       return {range.begin(), range.end()};
     }
+
+    //===------------------------------------------------------------------===//
+    // The code below is shared with LinalgStructuredInterface.
+    // vector.contract is really a linalg.generic on vectors without region.
+    // TODO: factor out in a common interface to inherit from ince identified.
+    //===------------------------------------------------------------------===//
+    ArrayRef<int64_t> getShape(OpOperand * opOperand) {
+      assert(opOperand->getOwner() == this->getOperation());
+      Type t = opOperand->get().getType();
+      return cast<VectorType>(t).getShape();
+    }
+
+    AffineMap getLoopsToShapesMap() {
+      auto maps = getIndexingMapsArray();
+      return concatAffineMaps(maps, getContext());
+    }
+
+    AffineMap getShapesToLoopsMap() {
+      return inversePermutation(getLoopsToShapesMap());
+    }
+
+    SmallVector<int64_t> getStaticShape(){
+      SmallVector<int64_t> res;
+      for (OpOperand &opOperand : this->getOperation()->getOpOperands())
+        llvm::append_range(res, getShape(&opOperand));
+      return res;
+    }
+
+    SmallVector<int64_t> getStaticLoopRanges() {
+      SmallVector<int64_t> viewSizes = getStaticShape();
+      AffineMap invertedMap = getShapesToLoopsMap();
+      return invertedMap.compose(viewSizes);
+    }
   }];
 
   let hasCanonicalizer = 1;

>From 888b4441917d2f9f7920127dea9f1c0eb3201a7d Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Mon, 23 Jun 2025 15:48:11 +0200
Subject: [PATCH 2/2] WIP - [mlir][Interface] Factor out common
 IndexingMapOpInterface behavior in a new generic interface

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |   1 +
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 179 +++---------------
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  26 +--
 mlir/include/mlir/Interfaces/CMakeLists.txt   |   1 +
 .../mlir/Interfaces/IndexingMapOpInterface.h  |  20 ++
 .../mlir/Interfaces/IndexingMapOpInterface.td | 162 ++++++++++++++++
 mlir/lib/Dialect/Linalg/IR/CMakeLists.txt     |   1 +
 .../Linalg/IR/ValueBoundsOpInterfaceImpl.cpp  |   1 +
 mlir/lib/Interfaces/CMakeLists.txt            |   2 +
 .../DestinationStyleOpInterface.cpp           |  53 +-----
 .../lib/Interfaces/IndexingMapOpInterface.cpp |  62 ++++++
 11 files changed, 283 insertions(+), 225 deletions(-)
 create mode 100644 mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
 create mode 100644 mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
 create mode 100644 mlir/lib/Interfaces/IndexingMapOpInterface.cpp

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index df32cafd2d024..3b2557cb9c165 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/RawOstreamExtras.h"
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 74c4c0a8835f2..f5fe3a5cd3394 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -14,6 +14,7 @@
 #define LINALG_IR_LINALGINTERFACES
 
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/IndexingMapOpInterface.td"
 include "mlir/IR/OpBase.td"
 
 // The 'LinalgContractionOpInterface' provides access to the
@@ -222,58 +223,11 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
   ];
 }
 
-def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
-  let description = [{
-    Interface for operations that connect an iteration domain to operands via
-    affine maps. Provides methods to access indexing maps between iteration
-    domain and operand index spaces.
-  }];
-  let cppNamespace = "::mlir::linalg";
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the indexing maps attribute within the current operation.
-      }],
-      /*retTy=*/"ArrayAttr",
-      /*methodName=*/"getIndexingMaps"
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the indexing maps within the current operation.
-      }],
-      /*retTy=*/"SmallVector<AffineMap>",
-      /*methodName=*/"getIndexingMapsArray",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto range = $_op.getIndexingMaps()
-          .template getAsValueRange<AffineMapAttr>();
-        return {range.begin(), range.end()};
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the input or output indexing map for `opOperand`.
-      }],
-      /*retTy=*/"AffineMap",
-      /*methodName=*/"getMatchingIndexingMap",
-      /*args=*/(ins "OpOperand*":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(opOperand->getOwner() == this->getOperation());
-        auto indexingMaps =
-          $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
-        return *(indexingMaps.begin() + opOperand->getOperandNumber());
-      }]
-    >,
-  ];
-}
-
 // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
 def LinalgStructuredInterface
     : OpInterface<"LinalgOp", [
       DestinationStyleOpInterface,
-      IndexingMapOpInterface
+      DeclareOpInterfaceMethods<IndexingMapOpInterface, ["getShape"]>
   ]> {
   let cppNamespace = "::mlir::linalg";
   let methods = [
@@ -464,30 +418,6 @@ def LinalgStructuredInterface
         return getBlock()->getArguments().take_back($_op.getNumDpsInits());
       }]
     >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the `opOperand` shape or an empty vector for scalars or vectors
-        not wrapped within a tensor or a memref.
-      }],
-      /*retTy=*/"ArrayRef<int64_t>",
-      /*methodName=*/"getShape",
-      /*args=*/(ins "OpOperand*":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(opOperand->getOwner() == this->getOperation());
-        Type t = opOperand->get().getType();
-        // A VectorType is an elemental type, do not consider its rank for the operand.
-        if (isa<VectorType>(t))
-          return {};
-        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
-          // Failsafe.
-          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
-                 "expected a ranked tensor or memref in LinalgInterface::getRank");
-          return shapedType.getShape();
-        }
-        return {};
-      }]
-    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the block argument for an `opOperand`.
@@ -620,7 +550,12 @@ def LinalgStructuredInterface
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return llvm::any_of(getStaticShape(), ShapedType::isDynamic);
+        for (OpOperand &opOperand : this->getOperation()->getOpOperands()) {
+          if (auto shapedType = dyn_cast<ShapedType>(opOperand.get().getType())) {
+            if (ShapedType::isDynamicShape(shapedType.getShape())) return true;
+          }
+        }
+        return false;
       }]
     >,
     InterfaceMethod<
@@ -738,53 +673,6 @@ def LinalgStructuredInterface
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
-    InterfaceMethod<
-      /*desc=*/[{
-        Hook to provide a custom AffineMap used to compute all the operand
-        subshapes given loop bounds. This is used to answer the question: "given
-        an iteration space over the codomain, what are the subshapes of the
-        operands involved in the computation".
-        The default behavior is to just concatenate all the indexing maps.
-        A custom AffineMap allows providing a map that can be used to
-        compute subshapes even in cases where the concatenation of indexing maps
-        (i.e. the data traversal order) is not a simple permutation of the loop
-        traversal order. It is then possible to define ops with skewed data
-        traversal order for which we can still easily compute hyperrectangular
-        loop bounds and subviews.
-      }],
-      /*retTy=*/"AffineMap",
-      /*methodName=*/"getLoopsToShapesMap",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto maps =  $_op.getIndexingMapsArray();
-        return concatAffineMaps(maps, $_op.getContext());
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Hook to provide a custom AffineMap used to construct the
-        hyperrectangular loop iteration space given all the operand subshapes.
-        This is used to answer the question:
-        "Given a list of operand ranges, what is the subportion of the iteration
-        space involved in the computation".
-        This is the inverse problem of `getLoopsToShapesMap`.
-        Return the empty AffineMap when such an AffineMap cannot be constructed.
-        The default behavior is based on a very simple inference procedure that
-        only works with permutation affine maps.
-        A more advanced Tensor-Comprehension like inference is possible but has
-        proven to be ambiguous in unfavorable case.
-        A safer and more robust alternative is to allow each op to define
-        its own AffineMap.
-      }],
-      /*retTy=*/"AffineMap",
-      /*methodName=*/"getShapesToLoopsMap",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return inversePermutation(getLoopsToShapesMap());
-      }]
-    >,
     InterfaceMethod<
       /*desc=*/[{
         Checks if the given operands can be dropped, and the remaining
@@ -798,39 +686,30 @@ def LinalgStructuredInterface
         return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
       }]
     >,
+    //===------------------------------------------------------------------===//
+    // IndexingMapOpInterface interface methods implementation.
+    //===------------------------------------------------------------------===//
     InterfaceMethod<
       /*desc=*/[{
-        Like `getShape`, but only returns statically-known information, without
-        generating any new IR. For each shape dimension, returns >=0 if that
-        dimension is statically known, or ShapedType::kDynamic otherwise.
-      }],
-      /*retTy=*/"SmallVector<int64_t>",
-      /*methodName=*/"getStaticShape",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        SmallVector<int64_t> res;
-        for (OpOperand &opOperand : this->getOperation()->getOpOperands())
-          llvm::append_range(res, getShape(&opOperand));
-        return res;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Returns the statically-known loop ranges. Composes
-        `getShapesToLoopsMap()` with the result of `getStaticShape`.
-        Returns ShapedType::kDynamic for non-statically-known loop ranges.
-        This is expected to be called by a valid Linalg op
+        Return the `opOperand` shape or an empty vector for scalars or vectors
+        not wrapped within a tensor or a memref.
       }],
-      /*retTy=*/"SmallVector<int64_t, 4>",
-      /*methodName=*/"getStaticLoopRanges",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        SmallVector<int64_t> viewSizes = getStaticShape();
-        AffineMap invertedMap = getShapesToLoopsMap();
-        assert(invertedMap && "expected a valid Linalg op to call the method");
-        return invertedMap.compose(viewSizes);
+      /*retTy=*/"ArrayRef<int64_t>",
+      /*methodName=*/"getShape",
+      /*args=*/(ins "OpOperand*":$opOperand),
+      /*methodBody=*/[{
+        assert(opOperand->getOwner() == this->getOperation());
+        Type t = opOperand->get().getType();
+        // A VectorType is an elemental type, do not consider its rank for the operand.
+        if (isa<VectorType>(t))
+          return {};
+        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
+          // Failsafe.
+          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
+                 "expected a ranked tensor or memref in LinalgInterface::getRank");
+          return shapedType.getShape();
+        }
+        return {};
       }]
     >,
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 12362be4d7d30..91c3fc528029e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -209,37 +209,13 @@ def Vector_ContractionOp :
     }
 
     //===------------------------------------------------------------------===//
-    // The code below is shared with LinalgStructuredInterface.
-    // vector.contract is really a linalg.generic on vectors without region.
-    // TODO: factor out in a common interface to inherit from ince identified.
+    // IndexingMapOpInterface interface methods implementation.
     //===------------------------------------------------------------------===//
     ArrayRef<int64_t> getShape(OpOperand * opOperand) {
       assert(opOperand->getOwner() == this->getOperation());
       Type t = opOperand->get().getType();
       return cast<VectorType>(t).getShape();
     }
-
-    AffineMap getLoopsToShapesMap() {
-      auto maps = getIndexingMapsArray();
-      return concatAffineMaps(maps, getContext());
-    }
-
-    AffineMap getShapesToLoopsMap() {
-      return inversePermutation(getLoopsToShapesMap());
-    }
-
-    SmallVector<int64_t> getStaticShape(){
-      SmallVector<int64_t> res;
-      for (OpOperand &opOperand : this->getOperation()->getOpOperands())
-        llvm::append_range(res, getShape(&opOperand));
-      return res;
-    }
-
-    SmallVector<int64_t> getStaticLoopRanges() {
-      SmallVector<int64_t> viewSizes = getStaticShape();
-      AffineMap invertedMap = getShapesToLoopsMap();
-      return invertedMap.compose(viewSizes);
-    }
   }];
 
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index d81298bb4daf0..067e0511e4e75 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_interface(CopyOpInterface)
 add_mlir_interface(DerivedAttributeOpInterface)
 add_mlir_interface(DestinationStyleOpInterface)
 add_mlir_interface(FunctionInterfaces)
+add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
diff --git a/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
new file mode 100644
index 0000000000000..d5978b6fe9b78
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
@@ -0,0 +1,20 @@
+//===- IndexingMapOpInterface.h ----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
+#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/IndexingMapOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
new file mode 100644
index 0000000000000..82dc420c5437a
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
@@ -0,0 +1,162 @@
+//===- IndexingMapOpInterface.td - Interface Declaration -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the IndexingMapOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
+#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
+
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/IR/OpBase.td"
+
+def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
+  let description = [{
+    Interface for operations that connect an iteration domain to operands via
+    affine maps. Provides methods to access indexing maps between iteration
+    domain and operand index spaces.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the indexing maps attribute within the current operation.
+      }],
+      /*retTy=*/"ArrayAttr",
+      /*methodName=*/"getIndexingMaps"
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the indexing maps within the current operation.
+      }],
+      /*retTy=*/"SmallVector<AffineMap>",
+      /*methodName=*/"getIndexingMapsArray",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto range = $_op.getIndexingMaps()
+          .template getAsValueRange<AffineMapAttr>();
+        return {range.begin(), range.end()};
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the input or output indexing map for `opOperand`.
+      }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getMatchingIndexingMap",
+      /*args=*/(ins "OpOperand*":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(opOperand->getOwner() == this->getOperation());
+        auto indexingMaps =
+          $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
+        return *(indexingMaps.begin() + opOperand->getOperandNumber());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Hook to provide a custom AffineMap used to compute all the operand
+        subshapes given loop bounds. This is used to answer the question: "given
+        an iteration space over the codomain, what are the subshapes of the
+        operands involved in the computation".
+        The default behavior is to just concatenate all the indexing maps.
+        A custom AffineMap allows providing a map that can be used to
+        compute subshapes even in cases where the concatenation of indexing maps
+        (i.e. the data traversal order) is not a simple permutation of the loop
+        traversal order. It is then possible to define ops with skewed data
+        traversal order for which we can still easily compute hyperrectangular
+        loop bounds and subviews.
+      }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getLoopsToShapesMap",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto maps =  $_op.getIndexingMapsArray();
+        return concatAffineMaps(maps, $_op.getContext());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the `opOperand` shape or an empty vector if not shaped.
+        Must be implemented by derived interfaces.
+      }],
+      /*retTy=*/"ArrayRef<int64_t>",
+      /*methodName=*/"getShape",
+      /*args=*/(ins "OpOperand*":$opOperand),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        assert(false && "Need to override");
+        return {};
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Like `getShape`, but only returns statically-known information, without
+        generating any new IR. For each shape dimension, returns >=0 if that
+        dimension is statically known, or ShapedType::kDynamic otherwise.
+      }],
+      /*retTy=*/"SmallVector<int64_t>",
+      /*methodName=*/"getStaticShape",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        SmallVector<int64_t> res;
+        for (OpOperand &opOperand : this->getOperation()->getOpOperands())
+          llvm::append_range(res, getShape(&opOperand));
+        return res;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Hook to provide a custom AffineMap used to construct the
+        hyperrectangular loop iteration space given all the operand subshapes.
+        This is used to answer the question:
+        "Given a list of operand ranges, what is the subportion of the iteration
+        space involved in the computation".
+        This is the inverse problem of `getLoopsToShapesMap`.
+        Return the empty AffineMap when such an AffineMap cannot be constructed.
+        The default behavior is based on a very simple inference procedure that
+        only works with permutation affine maps.
+        A more advanced Tensor-Comprehension like inference is possible but has
+        proven to be ambiguous in unfavorable case.
+        A safer and more robust alternative is to allow each op to define
+        its own AffineMap.
+      }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getShapesToLoopsMap",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return inversePermutation(getLoopsToShapesMap());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the statically-known loop ranges. Composes
+        `getShapesToLoopsMap()` with the result of `getStaticShape`.
+        Returns ShapedType::kDynamic for non-statically-known loop ranges.
+        This is expected to be called by a valid Linalg op
+      }],
+      /*retTy=*/"SmallVector<int64_t, 4>",
+      /*methodName=*/"getStaticLoopRanges",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        SmallVector<int64_t> viewSizes = getStaticShape();
+        AffineMap invertedMap = getShapesToLoopsMap();
+        assert(invertedMap && "expected a valid Linalg op to call the method");
+        return invertedMap.compose(viewSizes);
+      }]
+    >,
+  ];
+}
+
+#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index b4aeb44ac8faf..ec433284e17ad 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
   MLIRDestinationStyleOpInterface
   MLIRDialectUtils
   MLIRFunctionInterfaces
+  MLIRIndexingMapOpInterface
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRParser
diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
index f56ef485069f8..8d6d9dc690b55 100644
--- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index a25694cfff5f2..af923d98c76ff 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
   DestinationStyleOpInterface.cpp
   FunctionImplementation.cpp
   FunctionInterfaces.cpp
+  IndexingMapOpInterface.cpp
   InferIntRangeInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
@@ -62,6 +63,7 @@ add_mlir_library(MLIRFunctionInterfaces
   MLIRIR
 )
 
+add_mlir_interface_library(IndexingMapOpInterface)
 add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
index 496238fcaa3ff..cb0885344d906 100644
--- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
+++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
@@ -1,4 +1,4 @@
-//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
+//===- IndexingMapOpInterface.cpp -- Destination style ops -----------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,57 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
 
 using namespace mlir;
 
 namespace mlir {
-#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
+#include "mlir/Interfaces/IndexingMapOpInterface.cpp.inc"
 } // namespace mlir
-
-namespace {
-size_t getNumTensorResults(Operation *op) {
-  size_t numTensorResults = 0;
-  for (auto t : op->getResultTypes()) {
-    if (isa<TensorType>(t)) {
-      ++numTensorResults;
-    }
-  }
-  return numTensorResults;
-}
-} // namespace
-
-LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
-  DestinationStyleOpInterface dstStyleOp =
-      cast<DestinationStyleOpInterface>(op);
-
-  SmallVector<OpOperand *> outputTensorOperands;
-  for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
-    Type type = operand.get().getType();
-    if (isa<TensorType>(type)) {
-      outputTensorOperands.push_back(&operand);
-    } else if (!isa<BaseMemRefType>(type)) {
-      return op->emitOpError("expected that operand #")
-             << operand.getOperandNumber() << " is a tensor or a memref";
-    }
-  }
-
-  // Verify the number of tensor results matches the number of output tensors.
-  if (getNumTensorResults(op) != outputTensorOperands.size())
-    return op->emitOpError("expected the number of tensor results (")
-           << getNumTensorResults(op)
-           << ") to be equal to the number of output tensors ("
-           << outputTensorOperands.size() << ")";
-
-  for (OpOperand *opOperand : outputTensorOperands) {
-    OpResult result = dstStyleOp.getTiedOpResult(opOperand);
-    if (result.getType() != opOperand->get().getType())
-      return op->emitOpError("expected type of operand #")
-             << opOperand->getOperandNumber() << " ("
-             << opOperand->get().getType() << ")"
-             << " to match type of corresponding result (" << result.getType()
-             << ")";
-  }
-
-  return success();
-}
diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
new file mode 100644
index 0000000000000..496238fcaa3ff
--- /dev/null
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -0,0 +1,62 @@
+//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
+} // namespace mlir
+
+namespace {
+size_t getNumTensorResults(Operation *op) {
+  size_t numTensorResults = 0;
+  for (auto t : op->getResultTypes()) {
+    if (isa<TensorType>(t)) {
+      ++numTensorResults;
+    }
+  }
+  return numTensorResults;
+}
+} // namespace
+
+LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
+  DestinationStyleOpInterface dstStyleOp =
+      cast<DestinationStyleOpInterface>(op);
+
+  SmallVector<OpOperand *> outputTensorOperands;
+  for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
+    Type type = operand.get().getType();
+    if (isa<TensorType>(type)) {
+      outputTensorOperands.push_back(&operand);
+    } else if (!isa<BaseMemRefType>(type)) {
+      return op->emitOpError("expected that operand #")
+             << operand.getOperandNumber() << " is a tensor or a memref";
+    }
+  }
+
+  // Verify the number of tensor results matches the number of output tensors.
+  if (getNumTensorResults(op) != outputTensorOperands.size())
+    return op->emitOpError("expected the number of tensor results (")
+           << getNumTensorResults(op)
+           << ") to be equal to the number of output tensors ("
+           << outputTensorOperands.size() << ")";
+
+  for (OpOperand *opOperand : outputTensorOperands) {
+    OpResult result = dstStyleOp.getTiedOpResult(opOperand);
+    if (result.getType() != opOperand->get().getType())
+      return op->emitOpError("expected type of operand #")
+             << opOperand->getOperandNumber() << " ("
+             << opOperand->get().getType() << ")"
+             << " to match type of corresponding result (" << result.getType()
+             << ")";
+  }
+
+  return success();
+}



More information about the Mlir-commits mailing list