[Mlir-commits] [mlir] eca7698 - [mlir][vector] NFC: Expose castAwayContractionLeadingOneDim

Lei Zhang llvmlistbot at llvm.org
Fri Apr 21 09:44:52 PDT 2023


Author: Lei Zhang
Date: 2023-04-21T09:41:14-07:00
New Revision: eca7698a979ebc2338c075d63ac52c4c21b19cb1

URL: https://github.com/llvm/llvm-project/commit/eca7698a979ebc2338c075d63ac52c4c21b19cb1
DIFF: https://github.com/llvm/llvm-project/commit/eca7698a979ebc2338c075d63ac52c4c21b19cb1.diff

LOG: [mlir][vector] NFC: Expose castAwayContractionLeadingOneDim

This commit exposes the transformation behind the pattern.
It is useful for more targeted application on a specific op
for once.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D148758

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 4763b6525b934..ed25021e421f4 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -44,6 +44,7 @@ enum class AtomicRMWKind : uint64_t;
 } // namespace arith
 
 namespace vector {
+class ContractionOp;
 class TransferReadOp;
 class TransferWriteOp;
 class VectorDialect;
@@ -76,6 +77,11 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
 void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
                                       PatternBenefit benefit = 1);
 
+/// Cast away the leading unit dim, if exists, for the given contract op.
+/// Return success if the transformation applies; return failure otherwise.
+LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
+                                               RewriterBase &rewriter);
+
 /// Collect a set of leading one dimension removal patterns.
 ///
 /// These patterns insert vector.shape_cast to remove leading one dimensions

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 849e0442bc7e1..6105e87573c23 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -279,6 +279,121 @@ struct CastAwayTransferWriteLeadingOneDim
   }
 };
 
+} // namespace
+
+LogicalResult
+mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
+                                               RewriterBase &rewriter) {
+  VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
+  if (oldAccType == nullptr)
+    return failure();
+  if (oldAccType.getRank() < 2)
+    return failure();
+  if (oldAccType.getShape()[0] != 1)
+    return failure();
+  // currently we support only dropping one dim but the pattern can be applied
+  // greedily to drop more.
+  int64_t dropDim = 1;
+
+  auto oldIndexingMaps = contractOp.getIndexingMapsArray();
+  SmallVector<AffineMap> newIndexingMaps;
+
+  auto oldIteratorTypes = contractOp.getIteratorTypes();
+  SmallVector<Attribute> newIteratorTypes;
+
+  int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
+
+  if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
+    // only parallel type iterators can be dropped.
+    return failure();
+
+  for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
+    int64_t currDim = it.index();
+    if (currDim == dimToDrop)
+      continue;
+    newIteratorTypes.push_back(it.value());
+  }
+
+  SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
+                                 contractOp.getAcc()};
+  SmallVector<Value> newOperands;
+
+  for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
+    // Check if the dim to be dropped exists as a leading dim in the operand
+    // if it does then we use vector.extract to drop it.
+    bool validExtract = false;
+    SmallVector<AffineExpr> results;
+    auto map = it.value();
+    int64_t orginalZeroDim = it.value().getDimPosition(0);
+    if (orginalZeroDim != dimToDrop) {
+      // There are two reasons to be in this path, 1. We need to
+      // tranpose the operand to make the dim to be dropped
+      // leading. 2. The dim to be dropped does not exist and in
+      // that case we dont want to add a unit tranpose but we must
+      // check all the indices to make sure this is the case.
+      bool tranposeNeeded = false;
+      SmallVector<int64_t> perm;
+      SmallVector<AffineExpr> transposeResults;
+
+      for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+        int64_t currDim = map.getDimPosition(i);
+        if (currDim == dimToDrop) {
+          tranposeNeeded = true;
+          perm.insert(perm.begin(), i);
+          auto targetExpr = rewriter.getAffineDimExpr(currDim);
+          transposeResults.insert(transposeResults.begin(), targetExpr);
+        } else {
+          perm.push_back(i);
+          auto targetExpr = rewriter.getAffineDimExpr(currDim);
+          transposeResults.push_back(targetExpr);
+        }
+      }
+      // Do the tranpose now if needed so that we can drop the
+      // correct dim using extract later.
+      if (tranposeNeeded) {
+        map = AffineMap::get(map.getNumDims(), 0, transposeResults,
+                             contractOp.getContext());
+        operands[it.index()] = rewriter.create<vector::TransposeOp>(
+            contractOp.getLoc(), operands[it.index()], perm);
+      }
+    }
+    // We have taken care to have the dim to be dropped be
+    // the leading dim. If its still not leading that means it
+    // does not exist in this operand and hence we do not need
+    // an extract.
+    if (map.getDimPosition(0) == dimToDrop)
+      validExtract = true;
+
+    for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
+      int64_t currDim = map.getDimPosition(i);
+      if (currDim == dimToDrop)
+        // This is the dim we are dropping.
+        continue;
+      auto targetExpr = rewriter.getAffineDimExpr(
+          currDim < dimToDrop ? currDim : currDim - 1);
+      results.push_back(targetExpr);
+    }
+    newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
+                                             contractOp.getContext()));
+    // Extract if its a valid extraction, otherwise use the operand
+    // without extraction.
+    newOperands.push_back(
+        validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
+                                                          operands[it.index()],
+                                                          splatZero(dropDim))
+                     : operands[it.index()]);
+  }
+  auto newContractOp = rewriter.create<vector::ContractionOp>(
+      contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
+      rewriter.getAffineMapArrayAttr(newIndexingMaps),
+      rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
+  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+      contractOp, contractOp->getResultTypes()[0], newContractOp);
+  return success();
+}
+
+namespace {
+
 /// Turns vector.contract on vector with leading 1 dimensions into
 /// vector.extract followed by vector.contract on vector without leading
 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
@@ -289,112 +404,7 @@ struct CastAwayContractionLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
-    VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
-    if (oldAccType == nullptr)
-      return failure();
-    if (oldAccType.getRank() < 2)
-      return failure();
-    if (oldAccType.getShape()[0] != 1)
-      return failure();
-    // currently we support only dropping one dim but the pattern can be applied
-    // greedily to drop more.
-    int64_t dropDim = 1;
-
-    auto oldIndexingMaps = contractOp.getIndexingMapsArray();
-    SmallVector<AffineMap> newIndexingMaps;
-
-    auto oldIteratorTypes = contractOp.getIteratorTypes();
-    SmallVector<Attribute> newIteratorTypes;
-
-    int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
-
-    if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
-      // only parallel type iterators can be dropped.
-      return failure();
-
-    for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
-      int64_t currDim = it.index();
-      if (currDim == dimToDrop)
-        continue;
-      newIteratorTypes.push_back(it.value());
-    }
-
-    SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
-                                   contractOp.getAcc()};
-    SmallVector<Value> newOperands;
-
-    for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
-      // Check if the dim to be dropped exists as a leading dim in the operand
-      // if it does then we use vector.extract to drop it.
-      bool validExtract = false;
-      SmallVector<AffineExpr> results;
-      auto map = it.value();
-      int64_t orginalZeroDim = it.value().getDimPosition(0);
-      if (orginalZeroDim != dimToDrop) {
-        // There are two reasons to be in this path, 1. We need to
-        // tranpose the operand to make the dim to be dropped
-        // leading. 2. The dim to be dropped does not exist and in
-        // that case we dont want to add a unit tranpose but we must
-        // check all the indices to make sure this is the case.
-        bool tranposeNeeded = false;
-        SmallVector<int64_t> perm;
-        SmallVector<AffineExpr> transposeResults;
-
-        for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
-          int64_t currDim = map.getDimPosition(i);
-          if (currDim == dimToDrop) {
-            tranposeNeeded = true;
-            perm.insert(perm.begin(), i);
-            auto targetExpr = rewriter.getAffineDimExpr(currDim);
-            transposeResults.insert(transposeResults.begin(), targetExpr);
-          } else {
-            perm.push_back(i);
-            auto targetExpr = rewriter.getAffineDimExpr(currDim);
-            transposeResults.push_back(targetExpr);
-          }
-        }
-        // Do the tranpose now if needed so that we can drop the
-        // correct dim using extract later.
-        if (tranposeNeeded) {
-          map = AffineMap::get(map.getNumDims(), 0, transposeResults,
-                               contractOp.getContext());
-          operands[it.index()] = rewriter.create<vector::TransposeOp>(
-              contractOp.getLoc(), operands[it.index()], perm);
-        }
-      }
-      // We have taken care to have the dim to be dropped be
-      // the leading dim. If its still not leading that means it
-      // does not exist in this operand and hence we do not need
-      // an extract.
-      if (map.getDimPosition(0) == dimToDrop)
-        validExtract = true;
-
-      for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
-        int64_t currDim = map.getDimPosition(i);
-        if (currDim == dimToDrop)
-          // This is the dim we are dropping.
-          continue;
-        auto targetExpr = rewriter.getAffineDimExpr(
-            currDim < dimToDrop ? currDim : currDim - 1);
-        results.push_back(targetExpr);
-      }
-      newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
-                                               contractOp.getContext()));
-      // Extract if its a valid extraction, otherwise use the operand
-      // without extraction.
-      newOperands.push_back(validExtract
-                                ? rewriter.create<vector::ExtractOp>(
-                                      contractOp.getLoc(), operands[it.index()],
-                                      splatZero(dropDim))
-                                : operands[it.index()]);
-    }
-    auto newContractOp = rewriter.create<vector::ContractionOp>(
-        contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
-        rewriter.getAffineMapArrayAttr(newIndexingMaps),
-        rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-        contractOp, contractOp->getResultTypes()[0], newContractOp);
-    return success();
+    return castAwayContractionLeadingOneDim(contractOp, rewriter);
   }
 };
 


        


More information about the Mlir-commits mailing list