[Mlir-commits] [mlir] a1d5bdf - Make the folder more robust against op fold() methods that generate a type mismatch

Mehdi Amini llvmlistbot at llvm.org
Wed Feb 3 17:59:09 PST 2021


Author: Mehdi Amini
Date: 2021-02-04T01:58:56Z
New Revision: a1d5bdf8192fccf5dddeb3c18a187e9ffe2c2dbd

URL: https://github.com/llvm/llvm-project/commit/a1d5bdf8192fccf5dddeb3c18a187e9ffe2c2dbd
DIFF: https://github.com/llvm/llvm-project/commit/a1d5bdf8192fccf5dddeb3c18a187e9ffe2c2dbd.diff

LOG: Make the folder more robust against op fold() methods that generate a type mismatch

We could extend this with an interface to allow dialect to perform a type
conversion, but that would make the folder creating operation which isn't
the case at the moment, and isn't necessarily always desirable.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D95991

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/test/Transforms/test-canonicalize.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 59af919b9335..52eee2cb5d2d 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -221,6 +221,8 @@ LogicalResult OperationFolder::tryToFold(
 
     // Check if the result was an SSA value.
     if (auto repl = foldResults[i].dyn_cast<Value>()) {
+      if (repl.getType() != op->getResult(i).getType())
+        return failure();
       results.emplace_back(repl);
       continue;
     }

diff  --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 0cd308d378c1..cc6af03a7818 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
 
 // CHECK-LABEL: func @remove_op_with_inner_ops_pattern
 func @remove_op_with_inner_ops_pattern() {
   // CHECK-NEXT: return
   "test.op_with_region_pattern"() ({
-    "foo.op_with_region_terminator"() : () -> ()
+    "test.op_with_region_terminator"() : () -> ()
   }) : () -> ()
   return
 }
@@ -13,7 +13,7 @@ func @remove_op_with_inner_ops_pattern() {
 func @remove_op_with_inner_ops_fold_no_side_effect() {
   // CHECK-NEXT: return
   "test.op_with_region_fold_no_side_effect"() ({
-    "foo.op_with_region_terminator"() : () -> ()
+    "test.op_with_region_terminator"() : () -> ()
   }) : () -> ()
   return
 }
@@ -23,7 +23,7 @@ func @remove_op_with_inner_ops_fold_no_side_effect() {
 func @remove_op_with_inner_ops_fold(%arg0 : i32) -> (i32) {
   // CHECK-NEXT: return %[[ARG_0]]
   %0 = "test.op_with_region_fold"(%arg0) ({
-    "foo.op_with_region_terminator"() : () -> ()
+    "test.op_with_region_terminator"() : () -> ()
   }) : (i32) -> (i32)
   return %0 : i32
 }
@@ -51,3 +51,14 @@ func @test_commutative_multi(%arg0: i32, %arg1: i32) -> (i32, i32) {
   // CHECK-NEXT: return %[[O0]], %[[O1]]
   return %y, %z: i32, i32
 }
+
+func @typemismatch() -> i32 {
+  %c42 = constant 42.0 : f32
+
+  // The "passthrough_fold" folder will naively return its operand, but we don't
+  // want to fold here because of the type mismatch.
+
+  // CHECK: "test.passthrough_fold"
+  %0 = "test.passthrough_fold"(%c42) : (f32) -> (i32)
+  return %0 : i32
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a4139a5bf888..b13f4f44f1bd 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -628,6 +628,10 @@ OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
+  return getOperand();
+}
+
 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
     MLIRContext *, Optional<Location> location, ValueRange operands,
     DictionaryAttr attributes, RegionRange regions,

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ec836f2385b6..a36983864556 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -900,6 +900,13 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
   let hasFolder = 1;
 }
 
+// An op that always fold itself.
+def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
+  let arguments = (ins AnyType:$op);
+  let results = (outs AnyType);
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Test Patterns (Symbol Binding)
 


        


More information about the Mlir-commits mailing list