[Mlir-commits] [mlir] Stashing my current changes (PR #86681)

Balaji V. Iyer. llvmlistbot at llvm.org
Tue Mar 26 08:29:23 PDT 2024


https://github.com/bviyer created https://github.com/llvm/llvm-project/pull/86681

None

>From fde30b40372810f8ae089ef5251839f589ce759a Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Mon, 18 Mar 2024 14:52:34 +0000
Subject: [PATCH] Stashing my current changes

---
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    | 44 ++++-----
 .../Transforms/VectorTransferOpTransforms.cpp | 90 +++++++++++++++++--
 2 files changed, 106 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 41c7af4593c77c..c5176fc47c9f76 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -53,26 +53,28 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
       prodOfCollapsedDims *= sourceShape[sourceDim];
       currIndices.push_back(sourceDim++);
     }
+    if (sourceDim < sourceShape.size()) {
 
-    // If the current expanded dimension is dynamic, then the collapsed
-    // dimensions should also be dynamic and product of all previous unprocessed
-    // dimensions of the expanded shape should be 1.
-    if (sourceShape[sourceDim] == ShapedType::kDynamic &&
-        (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
-      return std::nullopt;
+      // If the current expanded dimension is dynamic, then the collapsed
+      // dimensions should also be dynamic and product of all previous
+      // unprocessed dimensions of the expanded shape should be 1.
+      if (sourceShape[sourceDim] == ShapedType::kDynamic &&
+          (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
+        return std::nullopt;
 
-    // If the collapsed dim is dynamic, the current expanded dim should also
-    // be dynamic.
-    if (currTargetShape == ShapedType::kDynamic &&
-        sourceShape[sourceDim] != ShapedType::kDynamic)
-      return std::nullopt;
+      // If the collapsed dim is dynamic, the current expanded dim should also
+      // be dynamic.
+      if (currTargetShape == ShapedType::kDynamic &&
+          sourceShape[sourceDim] != ShapedType::kDynamic)
+        return std::nullopt;
 
-    // For static shapes, if the product of dimensions of the expanded shape
-    // should match the collapsed dimension shape.
-    if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
-      return std::nullopt;
+      // For static shapes, if the product of dimensions of the expanded shape
+      // should match the collapsed dimension shape.
+      if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
+        return std::nullopt;
 
-    currIndices.push_back(sourceDim++);
+      currIndices.push_back(sourceDim++);
+    }
     reassociationMap.emplace_back(ReassociationIndices{});
     std::swap(reassociationMap.back(), currIndices);
     prodOfCollapsedDims = 1;
@@ -322,11 +324,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 0ffef6aabccc18..d1f86a92a8aff3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -353,6 +353,44 @@ createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
       .getResult();
 }
 
+static SmallVector<int64_t>
+findLinearizedShape(llvm::SmallVector<int64_t> currShape, unsigned elemBitWidth,
+                    unsigned targetBitWidth, int firstDimToCollapse) {
+  int idx = (int)currShape.size() - 1;
+  unsigned totalFlat = 1;
+  unsigned reqdBitWidth = targetBitWidth / elemBitWidth;
+  while (reqdBitWidth > 1 && idx >= firstDimToCollapse) {
+    unsigned curr = currShape[idx];
+    if (reqdBitWidth % curr == 0) {
+      reqdBitWidth /= curr;
+      totalFlat *= curr;
+      curr = 1;
+    } else if (curr % 2 == 0 and reqdBitWidth % 2 == 0) {
+      reqdBitWidth /= 2;
+      curr /= 2;
+      totalFlat *= 2;
+    } else {
+      // At this moment, bail when the shape is not perfectly divisible
+      // by the curent required bit-width. For example, if the shape is
+      // <5x4x3x2xi8>, and targetBitWidth=128, then when we reach '3'
+      // the required bitWidth is 8 and dividing 8/3 you get a remainder
+      // of 2, thus the shapes won't match. So bail here for now.
+      break;
+    }
+    currShape[idx] = curr;
+    if (curr == 1) {
+      idx--;
+    }
+  }
+  SmallVector<int64_t> newShape(currShape.begin(),
+                                currShape.begin() + (idx + 1));
+  newShape.push_back(totalFlat);
+  for (int i = firstDimToCollapse - 1; i >= 0; i--) {
+    newShape.push_back(currShape[i]);
+  }
+  return newShape;
+}
+
 namespace {
 
 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by
@@ -493,18 +531,39 @@ class TransferWriteDropUnitDimsPattern
 
 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
 /// input starting at `firstDimToCollapse`.
-static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
-                               Value input, int64_t firstDimToCollapse) {
+static std::optional<Value> collapseInnerDims(PatternRewriter &rewriter,
+                                              mlir::Location loc, Value input,
+                                              int64_t firstDimToCollapse,
+                                              unsigned vectorBitWidth) {
   ShapedType inputType = cast<ShapedType>(input.getType());
   if (inputType.getRank() == 1)
     return input;
   SmallVector<ReassociationIndices> reassociation;
+  SmallVector<int64_t> inputShape(inputType.getShape());
+  auto newShape =
+      findLinearizedShape(inputShape, inputType.getElementTypeBitWidth(),
+                          vectorBitWidth, firstDimToCollapse);
+  // InputShape == newShape means no need to Collapse and just use
+  // transfer as-is
+  if (inputShape == newShape) {
+    return std::nullopt;
+  }
+
+  auto reassocIndices =
+      mlir::getReassociationIndicesForCollapse(inputType.getShape(), newShape);
+  if (reassocIndices.has_value()) {
+    reassociation = reassocIndices.value();
+  } else {
+    return std::nullopt;
+  }
+#if 0
   for (int64_t i = 0; i < firstDimToCollapse; ++i)
     reassociation.push_back(ReassociationIndices{i});
   ReassociationIndices collapsedIndices;
   for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
     collapsedIndices.push_back(i);
   reassociation.push_back(collapsedIndices);
+#endif
   return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
 }
 
@@ -581,8 +640,11 @@ class FlattenContiguousRowMajorTransferReadPattern
     int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
 
     // 1. Collapse the source memref
-    Value collapsedSource =
-        collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
+    std::optional<Value> collapsedSrc = collapseInnerDims(
+        rewriter, loc, source, firstDimToCollapse, targetVectorBitwidth);
+    if (!collapsedSrc.has_value())
+      return failure();
+    Value collapsedSource = collapsedSrc.value();
     MemRefType collapsedSourceType =
         dyn_cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
@@ -715,24 +777,38 @@ class FlattenContiguousRowMajorTransferWritePattern
                                                 collapsedIndices)))
       return failure();
 
-    Value collapsedSource =
-        collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+    std::optional<Value> collapsedSrc = collapseInnerDims(
+        rewriter, loc, source, firstContiguousInnerDim, targetVectorBitwidth);
+    if (!collapsedSrc.has_value())
+      return failure();
+    Value collapsedSource = collapsedSrc.value();
     MemRefType collapsedSourceType =
         cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
-    assert(collapsedRank == firstContiguousInnerDim + 1);
+    // assert(collapsedRank == firstContiguousInnerDim + 1);
     SmallVector<AffineExpr, 1> dimExprs{
         getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+#if 0
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
                                                 vectorType.getElementType());
+#else
+    VectorType flatVectorType = VectorType::get(
+        collapsedSourceType.getShape(), collapsedSourceType.getElementType());
+#endif
     Value flatVector =
+        // rewriter.create<vector::ReshapeOp>(loc, flatVectorType, vector);
         rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
     vector::TransferWriteOp flatWrite =
         rewriter.create<vector::TransferWriteOp>(
             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+    llvm::errs() << "flatVector, CollapsedSource ";
+    collapsedSource.dump();
+    flatVector.dump();
+    llvm::errs() << "flatWrite: ";
+    flatWrite.dump();
     rewriter.eraseOp(transferWriteOp);
     return success();
   }



More information about the Mlir-commits mailing list