[Mlir-commits] [mlir] [mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass (PR #68330)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Oct 6 02:41:49 PDT 2023


================
@@ -0,0 +1,188 @@
+//===- TransformInterpreterUtils.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Lightweight transform dialect interpreter utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "transform-dialect-interpreter-utils"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+LogicalResult transform::detail::parseTransformInterpreterModule(
+    MLIRContext *context, llvm::StringRef transformFileName,
+    OwningOpRef<ModuleOp> &transformModule) {
+  if (transformFileName.empty()) {
+    LLVM_DEBUG(
+        DBGS() << "no transform file name specified, assuming the transform "
+                  "module is embedded in the IR next to the top-level\n");
+    return success();
+  }
+  // Parse transformFileName content into a ModuleOp.
+  std::string errorMessage;
+  auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
+  if (!memoryBuffer) {
+    return emitError(FileLineColLoc::get(
+               StringAttr::get(context, transformFileName), 0, 0))
+           << "failed to open transform file: " << errorMessage;
+  }
+  // Tell sourceMgr about this buffer, the parser will pick it up.
+  llvm::SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
+  transformModule =
+      OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
+  return mlir::verify(*transformModule);
+}
+
+ModuleOp transform::detail::getPreloadedTransformInterpreterModule(
+    MLIRContext *context) {
+  auto preloadedLibraryRange =
+      context->getOrLoadDialect<transform::TransformDialect>()
+          ->getLibraryModules();
+  if (!preloadedLibraryRange.empty())
+    return *preloadedLibraryRange.begin();
+  return ModuleOp();
+}
+
+transform::TransformOpInterface
+transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
+                                           StringRef entryPoint) {
+  SmallVector<Operation *, 2> l{root};
+  if (module)
+    l.push_back(module);
+  for (Operation *op : l) {
+    transform::TransformOpInterface transform = nullptr;
+    op->walk<WalkOrder::PreOrder>(
+        [&](transform::NamedSequenceOp namedSequenceOp) {
+          if (namedSequenceOp.getSymName() == entryPoint) {
+            transform = cast<transform::TransformOpInterface>(
+                namedSequenceOp.getOperation());
+            return WalkResult::interrupt();
+          }
+          return WalkResult::advance();
+        });
+    if (transform)
+      return transform;
+  }
+  auto diag = root->emitError()
+              << "could not find a nested named sequence with name: "
+              << entryPoint;
+  return nullptr;
+}
+
+LogicalResult transform::detail::defineDeclaredSymbols(Block &block,
+                                                       ModuleOp definitions) {
+  MLIRContext &ctx = *definitions->getContext();
+  auto consumedName =
+      StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
+  auto readOnlyName =
+      StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
+
+  for (Operation &op : llvm::make_early_inc_range(block)) {
+    LLVM_DEBUG(DBGS() << op << "\n");
+    auto symbol = dyn_cast<SymbolOpInterface>(op);
+    if (!symbol)
+      continue;
+    if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
+      continue;
+
+    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
+                      << symbol.getNameAttr() << ":");
+    SymbolTable symbolTable(definitions);
+    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+    if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
+        externalSymbol->getRegion(0).empty()) {
+      LLVM_DEBUG(llvm::dbgs() << "not found\n");
+      continue;
+    }
+
+    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+    auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
+    if (!symbolFunc || !externalSymbolFunc) {
+      LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+      continue;
+    }
+
+    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+    if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
+      return symbolFunc.emitError()
+             << "external definition has a mismatching signature ("
+             << externalSymbolFunc.getFunctionType() << ")";
+    }
+
+    for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
+      bool isExternalConsumed =
+          externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
+      bool isExternalReadonly =
+          externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+      bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
+      bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
+      if (!isExternalConsumed && !isExternalReadonly) {
+        if (isConsumed)
+          externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
+        else if (isReadonly)
+          externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+        continue;
+      }
+
+      if ((isExternalConsumed && !isConsumed) ||
+          (isExternalReadonly && !isReadonly)) {
+        return symbolFunc.emitError()
+               << "external definition has mismatching consumption annotations "
+                  "for argument #"
+               << i;
+      }
+    }
+
+    OpBuilder builder(&op);
+    builder.setInsertionPoint(&op);
+    builder.clone(*externalSymbol);
+    symbol->erase();
+  }
+
+  return success();
+}
+
+LogicalResult transform::applyTransformNamedSequence(
+    Operation *payload, ModuleOp transformModule,
+    const TransformOptions &options, StringRef entryPoint) {
+  Operation *transformRoot =
+      detail::findTransformEntryPoint(payload, transformModule, entryPoint);
+  if (!transformRoot)
+    return failure();
+
+  // `sharedTransformModule` may not be modified.
----------------
ftynse wrote:

There's no sharedTransformModule here.

https://github.com/llvm/llvm-project/pull/68330


More information about the Mlir-commits mailing list