[Mlir-commits] [mlir] [MLIR][Linalg] Introduce linalg.contract (PR #123618)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 20 06:13:32 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Rolf Morel (rolfmorel)

<details>
<summary>Changes</summary>

A new op that allows for representing arbitrary contractions on operands of arbitrary rank, with arbitrary transposes and arbitrary broadcasts specified through its indexing_maps attribute.

Supports the expected lowerings to linalg.generic and to vector.contract.

Corresponding RFC is here: https://discourse.llvm.org/t/mlir-rfc-introduce-linalg-contract/83589

---

Patch is 46.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123618.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+118) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+198-26) 
- (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+165-3) 
- (modified) mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir (+50) 
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+97) 
- (modified) mlir/test/Dialect/Linalg/loops.mlir (+47) 
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+21-2) 
- (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+14-3) 
- (modified) mlir/test/Dialect/Linalg/tile-tensors.mlir (+46) 
- (modified) mlir/test/Dialect/Linalg/transform-op-vectorize.mlir (+33) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index fff4048ee125e0..d4277bd34f3946 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -680,6 +680,124 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
     }];
 }
 
+//===----------------------------------------------------------------------===//
+// Contract op.
+//===----------------------------------------------------------------------===//
+
+def ContractOp : LinalgStructuredBase_Op<"contract", [
+               AttrSizedOperandSegments,
+               LinalgContractionOpInterface]> {
+  let summary = [{
+    Perform a contraction on two inputs, accumulating on top of a third.
+  }];
+  let description = [{
+    The semantics of contracting inputs `A` and `B` on top of `C` to produce
+    output `D` is given by
+
+      `D[H] = (SUM_{(I ∪ J) \ H} A[I] * B[J]) + C[H]`
+
+    where `I`, `J`, and `H` are multi-indices, i.e. sequences/ordered sets of
+    dimension identifiers (meant to range over valid indices), corresponding to
+    the co-domains of the (projected permutation) `indexing_maps` of `A`, `B`
+    and `C`, respectively. `SUM_{dims}` means reduce over all valid indices for
+    the dimensions in the set `dims`.
+
+    The iteration space consists of all dimensions in `I`, `J` and `H`, i.e. the
+    domain of each of the `affine_map`s. Like for einsums, the iteration type of
+    each dim is inferred and is either:
+
+    - reduction: the dim occurs in (the multi-index of) `A` and `B` but not `C`.
+      Per the above semantics, these dims will be contracted, i.e. reduced over.
+
+    - parallel: the dim occurs in `C` and at least one of `A` and `B`, and -
+      deriving from matmul terminology - is either an "M-like" dim (if in `A`
+      and `C`), an "N-like" dim (if in `B` and `C`) or a "batch"-dim (if in `A`,
+      `B`, and `C`).
+
+    For example, batch-matmul is given by `I = ⟨ b, m, k ⟩`, `J = ⟨ b, k, n ⟩`,
+    `H = ⟨ b, m, n ⟩` (with `k` as a contracting reduction-dimension while `m`,
+    `n` and `b` are of parallel iteration-type) and gets represented as:
+
+    ```
+    %0 = linalg.contract
+        indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                         affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                         affine_map<(batch, m, n, k) -> (batch, m, n)>]
+        ins(%arg0, %arg1: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+        outs(%arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+    ```
+
+    Note that by permuting the dims in the co-domains of the `affine_map`s, we
+    can apply arbitrary transposes to the inputs and output. Similarly,
+    arbitrary broadcasts can be achieved through leaving out dims on either
+    input operand.
+
+    Numeric casting is performed on the operands to the inner multiplication,
+    promoting them to the same data type as the accumulator/output.
+  }];
+
+  let arguments = (ins
+    Variadic<AnyType>:$inputs,
+    Variadic<AnyShaped>:$outputs,
+    AffineMapArrayAttr:$indexing_maps
+  );
+  let results = (outs Variadic<AnyShaped>:$result_tensors);
+  let regions = (region SizedRegion<1>:$combiner);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+      "ValueRange":$outputs, "ArrayAttr":$indexingMaps,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, resultTensorTypes, inputs,
+                          outputs, attributes, regionBuilder);
+      }]>,
+    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
+      "ArrayAttr":$indexingMaps,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
+      [{
+        $_state.addAttribute("indexing_maps", indexingMaps);
+        buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
+                          attributes, regionBuilder);
+      }]>
+  ];
+  let hasCustomAssemblyFormat = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    // Declare/implement functions necessary for LinalgStructuredInterface.
+    /// Infer iterator types for each dim in the domain of IndexingMaps.
+    SmallVector<utils::IteratorType> getIteratorTypesArray();
+
+    /// IndexingMaps always depends on attr associated to current Op instance.
+    bool hasDynamicIndexingMaps() { return true; };
+    bool hasUserDefinedMaps() { return true; };
+
+    static unsigned getNumRegionArgs();
+
+    static void regionBuilder(ImplicitLocOpBuilder &b,
+                              Block &block, ArrayRef<NamedAttribute> attrs);
+
+    static std::function<void(ImplicitLocOpBuilder &,
+                              Block &, ArrayRef<NamedAttribute>)>
+    getRegionBuilder() {
+      return regionBuilder;
+    }
+
+    std::string getLibraryCallName() {
+      return "op_has_no_registered_library_name";
+    }
+
+    // Implement function necessary for DestinationStyleOpInterface.
+    ::mlir::MutableOperandRange getDpsInitsMutable() {
+      return getOutputsMutable();
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..355ed2d269291c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3528,44 +3528,45 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
   return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
 }
 
-ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
-  SmallVector<Attribute, 3> indexingMapsAttr;
-  Attribute mapAttr;
-  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
-    if (parser.parseEqual())
-      return failure();
+FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
+  if (parser.parseOptionalKeyword("indexing_maps"))
+    return {nullptr}; // Success in case indexing_maps was not provided.
 
-    if (parser.parseLSquare())
+  SmallVector<Attribute> indexingMaps;
+
+  auto parseIndexingMap = [&]() -> ParseResult {
+    AffineMapAttr affineMapAttr;
+    if (parser.parseAttribute(affineMapAttr))
       return failure();
+    indexingMaps.push_back(affineMapAttr);
+    return success();
+  };
 
-    do {
-      if (parser.parseAttribute(mapAttr))
-        return failure();
-      if (!isa<AffineMapAttr>(mapAttr)) {
-        return parser.emitError(parser.getCurrentLocation(),
-                                "expected affine map attribute");
-      }
-      indexingMapsAttr.push_back(mapAttr);
+  if (parser.parseEqual() ||
+      parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
+                                     parseIndexingMap))
+    return failure();
 
-      if (parser.parseOptionalComma())
-        break;
-    } while (true);
+  return parser.getBuilder().getArrayAttr(indexingMaps);
+}
 
-    if (parser.parseRSquare())
-      return failure();
-  }
-  // Initialize indexingMaps, if not supplied explicitly.
-  if (indexingMapsAttr.empty()) {
-    indexingMapsAttr = llvm::map_to_vector(
+ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr))
+    return failure();
+
+  if (*indexingMapsAttr == nullptr) {
+    auto indexingMapAttrs = llvm::map_to_vector(
         MatmulOp::getDefaultIndexingMaps(parser.getContext()),
         [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+    indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
   }
-  result.addAttribute("indexing_maps",
-                      parser.getBuilder().getArrayAttr(indexingMapsAttr));
 
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
   return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
                                 MatmulOp::getRegionBuilder());
 }
+
 void MatmulOp::print(OpAsmPrinter &p) {
   SmallVector<StringRef, 3> elidedAttrs = {
       "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
@@ -3599,6 +3600,7 @@ LogicalResult MatmulOp::verify() {
 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   return memref::foldMemRefCast(*this);
 }
+
 void MatmulOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -3611,5 +3613,175 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
   return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
 }
 
+//===----------------------------------------------------------------------===//
+// ContractOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
+  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
+  /// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
+  /// domains are all the same, and each implements a projected permutation.
+  /// Each dim in the domain must occur for at least one operand and is
+  /// classified as either batch, N-like, M-like, or K-like. Only the latter
+  /// corresponds to a reduction _and_ it is the only dim-kind which does not
+  /// occur for the output operand. We use this fact for fast inference:
+  // NB: In case we allow dims to occur solely for one input, the above still
+  //     holds: per the einsum semantics, these are reduction dims as well.
+  auto dimsInOutput = SmallVector<bool>(outAffineMap.getNumDims(), false);
+  for (auto result : outAffineMap.getResults()) {
+    auto dimExpr = dyn_cast<AffineDimExpr>(result);
+    assert(dimExpr && "affine_map is a projected permutation");
+    dimsInOutput[dimExpr.getPosition()] = true;
+  }
+
+  SmallVector<utils::IteratorType> iteratorTypes;
+  for (auto dimOccursInOutput : dimsInOutput)
+    iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
+                                              : utils::IteratorType::reduction);
+
+  return iteratorTypes;
+}
+
+unsigned ContractOp::getNumRegionArgs() { return 3; }
+
+/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
+void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
+                               ArrayRef<NamedAttribute> attrs) {
+  assert(block.getNumArguments() == 3 &&
+         "ContractOp regionBuilder expects 3 args");
+  RegionBuilderHelper helper(b, block);
+
+  TypeFn castSignedness = TypeFn::cast_signed;
+  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
+    return attr.getName() == "cast";
+  });
+  if (castIter != attrs.end()) {
+    if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
+      castSignedness = attr.getValue();
+  }
+
+  // TODO: Support fields with operators besides mult & add.
+  Type outType = block.getArgument(2).getType();
+  Value lhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
+  Value rhsAtOutType =
+      helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
+  Value productAtOutType =
+      helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
+  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+                                      productAtOutType);
+  helper.yieldOutputs({result});
+}
+
+ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
+  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
+  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
+    return parser.emitError(parser.getCurrentLocation(),
+                            "expected 'indexing_map' attribute");
+  result.addAttribute("indexing_maps", *indexingMapsAttr);
+
+  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
+                                regionBuilder);
+}
+
+void ContractOp::print(OpAsmPrinter &p) {
+  p << " indexing_maps = [";
+  llvm::interleaveComma(getIndexingMaps(), p,
+                        [&](Attribute attr) { p.printAttribute(attr); });
+  p << "]";
+  printNamedStructuredOp(
+      p, getOperation(), getInputs(), getOutputs(),
+      /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
+}
+
+LogicalResult ContractOp::verify() {
+  int iterationSpaceDims = -1;
+  // Maps iter space dim (as index) to num of occurrences in inputs and output.
+  SmallVector<size_t> inOccurrences;
+  SmallVector<size_t> outOccurrences;
+
+  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
+                                   bool isInput) -> LogicalResult {
+    if (iterationSpaceDims == -1) {
+      iterationSpaceDims = affineMap.getNumDims();
+      inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+      outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
+    } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
+      return emitError("iteration spaces of provided affine_maps differ");
+    }
+
+    if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
+      if (affineMap.getNumResults() != shapedType.getRank())
+        return emitError("ranks of shaped operand and co-domain of "
+                         "corresponding affine_map differ");
+    } else if (affineMap.getNumResults() != 0) {
+      return emitError("affine_map specifies shaped access while operand has "
+                       "non-shaped type");
+    }
+
+    if (!affineMap.isProjectedPermutation())
+      return emitError("provided affine_map is not a projected permutation");
+
+    for (AffineExpr affineExpr : affineMap.getResults()) {
+      auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
+      if (!affineDimExpr)
+        llvm_unreachable("affine_map is a projected permutation");
+
+      if (isInput)
+        inOccurrences[affineDimExpr.getPosition()] += 1;
+      else
+        outOccurrences[affineDimExpr.getPosition()] += 1;
+    }
+
+    return success();
+  };
+
+  for (auto &&[affineMap, operandType, isInput] :
+       llvm::zip(getIndexingMapsArray(), getOperandTypes(),
+                 SmallVector<bool>{true, true, false}))
+    if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
+      return failure(); // NOTE: checking lambda will emit error.
+
+  bool hasContractingDim = false;
+  for (auto &&[inOccCount, outOccCount] : zip(inOccurrences, outOccurrences)) {
+    hasContractingDim |= inOccCount == 2 && outOccCount == 0;
+
+    if (inOccCount == 0)
+      return emitError("iteration space dim not used by either input");
+
+    // NB: A dim which occurs for only one input operand and not for the output.
+    //     In terms of einsum semantics, such dims have a sensible meaning -
+    //     namely an additional reduction per such dim - though this can also
+    //     always be expressed through an additional op. Additionally, at time
+    //     of writing, vector.contract's verifier accepts these dims but many of
+    //     its lowerings do not handle these kinds of dims. Hence...
+    // TODO: Remove following once we have comprehensive support for input-only
+    //       reduction dims, at both the linalg- and vector-dialect levels.
+    if (inOccCount == 1 && outOccCount != 1)
+      return emitError("iter type of dim is not one of M, N, K or batch");
+  }
+
+  if (!hasContractingDim)
+    return emitError("'indexing_maps' do not specify a contracting dimension");
+
+  return success();
+}
+
+LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+void ContractOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (hasPureTensorSemantics())
+    return;
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
+}
+
+Speculation::Speculatability ContractOp::getSpeculatability() {
+  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
+}
+
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index aba26c35931fd3..f7e570d5ce38f0 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -943,7 +943,6 @@ func.func @matmul_transpose_a_explicit(%arg0: memref<5x3xf32>, %arg1: memref<5x7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -969,7 +968,6 @@ func.func @matmul_transpose_b_explicit(%arg0: memref<3x5xf32>, %arg1: memref<7x5
                       ]
                       ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
@@ -996,9 +994,173 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
                       ]
                       ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
                       outs(%arg2: memref<3x7xf32>)
-                      
   return
 }
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @contract_matmul(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d0, d2)>,
+                    affine_map<(d0, d1, d2) -> (d2, d1)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<3x5xf32>, memref<5x7xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @contract_matmul_transpose_a_b(
+// CHECK-SAME:                                             %[[VAL_0:.*]]: memref<5x3xf32>,
+// CHECK-SAME:                                             %[[VAL_1:.*]]: memref<7x5xf32>,
+// CHECK-SAME:                                             %[[VAL_2:.*]]: memref<3x7xf32>) {
+
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-NEXT:     ^{{.+}}(
+// CHECK-NEXT:      arith.mulf
+// CHECK-NEXT:      arith.addf
+// CHECK-NEXT:      linalg.yield
+
+func.func @contract_matmul_transpose_a_b(%arg0: memref<5x3xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) {
+  linalg.contract indexing_maps = [
+                    affine_map<(d0, d1, d2) -> (d2, d0)>,
+                    affine_map<(d0, d1, d2) -> (d1, d2)>,
+                    affine_map<(d0, d1, d2) -> (d0, d1)>
+                  ]
+                  ins(%arg0, %arg1 : memref<5x3xf32>, memref<7x5xf32>)
+                  outs(%arg2: memref<3x7xf32>)
+
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHEC...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/123618


More information about the Mlir-commits mailing list