[Mlir-commits] [mlir] ff6e550 - [mlir] Structured transforms: introduce op splitting

Alex Zinenko llvmlistbot at llvm.org
Thu Jul 7 04:19:56 PDT 2022


Author: Alex Zinenko
Date: 2022-07-07T13:19:44+02:00
New Revision: ff6e5508d686395dfb5f26085fadeae174847d52

URL: https://github.com/llvm/llvm-project/commit/ff6e5508d686395dfb5f26085fadeae174847d52
DIFF: https://github.com/llvm/llvm-project/commit/ff6e5508d686395dfb5f26085fadeae174847d52.diff

LOG: [mlir] Structured transforms: introduce op splitting

Introduce a new transformation on structured ops that splits the iteration
space into two parts along the specified dimension. The index at which the
splitting happens may be static or dynamic. This transformation can be seen as
a rudimentary form of index-set splitting that only supports the splitting
along hyperplanes parallel to the iteration space hyperplanes, and is therefore
decomposable into per-dimension application.

It is a key low-level transformation that enables independent scheduling for
different parts of the iteration space of the same op, which hasn't been
possible previously. It may be used to implement, e.g., multi-sized tiling. In
future, peeling can be implemented as a combination of split-off amount
computation and splitting.

The transformation is conceptually close to tiling in its separation of the
iteration and data spaces, but cannot be currently implemented on top of
TilingInterface as the latter does not properly support `linalg.index`
offsetting.

Note that the transformation intentionally bypasses folding of
`tensor.extract_slice` operations when creating them as this folding was found
to prevent repeated splitting of the same operation because due to internal
assumptions about extract/insert_slice combination in dialect utilities.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129090

Added: 
    mlir/lib/Dialect/Linalg/Transforms/Split.cpp
    mlir/test/Dialect/Linalg/transform-op-split.mlir

Modified: 
    mlir/include/mlir-c/BuiltinTypes.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/lib/CAPI/IR/BuiltinTypes.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 495591464a494..d1083f9323bf0 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -150,10 +150,19 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type,
 /// in shaped types.
 MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size);
 
+/// Returns the value indicating a dynamic size in a shaped type. Prefer
+/// mlirShapedTypeIsDynamicSize to direct comparisons with this value.
+MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize();
+
 /// Checks whether the given value is used as a placeholder for dynamic strides
 /// and offsets in shaped types.
 MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val);
 
+/// Returns the value indicating a dynamic stride or offset in a shaped type.
+/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with
+/// this value.
+MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset();
+
 //===----------------------------------------------------------------------===//
 // Vector type.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 3e47a0751eaca..6c1d1fef4ee50 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -25,6 +25,7 @@ namespace mlir {
 class AffineApplyOp;
 class AffineBound;
 class AffineValueMap;
+class IRRewriter;
 
 /// TODO: These should be renamed if they are on the mlir namespace.
 ///       Ideally, they should go in a mlir::affine:: namespace.
@@ -384,6 +385,12 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
 SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
                                        AffineMap map, ValueRange values);
 
+/// Returns the values obtained by applying `map` to the list of values, which
+/// may be known constants.
+SmallVector<OpFoldResult> applyMapToValues(IRRewriter &b, Location loc,
+                                           AffineMap map,
+                                           ArrayRef<OpFoldResult> values);
+
 /// Given an affine map `map` and its input `operands`, this method composes
 /// into `map`, maps of AffineApplyOps whose results are the values in
 /// `operands`, iteratively until no more of `operands` are the result of an

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 461388dd61af3..13ef4656d5077 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -153,6 +153,38 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   }];
 }
 
+def SplitOp : Op<Transform_Dialect, "structured.split",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Indicates that the given `target` op should be split into two complementary
+    parts, which combined cover the entire iteration domain of the original op.
+    The split is performed along the iteration space dimension provided as
+    attribute. In case of dimension overflow, the transformation fails. The
+    split is performed at the dimension iterator value specified as either the
+    static split point attribute when it is known at transform IR construction
+    time or as the handle to an operation producing a single index-typed value
+    when it is computed by payload IR. In the latter case, the static split
+    point must be set to `ShapedType::kDynamicSize` and the dynamic size handle
+    must point to as many value-producing operations as there are structured
+    operations pointed to by the target handle.
+
+    The operation consumes the target handle, but preserves the split point
+    handle if provided. It produces two new handles pointing to the two parts
+    of the structured op after splitting, in the same order as the target
+    operand, with the first handle corresponding to the part with lower
+    iteration space indices.
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                       I64Attr:$dimension,
+                       Optional<PDL_Operation>:$dynamic_split_point,
+                       I64Attr:$static_split_point);
+  let results = (outs PDL_Operation:$first, PDL_Operation:$second);
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
         TransformEachOpTrait, TransformOpInterface]> {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6b3230ade0033..748f5440457ad 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -106,6 +106,34 @@ void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
 /// Patterns that are used to bubble up extract slice op above linalg op.
 void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
 
+/// Split the given `op` into two parts along the given iteration space
+/// `dimension` at the specified `splitPoint`, and return the two parts.
+///
+/// For example, the following op:
+///
+///   linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>)
+///                 outs(%2 : tensor<128x64xf32>)
+///
+/// split along the first dimension at position 42 will result in:
+///
+///   %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1]
+///   %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1]
+///   %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
+///                      outs(%5 : tensor<42x64xf32>)
+///   %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1]
+///
+///   %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1]
+///   %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1]
+///   %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>)
+///                      outs(%8 : tensor<86x64xf32>)
+///   tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1]
+///
+/// Note that there is no simplification other than constant propagation applied
+/// to slice extraction and insertion.
+std::pair<LinalgOp, LinalgOp> splitOp(RewriterBase &rewriter, LinalgOp op,
+                                      unsigned dimension,
+                                      OpFoldResult splitPoint);
+
 /// Perform standalone tiling of a single LinalgOp by `tileSizes`.
 /// and permute the loop nest according to `interchangeVector`
 /// The permutation is expressed as a list of integers that specify

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index a68701b1b04a1..9914f0a32872a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -177,12 +177,18 @@ bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
 bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
                    Value consumedView, LinalgOp producer);
 
-/// Compute tile offsets, given a list of loop `ivs` and `tileSizes`. In case a
+/// Creates either a memref.subview or a tensor.extract_slice with the given
+/// offsets/sizes/strides based on the type of `value`.
+Value createSlice(OpBuilder &builder, Location loc, Value value,
+                  ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+                  ArrayRef<OpFoldResult> strides);
+
+/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a
 /// tile size is zero (i.e., no tiling), the corresponding offset is also zero.
 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
                                       ValueRange ivs, ValueRange tileSizes);
 
-/// Compute tile sizes, given a list of `tileSizes` and dimension
+/// Computes tile sizes, given a list of `tileSizes` and dimension
 /// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the
 /// corresponding result size is the corresponding value from `sizeBounds`.
 /// Note: The returned tile sizes are closed intervals.
@@ -190,6 +196,20 @@ SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
                                     ValueRange tileSizes,
                                     ArrayRef<Value> sizeBounds);
 
+/// Returns the list of tensor output types produced when the given structured
+/// operation `op` is applied to the given `operands`. Note that `operands` are
+/// not necessarily the actual operands of `op`.
+SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands);
+
+/// Creates `insert_slice` ops that insert `results` back into larger tensors
+/// they were originally extracted from with `extract_slice` before being passed
+/// as `operands` to the given structured operation `op` or its clone. Note that
+/// `operands` are not necessarily the actual operands of `op`, the operation
+/// serves only as metadata container for operand types and positions.
+SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
+                                    LinalgOp op, ValueRange operands,
+                                    ValueRange results);
+
 /// Creates an extract_slice/subview op for a single `valueToTile` with
 /// `builder`. This new operation extracts a tile of `valueToTile`, starting
 /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index d93d9f66b159b..153664d0771dd 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -301,6 +301,15 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
           return shape;
         },
         "Returns the shape of the ranked shaped type as a list of integers.");
+    c.def_static(
+        "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
+        "Returns the value used to indicate dynamic dimensions in shaped "
+        "types.");
+    c.def_static(
+        "_get_dynamic_stride_or_offset",
+        []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
+        "Returns the value used to indicate dynamic strides or offsets in "
+        "shaped types.");
   }
 
 private:

diff  --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 446f9c4d4889c..be44b76e8c615 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -149,6 +149,8 @@ int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
   return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
 }
 
+int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamicSize; }
+
 bool mlirShapedTypeIsDynamicSize(int64_t size) {
   return ShapedType::isDynamic(size);
 }
@@ -157,6 +159,10 @@ bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
   return ShapedType::isDynamicStrideOrOffset(val);
 }
 
+int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
+  return ShapedType::kDynamicStrideOrOffset;
+}
+
 //===----------------------------------------------------------------------===//
 // Vector type.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 9986a9235101c..075a6cc11459c 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -748,6 +748,76 @@ SmallVector<Value, 4> mlir::applyMapToValues(OpBuilder &b, Location loc,
   return res;
 }
 
+SmallVector<OpFoldResult>
+mlir::applyMapToValues(IRRewriter &b, Location loc, AffineMap map,
+                       ArrayRef<OpFoldResult> values) {
+  // Materialize constants and keep track of produced operations so we can clean
+  // them up later.
+  SmallVector<Operation *> constants;
+  SmallVector<Value> actualValues;
+  actualValues.reserve(values.size());
+  auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
+  for (OpFoldResult ofr : values) {
+    if (auto value = ofr.dyn_cast<Value>()) {
+      actualValues.push_back(value);
+      continue;
+    }
+    constants.push_back(dialect->materializeConstant(b, ofr.get<Attribute>(),
+                                                     b.getIndexType(), loc));
+    actualValues.push_back(constants.back()->getResult(0));
+  }
+
+  // Compose, fold and construct maps for each result independently because they
+  // may simplify more effectively.
+  SmallVector<OpFoldResult> results;
+  results.reserve(map.getNumResults());
+  bool foldedAll = true;
+  for (auto i : llvm::seq<unsigned>(0, map.getNumResults())) {
+    AffineMap submap = map.getSubMap({i});
+    SmallVector<Value> operands = actualValues;
+    fullyComposeAffineMapAndOperands(&submap, &operands);
+    canonicalizeMapAndOperands(&submap, &operands);
+
+    // Identify the constant operands and extract their values as attributes.
+    // Note that we cannot use the original values directly because the list of
+    // operands may have changed due to canonicalization and composition.
+    SmallVector<Attribute> constantOperands;
+    constantOperands.reserve(operands.size());
+    for (Value operand : operands) {
+      IntegerAttr attr;
+      if (matchPattern(operand, m_Constant(&attr)))
+        constantOperands.push_back(attr);
+      else
+        constantOperands.push_back(nullptr);
+    }
+
+    // Create an apply operation and immediately attempt to fold it. On sucess,
+    // delete the operation and prepare the (unmaterialized) value for being
+    // returned. On failure, return the function result.
+    // TODO: arguably, the main folder (createOrFold) API should support this
+    // use case instead of indiscriminately materializing constants.
+    auto apply = b.create<AffineApplyOp>(loc, submap, operands);
+    SmallVector<OpFoldResult, 1> foldResult;
+    if (succeeded(apply->fold(constantOperands, foldResult))) {
+      assert(foldResult.size() == 1 && "expected single-result map");
+      b.eraseOp(apply);
+      results.push_back(foldResult.front());
+    } else {
+      results.push_back(apply.getResult());
+      foldedAll = false;
+    }
+  }
+
+  // If the entire map could be folded, remove the constants that were used in
+  // the initial ops.
+  if (foldedAll) {
+    for (Operation *constant : constants)
+      b.eraseOp(constant);
+  }
+
+  return results;
+}
+
 // A symbol may appear as a dim in affine.apply operations. This function
 // canonicalizes dims that are valid symbols into actual symbols.
 template <class MapOrSet>

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b644848c53172..26e5db6b2e91a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -399,6 +399,161 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target,
   return result->op;
 }
 
+//===----------------------------------------------------------------------===//
+// SplitOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
+                                           TransformState &state) {
+  // Collect the dynamic split points if provided.
+  ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
+  SimpleRewriter rewriter(getContext());
+  SmallVector<OpFoldResult> splitPoints;
+  splitPoints.reserve(payload.size());
+  if (getDynamicSplitPoint()) {
+    auto diag = DiagnosedSilenceableFailure::success();
+    splitPoints = llvm::to_vector(llvm::map_range(
+        state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
+          if (op->getNumResults() != 1 ||
+              !op->getResult(0).getType().isIndex()) {
+            diag = emitSilenceableError()
+                   << "expected dynamic split point handle to point to a "
+                      "single-result index-typed op";
+            diag.attachNote(op->getLoc()) << "dynamic split point";
+          }
+          return OpFoldResult(op->getResult(0));
+        }));
+    if (!diag.succeeded())
+      return diag;
+
+    if (splitPoints.size() != payload.size()) {
+      emitError() << "expected the dynamic split point handle to point to as "
+                     "many operations ("
+                  << splitPoints.size() << ") as the target handle ("
+                  << payload.size() << ")";
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+  } else {
+    splitPoints.resize(payload.size(),
+                       rewriter.getIndexAttr(getStaticSplitPoint()));
+  }
+
+  // Split each target operation.
+  SmallVector<Operation *> first, second;
+  for (const auto &pair : llvm::zip(payload, splitPoints)) {
+    Operation *target = std::get<0>(pair);
+    auto linalgOp = dyn_cast<LinalgOp>(target);
+    if (!linalgOp) {
+      auto diag = emitSilenceableError() << "only applies to structured ops";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    if (getDimension() >= linalgOp.getNumLoops()) {
+      auto diag = emitSilenceableError() << "dimension " << getDimension()
+                                         << " does not exist in target op";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    rewriter.setInsertionPoint(linalgOp);
+    std::tie(first.emplace_back(), second.emplace_back()) =
+        linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
+  }
+
+  results.set(getFirst().cast<OpResult>(), first);
+  results.set(getSecond().cast<OpResult>(), second);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void SplitOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  // The target handle is consumed.
+  effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
+                       TransformMappingResource::get());
+  effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
+                       TransformMappingResource::get());
+
+  // The dynamic split point handle is not consumed.
+  if (getDynamicSplitPoint()) {
+    effects.emplace_back(MemoryEffects::Read::get(), getDynamicSplitPoint(),
+                         TransformMappingResource::get());
+  }
+
+  // The resulting handles are produced.
+  for (Value result : getResults()) {
+    effects.emplace_back(MemoryEffects::Allocate::get(), result,
+                         TransformMappingResource::get());
+    effects.emplace_back(MemoryEffects::Write::get(), result,
+                         TransformMappingResource::get());
+  }
+
+  effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
+  effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
+}
+
+ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
+  IntegerAttr staticSplitPoint;
+  auto pdlOperationType =
+      pdl::OperationType::get(parser.getBuilder().getContext());
+  if (parser.parseOperand(target) ||
+      parser.resolveOperand(target, pdlOperationType, result.operands) ||
+      parser.parseKeyword("after"))
+    return failure();
+
+  OptionalParseResult dynamicPointParseResult =
+      parser.parseOptionalOperand(dynamicSplitPoint);
+  if (!dynamicPointParseResult.hasValue()) {
+    int64_t staticSplitPointValue;
+    if (failed(parser.parseInteger(staticSplitPointValue)))
+      return failure();
+
+    staticSplitPoint =
+        parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
+  } else {
+    if (failed(*dynamicPointParseResult) ||
+        parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
+                              result.operands)) {
+      return failure();
+    }
+
+    staticSplitPoint =
+        parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
+  }
+
+  result.addAttribute(
+      SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
+      staticSplitPoint);
+  if (failed(parser.parseOptionalAttrDict(result.attributes)))
+    return failure();
+
+  result.addTypes({pdlOperationType, pdlOperationType});
+  return success();
+}
+
+void SplitOp::print(OpAsmPrinter &printer) {
+  printer << " " << getTarget() << " after ";
+  int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
+  if (staticSplitSize != ShapedType::kDynamicSize)
+    printer << staticSplitSize;
+  else
+    printer << getDynamicSplitPoint();
+  printer << " ";
+  printer.printOptionalAttrDict(getOperation()->getAttrs(),
+                                {getStaticSplitPointAttrName()});
+}
+
+LogicalResult SplitOp::verify() {
+  if ((static_cast<int64_t>(getStaticSplitPoint()) !=
+       ShapedType::kDynamicSize) ^
+      (getDynamicSplitPoint() == nullptr)) {
+    return emitOpError()
+           << "expects either a dynamic or a static split point to be provided";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SplitReductionOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index d6f4061a71785..8015edeb59a92 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   NamedOpConversions.cpp
   Promotion.cpp
   SparseTensorRewriting.cpp
+  Split.cpp
   SplitReduction.cpp
   Tiling.cpp
   TilingInterfaceImpl.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
new file mode 100644
index 0000000000000..7d6fb66041d3a
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -0,0 +1,158 @@
+//===- Split.cpp - Structured op splitting --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+
+#include "llvm/ADT/STLExtras.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Turns an OpFoldResult into a value, creating an index-typed constant if
+/// necessary.
+static Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
+                                     OpFoldResult opFoldResult) {
+  if (opFoldResult.is<Value>())
+    return opFoldResult.get<Value>();
+  auto attr = opFoldResult.get<Attribute>().cast<IntegerAttr>();
+  return builder.create<arith::ConstantIndexOp>(attr.getValue().getSExtValue());
+}
+
+/// Extract the slices of `operands` supplied to the given operation `op` such
+/// that they are sufficient to execute the op for the subset of its iteration
+/// space defined by `splitIterationSpace`. The subset is a part of the original
+/// iteration space split at the given `dimension`. If `offset` is provided, it
+/// indicates the iterator value at which the dimension has been split and
+/// requires the "high" part starting at the given offset of the operands to be
+/// generated; otherwise, the "low" part with no offset is generated. Note that
+/// `operands` are not necessarily the actual operands of `op`.
+static SmallVector<Value>
+getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op,
+                 ValueRange splitIterationSpace, ValueRange operands,
+                 unsigned dimension, Value offset = nullptr) {
+  SmallVector<Value> slices;
+  slices.reserve(op.getNumInputsAndOutputs());
+  for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
+    auto type = opOperand->get().getType().dyn_cast<ShapedType>();
+    AffineMap indexing = op.getTiedIndexingMap(opOperand);
+
+    // If the type is not sliceable, or the slice is requested along the
+    // dimension that is not used in indexing this type, just use the entire
+    // operand.
+    if (!type || dimension >= indexing.getNumDims() ||
+        !indexing.isFunctionOfDim(dimension)) {
+      slices.push_back(opOperand->get());
+      continue;
+    }
+
+    SmallVector<Value, 4> sizes =
+        applyMapToValues(builder, op.getLoc(), indexing, splitIterationSpace);
+    SmallVector<OpFoldResult> offsets(type.getRank(), builder.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(type.getRank(), builder.getIndexAttr(1));
+
+    if (offset) {
+      offsets[dimension] = offset;
+      IRRewriter rewriter(builder);
+      offsets = applyMapToValues(rewriter, builder.getLoc(), indexing, offsets);
+    }
+
+    slices.push_back(createSlice(builder, op.getLoc(),
+                                 operands[opOperand->getOperandNumber()],
+                                 offsets, getAsOpFoldResult(sizes), strides));
+  }
+
+  return slices;
+}
+
+/// Creates a part of the given `op` split along the iteration space `dimension`
+/// with the given `size` and an optional `offset` (default 0). Makes slices
+/// of operands, using the input operands of the original op and the output
+/// operands provided as `resultOperands`. Expects `splitIterationSpace` to be
+/// a list of values representing the shape of the iteration space of the
+/// original op and updates it to be the iteration space of the curent part.
+/// Returns the split-out op as well as the output operand values updated with
+/// the partial results produced by this op through `results`.
+static LinalgOp createSplitPart(
+    ImplicitLocOpBuilder &builder, LinalgOp op, ValueRange resultOperands,
+    llvm::MutableArrayRef<Value> splitIterationSpace, unsigned dimension,
+    Value size, SmallVectorImpl<Value> &results, Value offset = nullptr) {
+  splitIterationSpace[dimension] = size;
+  SmallVector<Value> operands = llvm::to_vector(
+      llvm::map_range(op.getInputOperands(),
+                      [](OpOperand *opOperand) { return opOperand->get(); }));
+  llvm::append_range(operands, resultOperands);
+  operands = getOperandSlices(builder, op, splitIterationSpace, operands,
+                              dimension, offset);
+  Operation *part = op.clone(builder, op.getLoc(),
+                             getTensorOutputTypes(op, operands), operands);
+  results = insertSlicesBack(builder, builder.getLoc(), op, operands,
+                             part->getResults());
+  return cast<LinalgOp>(part);
+}
+
+std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
+                                              LinalgOp op, unsigned dimension,
+                                              OpFoldResult splitPoint) {
+  // Bail out on dimension overflow.
+  if (dimension >= op.getNumLoops())
+    return std::make_pair(op, LinalgOp());
+
+  // Compute the iteration space size as values.
+  ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+  SmallVector<Value, 4> allShapes =
+      op.createFlatListOfOperandDims(builder, op.getLoc());
+  AffineMap shapesToLoops = op.getShapesToLoopsMap();
+  SmallVector<Value, 4> iterationSpaceShapes =
+      applyMapToValues(builder, op.getLoc(), shapesToLoops, allShapes);
+
+  // Update the iteration space to have `splitPoint` as the size of `dimension`
+  // and use it to slice operands and results for a new, smaller instance of the
+  // `op`. Adjust the size if necessary to prevent overflows. Insert the partial
+  // results back.
+  Value splitPointValue = materializeOpFoldResult(builder, splitPoint);
+  splitPointValue = builder.createOrFold<AffineMinOp>(
+      builder.getIndexType(),
+      AffineMap::getMultiDimIdentityMap(/*numDims=*/2, builder.getContext()),
+      ValueRange({splitPointValue, iterationSpaceShapes[dimension]}));
+  SmallVector<Value> splitIterationSpace =
+      llvm::to_vector(iterationSpaceShapes);
+  SmallVector<Value> originalResults = llvm::to_vector(
+      llvm::map_range(op.getOutputOperands(),
+                      [](OpOperand *opOperand) { return opOperand->get(); }));
+  SmallVector<Value> firstResults;
+  LinalgOp first =
+      createSplitPart(builder, op, originalResults, splitIterationSpace,
+                      dimension, splitPointValue, firstResults);
+
+  // Update the iteration space to cover the remaining part of the original
+  // space, then create another instance of the `op` in that space. The size of
+  // the remaining part may become zero, but is never negative because of the
+  // adjustment above.
+  AffineExpr d0 = builder.getAffineDimExpr(0);
+  AffineExpr d1 = builder.getAffineDimExpr(1);
+  SmallVector<Value, 4> remainingSizes = applyMapToValues(
+      builder, op.getLoc(), AffineMap::inferFromExprList({d0 - d1}).front(),
+      {iterationSpaceShapes[dimension], splitPointValue});
+  SmallVector<Value> secondResults;
+  LinalgOp second =
+      createSplitPart(builder, op, firstResults, splitIterationSpace, dimension,
+                      remainingSizes.front(), secondResults, splitPointValue);
+
+  // Fixup the linalg.index results in the second part.
+  SmallVector<Value> ivAdditions;
+  ivAdditions.resize(splitIterationSpace.size());
+  ivAdditions[dimension] = splitPointValue;
+  linalg::addTileLoopIvsToIndexOpResults(builder, cast<LinalgOp>(second),
+                                         ivAdditions);
+
+  // Replace the original op with the results of the two newly created ops.
+  rewriter.replaceOp(op, secondResults);
+  return std::make_pair(first, second);
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index f3ed408cddfba..a7524b7ebfc80 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -182,32 +182,11 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes,
         makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes,
                         sizeBounds, /*omitPartialTileCheck=*/false);
 
-    // TODO: use an interface/adaptor to avoid leaking position in
-    // `tiledOperands`.
-    SmallVector<Type, 4> resultTensorTypes;
-    for (OpOperand *opOperand : op.getOutputTensorOperands())
-      resultTensorTypes.push_back(
-          tiledOperands[opOperand->getOperandNumber()].getType());
-
+    SmallVector<Type> resultTensorTypes =
+        getTensorOutputTypes(op, tiledOperands);
     res = op.clone(b, loc, resultTensorTypes, tiledOperands);
-
-    // Insert a insert_slice for each output tensor.
-    unsigned resultIdx = 0;
-    for (OpOperand *opOperand : op.getOutputTensorOperands()) {
-      // TODO: use an interface/adaptor to avoid leaking position in
-      // `tiledOperands`.
-      Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
-      // TODO: Propagate RewriterBase everywhere.
-      IRRewriter rewriter(b);
-      if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
-        tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp,
-                                                      res->getResult(resultIdx),
-                                                      sliceOp.getSource()));
-      } else {
-        tensorResults.push_back(res->getResult(resultIdx));
-      }
-      ++resultIdx;
-    }
+    tensorResults =
+        insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
   };
   GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index a2ffc691a6f0e..4b68c5cd96db9 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -913,6 +913,21 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
   return sliceOp->getResult(0);
 }
 
+Value createSlice(OpBuilder &builder, Location loc, Value value,
+                  ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+                  ArrayRef<OpFoldResult> strides) {
+  if (value.getType().isa<MemRefType>()) {
+    return builder.create<memref::SubViewOp>(loc, value, offsets, sizes,
+                                             strides);
+  }
+
+  // This intentionally does not attempt to compose the extractslice operations.
+  assert(value.getType().isa<RankedTensorType>() &&
+         "expected a ranked tensor type");
+  return builder.create<tensor::ExtractSliceOp>(loc, value, offsets, sizes,
+                                                strides);
+}
+
 SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
                                       ValueRange ivs, ValueRange tileSizes) {
   SmallVector<Value> offsets;
@@ -943,6 +958,41 @@ SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
   return sizes;
 }
 
+SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
+  // TODO: use an interface/adaptor to avoid leaking position in
+  // `tiledOperands`.
+  return llvm::to_vector(
+      llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) {
+        return operands[opOperand->getOperandNumber()].getType();
+      }));
+}
+
+SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
+                                    LinalgOp op, ValueRange operands,
+                                    ValueRange results) {
+  SmallVector<Value> tensorResults;
+  tensorResults.reserve(results.size());
+  // Insert a insert_slice for each output tensor.
+  unsigned resultIdx = 0;
+  for (OpOperand *opOperand : op.getOutputTensorOperands()) {
+    // TODO: use an interface/adaptor to avoid leaking position in
+    // `tiledOperands`.
+    Value outputTensor = operands[opOperand->getOperandNumber()];
+    if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
+      Value inserted = builder.create<tensor::InsertSliceOp>(
+          loc, sliceOp.source().getType(), results[resultIdx], sliceOp.source(),
+          sliceOp.offsets(), sliceOp.sizes(), sliceOp.strides(),
+          sliceOp.static_offsets(), sliceOp.static_sizes(),
+          sliceOp.static_strides());
+      tensorResults.push_back(inserted);
+    } else {
+      tensorResults.push_back(results[resultIdx]);
+    }
+    ++resultIdx;
+  }
+  return tensorResults;
+}
+
 SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
                                       LinalgOp linalgOp,
                                       ArrayRef<Value> valuesToTile,

diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index e5a2a473150cc..beef9240d8e35 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -15,6 +15,12 @@
 OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
 
 
+def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr:
+  if isinstance(value, int):
+    return IntegerAttr.get(IntegerType.get_signless(64), value)
+  return value
+
+
 def _get_array_attr(
     values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr:
   """Creates an array attribute from its operand."""
@@ -41,13 +47,7 @@ def _get_int_array_attr(
   if isinstance(values, ArrayAttr):
     return values
 
-  attributes = []
-  for value in values:
-    if isinstance(value, IntegerAttr):
-      attributes.append(value)
-    else:
-      attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value))
-  return ArrayAttr.get(attributes)
+  return ArrayAttr.get([_get_int64_attr(v) for v in values])
 
 
 def _get_int_int_array_attr(
@@ -152,6 +152,39 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
         pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip)
 
 
+class SplitOp:
+  """Specialization for SplitOp class."""
+
+  def __init__(self,
+               target: Union[Operation, Value],
+               dimension: Union[int, Attribute],
+               split_point: Union[int, Operation, Value, Attribute],
+               *,
+               loc=None,
+               ip=None):
+    dimension = _get_int64_attr(dimension)
+    if isinstance(split_point, int):
+      split_point = _get_int64_attr(split_point)
+
+    if isinstance(split_point, Attribute):
+      static_split_point = split_point
+      dynamic_split_point = None
+    else:
+      static_split_point = _get_int64_attr(ShapedType._get_dynamic_size())
+      dynamic_split_point = _get_op_result_or_value(split_point)
+
+    pdl_operation_type = pdl.OperationType.get()
+    super().__init__(
+        pdl_operation_type,
+        pdl_operation_type,
+        _get_op_result_or_value(target),
+        dimension=dimension,
+        static_split_point=static_split_point,
+        dynamic_split_point=dynamic_split_point,
+        loc=loc,
+        ip=ip)
+
+
 class TileOp:
   """Specialization for TileOp class."""
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
new file mode 100644
index 0000000000000..2eef84c82b4dd
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -0,0 +1,366 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -verify-diagnostics | FileCheck %s
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @linalg_generic : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.pdl_match @linalg_generic in %arg1
+    %1:2 = transform.structured.split %0 after 42 { dimension = 0 }
+  }
+}
+
+func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
+
+// CHECK: #[[$ADD_42_MAP:.+]] = affine_map<(d0) -> (d0 + 42)>
+// CHECK: #[[$ADD_10_MAP:.+]] = affine_map<(d0) -> (d0 + 10)>
+
+// CHECK-LABEL: @one_d_static
+// CHECK-SAME:  %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32>
+func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
+  // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [42] [1] : tensor<100xf32> to tensor<42xf32>
+  // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [42] [1] : tensor<100xf32> to tensor<42xf32>
+  // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
+  // CHECK:   ins(%[[IN_SLICE_LOW]]
+  // CHECK:   outs(%[[OUT_SLICE_LOW]]
+  // CHECK:   linalg.index 0
+  // CHECK:   func.call @elem
+  // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [42] [1]
+  //
+  // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][42] [58] [1] : tensor<100xf32> to tensor<58xf32>
+  // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][42] [58] [1] : tensor<100xf32> to tensor<58xf32>
+  // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
+  // CHECK:   ins(%[[IN_SLICE_HIGH]]
+  // CHECK:   outs(%[[OUT_SLICE_HIGH]]
+  // CHECK:   %[[IDX:.+]] = linalg.index 0
+  // CHECK:   affine.apply #[[$ADD_42_MAP]](%[[IDX]])
+  // CHECK:   func.call @elem
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][42] [58] [1]
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
+    iterator_types = ["parallel"]
+  }
+  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
+  ^bb0(%0: f32, %1: f32):
+    %i = linalg.index 0 : index
+    %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32
+    linalg.yield %call_res : f32
+  } -> tensor<100xf32>
+
+  // CHECK: return %[[RES]]
+  return %0 : tensor<100xf32>
+}
+
+// CHECK-LABEL: @one_d_static_overflow
+// CHECK-SAME:  %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
+func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [10] [1] : tensor<10xf32> to tensor<10xf32>
+  // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [10] [1] : tensor<10xf32> to tensor<10xf32>
+  // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
+  // CHECK:   ins(%[[IN_SLICE_LOW]]
+  // CHECK:   outs(%[[OUT_SLICE_LOW]]
+  // CHECK:   linalg.index 0
+  // CHECK:   func.call @elem
+  // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [10] [1]
+  //
+  // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][10] [0] [1] : tensor<10xf32> to tensor<0xf32>
+  // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][10] [0] [1] : tensor<10xf32> to tensor<0xf32>
+  // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
+  // CHECK:   ins(%[[IN_SLICE_HIGH]]
+  // CHECK:   outs(%[[OUT_SLICE_HIGH]]
+  // CHECK:   %[[IDX:.+]] = linalg.index 0
+  // CHECK:   affine.apply #[[$ADD_10_MAP]](%[[IDX]])
+  // CHECK:   func.call @elem
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][10] [0] [1]
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
+    iterator_types = ["parallel"]
+  }
+  ins(%arg0: tensor<10xf32>) outs(%arg1: tensor<10xf32>) {
+  ^bb0(%0: f32, %1: f32):
+    %i = linalg.index 0 : index
+    %call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32
+    linalg.yield %call_res : f32
+  } -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @func_call : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+  pdl.pattern @linalg_generic : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.pdl_match @linalg_generic in %arg1
+    %1 = transform.pdl_match @func_call in %arg1
+    transform.structured.split %0 after %1 { dimension = 0 }
+  }
+}
+
+func.func private @get_size() -> index
+
+// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<(d0, d1) -> (d0, 100)>
+// CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)>
+
+// CHECK-LABEL: @dynamic
+func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
+  // CHECK: %[[SPLIT:.+]] = call @get_size
+  // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]](%[[SPLIT]]
+  // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
+  // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
+  // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
+  // CHECK:   ins(%[[IN_SLICE_LOW]]
+  // CHECK:   outs(%[[OUT_SLICE_LOW]]
+  // CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1]
+  //
+  // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
+  // CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
+  // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor<?xf32>
+  // CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
+  // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] : tensor<100xf32> to tensor<?xf32>
+  // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
+  // CHECK:   ins(%[[IN_SLICE_HIGH]]
+  // CHECK:   outs(%[[OUT_SLICE_HIGH]]
+  // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1]
+  %0 = func.call @get_size() : () -> index
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
+    iterator_types = ["parallel"]
+  }
+  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
+  ^bb0(%3: f32, %4: f32):
+    linalg.yield %3 : f32
+  } -> tensor<100xf32>
+  return %1 : tensor<100xf32>
+}
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @linalg_generic : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.pdl_match @linalg_generic in %arg1
+    %1:2 = transform.structured.split %0 after 4 { dimension = 0}
+    %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 }
+  }
+}
+
+func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
+
+// CHECK-LABEL: @two_d
+func.func @two_d(%arg0: tensor<10x34xf32>,
+                 %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
+  // Check the overall structure: split along the dimension 0, and then split
+  // the second half only along the dimension 1.
+  // CHECK:      %[[IN_1:.+]] = tensor.extract_slice %[[IN:.+]][0, 0]
+  // CHECK:      %[[OUT_1:.+]] = tensor.extract_slice %[[OUT:.+]][0, 0]
+  // CHECK:      %[[RES_1:.+]] = linalg.generic
+  // CHECK-SAME:   ins(%[[IN_1]] : tensor<4x34xf32>)
+  // CHECK-SAME:   outs(%[[OUT_1]] : tensor<4x34xf32>)
+  // CHECK:      %[[PARTIAL_1:.+]] = tensor.insert_slice %[[RES_1]] into %[[OUT]]
+  //
+  // CHECK:      %[[IN_2:.+]] = tensor.extract_slice %[[IN]]
+  // CHECK:      %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]]
+  // CHECK:      %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]]
+  // CHECK:      %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]]
+  // CHECK:      %[[RES_21:.+]] = linalg.generic
+  // CHECK-SAME:   ins(%[[IN_21]] : tensor<6x16xf32>)
+  // CHECK-SAME:   outs(%[[OUT_21]] : tensor<6x16xf32>)
+  // CHECK:      %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]]
+  //
+  // CHECK:      %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]]
+  // CHECK:      %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]]
+  // CHECK:      %[[RES_22:.+]] = linalg.generic
+  // CHECK-SAME:   ins(%[[IN_22]] : tensor<6x18xf32>)
+  // CHECK-SAME:   outs(%[[OUT_22]] : tensor<6x18xf32>)
+  // CHECK:      %[[PARTIAL_22:.+]] = tensor.insert_slice %[[RES_22]] into %[[PARTIAL_21]]
+  // CHECK:      %[[PARTIAL_2:.+]] = tensor.insert_slice %[[PARTIAL_22]] into %[[PARTIAL_1]]
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+  ins(%arg0: tensor<10x34xf32>)
+  outs(%arg1: tensor<10x34xf32>) {
+  ^bb0(%0: f32, %1: f32):
+    %i = linalg.index 0 : index
+    %j = linalg.index 1 : index
+    %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
+    linalg.yield %call_res : f32
+  } -> tensor<10x34xf32>
+  return %0 : tensor<10x34xf32>
+}
+
+// -----
+
+transform.sequence {
+^bb1(%arg1: !pdl.operation):
+  // expected-error @below {{expects either a dynamic or a static split point to be provided}}
+  %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -1 } : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+}
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @func_call : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+  pdl.pattern @linalg_generic : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.pdl_match @linalg_generic in %arg1
+    %1 = transform.pdl_match @func_call in %arg1
+    // expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}}
+    transform.structured.split %0 after %1 { dimension = 0 }
+  }
+}
+
+func.func private @get_size() -> i64
+
+func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
+  // expected-note @below {{dynamic split point}}
+  %0 = func.call @get_size() : () -> i64
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
+    iterator_types = ["parallel"]
+  }
+  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
+  ^bb0(%3: f32, %4: f32):
+    linalg.yield %3 : f32
+  } -> tensor<100xf32>
+  return %1 : tensor<100xf32>
+}
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @func_call : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+  pdl.pattern @linalg_generic : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.pdl_match @linalg_generic in %arg1
+    %1 = transform.pdl_match @func_call in %arg1
+    // expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}}
+    transform.structured.split %0 after %1 { dimension = 0 }
+  }
+}
+
+func.func private @get_size() -> i64
+
+func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
+    iterator_types = ["parallel"]
+  }
+  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
+  ^bb0(%3: f32, %4: f32):
+    linalg.yield %3 : f32
+  } -> tensor<100xf32>
+  return %1 : tensor<100xf32>
+}
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @func_return : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "func.return"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.pdl_match @func_return in %arg1
+    // expected-error @below {{only applies to structured ops}}
+    transform.structured.split %0 after 16 { dimension = 1 }
+  }
+}
+
+func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
+  // expected-note @below {{target op}}
+  return %arg0 : tensor<100xf32>
+}
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @linalg_generic : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.pdl_match @linalg_generic in %arg1
+    // expected-error @below {{dimension 1 does not exist in target op}}
+    transform.structured.split %0 after 16 { dimension = 1 }
+  }
+}
+
+func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
+  // expected-note @below {{target op}}
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
+    iterator_types = ["parallel"]
+  }
+  ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
+  ^bb0(%0: f32, %1: f32):
+    linalg.yield %0 : f32
+  } -> tensor<100xf32>
+  return %0 : tensor<100xf32>
+}
+

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index a34b03fb9d0bc..f7838f7e2adbc 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -84,6 +84,19 @@ def testScalarize():
   # CHECK: transform.structured.scalarize
 
 
+ at run
+def testSplit():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
+    structured.SplitOp(
+        split.results[0], dimension=3, split_point=split.results[1])
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testSplit
+  # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
+  # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
+
+
 @run
 def testTileCompact():
   sequence = transform.SequenceOp()


        


More information about the Mlir-commits mailing list