[Mlir-commits] [mlir] [mlir] canonicalizer: shape_cast(poison) -> poison (PR #133988)

James Newling llvmlistbot at llvm.org
Wed Apr 9 13:01:50 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/133988

>From 2ddaa8ecb824e263730b1747b2762c2526b0773e Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 1 Apr 2025 14:25:16 -0700
Subject: [PATCH 1/4] add  canonicalizer

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 23 ++++++++++++++++++++--
 mlir/test/Dialect/Vector/canonicalize.mlir | 14 +++++++++++++
 2 files changed, 35 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..68b4c26880141 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5670,6 +5670,23 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
   }
 };
 
+// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
+class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (!shapeCastOp.getSource().getDefiningOp<ub::PoisonOp>())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ub::PoisonOp>(shapeCastOp,
+                                              shapeCastOp.getType());
+    return success();
+  }
+};
+
 /// Helper function that computes a new vector type based on the input vector
 /// type by removing the trailing one dims:
 ///
@@ -5828,8 +5845,10 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
-              ShapeCastBroadcastFolder>(context);
+  results
+      .add<ShapeCastConstantFolder, ShapeCastPoisonFolder,
+           ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..72064fb42741a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
 
 // -----
 
+// CHECK-LABEL: shape_cast_poison
+//       CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
+//       CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
+//       CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+  %poison = ub.poison : vector<5x4x2xf32>
+  %poison_1 = ub.poison : vector<12x2xi32>
+  %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
+  %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
+  return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_strided_constant
 //       CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
 //       CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>

>From 99ed6355272915d893db4a38fdfd7c3ba0fa0daa Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 7 Apr 2025 09:04:44 -0700
Subject: [PATCH 2/4] use folders where possible (replace 2 canons)

Signed-off-by: James Newling <james.newling at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 56 +++++++-----------------
 1 file changed, 15 insertions(+), 41 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 68b4c26880141..53fc47a6d6ef5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5643,49 +5643,24 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
       return bcastOp.getSource();
   }
 
-  return {};
-}
-
-namespace {
-// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
-class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    auto constantOp =
-        shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
-    if (!constantOp)
-      return failure();
-    // Only handle splat for now.
-    auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
-    if (!dense)
-      return failure();
-    auto newAttr =
-        DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
-                               dense.getSplatValue<Attribute>());
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
-    return success();
+  // Replace shape_cast(arith.constant) with arith.constant. Currently only
+  // handles splat constants.
+  if (auto constantOp = getSource().getDefiningOp<arith::ConstantOp>()) {
+    if (auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue())) {
+      return DenseElementsAttr::get(cast<VectorType>(getType()),
+                                    dense.getSplatValue<Attribute>());
+    }
   }
-};
 
-// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
-class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
+  // Replace shape_cast(poison) with poison.
+  if (getSource().getDefiningOp<ub::PoisonOp>()) {
+    return ub::PoisonAttr::get(getContext());
+  }
 
-    if (!shapeCastOp.getSource().getDefiningOp<ub::PoisonOp>())
-      return failure();
+  return {};
+}
 
-    rewriter.replaceOpWithNewOp<ub::PoisonOp>(shapeCastOp,
-                                              shapeCastOp.getType());
-    return success();
-  }
-};
+namespace {
 
 /// Helper function that computes a new vector type based on the input vector
 /// type by removing the trailing one dims:
@@ -5846,8 +5821,7 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   results
-      .add<ShapeCastConstantFolder, ShapeCastPoisonFolder,
-           ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+      .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
           context);
 }
 

>From eb5b9d7046b8c7ae0b30f8872ac4a593074af890 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 7 Apr 2025 10:16:20 -0700
Subject: [PATCH 3/4] simplify folding of const/poison

Signed-off-by: James Newling <james.newling at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 38 +++++++++++-------------
 1 file changed, 18 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 53fc47a6d6ef5..0ac22969629da 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -42,6 +42,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 
 #include <cassert>
 #include <cstdint>
@@ -5611,28 +5612,27 @@ LogicalResult ShapeCastOp::verify() {
 }
 
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
+
   // No-op shape cast.
-  if (getSource().getType() == getResult().getType())
+  if (getSource().getType() == getType())
     return getSource();
 
   // Canceling shape casts.
   if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-    if (getResult().getType() == otherOp.getSource().getType())
-      return otherOp.getSource();
 
-    // Only allows valid transitive folding.
-    VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
-    VectorType resultType = llvm::cast<VectorType>(getResult().getType());
-    if (srcType.getRank() < resultType.getRank()) {
-      if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
+    // Only allows valid transitive folding (expand/collapse dimensions).
+    VectorType srcType = otherOp.getSource().getType();
+    if (getType() == srcType)
+      return otherOp.getSource();
+    if (srcType.getRank() < getType().getRank()) {
+      if (!isValidShapeCast(srcType.getShape(), getType().getShape()))
         return {};
-    } else if (srcType.getRank() > resultType.getRank()) {
-      if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
+    } else if (srcType.getRank() > getType().getRank()) {
+      if (!isValidShapeCast(getType().getShape(), srcType.getShape()))
         return {};
     } else {
       return {};
     }
-
     setOperand(otherOp.getSource());
     return getResult();
   }
@@ -5643,17 +5643,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
       return bcastOp.getSource();
   }
 
-  // Replace shape_cast(arith.constant) with arith.constant. Currently only
-  // handles splat constants.
-  if (auto constantOp = getSource().getDefiningOp<arith::ConstantOp>()) {
-    if (auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue())) {
-      return DenseElementsAttr::get(cast<VectorType>(getType()),
-                                    dense.getSplatValue<Attribute>());
-    }
+  // shape_cast(constant) -> constant
+  if (auto splatAttr =
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+    return DenseElementsAttr::get(getType(),
+                                  splatAttr.getSplatValue<Attribute>());
   }
 
-  // Replace shape_cast(poison) with poison.
-  if (getSource().getDefiningOp<ub::PoisonOp>()) {
+  // shape_cast(poison) -> poison
+  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
     return ub::PoisonAttr::get(getContext());
   }
 

>From d47886c4c6516b877baeacade6a29b80403aad9b Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 9 Apr 2025 13:06:13 -0700
Subject: [PATCH 4/4] use resultType

Signed-off-by: James Newling <james.newling at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0ac22969629da..59f3b788cebed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5617,18 +5617,20 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
   if (getSource().getType() == getType())
     return getSource();
 
+  VectorType resultType = getType();
+
   // Canceling shape casts.
   if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
 
     // Only allows valid transitive folding (expand/collapse dimensions).
     VectorType srcType = otherOp.getSource().getType();
-    if (getType() == srcType)
+    if (resultType == srcType)
       return otherOp.getSource();
-    if (srcType.getRank() < getType().getRank()) {
-      if (!isValidShapeCast(srcType.getShape(), getType().getShape()))
+    if (srcType.getRank() < resultType.getRank()) {
+      if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
         return {};
-    } else if (srcType.getRank() > getType().getRank()) {
-      if (!isValidShapeCast(getType().getShape(), srcType.getShape()))
+    } else if (srcType.getRank() > resultType.getRank()) {
+      if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
         return {};
     } else {
       return {};
@@ -5639,14 +5641,14 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // Cancelling broadcast and shape cast ops.
   if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
-    if (bcastOp.getSourceType() == getType())
+    if (bcastOp.getSourceType() == resultType)
       return bcastOp.getSource();
   }
 
   // shape_cast(constant) -> constant
   if (auto splatAttr =
           llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
-    return DenseElementsAttr::get(getType(),
+    return DenseElementsAttr::get(resultType,
                                   splatAttr.getSplatValue<Attribute>());
   }
 



More information about the Mlir-commits mailing list