[Mlir-commits] [mlir] f10302e - [mlir] Require folders to produce Values of same type (#75887)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 19 21:39:27 PST 2023


Author: Matthias Springer
Date: 2023-12-20T14:39:22+09:00
New Revision: f10302e3fa468a22a43e7d6bd6ec75919c60d72d

URL: https://github.com/llvm/llvm-project/commit/f10302e3fa468a22a43e7d6bd6ec75919c60d72d
DIFF: https://github.com/llvm/llvm-project/commit/f10302e3fa468a22a43e7d6bd6ec75919c60d72d.diff

LOG: [mlir] Require folders to produce Values of same type (#75887)

This commit adds extra assertions to `OperationFolder` and `OpBuilder`
to ensure that the types of the folded SSA values match with the result
types of the op. There used to be checks that discard the folded results
if the types do not match. This commit makes these checks stricter and
turns them into assertions.

Discarding folded results with the wrong type (without failing
explicitly) can hide bugs in op folders. Two such bugs became apparent
in MLIR (and some more in downstream projects) and are fixed with this
change.

Note: The existing type checks were introduced in
https://reviews.llvm.org/D95991.

Migration guide: If you see failing assertions (`folder produced value
of incorrect type`; make sure to run with assertions enabled!), run with
`-debug` or dump the operation right before the failing assertion. This
will point you to the op that has the broken folder. A common mistake is
a mismatch between static/dynamic dimensions (e.g., input has a static
dimension but folded result has a dynamic dimension).

Added: 
    

Modified: 
    flang/lib/Optimizer/Dialect/FIROps.cpp
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/test/Transforms/test-canonicalize.mlir
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index ab1e9c0ec7adc8..6d62e470706e53 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -625,11 +625,13 @@ void fir::BoxAddrOp::build(mlir::OpBuilder &builder,
 mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) {
   if (auto *v = getVal().getDefiningOp()) {
     if (auto box = mlir::dyn_cast<fir::EmboxOp>(v)) {
-      if (!box.getSlice()) // Fold only if not sliced
+      // Fold only if not sliced
+      if (!box.getSlice() && box.getMemref().getType() == getType())
         return box.getMemref();
     }
     if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v))
-      return box.getMemref();
+      if (box.getMemref().getType() == getType())
+        return box.getMemref();
   }
   return {};
 }

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 56d5e0fed76185..ff72becc8dfa77 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
       setOperand(src);
       return getResult();
     }
+
     // trunci(zexti(a)) -> a
     // trunci(sexti(a)) -> a
-    return src;
+    if (srcType == dstType)
+      return src;
   }
 
   // trunci(trunci(a)) -> trunci(a))

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 7444f70a46e935..26c39ff3523434 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -771,6 +771,8 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
     ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());         \
     if (!inputTy.hasRank())                                                    \
       return {};                                                               \
+    if (inputTy != getType())                                                  \
+      return {};                                                               \
     if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1)          \
       return getInput();                                                       \
     return {};                                                                 \

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1d3200bf5c8217..8a23adae3c00e6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1602,9 +1602,10 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
                                        : 0;
   };
+
   // If splat or broadcast from a scalar, just return the source scalar.
   unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0)
+  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
     return source;
 
   unsigned extractResultRank = getRank(extractOp.getType());

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2cabfcd24d3559..d1565047658776 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -486,14 +486,11 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
   // Populate the results with the folded results.
   Dialect *dialect = op->getDialect();
-  for (auto it : llvm::zip(foldResults, opResults.getTypes())) {
+  for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
     Type expectedType = std::get<1>(it);
 
     // Normal values get pushed back directly.
     if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
-      if (value.getType() != expectedType)
-        return cleanupFailure();
-
       results.push_back(value);
       continue;
     }

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 87be08712ea35b..a726790391a0c5 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -606,13 +606,30 @@ void Operation::setSuccessor(Block *block, unsigned index) {
   getBlockOperands()[index].set(block);
 }
 
+#ifndef NDEBUG
+/// Assert that the folded results (in case of values) have the same type as
+/// the results of the given op.
+static void checkFoldResultTypes(Operation *op,
+                                 SmallVectorImpl<OpFoldResult> &results) {
+  if (!results.empty())
+    for (auto [ofr, opResult] : llvm::zip_equal(results, op->getResults()))
+      if (auto value = ofr.dyn_cast<Value>())
+        assert(value.getType() == opResult.getType() &&
+               "folder produced value of incorrect type");
+}
+#endif // NDEBUG
+
 /// Attempt to fold this operation using the Op's registered foldHook.
 LogicalResult Operation::fold(ArrayRef<Attribute> operands,
                               SmallVectorImpl<OpFoldResult> &results) {
   // If we have a registered operation definition matching this one, use it to
   // try to constant fold the operation.
-  if (succeeded(name.foldHook(this, operands, results)))
+  if (succeeded(name.foldHook(this, operands, results))) {
+#ifndef NDEBUG
+    checkFoldResultTypes(this, results);
+#endif // NDEBUG
     return success();
+  }
 
   // Otherwise, fall back on the dialect hook to handle it.
   Dialect *dialect = getDialect();
@@ -623,7 +640,12 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
   if (!interface)
     return failure();
 
-  return interface->fold(this, operands, results);
+  LogicalResult status = interface->fold(this, operands, results);
+#ifndef NDEBUG
+  if (succeeded(status))
+    checkFoldResultTypes(this, results);
+#endif // NDEBUG
+  return status;
 }
 
 LogicalResult Operation::fold(SmallVectorImpl<OpFoldResult> &results) {

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 90ee5ba51de3ad..eb4dcb251a2280 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -247,10 +247,6 @@ OperationFolder::processFoldResults(Operation *op,
 
     // Check if the result was an SSA value.
     if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
-      if (repl.getType() != op->getResult(i).getType()) {
-        results.clear();
-        return failure();
-      }
       results.emplace_back(repl);
       continue;
     }

diff  --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index bc463fefe65342..4f0095ed7e8cf4 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -70,19 +70,6 @@ func.func @test_commutative_multi_cst(%arg0: i32, %arg1: i32) -> (i32, i32) {
   return %y, %z: i32, i32
 }
 
-// CHECK-LABEL: func @typemismatch
-
-func.func @typemismatch() -> i32 {
-  %c42 = arith.constant 42.0 : f32
-
-  // The "passthrough_fold" folder will naively return its operand, but we don't
-  // want to fold here because of the type mismatch.
-
-  // CHECK: "test.passthrough_fold"
-  %0 = "test.passthrough_fold"(%c42) : (f32) -> (i32)
-  return %0 : i32
-}
-
 // CHECK-LABEL: test_dialect_canonicalizer
 func.func @test_dialect_canonicalizer() -> (i32) {
   %0 = "test.dialect_canonicalizable"() : () -> (i32)

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 6897b6f95f0d05..d8cf6e4719cede 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -310,16 +310,6 @@ builtin.module {
 
 // -----
 
-// The "passthrough_fold" folder will naively return its operand, but we don't
-// want to fold here because of the type mismatch.
-func.func @typemismatch(%arg: f32) -> i32 {
-  // expected-remark at +1 {{op 'test.passthrough_fold' is not legalizable}}
-  %0 = "test.passthrough_fold"(%arg) : (f32) -> (i32)
-  "test.return"(%0) : (i32) -> ()
-}
-
-// -----
-
 // expected-remark @below {{applyPartialConversion failed}}
 module {
   func.func private @callee(%0 : f32) -> f32

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 21400a60e65321..a1b30705f16a98 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -542,10 +542,6 @@ OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) {
-  return getOperand();
-}
-
 OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
   int64_t sum = 0;
   if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 96f66c2ca06ecf..70ccc71883e3c1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1363,13 +1363,6 @@ def TestOpFoldWithFoldAdaptor
   let hasFolder = 1;
 }
 
-// An op that always fold itself.
-def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
-  let arguments = (ins AnyType:$op);
-  let results = (outs AnyType);
-  let hasFolder = 1;
-}
-
 def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> {
   let arguments = (ins);
   let results = (outs I32);


        


More information about the Mlir-commits mailing list