[Mlir-commits] [mlir] 1253c40 - [mlir] Add FuncOp::eraseResults

Sean Silva llvmlistbot at llvm.org
Fri Oct 23 11:04:04 PDT 2020


Author: Sean Silva
Date: 2020-10-23T11:03:42-07:00
New Revision: 1253c40727d2fae9398fc63c86de75db88fb5124

URL: https://github.com/llvm/llvm-project/commit/1253c40727d2fae9398fc63c86de75db88fb5124
DIFF: https://github.com/llvm/llvm-project/commit/1253c40727d2fae9398fc63c86de75db88fb5124.diff

LOG: [mlir] Add FuncOp::eraseResults

I just found I needed this in an upcoming patch, and it seems generally
useful to have.

Differential Revision: https://reviews.llvm.org/D90000

Added: 
    mlir/test/IR/test-func-erase-result.mlir

Modified: 
    mlir/include/mlir/IR/Function.h
    mlir/lib/IR/Function.cpp
    mlir/test/lib/IR/TestFunc.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 60ae8497d8d7..5c57b754828c 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -65,6 +65,12 @@ class FuncOp
   /// `argIndices` is allowed to have duplicates and can be in any order.
   void eraseArguments(ArrayRef<unsigned> argIndices);
 
+  /// Erase a single result at `resultIndex`.
+  void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
+  /// Erases the results listed in `resultIndices`.
+  /// `resultIndices` is allowed to have duplicates and can be in any order.
+  void eraseResults(ArrayRef<unsigned> resultIndices);
+
   /// Create a deep copy of this function and all of its blocks, remapping
   /// any operands that use values outside of the function using the map that is
   /// provided (leaving them alone if no entry is present). If the mapper

diff  --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index 09c9441bb110..fb525d86912d 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -132,6 +132,31 @@ void FuncOp::eraseArguments(ArrayRef<unsigned> argIndices) {
       entry.eraseArgument(originalNumArgs - i - 1);
 }
 
+void FuncOp::eraseResults(ArrayRef<unsigned> resultIndices) {
+  auto oldType = getType();
+  int originalNumResults = oldType.getNumResults();
+  llvm::BitVector eraseIndices(originalNumResults);
+  for (auto index : resultIndices)
+    eraseIndices.set(index);
+  auto shouldEraseResult = [&](int i) { return eraseIndices.test(i); };
+
+  // There are 2 things that need to be updated:
+  // - Function type.
+  // - Result attrs.
+
+  // Update the function type and result attrs.
+  SmallVector<Type, 4> newResultTypes;
+  SmallVector<MutableDictionaryAttr, 4> newResultAttrs;
+  for (int i = 0; i < originalNumResults; i++) {
+    if (shouldEraseResult(i))
+      continue;
+    newResultTypes.emplace_back(oldType.getResult(i));
+    newResultAttrs.emplace_back(getResultAttrDict(i));
+  }
+  setType(FunctionType::get(oldType.getInputs(), newResultTypes, getContext()));
+  setAllResultAttrs(newResultAttrs);
+}
+
 /// Clone the internal blocks from this function into dest and all attributes
 /// from this function to dest.
 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {

diff  --git a/mlir/test/IR/test-func-erase-result.mlir b/mlir/test/IR/test-func-erase-result.mlir
new file mode 100644
index 000000000000..bdd08475ef12
--- /dev/null
+++ b/mlir/test/IR/test-func-erase-result.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s -test-func-erase-result -split-input-file | FileCheck %s
+
+// CHECK: func @f(){{$}}
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (f32 {test.erase_this_result})
+
+// -----
+
+// CHECK: func @f() -> (f32 {test.A})
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+  f32 {test.erase_this_result},
+  f32 {test.A}
+)
+
+// -----
+
+// CHECK: func @f() -> (f32 {test.A})
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+  f32 {test.A},
+  f32 {test.erase_this_result}
+)
+
+// -----
+
+// CHECK: func @f() -> (f32 {test.A}, f32 {test.B})
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+  f32 {test.A},
+  f32 {test.erase_this_result},
+  f32 {test.B}
+)
+
+// -----
+
+// CHECK: func @f() -> (f32 {test.A}, f32 {test.B})
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+  f32 {test.A},
+  f32 {test.erase_this_result},
+  f32 {test.erase_this_result},
+  f32 {test.B}
+)
+
+// -----
+
+// CHECK: func @f() -> (f32 {test.A}, f32 {test.B}, f32 {test.C})
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+  f32 {test.A},
+  f32 {test.erase_this_result},
+  f32 {test.B},
+  f32 {test.erase_this_result},
+  f32 {test.C}
+)
+
+// -----
+
+// CHECK: func @f() -> (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>)
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+  tensor<1xf32>,
+  f32 {test.erase_this_result},
+  tensor<2xf32>,
+  f32 {test.erase_this_result},
+  tensor<3xf32>
+)

diff  --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp
index 637864e049fd..f6b8294258f9 100644
--- a/mlir/test/lib/IR/TestFunc.cpp
+++ b/mlir/test/lib/IR/TestFunc.cpp
@@ -36,6 +36,30 @@ struct TestFuncEraseArg
   }
 };
 
+/// This is a test pass for verifying FuncOp's eraseResult method.
+struct TestFuncEraseResult
+    : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
+  void runOnOperation() override {
+    auto module = getOperation();
+
+    for (FuncOp func : module.getOps<FuncOp>()) {
+      SmallVector<unsigned, 4> indicesToErase;
+      for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
+        if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
+          // Push back twice to test that duplicate indices are handled
+          // correctly.
+          indicesToErase.push_back(resultIndex);
+          indicesToErase.push_back(resultIndex);
+        }
+      }
+      // Reverse the order to test that unsorted index lists are handled
+      // correctly.
+      std::reverse(indicesToErase.begin(), indicesToErase.end());
+      func.eraseResults(indicesToErase);
+    }
+  }
+};
+
 /// This is a test pass for verifying FuncOp's setType method.
 struct TestFuncSetType
     : public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
@@ -55,10 +79,13 @@ struct TestFuncSetType
 
 namespace mlir {
 void registerTestFunc() {
-  PassRegistration<TestFuncEraseArg> pass("test-func-erase-arg",
-                                          "Test erasing func args.");
+  PassRegistration<TestFuncEraseArg>("test-func-erase-arg",
+                                     "Test erasing func args.");
+
+  PassRegistration<TestFuncEraseResult>("test-func-erase-result",
+                                        "Test erasing func results.");
 
-  PassRegistration<TestFuncSetType> pass2("test-func-set-type",
-                                          "Test FuncOp::setType.");
+  PassRegistration<TestFuncSetType>("test-func-set-type",
+                                    "Test FuncOp::setType.");
 }
 } // namespace mlir


        


More information about the Mlir-commits mailing list