[Mlir-commits] [mlir] [MLIR] support dynamic indexing in `VectorEmulateNarrowTypes` (PR #114169)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 4 12:01:29 PST 2024


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/114169

>From 2effa6ab20f5a7d15fdc3faf6710fd970e6e8219 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Tue, 29 Oct 2024 14:03:23 +0000
Subject: [PATCH 1/8] Implement VectorLoadOp

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 116 ++++++++++++++----
 ...emulate-narrow-type-unaligned-dynamic.mlir |  53 ++++++++
 2 files changed, 143 insertions(+), 26 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1d6f8a991d9b5b..09bc7256fc3cf0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -149,6 +150,61 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
                                                        dest, offsets, strides);
 }
 
+static void dynamicallyExtractElementsToVector(
+    RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
+    Value destVec, OpFoldResult srcOffsetVar, int64_t loopSize) {
+  /*
+  // Create affine maps for the lower and upper bounds
+  AffineMap lowerBoundMap = AffineMap::getConstantMap(0, rewriter.getContext());
+  AffineMap upperBoundMap =
+      AffineMap::getConstantMap(loopSize, rewriter.getContext());
+
+  auto forLoop = rewriter.create<affine::AffineForOp>(
+      loc, ValueRange{}, lowerBoundMap, ValueRange{}, upperBoundMap, 1,
+      ArrayRef<Value>(destVec));
+
+  OpBuilder builder =
+      OpBuilder::atBlockEnd(forLoop.getBody(), rewriter.getListener());
+
+  auto iv = forLoop.getInductionVar();
+
+  auto loopDestVec = forLoop.getRegionIterArgs()[0];
+  auto extractLoc = builder.create<arith::AddIOp>(
+      loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(), iv);
+  auto extractElemOp = builder.create<vector::ExtractElementOp>(
+      loc, elemType, srcVec, extractLoc);
+  auto insertElemOp = builder.create<vector::InsertElementOp>(
+      loc, extractElemOp, loopDestVec, iv);
+  builder.create<affine::AffineYieldOp>(loc,
+                                        ValueRange{insertElemOp->getResult(0)});
+  return forLoop->getResult(0);
+  */
+  for (int i = 0; i < loopSize; ++i) {
+    Value extractLoc;
+    if (i == 0) {
+      extractLoc = srcOffsetVar.dyn_cast<Value>();
+    } else {
+      extractLoc = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(),
+          rewriter.create<arith::ConstantIndexOp>(loc, i));
+    }
+    auto extractOp =
+        rewriter.create<vector::ExtractOp>(loc, srcVec, extractLoc);
+    rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
+  }
+}
+
+static TypedValue<VectorType>
+emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
+                   Value base, OpFoldResult linearizedIndices, int64_t numBytes,
+                   int64_t scale, Type oldElememtType, Type newElementType) {
+  auto newLoad = rewriter.create<vector::LoadOp>(
+      loc, VectorType::get(numBytes, newElementType), base,
+      getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+  return rewriter.create<vector::BitCastOp>(
+      loc, VectorType::get(numBytes * scale, oldElememtType), newLoad);
+};
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -380,26 +436,29 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic intra vector offset
-      return failure();
-    }
-
+    // always load enough elements which can cover the original elements
+    auto maxintraDataOffset =
+        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
-    auto newLoad = rewriter.create<vector::LoadOp>(
-        loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
-
-    Value result = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements * scale, oldElementType), newLoad);
+        llvm::divideCeil(maxintraDataOffset + origElements, scale);
+    Value result =
+        emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
+                           numElements, scale, oldElementType, newElementType);
 
-    if (isUnalignedEmulation) {
-      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                      *foldedIntraVectorOffset, origElements);
+      }
+      rewriter.replaceOp(op, result);
+    } else {
+      auto resultVector = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      dynamicallyExtractElementsToVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+          linearizedInfo.intraDataOffset, origElements);
+      rewriter.replaceOp(op, resultVector);
     }
-
-    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -604,13 +663,10 @@ struct ConvertVectorTransferRead final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic inra-vector offset
-      return failure();
-    }
-
+    auto maxIntraVectorOffset =
+        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -621,9 +677,17 @@ struct ConvertVectorTransferRead final
         loc, VectorType::get(numElements * scale, oldElementType), newRead);
 
     Value result = bitCast->getResult(0);
-    if (isUnalignedEmulation) {
-      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                      *foldedIntraVectorOffset, origElements);
+      }
+    } else {
+      result = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      dynamicallyExtractElementsToVector(rewriter, loc, bitCast, result,
+                                         linearizedInfo.intraDataOffset,
+                                         origElements);
     }
     rewriter.replaceOp(op, result);
 
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
new file mode 100644
index 00000000000000..a92e62538c5332
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+
+// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
+// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
+func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %cst = arith.constant dense<0> : vector<3x3xi2>
+    %1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
+    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+    return %2 : vector<3x3xi2>
+}
+
+// CHECK: func @vector_load_i2
+// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
+
+//-----
+
+func.func @vector_transfer_read_i2(%arg1: index, %arg2: index) -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0i2 = arith.constant 0 : i2
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: func @vector_transfer_read_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>

>From 2580b469e4a1e3fcb62a0eafca9e0b128d93ca4f Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Wed, 30 Oct 2024 17:28:39 +0000
Subject: [PATCH 2/8] Update tests

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 52 +++++--------
 ...emulate-narrow-type-unaligned-dynamic.mlir |  5 +-
 .../vector-emulate-narrow-type-unaligned.mlir | 73 ++++++++++++++++++-
 3 files changed, 93 insertions(+), 37 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 09bc7256fc3cf0..dbd9b2caccd3fe 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -130,6 +130,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   return newMask;
 }
 
+/// A wrapper function for emitting `vector.extract_strided_slice`.
 static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
                                   VectorType extractType, Value vector,
                                   int64_t frontOffset, int64_t subvecSize) {
@@ -142,6 +143,7 @@ static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
       ->getResult(0);
 }
 
+/// A wrapper function for emitting `vector.insert_strided_slice`.
 static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
                                  Value src, Value dest, int64_t offset) {
   auto offsets = rewriter.getI64ArrayAttr({offset});
@@ -150,36 +152,14 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
                                                        dest, offsets, strides);
 }
 
+/// Extracts `lengthSubvec` elements from `srcVec` into `destVec` starting at
+/// the offset specified by `srcOffsetVar`. Use this function when
+/// `srcOffsetVar` is not a constant, making it impossible to use
+/// vector.extract_strided_slice, as it requires constant offsets.
 static void dynamicallyExtractElementsToVector(
     RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
-    Value destVec, OpFoldResult srcOffsetVar, int64_t loopSize) {
-  /*
-  // Create affine maps for the lower and upper bounds
-  AffineMap lowerBoundMap = AffineMap::getConstantMap(0, rewriter.getContext());
-  AffineMap upperBoundMap =
-      AffineMap::getConstantMap(loopSize, rewriter.getContext());
-
-  auto forLoop = rewriter.create<affine::AffineForOp>(
-      loc, ValueRange{}, lowerBoundMap, ValueRange{}, upperBoundMap, 1,
-      ArrayRef<Value>(destVec));
-
-  OpBuilder builder =
-      OpBuilder::atBlockEnd(forLoop.getBody(), rewriter.getListener());
-
-  auto iv = forLoop.getInductionVar();
-
-  auto loopDestVec = forLoop.getRegionIterArgs()[0];
-  auto extractLoc = builder.create<arith::AddIOp>(
-      loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(), iv);
-  auto extractElemOp = builder.create<vector::ExtractElementOp>(
-      loc, elemType, srcVec, extractLoc);
-  auto insertElemOp = builder.create<vector::InsertElementOp>(
-      loc, extractElemOp, loopDestVec, iv);
-  builder.create<affine::AffineYieldOp>(loc,
-                                        ValueRange{insertElemOp->getResult(0)});
-  return forLoop->getResult(0);
-  */
-  for (int i = 0; i < loopSize; ++i) {
+    Value destVec, OpFoldResult srcOffsetVar, int64_t lengthSubvec) {
+  for (int i = 0; i < lengthSubvec; ++i) {
     Value extractLoc;
     if (i == 0) {
       extractLoc = srcOffsetVar.dyn_cast<Value>();
@@ -194,15 +174,21 @@ static void dynamicallyExtractElementsToVector(
   }
 }
 
+/// Load `numLoadedElements` of `newElementType` from `base` at
+/// `linearizedIndices`, then bitcast the result into a vector of
+/// `oldElementType`.
 static TypedValue<VectorType>
 emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
-                   Value base, OpFoldResult linearizedIndices, int64_t numBytes,
-                   int64_t scale, Type oldElememtType, Type newElementType) {
+                   Value base, OpFoldResult linearizedIndices,
+                   int64_t numLoadedElements, Type oldElememtType,
+                   Type newElementType) {
+  auto scale = newElementType.getIntOrFloatBitWidth() /
+               oldElememtType.getIntOrFloatBitWidth();
   auto newLoad = rewriter.create<vector::LoadOp>(
-      loc, VectorType::get(numBytes, newElementType), base,
+      loc, VectorType::get(numLoadedElements, newElementType), base,
       getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
   return rewriter.create<vector::BitCastOp>(
-      loc, VectorType::get(numBytes * scale, oldElememtType), newLoad);
+      loc, VectorType::get(numLoadedElements * scale, oldElememtType), newLoad);
 };
 
 namespace {
@@ -443,7 +429,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
         llvm::divideCeil(maxintraDataOffset + origElements, scale);
     Value result =
         emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
-                           numElements, scale, oldElementType, newElementType);
+                           numElements, oldElementType, newElementType);
 
     if (foldedIntraVectorOffset) {
       if (isUnalignedEmulation) {
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
index a92e62538c5332..2e7ec43df31d10 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
@@ -2,14 +2,13 @@
 
 // CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
 // CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
-func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
+func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3xi2> {
     %0 = memref.alloc() : memref<3x3xi2>
     %c0 = arith.constant 0 : index
     %c2 = arith.constant 2 : index
     %cst = arith.constant dense<0> : vector<3x3xi2>
     %1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
-    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
-    return %2 : vector<3x3xi2>
+    return %1 : vector<3xi2>
 }
 
 // CHECK: func @vector_load_i2
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ecbad7968225d..6cfe623c8af42c 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -19,6 +19,25 @@ func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
 
 //-----
 
+func.func @vector_load_i2_unaligned(%arg1: index, %arg2: index) -> vector<3x3xi2> {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %cst = arith.constant dense<0> : vector<3x3xi2>
+    %1 = vector.load %0[%c0, %c1] : memref<3x3xi2>, vector<3xi2>
+    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+    return %2 : vector<3x3xi2>
+}
+
+// CHECK: func @vector_load_i2_unaligned
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = arith.constant 0 : index
+// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<1xi8>
+// CHECK: %[[VEC_I2:.+]] = vector.bitcast %[[VEC]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[EXCTRACT:.+]] = vector.extract_strided_slice %[[VEC_I2]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
+
+//-----
+
 func.func @vector_transfer_read_i2() -> vector<3xi2> {
  %0 = memref.alloc() : memref<3x3xi2>
  %c0i2 = arith.constant 0 : i2
@@ -37,6 +56,26 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
 
 //-----
 
+func.func @vector_transfer_read_i2_unaligned() -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0i2 = arith.constant 0 : i2
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = vector.transfer_read %0[%c0, %c1], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: func @vector_transfer_read_i2_unaligned
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[PAD:.+]] = arith.constant 0 : i2
+// CHECK: %[[EXT:.+]] = arith.extui %[[PAD]] : i2 to i8
+// CHECK: %[[INDEX:.+]] = arith.constant 0 : index
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[EXT]] : memref<3xi8>, vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<1xi8> to vector<4xi2>
+// CHECK: vector.extract_strided_slice %[[BITCAST]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
+
+//-----
+
 func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
     %0 = memref.alloc() : memref<3x5xi2>
     %cst = arith.constant dense<0> : vector<3x5xi2>
@@ -64,4 +103,36 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
 // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
 // CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi1> into vector<8xi1>
 // CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
-// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2> 
+// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
+
+//-----
+
+func.func @vector_cst_maskedload_i2_unaligned(%passthru: vector<5xi2>) -> vector<3x5xi2> {
+    %0 = memref.alloc() : memref<3x5xi2>
+    %cst = arith.constant dense<0> : vector<3x5xi2>
+    %mask = vector.constant_mask [3] : vector<5xi1>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %1 = vector.maskedload %0[%c0, %c1], %mask, %passthru :
+      memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+    %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
+    return %2 : vector<3x5xi2>
+}
+
+
+// CHECK: func @vector_cst_maskedload_i2_unaligned
+// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
+// CHECK: %[[NEWMASK:.+]] = arith.constant dense<[true, false]> : vector<2xi1>
+// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %arg0, %[[VESSEL]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
+// CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %alloc[%[[C0]]], %[[NEWMASK:.+]], %[[BITCAST1]]
+// CHECK-SAME: : memref<4xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
+// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>

>From 0b9bdce70f0a4179488b130e91ac08c4483d4fd2 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Wed, 30 Oct 2024 22:07:03 +0000
Subject: [PATCH 3/8] fix bugs

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 24 ++++++++++---------
 1 file changed, 13 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dbd9b2caccd3fe..55a3a191b2ccc4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -156,9 +156,11 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
 /// the offset specified by `srcOffsetVar`. Use this function when
 /// `srcOffsetVar` is not a constant, making it impossible to use
 /// vector.extract_strided_slice, as it requires constant offsets.
-static void dynamicallyExtractElementsToVector(
-    RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
-    Value destVec, OpFoldResult srcOffsetVar, int64_t lengthSubvec) {
+static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
+                                         TypedValue<VectorType> srcVec,
+                                         Value destVec,
+                                         OpFoldResult srcOffsetVar,
+                                         int64_t lengthSubvec) {
   for (int i = 0; i < lengthSubvec; ++i) {
     Value extractLoc;
     if (i == 0) {
@@ -170,8 +172,9 @@ static void dynamicallyExtractElementsToVector(
     }
     auto extractOp =
         rewriter.create<vector::ExtractOp>(loc, srcVec, extractLoc);
-    rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
+    destVec = rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
   }
+  return destVec;
 }
 
 /// Load `numLoadedElements` of `newElementType` from `base` at
@@ -436,15 +439,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
         result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
                                       *foldedIntraVectorOffset, origElements);
       }
-      rewriter.replaceOp(op, result);
     } else {
       auto resultVector = rewriter.create<arith::ConstantOp>(
           loc, op.getType(), rewriter.getZeroAttr(op.getType()));
-      dynamicallyExtractElementsToVector(
+      result = dynamicallyExtractSubVector(
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
           linearizedInfo.intraDataOffset, origElements);
-      rewriter.replaceOp(op, resultVector);
     }
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -669,11 +671,11 @@ struct ConvertVectorTransferRead final
                                       *foldedIntraVectorOffset, origElements);
       }
     } else {
-      result = rewriter.create<arith::ConstantOp>(
+      auto zeros = rewriter.create<arith::ConstantOp>(
           loc, op.getType(), rewriter.getZeroAttr(op.getType()));
-      dynamicallyExtractElementsToVector(rewriter, loc, bitCast, result,
-                                         linearizedInfo.intraDataOffset,
-                                         origElements);
+      result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
+                                           linearizedInfo.intraDataOffset,
+                                           origElements);
     }
     rewriter.replaceOp(op, result);
 

>From a9d72602a9b50cb4322d05e5ed705647e9605cdf Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Fri, 1 Nov 2024 02:11:17 +0000
Subject: [PATCH 4/8] Refactor and fixes

---
 .../Transforms/VectorEmulateNarrowType.cpp    |  66 ++++++-----
 ...emulate-narrow-type-unaligned-dynamic.mlir |  52 --------
 .../vector-emulate-narrow-type-unaligned.mlir | 111 ++++++++----------
 3 files changed, 84 insertions(+), 145 deletions(-)
 delete mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 55a3a191b2ccc4..4fa001a95f8cab 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -38,16 +38,17 @@ using namespace mlir;
 
 /// Returns a compressed mask. The mask value is set only if any mask is present
 /// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
-/// equals to 2, the following mask:
+/// equals to 1 (intraDataOffset strictly smaller than scale), the following
+/// mask:
 ///
-///   %mask = [1, 1, 1, 0, 0, 0]
+///   %mask = [1, 1, 0, 0, 0, 0]
 ///
 /// will first be padded with number of `intraDataOffset` zeros:
-///   %mask = [0, 0, 1, 1, 1, 0, 0, 0]
+///   %mask = [0, 1, 1, 0, 0, 0, 0, 0]
 ///
 /// then it will return the following new compressed mask:
 ///
-///   %mask = [0, 1, 1, 0]
+///   %mask = [1, 1, 0, 0]
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int origElements, int scale,
@@ -76,9 +77,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   shape.back() = numElements;
   auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
   if (createMaskOp) {
-    // TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
-    if (intraDataOffset != 0)
-      return failure();
     OperandRange maskOperands = createMaskOp.getOperands();
     size_t numMaskOperands = maskOperands.size();
     AffineExpr s0;
@@ -130,10 +128,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   return newMask;
 }
 
-/// A wrapper function for emitting `vector.extract_strided_slice`.
+/// A wrapper function for emitting `vector.extract_strided_slice`. The vector
+/// has to be of 1-D shape.
 static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
                                   VectorType extractType, Value vector,
                                   int64_t frontOffset, int64_t subvecSize) {
+  // get vector's vector type:
+  auto vectorType = dyn_cast<VectorType>(vector.getType());
+  assert(vectorType && "expected vector type");
+  assert(vectorType.getShape().size() == 1 && "expected 1-D vector type");
+  assert(extractType.getShape().size() == 1 &&
+         "extractType must be 1-D vector type");
+
   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
   auto sizes = rewriter.getI64ArrayAttr({subvecSize});
   auto strides = rewriter.getI64ArrayAttr({1});
@@ -143,9 +149,17 @@ static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
       ->getResult(0);
 }
 
-/// A wrapper function for emitting `vector.insert_strided_slice`.
+/// A wrapper function for emitting `vector.insert_strided_slice`. The source
+/// and dest vectors must be of 1-D shape.
 static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
                                  Value src, Value dest, int64_t offset) {
+  auto srcType = dyn_cast<VectorType>(src.getType());
+  assert(srcType && "expected vector type");
+  assert(srcType.getShape().size() == 1 && "expected 1-D vector type");
+  auto destType = dyn_cast<VectorType>(dest.getType());
+  assert(destType && "expected vector type");
+  assert(destType.getShape().size() == 1 && "expected 1-D vector type");
+
   auto offsets = rewriter.getI64ArrayAttr({offset});
   auto strides = rewriter.getI64ArrayAttr({1});
   return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -157,24 +171,20 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
 /// `srcOffsetVar` is not a constant, making it impossible to use
 /// vector.extract_strided_slice, as it requires constant offsets.
 static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
-                                         TypedValue<VectorType> srcVec,
-                                         Value destVec,
-                                         OpFoldResult srcOffsetVar,
-                                         int64_t lengthSubvec) {
-  for (int i = 0; i < lengthSubvec; ++i) {
-    Value extractLoc;
-    if (i == 0) {
-      extractLoc = srcOffsetVar.dyn_cast<Value>();
-    } else {
-      extractLoc = rewriter.create<arith::AddIOp>(
-          loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(),
-          rewriter.create<arith::ConstantIndexOp>(loc, i));
-    }
+                                         TypedValue<VectorType> source,
+                                         Value dest, OpFoldResult offset,
+                                         int64_t numElementsToExtract) {
+  for (int i = 0; i < numElementsToExtract; ++i) {
+    Value extractLoc =
+        (i == 0) ? offset.dyn_cast<Value>()
+                 : rewriter.create<arith::AddIOp>(
+                       loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
+                       rewriter.create<arith::ConstantIndexOp>(loc, i));
     auto extractOp =
-        rewriter.create<vector::ExtractOp>(loc, srcVec, extractLoc);
-    destVec = rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
+        rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
+    dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
   }
-  return destVec;
+  return dest;
 }
 
 /// Load `numLoadedElements` of `newElementType` from `base` at
@@ -183,15 +193,15 @@ static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
 static TypedValue<VectorType>
 emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
                    Value base, OpFoldResult linearizedIndices,
-                   int64_t numLoadedElements, Type oldElememtType,
+                   int64_t numElementsToLoad, Type oldElememtType,
                    Type newElementType) {
   auto scale = newElementType.getIntOrFloatBitWidth() /
                oldElememtType.getIntOrFloatBitWidth();
   auto newLoad = rewriter.create<vector::LoadOp>(
-      loc, VectorType::get(numLoadedElements, newElementType), base,
+      loc, VectorType::get(numElementsToLoad, newElementType), base,
       getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
   return rewriter.create<vector::BitCastOp>(
-      loc, VectorType::get(numLoadedElements * scale, oldElememtType), newLoad);
+      loc, VectorType::get(numElementsToLoad * scale, oldElememtType), newLoad);
 };
 
 namespace {
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
deleted file mode 100644
index 2e7ec43df31d10..00000000000000
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
+++ /dev/null
@@ -1,52 +0,0 @@
-// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
-
-// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
-// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
-func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3xi2> {
-    %0 = memref.alloc() : memref<3x3xi2>
-    %c0 = arith.constant 0 : index
-    %c2 = arith.constant 2 : index
-    %cst = arith.constant dense<0> : vector<3x3xi2>
-    %1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
-    return %1 : vector<3xi2>
-}
-
-// CHECK: func @vector_load_i2
-// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
-// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
-// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
-// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
-// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
-// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
-// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
-// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
-// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
-
-//-----
-
-func.func @vector_transfer_read_i2(%arg1: index, %arg2: index) -> vector<3xi2> {
- %0 = memref.alloc() : memref<3x3xi2>
- %c0i2 = arith.constant 0 : i2
- %1 = vector.transfer_read %0[%arg1, %arg2], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
- return %1 : vector<3xi2>
-}
-
-// CHECK: func @vector_transfer_read_i2
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
-// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
-// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
-// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
-// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
-// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
-// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
-// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
-// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
-// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 6cfe623c8af42c..ae6db8a0198aa7 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -1,5 +1,8 @@
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
 
+// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
+// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
+
 func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
     %0 = memref.alloc() : memref<3x3xi2>
     %c0 = arith.constant 0 : index
@@ -19,25 +22,6 @@ func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
 
 //-----
 
-func.func @vector_load_i2_unaligned(%arg1: index, %arg2: index) -> vector<3x3xi2> {
-    %0 = memref.alloc() : memref<3x3xi2>
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %cst = arith.constant dense<0> : vector<3x3xi2>
-    %1 = vector.load %0[%c0, %c1] : memref<3x3xi2>, vector<3xi2>
-    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
-    return %2 : vector<3x3xi2>
-}
-
-// CHECK: func @vector_load_i2_unaligned
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
-// CHECK: %[[INDEX:.+]] = arith.constant 0 : index
-// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<1xi8>
-// CHECK: %[[VEC_I2:.+]] = vector.bitcast %[[VEC]] : vector<1xi8> to vector<4xi2>
-// CHECK: %[[EXCTRACT:.+]] = vector.extract_strided_slice %[[VEC_I2]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
-
-//-----
-
 func.func @vector_transfer_read_i2() -> vector<3xi2> {
  %0 = memref.alloc() : memref<3x3xi2>
  %c0i2 = arith.constant 0 : i2
@@ -56,26 +40,6 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
 
 //-----
 
-func.func @vector_transfer_read_i2_unaligned() -> vector<3xi2> {
- %0 = memref.alloc() : memref<3x3xi2>
- %c0i2 = arith.constant 0 : i2
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %1 = vector.transfer_read %0[%c0, %c1], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
- return %1 : vector<3xi2>
-}
-
-// CHECK: func @vector_transfer_read_i2_unaligned
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
-// CHECK: %[[PAD:.+]] = arith.constant 0 : i2
-// CHECK: %[[EXT:.+]] = arith.extui %[[PAD]] : i2 to i8
-// CHECK: %[[INDEX:.+]] = arith.constant 0 : index
-// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[EXT]] : memref<3xi8>, vector<1xi8>
-// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<1xi8> to vector<4xi2>
-// CHECK: vector.extract_strided_slice %[[BITCAST]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
-
-//-----
-
 func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
     %0 = memref.alloc() : memref<3x5xi2>
     %cst = arith.constant dense<0> : vector<3x5xi2>
@@ -107,32 +71,49 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
 
 //-----
 
-func.func @vector_cst_maskedload_i2_unaligned(%passthru: vector<5xi2>) -> vector<3x5xi2> {
-    %0 = memref.alloc() : memref<3x5xi2>
-    %cst = arith.constant dense<0> : vector<3x5xi2>
-    %mask = vector.constant_mask [3] : vector<5xi1>
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %1 = vector.maskedload %0[%c0, %c1], %mask, %passthru :
-      memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
-    %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
-    return %2 : vector<3x5xi2>
+func.func @vector_load_i2_dynamic_indexing(%arg1: index, %arg2: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %cst = arith.constant dense<0> : vector<3x3xi2>
+  %1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
+  return %1 : vector<3xi2>
 }
 
+// CHECK: func @vector_load_i2_dynamic_indexing
+// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
 
-// CHECK: func @vector_cst_maskedload_i2_unaligned
-// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
-// CHECK: %[[NEWMASK:.+]] = arith.constant dense<[true, false]> : vector<2xi1>
-// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
-// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %arg0, %[[VESSEL]]
-// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
-// CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<8xi2> to vector<2xi8>
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %alloc[%[[C0]]], %[[NEWMASK:.+]], %[[BITCAST1]]
-// CHECK-SAME: : memref<4xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
-// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
-// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
-// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
-// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
-// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
-// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
+//-----
+
+func.func @vector_transfer_read_i2_dynamic_indexing(%arg1: index, %arg2: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %pad = arith.constant 0 : i2
+  %1 = vector.transfer_read %0[%arg1, %arg2], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK: func @vector_transfer_read_i2_dynamic_indexing
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>

>From b777a607bec5999a5c8355fc0597237bca819ee5 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Fri, 1 Nov 2024 03:05:52 +0000
Subject: [PATCH 5/8] updates

---
 .../Transforms/VectorEmulateNarrowType.cpp    |  1 -
 .../vector-emulate-narrow-type-unaligned.mlir | 58 ++++++++++---------
 2 files changed, 31 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 4fa001a95f8cab..0815f51487a716 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -133,7 +133,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
 static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
                                   VectorType extractType, Value vector,
                                   int64_t frontOffset, int64_t subvecSize) {
-  // get vector's vector type:
   auto vectorType = dyn_cast<VectorType>(vector.getType());
   assert(vectorType && "expected vector type");
   assert(vectorType.getShape().size() == 1 && "expected 1-D vector type");
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index ae6db8a0198aa7..6d9d5da707ae81 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -1,16 +1,20 @@
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
 
+// TODO: remove memref.alloc() in the tests to eliminate noises.
+// memref.alloc exists here because sub-byte vector data types such as i2
+// are currently not supported as input arguments.
+
 // CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
 // CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
 
-func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
-    %0 = memref.alloc() : memref<3x3xi2>
-    %c0 = arith.constant 0 : index
-    %c2 = arith.constant 2 : index
-    %cst = arith.constant dense<0> : vector<3x3xi2>
-    %1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
-    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
-    return %2 : vector<3x3xi2>
+func.func @vector_load_i2() -> vector<3x3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %cst = arith.constant dense<0> : vector<3x3xi2>
+  %1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
+  %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+  return %2 : vector<3x3xi2>
 }
 
 // CHECK: func @vector_load_i2
@@ -23,12 +27,12 @@ func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
 //-----
 
 func.func @vector_transfer_read_i2() -> vector<3xi2> {
- %0 = memref.alloc() : memref<3x3xi2>
- %c0i2 = arith.constant 0 : i2
- %c0 = arith.constant 0 : index
- %c2 = arith.constant 2 : index
- %1 = vector.transfer_read %0[%c2, %c0], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
- return %1 : vector<3xi2>
+  %0 = memref.alloc() : memref<3x3xi2>
+  %pad = arith.constant 0 : i2
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %1 = vector.transfer_read %0[%c2, %c0], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+  return %1 : vector<3xi2>
 }
 
 // CHECK: func @vector_transfer_read_i2
@@ -41,15 +45,15 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
 //-----
 
 func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
-    %0 = memref.alloc() : memref<3x5xi2>
-    %cst = arith.constant dense<0> : vector<3x5xi2>
-    %mask = vector.constant_mask [3] : vector<5xi1>
-    %c0 = arith.constant 0 : index
-    %c2 = arith.constant 2 : index
-    %1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
-      memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
-    %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
-    return %2 : vector<3x5xi2>
+  %0 = memref.alloc() : memref<3x5xi2>
+  %cst = arith.constant dense<0> : vector<3x5xi2>
+  %mask = vector.constant_mask [3] : vector<5xi1>
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
+    memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+  %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
+  return %2 : vector<3x5xi2>
 }
 
 // CHECK: func @vector_cst_maskedload_i2
@@ -71,10 +75,10 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
 
 //-----
 
-func.func @vector_load_i2_dynamic_indexing(%arg1: index, %arg2: index) -> vector<3xi2> {
+func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
   %0 = memref.alloc() : memref<3x3xi2>
   %cst = arith.constant dense<0> : vector<3x3xi2>
-  %1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
+  %1 = vector.load %0[%idx1, %idx2] : memref<3x3xi2>, vector<3xi2>
   return %1 : vector<3xi2>
 }
 
@@ -95,10 +99,10 @@ func.func @vector_load_i2_dynamic_indexing(%arg1: index, %arg2: index) -> vector
 
 //-----
 
-func.func @vector_transfer_read_i2_dynamic_indexing(%arg1: index, %arg2: index) -> vector<3xi2> {
+func.func @vector_transfer_read_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
   %0 = memref.alloc() : memref<3x3xi2>
   %pad = arith.constant 0 : i2
-  %1 = vector.transfer_read %0[%arg1, %arg2], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+  %1 = vector.transfer_read %0[%idx1, %idx2], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
   return %1 : vector<3xi2>
 }
 

>From fc292425d133f6009c65871f20481a93d8b6f7df Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Mon, 4 Nov 2024 02:03:21 +0000
Subject: [PATCH 6/8] update according to comments

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 61 +++++++++----------
 .../vector-emulate-narrow-type-unaligned.mlir | 19 +++---
 2 files changed, 41 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 0815f51487a716..de44bb6299dcc1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -128,47 +128,44 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   return newMask;
 }
 
-/// A wrapper function for emitting `vector.extract_strided_slice`. The vector
-/// has to be of 1-D shape.
+/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
+/// emitting `vector.extract_strided_slice`.
 static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
-                                  VectorType extractType, Value vector,
+                                  VectorType extractType, Value source,
                                   int64_t frontOffset, int64_t subvecSize) {
-  auto vectorType = dyn_cast<VectorType>(vector.getType());
-  assert(vectorType && "expected vector type");
-  assert(vectorType.getShape().size() == 1 && "expected 1-D vector type");
-  assert(extractType.getShape().size() == 1 &&
-         "extractType must be 1-D vector type");
-
+  auto vectorType = dyn_cast<VectorType>(source.getType());
+  assert(
+      (vectorType && vectorType.getRank() == 1 && extractType.getRank() == 1) &&
+      "expected 1-D source and destination types");
   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
   auto sizes = rewriter.getI64ArrayAttr({subvecSize});
   auto strides = rewriter.getI64ArrayAttr({1});
   return rewriter
-      .create<vector::ExtractStridedSliceOp>(loc, extractType, vector, offsets,
+      .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
                                              sizes, strides)
       ->getResult(0);
 }
 
-/// A wrapper function for emitting `vector.insert_strided_slice`. The source
-/// and dest vectors must be of 1-D shape.
+/// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
+/// at `offset`. it is a wrapper function for emitting
+/// `vector.insert_strided_slice`.
 static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
                                  Value src, Value dest, int64_t offset) {
   auto srcType = dyn_cast<VectorType>(src.getType());
-  assert(srcType && "expected vector type");
-  assert(srcType.getShape().size() == 1 && "expected 1-D vector type");
   auto destType = dyn_cast<VectorType>(dest.getType());
-  assert(destType && "expected vector type");
-  assert(destType.getShape().size() == 1 && "expected 1-D vector type");
-
+  assert(srcType && srcType.getRank() == 1 && destType &&
+         destType.getRank() == 1 &&
+         "expected source and dest to be vector type");
   auto offsets = rewriter.getI64ArrayAttr({offset});
   auto strides = rewriter.getI64ArrayAttr({1});
   return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
                                                        dest, offsets, strides);
 }
 
-/// Extracts `lengthSubvec` elements from `srcVec` into `destVec` starting at
-/// the offset specified by `srcOffsetVar`. Use this function when
-/// `srcOffsetVar` is not a constant, making it impossible to use
-/// vector.extract_strided_slice, as it requires constant offsets.
+/// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
+/// and size `numElementsToExtract`, and inserts into the `dest` vector. This
+/// Function emits multiple `vector.extract` and `vector.insert` ops, so only
+/// use it when `offset` cannot be folded into a constant value.
 static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
                                          TypedValue<VectorType> source,
                                          Value dest, OpFoldResult offset,
@@ -186,21 +183,23 @@ static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
   return dest;
 }
 
-/// Load `numLoadedElements` of `newElementType` from `base` at
-/// `linearizedIndices`, then bitcast the result into a vector of
-/// `oldElementType`.
+/// Returns the op sequence for an emulated sub-byte datatype vector load.
+/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
+/// The load location is given by `base` and `linearizedIndices`, and the
+/// load size is given by `numEmulatedElementsToLoad`.
 static TypedValue<VectorType>
 emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
                    Value base, OpFoldResult linearizedIndices,
-                   int64_t numElementsToLoad, Type oldElememtType,
-                   Type newElementType) {
-  auto scale = newElementType.getIntOrFloatBitWidth() /
-               oldElememtType.getIntOrFloatBitWidth();
+                   int64_t numEmultedElementsToLoad, Type origElemType,
+                   Type emulatedElemType) {
+  auto scale = emulatedElemType.getIntOrFloatBitWidth() /
+               origElemType.getIntOrFloatBitWidth();
   auto newLoad = rewriter.create<vector::LoadOp>(
-      loc, VectorType::get(numElementsToLoad, newElementType), base,
+      loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
       getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
   return rewriter.create<vector::BitCastOp>(
-      loc, VectorType::get(numElementsToLoad * scale, oldElememtType), newLoad);
+      loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
+      newLoad);
 };
 
 namespace {
@@ -435,7 +434,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             : 0;
 
     // always load enough elements which can cover the original elements
-    auto maxintraDataOffset =
+    int64_t maxintraDataOffset =
         foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
     auto numElements =
         llvm::divideCeil(maxintraDataOffset + origElements, scale);
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 6d9d5da707ae81..4800fe9fd83bcd 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -17,7 +17,7 @@ func.func @vector_load_i2() -> vector<3x3xi2> {
   return %2 : vector<3x3xi2>
 }
 
-// CHECK: func @vector_load_i2
+// CHECK-LABEL: func @vector_load_i2
 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
 // CHECK: %[[INDEX:.+]] = arith.constant 1 : index
 // CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8>
@@ -35,7 +35,7 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
   return %1 : vector<3xi2>
 }
 
-// CHECK: func @vector_transfer_read_i2
+// CHECK-LABEL: func @vector_transfer_read_i2
 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
 // CHECK: %[[INDEX:.+]] = arith.constant 1 : index
 // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %0 : memref<3xi8>, vector<2xi8>
@@ -56,11 +56,12 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
   return %2 : vector<3x5xi2>
 }
 
-// CHECK: func @vector_cst_maskedload_i2
+// CHECK-LABEL: func @vector_cst_maskedload_i2(
+// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
 // CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
 // CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
 // CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
-// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %arg0, %[[VESSEL]]
+// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
 // CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
 // CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<8xi2> to vector<2xi8>
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
@@ -82,7 +83,8 @@ func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector
   return %1 : vector<3xi2>
 }
 
-// CHECK: func @vector_load_i2_dynamic_indexing
+// CHECK-LABEL: func @vector_load_i2_dynamic_indexing(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index) -> vector<3xi2>
 // CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
 // CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
 // CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
@@ -106,11 +108,12 @@ func.func @vector_transfer_read_i2_dynamic_indexing(%idx1: index, %idx2: index)
   return %1 : vector<3xi2>
 }
 
-// CHECK: func @vector_transfer_read_i2_dynamic_indexing
+// CHECK-LABEL: func @vector_transfer_read_i2_dynamic_indexing(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index) -> vector<3xi2>
 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
 // CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
-// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
-// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%[[ARG0]], %[[ARG1]]]
 // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
 // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
 // CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>

>From 2b86a23ea401658adbc6433f7419df4a6593e2bc Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Mon, 4 Nov 2024 15:29:12 +0000
Subject: [PATCH 7/8] another update to resolve comments

---
 .../Transforms/VectorEmulateNarrowType.cpp    |  8 +--
 .../vector-emulate-narrow-type-unaligned.mlir | 59 ++++++++++++++++++-
 2 files changed, 61 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index de44bb6299dcc1..fe84a204009d24 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -134,9 +134,9 @@ static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
                                   VectorType extractType, Value source,
                                   int64_t frontOffset, int64_t subvecSize) {
   auto vectorType = dyn_cast<VectorType>(source.getType());
-  assert(
-      (vectorType && vectorType.getRank() == 1 && extractType.getRank() == 1) &&
-      "expected 1-D source and destination types");
+  assert(vectorType && vectorType.getRank() == 1 &&
+         extractType.getRank() == 1 &&
+         "expected 1-D source and destination types");
   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
   auto sizes = rewriter.getI64ArrayAttr({subvecSize});
   auto strides = rewriter.getI64ArrayAttr({1});
@@ -183,7 +183,7 @@ static Value dynamicallyExtractSubVector(RewriterBase &rewriter, Location loc,
   return dest;
 }
 
-/// Returns the op sequence for an emulated sub-byte datatype vector load.
+/// Returns the op sequence for an emulated sub-byte data type vector load.
 /// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
 /// The load location is given by `base` and `linearizedIndices`, and the
 /// load size is given by `numEmulatedElementsToLoad`.
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 4800fe9fd83bcd..5a9d6e3d9e4a52 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -6,6 +6,8 @@
 
 // CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
 // CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
+// CHECK: #map2 = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
+// CHECK: #map3 = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
 
 func.func @vector_load_i2() -> vector<3x3xi2> {
   %0 = memref.alloc() : memref<3x3xi2>
@@ -86,8 +88,34 @@ func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector
 // CHECK-LABEL: func @vector_load_i2_dynamic_indexing(
 // CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index) -> vector<3xi2>
 // CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
-// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
-// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
+
+//-----
+
+func.func @vector_load_i2_dynamic_indexing_mixed(%idx: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %c2 = arith.constant 2 : index
+  %cst = arith.constant dense<1> : vector<3x3xi2>
+  %1 = vector.load %0[%idx, %c2] : memref<3x3xi2>, vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK-LABEL: func @vector_load_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[ARG0:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map2()[%[[ARG0]]]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map3()[%[[ARG0]]]
 // CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
 // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
 // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
@@ -124,3 +152,30 @@ func.func @vector_transfer_read_i2_dynamic_indexing(%idx1: index, %idx2: index)
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
 // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
 // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+
+//-----
+
+func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %c2 = arith.constant 2 : index
+  %pad = arith.constant 0 : i2
+  %1 = vector.transfer_read %0[%idx1, %c2], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK-LABEL: func @vector_transfer_read_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[ARG0:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map2()[%[[ARG0]]]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map3()[%[[ARG0]]]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
\ No newline at end of file

>From 8225f7221bdb4f7bb8ca7c038638a035227cc4db Mon Sep 17 00:00:00 2001
From: lialan <alan.li at me.com>
Date: Mon, 4 Nov 2024 15:01:18 -0500
Subject: [PATCH 8/8] Update
 mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Co-authored-by: Han-Chung Wang <hanhan0912 at gmail.com>
---
 .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp   | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index fe84a204009d24..6232b0aa978f55 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -133,9 +133,9 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
 static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
                                   VectorType extractType, Value source,
                                   int64_t frontOffset, int64_t subvecSize) {
-  auto vectorType = dyn_cast<VectorType>(source.getType());
-  assert(vectorType && vectorType.getRank() == 1 &&
-         extractType.getRank() == 1 &&
+  auto vectorType = cast<VectorType>(source.getType());
+  assert((vectorType.getRank() == 1 &&
+         extractType.getRank() == 1) &&
          "expected 1-D source and destination types");
   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
   auto sizes = rewriter.getI64ArrayAttr({subvecSize});



More information about the Mlir-commits mailing list