[Mlir-commits] [mlir] [MLIR][Linalg] Remove/update failing obsolete OpDSL tests for linalg.matmul. (PR #115319)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 7 05:49:55 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Md Asghar Ahmad Shahid (shahidact)

<details>
<summary>Changes</summary>

The earlier PR(https://github.com/llvm/llvm-project/pull/104783) which introduces
transpose and broadcast semantic to linalg.matmul was reverted due to two failing
OpDSL test for linalg.matmul.

Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL,
these test started failing and needs to be removed/updated.

This commit removes/updates the failing obsolete tests from below files. All other files
were part of earlier PR and just cherry picked.
    "mlir/test/python/integration/dialects/linalg/opsrun.py"
    "mlir/test/python/integration/dialects/transform.py"

---

Patch is 67.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115319.diff


16 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+10) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-72) 
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+134) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+12-5) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+251-12) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp (+7) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5) 
- (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+6) 
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-17) 
- (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+111) 
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+159) 
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+243) 
- (modified) mlir/test/python/dialects/linalg/ops.py (-75) 
- (modified) mlir/test/python/integration/dialects/linalg/opsrun.py (-115) 
- (modified) mlir/test/python/integration/dialects/transform.py (+14-12) 
- (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+5-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index b81a4c9c8760cf..c0eff99c850752 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -708,6 +708,16 @@ def LinalgStructuredInterface
         return;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if the user has supplied an explicit indexing maps for this op.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasUserDefinedMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return false; }]
+    >,
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index bf2f26de26e9ed..ee88ca516de6ff 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1065,78 +1065,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: matmul
-  cpp_class_name: MatmulOp
-  doc: |-
-    Performs a matrix multiplication of two 2D inputs.
-
-    Numeric casting is performed on the operands to the inner multiply, promoting
-    them to the same data type as the accumulator/output.
-  implements:
-  - LinalgContractionOpInterface
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: A
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
-  - !LinalgOperandDefConfig
-    name: B
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
-  - !LinalgOperandDefConfig
-    name: C
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
-    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
-  iterator_types:
-  - parallel
-  - parallel
-  - reduction
-  assignments:
-  - !ScalarAssign
-    arg: C
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        fn_name: add
-        operands:
-        - !ScalarExpression
-          scalar_arg: C
-        - !ScalarExpression
-          scalar_fn:
-            kind: binary
-            fn_name: mul
-            operands:
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: A
-            - !ScalarExpression
-              scalar_fn:
-                kind: type
-                attr_name: cast
-                type_var: U
-                operands:
-                - !ScalarExpression
-                  scalar_arg: B
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: quantized_matmul
   cpp_class_name: QuantizedMatmulOp
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index c2fee8ea55c960..2b47414ff5e924 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -554,6 +554,140 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Op definition for MatmulOp
+//===----------------------------------------------------------------------===//
+
+def MatmulOp : LinalgStructuredBase_Op<"matmul", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+    
+  let summary = [{
+    Performs a matrix multiplication of two 2D inputs without broadcast or transpose.
+    }];
+  let description = [{
+    Numeric casting is performed on the operands to the inner multiply,
+    promoting them to the same data type as the accumulator/output.
+
+    Broadcast and Transpose semantics can be appiled by specifying the explicit attribute
+    'indexing_maps' as shown below.This is a list attribute, so the list must include all
+    the maps if specified.
+
+    Example Transpose:
+    ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                   ]
+                   ins(%arg0, %arg1 : memref<5x3xf32>,memref<5x7xf32>)
+                   outs(%arg2: memref<3x7xf32>)
+     ```
+
+    Example Broadcast:
+     ```
+    linalg.matmul indexing_maps = [
+                   affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                   affine_map<(d0, d1, d2) -> (d2, d1)>,
+                   affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+     ```
+
+     Example Broadcast and transpose:
+     ```
+     linalg.matmul indexing_maps = [
+                       affine_map<(d0, d1, d2) -> (d2, d0)>, // transpose
+                       affine_map<(d0, d1, d2) -> (d2)>,     // broadcast
+                       affine_map<(d0, d1, d2) -> (d0, d1)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5x3xf32>, memref<7xf32>) outs(%arg2: memref<3x7xf32>)
+    }];
+
+    let arguments = (ins
+      Variadic<AnyType>:$inputs,
+      Variadic<AnyShaped>:$outputs,
+      DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
+      DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
+    );
+    let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+    let regions = (region AnyRegion:$region);
+
+    let skipDefaultBuilders = 1;
+    let builders = [
+      OpBuilder<
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+            "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        buildStructuredOp($_builder, $_state, resultTensorTypes,
+          inputs, outputs, attributes, MatmulOp::getRegionBuilder());
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+            CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addOperands(operands);
+        $_state.addAttributes(attributes);
+        $_state.addTypes(resultTensorTypes);
+        (void)$_state.addRegion();
+      }]>,
+      OpBuilder<
+      (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+       "ValueRange":$outputs,
+       "Attribute":$cast, CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("cast", cast);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
+          attributes, MatmulOp::getRegionBuilder());
+      }]>
+
+    ];
+    let hasCustomAssemblyFormat = 1;
+    let hasFolder = 1;
+    let hasVerifier = 1;
+
+    let extraClassDeclaration = structuredOpsBaseDecls # [{
+      SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+      /// Implements the block region builder.
+      static void regionBuilder(ImplicitLocOpBuilder &b,
+                                Block &block, ArrayRef<NamedAttribute> attrs);
+
+      /// Returns a list of AffineMap with the typical matmul indexing charactristic.
+      SmallVector<AffineMap> getDefaultIndexingMaps();
+
+      /// Returns true if the given broadcast map \p bcastMap is valid for this op.
+      bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
+
+      static std::function<void(ImplicitLocOpBuilder &,
+                                Block &, ArrayRef<NamedAttribute>)>
+      getRegionBuilder() {
+        return regionBuilder;
+      }
+
+      ::mlir::MutableOperandRange getDpsInitsMutable() {
+        return getOutputsMutable();
+      }
+
+      // Generic methods.
+      static unsigned getNumRegionArgs();
+      std::string getLibraryCallName();
+      bool hasDynamicIndexingMaps();
+      /// Check if the op has broadcast and/or transpose semantic. Returns true if the
+      /// user defined indexing maps are not equal to default map.
+      bool hasUserDefinedMaps();
+    }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index bd77965194b27f..0cffadf8fb64a0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -15,14 +15,21 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 #include <numeric>
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -1211,7 +1218,6 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
 
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   LinalgOp linalgOp = cast<LinalgOp>(op);
-
   // Mixed tensor/buffer operands are not allowed.
   if (!linalgOp.hasPureTensorSemantics() &&
       !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
@@ -1231,6 +1237,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
            << ") to be equal to the number of input/output operands ("
            << linalgOp->getNumOperands() << ")";
 
+  // Set this flag if this op has user defined maps. This is required to guard
+  // the below error condition which assume default indexing maps.
   for (OpOperand &opOperand : linalgOp->getOpOperands()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
 
@@ -1247,13 +1255,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
              << " dim(s) to match the number of loops";
 
     int64_t rank = linalgOp.getRank(&opOperand);
+
     if (indexingMap.getNumResults() != rank)
       return op->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
   }
-
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
 
@@ -1263,9 +1271,8 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
   SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
-
-  // Verify only static cases since we can't get exact dimension sizes and loop
-  // ranges for dynamic cases in this stage.
+  // Verify only static cases since we can't get exact dimension sizes and
+  // loop ranges for dynamic cases in this stage.
   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
     for (int64_t &range : endLoopRangeValues)
       range -= 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 730c478c2883ef..c909d13e4314b4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
@@ -37,12 +38,17 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cassert>
 #include <optional>
 
 using namespace mlir;
@@ -149,15 +155,36 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   // iterator_types is an auto-generated method.
 }
 
+/// Helper to create a typical indexing map for MatmulOp. Returns a list of
+/// AffineMap.
+static SmallVector<AffineMap, 3>
+getDefaultIndexingMapsForMatmul(MLIRContext *context) {
+  AffineExpr d0, d1, d2;
+  SmallVector<AffineMap, 3> indexingMaps;
+  bindDims(context, d0, d1, d2);
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
+  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
+  return indexingMaps;
+}
+
+/// Wrapper to return the typical indexing map array attribute for MatmulOp.
+static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
+  return llvm::map_to_vector(
+      getDefaultIndexingMapsForMatmul(context),
+      [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+}
+
 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
 /// The result types are derived automatically if `resultTensorTypes` is none.
 /// The body of the operation is filled using `regionBuilder`. All ods-gen
 /// created structured operations use the method to implement their builders.
-static void buildStructuredOp(OpBuilder &b, OperationState &state,
-                              std::optional<TypeRange> resultTensorTypes,
-                              ValueRange inputs, ValueRange outputs,
-                              ArrayRef<NamedAttribute> attributes,
-                              RegionBuilderFn regionBuilder) {
+static void buildStructuredOp(
+    OpBuilder &b, OperationState &state,
+    std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
+    ValueRange outputs, ArrayRef<NamedAttribute> attributes,
+    RegionBuilderFn regionBuilder,
+    std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
   // Derive the result types if needed.
   SmallVector<Type> derivedResultTypes =
       resultTensorTypes.value_or(TypeRange());
@@ -168,6 +195,20 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   state.addOperands(inputs);
   state.addOperands(outputs);
   state.addTypes(derivedResultTypes);
+
+  // Initialize indexingMaps, for MatmulOp.
+  SmallVector<Attribute, 3> indexingMapsAttrVal;
+  if (indexingMaps.has_value()) {
+    for (mlir::AffineMap map : *indexingMaps) {
+      // Convert each AffineMap to an AffineMapAttr
+      indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
+    }
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  } else {
+    indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
+    state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
+  }
+
   state.addAttributes(attributes);
   state.addAttribute(
       "operandSegmentSizes",
@@ -299,11 +340,48 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result,
                                           unsigned numRegionArgs,
                                           RegionBuilderFn regionBuilder) {
+
+  SmallVector<Attribute, 3> indexingMapsAttr;
+  Attribute mapAttr;
+  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
+    if (parser.parseEqual())
+      return failure();
+
+    if (parser.parseLSquare())
+      return failure();
+
+    do {
+      if (parser.parseAttribute(mapAttr))
+        return failure();
+      if (!isa<AffineMapAttr>(mapAttr)) {
+        return parser.emitError(parser.getCurrentLocation(),
+                                "expected affine map attribute");
+      }
+      indexingMapsAttr.push_back(mapAttr);
+
+      if (parser.parseOptionalComma())
+        break;
+    } while (true);
+
+    if (parser.parseRSquare())
+      return failure();
+  }
+  // Initialize indexingMaps, if not supplied explicitly.
+  if (indexingMapsAttr.empty()) {
+    indexingMapsAttr = getDefaultIndexingMapAttr(result.getContext());
+  }
+  result.addAttribute("indexing_maps",
+                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
+
   // TODO: Enable when ods-gen supports captures.
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
 
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
   // TODO: consider merging results parsing into region parsing.
   // Need to wait for declarative assembly resolution to decide.
   SmallVector<Type, 1> outputTensorsTypes;
@@ -329,13 +407,9 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
 }
 
 static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
-                                   ValueRange inputs, ValueRange outputs) {
-  p.printOptionalAttrDict(
-      op->getAttrs(),
-      /*elidedAttrs=*/{"operandSegmentSizes",
-                       // See generated code in
-                       // LinalgNamedStructuredOps.yamlgen.cpp.inc
-                       "linalg.memoized_indexing_maps"});
+                                   ValueRange inputs, ValueRange outputs,
+                                   ArrayRef<StringRef> elidedAttrs = {}) {
+  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 
   // Printing is shared with generic ops, except for the region and
   // attributes.
@@ -3382,3 +3456,168 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
                                               Location loc) {
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
+
+/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
+/// defaultMap.
+static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
+  auto explicitRange = explictMap.getResults();
+  auto defaultRange = defaultMap.getResults();
+  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
+  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
+  llvm::set_union(explicitSet, defaultSet);
+  return explicitSet == defaultSet;
+}
+
+/// Returns true if the \p explictMap is broadcasted with respect to the
+/// \p defaultMap.
+static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
+  return explictMap.getNumResults() < defaultMap.getNumResults();
+}
+
+/// Verifies the broadcast and transpose semantic sepecified by the explicit
+/// indexing map for the MatmulOp \p op for each operand specified by \p
+/// opIndex.
+static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
+                                                  unsigned opIndex) {
+  SmallVector<AffineMap, 3> opInde...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/115319


More information about the Mlir-commits mailing list