[Mlir-commits] [mlir] [MLIR] support dynamic indexing of `vector.maskedload` in `VectorEmulateNarrowTypes` (PR #115070)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 5 13:33:03 PST 2024


https://github.com/lialan created https://github.com/llvm/llvm-project/pull/115070

Based on existing emulating scheme, this patch expands to support dynamic indexing by dynamically create intermediate new mask, new pass thru vector and dynamically insert the result into destination vector.

the dynamic parts are constructed by multiple `vector.extract` and `vector.insert` to rearrange the original mask/passthru vector, as `vector.insert_strided_slice` and `vector.extract_strided_slice` only take static offsets and indices.

>From 5eebcc0daa1f1594955159e3d3ea13512dcacb41 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Wed, 30 Oct 2024 19:37:11 +0000
Subject: [PATCH 1/2] Implement dynamic indexing for MaskedLoads

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 101 ++++++++++++------
 .../vector-emulate-narrow-type-unaligned.mlir |  52 +++++++++
 2 files changed, 120 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f169dab3bdd9af..3c94e992d695c0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int origElements, int scale,
                                                   int intraDataOffset = 0) {
+  assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
   auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
 
   Operation *maskOp = mask.getDefiningOp();
@@ -182,6 +183,27 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
   return dest;
 }
 
+/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
+static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
+                                        TypedValue<VectorType> source,
+                                        Value dest, OpFoldResult destOffsetVar,
+                                        int64_t length) {
+  assert(length > 0 && "length must be greater than 0");
+  for (int i = 0; i < length; ++i) {
+    Value insertLoc;
+    if (i == 0) {
+      insertLoc = destOffsetVar.dyn_cast<Value>();
+    } else {
+      insertLoc = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
+          rewriter.create<arith::ConstantIndexOp>(loc, i));
+    }
+    auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
+    dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
+  }
+  return dest;
+}
+
 /// Returns the op sequence for an emulated sub-byte data type vector load.
 /// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
 /// The load location is given by `base` and `linearizedIndices`, and the
@@ -199,7 +221,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
   return rewriter.create<vector::BitCastOp>(
       loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
       newLoad);
-};
+}
 
 namespace {
 
@@ -546,29 +568,30 @@ struct ConvertVectorMaskedLoad final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic intra vector offset
-      return failure();
-    }
-
-    FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
-                            *foldedIntraVectorOffset);
+    auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
+    FailureOr<Operation *> newMask = getCompressedMaskOp(
+        rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
     if (failed(newMask))
       return failure();
 
+    Value passthru = op.getPassThru();
+
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
     auto loadType = VectorType::get(numElements, newElementType);
     auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
 
-    Value passthru = op.getPassThru();
-    if (isUnalignedEmulation) {
-      // create an empty vector of the new type
-      auto emptyVector = rewriter.create<arith::ConstantOp>(
-          loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
-      passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
-                                           *foldedIntraVectorOffset);
+    auto emptyVector = rewriter.create<arith::ConstantOp>(
+        loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        passthru = staticallyInsertSubvector(
+            rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
+      }
+    } else {
+      passthru = dynamicallyInsertSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
+          emptyVector, linearizedInfo.intraDataOffset, origElements);
     }
     auto newPassThru =
         rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -585,23 +608,36 @@ struct ConvertVectorMaskedLoad final
         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
 
     Value mask = op.getMask();
-    if (isUnalignedEmulation) {
-      auto newSelectMaskType =
-          VectorType::get(numElements * scale, rewriter.getI1Type());
-      // TODO: can fold if op's mask is constant
-      auto emptyVector = rewriter.create<arith::ConstantOp>(
-          loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
-      mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
-                                       *foldedIntraVectorOffset);
+    auto newSelectMaskType =
+        VectorType::get(numElements * scale, rewriter.getI1Type());
+    // TODO: try to fold if op's mask is constant
+    auto emptyMask = rewriter.create<arith::ConstantOp>(
+        loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
+                                         *foldedIntraVectorOffset);
+      }
+    } else {
+      mask = dynamicallyInsertSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+          linearizedInfo.intraDataOffset, origElements);
     }
 
     Value result =
         rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
-
-    if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result =
+            staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+                                       *foldedIntraVectorOffset, origElements);
+      }
+    } else {
+      auto resultVector = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      result = dynamicallyExtractSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+          linearizedInfo.intraDataOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
@@ -659,10 +695,9 @@ struct ConvertVectorTransferRead final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    auto maxIntraVectorOffset =
-        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
+    auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
     auto numElements =
-        llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 0cecaddc5733e2..efa31b8bf5ac7d 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -183,3 +183,55 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
 // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
 // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+// -----
+
+func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %cst = arith.constant dense<0> : vector<3x3xi2>
+  %c2 = arith.constant 2 : index
+  %mask = vector.constant_mask [3] : vector<3xi1>
+  %1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
+    memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
+// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
+// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
+// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
+// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
+// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
+// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
+// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
+// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
+// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
+// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// extracts:
+// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
+// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
+// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
+// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
+// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
+// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2>
+// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
+// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
+// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>

>From f3c2d3ac5a1ba10c1a571c6bed01ed83161b8fea Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Tue, 5 Nov 2024 21:32:23 +0000
Subject: [PATCH 2/2] Small update

---
 .../Vector/Transforms/VectorEmulateNarrowType.cpp  | 14 ++++++--------
 .../vector-emulate-narrow-type-unaligned.mlir      |  1 -
 2 files changed, 6 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 3c94e992d695c0..56273ac2899d7e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -190,14 +190,12 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
                                         int64_t length) {
   assert(length > 0 && "length must be greater than 0");
   for (int i = 0; i < length; ++i) {
-    Value insertLoc;
-    if (i == 0) {
-      insertLoc = destOffsetVar.dyn_cast<Value>();
-    } else {
-      insertLoc = rewriter.create<arith::AddIOp>(
-          loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
-          rewriter.create<arith::ConstantIndexOp>(loc, i));
-    }
+    Value insertLoc =
+        1 == 0
+            ? destOffsetVar.dyn_cast<Value>()
+            : rewriter.create<arith::AddIOp>(
+                  loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
+                  rewriter.create<arith::ConstantIndexOp>(loc, i));
     auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
     dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
   }
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index efa31b8bf5ac7d..6a10a2f9ed32fe 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -219,7 +219,6 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 // CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
 // CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
 // CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
-// extracts:
 // CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
 // CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
 // CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>



More information about the Mlir-commits mailing list