[Mlir-commits] [mlir] [MLIR] support dynamic indexing of `vector.maskedload` in `VectorEmulateNarrowTypes` (PR #115070)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 5 13:33:42 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: lialan (lialan)

<details>
<summary>Changes</summary>

Based on existing emulating scheme, this patch expands to support dynamic indexing by dynamically create intermediate new mask, new pass thru vector and dynamically insert the result into destination vector.

the dynamic parts are constructed by multiple `vector.extract` and `vector.insert` to rearrange the original mask/passthru vector, as `vector.insert_strided_slice` and `vector.extract_strided_slice` only take static offsets and indices.

---
Full diff: https://github.com/llvm/llvm-project/pull/115070.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+66-33) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+51) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f169dab3bdd9af..56273ac2899d7e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int origElements, int scale,
                                                   int intraDataOffset = 0) {
+  assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
   auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
 
   Operation *maskOp = mask.getDefiningOp();
@@ -182,6 +183,25 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
   return dest;
 }
 
+/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
+static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
+                                        TypedValue<VectorType> source,
+                                        Value dest, OpFoldResult destOffsetVar,
+                                        int64_t length) {
+  assert(length > 0 && "length must be greater than 0");
+  for (int i = 0; i < length; ++i) {
+    Value insertLoc =
+        1 == 0
+            ? destOffsetVar.dyn_cast<Value>()
+            : rewriter.create<arith::AddIOp>(
+                  loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
+                  rewriter.create<arith::ConstantIndexOp>(loc, i));
+    auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
+    dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
+  }
+  return dest;
+}
+
 /// Returns the op sequence for an emulated sub-byte data type vector load.
 /// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
 /// The load location is given by `base` and `linearizedIndices`, and the
@@ -199,7 +219,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
   return rewriter.create<vector::BitCastOp>(
       loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
       newLoad);
-};
+}
 
 namespace {
 
@@ -546,29 +566,30 @@ struct ConvertVectorMaskedLoad final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic intra vector offset
-      return failure();
-    }
-
-    FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
-                            *foldedIntraVectorOffset);
+    auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
+    FailureOr<Operation *> newMask = getCompressedMaskOp(
+        rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
     if (failed(newMask))
       return failure();
 
+    Value passthru = op.getPassThru();
+
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
     auto loadType = VectorType::get(numElements, newElementType);
     auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
 
-    Value passthru = op.getPassThru();
-    if (isUnalignedEmulation) {
-      // create an empty vector of the new type
-      auto emptyVector = rewriter.create<arith::ConstantOp>(
-          loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
-      passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
-                                           *foldedIntraVectorOffset);
+    auto emptyVector = rewriter.create<arith::ConstantOp>(
+        loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        passthru = staticallyInsertSubvector(
+            rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
+      }
+    } else {
+      passthru = dynamicallyInsertSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
+          emptyVector, linearizedInfo.intraDataOffset, origElements);
     }
     auto newPassThru =
         rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -585,23 +606,36 @@ struct ConvertVectorMaskedLoad final
         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
 
     Value mask = op.getMask();
-    if (isUnalignedEmulation) {
-      auto newSelectMaskType =
-          VectorType::get(numElements * scale, rewriter.getI1Type());
-      // TODO: can fold if op's mask is constant
-      auto emptyVector = rewriter.create<arith::ConstantOp>(
-          loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
-      mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
-                                       *foldedIntraVectorOffset);
+    auto newSelectMaskType =
+        VectorType::get(numElements * scale, rewriter.getI1Type());
+    // TODO: try to fold if op's mask is constant
+    auto emptyMask = rewriter.create<arith::ConstantOp>(
+        loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
+                                         *foldedIntraVectorOffset);
+      }
+    } else {
+      mask = dynamicallyInsertSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+          linearizedInfo.intraDataOffset, origElements);
     }
 
     Value result =
         rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
-
-    if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result =
+            staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+                                       *foldedIntraVectorOffset, origElements);
+      }
+    } else {
+      auto resultVector = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      result = dynamicallyExtractSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+          linearizedInfo.intraDataOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
@@ -659,10 +693,9 @@ struct ConvertVectorTransferRead final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    auto maxIntraVectorOffset =
-        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
+    auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
     auto numElements =
-        llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 0cecaddc5733e2..6a10a2f9ed32fe 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -183,3 +183,54 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
 // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
 // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+// -----
+
+func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %cst = arith.constant dense<0> : vector<3x3xi2>
+  %c2 = arith.constant 2 : index
+  %mask = vector.constant_mask [3] : vector<3xi1>
+  %1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
+    memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
+// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
+// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
+// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
+// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
+// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
+// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
+// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
+// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
+// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
+// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
+// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
+// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
+// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
+// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
+// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2>
+// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
+// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
+// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>

``````````

</details>


https://github.com/llvm/llvm-project/pull/115070


More information about the Mlir-commits mailing list