[Mlir-commits] [mlir] [mlir] Exclude masked ops in vector transformation (PR #76468)

Jerry Wu llvmlistbot at llvm.org
Wed Dec 27 12:20:14 PST 2023


https://github.com/pzread created https://github.com/llvm/llvm-project/pull/76468

None

>From 4bba0664d8e320e037cc90321e2b3ed9af9ba637 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 22 Dec 2023 00:00:42 +0000
Subject: [PATCH] Handle masked op in VectorDropLeadUnitDim patterns

---
 .../Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp  | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 84294e4552a607..65517295aa72d2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -223,6 +223,9 @@ struct CastAwayTransferReadLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
+    // Not supported masked op yet.
+    if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
+      return failure();
     // TODO: support 0-d corner case.
     if (read.getTransferRank() == 0)
       return failure();
@@ -274,6 +277,9 @@ struct CastAwayTransferWriteLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
+    // Not supported masked op yet.
+    if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
+      return failure();
     // TODO: support 0-d corner case.
     if (write.getTransferRank() == 0)
       return failure();
@@ -325,6 +331,9 @@ struct CastAwayTransferWriteLeadingOneDim
 LogicalResult
 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                                RewriterBase &rewriter) {
+  // Not supported masked op yet.
+  if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
+    return failure();
   VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
   if (oldAccType == nullptr)
     return failure();



More information about the Mlir-commits mailing list