[Mlir-commits] [mlir] cd85f5d - [mlir] canonicalizer: shape_cast(poison) -> poison (#133988)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 11 07:13:07 PDT 2025


Author: James Newling
Date: 2025-04-11T15:13:03+01:00
New Revision: cd85f5dbdf135347a9912dde148ec9fd325ba8c1

URL: https://github.com/llvm/llvm-project/commit/cd85f5dbdf135347a9912dde148ec9fd325ba8c1
DIFF: https://github.com/llvm/llvm-project/commit/cd85f5dbdf135347a9912dde148ec9fd325ba8c1.diff

LOG: [mlir] canonicalizer: shape_cast(poison) -> poison  (#133988)

Based on the ShapeCastConstantFolder, this pattern replaces

%0 = ub.poison : vector<2x3xf32>
%1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32>

with 

%1 = ub.poison : vector<6xf32>

---------

Signed-off-by: James Newling <james.newling at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..59f3b788cebed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -42,6 +42,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 
 #include <cassert>
 #include <cstdint>
@@ -5611,18 +5612,20 @@ LogicalResult ShapeCastOp::verify() {
 }
 
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
+
   // No-op shape cast.
-  if (getSource().getType() == getResult().getType())
+  if (getSource().getType() == getType())
     return getSource();
 
+  VectorType resultType = getType();
+
   // Canceling shape casts.
   if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-    if (getResult().getType() == otherOp.getSource().getType())
-      return otherOp.getSource();
 
-    // Only allows valid transitive folding.
-    VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
-    VectorType resultType = llvm::cast<VectorType>(getResult().getType());
+    // Only allows valid transitive folding (expand/collapse dimensions).
+    VectorType srcType = otherOp.getSource().getType();
+    if (resultType == srcType)
+      return otherOp.getSource();
     if (srcType.getRank() < resultType.getRank()) {
       if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
         return {};
@@ -5632,43 +5635,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
     } else {
       return {};
     }
-
     setOperand(otherOp.getSource());
     return getResult();
   }
 
   // Cancelling broadcast and shape cast ops.
   if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
-    if (bcastOp.getSourceType() == getType())
+    if (bcastOp.getSourceType() == resultType)
       return bcastOp.getSource();
   }
 
+  // shape_cast(constant) -> constant
+  if (auto splatAttr =
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+    return DenseElementsAttr::get(resultType,
+                                  splatAttr.getSplatValue<Attribute>());
+  }
+
+  // shape_cast(poison) -> poison
+  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+    return ub::PoisonAttr::get(getContext());
+  }
+
   return {};
 }
 
 namespace {
-// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
-class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    auto constantOp =
-        shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
-    if (!constantOp)
-      return failure();
-    // Only handle splat for now.
-    auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
-    if (!dense)
-      return failure();
-    auto newAttr =
-        DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
-                               dense.getSplatValue<Attribute>());
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
-    return success();
-  }
-};
 
 /// Helper function that computes a new vector type based on the input vector
 /// type by removing the trailing one dims:
@@ -5828,8 +5820,9 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
-              ShapeCastBroadcastFolder>(context);
+  results
+      .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..72064fb42741a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
 
 // -----
 
+// CHECK-LABEL: shape_cast_poison
+//       CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
+//       CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
+//       CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+  %poison = ub.poison : vector<5x4x2xf32>
+  %poison_1 = ub.poison : vector<12x2xi32>
+  %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
+  %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
+  return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_strided_constant
 //       CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
 //       CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>


        


More information about the Mlir-commits mailing list