[Mlir-commits] [mlir] [mlir] only fold llvm.mlir.constant when types match (PR #70318)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Thu Oct 26 04:33:51 PDT 2023


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/70318

>From 9171bad250f67333bc68725fb3951a2e52befe6d Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Thu, 26 Oct 2023 11:27:42 +0000
Subject: [PATCH] [mlir] only fold llvm.mlir.constant when types match

The llvm.mlir.constant operation may be using different types for the
attribute value containing the constant and the resulting type. Only
constant fold the value when the two types match, as changes of type are
surprising during folded constant propagation.

Also relax the assertion in `m_Constant` matcher that a ConstantLike
operation must always constant-fold to support this case. The matcher
could fail anyway when looking for a constant with a more specific type
or value, so there were no further assumptions on it always succeeding.

Fixes #70278.
---
 mlir/include/mlir/IR/Matchers.h               |  5 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 13 ++++-
 .../FuncToLLVM/calling-convention.mlir        |  4 +-
 mlir/test/Dialect/LLVMIR/canonicalize.mlir    |  1 +
 .../Dialect/LLVMIR/fold-cross-dialect.mlir    | 50 +++++++++++++++++++
 5 files changed, 67 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/Dialect/LLVMIR/fold-cross-dialect.mlir

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c0..ba309a4fb83e1b8 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -87,9 +87,8 @@ struct constant_op_binder {
 
     // Fold the constant to an attribute.
     SmallVector<OpFoldResult, 1> foldedOp;
-    LogicalResult result = op->fold(/*operands=*/std::nullopt, foldedOp);
-    (void)result;
-    assert(succeeded(result) && "expected ConstantLike op to be foldable");
+    if (failed(op->fold(/*operands=*/std::nullopt, foldedOp)))
+      return false;
 
     if (auto attr = llvm::dyn_cast<AttrT>(foldedOp.front().get<Attribute>())) {
       if (bind_value)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 95c04098d05fc2f..49b6abe87b15ddf 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2757,7 +2757,18 @@ ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
 }
 
 // Constant op constant-folds to its value.
-OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
+OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) {
+  if (auto integerValue = getValue().dyn_cast<IntegerAttr>();
+      integerValue && integerValue.getType() == getType()) {
+    return getValue();
+  }
+  if (auto floatValue = getValue().dyn_cast<FloatAttr>();
+      floatValue && floatValue.getType() == getType()) {
+    return getValue();
+  }
+
+  return {};
+}
 
 //===----------------------------------------------------------------------===//
 // AtomicRMWOp
diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index 1ed67708875604d..710f70401589d00 100644
--- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
@@ -127,7 +127,7 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
   // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
   // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
-  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
   // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
   // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
@@ -159,7 +159,7 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes
 
   // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
-  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
   // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
   // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 5e26fa37b681d71..f19dcc4c8fe6b0a 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -208,3 +208,4 @@ llvm.func @volatile_load(%x : !llvm.ptr) {
   %3 = llvm.load %x  atomic unordered { alignment = 1 } : !llvm.ptr -> i8
   llvm.return
 }
+
diff --git a/mlir/test/Dialect/LLVMIR/fold-cross-dialect.mlir b/mlir/test/Dialect/LLVMIR/fold-cross-dialect.mlir
new file mode 100644
index 000000000000000..2c111c26814d767
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/fold-cross-dialect.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt %s --test-constant-fold --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @fold_constant_mismatching_type
+llvm.func @fold_constant_mismatching_type() -> i64 {
+  // CHECK-DAG: llvm.mlir.constant(42
+  // CHECK-DAG: llvm.mlir.constant(2
+  %0 = llvm.mlir.constant(42 : index) : i64
+  %1 = llvm.mlir.constant(2 : i64) : i64
+  // Using arith.add as there is no folder for llvm.add.
+  // This is not expected to fold because attribute types differ.
+  // CHECK: arith.add
+  %2 = arith.addi %0, %1 : i64
+  llvm.return %2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @fold_constant_matching_type
+llvm.func @fold_constant_matching_type() -> i64 {
+  // CHECK-NOT: llvm.mlir.constant
+  %0 = llvm.mlir.constant(42 : i64) : i64
+  %1 = llvm.mlir.constant(2 : i64) : i64
+  // Using arith.add as there is no folder for llvm.add.
+  // This is expected to fold.
+  // CHECK: %[[V:.+]] = arith.constant 44 : i64
+  // CHECK: llvm.return %[[V]]
+  %2 = arith.addi %0, %1 : i64
+  llvm.return %2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @fold_constant_vector_mismatching_type
+func.func @fold_constant_vector_mismatching_type() {
+  // CHECK: llvm.mlir.constant(0
+  %220 = llvm.mlir.constant(0 : index) : i64
+  // CHECK: vector.broadcast
+  %365 = vector.broadcast %220 : i64 to vector<26xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @fold_constant_vector_matching_type
+func.func @fold_constant_vector_matching_type() -> vector<26xi64>{
+  // CHECK: arith.constant dense<0> : vector<26xi64>
+  %220 = llvm.mlir.constant(0 : i64) : i64
+  %365 = vector.broadcast %220 : i64 to vector<26xi64>
+  return %365 : vector<26xi64>
+}



More information about the Mlir-commits mailing list