[Mlir-commits] [mlir] ce112a7 - [MLIR] support dynamic indexing in `VectorEmulateNarrowTypes` (#114169)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 5 09:21:07 PST 2024


Author: lialan
Date: 2024-11-05T09:21:03-08:00
New Revision: ce112a7f44ca0776d1192f6183a33e0c9f69df53

URL: https://github.com/llvm/llvm-project/commit/ce112a7f44ca0776d1192f6183a33e0c9f69df53
DIFF: https://github.com/llvm/llvm-project/commit/ce112a7f44ca0776d1192f6183a33e0c9f69df53.diff

LOG: [MLIR] support dynamic indexing in `VectorEmulateNarrowTypes` (#114169)

* Supports `vector.load` and `vector.transfer_read` ops.
* In the case of dynamic indexing, use per-element insertion/extraction
to build desired narrow type vectors.
* Fixed wrong function comment of `getCompressedMaskOp`.

---------

Co-authored-by: Han-Chung Wang <hanhan0912 at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
    mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1d6f8a991d9b5b..f169dab3bdd9af 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -37,16 +38,17 @@ using namespace mlir;
 
 /// Returns a compressed mask. The mask value is set only if any mask is present
 /// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
-/// equals to 2, the following mask:
+/// equals to 1 (intraDataOffset strictly smaller than scale), the following
+/// mask:
 ///
-///   %mask = [1, 1, 1, 0, 0, 0]
+///   %mask = [1, 1, 0, 0, 0, 0]
 ///
 /// will first be padded with number of `intraDataOffset` zeros:
-///   %mask = [0, 0, 1, 1, 1, 0, 0, 0]
+///   %mask = [0, 1, 1, 0, 0, 0, 0, 0]
 ///
 /// then it will return the following new compressed mask:
 ///
-///   %mask = [0, 1, 1, 0]
+///   %mask = [1, 1, 0, 0]
 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int origElements, int scale,
@@ -75,9 +77,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   shape.back() = numElements;
   auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
   if (createMaskOp) {
-    // TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
-    if (intraDataOffset != 0)
-      return failure();
     OperandRange maskOperands = createMaskOp.getOperands();
     size_t numMaskOperands = maskOperands.size();
     AffineExpr s0;
@@ -129,26 +128,79 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   return newMask;
 }
 
-static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
-                                  VectorType extractType, Value vector,
-                                  int64_t frontOffset, int64_t subvecSize) {
+/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
+/// emitting `vector.extract_strided_slice`.
+static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
+                                        VectorType extractType, Value source,
+                                        int64_t frontOffset,
+                                        int64_t subvecSize) {
+  auto vectorType = cast<VectorType>(source.getType());
+  assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
+         "expected 1-D source and destination types");
   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
   auto sizes = rewriter.getI64ArrayAttr({subvecSize});
   auto strides = rewriter.getI64ArrayAttr({1});
   return rewriter
-      .create<vector::ExtractStridedSliceOp>(loc, extractType, vector, offsets,
+      .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
                                              sizes, strides)
       ->getResult(0);
 }
 
-static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
-                                 Value src, Value dest, int64_t offset) {
+/// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
+/// at `offset`. it is a wrapper function for emitting
+/// `vector.insert_strided_slice`.
+static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
+                                       Value src, Value dest, int64_t offset) {
+  auto srcType = cast<VectorType>(src.getType());
+  auto destType = cast<VectorType>(dest.getType());
+  assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
+         "expected source and dest to be vector type");
   auto offsets = rewriter.getI64ArrayAttr({offset});
   auto strides = rewriter.getI64ArrayAttr({1});
   return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
                                                        dest, offsets, strides);
 }
 
+/// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
+/// and size `numElementsToExtract`, and inserts into the `dest` vector. This
+/// function emits multiple `vector.extract` and `vector.insert` ops, so only
+/// use it when `offset` cannot be folded into a constant value.
+static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
+                                         TypedValue<VectorType> source,
+                                         Value dest, OpFoldResult offset,
+                                         int64_t numElementsToExtract) {
+  for (int i = 0; i < numElementsToExtract; ++i) {
+    Value extractLoc =
+        (i == 0) ? offset.dyn_cast<Value>()
+                 : rewriter.create<arith::AddIOp>(
+                       loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
+                       rewriter.create<arith::ConstantIndexOp>(loc, i));
+    auto extractOp =
+        rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
+    dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
+  }
+  return dest;
+}
+
+/// Returns the op sequence for an emulated sub-byte data type vector load.
+/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
+/// The load location is given by `base` and `linearizedIndices`, and the
+/// load size is given by `numEmulatedElementsToLoad`.
+static TypedValue<VectorType>
+emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
+                   OpFoldResult linearizedIndices,
+                   int64_t numEmultedElementsToLoad, Type origElemType,
+                   Type emulatedElemType) {
+  auto scale = emulatedElemType.getIntOrFloatBitWidth() /
+               origElemType.getIntOrFloatBitWidth();
+  auto newLoad = rewriter.create<vector::LoadOp>(
+      loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
+      getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+  return rewriter.create<vector::BitCastOp>(
+      loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
+      newLoad);
+};
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -380,25 +432,27 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic intra vector offset
-      return failure();
-    }
-
+    // Always load enough elements which can cover the original elements.
+    int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
-    auto newLoad = rewriter.create<vector::LoadOp>(
-        loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
-
-    Value result = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements * scale, oldElementType), newLoad);
-
-    if (isUnalignedEmulation) {
-      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedIntraVectorOffset, origElements);
+        llvm::divideCeil(maxintraDataOffset + origElements, scale);
+    Value result =
+        emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
+                           numElements, oldElementType, newElementType);
+
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result =
+            staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+                                       *foldedIntraVectorOffset, origElements);
+      }
+    } else {
+      auto resultVector = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      result = dynamicallyExtractSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+          linearizedInfo.intraDataOffset, origElements);
     }
-
     rewriter.replaceOp(op, result);
     return success();
   }
@@ -513,8 +567,8 @@ struct ConvertVectorMaskedLoad final
       // create an empty vector of the new type
       auto emptyVector = rewriter.create<arith::ConstantOp>(
           loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
-      passthru = insertSubvectorInto(rewriter, loc, passthru, emptyVector,
-                                     *foldedIntraVectorOffset);
+      passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
+                                           *foldedIntraVectorOffset);
     }
     auto newPassThru =
         rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -537,16 +591,17 @@ struct ConvertVectorMaskedLoad final
       // TODO: can fold if op's mask is constant
       auto emptyVector = rewriter.create<arith::ConstantOp>(
           loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
-      mask = insertSubvectorInto(rewriter, loc, op.getMask(), emptyVector,
-                                 *foldedIntraVectorOffset);
+      mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
+                                       *foldedIntraVectorOffset);
     }
 
     Value result =
         rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
 
     if (isUnalignedEmulation) {
-      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedIntraVectorOffset, origElements);
+      result =
+          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+                                     *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
@@ -604,13 +659,10 @@ struct ConvertVectorTransferRead final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic inra-vector offset
-      return failure();
-    }
-
+    auto maxIntraVectorOffset =
+        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -621,9 +673,18 @@ struct ConvertVectorTransferRead final
         loc, VectorType::get(numElements * scale, oldElementType), newRead);
 
     Value result = bitCast->getResult(0);
-    if (isUnalignedEmulation) {
-      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result =
+            staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+                                       *foldedIntraVectorOffset, origElements);
+      }
+    } else {
+      auto zeros = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
+                                           linearizedInfo.intraDataOffset,
+                                           origElements);
     }
     rewriter.replaceOp(op, result);
 

diff  --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ecbad7968225d..0cecaddc5733e2 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -1,59 +1,65 @@
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
 
-func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
-    %0 = memref.alloc() : memref<3x3xi2>
-    %c0 = arith.constant 0 : index
-    %c2 = arith.constant 2 : index
-    %cst = arith.constant dense<0> : vector<3x3xi2>
-    %1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
-    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
-    return %2 : vector<3x3xi2>
+// TODO: remove memref.alloc() in the tests to eliminate noises.
+// memref.alloc exists here because sub-byte vector data types such as i2
+// are currently not supported as input arguments.
+
+
+func.func @vector_load_i2() -> vector<3x3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %cst = arith.constant dense<0> : vector<3x3xi2>
+  %1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
+  %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+  return %2 : vector<3x3xi2>
 }
 
-// CHECK: func @vector_load_i2
+// CHECK-LABEL: func @vector_load_i2
 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
 // CHECK: %[[INDEX:.+]] = arith.constant 1 : index
 // CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi8>, vector<2xi8>
 // CHECK: %[[VEC_I2:.+]] = vector.bitcast %[[VEC]] : vector<2xi8> to vector<8xi2>
 // CHECK: %[[EXCTRACT:.+]] = vector.extract_strided_slice %[[VEC_I2]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>
 
-//-----
+// -----
 
 func.func @vector_transfer_read_i2() -> vector<3xi2> {
- %0 = memref.alloc() : memref<3x3xi2>
- %c0i2 = arith.constant 0 : i2
- %c0 = arith.constant 0 : index
- %c2 = arith.constant 2 : index
- %1 = vector.transfer_read %0[%c2, %c0], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
- return %1 : vector<3xi2>
+  %0 = memref.alloc() : memref<3x3xi2>
+  %pad = arith.constant 0 : i2
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %1 = vector.transfer_read %0[%c2, %c0], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+  return %1 : vector<3xi2>
 }
 
-// CHECK: func @vector_transfer_read_i2
+// CHECK-LABEL: func @vector_transfer_read_i2
 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
 // CHECK: %[[INDEX:.+]] = arith.constant 1 : index
 // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %0 : memref<3xi8>, vector<2xi8>
 // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
 // CHECK: vector.extract_strided_slice %[[BITCAST]] {offsets = [2], sizes = [3], strides = [1]} : vector<8xi2> to vector<3xi2>
 
-//-----
+// -----
 
 func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
-    %0 = memref.alloc() : memref<3x5xi2>
-    %cst = arith.constant dense<0> : vector<3x5xi2>
-    %mask = vector.constant_mask [3] : vector<5xi1>
-    %c0 = arith.constant 0 : index
-    %c2 = arith.constant 2 : index
-    %1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
-      memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
-    %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
-    return %2 : vector<3x5xi2>
+  %0 = memref.alloc() : memref<3x5xi2>
+  %cst = arith.constant dense<0> : vector<3x5xi2>
+  %mask = vector.constant_mask [3] : vector<5xi1>
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
+    memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+  %2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
+  return %2 : vector<3x5xi2>
 }
 
-// CHECK: func @vector_cst_maskedload_i2
+// CHECK-LABEL: func @vector_cst_maskedload_i2(
+// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
 // CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
 // CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
 // CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
-// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %arg0, %[[VESSEL]]
+// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
 // CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
 // CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<8xi2> to vector<2xi8>
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
@@ -64,4 +70,116 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
 // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]]
 // CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi1> into vector<8xi1>
 // CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BITCAST2]], %[[INSERT1]] : vector<8xi1>, vector<8xi2>
-// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2> 
+// CHECK: vector.extract_strided_slice %[[SELECT]] {offsets = [2], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
+
+// -----
+
+func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %cst = arith.constant dense<0> : vector<3x3xi2>
+  %1 = vector.load %0[%idx1, %idx2] : memref<3x3xi2>, vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
+// CHECK: func @vector_load_i2_dynamic_indexing(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
+
+// -----
+
+func.func @vector_load_i2_dynamic_indexing_mixed(%idx: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %c2 = arith.constant 2 : index
+  %cst = arith.constant dense<1> : vector<3x3xi2>
+  %1 = vector.load %0[%idx, %c2] : memref<3x3xi2>, vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
+// CHECK: func @vector_load_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[ARG0:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
+
+// -----
+
+func.func @vector_transfer_read_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %pad = arith.constant 0 : i2
+  %1 = vector.transfer_read %0[%idx1, %idx2], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
+// CHECK: func @vector_transfer_read_i2_dynamic_indexing(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+
+// -----
+
+func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %c2 = arith.constant 2 : index
+  %pad = arith.constant 0 : i2
+  %1 = vector.transfer_read %0[%idx1, %c2], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
+// CHECK: func @vector_transfer_read_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[ARG0:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>


        


More information about the Mlir-commits mailing list