[Mlir-commits] [mlir] 9a7677d - [mlir] Narrow bitwidth emulation for vector.load

Hanhan Wang llvmlistbot at llvm.org
Tue Jul 11 13:38:21 PDT 2023


Author: yzhang93
Date: 2023-07-11T13:38:15-07:00
New Revision: 9a7677d8ee34e71a27834ab2886d79dfb9e71dae

URL: https://github.com/llvm/llvm-project/commit/9a7677d8ee34e71a27834ab2886d79dfb9e71dae
DIFF: https://github.com/llvm/llvm-project/commit/9a7677d8ee34e71a27834ab2886d79dfb9e71dae.diff

LOG: [mlir] Narrow bitwidth emulation for vector.load

This patch is a following for the previous patch https://reviews.llvm.org/D151519.
With this patch, vector.load op with narrow bitwidth (e.g., i4) can be converted to
supported wider bitwidth (e.g., i8).

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D154178

Added: 
    mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
    mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 12254bc215db57..c644090d8c78cd 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -22,6 +22,10 @@
 namespace mlir {
 class RewritePatternSet;
 
+namespace arith {
+class NarrowTypeEmulationConverter;
+} // namespace arith
+
 namespace vector {
 struct VectorTransformsOptions;
 
@@ -291,6 +295,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
                                                bool force32BitVectorIndices,
                                                PatternBenefit benefit = 1);
 
+/// Appends patterns for emulating vector operations over narrow types with ops
+/// over wider types.
+void populateVectorNarrowTypeEmulationPatterns(
+    arith::NarrowTypeEmulationConverter &typeConverter,
+    RewritePatternSet &patterns);
+
 } // namespace vector
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index deba91573e0ff1..ef6eebd987c6db 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorTranspose.cpp
   VectorDistribute.cpp
   VectorDropLeadUnitDim.cpp
+  VectorEmulateNarrowType.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
   VectorTransferOpTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
new file mode 100644
index 00000000000000..597a2c8d5d17ca
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -0,0 +1,200 @@
+//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// The emulation only works on 1D memref types.
+/// To make this work on N-D memref, we need to linearize the offset.
+///
+/// For example, to emulate i4 to i8, the following op:
+///
+/// %0 = memref.load %arg0[%v0, %v1] :
+///                  memref<?x?xi4, strided<[?, ?], offset: ?>>
+///
+/// can be replaced with
+///
+/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
+///
+/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
+/// %linearized_size = %size0 * %size1
+/// %scaled_linear_offset = %linearized_offset / 8 * 4
+/// %scaled_base_offset = %offset / 8 * 4
+///
+/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
+///                      sizes = [%linearized_size], strides = [%stride#1]
+///
+/// %new_load = vector.load %linearized[%scaled_linear_offset] :
+///                         memref<?xi8, strided<[?], offset: ?>>
+
+static Value
+linearizeVectorLoad(Location loc, MemRefType sourceType, int srcBits,
+                    int dstBits, SmallVector<Value> indices,
+                    memref::ExtractStridedMetadataOp stridedMetadata,
+                    int numElements, OpBuilder &builder) {
+  auto srcElementType = sourceType.getElementType();
+  unsigned sourceRank = indices.size();
+
+  Value baseBuffer = stridedMetadata.getBaseBuffer();
+  SmallVector<Value> baseSizes = stridedMetadata.getSizes();
+  SmallVector<Value> baseStrides = stridedMetadata.getStrides();
+  Value baseOffset = stridedMetadata.getOffset();
+  assert(indices.size() == baseStrides.size());
+
+  // Create the affine symbols and values for linearization.
+  SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
+  bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
+  symbols[0] = builder.getAffineSymbolExpr(0);
+  AffineExpr addMulMap = symbols.front();
+  AffineExpr mulMap = symbols.front();
+
+  SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
+  offsetValues[0] = builder.getIndexAttr(0);
+  SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
+  sizeValues[0] = builder.getIndexAttr(1);
+
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    unsigned offsetIdx = 2 * i + 1;
+    addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
+    offsetValues[offsetIdx] = indices[i];
+    offsetValues[offsetIdx + 1] = baseStrides[i];
+
+    unsigned sizeIdx = i + 1;
+    mulMap = mulMap * symbols[sizeIdx];
+    sizeValues[sizeIdx] = baseSizes[i];
+  }
+
+  // Adjust linearizedOffset by the scale factor (dstBits / srcBits).
+  OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
+  AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
+  offsetValues.back() = scaler;
+
+  OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
+      builder, loc, scaledAddMulMap, offsetValues);
+  OpFoldResult linearizedSize =
+      affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);
+
+  // Adjust baseOffset by the scale factor (dstBits / srcBits).
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+  OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
+      builder, loc, s0.floorDiv(s1), {baseOffset, scaler});
+
+  // Flatten n-D MemRef to 1-D MemRef.
+  auto layoutAttr = StridedLayoutAttr::get(sourceType.getContext(),
+                                           ShapedType::kDynamic, {1});
+  int64_t staticShape = sourceType.hasStaticShape()
+                            ? sourceType.getNumElements()
+                            : ShapedType::kDynamic;
+  auto flattenMemrefType = MemRefType::get(
+      staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());
+
+  auto reinterpret = builder.create<memref::ReinterpretCastOp>(
+      loc, flattenMemrefType, baseBuffer,
+      getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
+      getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
+      baseStrides.back());
+
+  return builder.create<vector::LoadOp>(
+      loc, VectorType::get(numElements, srcElementType),
+      reinterpret.getResult(),
+      getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset));
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConvertVectorLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+    auto sourceType = cast<MemRefType>(adaptor.getBase().getType());
+    Type oldElementType = op.getType().getElementType();
+    Type newElementType = sourceType.getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = newElementType.getIntOrFloatBitWidth();
+
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+    int scale = dstBits / srcBits;
+
+    // Adjust the number of elements to load when emulating narrow types,
+    // and then cast back to the original type with vector.bitcast op.
+    // Here only the 1-D vector load is considered, and the N-D memref types
+    // should be linearized.
+    // For example, to emulate i4 to i8, the following op:
+    //
+    // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
+    //
+    // can be replaced with
+    //
+    // %1 = vector.load %0[%linear_index] : memref<12xi8>, vector<2xi8>
+    // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
+    //
+    // TODO: Currently, only the even number of elements loading is supported.
+    // To deal with the odd number of elements, one has to extract the
+    // subvector at the proper offset after bit-casting.
+
+    auto origElements = op.getVectorType().getNumElements();
+    if (origElements % scale != 0)
+      return failure();
+
+    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, adaptor.getBase());
+
+    auto numElements = int(std::ceil(double(origElements) / scale));
+    auto newLoad = linearizeVectorLoad(loc, sourceType, srcBits, dstBits,
+                                       adaptor.getIndices(), stridedMetadata,
+                                       numElements, rewriter);
+
+    numElements *= scale;
+    auto castType = VectorType::get(numElements, oldElementType);
+    auto bitCast = rewriter.create<vector::BitCastOp>(loc, castType, newLoad);
+
+    rewriter.replaceOp(op, bitCast->getResult(0));
+    return success();
+  }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+void vector::populateVectorNarrowTypeEmulationPatterns(
+    arith::NarrowTypeEmulationConverter &typeConverter,
+    RewritePatternSet &patterns) {
+
+  // Populate `vector.*` conversion patterns.
+  patterns.add<ConvertVectorLoad>(typeConverter, patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
new file mode 100644
index 00000000000000..71c133778087d6
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s
+
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+
+// Expect no conversions, i8 is supported.
+// CHECK-LABEL: func @vector_load_i8
+// CHECK-SAME:  (%[[ARG:.*]]: memref<3x4xi8>, %[[IDX0:.*]]: index, %[[IDX1:.*]]: index)
+// CHECK-NEXT:  [[L:%.+]] = vector.load %[[ARG]][%[[IDX0]], %[[IDX1]]] : memref<3x4xi8>, vector<4xi8>
+// CHECK-NEXT:  return
+func.func @vector_load_i8(%arg0: memref<3x4xi8>, %arg1: index, %arg2: index) {
+    %0 = vector.load %arg0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_load_i4
+// CHECK-SAME:  (%[[ARG:.*]]: memref<3x4xi8>, %[[IDX0:.*]]: index, %[[IDX1:.*]]: index)
+// CHECK-NEXT:  %[[CST:.*]] = arith.constant dense<0> : vector<3x4xi4>
+// CHECK-NEXT:  %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x4xi8> -> memref<i8>, index, index, index, index, index
+// CHECK-NEXT:  %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[IDX0]], %[[STRIDES]]#0, %[[IDX1]], %[[STRIDES]]#1]
+// CHECK-NEXT:  %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK-NEXT:  %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT:  %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<12xi8, strided<[1], offset: ?>>
+// CHECK-NEXT:  %[[LOAD:.*]] = vector.load %[[CAST]][%[[INDEX]]] : memref<12xi8, strided<[1], offset: ?>>, vector<2xi8>
+// CHECK-NEXT:  %[[BITCAST:.*]] = vector.bitcast %[[LOAD]] : vector<2xi8> to vector<4xi4>
+// CHECK-NEXT:  %[[INSERT:.*]] = vector.insert %[[BITCAST]], %[[CST]] [0] : vector<4xi4> into vector<3x4xi4>
+// CHECK-NEXT:  return
+func.func @vector_load_i4(%arg0: memref<3x4xi4>, %arg1: index, %arg2: index) {
+    %cst = arith.constant dense<0> : vector<3x4xi4>
+    %0 = vector.load %arg0[%arg1, %arg2] : memref<3x4xi4>, vector<4xi4>
+    %1 = vector.insert %0, %cst [0] : vector<4xi4> into vector<3x4xi4>
+    return
+}

diff  --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index b1f23084449c2b..64646b01b9f515 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -95,6 +96,7 @@ struct TestEmulateNarrowTypePass
 
     arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
     memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
+    vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
 
     if (failed(applyPartialConversion(op, target, std::move(patterns))))
       signalPassFailure();


        


More information about the Mlir-commits mailing list