[Mlir-commits] [mlir] [MLIR] support dynamic indexing in `VectorEmulateNarrowTypes` (PR #114169)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 29 20:10:20 PDT 2024
https://github.com/lialan created https://github.com/llvm/llvm-project/pull/114169
None
>From cc7b19b56afe66e4376765e0b3385cefba2ed754 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Tue, 29 Oct 2024 14:03:23 +0000
Subject: [PATCH] Implement VectorLoadOp
---
.../Transforms/VectorEmulateNarrowType.cpp | 116 ++++++++++++++----
...emulate-narrow-type-unaligned-dynamic.mlir | 53 ++++++++
2 files changed, 143 insertions(+), 26 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1d6f8a991d9b5b..04514725c3aeee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -149,6 +150,61 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
dest, offsets, strides);
}
+static void dynamicallyExtractElementsToVector(
+ RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
+ Value destVec, OpFoldResult srcOffsetVar, int64_t loopSize) {
+ /*
+ // Create affine maps for the lower and upper bounds
+ AffineMap lowerBoundMap = AffineMap::getConstantMap(0, rewriter.getContext());
+ AffineMap upperBoundMap =
+ AffineMap::getConstantMap(loopSize, rewriter.getContext());
+
+ auto forLoop = rewriter.create<affine::AffineForOp>(
+ loc, ValueRange{}, lowerBoundMap, ValueRange{}, upperBoundMap, 1,
+ ArrayRef<Value>(destVec));
+
+ OpBuilder builder =
+ OpBuilder::atBlockEnd(forLoop.getBody(), rewriter.getListener());
+
+ auto iv = forLoop.getInductionVar();
+
+ auto loopDestVec = forLoop.getRegionIterArgs()[0];
+ auto extractLoc = builder.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(), iv);
+ auto extractElemOp = builder.create<vector::ExtractElementOp>(
+ loc, elemType, srcVec, extractLoc);
+ auto insertElemOp = builder.create<vector::InsertElementOp>(
+ loc, extractElemOp, loopDestVec, iv);
+ builder.create<affine::AffineYieldOp>(loc,
+ ValueRange{insertElemOp->getResult(0)});
+ return forLoop->getResult(0);
+ */
+ for (int i = 0; i < loopSize; ++i) {
+ Value extractLoc;
+ if (i == 0) {
+ extractLoc = srcOffsetVar.dyn_cast<Value>();
+ } else {
+ extractLoc = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(),
+ rewriter.create<arith::ConstantIndexOp>(loc, i));
+ }
+ auto extractOp =
+ rewriter.create<vector::ExtractOp>(loc, srcVec, extractLoc);
+ rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
+ }
+}
+
+static TypedValue<VectorType>
+emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
+ Value base, OpFoldResult linearizedIndices, int64_t numBytes,
+ int64_t scale, Type oldElememtType, Type newElementType) {
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ loc, VectorType::get(numBytes, newElementType), base,
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ return rewriter.create<vector::BitCastOp>(
+ loc, VectorType::get(numBytes * scale, oldElememtType), newLoad);
+};
+
namespace {
//===----------------------------------------------------------------------===//
@@ -380,26 +436,29 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- if (!foldedIntraVectorOffset) {
- // unimplemented case for dynamic intra vector offset
- return failure();
- }
-
+ // always load enough elements which can cover the original elements
+ auto maxintraVectorOffset =
+ foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
auto numElements =
- llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
- auto newLoad = rewriter.create<vector::LoadOp>(
- loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
-
- Value result = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements * scale, oldElementType), newLoad);
+ llvm::divideCeil(maxintraVectorOffset + origElements, scale);
+ Value result =
+ emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
+ numElements, scale, oldElementType, newElementType);
- if (isUnalignedEmulation) {
- result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+ *foldedIntraVectorOffset, origElements);
+ }
+ rewriter.replaceOp(op, result);
+ } else {
+ auto resultVector = rewriter.create<arith::ConstantOp>(
+ loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+ dynamicallyExtractElementsToVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+ linearizedInfo.intraVectorOffset, origElements);
+ rewriter.replaceOp(op, resultVector);
}
-
- rewriter.replaceOp(op, result);
return success();
}
};
@@ -604,13 +663,10 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- if (!foldedIntraVectorOffset) {
- // unimplemented case for dynamic inra-vector offset
- return failure();
- }
-
+ auto maxIntraVectorOffset =
+ foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
auto numElements =
- llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+ llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -621,9 +677,17 @@ struct ConvertVectorTransferRead final
loc, VectorType::get(numElements * scale, oldElementType), newRead);
Value result = bitCast->getResult(0);
- if (isUnalignedEmulation) {
- result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+ *foldedIntraVectorOffset, origElements);
+ }
+ } else {
+ result = rewriter.create<arith::ConstantOp>(
+ loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+ dynamicallyExtractElementsToVector(rewriter, loc, bitCast, result,
+ linearizedInfo.intraVectorOffset,
+ origElements);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
new file mode 100644
index 00000000000000..a92e62538c5332
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+
+// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
+// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
+func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %cst = arith.constant dense<0> : vector<3x3xi2>
+ %1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
+ %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+ return %2 : vector<3x3xi2>
+}
+
+// CHECK: func @vector_load_i2
+// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
+
+//-----
+
+func.func @vector_transfer_read_i2(%arg1: index, %arg2: index) -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0i2 = arith.constant 0 : i2
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: func @vector_transfer_read_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
More information about the Mlir-commits
mailing list