[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (PR #65774)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 12 08:28:17 PDT 2023


llvmbot wrote:

@llvm/pr-subscribers-mlir-vector

<details>
<summary>Changes</summary>

…cast) expansion

This revision adds a rewrite for sequences of vector ext(maybe_broadcast(bitcast)) to use a more efficient sequence of vector operations comprising shuffles, shifts and bitwise logical ops. The rewrite uses an intermediate bitwidth equal to the licm of the element types of the source and result types of `bitCastOp`. This intermediate type may be small or greater than the desired elemental type of the extOp, in which case appropriate ext or trunc operations are inserted. The rewrite fails if the intermediate type is greater than 64 and if the involved vector types fail to meet basic divisilibity requirements. In other words, this rewrite does not handle partial vector boundaries and leaves this part of the heavy-lifting to LLVM.
--

Patch is 29.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65774.diff

7 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+13) 
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+21-1) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (+10) 
- (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+276-4) 
- (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (-21) 
- (added) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+205) 


<pre>
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 2b8c95a94257e6c..d1cef91f8e27525 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -281,6 +281,19 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
   }];
 }
 
+def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.rewrite_narrow_types",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector narrow rewrite operations should be applied.
+
+    This is usually a late step that is run after bufferization as part of the
+    process of lowering to e.g. LLVM or NVVM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplySplitTransferFullPartialPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.split_transfer_full_partial",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index c644090d8c78cd0..20c33921f9de24e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -23,6 +23,7 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arith {
+class AndIOp;
 class NarrowTypeEmulationConverter;
 } // namespace arith
 
@@ -143,7 +144,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 
 /// Patterns that remove redundant vector broadcasts.
 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+                                         PatternBenefit benefit = 1);
 
 /// Populate `patterns` with the following patterns.
 ///
@@ -301,6 +302,25 @@ void populateVectorNarrowTypeEmulationPatterns(
     arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns);
 
+/// Rewrite vector ext(maybe_broadcast(bitcast)) to use a more efficient
+/// sequence of vector operations comprising shuffles, shifts and bitwise
+/// logical ops. The rewrite uses an intermediate bitwidth equal to the licm of
+/// the element types of the source and result types of `bitCastOp`. This
+/// intermediate type may be small or greater than the desired elemental type of
+/// the extOp, in which case appropriate ext or trunc operations are inserted.
+/// The rewrite fails if the intermediate type is greater than 64 and if the
+/// involved vector types fail to meet basic divisilibity requirements. In other
+/// words, this rewrite does not handle partial vector boundaries and leaves
+/// this part of the heavy-lifting to LLVM.
+FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                     vector::BitCastOp bitCastOp,
+                                     vector::BroadcastOp maybeBroadcastOp);
+
+/// Appends patterns for rewriting vector operations over narrow types with
+/// ops over wider types.
+void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
+                                             PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f031eb0a5c30ce9..9df5548cd5d939c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -357,6 +357,16 @@ class VectorType::Builder {
     return *this;
   }
 
+  /// Set a dim in shape @pos to val.
+  Builder &setDim(unsigned pos, int64_t val) {
+    if (storage.empty())
+      storage.append(shape.begin(), shape.end());
+    assert(pos < storage.size() && "overflow");
+    storage[pos] = val;
+    shape = {storage.data(), storage.size()};
+    return *this;
+  }
+
   operator VectorType() {
     return VectorType::get(shape, elementType, scalableDims);
   }
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 94f19e59669eafd..0fdeded436a9773 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -154,6 +154,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
   }
 }
 
+void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  populateVectorNarrowTypeRewritePatterns(patterns);
+}
+
 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index b2b7bfc5e4437c1..d10994f3709e390 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,7 +7,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -15,13 +14,29 @@
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
-#include <cassert>
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
+#include <numeric>
+#include <type_traits>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "vector-narrow-type-emulation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -155,6 +170,256 @@ struct ConvertVectorTransferRead final
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+/// Create a vector of bit masks: `idx .. idx + step - 1` and broadcast it
+/// `numOccurrences` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0xc, 0x3, 0xc, 0x3, 0xc, 0x3].
+static SmallVector<Attribute> computeExtOfBitCastMasks(MLIRContext *ctx,
+                                                       int64_t bitwidth,
+                                                       int64_t step,
+                                                       int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+  SmallVector<Attribute> tmpMasks;
+  tmpMasks.reserve(bitwidth / step);
+  // Create a vector of bit masks: `idx .. idx + step - 1`.
+  for (int64_t idx = 0; idx < bitwidth; idx += step) {
+    LDBG("Mask bits " << idx << " .. " << idx + step - 1 << " out of "
+                      << bitwidth);
+    IntegerAttr mask = IntegerAttr::get(
+        interimIntType, llvm::APInt::getBitsSet(bitwidth, idx, idx + step));
+    tmpMasks.push_back(mask);
+  }
+  // Replicate the vector of bit masks to the desired size.
+  SmallVector<Attribute> masks;
+  masks.reserve(numOccurrences * tmpMasks.size());
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(masks, tmpMasks);
+  return masks;
+}
+
+/// Create a vector of bit shifts by `k * idx` and broadcast it `numOccurrences`
+/// times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x2, 0x0, 0x2, 0x0, 0x2].
+static SmallVector<Attribute>
+computeExtOfBitCastShifts(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+                          int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+  SmallVector<Attribute> tmpShifts;
+  for (int64_t idx = 0; idx < bitwidth; idx += step) {
+    IntegerAttr shift = IntegerAttr::get(interimIntType, idx);
+    tmpShifts.push_back(shift);
+  }
+  SmallVector<Attribute> shifts;
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(shifts, tmpShifts);
+  return shifts;
+}
+
+/// Create a vector of bit shuffles: `numOccurrences * idx` and broadcast it
+/// `bitwidth/step` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x1, 0x0, 0x1, 0x0, 0x1].
+static SmallVector<int64_t>
+computeExtOfBitCastShuffles(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+                            int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  SmallVector<int64_t> shuffles;
+  int64_t n = floorDiv(bitwidth, step);
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(shuffles, SmallVector<int64_t>(n, idx));
+  return shuffles;
+}
+
+/// Compute the intermediate vector type, its elemental type must be an integer
+/// with bitwidth that:
+///   1. is smaller than 64 (TODO: in the future we may want target-specific
+///   control).
+///   2. divides sourceBitWidth * mostMinorSourceDim
+static int64_t computeExtOfBitCastBitWidth(int64_t sourceBitWidth,
+                                           int64_t mostMinorSourceDim,
+                                           int64_t targetBitWidth) {
+  for (int64_t mult : {32, 16, 8, 4, 2, 1}) {
+    int64_t interimBitWidth =
+        std::lcm(mult, std::lcm(sourceBitWidth, targetBitWidth));
+    if (interimBitWidth > 64)
+      continue;
+    if ((sourceBitWidth * mostMinorSourceDim) % interimBitWidth != 0)
+      continue;
+    return interimBitWidth;
+  }
+  return 0;
+}
+
+FailureOr<Value>
+mlir::vector::rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                  vector::BitCastOp bitCastOp,
+                                  vector::BroadcastOp maybeBroadcastOp) {
+  assert(
+      (llvm::isa<arith::ExtSIOp>(extOp) || llvm::isa<arith::ExtUIOp>(extOp)) &&
+      "unsupported op");
+
+  // The bitcast op is the load-bearing part, capture the source and bitCast
+  // types as well as bitwidth and most minor dimension.
+  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+  LDBG("sourceVectorType: " << sourceVectorType);
+
+  VectorType bitCastVectorType = bitCastOp.getResultVectorType();
+  int64_t targetBitWidth = bitCastVectorType.getElementTypeBitWidth();
+  LDBG("bitCastVectorType: " << bitCastVectorType);
+
+  int64_t interimBitWidth = computeExtOfBitCastBitWidth(
+      sourceBitWidth, mostMinorSourceDim, targetBitWidth);
+  LDBG("interimBitWidth: " << interimBitWidth);
+  if (!interimBitWidth) {
+    return rewriter.notifyMatchFailure(
+        extOp, "heuristic could not find a reasonable interim bitwidth");
+  }
+  if (sourceBitWidth == interimBitWidth || targetBitWidth == interimBitWidth) {
+    return rewriter.notifyMatchFailure(
+        extOp, "interim bitwidth is equal to source or target, nothing to do");
+  }
+
+  int64_t interimMostMinorDim =
+      sourceBitWidth * mostMinorSourceDim / interimBitWidth;
+  LDBG("interimMostMinorDim: " << interimMostMinorDim);
+
+  Location loc = extOp->getLoc();
+  MLIRContext *ctx = extOp->getContext();
+
+  VectorType interimVectorType =
+      VectorType::Builder(sourceVectorType)
+          .setDim(sourceVectorType.getRank() - 1, interimMostMinorDim)
+          .setElementType(IntegerType::get(ctx, interimBitWidth));
+  LDBG("interimVectorType: " << interimVectorType);
+
+  IntegerType interimIntType = IntegerType::get(ctx, interimBitWidth);
+  VectorType vt =
+      VectorType::Builder(bitCastVectorType).setElementType(interimIntType);
+
+  // Rewrite the original bitcast to the interim vector type and shuffle to
+  // broadcast to the desired size.
+  auto newBitCastOp = rewriter.create<vector::BitCastOp>(loc, interimVectorType,
+                                                         bitCastOp.getSource());
+  SmallVector<int64_t> shuffles = computeExtOfBitCastShuffles(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto shuffleOp = rewriter.create<vector::ShuffleOp>(loc, newBitCastOp,
+                                                      newBitCastOp, shuffles);
+  LDBG("shuffle: " << shuffleOp);
+
+  // Compute the constants for masking.
+  SmallVector<Attribute> masks = computeExtOfBitCastMasks(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto maskConstantOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(vt, masks));
+  LDBG("maskConstant: " << maskConstantOp);
+  auto andOp = rewriter.create<arith::AndIOp>(loc, shuffleOp, maskConstantOp);
+  LDBG("andOp: " << andOp);
+
+  // Preserve the intermediate type: this may have serious consequences on the
+  // backend's ability to generate efficient vector operations.
+  // For instance on x86, converting to f16 without going through i32 has severe
+  // performance implications.
+  // As a consequence, this pattern must preserve the original behavior.
+  VectorType resultType = cast<VectorType>(extOp->getResultTypes().front());
+  Type resultElementType = getElementTypeOrSelf(resultType);
+  SmallVector<Attribute> shifts = computeExtOfBitCastShifts(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto shiftConstantOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(vt, shifts));
+  LDBG("shiftConstant: " << shiftConstantOp);
+  Value newResult =
+      TypeSwitch<Operation *, Value>(extOp)
+          .template Case<arith::ExtSIOp>([&](arith::ExtSIOp op) {
+            Value shifted =
+                rewriter.create<arith::ShRSIOp>(loc, andOp, shiftConstantOp);
+            auto vt = shifted.getType().cast<VectorType>();
+            VectorType extVt =
+                VectorType::Builder(vt).setElementType(resultElementType);
+            Operation *res =
+                (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+                    ? rewriter.create<arith::ExtSIOp>(loc, extVt, shifted)
+                    : rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
+            return res->getResult(0);
+          })
+          .template Case<arith::ExtUIOp>([&](arith::ExtUIOp op) {
+            Value shifted =
+                rewriter.create<arith::ShRUIOp>(loc, andOp, shiftConstantOp);
+            auto vt = shifted.getType().cast<VectorType>();
+            VectorType extVt =
+                VectorType::Builder(vt).setElementType(resultElementType);
+            Operation *res =
+                (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+                    ? rewriter.create<arith::ExtUIOp>(loc, extVt, shifted)
+                    : rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
+            return res->getResult(0);
+          })
+          .Default([&](Operation *op) {
+            llvm_unreachable("unexpected op type");
+            return nullptr;
+          });
+
+  if (maybeBroadcastOp) {
+    newResult =
+        rewriter.create<vector::BroadcastOp>(loc, resultType, newResult);
+  }
+
+  return newResult;
+}
+
+namespace {
+template <typename ExtOpType>
+struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
+  using OpRewritePattern<ExtOpType>::OpRewritePattern;
+
+  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
+      : OpRewritePattern<ExtOpType>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(ExtOpType extOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultTy = dyn_cast<VectorType>(extOp.getType());
+    if (!resultTy)
+      return rewriter.notifyMatchFailure(extOp, "not a vector type");
+
+    int64_t elementalBitWidth = resultTy.getElementTypeBitWidth();
+    if (elementalBitWidth & (elementalBitWidth - 1)) {
+      return rewriter.notifyMatchFailure(
+          extOp, "result bitwidth must be a power of 2");
+    }
+
+    // Provision for a potential broadcast op that will be rewritten late.
+    auto maybeBroadcastOp =
+        extOp.getIn().template getDefiningOp<vector::BroadcastOp>();
+
+    // The source must be a bitcast op.
+    auto bitCastOp =
+        maybeBroadcastOp
+            ? maybeBroadcastOp.getSource()
+                  .template getDefiningOp<vector::BitCastOp>()
+            : extOp.getIn().template getDefiningOp<vector::BitCastOp>();
+    if (!bitCastOp)
+      return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
+
+    // Try to rewrite.
+    FailureOr<Value> result =
+        rewriteExtOfBitCast(rewriter, extOp, bitCastOp, maybeBroadcastOp);
+    if (failed(result))
+      return failure();
+
+    rewriter.replaceOp(extOp, *result);
+    return success();
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Public Interface Definition
 //===----------------------------------------------------------------------===//
@@ -167,3 +432,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
   patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
       typeConverter, patterns.getContext());
 }
+
+void vector::populateVectorNarrowTypeRewritePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RewriteExtOfBitCast<arith::ExtSIOp>,
+               RewriteExtOfBitCast<arith::ExtUIOp>>(patterns.getContext(),
+                                                    benefit);
+}
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 777de75b1a47acc..2cb753a3d7fb8f3 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -29,33 +29,12 @@ transform.sequence failures(propagate) {
   // lowering TD macros.
   transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.transfer_permutation_patterns
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_shape_cast
-  } : !transform.any_op
-
-  transform.apply_pattern...
<truncated>
</pre>

</details>

https://github.com/llvm/llvm-project/pull/65774


More information about the Mlir-commits mailing list