[Mlir-commits] [mlir] [mlir] only fold llvm.mlir.constant when types match (PR #70318)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 26 04:33:49 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

The llvm.mlir.constant operation may be using different types for the attribute value containing the constant and the resulting type. Only constant fold the value when the two types match, as changes of type are surprising during folded constant propagation.

Also relax the assertion in `m_Constant` matcher that a ConstantLike operation must always constant-fold to support this case. The matcher could fail anyway when looking for a constant with a more specific type or value, so there were no further assumptions on it always succeeding.

Fixes #<!-- -->70278.

---
Full diff: https://github.com/llvm/llvm-project/pull/70318.diff


5 Files Affected:

- (modified) mlir/include/mlir/IR/Matchers.h (+2-3) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+12-1) 
- (modified) mlir/test/Conversion/FuncToLLVM/calling-convention.mlir (+2-2) 
- (modified) mlir/test/Dialect/LLVMIR/canonicalize.mlir (+18) 
- (added) mlir/test/Dialect/LLVMIR/fold-cross-dialect.mlir (+50) 


``````````diff
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c0..ba309a4fb83e1b8 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -87,9 +87,8 @@ struct constant_op_binder {
 
     // Fold the constant to an attribute.
     SmallVector<OpFoldResult, 1> foldedOp;
-    LogicalResult result = op->fold(/*operands=*/std::nullopt, foldedOp);
-    (void)result;
-    assert(succeeded(result) && "expected ConstantLike op to be foldable");
+    if (failed(op->fold(/*operands=*/std::nullopt, foldedOp)))
+      return false;
 
     if (auto attr = llvm::dyn_cast<AttrT>(foldedOp.front().get<Attribute>())) {
       if (bind_value)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 95c04098d05fc2f..49b6abe87b15ddf 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2757,7 +2757,18 @@ ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
 }
 
 // Constant op constant-folds to its value.
-OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
+OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) {
+  if (auto integerValue = getValue().dyn_cast<IntegerAttr>();
+      integerValue && integerValue.getType() == getType()) {
+    return getValue();
+  }
+  if (auto floatValue = getValue().dyn_cast<FloatAttr>();
+      floatValue && floatValue.getType() == getType()) {
+    return getValue();
+  }
+
+  return {};
+}
 
 //===----------------------------------------------------------------------===//
 // AtomicRMWOp
diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index 1ed67708875604d..710f70401589d00 100644
--- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
@@ -127,7 +127,7 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
   // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
   // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
-  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
   // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
   // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
@@ -159,7 +159,7 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes
 
   // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
-  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
   // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
   // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
index 5e26fa37b681d71..d515ea3d5d87959 100644
--- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir
+++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir
@@ -208,3 +208,21 @@ llvm.func @volatile_load(%x : !llvm.ptr) {
   %3 = llvm.load %x  atomic unordered { alignment = 1 } : !llvm.ptr -> i8
   llvm.return
 }
+
+// -----
+
+llvm.func @fold_constant_mismatching_type() -> i64 {
+  %0 = llvm.mlir.constant(42 : index) : i64
+  %1 = llvm.mlir.constant(2 : i64) : i64
+  %2 = arith.addi %0, %1 : i64
+  llvm.return %2 : i64
+}
+
+// -----
+
+llvm.func @fold_constant_matching_type() -> i64 {
+  %0 = llvm.mlir.constant(42 : i64) : i64
+  %1 = llvm.mlir.constant(2 : i64) : i64
+  %2 = arith.addi %0, %1 : i64
+  llvm.return %2 : i64
+}
diff --git a/mlir/test/Dialect/LLVMIR/fold-cross-dialect.mlir b/mlir/test/Dialect/LLVMIR/fold-cross-dialect.mlir
new file mode 100644
index 000000000000000..2c111c26814d767
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/fold-cross-dialect.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt %s --test-constant-fold --split-input-file | FileCheck %s
+
+// CHECK-LABEL: llvm.func @fold_constant_mismatching_type
+llvm.func @fold_constant_mismatching_type() -> i64 {
+  // CHECK-DAG: llvm.mlir.constant(42
+  // CHECK-DAG: llvm.mlir.constant(2
+  %0 = llvm.mlir.constant(42 : index) : i64
+  %1 = llvm.mlir.constant(2 : i64) : i64
+  // Using arith.add as there is no folder for llvm.add.
+  // This is not expected to fold because attribute types differ.
+  // CHECK: arith.add
+  %2 = arith.addi %0, %1 : i64
+  llvm.return %2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @fold_constant_matching_type
+llvm.func @fold_constant_matching_type() -> i64 {
+  // CHECK-NOT: llvm.mlir.constant
+  %0 = llvm.mlir.constant(42 : i64) : i64
+  %1 = llvm.mlir.constant(2 : i64) : i64
+  // Using arith.add as there is no folder for llvm.add.
+  // This is expected to fold.
+  // CHECK: %[[V:.+]] = arith.constant 44 : i64
+  // CHECK: llvm.return %[[V]]
+  %2 = arith.addi %0, %1 : i64
+  llvm.return %2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @fold_constant_vector_mismatching_type
+func.func @fold_constant_vector_mismatching_type() {
+  // CHECK: llvm.mlir.constant(0
+  %220 = llvm.mlir.constant(0 : index) : i64
+  // CHECK: vector.broadcast
+  %365 = vector.broadcast %220 : i64 to vector<26xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @fold_constant_vector_matching_type
+func.func @fold_constant_vector_matching_type() -> vector<26xi64>{
+  // CHECK: arith.constant dense<0> : vector<26xi64>
+  %220 = llvm.mlir.constant(0 : i64) : i64
+  %365 = vector.broadcast %220 : i64 to vector<26xi64>
+  return %365 : vector<26xi64>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/70318


More information about the Mlir-commits mailing list