[Mlir-commits] [mlir] d31ba52 - [mlir][Interface] Factor out common IndexingMapOpInterface behavior in a new generic interface (#145313)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 23 22:56:36 PDT 2025


Author: Nicolas Vasilache
Date: 2025-06-24T07:56:32+02:00
New Revision: d31ba5256327d30f264c2f671bf197877b242cde

URL: https://github.com/llvm/llvm-project/commit/d31ba5256327d30f264c2f671bf197877b242cde
DIFF: https://github.com/llvm/llvm-project/commit/d31ba5256327d30f264c2f671bf197877b242cde.diff

LOG: [mlir][Interface] Factor out common IndexingMapOpInterface behavior in a new generic interface (#145313)

Refactor the verifiers to make use of the common bits and make
`vector.contract` also use this interface.
In the process, the confusingly named getStaticShape has disappeared.

Note: the verifier for IndexingMapOpInterface is currently called
manually from other verifiers as it was unclear how to avoid it taking
precedence over more meaningful error messages

Added: 
    mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
    mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
    mlir/lib/Interfaces/IndexingMapOpInterface.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/Interfaces/CMakeLists.txt
    mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/IR/CMakeLists.txt
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Interfaces/CMakeLists.txt
    mlir/test/Dialect/Linalg/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0f960fb5ad795..0ebbeea937554 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/RawOstreamExtras.h"

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 74c4c0a8835f2..ca1cba8747bd8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -14,6 +14,7 @@
 #define LINALG_IR_LINALGINTERFACES
 
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/IndexingMapOpInterface.td"
 include "mlir/IR/OpBase.td"
 
 // The 'LinalgContractionOpInterface' provides access to the
@@ -222,59 +223,11 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
   ];
 }
 
-def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
-  let description = [{
-    Interface for operations that connect an iteration domain to operands via
-    affine maps. Provides methods to access indexing maps between iteration
-    domain and operand index spaces.
-  }];
-  let cppNamespace = "::mlir::linalg";
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the indexing maps attribute within the current operation.
-      }],
-      /*retTy=*/"ArrayAttr",
-      /*methodName=*/"getIndexingMaps"
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the indexing maps within the current operation.
-      }],
-      /*retTy=*/"SmallVector<AffineMap>",
-      /*methodName=*/"getIndexingMapsArray",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto range = $_op.getIndexingMaps()
-          .template getAsValueRange<AffineMapAttr>();
-        return {range.begin(), range.end()};
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the input or output indexing map for `opOperand`.
-      }],
-      /*retTy=*/"AffineMap",
-      /*methodName=*/"getMatchingIndexingMap",
-      /*args=*/(ins "OpOperand*":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(opOperand->getOwner() == this->getOperation());
-        auto indexingMaps =
-          $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
-        return *(indexingMaps.begin() + opOperand->getOperandNumber());
-      }]
-    >,
-  ];
-}
-
 // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
 def LinalgStructuredInterface
-    : OpInterface<"LinalgOp", [
-      DestinationStyleOpInterface,
-      IndexingMapOpInterface
-  ]> {
+    : OpInterface<"LinalgOp", 
+      [DestinationStyleOpInterface, IndexingMapOpInterface]
+  > {
   let cppNamespace = "::mlir::linalg";
   let methods = [
     //===------------------------------------------------------------------===//
@@ -464,30 +417,6 @@ def LinalgStructuredInterface
         return getBlock()->getArguments().take_back($_op.getNumDpsInits());
       }]
     >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the `opOperand` shape or an empty vector for scalars or vectors
-        not wrapped within a tensor or a memref.
-      }],
-      /*retTy=*/"ArrayRef<int64_t>",
-      /*methodName=*/"getShape",
-      /*args=*/(ins "OpOperand*":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(opOperand->getOwner() == this->getOperation());
-        Type t = opOperand->get().getType();
-        // A VectorType is an elemental type, do not consider its rank for the operand.
-        if (isa<VectorType>(t))
-          return {};
-        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
-          // Failsafe.
-          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
-                 "expected a ranked tensor or memref in LinalgInterface::getRank");
-          return shapedType.getShape();
-        }
-        return {};
-      }]
-    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the block argument for an `opOperand`.
@@ -620,7 +549,12 @@ def LinalgStructuredInterface
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return llvm::any_of(getStaticShape(), ShapedType::isDynamic);
+        for (OpOperand &opOperand : this->getOperation()->getOpOperands()) {
+          if (auto shapedType = dyn_cast<ShapedType>(opOperand.get().getType())) {
+            if (ShapedType::isDynamicShape(shapedType.getShape())) return true;
+          }
+        }
+        return false;
       }]
     >,
     InterfaceMethod<
@@ -738,53 +672,6 @@ def LinalgStructuredInterface
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
-    InterfaceMethod<
-      /*desc=*/[{
-        Hook to provide a custom AffineMap used to compute all the operand
-        subshapes given loop bounds. This is used to answer the question: "given
-        an iteration space over the codomain, what are the subshapes of the
-        operands involved in the computation".
-        The default behavior is to just concatenate all the indexing maps.
-        A custom AffineMap allows providing a map that can be used to
-        compute subshapes even in cases where the concatenation of indexing maps
-        (i.e. the data traversal order) is not a simple permutation of the loop
-        traversal order. It is then possible to define ops with skewed data
-        traversal order for which we can still easily compute hyperrectangular
-        loop bounds and subviews.
-      }],
-      /*retTy=*/"AffineMap",
-      /*methodName=*/"getLoopsToShapesMap",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto maps =  $_op.getIndexingMapsArray();
-        return concatAffineMaps(maps, $_op.getContext());
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Hook to provide a custom AffineMap used to construct the
-        hyperrectangular loop iteration space given all the operand subshapes.
-        This is used to answer the question:
-        "Given a list of operand ranges, what is the subportion of the iteration
-        space involved in the computation".
-        This is the inverse problem of `getLoopsToShapesMap`.
-        Return the empty AffineMap when such an AffineMap cannot be constructed.
-        The default behavior is based on a very simple inference procedure that
-        only works with permutation affine maps.
-        A more advanced Tensor-Comprehension like inference is possible but has
-        proven to be ambiguous in unfavorable case.
-        A safer and more robust alternative is to allow each op to define
-        its own AffineMap.
-      }],
-      /*retTy=*/"AffineMap",
-      /*methodName=*/"getShapesToLoopsMap",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return inversePermutation(getLoopsToShapesMap());
-      }]
-    >,
     InterfaceMethod<
       /*desc=*/[{
         Checks if the given operands can be dropped, and the remaining
@@ -798,39 +685,30 @@ def LinalgStructuredInterface
         return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands);
       }]
     >,
+    //===------------------------------------------------------------------===//
+    // IndexingMapOpInterface interface methods implementation.
+    //===------------------------------------------------------------------===//
     InterfaceMethod<
       /*desc=*/[{
-        Like `getShape`, but only returns statically-known information, without
-        generating any new IR. For each shape dimension, returns >=0 if that
-        dimension is statically known, or ShapedType::kDynamic otherwise.
-      }],
-      /*retTy=*/"SmallVector<int64_t>",
-      /*methodName=*/"getStaticShape",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        SmallVector<int64_t> res;
-        for (OpOperand &opOperand : this->getOperation()->getOpOperands())
-          llvm::append_range(res, getShape(&opOperand));
-        return res;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Returns the statically-known loop ranges. Composes
-        `getShapesToLoopsMap()` with the result of `getStaticShape`.
-        Returns ShapedType::kDynamic for non-statically-known loop ranges.
-        This is expected to be called by a valid Linalg op
+        Return the `opOperand` shape or an empty vector for scalars or vectors
+        not wrapped within a tensor or a memref.
       }],
-      /*retTy=*/"SmallVector<int64_t, 4>",
-      /*methodName=*/"getStaticLoopRanges",
-      /*args=*/(ins),
+      /*retTy=*/"ArrayRef<int64_t>",
+      /*methodName=*/"getShape",
+      /*args=*/(ins "OpOperand*":$opOperand),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        SmallVector<int64_t> viewSizes = getStaticShape();
-        AffineMap invertedMap = getShapesToLoopsMap();
-        assert(invertedMap && "expected a valid Linalg op to call the method");
-        return invertedMap.compose(viewSizes);
+        Type t = opOperand->get().getType();
+        // A VectorType is an elemental type, do not consider its rank for the operand.
+        if (isa<VectorType>(t))
+          return {};
+        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
+          // Failsafe.
+          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
+                 "expected a ranked tensor or memref in LinalgInterface::getRank");
+          return shapedType.getShape();
+        }
+        return {};
       }]
     >,
     //===------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 98fb6075cbf32..364c1728715e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -25,6 +25,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 926a92eff2ebb..02e62930a742d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -21,6 +21,7 @@ include "mlir/Dialect/Vector/IR/Vector.td"
 include "mlir/Dialect/Vector/IR/VectorAttributes.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/IndexingMapOpInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -33,6 +34,7 @@ include "mlir/IR/EnumAttr.td"
 // than the current set: {*, +}.
 def Vector_ContractionOp :
   Vector_Op<"contract", [
+      IndexingMapOpInterface,
       Pure,
       PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
       PredOpTrait<"third operand acc and result have same element type",
@@ -207,6 +209,16 @@ def Vector_ContractionOp :
               .template getAsValueRange<IteratorTypeAttr, IteratorType>();
       return {range.begin(), range.end()};
     }
+
+    //===------------------------------------------------------------------===//
+    // IndexingMapOpInterface interface methods implementation.
+    //===------------------------------------------------------------------===//
+    ArrayRef<int64_t> getShape(OpOperand * opOperand) {
+      Type t = opOperand->get().getType();
+      if (auto vt = dyn_cast<VectorType>(t))
+        return vt.getShape();
+      return {};
+    }
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index d81298bb4daf0..067e0511e4e75 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_interface(CopyOpInterface)
 add_mlir_interface(DerivedAttributeOpInterface)
 add_mlir_interface(DestinationStyleOpInterface)
 add_mlir_interface(FunctionInterfaces)
+add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)

diff  --git a/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
new file mode 100644
index 0000000000000..40252613a21f4
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.h
@@ -0,0 +1,27 @@
+//===- IndexingMapOpInterface.h ----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
+#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace detail {
+/// Verify that `op` conforms to the invariants of StructuredOpInterface
+LogicalResult verifyIndexingMapOpInterface(Operation *op);
+} // namespace detail
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/IndexingMapOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE_H_

diff  --git a/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
new file mode 100644
index 0000000000000..fdcc183d99215
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/IndexingMapOpInterface.td
@@ -0,0 +1,153 @@
+//===- IndexingMapOpInterface.td - Interface Declaration -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the IndexingMapOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
+#define MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
+  let description = [{
+    Interface for operations that connect an iteration domain to operands via
+    affine maps. Provides methods to access indexing maps between iteration
+    domain and operand index spaces.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the indexing maps attribute within the current operation.
+      }],
+      /*retTy=*/"ArrayAttr",
+      /*methodName=*/"getIndexingMaps"
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the indexing maps within the current operation.
+      }],
+      /*retTy=*/"SmallVector<AffineMap>",
+      /*methodName=*/"getIndexingMapsArray",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto range = $_op.getIndexingMaps()
+          .template getAsValueRange<AffineMapAttr>();
+        return {range.begin(), range.end()};
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the input or output indexing map for `opOperand`.
+      }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getMatchingIndexingMap",
+      /*args=*/(ins "OpOperand*":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(opOperand->getOwner() == this->getOperation());
+        auto indexingMaps =
+          $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
+        return *(indexingMaps.begin() + opOperand->getOperandNumber());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Hook to provide a custom AffineMap used to compute all the operand
+        subshapes given loop bounds. This is used to answer the question: "given
+        an iteration space over the codomain, what are the subshapes of the
+        operands involved in the computation".
+        The default behavior is to just concatenate all the indexing maps.
+        A custom AffineMap allows providing a map that can be used to
+        compute subshapes even in cases where the concatenation of indexing maps
+        (i.e. the data traversal order) is not a simple permutation of the loop
+        traversal order. It is then possible to define ops with skewed data
+        traversal order for which we can still easily compute hyperrectangular
+        loop bounds and subviews.
+      }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getLoopsToShapesMap",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto maps =  $_op.getIndexingMapsArray();
+        return concatAffineMaps(maps, $_op.getContext());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Hook to provide a custom AffineMap used to construct the
+        hyperrectangular loop iteration space given all the operand subshapes.
+        This is used to answer the question:
+        "Given a list of operand ranges, what is the subportion of the iteration
+        space involved in the computation".
+        This is the inverse problem of `getLoopsToShapesMap`.
+        Return the empty AffineMap when such an AffineMap cannot be constructed.
+        The default behavior is based on a very simple inference procedure that
+        only works with permutation affine maps.
+        A more advanced Tensor-Comprehension like inference is possible but has
+        proven to be ambiguous in unfavorable case.
+        A safer and more robust alternative is to allow each op to define
+        its own AffineMap.
+      }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getShapesToLoopsMap",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return inversePermutation($_op.getLoopsToShapesMap());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the static shape of the underlying operand (note this is
+        op-specific behavior).
+        Returns ShapedType::kDynamic for non-statically-known loop ranges.
+      }],
+      /*retTy=*/"SmallVector<int64_t>",
+      /*methodName=*/"getStaticOperandShape",
+      /*args=*/(ins "OpOperand*":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        SmallVector<int64_t> res;
+        llvm::append_range(res, $_op.getShape(opOperand));
+        return res;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns loop ranges by composing `getShapesToLoopsMap()` with the
+        flattened list of operand shapes.
+        Returns ShapedType::kDynamic for non-statically-known loop ranges.
+      }],
+      /*retTy=*/"SmallVector<int64_t>",
+      /*methodName=*/"getStaticLoopRanges",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        SmallVector<int64_t> allShapesSizes;
+        for (OpOperand &opOperand : this->getOperation()->getOpOperands())
+          llvm::append_range(allShapesSizes, $_op.getShape(&opOperand));
+        AffineMap invertedMap = $_op.getShapesToLoopsMap();
+        assert(invertedMap && "expected a valid op");
+        return invertedMap.compose(allShapesSizes);
+      }]
+    >
+  ];
+  let extraClassDeclaration = [{
+    // Verifier implementation for IndexingMapOpInterface.
+    // This must be called manually as part of other verifiers so that the
+    // verification order, and meaningful error messages, are not preempted.
+    LogicalResult verifyImpl();
+  }];
+}
+
+#endif // MLIR_INTERFACES_INDEXING_MAP_OP_INTERFACE

diff  --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index b4aeb44ac8faf..ec433284e17ad 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
   MLIRDestinationStyleOpInterface
   MLIRDialectUtils
   MLIRFunctionInterfaces
+  MLIRIndexingMapOpInterface
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRParser

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 139e9901b0a29..ca7f31dd6b518 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1251,38 +1251,20 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
     if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
       return failure();
 
-  // All input/output operands must be indexed.
-  if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
-      linalgOp->getNumOperands())
-    return op->emitOpError("expected the number of indexing_map (")
-           << linalgOp.getIndexingMapsArray().size()
-           << ") to be equal to the number of input/output operands ("
-           << linalgOp->getNumOperands() << ")";
+  // Delayed calling of IndexingMapOpInterface::verifyImpl.
+  if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
+    return failure();
 
   // Set this flag if this op has user defined maps. This is required to guard
   // the below error condition which assume default indexing maps.
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
-
-    // Symbols disallowed.
-    if (indexingMap.getNumSymbols() != 0)
-      return op->emitOpError("unexpected symbols in indexing_map #")
-             << opOperand.getOperandNumber();
-
     // Domain must be consistent.
     unsigned numLoops = linalgOp.getNumLoops();
     if (indexingMap.getNumDims() != numLoops)
       return op->emitOpError("expected indexing_map #")
              << opOperand.getOperandNumber() << " to have " << numLoops
              << " dim(s) to match the number of loops";
-
-    int64_t rank = linalgOp.getRank(&opOperand);
-
-    if (indexingMap.getNumResults() != rank)
-      return op->emitOpError("expected operand rank (")
-             << rank << ") to match the result rank of indexing_map #"
-             << opOperand.getOperandNumber() << " ("
-             << indexingMap.getNumResults() << ")";
   }
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
@@ -1290,67 +1272,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   if (!linalgOp.getShapesToLoopsMap())
     return op->emitOpError("expected the shape-to-loops map to be non-null");
 
-  // Check if given shapes match to inferred shapes.
-  SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
-  SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-  // Verify only static cases since we can't get exact dimension sizes and
-  // loop ranges for dynamic cases in this stage.
-  if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
-    for (int64_t &range : endLoopRangeValues)
-      range -= 1;
-    for (OpOperand &opOperand : linalgOp->getOpOperands()) {
-      AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
-      SmallVector<int64_t, 4> startIndices =
-          indexingMap.compose(startLoopRangeValues);
-      SmallVector<int64_t, 4> endIndices =
-          indexingMap.compose(endLoopRangeValues);
-      ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
-      for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
-        // Ignore dynamic dimension or the case that the dimension size is 0
-        if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
-          continue;
-
-        // The first index or last index should be the maximum or the minimum in
-        // the inferred index ranges since the range is increasing or
-        // decreasing. The size of dimensions of input/output operands and the
-        // maximum value + 1 in the inferred range should be the same. But, for
-        // now we check if the inferred ranges are in boundary of input/output
-        // operands' size or not in case that Affine Expressions are complicated
-        // such as d0 * 3
-        // + d1 since it is not easy to handle the issues.
-        // Found the case that this solution can't check, for example, (d0, d1)
-        // -> (d1 - d0)
-        int64_t inferredDimSize =
-            std::max(startIndices[dim], endIndices[dim]) + 1;
-        if (std::min(startIndices[dim], endIndices[dim]) < 0) {
-          std::string mapStr;
-          {
-            llvm::raw_string_ostream os(mapStr);
-            os << indexingMap;
-          }
-          return op->emitOpError(
-                     "unexpected result less than 0 at expression #")
-                 << dim << " in " << mapStr;
-        }
-        if (isa<AffineDimExpr>(indexingMap.getResult(dim))) {
-          if (inferredDimSize != shape[dim]) {
-            return op->emitOpError("inferred input/output operand #")
-                   << opOperand.getOperandNumber() << " has shape's dimension #"
-                   << dim << " to be " << inferredDimSize << ", but found "
-                   << shape[dim];
-          }
-        } else {
-          if (inferredDimSize > shape[dim]) {
-            return op->emitOpError("inferred input/output operand #")
-                   << opOperand.getOperandNumber() << " has shape's dimension #"
-                   << dim << " to be greater than or equal to "
-                   << inferredDimSize << ", but found " << shape[dim];
-          }
-        }
-      }
-    }
-  }
-
   // Check the region has exactly one block.
   if (linalgOp->getNumRegions() != 1 ||
       !llvm::hasSingleElement(linalgOp->getRegion(0)))

diff  --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
index f56ef485069f8..8d6d9dc690b55 100644
--- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 
 using namespace mlir;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 5e6dde36d7f9f..c5d9a729a4136 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -398,7 +398,10 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
     return rewriter.notifyMatchFailure(genericOp,
                                        "invalid indexing maps for operation");
   }
-  SmallVector<int64_t> dims = genericOp.getStaticShape();
+
+  SmallVector<int64_t> allShapesSizes;
+  for (OpOperand &opOperand : genericOp->getOpOperands())
+    llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand));
 
   // 1a. Get the allowed list of dimensions to drop from the `options`.
   SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
@@ -411,7 +414,7 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
   llvm::SmallDenseSet<unsigned> unitDims;
   for (const auto &expr : enumerate(invertedMap.getResults())) {
     if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
-      if (dims[dimExpr.getPosition()] == 1 &&
+      if (allShapesSizes[dimExpr.getPosition()] == 1 &&
           unitDimsFilter.count(expr.index()))
         unitDims.insert(expr.index());
     }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ff28bd7c48342..ff8e0b8977ae8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -31,6 +31,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/STLExtras.h"
@@ -2217,7 +2218,9 @@ static LogicalResult vectorizeLinalgOpPrecondition(
     LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
     bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
   // tensor with dimension of 0 cannot be vectorized.
-  if (llvm::is_contained(linalgOp.getStaticShape(), 0))
+  if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
+        return llvm::is_contained(linalgOp.getShape(&operand), 0);
+      }))
     return failure();
   // Check API contract for input vector sizes.
   if (!inputVectorSizes.empty() &&

diff  --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 204462ffd047c..d464230c87723 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRVectorDialect
   MLIRDataLayoutInterfaces
   MLIRDestinationStyleOpInterface
   MLIRDialectUtils
+  MLIRIndexingMapOpInterface
   MLIRIR
   MLIRMaskableOpInterface
   MLIRMaskingOpInterface

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ee9ab61b670c4..5e0f36064be3b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1063,7 +1063,8 @@ LogicalResult ContractionOp::verify() {
   if (!isSupportedCombiningKind(getKind(), elementType))
     return emitOpError("unsupported contraction type");
 
-  return success();
+  // Delayed calling of IndexingMapOpInterface::verifyImpl.
+  return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
 }
 
 // MaskableOpInterface methods.

diff  --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index a25694cfff5f2..af923d98c76ff 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -8,6 +8,7 @@ set(LLVM_OPTIONAL_SOURCES
   DestinationStyleOpInterface.cpp
   FunctionImplementation.cpp
   FunctionInterfaces.cpp
+  IndexingMapOpInterface.cpp
   InferIntRangeInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
@@ -62,6 +63,7 @@ add_mlir_library(MLIRFunctionInterfaces
   MLIRIR
 )
 
+add_mlir_interface_library(IndexingMapOpInterface)
 add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 

diff  --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
new file mode 100644
index 0000000000000..f3c12aed8df84
--- /dev/null
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -0,0 +1,125 @@
+//===- IndexingMapOpInterface.cpp -- IndexingMapOpInterface impl ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/IndexingMapOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+#include "mlir/Interfaces/IndexingMapOpInterface.cpp.inc"
+} // namespace mlir
+
+LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
+  // All input/output operands must be indexed.
+  if (static_cast<int64_t>(getIndexingMapsArray().size()) !=
+      getOperation()->getNumOperands())
+    return this->emitOpError("expected the number of indexing_map (")
+           << getIndexingMapsArray().size()
+           << ") to be equal to the number of input/output operands ("
+           << getOperation()->getNumOperands() << ")";
+
+  AffineMap invertedMap = getShapesToLoopsMap();
+  if (!invertedMap) {
+    std::string str;
+    llvm::raw_string_ostream os(str);
+    getLoopsToShapesMap().print(os);
+    return this->emitOpError("invalid indexing maps are non-invertible: ")
+           << "(" << str << ")";
+  }
+
+  SmallVector<int64_t> endLoopRangeValues = getStaticLoopRanges();
+
+  // Set this flag if this op has user defined maps. This is required to guard
+  // the below error condition which assume default indexing maps.
+  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+    AffineMap indexingMap = getMatchingIndexingMap(&opOperand);
+
+    // Symbols disallowed.
+    if (indexingMap.getNumSymbols() != 0)
+      return getOperation()->emitOpError("unexpected symbols in indexing_map #")
+             << opOperand.getOperandNumber();
+
+    // Domain must be consistent.
+    if (indexingMap.getNumDims() != endLoopRangeValues.size())
+      return getOperation()->emitOpError("expected indexing_map #")
+             << opOperand.getOperandNumber() << " to have "
+             << endLoopRangeValues.size()
+             << " dim(s) to match the number of loops";
+
+    SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
+    int64_t rank = shape.size();
+
+    if (indexingMap.getNumResults() != rank)
+      return getOperation()->emitOpError("expected operand rank (")
+             << rank << ") to match the result rank of indexing_map #"
+             << opOperand.getOperandNumber() << " ("
+             << indexingMap.getNumResults() << ")";
+  }
+
+  // Check if given shapes match to inferred shapes.
+  SmallVector<int64_t> startLoopRangeValues(endLoopRangeValues.size(), 0);
+  // Verify only static cases since we can't get exact dimension sizes and
+  // loop ranges for dynamic cases in this stage.
+  if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
+    // Exclusive end range.
+    for (int64_t &range : endLoopRangeValues)
+      range -= 1;
+    for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+      AffineMap indexingMap = getMatchingIndexingMap(&opOperand);
+      SmallVector<int64_t> startIndices =
+          indexingMap.compose(startLoopRangeValues);
+      SmallVector<int64_t> endIndices = indexingMap.compose(endLoopRangeValues);
+      SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
+      for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
+        // Ignore dynamic dimension or the case that the dimension size is 0
+        if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
+          continue;
+
+        // The first index or last index should be the maximum or the minimum in
+        // the inferred index ranges since the range is increasing or
+        // decreasing. The size of dimensions of input/output operands and the
+        // maximum value + 1 in the inferred range should be the same. But, for
+        // now we check if the inferred ranges are in boundary of input/output
+        // operands' size or not in case that Affine Expressions are complicated
+        // such as d0 * 3
+        // + d1 since it is not easy to handle the issues.
+        // Found the case that this solution can't check, for example, (d0, d1)
+        // -> (d1 - d0)
+        int64_t inferredDimSize =
+            std::max(startIndices[dim], endIndices[dim]) + 1;
+        if (std::min(startIndices[dim], endIndices[dim]) < 0) {
+          std::string mapStr;
+          {
+            llvm::raw_string_ostream os(mapStr);
+            os << indexingMap;
+          }
+          return this->emitOpError(
+                     "unexpected result less than 0 at expression #")
+                 << dim << " in " << mapStr;
+        }
+        if (isa<AffineDimExpr>(indexingMap.getResult(dim))) {
+          if (inferredDimSize != shape[dim]) {
+            return this->emitOpError("inferred input/output operand #")
+                   << opOperand.getOperandNumber() << " has shape's dimension #"
+                   << dim << " to be " << inferredDimSize << ", but found "
+                   << shape[dim];
+          }
+        } else {
+          if (inferredDimSize > shape[dim]) {
+            return this->emitOpError("inferred input/output operand #")
+                   << opOperand.getOperandNumber() << " has shape's dimension #"
+                   << dim << " to be greater than or equal to "
+                   << inferredDimSize << ", but found " << shape[dim];
+          }
+        }
+      }
+    }
+  }
+
+  return success();
+}

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c0c5f785e856b..ca40301f04fa1 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -151,7 +151,7 @@ func.func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off
 // -----
 
 func.func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>, %arg1: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+1 {{expected the shape-to-loops map to be non-null}}
+  // expected-error @+1 {{invalid indexing maps are non-invertible: ((d0, d1) -> (d0 + d1, d0 + d1))}}
   linalg.generic {
     indexing_maps =  [
       affine_map<(i, j) -> (i + j)>,


        


More information about the Mlir-commits mailing list