[Mlir-commits] [mlir] 650f04f - [mlir][vector] Add pattern to break down vector.bitcast

Quinn Dawkins llvmlistbot at llvm.org
Tue Apr 25 17:22:16 PDT 2023


Author: Quinn Dawkins
Date: 2023-04-25T20:18:02-04:00
New Revision: 650f04feda9039e170de513dab261c672fa847cd

URL: https://github.com/llvm/llvm-project/commit/650f04feda9039e170de513dab261c672fa847cd
DIFF: https://github.com/llvm/llvm-project/commit/650f04feda9039e170de513dab261c672fa847cd.diff

LOG: [mlir][vector] Add pattern to break down vector.bitcast

The pattern added here is intended as a last resort for targets like
SPIR-V where there are vector size restrictions and we need to be able
to break down large vector types. Vectorizing loads/stores for small
bitwidths (e.g. i8) relies on bitcasting to a larger element type and
patterns to bubble bitcast ops to where they can cancel.
This fails for cases such as
```
%1 = arith.trunci %0 : vector<2x32xi32> to vector<2x32xi8>
vector.transfer_write %1, %destination[%c0, %c0] {in_bounds = [true, true]} : vector<2x32xi8>, memref<2x32xi8>
```
where the `arith.trunci` op essentially does the job of one of the
bitcasts, leading to a bitcast that need to be further broken down
```
vector.bitcast %0 : vector<16xi8> to vector<4xi32>
```

Differential Revision: https://reviews.llvm.org/D149065

Added: 
    mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a79bbd0be097..325860079b3d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -181,6 +181,22 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
     std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
     PatternBenefit benefit = 1);
 
+/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
+/// based on the destination vector shape. Bitcasts from a lower bitwidth
+/// element type to a higher bitwidth one are extracted from the lower bitwidth
+/// based on the native destination vector shape and inserted based on the ratio
+/// of the bitwidths.
+///
+/// This acts as a last resort way to break down vector.bitcast ops to smaller
+/// vector sizes. Because this pattern composes until it is bitcasting to a
+/// single element of the higher bitwidth, the is an optional control function.
+/// If `controlFn` is not nullptr, the pattern will only apply to ops where
+/// `controlFn` returns true, otherwise applies to all bitcast ops.
+void populateBreakDownVectorBitCastOpPatterns(
+    RewritePatternSet &patterns,
+    std::function<bool(BitCastOp)> controlFn = nullptr,
+    PatternBenefit benefit = 1);
+
 /// Populate `patterns` with the following patterns.
 ///
 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d412b157e284..44f3a10c4da5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -800,6 +800,91 @@ struct BubbleUpBitCastForStridedSliceInsert
   }
 };
 
+// Breaks down vector.bitcast op
+//
+// This transforms IR like:
+//   %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
+// Into:
+//   %cst = vector.splat %c0_f32 : vector<4xf32>
+//   %1 = vector.extract_strided_slice %0 {
+//          offsets = [0], sizes = [4], strides = [1]
+//        } : vector<8xf16> to vector<4xf16>
+//   %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
+//   %4 = vector.insert_strided_slice %2, %cst {
+//          offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+//   %5 = vector.extract_strided_slice %0 {
+//          offsets = [4], sizes = [4], strides = [1]
+//        } : vector<8xf16> to vector<4xf16>
+//   %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
+//   %7 = vector.insert_strided_slice %6, %cst {
+//          offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+public:
+  BreakDownVectorBitCast(MLIRContext *context,
+                         std::function<bool(vector::BitCastOp)> controlFn,
+                         PatternBenefit benefit)
+      : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+  LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (controlFn && !controlFn(bitcastOp))
+      return failure();
+
+    VectorType castSrcType = bitcastOp.getSourceVectorType();
+    VectorType castDstType = bitcastOp.getResultVectorType();
+    assert(castSrcType.getRank() == castDstType.getRank());
+
+    // Only support rank 1 case for now.
+    if (castSrcType.getRank() != 1)
+      return failure();
+
+    int64_t castSrcLastDim = castSrcType.getShape().back();
+    int64_t castDstLastDim = castDstType.getShape().back();
+    // Require casting to less elements for now; other cases to be implemented.
+    if (castSrcLastDim < castDstLastDim)
+      return failure();
+
+    assert(castSrcLastDim % castDstLastDim == 0);
+    int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
+    // Nothing to do if it is already bitcasting to a single element.
+    if (castSrcLastDim == shrinkRatio)
+      return failure();
+
+    Location loc = bitcastOp.getLoc();
+    Type elemType = castDstType.getElementType();
+    assert(elemType.isSignlessIntOrIndexOrFloat());
+
+    Value zero = rewriter.create<arith::ConstantOp>(
+        loc, elemType, rewriter.getZeroAttr(elemType));
+    Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
+
+    SmallVector<int64_t> sliceShape{castDstLastDim};
+    SmallVector<int64_t> strides{1};
+    VectorType newCastDstType =
+        VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
+                        castDstType.getElementType());
+
+    for (int i = 0, e = shrinkRatio; i < e; ++i) {
+      Value extracted = rewriter.create<ExtractStridedSliceOp>(
+          loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
+          sliceShape, strides);
+      Value bitcast =
+          rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
+      res = rewriter.create<InsertStridedSliceOp>(
+          loc, bitcast, res,
+          ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
+    }
+    rewriter.replaceOp(bitcastOp, res);
+    return success();
+  }
+
+private:
+  std::function<bool(BitCastOp)> controlFn;
+};
+
 // Helper that returns a vector comparison that constructs a mask:
 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
 //
@@ -1151,6 +1236,13 @@ void mlir::vector::populateBubbleVectorBitCastOpPatterns(
                                                      benefit);
 }
 
+void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
+    RewritePatternSet &patterns,
+    std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
+  patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
+                                       std::move(controlFn), benefit);
+}
+
 void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
     RewritePatternSet &patterns,
     std::function<LogicalResult(vector::ContractionOp)> constraint,

diff  --git a/mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir b/mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir
new file mode 100644
index 000000000000..fbb2f7605e64
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-opt -split-input-file -test-vector-break-down-bitcast %s | FileCheck %s
+
+// CHECK-LABEL: func.func @bitcast_f16_to_f32
+//  CHECK-SAME: (%[[INPUT:.+]]: vector<8xf16>)
+func.func @bitcast_f16_to_f32(%input: vector<8xf16>) -> vector<4xf32> {
+  %0 = vector.bitcast %input : vector<8xf16> to vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+// CHECK: %[[EXTRACT0:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK: %[[CAST0:.+]] = vector.bitcast %[[EXTRACT0]] : vector<4xf16> to vector<2xf32>
+// CHECK: %[[INSERT0:.+]] = vector.insert_strided_slice %[[CAST0]], %[[INIT]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
+// CHECK: %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<4xf16> to vector<2xf32>
+// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST1]], %[[INSERT0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
+// CHECK: return %[[INSERT1]]
+
+// -----
+
+// CHECK-LABEL: func.func @bitcast_i8_to_i32
+//  CHECK-SAME: (%[[INPUT:.+]]: vector<16xi8>)
+func.func @bitcast_i8_to_i32(%input: vector<16xi8>) -> vector<4xi32> {
+  %0 = vector.bitcast %input : vector<16xi8> to vector<4xi32>
+  return %0: vector<4xi32>
+}
+
+// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<4xi32>
+// CHECK: %[[EXTRACT0:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi8> to vector<4xi8>
+// CHECK: %[[CAST0:.+]] = vector.bitcast %[[EXTRACT0]] : vector<4xi8> to vector<1xi32>
+// CHECK: %[[INSERT0:.+]] = vector.insert_strided_slice %[[CAST0]], %[[INIT]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<4xi32>
+// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi8> to vector<4xi8>
+// CHECK: %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<4xi8> to vector<1xi32>
+// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST1]], %[[INSERT0]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi8> to vector<4xi8>
+// CHECK: %[[CAST2:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi8> to vector<1xi32>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[CAST2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi8> to vector<4xi8>
+// CHECK: %[[CAST3:.+]] = vector.bitcast %[[EXTRACT3]] : vector<4xi8> to vector<1xi32>
+// CHECK: %[[INSERT3:.+]] = vector.insert_strided_slice %[[CAST3]], %[[INSERT2]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32>
+// CHECK: return %[[INSERT3]]

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 55825b32e442..dd853aa1dc3c 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -604,6 +604,26 @@ struct TestVectorExtractStridedSliceLowering
   }
 };
 
+struct TestVectorBreakDownBitCast
+    : public PassWrapper<TestVectorBreakDownBitCast,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast)
+
+  StringRef getArgument() const final {
+    return "test-vector-break-down-bitcast";
+  }
+  StringRef getDescription() const final {
+    return "Test pattern that breaks down vector.bitcast ops ";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) {
+      return op.getSourceVectorType().getShape().back() > 4;
+    });
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestCreateVectorBroadcast
     : public PassWrapper<TestCreateVectorBroadcast,
                          OperationPass<func::FuncOp>> {
@@ -688,6 +708,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorExtractStridedSliceLowering>();
 
+  PassRegistration<TestVectorBreakDownBitCast>();
+
   PassRegistration<TestCreateVectorBroadcast>();
 
   PassRegistration<TestVectorGatherLowering>();


        


More information about the Mlir-commits mailing list