[Mlir-commits] [mlir] [mlir][vector] Refactor parts of `VectorToSCF.cpp` (NFC) (PR #75855)

Rik Huijzer llvmlistbot at llvm.org
Mon Dec 18 13:13:34 PST 2023


https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/75855

>From 898b3b05960cecc32c1c4e15be4278c4f01ee22a Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 18 Dec 2023 21:57:08 +0100
Subject: [PATCH 1/3] [mlir][vector] Refactor parts of `VectorToSCF.cpp`

---
 .../Conversion/VectorToSCF/VectorToSCF.cpp    | 135 +++++++++---------
 1 file changed, 70 insertions(+), 65 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..39bf5fb07f1a70 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -321,7 +321,7 @@ static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
   // It may be possible to support these in future by using dynamic memref dims.
   if (vectorType.getScalableDims().front())
     return failure();
-  auto memrefShape = type.getShape();
+  ArrayRef<int64_t> memrefShape = type.getShape();
   SmallVector<int64_t, 8> newMemrefShape;
   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
   newMemrefShape.push_back(vectorType.getDimSize(0));
@@ -350,7 +350,7 @@ struct Strategy<TransferReadOp> {
   /// result to the temporary buffer allocation.
   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
-    auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
+    auto storeOp = cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
     return storeOp;
   }
@@ -369,8 +369,8 @@ struct Strategy<TransferReadOp> {
   /// Retrieve the indices of the current StoreOp that stores into the buffer.
   static void getBufferIndices(TransferReadOp xferOp,
                                SmallVector<Value, 8> &indices) {
-    auto storeOp = getStoreOp(xferOp);
-    auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
+    memref::StoreOp storeOp = getStoreOp(xferOp);
+    ValueRange prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
     indices.append(prevIndices.begin(), prevIndices.end());
   }
 
@@ -413,8 +413,8 @@ struct Strategy<TransferReadOp> {
     getXferIndices(b, xferOp, iv, xferIndices);
 
     Location loc = xferOp.getLoc();
-    auto bufferType = dyn_cast<ShapedType>(buffer.getType());
-    auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
+    auto bufferType = cast<ShapedType>(buffer.getType());
+    auto vecType = cast<VectorType>(bufferType.getElementType());
     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
     auto newXferOp = b.create<vector::TransferReadOp>(
         loc, vecType, xferOp.getSource(), xferIndices,
@@ -437,8 +437,8 @@ struct Strategy<TransferReadOp> {
     storeIndices.push_back(iv);
 
     Location loc = xferOp.getLoc();
-    auto bufferType = dyn_cast<ShapedType>(buffer.getType());
-    auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
+    auto bufferType = cast<ShapedType>(buffer.getType());
+    auto vecType = cast<VectorType>(bufferType.getElementType());
     auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
 
@@ -476,7 +476,7 @@ struct Strategy<TransferWriteOp> {
   static void getBufferIndices(TransferWriteOp xferOp,
                                SmallVector<Value, 8> &indices) {
     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
-    auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
+    ValueRange prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
     indices.append(prevIndices.begin(), prevIndices.end());
   }
 
@@ -591,11 +591,11 @@ struct PrepareTransferReadConversion
     if (checkPrepareXferOp(xferOp, options).failed())
       return failure();
 
-    auto buffers = allocBuffers(rewriter, xferOp);
+    BufferAllocs buffers = allocBuffers(rewriter, xferOp);
     auto *newXfer = rewriter.clone(*xferOp.getOperation());
     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
     if (xferOp.getMask()) {
-      dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
+      cast<TransferReadOp>(newXfer).getMaskMutable().assign(
           buffers.maskBuffer);
     }
 
@@ -641,7 +641,7 @@ struct PrepareTransferWriteConversion
       return failure();
 
     Location loc = xferOp.getLoc();
-    auto buffers = allocBuffers(rewriter, xferOp);
+    BufferAllocs buffers = allocBuffers(rewriter, xferOp);
     rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
                                      buffers.dataBuffer);
     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
@@ -707,8 +707,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     if (vectorType.getRank() > 1 && vectorType.isScalable())
       return failure();
 
-    auto loc = printOp.getLoc();
-    auto value = printOp.getSource();
+    Location loc = printOp.getLoc();
+    Value value = printOp.getSource();
 
     if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
       // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
@@ -719,27 +719,27 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
       auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
                                          intTy.getSignedness());
       // arith can only take signless integers, so we must cast back and forth.
-      auto signlessSourceVectorType =
+      VectorType signlessSource =
           vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
-      auto signlessTargetVectorType =
+      VectorType signlessTarget =
           vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
-      auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
-      value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
+      VectorType target = vectorType.cloneWith({}, legalIntTy);
+      value = rewriter.create<vector::BitCastOp>(loc, signlessSource,
                                                  value);
-      if (value.getType() != signlessTargetVectorType) {
+      if (value.getType() != signlessTarget) {
         if (width == 1 || intTy.isUnsigned())
-          value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
+          value = rewriter.create<arith::ExtUIOp>(loc, signlessTarget,
                                                   value);
         else
-          value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
+          value = rewriter.create<arith::ExtSIOp>(loc, signlessTarget,
                                                   value);
       }
-      value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
-      vectorType = targetVectorType;
+      value = rewriter.create<vector::BitCastOp>(loc, target, value);
+      vectorType = target;
     }
 
-    auto scalableDimensions = vectorType.getScalableDims();
-    auto shape = vectorType.getShape();
+    ArrayRef<bool> scalableDimensions = vectorType.getScalableDims();
+    ArrayRef<int64_t> shape = vectorType.getShape();
     constexpr int64_t singletonShape[] = {1};
     if (vectorType.getRank() == 0)
       shape = singletonShape;
@@ -748,11 +748,11 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
       // Flatten n-D vectors to 1D. This is done to allow indexing with a
       // non-constant value (which can currently only be done via
       // vector.extractelement for 1D vectors).
-      auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
+      int flatLength = std::accumulate(shape.begin(), shape.end(), 1,
                                         std::multiplies<int64_t>());
-      auto flatVectorType =
+      VectorType flat =
           VectorType::get({flatLength}, vectorType.getElementType());
-      value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
+      value = rewriter.create<vector::ShapeCastOp>(loc, flat, value);
     }
 
     vector::PrintOp firstClose;
@@ -866,6 +866,37 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     this->setHasBoundedRewriteRecursion();
   }
 
+  Value createNewTransferOp(OpBuilder &b, OpTy xferOp, Value castedDataBuffer,
+                            Value iv, ValueRange loopState,
+                            Value castedMaskBuffer, PatternRewriter &rewriter,
+                            Location loc) const {
+    // Create new transfer op.
+    OpTy newXfer = Strategy<OpTy>::rewriteOp(b, this->options, xferOp,
+                                             castedDataBuffer, iv, loopState);
+
+    // If old transfer op has a mask: Set mask on new transfer op.
+    // Special case: If the mask of the old transfer op is 1D and
+    // th unpacked dim is not a broadcast, no mask is needed on the
+    // new transfer op.
+    if (xferOp.getMask() &&
+        (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() > 1)) {
+      OpBuilder::InsertionGuard guard(b);
+      b.setInsertionPoint(newXfer); // Insert load before newXfer.
+
+      SmallVector<Value, 8> loadIndices;
+      Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
+      // In case of broadcast: Use same indices to load from memref
+      // as before.
+      if (!xferOp.isBroadcastDim(0))
+        loadIndices.push_back(iv);
+
+      auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, loadIndices);
+      rewriter.updateRootInPlace(
+          newXfer, [&]() { newXfer.getMaskMutable().assign(mask); });
+    }
+    return loopState.empty() ? Value() : newXfer->getResult(0);
+  }
+
   LogicalResult matchAndRewrite(OpTy xferOp,
                                 PatternRewriter &rewriter) const override {
     if (!xferOp->hasAttr(kPassLabel))
@@ -873,9 +904,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
 
     // Find and cast data buffer. How the buffer can be found depends on OpTy.
     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
-    auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
-    auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
-    auto castedDataType = unpackOneDim(dataBufferType);
+    Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
+    auto dataBufferType = cast<MemRefType>(dataBuffer.getType());
+    FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
     if (failed(castedDataType))
       return failure();
 
@@ -885,8 +916,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     // If the xferOp has a mask: Find and cast mask buffer.
     Value castedMaskBuffer;
     if (xferOp.getMask()) {
-      auto maskBuffer = getMaskBuffer(xferOp);
-      auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+      Value maskBuffer = getMaskBuffer(xferOp);
       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
         // Do not unpack a dimension of the mask, if:
         // * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +927,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
       } else {
         // It's safe to assume the mask buffer can be unpacked if the data
         // buffer was unpacked.
-        auto castedMaskType = *unpackOneDim(maskBufferType);
+        auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
+        MemRefType castedMaskType = *unpackOneDim(maskBufferType);
         castedMaskBuffer =
             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
       }
@@ -910,7 +941,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     auto step = locB.create<arith::ConstantIndexOp>(1);
     // TransferWriteOps that operate on tensors return the modified tensor and
     // require a loop state.
-    auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
+    Value loopState = Strategy<OpTy>::initialLoopState(xferOp);
 
     // Generate for loop.
     auto result = locB.create<scf::ForOp>(
@@ -918,40 +949,14 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
 
-          auto result = generateInBoundsCheck(
+          Value result = generateInBoundsCheck(
               b, xferOp, iv, unpackedDim(xferOp),
               stateType ? TypeRange(stateType) : TypeRange(),
               /*inBoundsCase=*/
               [&](OpBuilder &b, Location loc) {
-                // Create new transfer op.
-                OpTy newXfer = Strategy<OpTy>::rewriteOp(
-                    b, this->options, xferOp, castedDataBuffer, iv, loopState);
-
-                // If old transfer op has a mask: Set mask on new transfer op.
-                // Special case: If the mask of the old transfer op is 1D and
-                // the
-                //               unpacked dim is not a broadcast, no mask is
-                //               needed on the new transfer op.
-                if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
-                                         xferOp.getMaskType().getRank() > 1)) {
-                  OpBuilder::InsertionGuard guard(b);
-                  b.setInsertionPoint(newXfer); // Insert load before newXfer.
-
-                  SmallVector<Value, 8> loadIndices;
-                  Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
-                  // In case of broadcast: Use same indices to load from memref
-                  // as before.
-                  if (!xferOp.isBroadcastDim(0))
-                    loadIndices.push_back(iv);
-
-                  auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
-                                                       loadIndices);
-                  rewriter.updateRootInPlace(newXfer, [&]() {
-                    newXfer.getMaskMutable().assign(mask);
-                  });
-                }
-
-                return loopState.empty() ? Value() : newXfer->getResult(0);
+                return createNewTransferOp(b, xferOp, castedDataBuffer, iv,
+                                           loopState, castedMaskBuffer,
+                                           rewriter, loc);
               },
               /*outOfBoundsCase=*/
               [&](OpBuilder &b, Location /*loc*/) {

>From 91032d429262d0c1c5b665d58002fb72d239c653 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 18 Dec 2023 22:08:04 +0100
Subject: [PATCH 2/3] Remove outdated comment

---
 mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 39bf5fb07f1a70..87e22703982e93 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -870,7 +870,6 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
                             Value iv, ValueRange loopState,
                             Value castedMaskBuffer, PatternRewriter &rewriter,
                             Location loc) const {
-    // Create new transfer op.
     OpTy newXfer = Strategy<OpTy>::rewriteOp(b, this->options, xferOp,
                                              castedDataBuffer, iv, loopState);
 

>From c66dd8f5cca9357a70ebe23a746c6c3a39129858 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 18 Dec 2023 22:13:22 +0100
Subject: [PATCH 3/3] Apply `clang-format`

---
 mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 14 +++++---------
 1 file changed, 5 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 87e22703982e93..d6509844215337 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -595,8 +595,7 @@ struct PrepareTransferReadConversion
     auto *newXfer = rewriter.clone(*xferOp.getOperation());
     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
     if (xferOp.getMask()) {
-      cast<TransferReadOp>(newXfer).getMaskMutable().assign(
-          buffers.maskBuffer);
+      cast<TransferReadOp>(newXfer).getMaskMutable().assign(buffers.maskBuffer);
     }
 
     Location loc = xferOp.getLoc();
@@ -724,15 +723,12 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
       VectorType signlessTarget =
           vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
       VectorType target = vectorType.cloneWith({}, legalIntTy);
-      value = rewriter.create<vector::BitCastOp>(loc, signlessSource,
-                                                 value);
+      value = rewriter.create<vector::BitCastOp>(loc, signlessSource, value);
       if (value.getType() != signlessTarget) {
         if (width == 1 || intTy.isUnsigned())
-          value = rewriter.create<arith::ExtUIOp>(loc, signlessTarget,
-                                                  value);
+          value = rewriter.create<arith::ExtUIOp>(loc, signlessTarget, value);
         else
-          value = rewriter.create<arith::ExtSIOp>(loc, signlessTarget,
-                                                  value);
+          value = rewriter.create<arith::ExtSIOp>(loc, signlessTarget, value);
       }
       value = rewriter.create<vector::BitCastOp>(loc, target, value);
       vectorType = target;
@@ -749,7 +745,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
       // non-constant value (which can currently only be done via
       // vector.extractelement for 1D vectors).
       int flatLength = std::accumulate(shape.begin(), shape.end(), 1,
-                                        std::multiplies<int64_t>());
+                                       std::multiplies<int64_t>());
       VectorType flat =
           VectorType::get({flatLength}, vectorType.getElementType());
       value = rewriter.create<vector::ShapeCastOp>(loc, flat, value);



More information about the Mlir-commits mailing list