[Mlir-commits] [mlir] c8fab80 - [mlir][Transform] NFC - Add custom builders for some useful transforms.

Nicolas Vasilache llvmlistbot at llvm.org
Fri Nov 4 10:04:36 PDT 2022


Author: Nicolas Vasilache
Date: 2022-11-04T10:04:28-07:00
New Revision: c8fab80d64119ffcde78f0e9a70c5babb0da0467

URL: https://github.com/llvm/llvm-project/commit/c8fab80d64119ffcde78f0e9a70c5babb0da0467
DIFF: https://github.com/llvm/llvm-project/commit/c8fab80d64119ffcde78f0e9a70c5babb0da0467.diff

LOG: [mlir][Transform] NFC - Add custom builders for some useful transforms.

Differential Revision: https://reviews.llvm.org/D137443

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index f7952db7e2a23..2583875e2d0ea 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -20,6 +20,12 @@ namespace linalg {
 class GenericOp;
 class LinalgOp;
 } // namespace linalg
+
+namespace transform {
+// Types needed for builders.
+struct TileSizesSpec {};
+struct NumThreadsSpec {};
+} // namespace transform
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6cb14acb1b089..347def6c9d1b5 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -112,6 +112,10 @@ def FuseIntoContainingOp :
                           [TransformMappingAlloc,
                            TransformMappingWrite]>:$fused_op);
   let assemblyFormat = "$producer_op `into` $containing_op attr-dict";
+
+  let builders = [
+    OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)>
+  ];
 }
 
 def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
@@ -226,6 +230,10 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
   // TODO: variadic results when needed.
   let results = (outs PDL_Operation:$results);
 
+  let builders = [
+    OpBuilder<(ins "Value":$target, "ArrayRef<StringRef>":$opNames)>
+  ];
+
   let assemblyFormat = [{
     (`ops` `{` $ops^ `}`)?
     (`interface` `{` $interface^ `}`)?
@@ -600,6 +608,15 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
 
   let assemblyFormat = "$target attr-dict";
 
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "int64_t":$splitFactor,
+                   "int64_t":$insertSplitDimension,
+                   CArg<"bool", "false">:$innerParallel,
+                   CArg<"bool", "false">:$useScalingAlgorithm,
+                   CArg<"bool", "false">:$useAlloc)>
+  ];
+
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::linalg::LinalgOp target, 
@@ -818,6 +835,30 @@ def TileToForeachThreadOp :
                    OptionalAttr<I64ArrayAttr>:$thread_dim_mapping);
   let results = (outs PDL_Operation:$foreach_thread_op,
                       PDL_Operation:$tiled_op);
+
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$staticTileSizes,
+                   CArg<"::mlir::transform::TileSizesSpec", 
+                        "::mlir::transform::TileSizesSpec()">,
+                   CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<OpFoldResult>":$mixedTileSizes,
+                   CArg<"::mlir::transform::TileSizesSpec", 
+                        "::mlir::transform::TileSizesSpec()">,
+                   CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$staticNumThreads,
+                   CArg<"::mlir::transform::NumThreadsSpec", 
+                        "::mlir::transform::NumThreadsSpec()">,
+                   CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<OpFoldResult>":$mixedNumThreads,
+                   CArg<"::mlir::transform::NumThreadsSpec", 
+                        "::mlir::transform::NumThreadsSpec()">,
+                   CArg<"ArrayRef<int64_t>", "{}">:$threadDimMapping)>,
+  ];
+
   let assemblyFormat = [{
     $target oilist(
         `num_threads` custom<DynamicIndexList>($num_threads,
@@ -943,6 +984,10 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   let results = (outs PDL_Operation:$transformed);
 
   let assemblyFormat = "$target attr-dict";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target, CArg<"bool", "false">:$vectorizePadding)>
+  ];
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::Operation *target, 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 4b1bb02ee757a..42f8d5cb27698 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -253,6 +253,11 @@ def SplitHandlesOp : TransformDialectOp<"split_handles",
   let arguments = (ins TransformTypeInterface:$handle,
                        I64Attr:$num_result_handles);
   let results = (outs Variadic<TransformTypeInterface>:$results);
+
+  let builders = [
+    OpBuilder<(ins "Value":$handle, "int64_t":$numResultHandles)>
+  ];
+
   let assemblyFormat = [{
     $handle `in` `[` $num_result_handles `]` 
     attr-dict `:` functional-type(operands, results)
@@ -305,6 +310,12 @@ def PrintOp : TransformDialectOp<"print",
   let arguments = (ins Optional<TransformTypeInterface>:$target,
                        OptionalAttr<StrAttr>:$name);
   let results = (outs);
+
+  let builders = [
+    OpBuilder<(ins CArg<"StringRef", "StringRef()">:$name)>,
+    OpBuilder<(ins "Value":$target, CArg<"StringRef", "StringRef()">:$name)>
+  ];
+
   let assemblyFormat = "$target attr-dict (`:` type($target)^)?";
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c8a3cb6946e3d..a35dd14483963 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -254,6 +254,14 @@ LogicalResult transform::FuseOp::verify() {
 // FuseIntoContainingOp
 //===----------------------------------------------------------------------===//
 
+void transform::FuseIntoContainingOp::build(OpBuilder &builder,
+                                            OperationState &result,
+                                            Value producerOp,
+                                            Value containingOp) {
+  result.addOperands({producerOp, containingOp});
+  result.addTypes(pdl::OperationType::get(builder.getContext()));
+}
+
 /// Find the first "extract" user of `producerOp` and tile it right before its
 /// use. The tiled op is fused under the `containingOp`.
 /// Return this fused op on success or nullptr if anything fails.
@@ -628,6 +636,14 @@ LogicalResult transform::InterchangeOp::verify() {
 // MatchOp
 //===---------------------------------------------------------------------===//
 
+void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
+                               Value target, ArrayRef<StringRef> opNames) {
+  result.addOperands(target);
+  result.addAttribute(MatchOp::getOpsAttrName(result.name),
+                      builder.getStrArrayAttr(opNames));
+  result.addTypes(pdl::OperationType::get(builder.getContext()));
+}
+
 DiagnosedSilenceableFailure
 transform::MatchOp::apply(transform::TransformResults &results,
                           transform::TransformState &state) {
@@ -1069,6 +1085,34 @@ LogicalResult SplitOp::verify() {
 // SplitReductionOp
 //===----------------------------------------------------------------------===//
 
+void transform::SplitReductionOp::build(
+    OpBuilder &builder, OperationState &result, Value target,
+    int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
+    bool useScalingAlgorithm, bool useAlloc) {
+  MLIRContext *ctx = builder.getContext();
+  result.addOperands(target);
+  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
+                      builder.getI64IntegerAttr(splitFactor));
+  result.addAttribute(
+      SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
+      builder.getI64IntegerAttr(insertSplitDimension));
+  if (innerParallel) {
+    result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
+                        builder.getUnitAttr());
+  }
+  if (useScalingAlgorithm) {
+    result.addAttribute(
+        SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
+        builder.getUnitAttr());
+  }
+  if (useAlloc) {
+    result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
+                        builder.getUnitAttr());
+  }
+  auto resultType = pdl::OperationType::get(ctx);
+  result.addTypes({resultType, resultType, resultType, resultType});
+}
+
 DiagnosedSilenceableFailure
 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
                                         SmallVectorImpl<Operation *> &results,
@@ -1277,13 +1321,75 @@ void transform::TileOp::getEffects(
 // TileToForeachThreadOp
 //===----------------------------------------------------------------------===//
 
+void transform::TileToForeachThreadOp::build(
+    OpBuilder &builder, OperationState &result, Value target,
+    ArrayRef<int64_t> staticTileSizes, transform::TileSizesSpec,
+    ArrayRef<int64_t> threadDimMapping) {
+  return build(builder, result, target,
+               getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+               TileSizesSpec(), threadDimMapping);
+}
+
+void transform::TileToForeachThreadOp::build(
+    OpBuilder &builder, OperationState &result, Value target,
+    ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec,
+    ArrayRef<int64_t> threadDimMapping) {
+  SmallVector<int64_t> staticTileSizes;
+  SmallVector<Value> dynamicTileSizes;
+  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes,
+                             ShapedType::kDynamicSize);
+  // Call the default builder which sets up the proper operands segment sizes
+  // attributes for multiple variadic operands. In the absence of this, horrible
+  // bugs ensue.
+  MLIRContext *ctx = builder.getContext();
+  auto operationType = pdl::OperationType::get(ctx);
+  auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
+  ArrayAttr threadDimMappingAttr;
+  if (!threadDimMapping.empty())
+    threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping);
+  build(builder, result, TypeRange{operationType, operationType}, target,
+        /*numThreads=*/ValueRange{}, dynamicTileSizes,
+        /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr,
+        threadDimMappingAttr);
+}
+
+void transform::TileToForeachThreadOp::build(
+    OpBuilder &builder, OperationState &result, Value target,
+    ArrayRef<int64_t> staticNumThreads, transform::NumThreadsSpec,
+    ArrayRef<int64_t> threadDimMapping) {
+  return build(builder, result, target,
+               getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
+               NumThreadsSpec(), threadDimMapping);
+}
+
+void transform::TileToForeachThreadOp::build(
+    OpBuilder &builder, OperationState &result, Value target,
+    ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec,
+    ArrayRef<int64_t> threadDimMapping) {
+  SmallVector<int64_t> staticNumThreads;
+  SmallVector<Value> dynamicNumThreads;
+  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
+                             staticNumThreads, ShapedType::kDynamicSize);
+  // Call the default builder which sets up the proper operands segment sizes
+  // attributes for multiple variadic operands. In the absence of this, horrible
+  // bugs ensue.
+  MLIRContext *ctx = builder.getContext();
+  auto operationType = pdl::OperationType::get(ctx);
+  auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads);
+  ArrayAttr threadDimMappingAttr;
+  if (!threadDimMapping.empty())
+    threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping);
+  build(builder, result, TypeRange{operationType, operationType}, target,
+        dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr,
+        /*staticTileSizes=*/ArrayAttr(), threadDimMappingAttr);
+}
+
 DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
     RewriterBase &rewriter, transform::TransformState &state,
     TransformOpInterface transformOp, ArrayRef<Operation *> targets,
     ArrayRef<OpFoldResult> mixedNumThreads,
     ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping,
     SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
-
   if (targets.empty())
     return DiagnosedSilenceableFailure(success());
 
@@ -1573,6 +1679,16 @@ void transform::TileToScfForOp::getEffects(
 // VectorizeOp
 //===----------------------------------------------------------------------===//
 
+void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result,
+                                   Value target, bool vectorizePadding) {
+  result.addOperands(target);
+  if (vectorizePadding) {
+    result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name),
+                        builder.getUnitAttr());
+  }
+  result.addTypes(pdl::OperationType::get(builder.getContext()));
+}
+
 namespace {
 /// This is an helper only to call vectorize via a pattern inside of
 /// VectorizeOp::applyToOne.

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 5d84b7b0a6030..9b136cccbe6f1 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -314,11 +314,11 @@ transform::TransformResults::TransformResults(unsigned numSegments) {
 
 void transform::TransformResults::set(OpResult value,
                                       ArrayRef<Operation *> ops) {
-  unsigned position = value.getResultNumber();
-  assert(position < segments.size() &&
+  int64_t position = value.getResultNumber();
+  assert(position < static_cast<int64_t>(segments.size()) &&
          "setting results for a non-existent handle");
   assert(segments[position].data() == nullptr && "results already set");
-  unsigned start = operations.size();
+  int64_t start = operations.size();
   llvm::append_range(operations, ops);
   segments[position] = makeArrayRef(operations).drop_front(start);
 }

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 2be1bea91fbe9..5fe2d465ee51a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -472,6 +472,16 @@ OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
 // SplitHandlesOp
 //===----------------------------------------------------------------------===//
 
+void transform::SplitHandlesOp::build(OpBuilder &builder,
+                                      OperationState &result, Value target,
+                                      int64_t numResultHandles) {
+  result.addOperands(target);
+  result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name),
+                      builder.getI64IntegerAttr(numResultHandles));
+  auto pdlOpType = pdl::OperationType::get(builder.getContext());
+  result.addTypes(SmallVector<pdl::OperationType>(numResultHandles, pdlOpType));
+}
+
 DiagnosedSilenceableFailure
 transform::SplitHandlesOp::apply(transform::TransformResults &results,
                                  transform::TransformState &state) {
@@ -812,6 +822,20 @@ LogicalResult transform::WithPDLPatternsOp::verify() {
 // PrintOp
 //===----------------------------------------------------------------------===//
 
+void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
+                               StringRef name) {
+  if (!name.empty()) {
+    result.addAttribute(PrintOp::getNameAttrName(result.name),
+                        builder.getStrArrayAttr(name));
+  }
+}
+
+void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
+                               Value target, StringRef name) {
+  result.addOperands({target});
+  build(builder, result, name);
+}
+
 DiagnosedSilenceableFailure
 transform::PrintOp::apply(transform::TransformResults &results,
                           transform::TransformState &state) {


        


More information about the Mlir-commits mailing list