[flang-commits] [flang] [mlir] [mlir] Require folders to produce Values of same type (PR #75887)

Matthias Springer via flang-commits flang-commits at lists.llvm.org
Mon Dec 18 22:44:31 PST 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/75887

>From ef7badf671bb18407482d4a297ae0234d5437849 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 19 Dec 2023 12:17:51 +0900
Subject: [PATCH] [mlir] Require folders to produce Values of same type

This commit adds extra assertions to `OperationFolder` and `OpBuilder` to ensure that the types of the folded SSA values match with the result types of the op. There used to be checks that discard the folded results if the types do not match. This commit makes these checks stricter and turns them into assertions.

Discarding folded results with the wrong type can hide bugs in op folders. Two such bugs became apparent and are fixed with this change.

Note: The existing type checks were introduced in https://reviews.llvm.org/D95991.
---
 flang/lib/Optimizer/Dialect/FIROps.cpp      |  6 ++++--
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp      |  4 +++-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp    |  5 +----
 mlir/lib/IR/Builders.cpp                    |  5 ++---
 mlir/lib/Transforms/Utils/FoldUtils.cpp     |  6 ++----
 mlir/test/Transforms/test-canonicalize.mlir | 13 -------------
 mlir/test/Transforms/test-legalizer.mlir    | 10 ----------
 mlir/test/lib/Dialect/Test/TestDialect.cpp  |  4 ----
 mlir/test/lib/Dialect/Test/TestOps.td       |  7 -------
 9 files changed, 12 insertions(+), 48 deletions(-)

diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index ab1e9c0ec7adc8..6d62e470706e53 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -625,11 +625,13 @@ void fir::BoxAddrOp::build(mlir::OpBuilder &builder,
 mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) {
   if (auto *v = getVal().getDefiningOp()) {
     if (auto box = mlir::dyn_cast<fir::EmboxOp>(v)) {
-      if (!box.getSlice()) // Fold only if not sliced
+      // Fold only if not sliced
+      if (!box.getSlice() && box.getMemref().getType() == getType())
         return box.getMemref();
     }
     if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v))
-      return box.getMemref();
+      if (box.getMemref().getType() == getType())
+        return box.getMemref();
   }
   return {};
 }
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 56d5e0fed76185..ff72becc8dfa77 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
       setOperand(src);
       return getResult();
     }
+
     // trunci(zexti(a)) -> a
     // trunci(sexti(a)) -> a
-    return src;
+    if (srcType == dstType)
+      return src;
   }
 
   // trunci(trunci(a)) -> trunci(a))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..ac9485326a32ed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1600,11 +1600,8 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
                                        : 0;
   };
-  // If splat or broadcast from a scalar, just return the source scalar.
-  unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0)
-    return source;
 
+  unsigned broadcastSrcRank = getRank(source.getType());
   unsigned extractResultRank = getRank(extractOp.getType());
   if (extractResultRank >= broadcastSrcRank)
     return Value();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2cabfcd24d3559..c28cbe109c3ffd 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -491,9 +491,8 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
     // Normal values get pushed back directly.
     if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
-      if (value.getType() != expectedType)
-        return cleanupFailure();
-
+      assert(value.getType() == expectedType &&
+             "folder produced value of incorrect type");
       results.push_back(value);
       continue;
     }
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 90ee5ba51de3ad..34b7117a035748 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -247,10 +247,8 @@ OperationFolder::processFoldResults(Operation *op,
 
     // Check if the result was an SSA value.
     if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
-      if (repl.getType() != op->getResult(i).getType()) {
-        results.clear();
-        return failure();
-      }
+      assert(repl.getType() == op->getResult(i).getType() &&
+             "folder produced value of incorrect type");
       results.emplace_back(repl);
       continue;
     }
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index bc463fefe65342..4f0095ed7e8cf4 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -70,19 +70,6 @@ func.func @test_commutative_multi_cst(%arg0: i32, %arg1: i32) -> (i32, i32) {
   return %y, %z: i32, i32
 }
 
-// CHECK-LABEL: func @typemismatch
-
-func.func @typemismatch() -> i32 {
-  %c42 = arith.constant 42.0 : f32
-
-  // The "passthrough_fold" folder will naively return its operand, but we don't
-  // want to fold here because of the type mismatch.
-
-  // CHECK: "test.passthrough_fold"
-  %0 = "test.passthrough_fold"(%c42) : (f32) -> (i32)
-  return %0 : i32
-}
-
 // CHECK-LABEL: test_dialect_canonicalizer
 func.func @test_dialect_canonicalizer() -> (i32) {
   %0 = "test.dialect_canonicalizable"() : () -> (i32)
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 6897b6f95f0d05..d8cf6e4719cede 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -310,16 +310,6 @@ builtin.module {
 
 // -----
 
-// The "passthrough_fold" folder will naively return its operand, but we don't
-// want to fold here because of the type mismatch.
-func.func @typemismatch(%arg: f32) -> i32 {
-  // expected-remark at +1 {{op 'test.passthrough_fold' is not legalizable}}
-  %0 = "test.passthrough_fold"(%arg) : (f32) -> (i32)
-  "test.return"(%0) : (i32) -> ()
-}
-
-// -----
-
 // expected-remark @below {{applyPartialConversion failed}}
 module {
   func.func private @callee(%0 : f32) -> f32
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 21400a60e65321..a1b30705f16a98 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -542,10 +542,6 @@ OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) {
-  return getOperand();
-}
-
 OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
   int64_t sum = 0;
   if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 96f66c2ca06ecf..70ccc71883e3c1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1363,13 +1363,6 @@ def TestOpFoldWithFoldAdaptor
   let hasFolder = 1;
 }
 
-// An op that always fold itself.
-def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
-  let arguments = (ins AnyType:$op);
-  let results = (outs AnyType);
-  let hasFolder = 1;
-}
-
 def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> {
   let arguments = (ins);
   let results = (outs I32);



More information about the flang-commits mailing list