[Mlir-commits] [mlir] [MLIR] support dynamic indexing in `VectorEmulateNarrowTypes` (PR #114169)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 30 06:12:44 PDT 2024


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/114169

>From 2effa6ab20f5a7d15fdc3faf6710fd970e6e8219 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Tue, 29 Oct 2024 14:03:23 +0000
Subject: [PATCH] Implement VectorLoadOp

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 116 ++++++++++++++----
 ...emulate-narrow-type-unaligned-dynamic.mlir |  53 ++++++++
 2 files changed, 143 insertions(+), 26 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1d6f8a991d9b5b..09bc7256fc3cf0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -149,6 +150,61 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
                                                        dest, offsets, strides);
 }
 
+static void dynamicallyExtractElementsToVector(
+    RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
+    Value destVec, OpFoldResult srcOffsetVar, int64_t loopSize) {
+  /*
+  // Create affine maps for the lower and upper bounds
+  AffineMap lowerBoundMap = AffineMap::getConstantMap(0, rewriter.getContext());
+  AffineMap upperBoundMap =
+      AffineMap::getConstantMap(loopSize, rewriter.getContext());
+
+  auto forLoop = rewriter.create<affine::AffineForOp>(
+      loc, ValueRange{}, lowerBoundMap, ValueRange{}, upperBoundMap, 1,
+      ArrayRef<Value>(destVec));
+
+  OpBuilder builder =
+      OpBuilder::atBlockEnd(forLoop.getBody(), rewriter.getListener());
+
+  auto iv = forLoop.getInductionVar();
+
+  auto loopDestVec = forLoop.getRegionIterArgs()[0];
+  auto extractLoc = builder.create<arith::AddIOp>(
+      loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(), iv);
+  auto extractElemOp = builder.create<vector::ExtractElementOp>(
+      loc, elemType, srcVec, extractLoc);
+  auto insertElemOp = builder.create<vector::InsertElementOp>(
+      loc, extractElemOp, loopDestVec, iv);
+  builder.create<affine::AffineYieldOp>(loc,
+                                        ValueRange{insertElemOp->getResult(0)});
+  return forLoop->getResult(0);
+  */
+  for (int i = 0; i < loopSize; ++i) {
+    Value extractLoc;
+    if (i == 0) {
+      extractLoc = srcOffsetVar.dyn_cast<Value>();
+    } else {
+      extractLoc = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(),
+          rewriter.create<arith::ConstantIndexOp>(loc, i));
+    }
+    auto extractOp =
+        rewriter.create<vector::ExtractOp>(loc, srcVec, extractLoc);
+    rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
+  }
+}
+
+static TypedValue<VectorType>
+emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
+                   Value base, OpFoldResult linearizedIndices, int64_t numBytes,
+                   int64_t scale, Type oldElememtType, Type newElementType) {
+  auto newLoad = rewriter.create<vector::LoadOp>(
+      loc, VectorType::get(numBytes, newElementType), base,
+      getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+  return rewriter.create<vector::BitCastOp>(
+      loc, VectorType::get(numBytes * scale, oldElememtType), newLoad);
+};
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -380,26 +436,29 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic intra vector offset
-      return failure();
-    }
-
+    // always load enough elements which can cover the original elements
+    auto maxintraDataOffset =
+        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
-    auto newLoad = rewriter.create<vector::LoadOp>(
-        loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
-
-    Value result = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements * scale, oldElementType), newLoad);
+        llvm::divideCeil(maxintraDataOffset + origElements, scale);
+    Value result =
+        emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
+                           numElements, scale, oldElementType, newElementType);
 
-    if (isUnalignedEmulation) {
-      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                      *foldedIntraVectorOffset, origElements);
+      }
+      rewriter.replaceOp(op, result);
+    } else {
+      auto resultVector = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      dynamicallyExtractElementsToVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+          linearizedInfo.intraDataOffset, origElements);
+      rewriter.replaceOp(op, resultVector);
     }
-
-    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -604,13 +663,10 @@ struct ConvertVectorTransferRead final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic inra-vector offset
-      return failure();
-    }
-
+    auto maxIntraVectorOffset =
+        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -621,9 +677,17 @@ struct ConvertVectorTransferRead final
         loc, VectorType::get(numElements * scale, oldElementType), newRead);
 
     Value result = bitCast->getResult(0);
-    if (isUnalignedEmulation) {
-      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                      *foldedIntraVectorOffset, origElements);
+      }
+    } else {
+      result = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      dynamicallyExtractElementsToVector(rewriter, loc, bitCast, result,
+                                         linearizedInfo.intraDataOffset,
+                                         origElements);
     }
     rewriter.replaceOp(op, result);
 
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
new file mode 100644
index 00000000000000..a92e62538c5332
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+
+// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
+// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
+func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %cst = arith.constant dense<0> : vector<3x3xi2>
+    %1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
+    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+    return %2 : vector<3x3xi2>
+}
+
+// CHECK: func @vector_load_i2
+// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
+
+//-----
+
+func.func @vector_transfer_read_i2(%arg1: index, %arg2: index) -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0i2 = arith.constant 0 : i2
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: func @vector_transfer_read_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>



More information about the Mlir-commits mailing list