[Mlir-commits] [mlir] 6506355 - [mlir][arith] Only fold splats for static shape result types (#93102)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 10:58:58 PDT 2024


Author: Kavan Bickerstaff
Date: 2024-05-23T13:58:54-04:00
New Revision: 650635586220aa8878397579744b71effb35938e

URL: https://github.com/llvm/llvm-project/commit/650635586220aa8878397579744b71effb35938e
DIFF: https://github.com/llvm/llvm-project/commit/650635586220aa8878397579744b71effb35938e.diff

LOG: [mlir][arith] Only fold splats for static shape result types (#93102)

This prevents an assertion when constructing the DenseElementsAttr
result, where the passed-in type is expected to have a static shape.

Fixes https://github.com/llvm/llvm-project/issues/92057

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/CommonFolders.h
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 7dabc781cd595..6f497a259262a 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -298,7 +298,10 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
         calculate(op.getSplatValue<ElementValueT>(), castStatus);
     if (!castStatus)
       return {};
-    return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult);
+    auto shapedResType = cast<ShapedType>(resType);
+    if (!shapedResType.hasStaticShape())
+      return {};
+    return DenseElementsAttr::get(shapedResType, elementResult);
   }
   if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
     // Operand is ElementsAttr-derived; perform an element-wise fold by

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e4f95bb0545a2..1a387c20c4b29 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2950,6 +2950,14 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
   return %ext : tensor<i16>
 }
 
+// Just checks that this doesn't crash.
+// CHECK-LABEL: @signedExtendSplatAsDynamicShape
+func.func @signedExtendSplatAsDynamicShape() -> tensor<?xi64> {
+  %splat = arith.constant dense<5> : tensor<2xi16>
+  %extsplat = arith.extsi %splat : tensor<2xi16> to tensor<?xi64>
+  return %extsplat : tensor<?xi64>
+}
+
 // CHECK-LABEL: @extsi_i0
 //       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
 //       CHECK:   return %[[ZERO]] : i16


        


More information about the Mlir-commits mailing list