[Mlir-commits] [mlir] 9a7677d - [mlir] Narrow bitwidth emulation for vector.load
Hanhan Wang
llvmlistbot at llvm.org
Tue Jul 11 13:38:21 PDT 2023
Author: yzhang93
Date: 2023-07-11T13:38:15-07:00
New Revision: 9a7677d8ee34e71a27834ab2886d79dfb9e71dae
URL: https://github.com/llvm/llvm-project/commit/9a7677d8ee34e71a27834ab2886d79dfb9e71dae
DIFF: https://github.com/llvm/llvm-project/commit/9a7677d8ee34e71a27834ab2886d79dfb9e71dae.diff
LOG: [mlir] Narrow bitwidth emulation for vector.load
This patch is a following for the previous patch https://reviews.llvm.org/D151519.
With this patch, vector.load op with narrow bitwidth (e.g., i4) can be converted to
supported wider bitwidth (e.g., i8).
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D154178
Added:
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 12254bc215db57..c644090d8c78cd 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -22,6 +22,10 @@
namespace mlir {
class RewritePatternSet;
+namespace arith {
+class NarrowTypeEmulationConverter;
+} // namespace arith
+
namespace vector {
struct VectorTransformsOptions;
@@ -291,6 +295,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
bool force32BitVectorIndices,
PatternBenefit benefit = 1);
+/// Appends patterns for emulating vector operations over narrow types with ops
+/// over wider types.
+void populateVectorNarrowTypeEmulationPatterns(
+ arith::NarrowTypeEmulationConverter &typeConverter,
+ RewritePatternSet &patterns);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index deba91573e0ff1..ef6eebd987c6db 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorTranspose.cpp
VectorDistribute.cpp
VectorDropLeadUnitDim.cpp
+ VectorEmulateNarrowType.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorTransferOpTransforms.cpp
VectorTransferSplitRewritePatterns.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
new file mode 100644
index 00000000000000..597a2c8d5d17ca
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -0,0 +1,200 @@
+//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// The emulation only works on 1D memref types.
+/// To make this work on N-D memref, we need to linearize the offset.
+///
+/// For example, to emulate i4 to i8, the following op:
+///
+/// %0 = memref.load %arg0[%v0, %v1] :
+/// memref<?x?xi4, strided<[?, ?], offset: ?>>
+///
+/// can be replaced with
+///
+/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
+///
+/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
+/// %linearized_size = %size0 * %size1
+/// %scaled_linear_offset = %linearized_offset / 8 * 4
+/// %scaled_base_offset = %offset / 8 * 4
+///
+/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
+/// sizes = [%linearized_size], strides = [%stride#1]
+///
+/// %new_load = vector.load %linearized[%scaled_linear_offset] :
+/// memref<?xi8, strided<[?], offset: ?>>
+
+static Value
+linearizeVectorLoad(Location loc, MemRefType sourceType, int srcBits,
+ int dstBits, SmallVector<Value> indices,
+ memref::ExtractStridedMetadataOp stridedMetadata,
+ int numElements, OpBuilder &builder) {
+ auto srcElementType = sourceType.getElementType();
+ unsigned sourceRank = indices.size();
+
+ Value baseBuffer = stridedMetadata.getBaseBuffer();
+ SmallVector<Value> baseSizes = stridedMetadata.getSizes();
+ SmallVector<Value> baseStrides = stridedMetadata.getStrides();
+ Value baseOffset = stridedMetadata.getOffset();
+ assert(indices.size() == baseStrides.size());
+
+ // Create the affine symbols and values for linearization.
+ SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
+ bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
+ symbols[0] = builder.getAffineSymbolExpr(0);
+ AffineExpr addMulMap = symbols.front();
+ AffineExpr mulMap = symbols.front();
+
+ SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
+ offsetValues[0] = builder.getIndexAttr(0);
+ SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
+ sizeValues[0] = builder.getIndexAttr(1);
+
+ for (unsigned i = 0; i < sourceRank; ++i) {
+ unsigned offsetIdx = 2 * i + 1;
+ addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
+ offsetValues[offsetIdx] = indices[i];
+ offsetValues[offsetIdx + 1] = baseStrides[i];
+
+ unsigned sizeIdx = i + 1;
+ mulMap = mulMap * symbols[sizeIdx];
+ sizeValues[sizeIdx] = baseSizes[i];
+ }
+
+ // Adjust linearizedOffset by the scale factor (dstBits / srcBits).
+ OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
+ AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
+ offsetValues.back() = scaler;
+
+ OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
+ builder, loc, scaledAddMulMap, offsetValues);
+ OpFoldResult linearizedSize =
+ affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);
+
+ // Adjust baseOffset by the scale factor (dstBits / srcBits).
+ AffineExpr s0, s1;
+ bindSymbols(builder.getContext(), s0, s1);
+ OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
+ builder, loc, s0.floorDiv(s1), {baseOffset, scaler});
+
+ // Flatten n-D MemRef to 1-D MemRef.
+ auto layoutAttr = StridedLayoutAttr::get(sourceType.getContext(),
+ ShapedType::kDynamic, {1});
+ int64_t staticShape = sourceType.hasStaticShape()
+ ? sourceType.getNumElements()
+ : ShapedType::kDynamic;
+ auto flattenMemrefType = MemRefType::get(
+ staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());
+
+ auto reinterpret = builder.create<memref::ReinterpretCastOp>(
+ loc, flattenMemrefType, baseBuffer,
+ getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
+ getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
+ baseStrides.back());
+
+ return builder.create<vector::LoadOp>(
+ loc, VectorType::get(numElements, srcElementType),
+ reinterpret.getResult(),
+ getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset));
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConvertVectorLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto loc = op.getLoc();
+ auto sourceType = cast<MemRefType>(adaptor.getBase().getType());
+ Type oldElementType = op.getType().getElementType();
+ Type newElementType = sourceType.getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = newElementType.getIntOrFloatBitWidth();
+
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+ int scale = dstBits / srcBits;
+
+ // Adjust the number of elements to load when emulating narrow types,
+ // and then cast back to the original type with vector.bitcast op.
+ // Here only the 1-D vector load is considered, and the N-D memref types
+ // should be linearized.
+ // For example, to emulate i4 to i8, the following op:
+ //
+ // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
+ //
+ // can be replaced with
+ //
+ // %1 = vector.load %0[%linear_index] : memref<12xi8>, vector<2xi8>
+ // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
+ //
+ // TODO: Currently, only the even number of elements loading is supported.
+ // To deal with the odd number of elements, one has to extract the
+ // subvector at the proper offset after bit-casting.
+
+ auto origElements = op.getVectorType().getNumElements();
+ if (origElements % scale != 0)
+ return failure();
+
+ auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, adaptor.getBase());
+
+ auto numElements = int(std::ceil(double(origElements) / scale));
+ auto newLoad = linearizeVectorLoad(loc, sourceType, srcBits, dstBits,
+ adaptor.getIndices(), stridedMetadata,
+ numElements, rewriter);
+
+ numElements *= scale;
+ auto castType = VectorType::get(numElements, oldElementType);
+ auto bitCast = rewriter.create<vector::BitCastOp>(loc, castType, newLoad);
+
+ rewriter.replaceOp(op, bitCast->getResult(0));
+ return success();
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+void vector::populateVectorNarrowTypeEmulationPatterns(
+ arith::NarrowTypeEmulationConverter &typeConverter,
+ RewritePatternSet &patterns) {
+
+ // Populate `vector.*` conversion patterns.
+ patterns.add<ConvertVectorLoad>(typeConverter, patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
new file mode 100644
index 00000000000000..71c133778087d6
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s
+
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+
+// Expect no conversions, i8 is supported.
+// CHECK-LABEL: func @vector_load_i8
+// CHECK-SAME: (%[[ARG:.*]]: memref<3x4xi8>, %[[IDX0:.*]]: index, %[[IDX1:.*]]: index)
+// CHECK-NEXT: [[L:%.+]] = vector.load %[[ARG]][%[[IDX0]], %[[IDX1]]] : memref<3x4xi8>, vector<4xi8>
+// CHECK-NEXT: return
+func.func @vector_load_i8(%arg0: memref<3x4xi8>, %arg1: index, %arg2: index) {
+ %0 = vector.load %arg0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_load_i4
+// CHECK-SAME: (%[[ARG:.*]]: memref<3x4xi8>, %[[IDX0:.*]]: index, %[[IDX1:.*]]: index)
+// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<0> : vector<3x4xi4>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x4xi8> -> memref<i8>, index, index, index, index, index
+// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[IDX0]], %[[STRIDES]]#0, %[[IDX1]], %[[STRIDES]]#1]
+// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<12xi8, strided<[1], offset: ?>>
+// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[CAST]][%[[INDEX]]] : memref<12xi8, strided<[1], offset: ?>>, vector<2xi8>
+// CHECK-NEXT: %[[BITCAST:.*]] = vector.bitcast %[[LOAD]] : vector<2xi8> to vector<4xi4>
+// CHECK-NEXT: %[[INSERT:.*]] = vector.insert %[[BITCAST]], %[[CST]] [0] : vector<4xi4> into vector<3x4xi4>
+// CHECK-NEXT: return
+func.func @vector_load_i4(%arg0: memref<3x4xi4>, %arg1: index, %arg2: index) {
+ %cst = arith.constant dense<0> : vector<3x4xi4>
+ %0 = vector.load %arg0[%arg1, %arg2] : memref<3x4xi4>, vector<4xi4>
+ %1 = vector.insert %0, %cst [0] : vector<4xi4> into vector<3x4xi4>
+ return
+}
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index b1f23084449c2b..64646b01b9f515 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -95,6 +96,7 @@ struct TestEmulateNarrowTypePass
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
+ vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
More information about the Mlir-commits
mailing list