[Mlir-commits] [mlir] 5cde5a5 - [mlir] add interchange, pad and scalarize to structured transform dialect

Alex Zinenko llvmlistbot at llvm.org
Mon May 30 02:42:46 PDT 2022


Author: Alex Zinenko
Date: 2022-05-30T11:42:40+02:00
New Revision: 5cde5a5739069a4be7f86a17bd20cc6e8f2daf68

URL: https://github.com/llvm/llvm-project/commit/5cde5a5739069a4be7f86a17bd20cc6e8f2daf68
DIFF: https://github.com/llvm/llvm-project/commit/5cde5a5739069a4be7f86a17bd20cc6e8f2daf68.diff

LOG: [mlir] add interchange, pad and scalarize to structured transform dialect

Add ops to the structured transform extension of the transform dialect that
perform interchange, padding and scalarization on structured ops. Along with
tiling that is already defined, this provides a minimal set of transformations
necessary to build vectorizable code for a single structured op.

Define two helper traits: one that implements TransformOpInterface by applying
a function to each payload op independently and another that provides a simple
"functional-style" producer/consumer list of memory effects for the transform
ops.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D126374

Added: 
    mlir/test/Dialect/Linalg/transform-op-interchange.mlir
    mlir/test/Dialect/Linalg/transform-op-pad.mlir
    mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
    mlir/test/Dialect/Linalg/transform-op-tile.mlir
    mlir/test/Dialect/Linalg/transform-ops-invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 9213e0ddef72..ddd9aa2aef0d 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -9,9 +9,16 @@
 #ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
 #define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
 
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpImplementation.h"
 
+namespace mlir {
+namespace linalg {
+class LinalgOp;
+} // namespace linalg
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // Linalg Transform Operations
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 34a5e9f7140a..3557f3323779 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -16,6 +16,81 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+    TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Interchanges the iterators of the operations pointed to by the target handle
+    using the iterator interchange attribute.
+  }];
+
+  let arguments =
+    (ins PDL_Operation:$target,
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$iterator_interchange);
+  let results = (outs PDL_Operation:$transformed);
+
+  let assemblyFormat = "$target attr-dict";
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+        ::mlir::linalg::LinalgOp target);
+  }];
+}
+
+def PadOp : Op<Transform_Dialect, "structured.pad",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Pads the operations pointed to by the target handle using the options
+    provides as operation attributes.
+  }];
+
+  let arguments =
+    (ins PDL_Operation:$target,
+         DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$hoist_paddings,
+         DefaultValuedAttr<
+          TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
+          "{}">:$transpose_paddings);
+  let results = (outs PDL_Operation:$transformed);
+
+  let assemblyFormat = "$target attr-dict";
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+        ::mlir::linalg::LinalgOp target);
+  }];
+}
+
+def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Indicates that ops of a specific kind in the given function should be
+    scalarized (i.e. their dynamic dimensions tiled by 1).
+
+    This operation returns the tiled op but not the loops.
+
+    We make this design choice because it is hard to know ahead of time the
+    number of loops that will be produced (it depends on the number of dynamic
+    dimensions after multiple transformations have been applied).
+  }];
+
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs PDL_Operation:$result);
+
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+        ::mlir::linalg::LinalgOp target);
+  }];
+}
+
 def TileOp : Op<Transform_Dialect, "structured.tile",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index d029c214b49e..d08267c72c4d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -397,6 +397,31 @@ class PossibleTopLevelTransformOpTrait
   }
 };
 
+/// Trait implementing the TransformOpInterface for operations applying a
+/// transformation to a single operation handle and producing a single operation
+/// handle. The op must implement a method with one of the following signatures:
+///   - FailureOr<convertible-to-Operation*> applyToOne(OpTy)
+///   - LogicalResult applyToOne(OpTy)
+/// to perform a transformation that is applied in turn to all payload IR
+/// operations that correspond to the handle of the transform IR operation.
+/// In the functions above, OpTy is either Operation * or a concrete payload IR
+/// Op class that the transformation is applied to (NOT the class of the
+/// transform IR op). The op is expected to have one operand and zero or one
+/// results.
+template <typename OpTy>
+class TransformEachOpTrait
+    : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
+public:
+  /// Calls `applyToOne` for every payload operation associated with the operand
+  /// of this transform IR op. If `applyToOne` returns ops, associates them with
+  /// the result of this transform op.
+  LogicalResult apply(TransformResults &transformResults,
+                      TransformState &state);
+
+  /// Checks that the op matches the expectations of this trait.
+  static LogicalResult verifyTrait(Operation *op);
+};
+
 /// Side effect resource corresponding to the mapping between Transform IR
 /// values and Payload IR operations. An Allocate effect from this resource
 /// means creating a new mapping entry, it is always accompanied by a Write
@@ -426,9 +451,150 @@ struct PayloadIRResource
   StringRef getName() override { return "transform.payload_ir"; }
 };
 
+/// Trait implementing the MemoryEffectOpInterface for single-operand
+/// single-result operations that "consume" their operand and produce a new
+/// result.
+template <typename OpTy>
+class FunctionalStyleTransformOpTrait
+    : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
+public:
+  /// This op "consumes" the operand by reading and freeing it, "produces" the
+  /// result by allocating and writing it and reads/writes the payload IR in the
+  /// process.
+  void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+    effects.emplace_back(MemoryEffects::Read::get(),
+                         this->getOperation()->getOperand(0),
+                         TransformMappingResource::get());
+    effects.emplace_back(MemoryEffects::Free::get(),
+                         this->getOperation()->getOperand(0),
+                         TransformMappingResource::get());
+    effects.emplace_back(MemoryEffects::Allocate::get(),
+                         this->getOperation()->getResult(0),
+                         TransformMappingResource::get());
+    effects.emplace_back(MemoryEffects::Write::get(),
+                         this->getOperation()->getResult(0),
+                         TransformMappingResource::get());
+    effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
+    effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
+  }
+
+  /// Checks that the op matches the expectations of this trait.
+  static LogicalResult verifyTrait(Operation *op) {
+    static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
+                  "expected single-operand op");
+    static_assert(OpTy::template hasTrait<OpTrait::OneResult>(),
+                  "expected single-result op");
+    if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
+      op->emitError()
+          << "FunctionalStyleTransformOpTrait should only be attached to ops "
+             "that implement MemoryEffectOpInterface";
+    }
+    return success();
+  }
+};
+
 } // namespace transform
 } // namespace mlir
 
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
 
+namespace mlir {
+namespace transform {
+namespace detail {
+/// Appends `result` to the vector assuming it corresponds to the success state
+/// in `FailureOr<convertible-to-Operation*>`. If `result` is just a
+/// `LogicalResult`, does nothing.
+template <typename Ty>
+std::enable_if_t<std::is_same<Ty, LogicalResult>::value, LogicalResult>
+appendTransformResultToVector(Ty result,
+                              SmallVectorImpl<Operation *> &results) {
+  return result;
+}
+template <typename Ty>
+std::enable_if_t<!std::is_same<Ty, LogicalResult>::value, LogicalResult>
+appendTransformResultToVector(Ty result,
+                              SmallVectorImpl<Operation *> &results) {
+  static_assert(
+      std::is_convertible<typename Ty::value_type, Operation *>::value,
+      "expected transform function to return operations");
+  if (failed(result))
+    return failure();
+
+  results.push_back(*result);
+  return success();
+}
+
+/// Applies a one-to-one transform to each of the given targets. Puts the
+/// results of transforms, if any, in `results` in the same order. Fails if any
+/// of the application fails. Individual transforms must be callable with
+/// one of the following signatures:
+///   - FailureOr<convertible-to-Operation*>(OpTy)
+///   - LogicalResult(OpTy)
+/// where OpTy is either
+///   - Operation *, in which case the transform is always applied;
+///   - a concrete Op class, in which case a check is performed whether
+///   `targets` contains operations of the same class and a failure is reported
+///   if it does not.
+template <typename FnTy>
+LogicalResult applyTransformToEach(ArrayRef<Operation *> targets,
+                                   SmallVectorImpl<Operation *> &results,
+                                   FnTy transform) {
+  using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
+  static_assert(std::is_convertible<OpTy, Operation *>::value,
+                "expected transform function to take an operation");
+  using RetTy = typename llvm::function_traits<FnTy>::result_t;
+  static_assert(std::is_convertible<RetTy, LogicalResult>::value,
+                "expected transform function to return LogicalResult or "
+                "FailureOr<convertible-to-Operation*>");
+  for (Operation *target : targets) {
+    auto specificOp = dyn_cast<OpTy>(target);
+    if (!specificOp)
+      return failure();
+
+    auto result = transform(specificOp);
+    if (failed(appendTransformResultToVector(result, results)))
+      return failure();
+  }
+  return success();
+}
+} // namespace detail
+} // namespace transform
+} // namespace mlir
+
+template <typename OpTy>
+mlir::LogicalResult mlir::transform::TransformEachOpTrait<OpTy>::apply(
+    TransformResults &transformResults, TransformState &state) {
+  using TransformOpType = typename llvm::function_traits<
+      decltype(&OpTy::applyToOne)>::template arg_t<0>;
+  ArrayRef<Operation *> targets =
+      state.getPayloadOps(this->getOperation()->getOperand(0));
+  SmallVector<Operation *> results;
+  if (failed(detail::applyTransformToEach(
+          targets, results, [&](TransformOpType specificOp) {
+            return static_cast<OpTy *>(this)->applyToOne(specificOp);
+          })))
+    return failure();
+  if (OpTy::template hasTrait<OpTrait::OneResult>()) {
+    transformResults.set(
+        this->getOperation()->getResult(0).template cast<OpResult>(), results);
+  }
+  return success();
+}
+
+template <typename OpTy>
+mlir::LogicalResult
+mlir::transform::TransformEachOpTrait<OpTy>::verifyTrait(Operation *op) {
+  static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
+                "expected single-operand op");
+  static_assert(OpTy::template hasTrait<OpTrait::OneResult>() ||
+                    OpTy::template hasTrait<OpTrait::ZeroResults>(),
+                "expected zero- or single-result op");
+  if (!op->getName().getInterface<TransformOpInterface>()) {
+    return op->emitError() << "TransformEachOpTrait should only be attached to "
+                              "ops that implement TransformOpInterface";
+  }
+
+  return success();
+}
+
 #endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index 5b8d4202f0b4..fad845cc7f51 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -49,4 +49,13 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
   ];
 }
 
+def FunctionalStyleTransformOpTrait
+    : NativeOpTrait<"FunctionalStyleTransformOpTrait"> {
+  let cppNamespace = "::mlir::transform";
+}
+
+def TransformEachOpTrait : NativeOpTrait<"TransformEachOpTrait"> {
+  let cppNamespace = "::mlir::transform";
+}
+
 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 89c5815c6295..689157d01a83 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -49,6 +49,179 @@ class SimpleRewriter : public PatternRewriter {
 };
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// InterchangeOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
+  SmallVector<unsigned> interchangeVector =
+      extractUIntArray(getIteratorInterchange());
+  // Exit early if no transformation is needed.
+  if (interchangeVector.empty())
+    return target;
+
+  auto genericTarget = dyn_cast<GenericOp>(target.getOperation());
+  if (!genericTarget) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "applies to " << GenericOp::getOperationName()
+                              << " ops";
+    diag.attachNote(target.getLoc()) << "attempted to apply to this op";
+    return diag;
+  }
+
+  GenericOpInterchangePattern pattern(getContext(), interchangeVector);
+  SimpleRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(target);
+  FailureOr<GenericOp> result =
+      pattern.returningMatchAndRewrite(genericTarget, rewriter);
+  if (failed(result))
+    return failure();
+
+  return cast<LinalgOp>(result->getOperation());
+}
+
+LogicalResult transform::InterchangeOp::verify() {
+  SmallVector<unsigned> permutation =
+      extractUIntArray(getIteratorInterchange());
+  auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
+  if (!std::is_permutation(sequence.begin(), sequence.end(),
+                           permutation.begin(), permutation.end())) {
+    return emitOpError()
+           << "expects iterator_interchange to be a permutation, found "
+           << getIteratorInterchange();
+  }
+  return success();
+}
+
+//===---------------------------------------------------------------------===//
+// PadOp
+//===---------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
+  // Convert the integer packing flags to booleans.
+  SmallVector<bool> packPaddings;
+  for (int64_t packPadding : extractI64Array(getPackPaddings()))
+    packPaddings.push_back(static_cast<bool>(packPadding));
+
+  // Convert the padding values to attributes.
+  SmallVector<Attribute> paddingValues;
+  for (auto const &it :
+       llvm::zip(getPaddingValues(), target->getOperandTypes())) {
+    Attribute attr = std::get<0>(it);
+    Type elementType = getElementTypeOrSelf(std::get<1>(it));
+    // Try to parse string attributes to obtain an attribute of element type.
+    if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
+      paddingValues.push_back(
+          parseAttribute(attr.cast<StringAttr>(), elementType));
+      if (!paddingValues.back()) {
+        InFlightDiagnostic diag = emitOpError()
+                                  << "expects a padding value that parses to "
+                                  << elementType << ", got " << std::get<0>(it);
+        diag.attachNote(target.getLoc()) << "when applied to this op";
+        return diag;
+      }
+      continue;
+    }
+    // Otherwise, add the attribute directly.
+    if (attr.getType() != elementType) {
+      InFlightDiagnostic diag = emitOpError()
+                                << "expects a padding value of type "
+                                << elementType << ", got " << attr;
+      diag.attachNote(target.getLoc()) << "when applied to this op";
+      return diag;
+    }
+    paddingValues.push_back(attr);
+  }
+
+  // Extract the transpose vectors.
+  SmallVector<SmallVector<int64_t>> transposePaddings;
+  for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
+    transposePaddings.push_back(
+        extractI64Array(transposeVector.cast<ArrayAttr>()));
+
+  LinalgPaddingOptions paddingOptions;
+  paddingOptions.setPaddingValues(paddingValues);
+  paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
+  paddingOptions.setPackPaddings(packPaddings);
+  paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
+  paddingOptions.setTransposePaddings(transposePaddings);
+
+  LinalgPaddingPattern pattern(getContext(), paddingOptions);
+  SimpleRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(target);
+  FailureOr<LinalgOp> patternResult =
+      pattern.returningMatchAndRewrite(target, rewriter);
+  if (failed(patternResult)) {
+    InFlightDiagnostic diag = emitError()
+                              << "failed to apply pattern to target op";
+    diag.attachNote(target.getLoc()) << "target op";
+    return diag;
+  }
+  return patternResult;
+}
+
+LogicalResult transform::PadOp::verify() {
+  SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
+  if (any_of(packPaddings, [](int64_t packPadding) {
+        return packPadding != 0 && packPadding != 1;
+      })) {
+    return emitOpError()
+           << "expects pack_paddings to contain booleans (0/1), found "
+           << getPackPaddings();
+  }
+
+  SmallVector<int64_t> paddingDimensions =
+      extractI64Array(getPaddingDimensions());
+  if (any_of(paddingDimensions,
+             [](int64_t paddingDimension) { return paddingDimension < 0; })) {
+    return emitOpError()
+           << "expects padding_dimensions to contain positive integers, found "
+           << getPaddingDimensions();
+  }
+
+  SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
+  if (any_of(hoistPaddings,
+             [](int64_t hoistPadding) { return hoistPadding < 0; })) {
+    return emitOpError()
+           << "expects hoist_paddings to contain positive integers, found "
+           << getHoistPaddings();
+  }
+
+  ArrayAttr transposes = getTransposePaddings();
+  for (Attribute attr : transposes) {
+    SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
+    auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
+    if (!std::is_permutation(sequence.begin(), sequence.end(),
+                             transpose.begin(), transpose.end())) {
+      return emitOpError()
+             << "expects transpose_paddings to be a permutation, found "
+             << attr;
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ScalarizeOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
+  LinalgTilingOptions tilingOptions;
+  tilingOptions.scalarizeDynamicDims();
+  // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
+  // sizes and asserts that it is not already set.
+  SmallVector<int64_t> emptyTileSizes;
+  LinalgTilingPattern pattern(getContext(), tilingOptions);
+  SimpleRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(target);
+  FailureOr<TiledLinalgOp> result =
+      pattern.returningMatchAndRewrite(target, rewriter);
+  if (failed(result))
+    return failure();
+
+  return result->op;
+}
+
 //===----------------------------------------------------------------------===//
 // TileOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir
new file mode 100644
index 000000000000..cb8badec2004
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+//       CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @interchange_generic
+func.func @interchange_generic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  //      CHECK:   linalg.generic
+  // CHECK-SAME:   indexing_maps = [#[[$MAP]], #[[$MAP]]
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):
+    %1 = math.exp %arg2 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_generic : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.generic"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_generic in %arg1
+    transform.structured.interchange %0 { iterator_interchange = [1, 0]}
+  }
+}
+
+// -----
+
+func.func @interchange_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-note @below {{attempted to apply to this op}}
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @match_generic : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @match_generic in %arg1
+    // expected-error @below {{applies to linalg.generic ops}}
+    transform.structured.interchange %0 { iterator_interchange = [1, 0]}
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
new file mode 100644
index 000000000000..2a9025aa1f6a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter -split-input-file -verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<()[s0] -> (-s0 + 12, 7)>
+
+// CHECK-LABEL: @static_sizes_output_divisible
+func.func @static_sizes_output_divisible(%arg0: tensor<24x12xf32>,
+                                         %arg1: tensor<12x25xf32>,
+                                         %arg2: tensor<24x25xf32>,
+                                         %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> {
+  %0 = affine.min #map()[%iv2]
+
+  //      CHECK: %[[T0:.*]] = tensor.extract_slice %
+  //      CHECK: %[[T1:.*]] = tensor.extract_slice %
+  //      CHECK: %[[T2:.*]] = tensor.extract_slice %
+  %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
+  %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor<?x5xf32>
+  %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32>
+
+  //  CHECK-DAG: %[[CST:.*]] = arith.constant 0.
+  //  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+  //      CHECK: %[[T3:.*]] = tensor.pad %[[T0]] nofold
+  //      CHECK: tensor.yield %[[CST]]
+  //      CHECK: %[[T4:.*]] = tensor.pad %[[T1]] nofold
+
+  //      CHECK: %[[T5:.*]] = linalg.matmul
+  // CHECK-SAME:              ins(%[[T3]], %[[T4]] : tensor<4x7xf32>, tensor<7x5xf32>)
+  // CHECK-SAME:              outs(%[[T2]] : tensor<4x5xf32>)
+  %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
+  %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
+  func.return %5 : tensor<24x25xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
+  }
+}
+
+// -----
+
+func.func @pad(%arg0: tensor<24x12xf32>,
+               %arg1: tensor<12x25xf32>,
+               %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  // expected-note @below {{when applied to this op}} 
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    // expected-error @below {{op expects a padding value of type 'f32', got 0 : i32}}
+    %1 = transform.structured.pad %0 {padding_values=[0: i32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
+  }
+}
+
+// -----
+
+func.func @pad(%arg0: tensor<24x12xf32>,
+               %arg1: tensor<12x25xf32>,
+               %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  // expected-note @below {{when applied to this op}}
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    // expected-error @below {{expects a padding value that parses to 'f32', got "foo"}}
+    %1 = transform.structured.pad %0 {padding_values=["foo", 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
+  }
+}
+
+// -----
+
+func.func @pad(%arg0: tensor<24x12xf32>,
+               %arg1: tensor<12x25xf32>,
+               %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  // expected-note @below {{target op}}
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    // expected-error @below {{failed to apply pattern to target op}}
+    %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
new file mode 100644
index 000000000000..ab25777adeef
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt -test-transform-dialect-interpreter %s | FileCheck %s
+
+func.func @scalarize(%arg0: tensor<24x12xf32>,
+                     %arg1: tensor<12x25xf32>,
+                     %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  // The op is first tiled by 10 in the first dimension, which creates a
+  // dynamic size, and then scalarized, which brings the dimension to static 1.
+  // CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x12
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1, %loops = transform.structured.tile %0 {sizes = [10, 0, 0]}
+    %2 = transform.structured.scalarize %1
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
new file mode 100644
index 000000000000..a5310944357d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  sequence %arg0 {
+    ^bb0(%arg1: !pdl.operation):
+      %0 = pdl_match @pdl_target in %arg1
+      %1, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]}
+  }
+
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    rewrite %0 with "transform.dialect"
+  }
+}
+
+// CHECK-LABEL: func @tile_linalg_matmul(
+// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME:  -> tensor<128x128xf32> {
+func.func @tile_linalg_matmul(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32> {
+//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) {
+//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) {
+//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) {
+//      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
+//      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
+//      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
+//      CHECK:       %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<4x4xf32>, tensor<4x4xf32>)
+// CHECK-SAME:                                   outs(%[[sTC]] : tensor<4x4xf32>)  -> tensor<4x4xf32>
+//      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<4x4xf32> into tensor<128x128xf32>
+//      CHECK:       scf.yield %[[TD]] : tensor<128x128xf32>
+//      CHECK:     scf.yield %[[TD2]] : tensor<128x128xf32>
+//      CHECK:   scf.yield %[[TD1]] : tensor<128x128xf32>
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+
+//      CHECK: return %[[TD0]] : tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+

diff  --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
new file mode 100644
index 000000000000..9b80b2300f77
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error at below {{expects iterator_interchange to be a permutation, found [1, 1]}}
+  transform.structured.interchange %arg0 {iterator_interchange = [1, 1]}
+}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error at below {{expects padding_dimensions to contain positive integers, found [1, -7]}}
+  transform.structured.pad %arg0 {padding_dimensions=[1, -7]}
+}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error at below {{expects pack_paddings to contain booleans (0/1), found [1, 7]}}
+  transform.structured.pad %arg0 {pack_paddings=[1, 7]}
+}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error at below {{expects hoist_paddings to contain positive integers, found [1, -7]}}
+  transform.structured.pad %arg0 {hoist_paddings=[1, -7]}
+}
+
+// -----
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // expected-error at below {{expects transpose_paddings to be a permutation, found [1, 1]}}
+  transform.structured.pad %arg0 {transpose_paddings=[[1, 1]]}
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir
index a5310944357d..ae01f3d571d3 100644
--- a/mlir/test/Dialect/Linalg/transform-ops.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops.mlir
@@ -1,46 +1,31 @@
-// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
-transform.with_pdl_patterns {
-^bb0(%arg0: !pdl.operation):
-  sequence %arg0 {
-    ^bb0(%arg1: !pdl.operation):
-      %0 = pdl_match @pdl_target in %arg1
-      %1, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]}
-  }
-
-  pdl.pattern @pdl_target : benefit(1) {
-    %args = operands
-    %results = types
-    %0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
-    rewrite %0 with "transform.dialect"
-  }
+transform.sequence {
+^bb1(%arg0: !pdl.operation):
+  // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile
+  %0, %1:2 = transform.structured.tile %arg0 { sizes = [2, 0, 3] }
 }
 
-// CHECK-LABEL: func @tile_linalg_matmul(
-// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-SAME:  -> tensor<128x128xf32> {
-func.func @tile_linalg_matmul(
-  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
-    -> tensor<128x128xf32> {
-//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) {
-//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) {
-//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) {
-//      CHECK:       %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
-//      CHECK:       %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
-//      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<4x4xf32>
-//      CHECK:       %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<4x4xf32>, tensor<4x4xf32>)
-// CHECK-SAME:                                   outs(%[[sTC]] : tensor<4x4xf32>)  -> tensor<4x4xf32>
-//      CHECK:       %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<4x4xf32> into tensor<128x128xf32>
-//      CHECK:       scf.yield %[[TD]] : tensor<128x128xf32>
-//      CHECK:     scf.yield %[[TD2]] : tensor<128x128xf32>
-//      CHECK:   scf.yield %[[TD1]] : tensor<128x128xf32>
-  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
-                     outs(%arg2: tensor<128x128xf32>)
-    -> tensor<128x128xf32>
+//===----------------------------------------------------------------------===//
+// Check that operations are registered correctly through the extension
+// mechanism. Their syntax is generated and requries no additional testing since
+// we test the generator.
+//===----------------------------------------------------------------------===//
+
+transform.sequence {
+^bb1(%arg0: !pdl.operation):
+  // CHECK: transform.structured.pad
+  %0 = transform.structured.pad %arg0
+}
 
-//      CHECK: return %[[TD0]] : tensor<128x128xf32>
-  return %0 : tensor<128x128xf32>
+transform.sequence {
+^bb1(%arg0: !pdl.operation):
+  // CHECK: transform.structured.interchange
+  %0 = transform.structured.interchange %arg0
 }
 
+transform.sequence {
+^bb1(%arg0: !pdl.operation):
+  // CHECK: transform.structured.scalarize
+  %0 = transform.structured.scalarize %arg0
+}


        


More information about the Mlir-commits mailing list