[Mlir-commits] [mlir] 7f312f6 - [mlir] Avoid folding in OpBuilder::tryFold when types change

River Riddle llvmlistbot at llvm.org
Wed Nov 3 13:35:53 PDT 2021


Author: River Riddle
Date: 2021-11-03T20:35:46Z
New Revision: 7f312f6d790113f282fe336d7c501638cea392c8

URL: https://github.com/llvm/llvm-project/commit/7f312f6d790113f282fe336d7c501638cea392c8
DIFF: https://github.com/llvm/llvm-project/commit/7f312f6d790113f282fe336d7c501638cea392c8.diff

LOG: [mlir] Avoid folding in OpBuilder::tryFold when types change

This was missed when tightening fold restrictions in https://reviews.llvm.org/D95991.

Differential Revision: https://reviews.llvm.org/D113138

Added: 
    

Modified: 
    mlir/lib/IR/Builders.cpp
    mlir/test/Transforms/test-legalizer.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 775d0c40c53c5..68471a5fbf7da 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -392,9 +392,11 @@ Operation *OpBuilder::createOperation(const OperationState &state) {
 /// Note: This function does not erase the operation on a successful fold.
 LogicalResult OpBuilder::tryFold(Operation *op,
                                  SmallVectorImpl<Value> &results) {
-  results.reserve(op->getNumResults());
+  ResultRange opResults = op->getResults();
+
+  results.reserve(opResults.size());
   auto cleanupFailure = [&] {
-    results.assign(op->result_begin(), op->result_end());
+    results.assign(opResults.begin(), opResults.end());
     return failure();
   };
 
@@ -405,7 +407,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
   // Check to see if any operands to the operation is constant and whether
   // the operation knows how to constant fold itself.
   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
-  for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
 
   // Try to fold the operation.
@@ -419,9 +421,14 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
   // Populate the results with the folded results.
   Dialect *dialect = op->getDialect();
-  for (auto &it : llvm::enumerate(foldResults)) {
+  for (auto it : llvm::zip(foldResults, opResults.getTypes())) {
+    Type expectedType = std::get<1>(it);
+
     // Normal values get pushed back directly.
-    if (auto value = it.value().dyn_cast<Value>()) {
+    if (auto value = std::get<0>(it).dyn_cast<Value>()) {
+      if (value.getType() != expectedType)
+        return cleanupFailure();
+
       results.push_back(value);
       continue;
     }
@@ -431,9 +438,9 @@ LogicalResult OpBuilder::tryFold(Operation *op,
       return cleanupFailure();
 
     // Ask the dialect to materialize a constant operation for this value.
-    Attribute attr = it.value().get<Attribute>();
-    auto *constOp = dialect->materializeConstant(
-        cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
+    Attribute attr = std::get<0>(it).get<Attribute>();
+    auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
+                                                 op->getLoc());
     if (!constOp) {
       // Erase any generated constants.
       for (Operation *cst : generatedConstants)

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 3342402740209..556e820465da5 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -307,3 +307,13 @@ builtin.module {
   }
 
 }
+
+// -----
+
+// The "passthrough_fold" folder will naively return its operand, but we don't
+// want to fold here because of the type mismatch.
+func @typemismatch(%arg: f32) -> i32 {
+  // expected-remark at +1 {{op 'test.passthrough_fold' is not legalizable}}
+  %0 = "test.passthrough_fold"(%arg) : (f32) -> (i32)
+  "test.return"(%0) : (i32) -> ()
+}


        


More information about the Mlir-commits mailing list