[Mlir-commits] [mlir] 1253c40 - [mlir] Add FuncOp::eraseResults
Sean Silva
llvmlistbot at llvm.org
Fri Oct 23 11:04:04 PDT 2020
Author: Sean Silva
Date: 2020-10-23T11:03:42-07:00
New Revision: 1253c40727d2fae9398fc63c86de75db88fb5124
URL: https://github.com/llvm/llvm-project/commit/1253c40727d2fae9398fc63c86de75db88fb5124
DIFF: https://github.com/llvm/llvm-project/commit/1253c40727d2fae9398fc63c86de75db88fb5124.diff
LOG: [mlir] Add FuncOp::eraseResults
I just found I needed this in an upcoming patch, and it seems generally
useful to have.
Differential Revision: https://reviews.llvm.org/D90000
Added:
mlir/test/IR/test-func-erase-result.mlir
Modified:
mlir/include/mlir/IR/Function.h
mlir/lib/IR/Function.cpp
mlir/test/lib/IR/TestFunc.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 60ae8497d8d7..5c57b754828c 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -65,6 +65,12 @@ class FuncOp
/// `argIndices` is allowed to have duplicates and can be in any order.
void eraseArguments(ArrayRef<unsigned> argIndices);
+ /// Erase a single result at `resultIndex`.
+ void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
+ /// Erases the results listed in `resultIndices`.
+ /// `resultIndices` is allowed to have duplicates and can be in any order.
+ void eraseResults(ArrayRef<unsigned> resultIndices);
+
/// Create a deep copy of this function and all of its blocks, remapping
/// any operands that use values outside of the function using the map that is
/// provided (leaving them alone if no entry is present). If the mapper
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index 09c9441bb110..fb525d86912d 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -132,6 +132,31 @@ void FuncOp::eraseArguments(ArrayRef<unsigned> argIndices) {
entry.eraseArgument(originalNumArgs - i - 1);
}
+void FuncOp::eraseResults(ArrayRef<unsigned> resultIndices) {
+ auto oldType = getType();
+ int originalNumResults = oldType.getNumResults();
+ llvm::BitVector eraseIndices(originalNumResults);
+ for (auto index : resultIndices)
+ eraseIndices.set(index);
+ auto shouldEraseResult = [&](int i) { return eraseIndices.test(i); };
+
+ // There are 2 things that need to be updated:
+ // - Function type.
+ // - Result attrs.
+
+ // Update the function type and result attrs.
+ SmallVector<Type, 4> newResultTypes;
+ SmallVector<MutableDictionaryAttr, 4> newResultAttrs;
+ for (int i = 0; i < originalNumResults; i++) {
+ if (shouldEraseResult(i))
+ continue;
+ newResultTypes.emplace_back(oldType.getResult(i));
+ newResultAttrs.emplace_back(getResultAttrDict(i));
+ }
+ setType(FunctionType::get(oldType.getInputs(), newResultTypes, getContext()));
+ setAllResultAttrs(newResultAttrs);
+}
+
/// Clone the internal blocks from this function into dest and all attributes
/// from this function to dest.
void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
diff --git a/mlir/test/IR/test-func-erase-result.mlir b/mlir/test/IR/test-func-erase-result.mlir
new file mode 100644
index 000000000000..bdd08475ef12
--- /dev/null
+++ b/mlir/test/IR/test-func-erase-result.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s -test-func-erase-result -split-input-file | FileCheck %s
+
+// CHECK: func @f(){{$}}
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (f32 {test.erase_this_result})
+
+// -----
+
+// CHECK: func @f() -> (f32 {test.A})
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+ f32 {test.erase_this_result},
+ f32 {test.A}
+)
+
+// -----
+
+// CHECK: func @f() -> (f32 {test.A})
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+ f32 {test.A},
+ f32 {test.erase_this_result}
+)
+
+// -----
+
+// CHECK: func @f() -> (f32 {test.A}, f32 {test.B})
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+ f32 {test.A},
+ f32 {test.erase_this_result},
+ f32 {test.B}
+)
+
+// -----
+
+// CHECK: func @f() -> (f32 {test.A}, f32 {test.B})
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+ f32 {test.A},
+ f32 {test.erase_this_result},
+ f32 {test.erase_this_result},
+ f32 {test.B}
+)
+
+// -----
+
+// CHECK: func @f() -> (f32 {test.A}, f32 {test.B}, f32 {test.C})
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+ f32 {test.A},
+ f32 {test.erase_this_result},
+ f32 {test.B},
+ f32 {test.erase_this_result},
+ f32 {test.C}
+)
+
+// -----
+
+// CHECK: func @f() -> (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>)
+// CHECK-NOT: attributes{{.*}}result
+func @f() -> (
+ tensor<1xf32>,
+ f32 {test.erase_this_result},
+ tensor<2xf32>,
+ f32 {test.erase_this_result},
+ tensor<3xf32>
+)
diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp
index 637864e049fd..f6b8294258f9 100644
--- a/mlir/test/lib/IR/TestFunc.cpp
+++ b/mlir/test/lib/IR/TestFunc.cpp
@@ -36,6 +36,30 @@ struct TestFuncEraseArg
}
};
+/// This is a test pass for verifying FuncOp's eraseResult method.
+struct TestFuncEraseResult
+ : public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
+ void runOnOperation() override {
+ auto module = getOperation();
+
+ for (FuncOp func : module.getOps<FuncOp>()) {
+ SmallVector<unsigned, 4> indicesToErase;
+ for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
+ if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
+ // Push back twice to test that duplicate indices are handled
+ // correctly.
+ indicesToErase.push_back(resultIndex);
+ indicesToErase.push_back(resultIndex);
+ }
+ }
+ // Reverse the order to test that unsorted index lists are handled
+ // correctly.
+ std::reverse(indicesToErase.begin(), indicesToErase.end());
+ func.eraseResults(indicesToErase);
+ }
+ }
+};
+
/// This is a test pass for verifying FuncOp's setType method.
struct TestFuncSetType
: public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
@@ -55,10 +79,13 @@ struct TestFuncSetType
namespace mlir {
void registerTestFunc() {
- PassRegistration<TestFuncEraseArg> pass("test-func-erase-arg",
- "Test erasing func args.");
+ PassRegistration<TestFuncEraseArg>("test-func-erase-arg",
+ "Test erasing func args.");
+
+ PassRegistration<TestFuncEraseResult>("test-func-erase-result",
+ "Test erasing func results.");
- PassRegistration<TestFuncSetType> pass2("test-func-set-type",
- "Test FuncOp::setType.");
+ PassRegistration<TestFuncSetType>("test-func-set-type",
+ "Test FuncOp::setType.");
}
} // namespace mlir
More information about the Mlir-commits
mailing list