[Mlir-commits] [mlir] 8066f22 - [MLIR] Add argument insertion helpers for FunctionLike

Fabian Schuiki llvmlistbot at llvm.org
Thu Jul 1 00:22:32 PDT 2021


Author: Fabian Schuiki
Date: 2021-07-01T09:18:57+02:00
New Revision: 8066f22c4663d9ee6c763d9108c89448e5c19848

URL: https://github.com/llvm/llvm-project/commit/8066f22c4663d9ee6c763d9108c89448e5c19848
DIFF: https://github.com/llvm/llvm-project/commit/8066f22c4663d9ee6c763d9108c89448e5c19848.diff

LOG: [MLIR] Add argument insertion helpers for FunctionLike

Add helpers to facilitate adding arguments and results to operations
that implement the `FunctionLike` trait. These operations already have a
convenient argument and result *erasure* mechanism, but a corresopnding
utility for insertion is missing. This introduces such a utility.

Added: 
    mlir/test/IR/test-func-insert-arg.mlir
    mlir/test/IR/test-func-insert-result.mlir

Modified: 
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/FunctionSupport.h
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/lib/IR/FunctionSupport.cpp
    mlir/test/lib/IR/TestFunc.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4edf72667bd6e..edbd1ea2ae912 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -154,6 +154,13 @@ def Builtin_Function : Builtin_Type<"Function", [
     unsigned getNumResults() const;
     Type getResult(unsigned i) const { return getResults()[i]; }
 
+    /// Returns a new function type with the specified arguments and results
+    /// inserted.
+    FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+                                       TypeRange argTypes,
+                                       ArrayRef<unsigned> resultIndices,
+                                       TypeRange resultTypes);
+
     /// Returns a new function type without the specified arguments and results.
     FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
                                           ArrayRef<unsigned> resultIndices);

diff  --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index f27c857a175a6..c7ee9429d583e 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -68,6 +68,19 @@ inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
   return resultDict ? resultDict.getValue() : llvm::None;
 }
 
+/// Insert the specified arguments and update the function type attribute.
+void insertFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
+                             TypeRange argTypes,
+                             ArrayRef<DictionaryAttr> argAttrs,
+                             ArrayRef<Optional<Location>> argLocs,
+                             unsigned originalNumArgs, Type newType);
+
+/// Insert the specified results and update the function type attribute.
+void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
+                           TypeRange resultTypes,
+                           ArrayRef<DictionaryAttr> resultAttrs,
+                           unsigned originalNumResults, Type newType);
+
 /// Erase the specified arguments and update the function type attribute.
 void eraseFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
                             unsigned originalNumArgs, Type newType);
@@ -208,6 +221,22 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
     return function_like_impl::getFunctionType(this->getOperation());
   }
 
+  /// Return the type of this function with the specified arguments and results
+  /// inserted. This is used to update the function's signature in the
+  /// `insertArguments` and `insertResults` methods. The arrays must be sorted
+  /// by increasing index.
+  ///
+  /// Note that the concrete class must define a method with the same name to
+  /// hide this one if the concrete class does not use FunctionType for the
+  /// function type under the hood.
+  FunctionType getTypeWithArgsAndResults(ArrayRef<unsigned> argIndices,
+                                         TypeRange argTypes,
+                                         ArrayRef<unsigned> resultIndices,
+                                         TypeRange resultTypes) {
+    return getType().getWithArgsAndResults(argIndices, argTypes, resultIndices,
+                                           resultTypes);
+  }
+
   /// Return the type of this function without the specified arguments and
   /// results. This is used to update the function's signature in the
   /// `eraseArguments` and `eraseResults` methods. The arrays of indices are
@@ -267,6 +296,48 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
     return getBody().getArgumentTypes();
   }
 
+  /// Insert a single argument of type `argType` with attributes `argAttrs` and
+  /// location `argLoc` at `argIndex`.
+  void insertArgument(unsigned argIndex, Type argType, DictionaryAttr argAttrs,
+                      Optional<Location> argLoc = {}) {
+    insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc});
+  }
+
+  /// Inserts arguments with the listed types, attributes, and locations at the
+  /// listed indices. `argIndices` must be sorted. Arguments are inserted in the
+  /// order they are listed, such that arguments with identical index will
+  /// appear in the same order that they were listed here.
+  void insertArguments(ArrayRef<unsigned> argIndices, TypeRange argTypes,
+                       ArrayRef<DictionaryAttr> argAttrs,
+                       ArrayRef<Optional<Location>> argLocs) {
+    unsigned originalNumArgs = getNumArguments();
+    Type newType = getTypeWithArgsAndResults(
+        argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{});
+    function_like_impl::insertFunctionArguments(
+        this->getOperation(), argIndices, argTypes, argAttrs, argLocs,
+        originalNumArgs, newType);
+  }
+
+  /// Insert a single result of type `resultType` at `resultIndex`.
+  void insertResult(unsigned resultIndex, Type resultType,
+                    DictionaryAttr resultAttrs) {
+    insertResults({resultIndex}, {resultType}, {resultAttrs});
+  }
+
+  /// Inserts results with the listed types at the listed indices.
+  /// `resultIndices` must be sorted. Results are inserted in the order they are
+  /// listed, such that results with identical index will appear in the same
+  /// order that they were listed here.
+  void insertResults(ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
+                     ArrayRef<DictionaryAttr> resultAttrs) {
+    unsigned originalNumResults = getNumResults();
+    Type newType = getTypeWithArgsAndResults(/*argIndices=*/{}, /*argTypes=*/{},
+                                             resultIndices, resultTypes);
+    function_like_impl::insertFunctionResults(
+        this->getOperation(), resultIndices, resultTypes, resultAttrs,
+        originalNumResults, newType);
+  }
+
   /// Erase a single argument at `argIndex`.
   void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
 

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d5fd1eadbb69f..f350596384a90 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -172,6 +172,45 @@ inline void iterateIndicesExcept(unsigned totalIndices,
       callback(i);
 }
 
+/// Returns a new function type with the specified arguments and results
+/// inserted.
+FunctionType FunctionType::getWithArgsAndResults(
+    ArrayRef<unsigned> argIndices, TypeRange argTypes,
+    ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
+  assert(argIndices.size() == argTypes.size());
+  assert(resultIndices.size() == resultTypes.size());
+
+  ArrayRef<Type> newInputTypes = getInputs();
+  SmallVector<Type, 4> newInputTypesBuffer;
+  if (!argIndices.empty()) {
+    const auto *fromIt = newInputTypes.begin();
+    for (auto it : llvm::zip(argIndices, argTypes)) {
+      const auto *toIt = newInputTypes.begin() + std::get<0>(it);
+      newInputTypesBuffer.append(fromIt, toIt);
+      newInputTypesBuffer.push_back(std::get<1>(it));
+      fromIt = toIt;
+    }
+    newInputTypesBuffer.append(fromIt, newInputTypes.end());
+    newInputTypes = newInputTypesBuffer;
+  }
+
+  ArrayRef<Type> newResultTypes = getResults();
+  SmallVector<Type, 4> newResultTypesBuffer;
+  if (!resultIndices.empty()) {
+    const auto *fromIt = newResultTypes.begin();
+    for (auto it : llvm::zip(resultIndices, resultTypes)) {
+      const auto *toIt = newResultTypes.begin() + std::get<0>(it);
+      newResultTypesBuffer.append(fromIt, toIt);
+      newResultTypesBuffer.push_back(std::get<1>(it));
+      fromIt = toIt;
+    }
+    newResultTypesBuffer.append(fromIt, newResultTypes.end());
+    newResultTypes = newResultTypesBuffer;
+  }
+
+  return FunctionType::get(getContext(), newInputTypes, newResultTypes);
+}
+
 /// Returns a new function type without the specified arguments and results.
 FunctionType
 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,

diff  --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp
index b8a0ebc3f4a53..4f6f76cfbcfb9 100644
--- a/mlir/lib/IR/FunctionSupport.cpp
+++ b/mlir/lib/IR/FunctionSupport.cpp
@@ -121,6 +121,95 @@ void mlir::function_like_impl::setAllResultAttrDicts(
                         llvm::to_vector<8>(wrappedAttrs));
 }
 
+void mlir::function_like_impl::insertFunctionArguments(
+    Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
+    ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Optional<Location>> argLocs,
+    unsigned originalNumArgs, Type newType) {
+  assert(argIndices.size() == argTypes.size());
+  assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
+  assert(argIndices.size() == argLocs.size() || argLocs.empty());
+  if (argIndices.empty())
+    return;
+
+  // There are 3 things that need to be updated:
+  // - Function type.
+  // - Arg attrs.
+  // - Block arguments of entry block.
+  Block &entry = op->getRegion(0).front();
+
+  // Update the argument attributes of the function.
+  auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
+  if (oldArgAttrs || !argAttrs.empty()) {
+    SmallVector<DictionaryAttr, 4> newArgAttrs;
+    newArgAttrs.reserve(originalNumArgs + argIndices.size());
+    unsigned oldIdx = 0;
+    auto migrate = [&](unsigned untilIdx) {
+      if (!oldArgAttrs) {
+        newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
+      } else {
+        auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
+        newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
+                           oldArgAttrRange.begin() + untilIdx);
+      }
+      oldIdx = untilIdx;
+    };
+    for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
+      migrate(argIndices[i]);
+      newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
+    }
+    migrate(originalNumArgs);
+    setAllArgAttrDicts(op, newArgAttrs);
+  }
+
+  // Update the function type and any entry block arguments.
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
+    entry.insertArgument(argIndices[i], argTypes[i],
+                         argLocs.empty() ? Optional<Location>{} : argLocs[i]);
+}
+
+void mlir::function_like_impl::insertFunctionResults(
+    Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
+    ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
+    Type newType) {
+  assert(resultIndices.size() == resultTypes.size());
+  assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
+  if (resultIndices.empty())
+    return;
+
+  // There are 2 things that need to be updated:
+  // - Function type.
+  // - Result attrs.
+
+  // Update the result attributes of the function.
+  auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
+  if (oldResultAttrs || !resultAttrs.empty()) {
+    SmallVector<DictionaryAttr, 4> newResultAttrs;
+    newResultAttrs.reserve(originalNumResults + resultIndices.size());
+    unsigned oldIdx = 0;
+    auto migrate = [&](unsigned untilIdx) {
+      if (!oldResultAttrs) {
+        newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
+      } else {
+        auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
+        newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
+                              oldResultAttrsRange.begin() + untilIdx);
+      }
+      oldIdx = untilIdx;
+    };
+    for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
+      migrate(resultIndices[i]);
+      newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
+                                                   : resultAttrs[i]);
+    }
+    migrate(originalNumResults);
+    setAllResultAttrDicts(op, newResultAttrs);
+  }
+
+  // Update the function type.
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+}
+
 void mlir::function_like_impl::eraseFunctionArguments(
     Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs,
     Type newType) {

diff  --git a/mlir/test/IR/test-func-insert-arg.mlir b/mlir/test/IR/test-func-insert-arg.mlir
new file mode 100644
index 0000000000000..2de6c666d0d31
--- /dev/null
+++ b/mlir/test/IR/test-func-insert-arg.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s -test-func-insert-arg -split-input-file | FileCheck %s
+
+// CHECK: func @f(%arg0: f32 {test.A})
+func @f() attributes {test.insert_args = [
+  [0, f32, {test.A}]]} {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B})
+func @f(%arg0: f32 {test.B}) attributes {test.insert_args = [
+  [0, f32, {test.A}]]} {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B})
+func @f(%arg0: f32 {test.A}) attributes {test.insert_args = [
+  [1, f32, {test.B}]]} {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B}, %arg2: f32 {test.C})
+func @f(%arg0: f32 {test.A}, %arg1: f32 {test.C}) attributes {test.insert_args = [
+  [1, f32, {test.B}]]} {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B}, %arg2: f32 {test.C})
+func @f(%arg0: f32 {test.B}) attributes {test.insert_args = [
+  [0, f32, {test.A}],
+  [1, f32, {test.C}]]} {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B}, %arg2: f32 {test.C})
+func @f(%arg0: f32 {test.C}) attributes {test.insert_args = [
+  [0, f32, {test.A}],
+  [0, f32, {test.B}]]} {
+  return
+}

diff  --git a/mlir/test/IR/test-func-insert-result.mlir b/mlir/test/IR/test-func-insert-result.mlir
new file mode 100644
index 0000000000000..129fff4b56fe9
--- /dev/null
+++ b/mlir/test/IR/test-func-insert-result.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -test-func-insert-result -split-input-file | FileCheck %s
+
+// CHECK: func private @f() -> (f32 {test.A})
+func private @f() attributes {test.insert_results = [
+  [0, f32, {test.A}]]}
+
+// -----
+
+// CHECK: func private @f() -> (f32 {test.A}, f32 {test.B})
+func private @f() -> (f32 {test.B}) attributes {test.insert_results = [
+  [0, f32, {test.A}]]}
+
+// -----
+
+// CHECK: func private @f() -> (f32 {test.A}, f32 {test.B})
+func private @f() -> (f32 {test.A}) attributes {test.insert_results = [
+  [1, f32, {test.B}]]}
+
+// -----
+
+// CHECK: func private @f() -> (f32 {test.A}, f32 {test.B}, f32 {test.C})
+func private @f() -> (f32 {test.A}, f32 {test.C}) attributes {test.insert_results = [
+  [1, f32, {test.B}]]}
+
+// -----
+
+// CHECK: func private @f() -> (f32 {test.A}, f32 {test.B}, f32 {test.C})
+func private @f() -> (f32 {test.B}) attributes {test.insert_results = [
+  [0, f32, {test.A}],
+  [1, f32, {test.C}]]}
+
+// -----
+
+// CHECK: func private @f() -> (f32 {test.A}, f32 {test.B}, f32 {test.C})
+func private @f() -> (f32 {test.C}) attributes {test.insert_results = [
+  [0, f32, {test.A}],
+  [0, f32, {test.B}]]}

diff  --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp
index 4f2b45628ed99..a2d0c796c66b5 100644
--- a/mlir/test/lib/IR/TestFunc.cpp
+++ b/mlir/test/lib/IR/TestFunc.cpp
@@ -12,6 +12,72 @@
 using namespace mlir;
 
 namespace {
+/// This is a test pass for verifying FuncOp's insertArgument method.
+struct TestFuncInsertArg
+    : public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
+  StringRef getArgument() const final { return "test-func-insert-arg"; }
+  StringRef getDescription() const final { return "Test inserting func args."; }
+  void runOnOperation() override {
+    auto module = getOperation();
+
+    for (FuncOp func : module.getOps<FuncOp>()) {
+      auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
+      if (!inserts || inserts.empty())
+        continue;
+      SmallVector<unsigned, 4> indicesToInsert;
+      SmallVector<Type, 4> typesToInsert;
+      SmallVector<DictionaryAttr, 4> attrsToInsert;
+      SmallVector<Optional<Location>, 4> locsToInsert;
+      for (auto insert : inserts.getAsRange<ArrayAttr>()) {
+        indicesToInsert.push_back(
+            insert[0].cast<IntegerAttr>().getValue().getZExtValue());
+        typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
+        attrsToInsert.push_back(insert.size() > 2
+                                    ? insert[2].cast<DictionaryAttr>()
+                                    : DictionaryAttr::get(&getContext()));
+        locsToInsert.push_back(
+            insert.size() > 3
+                ? Optional<Location>(insert[3].cast<LocationAttr>())
+                : Optional<Location>{});
+      }
+      func->removeAttr("test.insert_args");
+      func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
+                           locsToInsert);
+    }
+  }
+};
+
+/// This is a test pass for verifying FuncOp's insertResult method.
+struct TestFuncInsertResult
+    : public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
+  StringRef getArgument() const final { return "test-func-insert-result"; }
+  StringRef getDescription() const final {
+    return "Test inserting func results.";
+  }
+  void runOnOperation() override {
+    auto module = getOperation();
+
+    for (FuncOp func : module.getOps<FuncOp>()) {
+      auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
+      if (!inserts || inserts.empty())
+        continue;
+      SmallVector<unsigned, 4> indicesToInsert;
+      SmallVector<Type, 4> typesToInsert;
+      SmallVector<DictionaryAttr, 4> attrsToInsert;
+      for (auto insert : inserts.getAsRange<ArrayAttr>()) {
+        indicesToInsert.push_back(
+            insert[0].cast<IntegerAttr>().getValue().getZExtValue());
+        typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
+        attrsToInsert.push_back(insert.size() > 2
+                                    ? insert[2].cast<DictionaryAttr>()
+                                    : DictionaryAttr::get(&getContext()));
+      }
+      func->removeAttr("test.insert_results");
+      func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
+    }
+  }
+};
+
 /// This is a test pass for verifying FuncOp's eraseArgument method.
 struct TestFuncEraseArg
     : public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
@@ -51,18 +117,15 @@ struct TestFuncEraseResult
     for (FuncOp func : module.getOps<FuncOp>()) {
       SmallVector<unsigned, 4> indicesToErase;
       for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
-        if (func.getResultAttr(resultIndex, "test.erase_this_"
-                                            "result")) {
-          // Push back twice to test
-          // that duplicate indices
-          // are handled correctly.
+        if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
+          // Push back twice to test that duplicate indices are handled
+          // correctly.
           indicesToErase.push_back(resultIndex);
           indicesToErase.push_back(resultIndex);
         }
       }
-      // Reverse the order to test
-      // that unsorted index lists are
-      // handled correctly.
+      // Reverse the order to test that unsorted index lists are handled
+      // correctly.
       std::reverse(indicesToErase.begin(), indicesToErase.end());
       func.eraseResults(indicesToErase);
     }
@@ -90,6 +153,10 @@ struct TestFuncSetType
 
 namespace mlir {
 void registerTestFunc() {
+  PassRegistration<TestFuncInsertArg>();
+
+  PassRegistration<TestFuncInsertResult>();
+
   PassRegistration<TestFuncEraseArg>();
 
   PassRegistration<TestFuncEraseResult>();


        


More information about the Mlir-commits mailing list