[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (PR #66648)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 18 07:31:49 PDT 2023


================
@@ -359,93 +487,93 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     if (!truncOp)
       return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
 
+    // Set up the BitCastRewriter and verify the precondition.
+    VectorType sourceVectorType = bitCastOp.getSourceVectorType();
     VectorType targetVectorType = bitCastOp.getResultVectorType();
-    if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
-      return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
-    // TODO: consider relaxing this restriction in the future if we find ways
-    // to really work with subbyte elements across the MLIR/LLVM boundary.
-    int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
-    if (resultBitwidth % 8 != 0)
-      return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
+    BitCastRewriter bcr(sourceVectorType, targetVectorType);
+    if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+      return failure();
 
-    VectorType sourceVectorType = bitCastOp.getSourceVectorType();
-    BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
-    LDBG("\n" << be.sourceElementRanges);
-
-    Value initialValue = truncOp.getIn();
-    auto initalVectorType = initialValue.getType().cast<VectorType>();
-    auto initalElementType = initalVectorType.getElementType();
-    auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
-
-    Value res;
-    for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e;
-         ++shuffleIdx) {
-      SmallVector<int64_t> shuffles;
-      SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
-
-      // Create the attribute quantities for the shuffle / mask / shift ops.
-      for (auto &srcEltRangeList : be.sourceElementRanges) {
-        bool idxContributesBits =
-            (shuffleIdx < (int64_t)srcEltRangeList.size());
-        int64_t sourceElementIdx =
-            idxContributesBits ? srcEltRangeList[shuffleIdx].sourceElementIdx
-                               : 0;
-        shuffles.push_back(sourceElementIdx);
-
-        int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
-                            ? srcEltRangeList[shuffleIdx].sourceBitBegin
-                            : 0;
-        int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
-                            ? srcEltRangeList[shuffleIdx].sourceBitEnd
-                            : 0;
-        IntegerAttr mask = IntegerAttr::get(
-            rewriter.getIntegerType(initalElementBitWidth),
-            llvm::APInt::getBitsSet(initalElementBitWidth, bitLo, bitHi));
-        masks.push_back(mask);
-
-        int64_t shiftRight = bitLo;
-        shiftRightAmounts.push_back(IntegerAttr::get(
-            rewriter.getIntegerType(initalElementBitWidth), shiftRight));
-
-        int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
-        shiftLeftAmounts.push_back(IntegerAttr::get(
-            rewriter.getIntegerType(initalElementBitWidth), shiftLeft));
-      }
-
-      // Create vector.shuffle #shuffleIdx.
-      auto shuffleOp = rewriter.create<vector::ShuffleOp>(
-          bitCastOp.getLoc(), initialValue, initialValue, shuffles);
-      // And with the mask.
-      VectorType vt = VectorType::Builder(initalVectorType)
-                          .setDim(initalVectorType.getRank() - 1, masks.size());
-      auto constOp = rewriter.create<arith::ConstantOp>(
-          bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks));
-      Value andValue = rewriter.create<arith::AndIOp>(bitCastOp.getLoc(),
-                                                      shuffleOp, constOp);
-      // Align right on 0.
-      auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
-          bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts));
-      Value shiftedRight = rewriter.create<arith::ShRUIOp>(
-          bitCastOp.getLoc(), andValue, shiftRightConstantOp);
-
-      auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
-          bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftLeftAmounts));
-      Value shiftedLeft = rewriter.create<arith::ShLIOp>(
-          bitCastOp.getLoc(), shiftedRight, shiftLeftConstantOp);
-
-      res = res ? rewriter.create<arith::OrIOp>(bitCastOp.getLoc(), res,
-                                                shiftedLeft)
-                : shiftedLeft;
+    // Perform the rewrite.
+    Value truncValue = truncOp.getIn();
+    auto shuffledElementType =
+        cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
+    Value runningResult;
+    for (const BitCastRewriter ::Metadata &metadata :
+         bcr.precomputeMetadata(shuffledElementType)) {
+      runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
+                                      runningResult, metadata);
     }
 
-    bool narrowing = resultBitwidth <= initalElementBitWidth;
+    // Finalize the rewrite.
+    bool narrowing = targetVectorType.getElementTypeBitWidth() <=
+                     shuffledElementType.getIntOrFloatBitWidth();
     if (narrowing) {
       rewriter.replaceOpWithNewOp<arith::TruncIOp>(
-          bitCastOp, bitCastOp.getResultVectorType(), res);
+          bitCastOp, bitCastOp.getResultVectorType(), runningResult);
     } else {
       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
-          bitCastOp, bitCastOp.getResultVectorType(), res);
+          bitCastOp, bitCastOp.getResultVectorType(), runningResult);
     }
+
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
----------------
qcolombet wrote:

Nit: ext(bitcast)

https://github.com/llvm/llvm-project/pull/66648


More information about the Mlir-commits mailing list