[Mlir-commits] [mlir] [MLIR] support dynamic indexing of `vector.maskedload` in `VectorEmulateNarrowTypes` (PR #115070)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 11 17:14:26 PST 2024
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/115070
>From 5eebcc0daa1f1594955159e3d3ea13512dcacb41 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Wed, 30 Oct 2024 19:37:11 +0000
Subject: [PATCH 1/7] Implement dynamic indexing for MaskedLoads
---
.../Transforms/VectorEmulateNarrowType.cpp | 101 ++++++++++++------
.../vector-emulate-narrow-type-unaligned.mlir | 52 +++++++++
2 files changed, 120 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f169dab3bdd9af..3c94e992d695c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
int origElements, int scale,
int intraDataOffset = 0) {
+ assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
Operation *maskOp = mask.getDefiningOp();
@@ -182,6 +183,27 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
return dest;
}
+/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
+static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
+ TypedValue<VectorType> source,
+ Value dest, OpFoldResult destOffsetVar,
+ int64_t length) {
+ assert(length > 0 && "length must be greater than 0");
+ for (int i = 0; i < length; ++i) {
+ Value insertLoc;
+ if (i == 0) {
+ insertLoc = destOffsetVar.dyn_cast<Value>();
+ } else {
+ insertLoc = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
+ rewriter.create<arith::ConstantIndexOp>(loc, i));
+ }
+ auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
+ dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
+ }
+ return dest;
+}
+
/// Returns the op sequence for an emulated sub-byte data type vector load.
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
/// The load location is given by `base` and `linearizedIndices`, and the
@@ -199,7 +221,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
return rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
newLoad);
-};
+}
namespace {
@@ -546,29 +568,30 @@ struct ConvertVectorMaskedLoad final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- if (!foldedIntraVectorOffset) {
- // unimplemented case for dynamic intra vector offset
- return failure();
- }
-
- FailureOr<Operation *> newMask =
- getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
- *foldedIntraVectorOffset);
+ auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
+ FailureOr<Operation *> newMask = getCompressedMaskOp(
+ rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
if (failed(newMask))
return failure();
+ Value passthru = op.getPassThru();
+
auto numElements =
- llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+ llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto loadType = VectorType::get(numElements, newElementType);
auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
- Value passthru = op.getPassThru();
- if (isUnalignedEmulation) {
- // create an empty vector of the new type
- auto emptyVector = rewriter.create<arith::ConstantOp>(
- loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
- passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
- *foldedIntraVectorOffset);
+ auto emptyVector = rewriter.create<arith::ConstantOp>(
+ loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ passthru = staticallyInsertSubvector(
+ rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
+ }
+ } else {
+ passthru = dynamicallyInsertSubVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
+ emptyVector, linearizedInfo.intraDataOffset, origElements);
}
auto newPassThru =
rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -585,23 +608,36 @@ struct ConvertVectorMaskedLoad final
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
Value mask = op.getMask();
- if (isUnalignedEmulation) {
- auto newSelectMaskType =
- VectorType::get(numElements * scale, rewriter.getI1Type());
- // TODO: can fold if op's mask is constant
- auto emptyVector = rewriter.create<arith::ConstantOp>(
- loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
- mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
- *foldedIntraVectorOffset);
+ auto newSelectMaskType =
+ VectorType::get(numElements * scale, rewriter.getI1Type());
+ // TODO: try to fold if op's mask is constant
+ auto emptyMask = rewriter.create<arith::ConstantOp>(
+ loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
+ *foldedIntraVectorOffset);
+ }
+ } else {
+ mask = dynamicallyInsertSubVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+ linearizedInfo.intraDataOffset, origElements);
}
Value result =
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
-
- if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ result =
+ staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+ *foldedIntraVectorOffset, origElements);
+ }
+ } else {
+ auto resultVector = rewriter.create<arith::ConstantOp>(
+ loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+ result = dynamicallyExtractSubVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+ linearizedInfo.intraDataOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -659,10 +695,9 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- auto maxIntraVectorOffset =
- foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
+ auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
auto numElements =
- llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
+ llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 0cecaddc5733e2..efa31b8bf5ac7d 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -183,3 +183,55 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+// -----
+
+func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %cst = arith.constant dense<0> : vector<3x3xi2>
+ %c2 = arith.constant 2 : index
+ %mask = vector.constant_mask [3] : vector<3xi1>
+ %1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
+ memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
+// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
+// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
+// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
+// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
+// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
+// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
+// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
+// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
+// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
+// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// extracts:
+// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
+// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
+// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
+// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
+// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
+// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2>
+// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
+// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
+// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
>From f3c2d3ac5a1ba10c1a571c6bed01ed83161b8fea Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Tue, 5 Nov 2024 21:32:23 +0000
Subject: [PATCH 2/7] Small update
---
.../Vector/Transforms/VectorEmulateNarrowType.cpp | 14 ++++++--------
.../vector-emulate-narrow-type-unaligned.mlir | 1 -
2 files changed, 6 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 3c94e992d695c0..56273ac2899d7e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -190,14 +190,12 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
int64_t length) {
assert(length > 0 && "length must be greater than 0");
for (int i = 0; i < length; ++i) {
- Value insertLoc;
- if (i == 0) {
- insertLoc = destOffsetVar.dyn_cast<Value>();
- } else {
- insertLoc = rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
- rewriter.create<arith::ConstantIndexOp>(loc, i));
- }
+ Value insertLoc =
+ 1 == 0
+ ? destOffsetVar.dyn_cast<Value>()
+ : rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
+ rewriter.create<arith::ConstantIndexOp>(loc, i));
auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
}
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index efa31b8bf5ac7d..6a10a2f9ed32fe 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -219,7 +219,6 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
-// extracts:
// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
>From 94ab287240cf9d347b322fffc5ac878c4a558431 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Tue, 5 Nov 2024 22:49:27 +0000
Subject: [PATCH 3/7] fix
---
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 56273ac2899d7e..dabb137351601d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -191,7 +191,7 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
assert(length > 0 && "length must be greater than 0");
for (int i = 0; i < length; ++i) {
Value insertLoc =
- 1 == 0
+ i == 0
? destOffsetVar.dyn_cast<Value>()
: rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
>From 21bd52c3aa08ae4b98a4d7059f2d2d3d0c453a58 Mon Sep 17 00:00:00 2001
From: hasekawa-takumi <167335845+hasekawa-takumi at users.noreply.github.com>
Date: Thu, 7 Nov 2024 23:10:35 -0500
Subject: [PATCH 4/7] Update
---
.../Transforms/VectorEmulateNarrowType.cpp | 6 ++----
.../vector-emulate-narrow-type-unaligned.mlir | 18 ++++++++++++++++--
2 files changed, 18 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dabb137351601d..9c565c6881c4e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -631,11 +631,9 @@ struct ConvertVectorMaskedLoad final
*foldedIntraVectorOffset, origElements);
}
} else {
- auto resultVector = rewriter.create<arith::ConstantOp>(
- loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
- linearizedInfo.intraDataOffset, origElements);
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
+ op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 6a10a2f9ed32fe..6d37493d174a21 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -205,6 +205,8 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
+
+// extract passthru vector, and insert into zero vector, this is for constructing a new passthru
// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -215,21 +217,33 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
+
+// bitcast the new passthru vector to emulated i8 vector
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
+
+// use the emulated i8 vector to masked load from the memory
// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+
+// bitcast back to i2 vector
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+
// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
+
+// create a mask vector and select passthru part from the loaded vector.
+// note that if indices are known then we can fold the part generating mask.
// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
+
// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
-// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2>
+
+// finally, insert the selected parts into actual passthru vector.
// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
-// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2>
+// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[PTH]] [0] : i2 into vector<3xi2>
// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
>From d6437e9e82a92b5653d50624a3add9a15a7cb68c Mon Sep 17 00:00:00 2001
From: Alan Li <alan.li at me.com>
Date: Mon, 11 Nov 2024 09:34:42 -0500
Subject: [PATCH 5/7] update comments
---
.../Transforms/VectorEmulateNarrowType.cpp | 6 +++--
.../vector-emulate-narrow-type-unaligned.mlir | 27 ++++++++++---------
2 files changed, 18 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ef072638af26ef..58b799b028694c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -43,7 +43,9 @@ using namespace mlir;
///
/// %mask = [1, 1, 0, 0, 0, 0]
///
-/// will first be padded with number of `intraDataOffset` zeros:
+/// will first be padded in the front with number of `intraDataOffset` zeros,
+/// and pad zeros in the back to make the number of elements a multiple of
+/// `scale` (just to make it easier to compute). The new mask will be:
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
///
/// then it will return the following new compressed mask:
@@ -54,7 +56,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
int origElements, int scale,
int intraDataOffset = 0) {
assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
- auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
+ auto numElements = llvm::divideCeil(intraDataOffset + origElements, scale);
Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 6d37493d174a21..7ed75ff7f1579c 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -206,7 +206,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
-// extract passthru vector, and insert into zero vector, this is for constructing a new passthru
+// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -216,32 +216,33 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
-// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
+// CHECK: %[[NEW_PASSTHRU:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
-// bitcast the new passthru vector to emulated i8 vector
-// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
+// Bitcast the new passthru vector to emulated i8 vector
+// CHECK: %[[BCAST_PASSTHRU:.+]] = vector.bitcast %[[NEW_PASSTHRU]] : vector<8xi2> to vector<2xi8>
-// use the emulated i8 vector to masked load from the memory
-// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
+// Use the emulated i8 vector for masked load from the source memory
+// CHECK: %[[SOURCE:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BCAST_PASSTHRU]]
// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
-// bitcast back to i2 vector
-// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// Bitcast back to i2 vector
+// CHECK: %[[BCAST_MASKLOAD:.+]] = vector.bitcast %[[SOURCE]] : vector<2xi8> to vector<8xi2>
// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
-// create a mask vector and select passthru part from the loaded vector.
-// note that if indices are known then we can fold the part generating mask.
+// Create a mask vector
+// Note that if indices are known then we can fold the part generating mask.
// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
-// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
+// CHECK: %[[NEW_MASK:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
-// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
+// Select the effective part from the source and passthru vectors
+// CHECK: %[[SELECT:.+]] = arith.select %[[NEW_MASK]], %[[BCAST_MASKLOAD]], %[[NEW_PASSTHRU]] : vector<8xi1>, vector<8xi2>
-// finally, insert the selected parts into actual passthru vector.
+// Finally, insert the selected parts into actual passthru vector.
// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[PTH]] [0] : i2 into vector<3xi2>
// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
>From 9e7aedfe8c88ad2be2426ac394456d83697149ed Mon Sep 17 00:00:00 2001
From: Alan Li <alan.li at me.com>
Date: Mon, 11 Nov 2024 19:43:11 -0500
Subject: [PATCH 6/7] fix according to comments
---
.../Transforms/VectorEmulateNarrowType.cpp | 73 ++++++++-----------
1 file changed, 32 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58b799b028694c..604d261b4513d7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -194,13 +194,14 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
Value dest, OpFoldResult destOffsetVar,
int64_t length) {
assert(length > 0 && "length must be greater than 0");
+ Value destOffsetVal =
+ getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
for (int i = 0; i < length; ++i) {
- Value insertLoc =
- i == 0
- ? destOffsetVar.dyn_cast<Value>()
- : rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
- rewriter.create<arith::ConstantIndexOp>(loc, i));
+ auto insertLoc = i == 0
+ ? destOffsetVal
+ : rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), destOffsetVal,
+ rewriter.create<arith::ConstantIndexOp>(loc, i));
auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
}
@@ -465,18 +466,16 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
numElements, oldElementType, newElementType);
- if (foldedIntraVectorOffset) {
- if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
- }
- } else {
+ if (!foldedIntraVectorOffset) {
auto resultVector = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
+ } else if (isUnalignedEmulation) {
+ result =
+ staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+ *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
return success();
@@ -571,7 +570,7 @@ struct ConvertVectorMaskedLoad final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
+ int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
FailureOr<Operation *> newMask = getCompressedMaskOp(
rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
if (failed(newMask))
@@ -586,15 +585,13 @@ struct ConvertVectorMaskedLoad final
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
- if (foldedIntraVectorOffset) {
- if (isUnalignedEmulation) {
- passthru = staticallyInsertSubvector(
- rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
- }
- } else {
+ if (!foldedIntraVectorOffset) {
passthru = dynamicallyInsertSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
emptyVector, linearizedInfo.intraDataOffset, origElements);
+ } else if (isUnalignedEmulation) {
+ passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
+ *foldedIntraVectorOffset);
}
auto newPassThru =
rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -616,29 +613,25 @@ struct ConvertVectorMaskedLoad final
// TODO: try to fold if op's mask is constant
auto emptyMask = rewriter.create<arith::ConstantOp>(
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
- if (foldedIntraVectorOffset) {
- if (isUnalignedEmulation) {
- mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
- *foldedIntraVectorOffset);
- }
- } else {
+ if (!foldedIntraVectorOffset) {
mask = dynamicallyInsertSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
linearizedInfo.intraDataOffset, origElements);
+ } else if (isUnalignedEmulation) {
+ mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
+ *foldedIntraVectorOffset);
}
Value result =
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
- if (foldedIntraVectorOffset) {
- if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
- }
- } else {
+ if (!foldedIntraVectorOffset) {
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
+ } else if (isUnalignedEmulation) {
+ result =
+ staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+ *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -696,7 +689,7 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
+ int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
auto numElements =
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
@@ -709,18 +702,16 @@ struct ConvertVectorTransferRead final
loc, VectorType::get(numElements * scale, oldElementType), newRead);
Value result = bitCast->getResult(0);
- if (foldedIntraVectorOffset) {
- if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
- }
- } else {
+ if (!foldedIntraVectorOffset) {
auto zeros = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
linearizedInfo.intraDataOffset,
origElements);
+ } else if (isUnalignedEmulation) {
+ result =
+ staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+ *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
>From f72ac5c339de0a3ae065fb7e35f69e9e56760476 Mon Sep 17 00:00:00 2001
From: Alan Li <alan.li at me.com>
Date: Mon, 11 Nov 2024 20:14:06 -0500
Subject: [PATCH 7/7] another update according to comments.
---
.../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 604d261b4513d7..c1324e4f3a8ea7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -188,15 +188,15 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
return dest;
}
-/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
+/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
TypedValue<VectorType> source,
Value dest, OpFoldResult destOffsetVar,
- int64_t length) {
+ size_t length) {
assert(length > 0 && "length must be greater than 0");
Value destOffsetVal =
getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
- for (int i = 0; i < length; ++i) {
+ for (size_t i = 0; i < length; ++i) {
auto insertLoc = i == 0
? destOffsetVal
: rewriter.create<arith::AddIOp>(
More information about the Mlir-commits
mailing list