[Mlir-commits] [mlir] cd9bacd - [mlir] generalize matchers to support batch matmul

Alex Zinenko llvmlistbot at llvm.org
Fri Jul 7 07:44:51 PDT 2023


Author: Alex Zinenko
Date: 2023-07-07T14:44:44Z
New Revision: cd9bacdf7fda6fcb1f06e96f39af7f537a2542ad

URL: https://github.com/llvm/llvm-project/commit/cd9bacdf7fda6fcb1f06e96f39af7f537a2542ad
DIFF: https://github.com/llvm/llvm-project/commit/cd9bacdf7fda6fcb1f06e96f39af7f537a2542ad.diff

LOG: [mlir] generalize matchers to support batch matmul

Mostly the same logic applies, with a different rank.

Additionally expose the logic to do identify contraction dimensions and
contraction-like bodies as independent transform ops. This allows us to
recognize "generic" operations and not only the named ones.

Rework the contraction body matching logic to no longer rely on
contraction operations beign uniquely named.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D154498

Added: 
    mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
    mlir/test/Integration/Dialect/Transform/match_matmul_common.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
    mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
    mlir/test/Dialect/Linalg/match-ops-invalid.mlir
    mlir/test/Integration/Dialect/Transform/match_matmul.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0562f3779e08b8..a330d9cf9fc60d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -72,8 +72,28 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
 
 namespace detail {
 
+/// Returns true if the block contains a contraction of the following form:
+///
+///   %0 = <elemwise>(permutation-of(cu(block-argument-0),
+///                                  cu(block-argument-1)))
+///   %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
+///   return-like cu(%1)
+///
+/// where <elemwise> and <reduce> are binary operations constituting a
+/// contraction (in the canonical case, <elemwise> is a multiplication and
+/// <reduce> is an addition). The name and other properties of these operations
+/// are checked by `isaPair`. All operands of all operations may be supplied
+/// through a chain of side effect-free unary operations, such as casts, which
+/// is denoted as `cu` above.
+///
+/// When the body does not contain a contraction, a more precise description of
+/// the failed precondition is send to the `errs` stream, if provided.
+bool isContractionBody(Block &block,
+                       function_ref<bool(Operation *, Operation *)> isaPair,
+                       llvm::raw_ostream &errs = llvm::nulls());
+
 /// Result of matching a Linalg generic against the predicates of it being a
-/// contractiom.
+/// contraction.
 enum class MatchContractionResult;
 
 /// Checks whether `op` conforms to ContractionOpInterface and populates

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index 6f458d4e2e3b2a..ad348d0ce89a64 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -106,6 +106,12 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
       * `passthrough`: the body of the structured payload op only forwards
         inputs to the outputs (copy or broadcast).
 
+      * `contraction`: the body of the structured payload op is a contraction
+        of the form `<red>(<elem>(bbarg0, bbarg1), bbarg2)` where `<elem>` and
+        `<red>` are binary operations whose names are specified in the attribute
+        and operands can be permuted and optionally forwarded through a chain of
+        unary side effect-free operations.
+
   }], StructuredPredicate.extraDescription, [{
 
     #### Return modes
@@ -116,12 +122,54 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
   }]);
   let arguments = (ins TransformHandleTypeInterface:$operand_handle,
                        OptionalAttr<I64Attr>:$reduction_position,
-                       UnitAttr:$passthrough);
+                       UnitAttr:$passthrough,
+                       OptionalAttr<StrArrayAttr>:$contraction);
   let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle)";
   let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
   let hasVerifier = 1;
 }
 
+def MatchStructuredClassifyContractionDimsOp
+    : Op<Transform_Dialect, "match.structured.classify_contraction_dims", [
+    SingleOpMatcher,
+    StructuredPredicate,
+    MatchOpInterface,
+    MemoryEffectsOpInterface]> {
+  let summary =
+      "Checks if an operation has contraction-like dimensions and returns them";
+  let description = !strconcat([{
+    Checks if the structured payload op has contraction-like dimensions as
+    follows:
+
+      C(batch, m, n) += A(batch, m, k) * B(batch, k, n)
+
+    That is:
+    
+      - 'batch' are parallel dimensions used in inputs and result;
+      - 'm' are parallel dimensions used in the LHS and result;
+      - 'n' are parallel dimensions used in rhe RHS and result;
+      - 'k' are reduction dimensions present only in LHS and RHS.
+
+    Note that this doesn't check the operation in the body.
+
+  }], StructuredPredicate.extraDescription, [{
+
+    #### Return modes
+
+    Succeeds if the operation has the contraction-like dimensions, produces a
+    silenceable failure otherwise.
+  }]);
+
+  let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+  let results = (outs TransformParamTypeInterface:$batch,
+                      TransformParamTypeInterface:$m,
+                      TransformParamTypeInterface:$n,
+                      TransformParamTypeInterface:$k);
+  let assemblyFormat =
+    "$operand_handle attr-dict `:` functional-type(operands, results)";
+  let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
+}
+
 class StructuredDimDescription<string kind> {
   string description = !strconcat([{
      The following }], kind ,[{ specifications are supported:

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index b720b0502d08cd..08de32818f0047 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -52,68 +52,106 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
 // ContractionOpInterface implementation
 //===----------------------------------------------------------------------===//
 
-/// Return true if the use-def chain from `v` to `from` consists of 0 or more
-/// unary single-operand operations.
+/// If the value is defined by a chain of unary side effect-free, go up the
+/// use-def chain until the first value that isn't defined by such an op.
 // TODO: relax to multi-operands with constants, which are technically unary ops
 // as needed (e.g. add5).
-static bool isChainOfUnaryOpsFrom(Value v, Value from) {
-  while (true) {
-    if (v == from)
-      return true;
-    Operation *op = v.getDefiningOp();
-    if (!op || op->getNumOperands() != 1)
-      return false;
-    v = op->getOperand(0);
-  };
+static Value getSourceSkipUnary(Value value) {
+  Operation *op = value.getDefiningOp();
+  while (op && op->getNumOperands() == 1) {
+    auto iface = dyn_cast<MemoryEffectOpInterface>(op);
+    if (!iface || !iface.hasNoEffect())
+      break;
+    value = op->getOperand(0);
+    op = value.getDefiningOp();
+  }
+  return value;
 }
 
-/// Return the unique instance of OpType in `block` if it is indeed unique.
-/// Return null if none or more than 1 instances exist.
-template <typename OpType>
-static OpType getSingleOpOfType(Block &block) {
-  OpType res = nullptr;
-  block.walk([&](OpType op) {
-    if (res) {
-      res = nullptr;
-      return WalkResult::interrupt();
-    }
-    res = op;
-    return WalkResult::advance();
-  });
-  return res;
-}
+bool mlir::linalg::detail::isContractionBody(
+    Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
+    llvm::raw_ostream &errs) {
+  if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
+    errs << "no terminator in the block";
+    return false;
+  }
+
+  if (block.getNumArguments() != 3) {
+    errs << "expected block with 3 arguments";
+    return false;
+  }
+
+  Operation *terminator = block.getTerminator();
+  if (terminator->getNumOperands() != 1) {
+    errs << "expected terminator with 1 operand";
+    return false;
+  }
+
+  Value yielded = getSourceSkipUnary(terminator->getOperand(0));
+  Operation *reductionOp = yielded.getDefiningOp();
+  if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
+    errs << "expected reduction op to be binary";
+    return false;
+  }
+
+  Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
+  Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
 
-/// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
-/// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
-/// unary operations that may change the type.
-template <typename AddOpType, typename MulOpType>
-static bool isAddMul(Block &block) {
-  if (block.getNumArguments() != 3)
+  if (reductionLHS != block.getArgument(2) &&
+      reductionRHS != block.getArgument(2)) {
+    errs << "expected reduction to take block argument #2 as one of the "
+            "operands (modulo unary casts)";
     return false;
-  Operation *yieldOp = block.getTerminator();
-  if (yieldOp->getNumOperands() != 1)
+  }
+
+  Value contributed = getSourceSkipUnary(
+      isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
+  Operation *elementwiseOp = contributed.getDefiningOp();
+  if (elementwiseOp->getNumResults() != 1 ||
+      elementwiseOp->getNumOperands() != 2) {
+    errs << "expected elementwise op to be binary";
+    return false;
+  }
+
+  if (!isaPair(elementwiseOp, reductionOp)) {
+    errs << "expected reduction/elementwise op kind not satisfied";
     return false;
+  }
+
+  Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
+  Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
+  if ((elementwiseLHS == block.getArgument(0) &&
+       elementwiseRHS == block.getArgument(1)) ||
+      (elementwiseLHS == block.getArgument(1) &&
+       elementwiseRHS == block.getArgument(0))) {
+    return true;
+  }
 
-  AddOpType addOp = getSingleOpOfType<AddOpType>(block);
-  MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
-  if (!addOp || !mulOp)
+  errs << "expected elementwise op to apply to block arguments (modulo unary "
+          "casts)";
+  return false;
+}
+
+/// Returns true if the two operations are of the kinds specified by a pair of
+/// consecutive template arguments.
+template <typename AddOpTy, typename MulOpTy, typename... Args>
+static bool isPairTemplateImpl(Operation *add, Operation *mul) {
+  static_assert(sizeof...(Args) % 2 == 0,
+                "expected an even number of template arguments");
+  if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
+    return true;
+
+  if constexpr (sizeof...(Args) > 0)
+    return isPairTemplateImpl<Args...>(add, mul);
+  else
     return false;
+}
 
-  Value argA = block.getArgument(0), argB = block.getArgument(1);
-  Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
-  Value mul = mulOp->getResult(0);
-  Value argC = block.getArgument(2);
-  Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
-  Value add = addOp->getResult(0);
-  Value res = yieldOp->getOperand(0);
-  // Result traces back to add.
-  auto un = isChainOfUnaryOpsFrom;
-  bool success = un(res, add);
-  // One of the operands of add traces back to argC, the other to the mul.
-  success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
-  // One of the operands of mul traces back to argA, the other to argB.
-  success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
-  return success;
+/// Returns true if the block is a body of a contraction with the kinds of
+/// operations given pairwise by template arguments.
+template <typename... Args>
+static bool isContractionBody(Block &block) {
+  return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
 }
 
 /// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
@@ -231,12 +269,16 @@ mlir::linalg::detail::isContractionInterfaceImpl(
                    [](AffineMap m) { return !m.isProjectedPermutation(); }))
     return MatchContractionResult::NotProjectedPermutations;
   // TODO: more fields than add/mul.
-  if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
-      !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
-      !isAddMul<complex::AddOp, complex::MulOp>(
-          linalgOp->getRegion(0).front()) &&
-      !isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
+  // clang-format off
+  if (!::isContractionBody<
+        arith::MulFOp, arith::AddFOp,
+        arith::MulIOp, arith::AddIOp,
+        complex::MulOp, complex::AddOp,
+        arith::AndIOp, arith::OrIOp>(
+      *linalgOp.getBlock())) {
     return MatchContractionResult::NotAddMul;
+  }
+  // clang-format on
 
   if (dimensions) {
     FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 24debc1dffae4b..2220930e8f4dd0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
 #include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -186,17 +187,78 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
     }
     return DiagnosedSilenceableFailure::success();
   }
+  if (std::optional<ArrayAttr> contractionOps = getContraction()) {
+    Block &body = linalgOp->getRegion(0).front();
+    std::string message;
+    llvm::raw_string_ostream os(message);
+    bool result = linalg::detail::isContractionBody(
+        body,
+        [&](Operation *elem, Operation *red) {
+          return elem->getName().getStringRef() ==
+                     (*contractionOps)[0].cast<StringAttr>().getValue() &&
+                 red->getName().getStringRef() ==
+                     (*contractionOps)[1].cast<StringAttr>().getValue();
+        },
+        os);
+    if (result)
+      return DiagnosedSilenceableFailure::success();
+    return emitSilenceableError() << "contraction: " << os.str();
+  }
   return emitDefiniteFailure() << "unknown body condition";
 }
 
 LogicalResult transform::MatchStructuredBodyOp::verify() {
-  if (getReductionPosition() && getPassthrough()) {
-    return emitOpError() << "reduction position and passthrough conditions are "
-                            "mutually exclusive";
+  int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
+                       getContraction().has_value();
+
+  if (numOptions > 1) {
+    std::string attributeNames;
+    llvm::raw_string_ostream os(attributeNames);
+    llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
+                                               getPassthroughAttrName(),
+                                               getContractionAttrName()},
+                          os);
+    return emitOpError() << "only one of {" << os.str() << "} is allowed";
+  }
+
+  if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
+    if (contractionAttr->size() != 2) {
+      return emitOpError() << "expects " << getContractionAttrName()
+                           << " to contain two elements";
+    }
   }
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MatchStructuredClassifyContractionDimsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
+    Operation *current, transform::TransformResults &results,
+    transform::TransformState &state) {
+  FailureOr<linalg::ContractionDimensions> contractionDims =
+      linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
+  if (failed(contractionDims))
+    return emitSilenceableError() << "could not infer contraction dimensions";
+
+  MLIRContext *context = current->getContext();
+  Builder builder(context);
+  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
+    return llvm::to_vector(
+        llvm::map_range(values, [&](unsigned value) -> Attribute {
+          return builder.getI64IntegerAttr(value);
+        }));
+  };
+  results.setParams(getBatch().cast<OpResult>(),
+                    makeI64Attrs(contractionDims->batch));
+  results.setParams(getM().cast<OpResult>(), makeI64Attrs(contractionDims->m));
+  results.setParams(getN().cast<OpResult>(), makeI64Attrs(contractionDims->n));
+  results.setParams(getK().cast<OpResult>(), makeI64Attrs(contractionDims->k));
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // Utilities for structured match predicates.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index e4b75e567ee843..d6f85b5218277e 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -777,7 +777,6 @@ module attributes { transform.with_named_sequence } {
 
 // -----
 
-
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @match_input_indexing_map(%arg0: !transform.any_op {transform.readonly})
       -> (!transform.affine_map, !transform.any_op) {
@@ -831,3 +830,79 @@ module attributes { transform.with_named_sequence } {
     return
   }
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @match_contraction(%arg0: !transform.any_op {transform.readonly})
+    -> (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
+    %1:4 = transform.match.structured %arg0 : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
+    ^bb0(%struct: !transform.any_op):
+      transform.match.structured.body %struct { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
+      %0:4 = transform.match.structured.classify_contraction_dims %struct
+        : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
+      transform.match.structured.yield %0#0, %0#1, %0#2, %0#3
+        : !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+    }
+    transform.yield %arg0, %1#0, %1#1, %1#2, %1#3 : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+  }
+
+  transform.named_sequence @print_contraction(
+      %op: !transform.any_op {transform.readonly},
+      %batch: !transform.param<i64> {transform.readonly},
+      %m: !transform.param<i64> {transform.readonly},
+      %n: !transform.param<i64> {transform.readonly},
+      %k: !transform.param<i64> {transform.readonly}) {
+    transform.test_print_remark_at_operand %op, "contraction" : !transform.any_op
+    transform.test_print_param %batch, "batch dims" at %op : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %m, "m dims" at %op : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %n, "n dims" at %op : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %k, "k dims" at %op : !transform.param<i64>, !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+  ^bb0(%arg0: !transform.any_op):
+    %3 = transform.foreach_match in %arg0 @match_contraction -> @print_contraction : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+module attributes { transform.target_tag = "start_here" } {
+  func.func @matmul_simple(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf64> {
+    %cst = arith.constant 0.0 : f64
+    %empty = tensor.empty() : tensor<10x15xf64>
+    %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf64>) -> tensor<10x15xf64>
+    // expected-remark @below {{contraction}}
+    // expected-remark @below {{batch dims}}
+    // expected-remark @below {{m dims 0}}
+    // expected-remark @below {{n dims 1}}
+    // expected-remark @below {{k dims 2}}
+    %result = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf32>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf64>) -> tensor<10x15xf64>
+    return %result : tensor<10x15xf64>
+  }
+
+  func.func @double_batch(%lhs: tensor<40x10x50x20xf32>, %rhs: tensor<40x20x50x15xf32>) -> tensor<40x10x50x15xf32> {
+    %cst = arith.constant 0.0 : f32
+    %empty = tensor.empty() : tensor<40x10x50x15xf32>
+    %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<40x10x50x15xf32>) -> tensor<40x10x50x15xf32>
+    // expected-remark @below {{contraction}}
+    // expected-remark @below {{batch dims 0 : i64, 2 : i64}}
+    // expected-remark @below {{m dims 1}}
+    // expected-remark @below {{n dims 3}}
+    // expected-remark @below {{k dims 4}}
+    %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>,
+                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2, d3)>,
+                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]
+    } ins(%lhs, %rhs : tensor<40x10x50x20xf32>, tensor<40x20x50x15xf32>)
+      outs(%fill : tensor<40x10x50x15xf32>) {
+    ^bb(%arg0: f32, %arg1: f32, %arg2: f32):
+      %0 = arith.mulf %arg0, %arg1 : f32
+      %1 = arith.addf %arg2, %0 : f32
+      linalg.yield %1 : f32
+    } -> tensor<40x10x50x15xf32>
+    return %result : tensor<40x10x50x15xf32>
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
index 6dade865c64c16..ec99e205090c4c 100644
--- a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
@@ -64,7 +64,7 @@ transform.sequence failures(suppress) {
 ^bb0(%arg0: !transform.any_op):
   transform.match.structured %arg0 : !transform.any_op {
   ^bb1(%arg1: !transform.any_op):
-    // expected-error @below {{reduction position and passthrough conditions are mutually exclusive}}
+    // expected-error @below {{only one of {"reduction_position", "passthrough", "contraction"} is allowed}}
     transform.match.structured.body %arg1 { passthrough, reduction_position = 0 } : !transform.any_op
     transform.match.structured.yield
   }

diff  --git a/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
new file mode 100644
index 00000000000000..73bc243ad76060
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @_match_matmul_like(
+      %entry: !transform.any_op {transform.readonly},
+      %rank: !transform.param<i64> {transform.readonly})
+      -> (!transform.any_op, !transform.any_op, !transform.param<i64>,
+          !transform.type, !transform.type, !transform.type,
+          !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
+
+  transform.named_sequence @match_bmm(%entry: !transform.any_op {transform.readonly})
+      -> (!transform.any_op, !transform.any_op, !transform.param<i64>,
+          !transform.type, !transform.type, !transform.type, !transform.param<i64>) {
+    transform.match.operation_name %entry ["linalg.batch_matmul", "linalg.generic"] : !transform.any_op
+    %c3 = transform.param.constant 4 : i64 -> !transform.param<i64>
+    %fill, %bmm, %dims, %lhs_type, %rhs_type, %res_type, %batch, %m, %n, %k =
+      transform.include @_match_matmul_like failures(propagate) (%entry, %c3)
+        : (!transform.any_op, !transform.param<i64>)
+        -> (!transform.any_op, !transform.any_op, !transform.param<i64>,
+            !transform.type, !transform.type, !transform.type,
+            !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
+
+    transform.yield %fill, %bmm, %dims, %lhs_type, %rhs_type, %res_type, %batch
+        : !transform.any_op, !transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type, !transform.param<i64>
+  }
+
+  transform.named_sequence @print_bmm(
+      %fill: !transform.any_op {transform.readonly},
+      %bmm: !transform.any_op {transform.readonly},
+      %dims: !transform.param<i64> {transform.readonly},
+      %lhs_type: !transform.type {transform.readonly},
+      %rhs_type: !transform.type {transform.readonly},
+      %res_type: !transform.type {transform.readonly},
+      %batch: !transform.param<i64> {transform.readonly}) {
+    transform.test_print_remark_at_operand %fill, "fill" : !transform.any_op
+    transform.test_print_remark_at_operand %bmm, "batch matmul" : !transform.any_op
+    transform.test_print_param %dims, "dimensions" at %bmm : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %lhs_type, "LHS type" at %bmm : !transform.type, !transform.any_op
+    transform.test_print_param %rhs_type, "RHS type" at %bmm : !transform.type, !transform.any_op
+    transform.test_print_param %res_type, "result type" at %bmm : !transform.type, !transform.any_op
+    transform.test_print_param %batch, "batch dimension" at %bmm : !transform.param<i64>, !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb(%root: !transform.any_op):
+    foreach_match in %root
+      @match_bmm -> @print_bmm
+      : (!transform.any_op) -> !transform.any_op
+  }
+}
+
+func.func @bmm_simple(%lhs: tensor<40x10x20xf16>, %rhs: tensor<40x20x15xf32>) -> tensor<40x10x15xf64>{
+  %cst = arith.constant 0.0 : f64
+  %empty = tensor.empty() : tensor<40x10x15xf64>
+  // expected-remark @below {{fill}}
+  %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<40x10x15xf64>) -> tensor<40x10x15xf64>
+  // expected-remark @below {{batch matmul}}
+  // expected-remark @below {{dimensions 40 : i64, 10 : i64, 15 : i64, 20 : i64}}
+  // expected-remark @below {{LHS type f16}}
+  // expected-remark @below {{RHS type f32}}
+  // expected-remark @below {{result type f64}}
+  // expected-remark @below {{batch dimension 0}}
+  %result = linalg.batch_matmul ins(%lhs, %rhs: tensor<40x10x20xf16>, tensor<40x20x15xf32>) outs(%fill: tensor<40x10x15xf64>) -> tensor<40x10x15xf64>
+  return %result : tensor<40x10x15xf64>
+}

diff  --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index 8f6fb8b3a50757..f164a3d1bd99dd 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -1,36 +1,26 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='transform-library-file-name=%p/match_matmul_common.mlir' --verify-diagnostics
 
 module attributes { transform.with_named_sequence } {
+  transform.named_sequence @_match_matmul_like(
+      %entry: !transform.any_op {transform.readonly},
+      %rank: !transform.param<i64> {transform.readonly})
+      -> (!transform.any_op, !transform.any_op, !transform.param<i64>,
+          !transform.type, !transform.type, !transform.type,
+          !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
+
   transform.named_sequence @match_matmul(%entry: !transform.any_op {transform.readonly})
       -> (!transform.any_op, !transform.any_op, !transform.param<i64>,
           !transform.type, !transform.type, !transform.type) {
-    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
-    %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
-    %capture:5 = transform.match.structured %entry : (!transform.any_op)
-        -> (!transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type) {
-    ^bb0(%struct: !transform.any_op):
-      transform.match.operation_name %struct ["linalg.matmul"] : !transform.any_op
-      %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param<i64>
-      
-      %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param<i64>
-      %n_inits = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
-      transform.match.param.cmpi eq %n_inputs, %c2 : !transform.param<i64>
-      transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>
-      
-      %lhs = transform.match.structured.input %struct[0] : (!transform.any_op) -> !transform.any_value
-      %rhs = transform.match.structured.input %struct[1] : (!transform.any_op) -> !transform.any_value
-      %res = transform.match.structured.result %struct[0] : (!transform.any_op) -> !transform.any_value
-      %lhs_type = transform.get_type elemental %lhs : (!transform.any_value) -> !transform.type
-      %rhs_type = transform.get_type elemental %rhs : (!transform.any_value) -> !transform.type
-      %res_type = transform.get_type elemental %res : (!transform.any_value) -> !transform.type
+    transform.match.operation_name %entry ["linalg.matmul", "linalg.generic"] : !transform.any_op
+    %c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
+    %fill, %matmul, %dims, %lhs_type, %rhs_type, %res_type, %kinds:4 =
+      transform.include @_match_matmul_like failures(propagate) (%entry, %c3)
+        : (!transform.any_op, !transform.param<i64>)
+        -> (!transform.any_op, !transform.any_op, !transform.param<i64>,
+            !transform.type, !transform.type, !transform.type,
+            !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
 
-      %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_op
-      transform.match.operation_name %init ["linalg.fill"] : !transform.any_op
-
-      transform.match.structured.yield %init, %dims, %lhs_type, %rhs_type, %res_type
-          : !transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type
-    }
-    transform.yield %capture#0, %entry, %capture#1, %capture#2, %capture#3, %capture#4
+    transform.yield %fill, %matmul, %dims, %lhs_type, %rhs_type, %res_type
         : !transform.any_op, !transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type
   }
 
@@ -90,3 +80,29 @@ func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<2
   %result = linalg.matmul ins(%real_lhs, %rhs: tensor<10x20xf32>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf32>) -> tensor<10x15xf32>
   return %result : tensor<10x15xf32>
 }
+
+func.func @matmul_generic(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf64>{
+  %cst = arith.constant 0.0 : f64
+  %empty = tensor.empty() : tensor<10x15xf64>
+  // expected-remark @below {{fill}}
+  %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf64>) -> tensor<10x15xf64>
+  // expected-remark @below {{matmul}}
+  // expected-remark @below {{dimensions 10 : i64, 15 : i64, 20 : i64}}
+  // expected-remark @below {{LHS type f16}}
+  // expected-remark @below {{RHS type f32}}
+  // expected-remark @below {{result type f64}}
+  %result = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                     affine_map<(d0, d1, d2) -> (d2, d1)>,
+                     affine_map<(d0, d1, d2) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel", "reduction"]
+  } ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf64>) {
+  ^bb(%arg0: f16, %arg1: f32, %arg2: f64):
+    %0 = arith.extf %arg0 : f16 to f32
+    %1 = arith.mulf %0, %arg1 : f32
+    %2 = arith.extf %1 : f32 to f64
+    %3 = arith.addf %2, %arg2 : f64
+    linalg.yield %3 : f64
+  }-> tensor<10x15xf64>
+  return %result : tensor<10x15xf64>
+}

diff  --git a/mlir/test/Integration/Dialect/Transform/match_matmul_common.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul_common.mlir
new file mode 100644
index 00000000000000..78f92efae54fbb
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul_common.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @_match_matmul_like(
+      %entry: !transform.any_op {transform.readonly},
+      %rank: !transform.param<i64> {transform.readonly})
+      -> (!transform.any_op, !transform.any_op, !transform.param<i64>,
+          !transform.type, !transform.type, !transform.type,
+          !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
+    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
+    %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+    %capture:9 = transform.match.structured %entry : (!transform.any_op)
+        -> (!transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type,
+            !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
+    ^bb0(%struct: !transform.any_op):
+      %op_rank = transform.match.structured.rank %struct : (!transform.any_op) -> !transform.param<i64>
+      transform.match.param.cmpi eq %rank, %op_rank : !transform.param<i64>
+      %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param<i64>
+      
+      %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param<i64>
+      %n_inits = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
+      transform.match.param.cmpi eq %n_inputs, %c2 : !transform.param<i64>
+      transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>
+      
+      %lhs = transform.match.structured.input %struct[0] : (!transform.any_op) -> !transform.any_value
+      %rhs = transform.match.structured.input %struct[1] : (!transform.any_op) -> !transform.any_value
+      %res = transform.match.structured.result %struct[0] : (!transform.any_op) -> !transform.any_value
+      %lhs_type = transform.get_type elemental %lhs : (!transform.any_value) -> !transform.type
+      %rhs_type = transform.get_type elemental %rhs : (!transform.any_value) -> !transform.type
+      %res_type = transform.get_type elemental %res : (!transform.any_value) -> !transform.type
+
+      %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_op
+      transform.match.operation_name %init ["linalg.fill"] : !transform.any_op
+
+      transform.match.structured.body %struct { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
+      %dim_kinds:4 = transform.match.structured.classify_contraction_dims %struct
+        : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
+
+      transform.match.structured.yield %init, %dims, %lhs_type, %rhs_type, %res_type, %dim_kinds#0, %dim_kinds#1, %dim_kinds#2, %dim_kinds#3
+          : !transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type,
+            !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+    }
+    transform.yield %capture#0, %entry, %capture#1, %capture#2, %capture#3, %capture#4, 
+                    %capture#5, %capture#6, %capture#7, %capture#8
+        : !transform.any_op, !transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type,
+          !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+  }
+}


        


More information about the Mlir-commits mailing list