[Mlir-commits] [mlir] 7f312f6 - [mlir] Avoid folding in OpBuilder::tryFold when types change
River Riddle
llvmlistbot at llvm.org
Wed Nov 3 13:35:53 PDT 2021
Author: River Riddle
Date: 2021-11-03T20:35:46Z
New Revision: 7f312f6d790113f282fe336d7c501638cea392c8
URL: https://github.com/llvm/llvm-project/commit/7f312f6d790113f282fe336d7c501638cea392c8
DIFF: https://github.com/llvm/llvm-project/commit/7f312f6d790113f282fe336d7c501638cea392c8.diff
LOG: [mlir] Avoid folding in OpBuilder::tryFold when types change
This was missed when tightening fold restrictions in https://reviews.llvm.org/D95991.
Differential Revision: https://reviews.llvm.org/D113138
Added:
Modified:
mlir/lib/IR/Builders.cpp
mlir/test/Transforms/test-legalizer.mlir
Removed:
################################################################################
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 775d0c40c53c5..68471a5fbf7da 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -392,9 +392,11 @@ Operation *OpBuilder::createOperation(const OperationState &state) {
/// Note: This function does not erase the operation on a successful fold.
LogicalResult OpBuilder::tryFold(Operation *op,
SmallVectorImpl<Value> &results) {
- results.reserve(op->getNumResults());
+ ResultRange opResults = op->getResults();
+
+ results.reserve(opResults.size());
auto cleanupFailure = [&] {
- results.assign(op->result_begin(), op->result_end());
+ results.assign(opResults.begin(), opResults.end());
return failure();
};
@@ -405,7 +407,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
// Check to see if any operands to the operation is constant and whether
// the operation knows how to constant fold itself.
SmallVector<Attribute, 4> constOperands(op->getNumOperands());
- for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+ for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
// Try to fold the operation.
@@ -419,9 +421,14 @@ LogicalResult OpBuilder::tryFold(Operation *op,
// Populate the results with the folded results.
Dialect *dialect = op->getDialect();
- for (auto &it : llvm::enumerate(foldResults)) {
+ for (auto it : llvm::zip(foldResults, opResults.getTypes())) {
+ Type expectedType = std::get<1>(it);
+
// Normal values get pushed back directly.
- if (auto value = it.value().dyn_cast<Value>()) {
+ if (auto value = std::get<0>(it).dyn_cast<Value>()) {
+ if (value.getType() != expectedType)
+ return cleanupFailure();
+
results.push_back(value);
continue;
}
@@ -431,9 +438,9 @@ LogicalResult OpBuilder::tryFold(Operation *op,
return cleanupFailure();
// Ask the dialect to materialize a constant operation for this value.
- Attribute attr = it.value().get<Attribute>();
- auto *constOp = dialect->materializeConstant(
- cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
+ Attribute attr = std::get<0>(it).get<Attribute>();
+ auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
+ op->getLoc());
if (!constOp) {
// Erase any generated constants.
for (Operation *cst : generatedConstants)
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 3342402740209..556e820465da5 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -307,3 +307,13 @@ builtin.module {
}
}
+
+// -----
+
+// The "passthrough_fold" folder will naively return its operand, but we don't
+// want to fold here because of the type mismatch.
+func @typemismatch(%arg: f32) -> i32 {
+ // expected-remark at +1 {{op 'test.passthrough_fold' is not legalizable}}
+ %0 = "test.passthrough_fold"(%arg) : (f32) -> (i32)
+ "test.return"(%0) : (i32) -> ()
+}
More information about the Mlir-commits
mailing list