[Mlir-commits] [mlir] f4bef78 - Add narrow type emulation pattern for vector.transfer_read

Hanhan Wang llvmlistbot at llvm.org
Tue Aug 29 13:15:47 PDT 2023


Author: yzhang93
Date: 2023-08-29T13:15:19-07:00
New Revision: f4bef787bcb2536c674ff0967f22ac2bd74bd571

URL: https://github.com/llvm/llvm-project/commit/f4bef787bcb2536c674ff0967f22ac2bd74bd571
DIFF: https://github.com/llvm/llvm-project/commit/f4bef787bcb2536c674ff0967f22ac2bd74bd571.diff

LOG: Add narrow type emulation pattern for vector.transfer_read

Reviewed By: mravishankar, hanchung

Differential Revision: https://reviews.llvm.org/D158757

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
    mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 01ee43354a711e..b2b7bfc5e4437c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -36,9 +36,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     auto loc = op.getLoc();
-    auto sourceType = cast<MemRefType>(adaptor.getBase().getType());
+    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
-    Type newElementType = sourceType.getElementType();
+    Type newElementType = convertedType.getElementType();
     int srcBits = oldElementType.getIntOrFloatBitWidth();
     int dstBits = newElementType.getIntOrFloatBitWidth();
 
@@ -81,16 +81,73 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto srcElementType = sourceType.getElementType();
-    auto numElements =
-        static_cast<int>(std::ceil(static_cast<double>(origElements) / scale));
+    auto numElements = (origElements + scale - 1) / scale;
     auto newLoad = rewriter.create<vector::LoadOp>(
-        loc, VectorType::get(numElements, srcElementType), adaptor.getBase(),
+        loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
 
-    numElements *= scale;
-    auto castType = VectorType::get(numElements, oldElementType);
-    auto bitCast = rewriter.create<vector::BitCastOp>(loc, castType, newLoad);
+    auto bitCast =
+        rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
+
+    rewriter.replaceOp(op, bitCast->getResult(0));
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertVectorTransferRead
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorTransferRead final
+    : OpConversionPattern<vector::TransferReadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+    auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
+    Type oldElementType = op.getType().getElementType();
+    Type newElementType = convertedType.getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = newElementType.getIntOrFloatBitWidth();
+
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+    int scale = dstBits / srcBits;
+
+    auto origElements = op.getVectorType().getNumElements();
+    if (origElements % scale != 0)
+      return failure();
+
+    auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
+                                                      adaptor.getPadding());
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
+
+    OpFoldResult linearizedIndices;
+    std::tie(std::ignore, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(),
+            getAsOpFoldResult(adaptor.getIndices()));
+
+    auto numElements = (origElements + scale - 1) / scale;
+    auto newReadType = VectorType::get(numElements, newElementType);
+
+    auto newRead = rewriter.create<vector::TransferReadOp>(
+        loc, newReadType, adaptor.getSource(),
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
+        newPadding);
+
+    auto bitCast =
+        rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
 
     rewriter.replaceOp(op, bitCast->getResult(0));
     return success();
@@ -107,5 +164,6 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `vector.*` conversion patterns.
-  patterns.add<ConvertVectorLoad>(typeConverter, patterns.getContext());
+  patterns.add<ConvertVectorLoad, ConvertVectorTransferRead>(
+      typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index e3c6b098b70ba4..6fcea33ddc952f 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -79,3 +79,32 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK32:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
 //      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
+
+// -----
+
+func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
+    %c0 = arith.constant 0 : i4
+    %0 = memref.alloc() : memref<3x8xi4>
+    %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
+      memref<3x8xi4>, vector<8xi4>
+    return %1 : vector<8xi4>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_transfer_read_i4
+// CHECK-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK:   %[[CONST:.+]] = arith.constant 0 : i4
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+//      CHECK:   %[[PAD:.+]] = arith.extui %[[CONST]] : i4 to i8
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
+//      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_transfer_read_i4
+// CHECK32-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK32:   %[[CONST:.+]] = arith.constant 0 : i4
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[PAD:.+]] = arith.extui %[[CONST]] : i4 to i32
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
+//      CHECK32:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>


        


More information about the Mlir-commits mailing list