[Mlir-commits] [mlir] dd81c6b - [mlir] integration tests for transform dialect matchers

Alex Zinenko llvmlistbot at llvm.org
Wed Jul 5 03:43:37 PDT 2023


Author: Alex Zinenko
Date: 2023-07-05T10:43:30Z
New Revision: dd81c6b8d34fdd53942a1e2532c514723fb10067

URL: https://github.com/llvm/llvm-project/commit/dd81c6b8d34fdd53942a1e2532c514723fb10067
DIFF: https://github.com/llvm/llvm-project/commit/dd81c6b8d34fdd53942a1e2532c514723fb10067.diff

LOG: [mlir] integration tests for transform dialect matchers

Add integration tests exercising transform dialect matchers for slightly
larger compositions of structured ops, namely reductions and matrix
multiplications with optional leading and trailing elementwise
operations.

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D154440

Added: 
    mlir/test/Integration/Dialect/Transform/match_matmul.mlir
    mlir/test/Integration/Dialect/Transform/match_reduction.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
    mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
    mlir/test/Dialect/Linalg/transform-op-tile.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index c409adb05f4b30..6f458d4e2e3b2a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -244,7 +244,8 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
 
   // TODO: allow this to bind multiple inputs simultaneously after checking that
   // `transform.foreach` works well in matches.
-  let results = (outs Optional<TransformAnyHandle>:$result);
+  let results =
+      (outs Optional<AnyTypeOf<[TransformAnyHandle,Transform_AffineMapParamType]>>:$result);
   let assemblyFormat =
       "$operand_handle `[`"
       "custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
@@ -262,12 +263,36 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
 
 def MatchStructuredInputOp : MatchStructuredOperandOp<"match.structured.input"> {
   let summary =
-    "Captures input operand(s) of a structured operation in an op or value handle";
+    "Captures input operand(s) of a structured operation";
   let description = !strconcat([{
-    Produces a transform dialect value handle associated with the payload value
-    supplied as input operand to the given structured payload operation, or an
-    operation handle to the structured payload operation producing said payload
-    value depending on the result type.
+    Produces a transform dialect value depending on the result type:
+    
+      - If the result type is a value handle, it will be associated with the input
+        operand(s) of the payload operation associated with the operand handle.
+      - If the result type is an operation handle, it will be associated with the
+        operation defining the input operand(s) of the payload operation associated
+        with the operand handle.
+      - If the result type is an affine map parameter type, it will be associated
+        with the indexing map that corresponds to the input operand(s) of the
+        payload operation associated with the operand handle.
+
+    For example, given the following operation:
+
+    ```mlir
+    %arg1 = some.op
+    linalg.matmul ins(%arg1, %arg2 : ...) outs(%arg3 : ...)
+    ```
+
+    in case of a successful match for operand 0 this operation will return, for
+    each of the respective cases above:
+
+      - A handle to `%arg1` if the result is a value handle.
+      - A handle to `some.op` if the result is an operation handle.
+      - A parameter containing the LHS map of the matrix multiplication, i.e.
+        `affine_map<(d0, d1, d2) -> (d0, d2)>` if the result is an affine
+        map parameter.
+
+    The match succeeds if the conditions specified as attributes succeed.
 
     }], 
     StructuredDimDescription<"input">.description,
@@ -288,12 +313,35 @@ def MatchStructuredInputOp : MatchStructuredOperandOp<"match.structured.input">
 
 def MatchStructuredInitOp : MatchStructuredOperandOp<"match.structured.init"> {
   let summary =
-    "Captures init operand(s) of a structured operation in an op or value handle";
+    "Captures init operand(s) of a structured operation";
   let description = !strconcat([{
-    Produces a transform dialect value handle associated with the payload value
-    supplied as init(outs) operand to the given structured payload operation,
-    or an operation handle to the structured payload operation producing said
-    payload value depending on the result type.
+    Produces a transform dialect value depending on the result type:
+      - If the result type is a value handle, it will be associated with the init
+        operand(s) of the payload operation associated with the operand handle.
+      - If the result type is an operation handle, it will be associated with the
+        operation defining the init operand(s) of the payload operation associated
+        with the operand handle.
+      - If the result type is an affine map parameter type, it will be associated
+        with the indexing map that corresponds to the init operand(s) of the
+        payload operation associated with the operand handle.
+
+    For example, given the following operation:
+
+    ```mlir
+    %arg3 = linalg.fill
+    linalg.matmul ins(%arg1, %arg2 : ...) outs(%arg3 : ...)
+    ```
+
+    in case of a successful match for init operand 0 this operation will return,
+    for each of the respective cases above:
+
+      - A handle to `%arg3` if the result is a value handle.
+      - A handle to `linalg.fill` if the result is an operation handle.
+      - A parameter containing the result map of the matrix multiplication, i.e.
+        `affine_map<(d0, d1, d2) -> (d0, d1)>` if the result is an affine
+        map parameter.
+
+    The match succeeds if the conditions specified as attributes succeed.
 
     }], 
     StructuredDimDescription<"init">.description,

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 0f8fdfd358858c..7a98a15ead3c98 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -443,6 +443,7 @@ def GetConsumersOfResult : TransformDialectOp<"get_consumers_of_result",
 
 def GetDefiningOp : TransformDialectOp<"get_defining_op",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
+     MatchOpInterface,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
   let summary = "Get handle to the defining op of a value";
   let description = [{
@@ -531,6 +532,25 @@ def GetResultOp : TransformDialectOp<"get_result",
                        "functional-type(operands, results)";
 }
 
+def GetTypeOp : TransformDialectOp<"get_type",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     MatchOpInterface,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let summary = "Get a parameter containing the type of the given value";
+  let description = [{
+    This operation creates a new Transform parameter containing the
+    type(s) of the value(s) associated with the operand handle.
+
+    This transform never fails.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface:$value,
+                       UnitAttr:$elemental);
+  let results = (outs TransformParamTypeInterface:$type_param);
+  let assemblyFormat = "(`elemental` $elemental^)? $value attr-dict `:`"
+                       "functional-type(operands, results)";
+}
+
 def IncludeOp : TransformDialectOp<"include",
     [CallOpInterface,
      MatchOpInterface,
@@ -838,6 +858,7 @@ def SequenceOp : TransformDialectOp<"sequence",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getSuccessorEntryOperands", "getSuccessorRegions",
          "getRegionInvocationBounds"]>,
+     MatchOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      OpAsmOpInterface, PossibleTopLevelTransformOpTrait,

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
index d5e7dd03ecb779..1d88e2b5880a39 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
@@ -13,6 +13,16 @@ include "mlir/IR/AttrTypeBase.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 
+def Transform_AffineMapParamType : TypeDef<Transform_Dialect, "AffineMapParam",
+    [DeclareTypeInterfaceMethods<TransformParamTypeInterface>]> {
+  let description = [{
+    Transform IR parameter value that can be associated with a list of affine
+    map attributes.
+  }];
+  let mnemonic = "affine_map";
+  let assemblyFormat = "";
+}
+
 def Transform_AnyOpType : TypeDef<Transform_Dialect, "AnyOp",
     [DeclareTypeInterfaceMethods<TransformHandleTypeInterface>]> {
   let description = [{
@@ -23,6 +33,15 @@ def Transform_AnyOpType : TypeDef<Transform_Dialect, "AnyOp",
   let assemblyFormat = "";
 }
 
+def Transform_AnyValue : TypeDef<Transform_Dialect, "AnyValue",
+    [DeclareTypeInterfaceMethods<TransformValueHandleTypeInterface>]> {
+  let description = [{
+    Transform IR value that can be associated with a list of Payload IR values.
+  }];
+  let mnemonic = "any_value";
+  let assemblyFormat = "";
+}
+
 def Transform_OperationType : TypeDef<Transform_Dialect, "Operation",
     [DeclareTypeInterfaceMethods<TransformHandleTypeInterface>]> {
   let description = [{
@@ -52,12 +71,13 @@ def Transform_ParamType : TypeDef<Transform_Dialect, "Param",
   let genVerifyDecl = 1;
 }
 
-def Transform_AnyValue : TypeDef<Transform_Dialect, "AnyValue",
-    [DeclareTypeInterfaceMethods<TransformValueHandleTypeInterface>]> {
+def Transform_TypeParamType : TypeDef<Transform_Dialect, "TypeParam",
+    [DeclareTypeInterfaceMethods<TransformParamTypeInterface>]> {
   let description = [{
-    Transform IR value that can be associated with a list of Payload IR values.
+    Transform IR parameter value that can be associated with a list of type
+    attributes.
   }];
-  let mnemonic = "any_value";
+  let mnemonic = "type";
   let assemblyFormat = "";
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index d788c2feaac8d5..24debc1dffae4b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -428,6 +428,11 @@ DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
     if (!getResult())
       continue;
 
+    if (isa<AffineMapParamType>(getResult().getType())) {
+      operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
+      continue;
+    }
+
     Value operand = linalgOp.getDpsInputOperand(position)->get();
     if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
       operandMapping.emplace_back(operand);
@@ -513,6 +518,11 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
     if (!getResult())
       continue;
 
+    if (isa<AffineMapParamType>(getResult().getType())) {
+      operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
+      continue;
+    }
+
     Value operand = linalgOp.getDpsInitOperand(position)->get();
     if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
       operandMapping.emplace_back(operand);

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3c60d45999d7b3..7720b6cc15b9e8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1049,6 +1049,37 @@ transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// GetTypeOp
+//===----------------------------------------------------------------------===//
+
+void transform::GetTypeOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getValue(), effects);
+  producesHandle(getResult(), effects);
+  onlyReadsPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
+                            transform::TransformResults &results,
+                            transform::TransformState &state) {
+  SmallVector<Attribute> params;
+  ArrayRef<Value> values = state.getPayloadValues(getValue());
+  params.reserve(values.size());
+  for (Value value : values) {
+    Type type = value.getType();
+    if (getElemental()) {
+      if (auto shaped = dyn_cast<ShapedType>(type)) {
+        type = shaped.getElementType();
+      }
+    }
+    params.push_back(TypeAttr::get(type));
+  }
+  results.setParams(getResult().cast<OpResult>(), params);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // IncludeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
index 8fb6ed1e909940..1360f23c5018fd 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
@@ -38,6 +38,22 @@ void transform::TransformDialect::initializeTypes() {
       >();
 }
 
+//===----------------------------------------------------------------------===//
+// transform::AffineMapParamType
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::AffineMapParamType::checkPayload(Location loc,
+                                            ArrayRef<Attribute> payload) const {
+  for (Attribute attr : payload) {
+    if (!attr.isa<AffineMapAttr>()) {
+      return emitSilenceableError(loc)
+             << "expected affine map attribute, got " << attr;
+    }
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // transform::AnyOpType
 //===----------------------------------------------------------------------===//
@@ -48,6 +64,16 @@ transform::AnyOpType::checkPayload(Location loc,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// transform::AnyValueType
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::AnyValueType::checkPayload(Location loc,
+                                      ArrayRef<Value> payload) const {
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // transform::OperationType
 //===----------------------------------------------------------------------===//
@@ -103,11 +129,17 @@ transform::ParamType::checkPayload(Location loc,
 }
 
 //===----------------------------------------------------------------------===//
-// transform::AnyValueType
+// transform::TypeParamType
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::AnyValueType::checkPayload(Location loc,
-                                      ArrayRef<Value> payload) const {
+transform::TypeParamType::checkPayload(Location loc,
+                                       ArrayRef<Attribute> payload) const {
+  for (Attribute attr : payload) {
+    if (!attr.isa<TypeAttr>()) {
+      return emitSilenceableError(loc)
+             << "expected type attribute, got " << attr;
+    }
+  }
   return DiagnosedSilenceableFailure::success();
 }

diff  --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 2c8287130b8b45..e4b75e567ee843 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -17,7 +17,7 @@ module attributes { transform.with_named_sequence } {
   // Entry point. Match any structured operation and emit at remark.
   transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
   ^bb0(%arg0: !transform.any_op):
-    transform.foreach_match in %arg0 
+    transform.foreach_match in %arg0
         @match_structured_empty -> @print_structured
         : (!transform.any_op) -> !transform.any_op
   }
@@ -73,7 +73,7 @@ module attributes { transform.with_named_sequence } {
 
   transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
   ^bb0(%arg0: !transform.any_op):
-    transform.foreach_match in %arg0 
+    transform.foreach_match in %arg0
         @match_structured_suppress -> @do_nothing
         : (!transform.any_op) -> !transform.any_op
   }
@@ -118,7 +118,7 @@ module attributes { transform.with_named_sequence } {
 
   transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
   ^bb0(%arg0: !transform.any_op):
-    transform.foreach_match in %arg0 
+    transform.foreach_match in %arg0
         @match_structured_body_passthrough -> @print_passthrough
         : (!transform.any_op) -> !transform.any_op
   }
@@ -129,7 +129,7 @@ module attributes { transform.with_named_sequence } {
       indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
       iterator_types = ["parallel"]
     } ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) {
-    ^bb0(%arg0: f32, %arg1: f32):      
+    ^bb0(%arg0: f32, %arg1: f32):
       linalg.yield %arg0 : f32
     } -> tensor<2xf32>
 
@@ -137,7 +137,7 @@ module attributes { transform.with_named_sequence } {
       indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
       iterator_types = ["parallel"]
     } ins(%in : tensor<2xf32>) outs(%out : tensor<2xf32>) {
-    ^bb0(%arg0: f32, %arg1: f32):      
+    ^bb0(%arg0: f32, %arg1: f32):
       %0 = arith.mulf %arg0, %arg1 : f32
       linalg.yield %0 : f32
     } -> tensor<2xf32>
@@ -168,7 +168,7 @@ module attributes { transform.with_named_sequence } {
 
   transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
   ^bb0(%arg0: !transform.any_op):
-    transform.foreach_match in %arg0 
+    transform.foreach_match in %arg0
         @match_structured_body_reduction -> @print_reduction
         : (!transform.any_op) -> !transform.any_op
   }
@@ -230,8 +230,8 @@ module attributes { transform.with_named_sequence } {
 
   transform.named_sequence @match_dimension_capture(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
     // Capture multiple dimension values. Suppress failures so we can print them anyway after the capture.
-    %0:9 = transform.match.structured failures(suppress) %arg0 
-      : (!transform.any_op) -> (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, 
+    %0:9 = transform.match.structured failures(suppress) %arg0
+      : (!transform.any_op) -> (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
             !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
     ^bb0(%arg1: !transform.any_op):
       // This also tests the positional specification used by other ops, which may not test it again.
@@ -243,8 +243,8 @@ module attributes { transform.with_named_sequence } {
       %6 = transform.match.structured.dim %arg1[except(-1)] : (!transform.any_op) -> !transform.param<i64>
       %7 = transform.match.structured.dim %arg1[except(0, -2)] : (!transform.any_op) -> !transform.param<i64>
       %8 = transform.match.structured.dim %arg1[0, -3] : (!transform.any_op) -> !transform.param<i64>
-      transform.match.structured.yield %arg1, %1, %2, %3, %4, %5, %6, %7, %8 
-          : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, 
+      transform.match.structured.yield %arg1, %1, %2, %3, %4, %5, %6, %7, %8
+          : !transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
             !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
     }
     transform.test_print_param %0#1, "dimensions all:" at %0#0 : !transform.param<i64>, !transform.any_op
@@ -280,7 +280,7 @@ module attributes { transform.with_named_sequence } {
   }
 
   func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
-    // The last does not emit anything because it fails to match 
+    // The last does not emit anything because it fails to match
     // due to 0 and -3 being the same dimension in the 3D case.
     // expected-remark @below {{dimensions all: 2 : i64, 3 : i64, 4 : i64}}
     // expected-remark @below {{dimension 0: 2 : i64}}
@@ -404,7 +404,7 @@ module attributes { transform.with_named_sequence } {
     }
     transform.yield %arg0, %bw : !transform.any_op, !transform.param<i64>
   }
-  
+
   transform.named_sequence @print_bitwidth(%arg0: !transform.any_op {transform.readonly}, %arg1: !transform.param<i64> {transform.readonly}) {
     transform.test_print_param %arg1, "bitwidth:" at %arg0 : !transform.param<i64>, !transform.any_op
     transform.yield
@@ -417,7 +417,7 @@ module attributes { transform.with_named_sequence } {
   }
 
   func.func @payload(%f32: f32, %tf32: tensor<?xf32>,
-                     %index: index, %tindex: tensor<?xindex>) 
+                     %index: index, %tindex: tensor<?xindex>)
             attributes { transform.target_tag = "start_here" }  {
     // expected-remark @below {{bitwidth: 32}}
     linalg.fill ins(%f32: f32) outs(%tf32: tensor<?xf32>) -> tensor<?xf32>
@@ -429,7 +429,7 @@ module attributes { transform.with_named_sequence } {
 // -----
 
 module attributes { transform.with_named_sequence } {
-  transform.named_sequence @match_init(%arg0: !transform.any_op {transform.readonly}) 
+  transform.named_sequence @match_init(%arg0: !transform.any_op {transform.readonly})
       -> (!transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op) {
     %outs:3 = transform.match.structured failures(suppress) %arg0
       : (!transform.any_op) -> (!transform.any_value, !transform.any_value, !transform.any_op) {
@@ -441,7 +441,7 @@ module attributes { transform.with_named_sequence } {
     }
     transform.yield %arg0, %outs#0, %outs#1, %outs#2 : !transform.any_op, !transform.any_value, !transform.any_value, !transform.any_op
   }
-  
+
   transform.named_sequence @print_init(%arg0: !transform.any_op {transform.readonly},
                                          %arg1: !transform.any_value {transform.readonly},
                                          %arg2: !transform.any_value {transform.readonly},
@@ -459,21 +459,21 @@ module attributes { transform.with_named_sequence } {
   }
 
 
-  func.func @payload(%f32: f32, 
+  func.func @payload(%f32: f32,
             // expected-remark @below {{output 0}}
             // expected-remark @below {{all output}}
             // expected-note @below {{value handle points to a block argument #1 in block #0 in region #0}}
             %tf32: tensor<?xf32>,
             // expected-remark @below {{all output}}
             // expected-note @below {{value handle points to a block argument #2 in block #0 in region #0}}
-            %tf32_2: tensor<?xf32>) 
+            %tf32_2: tensor<?xf32>)
             attributes { transform.target_tag = "start_here" }  {
     // expected-remark @below {{output 0}}
     // expected-remark @below {{output producer}}
     // expected-remark @below {{all output}}
     // expected-note @below {{value handle points to an op result #0}}
     %0 = linalg.fill ins(%f32: f32) outs(%tf32: tensor<?xf32>) -> tensor<?xf32>
-    
+
     linalg.generic {
       indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
       iterator_types = ["parallel"]
@@ -488,7 +488,7 @@ module attributes { transform.with_named_sequence } {
 // -----
 
 module attributes { transform.with_named_sequence } {
-  transform.named_sequence @match_init_0_permutation(%arg0: !transform.any_op {transform.readonly}) 
+  transform.named_sequence @match_init_0_permutation(%arg0: !transform.any_op {transform.readonly})
       -> !transform.any_op {
     %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
     ^bb0(%arg1: !transform.any_op):
@@ -497,7 +497,7 @@ module attributes { transform.with_named_sequence } {
     }
     transform.yield %0 : !transform.any_op
   }
-  transform.named_sequence @match_init_1_permutation(%arg0: !transform.any_op {transform.readonly}) 
+  transform.named_sequence @match_init_1_permutation(%arg0: !transform.any_op {transform.readonly})
       -> !transform.any_op {
     %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
     ^bb0(%arg1: !transform.any_op):
@@ -506,7 +506,7 @@ module attributes { transform.with_named_sequence } {
     }
     transform.yield %0 : !transform.any_op
   }
-  transform.named_sequence @match_init_2_projected_permutation(%arg0: !transform.any_op {transform.readonly}) 
+  transform.named_sequence @match_init_2_projected_permutation(%arg0: !transform.any_op {transform.readonly})
       -> !transform.any_op {
     %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
     ^bb0(%arg1: !transform.any_op):
@@ -515,7 +515,7 @@ module attributes { transform.with_named_sequence } {
     }
     transform.yield %0 : !transform.any_op
   }
-  
+
   transform.named_sequence @print_init_0_permutation(%arg0: !transform.any_op {transform.readonly}) {
     transform.test_print_remark_at_operand %arg0, "matched output 0 permutation" : !transform.any_op
     transform.yield
@@ -537,10 +537,10 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 
-  func.func @payload(%f32: f32, 
+  func.func @payload(%f32: f32,
             %oned: tensor<?xf32>,
             %oned2: tensor<?xf32>,
-            %twod: tensor<?x?xf32>) 
+            %twod: tensor<?x?xf32>)
             attributes { transform.target_tag = "start_here" }  {
     // expected-remark @below {{matched output 2 projected permutation}}
     linalg.generic {
@@ -575,9 +575,9 @@ module attributes { transform.with_named_sequence } {
 
 
 module attributes { transform.with_named_sequence } {
-  transform.named_sequence @match_num_io(%arg0: !transform.any_op {transform.readonly}) 
+  transform.named_sequence @match_num_io(%arg0: !transform.any_op {transform.readonly})
       -> (!transform.param<i64>, !transform.param<i64>, !transform.any_op) {
-    %0:3 = transform.match.structured failures(propagate) %arg0 
+    %0:3 = transform.match.structured failures(propagate) %arg0
          : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.any_op) {
     ^bb0(%arg1: !transform.any_op):
       %1 = transform.match.structured.num_inputs %arg1 : (!transform.any_op) -> !transform.param<i64>
@@ -587,7 +587,7 @@ module attributes { transform.with_named_sequence } {
     transform.yield %0#0, %0#1, %0#2 : !transform.param<i64>, !transform.param<i64>, !transform.any_op
   }
 
-  
+
   transform.named_sequence @print_num_io(
       %arg0: !transform.param<i64> {transform.readonly},
       %arg1: !transform.param<i64> {transform.readonly},
@@ -604,10 +604,10 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 
-  func.func @payload(%f32: f32, 
+  func.func @payload(%f32: f32,
             %oned: tensor<?xf32>,
             %oned2: tensor<?xf32>,
-            %twod: tensor<?x?xf32>) 
+            %twod: tensor<?x?xf32>)
             attributes { transform.target_tag = "start_here" }  {
     // expected-remark @below {{inputs 1}}
     // expected-remark @below {{outputs 3}}
@@ -641,9 +641,9 @@ module attributes { transform.with_named_sequence } {
 // -----
 
 module attributes { transform.with_named_sequence } {
-  transform.named_sequence @match_rank(%arg0: !transform.any_op {transform.readonly}) 
+  transform.named_sequence @match_rank(%arg0: !transform.any_op {transform.readonly})
       -> (!transform.param<i64>, !transform.any_op) {
-    %0:2 = transform.match.structured failures(propagate) %arg0 
+    %0:2 = transform.match.structured failures(propagate) %arg0
          : (!transform.any_op) -> (!transform.param<i64>, !transform.any_op) {
     ^bb0(%arg1: !transform.any_op):
       %1 = transform.match.structured.rank %arg1 : (!transform.any_op) -> !transform.param<i64>
@@ -652,7 +652,7 @@ module attributes { transform.with_named_sequence } {
     transform.yield %0#0, %0#1 : !transform.param<i64>, !transform.any_op
   }
 
-  
+
   transform.named_sequence @print_rank(%arg0: !transform.param<i64> {transform.readonly},
                                        %arg2: !transform.any_op {transform.readonly}) {
     transform.test_print_param %arg0, "rank" at %arg2 : !transform.param<i64>, !transform.any_op
@@ -665,8 +665,8 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 
-  func.func @payload(%f32: f32, 
-            %twod: tensor<42x42xf32>) 
+  func.func @payload(%f32: f32,
+            %twod: tensor<42x42xf32>)
             attributes { transform.target_tag = "start_here" } {
     %0 = tensor.empty() : tensor<42x42xf32>
     // expected-remark @below {{rank 2}}
@@ -681,9 +681,9 @@ module attributes { transform.with_named_sequence } {
 // -----
 
 module attributes { transform.with_named_sequence } {
-  transform.named_sequence @match_single_result(%arg0: !transform.any_op {transform.readonly}) 
+  transform.named_sequence @match_single_result(%arg0: !transform.any_op {transform.readonly})
       -> (!transform.any_op, !transform.any_op) {
-    %0:2 = transform.match.structured failures(propagate) %arg0 
+    %0:2 = transform.match.structured failures(propagate) %arg0
          : (!transform.any_op) -> (!transform.any_op, !transform.any_op) {
     ^bb0(%arg1: !transform.any_op):
       %1 = transform.match.structured.result %arg1[0] { single } : (!transform.any_op) -> !transform.any_op
@@ -693,7 +693,7 @@ module attributes { transform.with_named_sequence } {
   }
   transform.named_sequence @match_result_value(%arg0: !transform.any_op {transform.readonly})
       -> (!transform.any_value, !transform.any_op) {
-    %0:2 = transform.match.structured failures(propagate) %arg0 
+    %0:2 = transform.match.structured failures(propagate) %arg0
          : (!transform.any_op) -> (!transform.any_value, !transform.any_op) {
     ^bb0(%arg1: !transform.any_op):
       %1 = transform.match.structured.result %arg1[0] : (!transform.any_op) -> !transform.any_value
@@ -701,9 +701,9 @@ module attributes { transform.with_named_sequence } {
     }
     transform.yield %0#0, %0#1 : !transform.any_value, !transform.any_op
   }
-  transform.named_sequence @match_any_result(%arg0: !transform.any_op {transform.readonly}) 
+  transform.named_sequence @match_any_result(%arg0: !transform.any_op {transform.readonly})
       -> (!transform.any_op) {
-    %0 = transform.match.structured failures(propagate) %arg0 
+    %0 = transform.match.structured failures(propagate) %arg0
          : (!transform.any_op) -> !transform.any_op {
     ^bb0(%arg1: !transform.any_op):
       %1 = transform.match.structured.result %arg1[-1] { any } : (!transform.any_op) -> !transform.any_op
@@ -711,7 +711,7 @@ module attributes { transform.with_named_sequence } {
     }
     transform.yield %0 : !transform.any_op
   }
-  
+
   transform.named_sequence @print_single_result(%arg0: !transform.any_op {transform.readonly},
                                                 %arg2: !transform.any_op {transform.readonly}) {
     transform.test_print_remark_at_operand %arg2, "matched single result" : !transform.any_op
@@ -738,7 +738,7 @@ module attributes { transform.with_named_sequence } {
   }
 
   func.func @payload(%f32: f32, %f322: f32, %f323: f32,
-            %twod: tensor<42x42xf32>) 
+            %twod: tensor<42x42xf32>)
             attributes { transform.target_tag = "start_here" } {
     %0 = tensor.empty() : tensor<42x42xf32>
 
@@ -774,3 +774,60 @@ module attributes { transform.with_named_sequence } {
     return
   }
 }
+
+// -----
+
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @match_input_indexing_map(%arg0: !transform.any_op {transform.readonly})
+      -> (!transform.affine_map, !transform.any_op) {
+    %0 = transform.match.structured failures(propagate) %arg0
+         : (!transform.any_op) -> !transform.affine_map {
+    ^bb0(%arg1: !transform.any_op):
+      %1 = transform.match.structured.input %arg1[0]  : (!transform.any_op) -> !transform.affine_map
+      transform.match.structured.yield %1 : !transform.affine_map
+    }
+    transform.yield %0, %arg0 : !transform.affine_map, !transform.any_op
+  }
+  transform.named_sequence @match_init_indexing_map(%arg0: !transform.any_op {transform.readonly})
+      -> (!transform.affine_map, !transform.any_op) {
+    %0 = transform.match.structured failures(propagate) %arg0
+         : (!transform.any_op) -> !transform.affine_map {
+    ^bb0(%arg1: !transform.any_op):
+      %1 = transform.match.structured.init %arg1[0]  : (!transform.any_op) -> !transform.affine_map
+      transform.match.structured.yield %1 : !transform.affine_map
+    }
+    transform.yield %0, %arg0 : !transform.affine_map, !transform.any_op
+  }
+
+  transform.named_sequence @print_indexing_map_1(%arg0: !transform.affine_map {transform.readonly},
+                                               %arg1: !transform.any_op {transform.readonly}) {
+    transform.test_print_param %arg0, "indexing map 1" at %arg1 : !transform.affine_map, !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence @print_indexing_map_2(%arg0: !transform.affine_map {transform.readonly},
+                                               %arg1: !transform.any_op {transform.readonly}) {
+    transform.test_print_param %arg0, "indexing map 2" at %arg1 : !transform.affine_map, !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
+  ^bb0(%arg0: !transform.any_op):
+    %3 = transform.foreach_match in %arg0 @match_input_indexing_map -> @print_indexing_map_1 : (!transform.any_op) -> !transform.any_op
+    %4 = transform.foreach_match in %3 @match_init_indexing_map -> @print_indexing_map_2 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  func.func @payload(%lhs: tensor<32x32xf32>, %rhs: tensor<32x32xf32>)
+            attributes { transform.target_tag = "start_here" } {
+    %out = tensor.empty() : tensor<32x32xf32>
+    %cst = arith.constant 1.0 : f32
+    // expected-remark @below {{indexing map 1 affine_map<(d0, d1) -> ()>}}
+    // expected-remark @below {{indexing map 2 affine_map<(d0, d1) -> (d0, d1)>}}
+    %res = linalg.fill ins(%cst : f32) outs(%out : tensor<32x32xf32>) -> tensor<32x32xf32>
+    // expected-remark @below {{indexing map 1 affine_map<(d0, d1, d2) -> (d0, d2)>}}
+    // expected-remark @below {{indexing map 2 affine_map<(d0, d1, d2) -> (d0, d1)>}}
+    linalg.matmul ins(%lhs, %rhs : tensor<32x32xf32>, tensor<32x32xf32>) outs(%res : tensor<32x32xf32>) -> tensor<32x32xf32>
+    return
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index 8b449770ee8a1b..d4629dcb29c3ef 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -80,7 +80,7 @@ transform.sequence failures(propagate) {
 ^bb0(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   // expected-note @below {{for this parameter}}
-  %1 = transform.test_produce_integer_param_with_type i64 : !transform.param<i64>
+  %1 = transform.test_produce_param (0 : i64) : !transform.param<i64>
   // expected-error @below {{expected as many parameter values (0) as target ops (2)}}
   transform.structured.tile %0 [%1, %1, %1]
     : (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 36a6ff40f0b3a2..b6ea867df6fc64 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1037,7 +1037,7 @@ transform.sequence -> !transform.any_op failures(suppress) {
 
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
-  %0 = transform.test_produce_integer_param_with_type i32 : !transform.test_dialect_param
+  %0 = transform.test_produce_param (0 : i32) : !transform.test_dialect_param
   // expected-remark @below {{0 : i32}}
   transform.test_print_param %0 : !transform.test_dialect_param
 }
@@ -1047,7 +1047,7 @@ transform.sequence failures(propagate) {
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !transform.any_op):
   // expected-error @below {{expected the type of the parameter attribute ('i32') to match the parameter type ('i64')}}
-  transform.test_produce_integer_param_with_type i32 : !transform.param<i64>
+  transform.test_produce_param (0 : i32) : !transform.param<i64>
 }
 
 // -----
@@ -1860,3 +1860,58 @@ transform.sequence failures(propagate) {
   // expected-remark @below{{1}}
   test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op
 }
+
+// -----
+
+func.func @cast(%arg0: f32) -> f64 {
+  // expected-remark @below{{f64}}
+  %0 = arith.extf %arg0 : f32 to f64
+  return %0 : f64
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["arith.extf"]} in %arg0 : (!transform.any_op) -> !transform.op<"arith.extf">
+  %1 = transform.get_result %0[0] : (!transform.op<"arith.extf">) -> !transform.any_value
+  %2 = transform.get_type %1 : (!transform.any_value) -> !transform.type
+  transform.test_print_param %2 at %0 : !transform.type, !transform.op<"arith.extf">
+  transform.yield
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{expected type attribute, got 0 : i32}}
+  transform.test_produce_param (0 : i32) : !transform.type
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{expected affine map attribute, got 0 : i32}}
+  transform.test_produce_param (0 : i32) : !transform.affine_map
+}
+
+// -----
+
+// CHECK-LABEL: @type_param_anchor
+func.func private @type_param_anchor()
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // CHECK: test_produce_param(f32) : !transform.type
+  transform.test_produce_param(f32) : !transform.type
+}
+
+// -----
+
+// CHECK-LABEL: @affine_map_param_anchor
+func.func private @affine_map_param_anchor()
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // CHECK: test_produce_param(#{{.*}}) : !transform.affine_map
+  transform.test_produce_param(affine_map<(d0) -> ()>) : !transform.affine_map
+}

diff  --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
new file mode 100644
index 00000000000000..8f6fb8b3a50757
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --verify-diagnostics
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @match_matmul(%entry: !transform.any_op {transform.readonly})
+      -> (!transform.any_op, !transform.any_op, !transform.param<i64>,
+          !transform.type, !transform.type, !transform.type) {
+    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
+    %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+    %capture:5 = transform.match.structured %entry : (!transform.any_op)
+        -> (!transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type) {
+    ^bb0(%struct: !transform.any_op):
+      transform.match.operation_name %struct ["linalg.matmul"] : !transform.any_op
+      %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param<i64>
+      
+      %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param<i64>
+      %n_inits = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
+      transform.match.param.cmpi eq %n_inputs, %c2 : !transform.param<i64>
+      transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>
+      
+      %lhs = transform.match.structured.input %struct[0] : (!transform.any_op) -> !transform.any_value
+      %rhs = transform.match.structured.input %struct[1] : (!transform.any_op) -> !transform.any_value
+      %res = transform.match.structured.result %struct[0] : (!transform.any_op) -> !transform.any_value
+      %lhs_type = transform.get_type elemental %lhs : (!transform.any_value) -> !transform.type
+      %rhs_type = transform.get_type elemental %rhs : (!transform.any_value) -> !transform.type
+      %res_type = transform.get_type elemental %res : (!transform.any_value) -> !transform.type
+
+      %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_op
+      transform.match.operation_name %init ["linalg.fill"] : !transform.any_op
+
+      transform.match.structured.yield %init, %dims, %lhs_type, %rhs_type, %res_type
+          : !transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type
+    }
+    transform.yield %capture#0, %entry, %capture#1, %capture#2, %capture#3, %capture#4
+        : !transform.any_op, !transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type
+  }
+
+  transform.named_sequence @print_matmul(
+      %fill: !transform.any_op {transform.readonly},
+      %matmul: !transform.any_op {transform.readonly},
+      %dims: !transform.param<i64> {transform.readonly},
+      %lhs_type: !transform.type {transform.readonly},
+      %rhs_type: !transform.type {transform.readonly},
+      %res_type: !transform.type {transform.readonly}) {
+    transform.test_print_remark_at_operand %fill, "fill" : !transform.any_op
+    transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op
+    transform.test_print_param %dims, "dimensions" at %matmul : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %lhs_type, "LHS type" at %matmul : !transform.type, !transform.any_op
+    transform.test_print_param %rhs_type, "RHS type" at %matmul : !transform.type, !transform.any_op
+    transform.test_print_param %res_type, "result type" at %matmul : !transform.type, !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb(%root: !transform.any_op):
+    foreach_match in %root
+      @match_matmul -> @print_matmul
+      : (!transform.any_op) -> !transform.any_op
+  }
+}
+
+func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf64>{
+  %cst = arith.constant 0.0 : f64
+  %empty = tensor.empty() : tensor<10x15xf64>
+  // expected-remark @below {{fill}}
+  %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf64>) -> tensor<10x15xf64>
+  // expected-remark @below {{matmul}}
+  // expected-remark @below {{dimensions 10 : i64, 15 : i64, 20 : i64}}
+  // expected-remark @below {{LHS type f16}}
+  // expected-remark @below {{RHS type f32}}
+  // expected-remark @below {{result type f64}}
+  %result = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf64>) -> tensor<10x15xf64>
+  return %result : tensor<10x15xf64>
+}
+
+func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> {
+  %cst = arith.constant 0.0 : f64
+  %empty = tensor.empty() : tensor<10x15xf32>
+
+  // expected-remark @below {{fill}}
+  %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
+
+  %real_lhs = linalg.elemwise_binary { fun = #linalg.binary_fn<mul> } 
+    ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32>
+
+  // expected-remark @below {{matmul}}
+  // expected-remark @below {{dimensions 10 : i64, 15 : i64, 20 : i64}}
+  // expected-remark @below {{LHS type f32}}
+  // expected-remark @below {{RHS type f32}}
+  // expected-remark @below {{result type f32}}
+  %result = linalg.matmul ins(%real_lhs, %rhs: tensor<10x20xf32>, tensor<20x15xf32>) outs(%fill: tensor<10x15xf32>) -> tensor<10x15xf32>
+  return %result : tensor<10x15xf32>
+}

diff  --git a/mlir/test/Integration/Dialect/Transform/match_reduction.mlir b/mlir/test/Integration/Dialect/Transform/match_reduction.mlir
new file mode 100644
index 00000000000000..c85547af4ef1dd
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Transform/match_reduction.mlir
@@ -0,0 +1,319 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --verify-diagnostics
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly})
+      -> (!transform.any_op) {
+    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
+
+    transform.match.structured %entry : !transform.any_op {
+    ^bb0(%struct: !transform.any_op):
+      transform.match.structured.dim %struct[all] {parallel} : !transform.any_op
+      transform.match.structured.input %struct[all] {projected_permutation} : !transform.any_op
+      transform.match.structured.init %struct[all] {permutation} : !transform.any_op
+      %ni = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
+      transform.match.param.cmpi eq %ni, %c1 : !transform.param<i64>
+    }
+    transform.yield %entry : !transform.any_op
+  }
+
+  transform.named_sequence @fill_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly})
+      -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
+          !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
+    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
+    %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+    %c4 = transform.param.constant 4 : i64 -> !transform.param<i64>
+
+    %rk, %dms, %bw, %operand_o, %init_v, %trailing_o = transform.match.structured failures(propagate) %entry 
+        : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
+                                  !transform.any_op, !transform.any_value, !transform.any_op) {
+    ^bb0(%struct: !transform.any_op):
+      %rank = transform.match.structured.rank %struct : (!transform.any_op) -> !transform.param<i64>
+      transform.match.param.cmpi ge %rank, %c2 : !transform.param<i64>
+      transform.match.param.cmpi le %rank, %c4 : !transform.param<i64>
+      
+      transform.match.structured.dim %struct[-1] {reduction} : !transform.any_op
+      transform.match.structured.dim %struct[except(-1)] {parallel} : !transform.any_op
+      %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param<i64>
+
+      %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param<i64>
+      %n_outputs = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
+      transform.match.param.cmpi eq %n_inputs, %c1 : !transform.param<i64>
+      transform.match.param.cmpi eq %n_outputs, %c1 : !transform.param<i64>
+
+      transform.match.structured.input %struct[0] {projected_permutation} : !transform.any_op
+      transform.match.structured.init %struct[0] {projected_permutation} : !transform.any_op
+      %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_value
+      
+      // This danse is necessary to create an empty handle if there is no single
+      // user without failing the entire match
+      %trailing_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) {
+      ^bb0(%struct_inner: !transform.any_op):
+        %result = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op {
+        ^bb0(%struct_inner_inner: !transform.any_op):
+          %result_inner = transform.match.structured.result %struct_inner_inner[0] {single} : (!transform.any_op) -> !transform.any_op
+          %trailing = transform.include @_reduce_leading_trailing failures(propagate) (%result_inner) : (!transform.any_op) -> !transform.any_op
+          transform.match.structured.yield %trailing : !transform.any_op
+        }
+        transform.yield %result: !transform.any_op
+      }
+
+      // Suppress errors as a way to implement optionality. We cannot suppress them in
+      // the include because it keeps matching after "get_defining_op" fails, which
+      // breaks the single-op precondition of the following ops. We don't want to
+      // propagate that failure though.
+      //
+      // Additionally, we cannot put the sequence inside the call because its first
+      // operand must be an operation handle (the verifier asserts!) and there is
+      // no such handle available there.
+      //
+      // TODO: extend the structured matching to gracefully handle empty handles
+      // or provide the suppress-errors-but-stop failure mode for includes to
+      // implement optionality.
+      %operand_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) {
+      ^bb0(%struct_inner: !transform.any_op):
+        %operand3 = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op {
+        ^bb1(%struct_inner_inner: !transform.any_op):
+          %operand = transform.match.structured.input %struct_inner_inner[0] : (!transform.any_op) -> !transform.any_op
+          %operand2 = transform.include @_reduce_leading_trailing failures(propagate) (%operand) : (!transform.any_op) -> !transform.any_op
+          transform.match.structured.yield %operand2 : !transform.any_op
+        }
+        transform.yield %operand3 : !transform.any_op
+      }
+
+      %bitwidth = transform.match.structured.elemental_bitwidth %init : (!transform.any_value) -> !transform.param<i64>
+
+      transform.match.structured.body %struct { reduction_position = 0 } : !transform.any_op
+      transform.match.structured.yield %rank, %dims, %bitwidth, %operand_optional, %init, %trailing_optional
+        : !transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
+          !transform.any_op, !transform.any_value, !transform.any_op
+    }
+
+    %init_o = transform.get_defining_op %init_v : (!transform.any_value) -> !transform.any_op
+    transform.match.operation_name %init_o ["linalg.fill"] : !transform.any_op    
+
+    transform.yield %operand_o, %init_o, %entry, %trailing_o, %rk, %dms, %bw
+        : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
+          !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+  }
+
+  transform.named_sequence @print_reduce_leading_trailing(
+      %leading: !transform.any_op {transform.readonly},
+      %fill: !transform.any_op {transform.readonly},
+      %reduction: !transform.any_op {transform.readonly},
+      %trailing: !transform.any_op {transform.readonly},
+      %rank: !transform.param<i64> {transform.readonly},
+      %dims: !transform.param<i64> {transform.readonly},
+      %bitwidth: !transform.param<i64> {transform.readonly}) {
+    transform.test_print_remark_at_operand %leading, "leading" : !transform.any_op
+    transform.test_print_remark_at_operand %fill, "fill" : !transform.any_op
+    transform.test_print_remark_at_operand %reduction, "reduction" : !transform.any_op
+    transform.test_print_remark_at_operand %trailing, "trailing" : !transform.any_op
+    transform.test_print_param %rank, "rank" at %reduction : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %dims, "dimensions" at %reduction : !transform.param<i64>, !transform.any_op
+    transform.test_print_param %bitwidth, "bitwidth" at %reduction : !transform.param<i64>, !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb(%root: !transform.any_op):
+    foreach_match in %root
+      @fill_reduce_leading_trailing -> @print_reduce_leading_trailing
+      : (!transform.any_op) -> !transform.any_op
+  }
+}
+
+!in_tensor_t = tensor<8x64xf32>
+!out_tensor_t = tensor<8xf32>
+
+func.func @eltwise_reduce(%arg : !in_tensor_t) -> (!out_tensor_t) {
+  %cst = arith.constant -0.000000e+00 : f32
+
+  %0 = tensor.empty() : !out_tensor_t
+  // expected-remark @below {{fill}}
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) ->  !out_tensor_t
+  %2 = tensor.empty() : !in_tensor_t
+  // expected-remark @below {{leading}}
+  %3 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
+    ^bb0(%arg3: f32, %arg4: f32):
+      %4 = arith.addf %arg3, %arg3 : f32
+      %5 = arith.addf %4, %4 : f32
+      linalg.yield %5 : f32
+    } -> !in_tensor_t
+
+  // expected-remark @below {{reduction}}
+  // expected-remark @below {{rank 2}}
+  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
+  // expected-remark @below {{bitwidth 32 : i64}}
+  %6 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]}
+    ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
+      ^bb0(%arg3: f32, %arg4: f32):
+        %4 = arith.addf %arg3, %arg4 : f32
+        linalg.yield %4 : f32
+      } -> !out_tensor_t
+
+  return %6 : !out_tensor_t
+}
+
+func.func @reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) {
+  %cst = arith.constant -0.000000e+00 : f32
+
+  %0 = tensor.empty() : !out_tensor_t
+  // expected-remark @below {{fill}}
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t
+  // expected-remark @below {{reduction}}
+  // expected-remark @below {{rank 2}}
+  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
+  // expected-remark @below {{bitwidth 32 : i64}}
+  %5 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]}
+    ins(%arg : !in_tensor_t) outs(%1 : !out_tensor_t) {
+      ^bb0(%arg3: f32, %arg4: f32):
+        %4 = arith.addf %arg3, %arg4 : f32
+        linalg.yield %4 : f32
+      } -> !out_tensor_t
+
+  %6 = tensor.empty() : !out_tensor_t
+  // expected-remark @below {{trailing}}
+  %7 = linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>,
+                     affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]}
+    ins(%5 : !out_tensor_t) outs(%6 : !out_tensor_t) {  
+    ^bb0(%arg3: f32, %arg4: f32):
+      %4 = math.sqrt %arg3 : f32
+      linalg.yield %4 : f32
+    } -> !out_tensor_t
+  return %7 : !out_tensor_t
+}
+
+func.func @eltwise_reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) {
+  %cst = arith.constant -0.000000e+00 : f32
+
+  %0 = tensor.empty() : !out_tensor_t
+  // expected-remark @below {{fill}}
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) ->  !out_tensor_t
+  %2 = tensor.empty() : !in_tensor_t
+  // expected-remark @below {{leading}}
+  %3 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
+    ^bb0(%arg3: f32, %arg4: f32):
+      %4 = arith.addf %arg3, %arg3 : f32
+      %5 = arith.addf %4, %4 : f32
+      linalg.yield %5 : f32
+    } -> !in_tensor_t
+
+  // expected-remark @below {{reduction}}
+  // expected-remark @below {{rank 2}}
+  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
+  // expected-remark @below {{bitwidth 32 : i64}}
+  %6 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]}
+    ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
+      ^bb0(%arg3: f32, %arg4: f32):
+        %4 = arith.addf %arg3, %arg4 : f32
+        linalg.yield %4 : f32
+      } -> !out_tensor_t
+
+  %7 = tensor.empty() : !out_tensor_t
+  // expected-remark @below {{trailing}}
+  %8 = linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>,
+                     affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]}
+    ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) {  
+    ^bb0(%arg3: f32, %arg4: f32):
+      %4 = math.sqrt %arg3 : f32
+      linalg.yield %4 : f32
+    } -> !out_tensor_t
+
+
+  return %8 : !out_tensor_t
+}
+
+func.func @eltwise_reduce_eltwise_swapped(%arg : !in_tensor_t) -> (!out_tensor_t) {
+  %cst = arith.constant -0.000000e+00 : f32
+
+  %2 = tensor.empty() : !in_tensor_t
+  // expected-remark @below {{leading}}
+  %3 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
+    ^bb0(%arg3: f32, %arg4: f32):
+      %4 = arith.addf %arg3, %arg3 : f32
+      %5 = arith.addf %4, %4 : f32
+      linalg.yield %5 : f32
+    } -> !in_tensor_t
+
+  %0 = tensor.empty() : !out_tensor_t
+  // expected-remark @below {{fill}}
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) ->  !out_tensor_t
+  // expected-remark @below {{reduction}}
+  // expected-remark @below {{rank 2}}
+  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
+  // expected-remark @below {{bitwidth 32 : i64}}
+  %6 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]}
+    ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
+      ^bb0(%arg3: f32, %arg4: f32):
+        %4 = arith.addf %arg3, %arg4 : f32
+        linalg.yield %4 : f32
+      } -> !out_tensor_t
+
+  %7 = tensor.empty() : !out_tensor_t
+  // expected-remark @below {{trailing}}
+  %8 = linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>,
+                     affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]}
+    ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) {  
+    ^bb0(%arg3: f32, %arg4: f32):
+      %4 = math.sqrt %arg3 : f32
+      linalg.yield %4 : f32
+    } -> !out_tensor_t
+
+
+  return %8 : !out_tensor_t
+}
+
+func.func @reduction_with_extra_op_in_func(%arg0: tensor<8x479xf32>, %arg1: tensor<32x32xf32>) -> (tensor<8xf32>, tensor<32xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %empty = tensor.empty() : tensor<8xf32>
+  // expected-remark @below {{fill}}
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8xf32>) -> tensor<8xf32>
+  // expected-remark @below {{reduction}}
+  // expected-remark @below {{rank 2}}
+  // expected-remark @below {{dimensions 8 : i64, 479 : i64}}
+  // expected-remark @below {{bitwidth 32 : i64}}
+  %result = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]} 
+    ins(%arg0 : tensor<8x479xf32>)
+    outs(%fill : tensor<8xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %6 = arith.addf %in, %out : f32
+    linalg.yield %6 : f32
+  } -> tensor<8xf32>
+
+  %empty2 = tensor.empty() : tensor<32xf32>
+  %fill2 = linalg.fill ins(%cst : f32) outs(%empty2 : tensor<32xf32>) -> tensor<32xf32>
+  return %result, %fill2 : tensor<8xf32>, tensor<32xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index f3e80060f5cec6..7aa632733b6c5b 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -633,21 +633,13 @@ mlir::test::TestProduceParamWithNumberOfTestOps::apply(
 }
 
 DiagnosedSilenceableFailure
-mlir::test::TestProduceIntegerParamWithTypeOp::apply(
-    transform::TransformRewriter &rewriter,
-    transform::TransformResults &results, transform::TransformState &state) {
-  Attribute zero = IntegerAttr::get(getType(), 0);
-  results.setParams(llvm::cast<OpResult>(getResult()), zero);
+mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter,
+                                      transform::TransformResults &results,
+                                      transform::TransformState &state) {
+  results.setParams(llvm::cast<OpResult>(getResult()), getAttr());
   return DiagnosedSilenceableFailure::success();
 }
 
-LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() {
-  if (!llvm::isa<IntegerType>(getType())) {
-    return emitOpError() << "expects an integer type";
-  }
-  return success();
-}
-
 void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   transform::onlyReadsHandle(getIn(), effects);

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 02f4955c30cc6e..c1e82a5ad6172c 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -428,15 +428,14 @@ def TestProduceParamWithNumberOfTestOps
   let cppNamespace = "::mlir::test";
 }
 
-def TestProduceIntegerParamWithTypeOp
-  : Op<Transform_Dialect, "test_produce_integer_param_with_type",
+def TestProduceParamOp
+  : Op<Transform_Dialect, "test_produce_param",
        [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
         DeclareOpInterfaceMethods<TransformOpInterface>]> {
-  let arguments = (ins TypeAttr:$type);
+  let arguments = (ins AnyAttr:$attr);
   let results = (outs TransformParamTypeInterface:$result);
-  let assemblyFormat = "$type attr-dict `:` type($result)";
+  let assemblyFormat = "`(` $attr `)` attr-dict `:` type($result)";
   let cppNamespace = "::mlir::test";
-  let hasVerifier = 1;
 }
 
 def TestProduceTransformParamOrForwardOperandOp


        


More information about the Mlir-commits mailing list