[Mlir-commits] [mlir] 1c4b04c - [mlir] Fix crash in `InsertOpConstantFolder` when vector.insert operand is from a llvm.mlir.constant op (#88314)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 8 16:44:33 PDT 2024


Author: Rajveer Singh Bharadwaj
Date: 2024-09-09T01:44:30+02:00
New Revision: 1c4b04ce1f022ec12b0111ad75d8f2f0eec3b054

URL: https://github.com/llvm/llvm-project/commit/1c4b04ce1f022ec12b0111ad75d8f2f0eec3b054
DIFF: https://github.com/llvm/llvm-project/commit/1c4b04ce1f022ec12b0111ad75d8f2f0eec3b054.diff

LOG: [mlir] Fix crash in `InsertOpConstantFolder` when vector.insert operand is from a llvm.mlir.constant op (#88314)

In cases where llvm.mlir.constant has an attribute with a different type than the returned type,
the folder use to create an incorrect DenseElementsAttr and crash.

Resolves #74236

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/LLVMIR/constant-folding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 44bd4aa76ffbd6..b62f1fa5992958 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2896,10 +2896,17 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
         linearize(completePositions, computeStrides(destTy.getShape()));
 
     SmallVector<Attribute> insertedValues;
-    if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
-      llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
-    else
-      insertedValues.push_back(sourceCst);
+    Type destEltType = destTy.getElementType();
+
+    // The `convertIntegerAttr` method specifically handles the case
+    // for `llvm.mlir.constant` which can hold an attribute with a
+    // 
diff erent type than the return type.
+    if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
+      for (auto value : denseSource.getValues<Attribute>())
+        insertedValues.push_back(convertIntegerAttr(value, destEltType));
+    } else {
+      insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
+    }
 
     auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
     copy(insertedValues, allValues.begin() + insertBeginPosition);
@@ -2908,6 +2915,17 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
     rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
     return success();
   }
+
+private:
+  /// Converts the expected type to an IntegerAttr if there's
+  /// a mismatch.
+  Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
+    if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
+      if (intAttr.getType() != expectedType)
+        return IntegerAttr::get(expectedType, intAttr.getInt());
+    }
+    return attr;
+  }
 };
 
 } // namespace

diff  --git a/mlir/test/Dialect/LLVMIR/constant-folding.mlir b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
index 497d679a12a09f..a3570a11d4a2d0 100644
--- a/mlir/test/Dialect/LLVMIR/constant-folding.mlir
+++ b/mlir/test/Dialect/LLVMIR/constant-folding.mlir
@@ -169,3 +169,16 @@ llvm.func @null_pointer_select(%cond: i1) -> !llvm.ptr {
   // CHECK-NEXT: llvm.return %[[NULLPTR]]
   llvm.return %result : !llvm.ptr
 }
+
+// -----
+
+llvm.func @malloc(i64) -> !llvm.ptr
+
+// CHECK-LABEL: func.func @insert_op
+func.func @insert_op(%arg0: index, %arg1: memref<13x13xi64>, %arg2: index) {
+  %cst_7 = arith.constant dense<1526248407> : vector<1xi64>
+  %1 = llvm.mlir.constant(1 : index) : i64
+  %101 = vector.insert %1, %cst_7 [0] : i64 into vector<1xi64>
+  vector.print %101 : vector<1xi64>
+  return
+}


        


More information about the Mlir-commits mailing list