[Mlir-commits] [mlir] 8345b86 - [mlir][Vector] Add lowering of 1-D vector transfer_read/write to masked load/store

Nicolas Vasilache llvmlistbot at llvm.org
Thu Apr 9 13:20:35 PDT 2020


Author: Nicolas Vasilache
Date: 2020-04-09T16:17:05-04:00
New Revision: 8345b86d9ac20c112c6f66b1bfbcf9c5c4158996

URL: https://github.com/llvm/llvm-project/commit/8345b86d9ac20c112c6f66b1bfbcf9c5c4158996
DIFF: https://github.com/llvm/llvm-project/commit/8345b86d9ac20c112c6f66b1bfbcf9c5c4158996.diff

LOG: [mlir][Vector] Add lowering of 1-D vector transfer_read/write to masked load/store

Summary:
This revision adds support to lower 1-D vector transfers to LLVM.
A mask of the vector length is created that compares the base offset + linear index to the dim of the vector.
In each position where this does not overflow (i.e. offset + vector index < dim), the mask is set to 1.

A notable fact is that the lowering uses llvm.dialect_cast to allow writing code in the simplest form by targeting the simplest mix of vector and LLVM dialects and
letting other conversions kick in.

Differential Revision: https://reviews.llvm.org/D77703

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/IR/Builders.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/IR/Builders.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index fffdf3947a2e..555666753360 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -398,6 +398,29 @@ class ConvertToLLVMPattern : public ConversionPattern {
   Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
                             uint64_t value) const;
 
+  // Given subscript indices and array sizes in row-major order,
+  //   i_n, i_{n-1}, ..., i_1
+  //   s_n, s_{n-1}, ..., s_1
+  // obtain a value that corresponds to the linearized subscript
+  //   \sum_k i_k * \prod_{j=1}^{k-1} s_j
+  // by accumulating the running linearized value.
+  // Note that `indices` and `allocSizes` are passed in the same order as they
+  // appear in load/store operations and memref type declarations.
+  Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
+                            ArrayRef<Value> indices,
+                            ArrayRef<Value> allocSizes) const;
+
+  // This is a strided getElementPtr variant that linearizes subscripts as:
+  //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
+  Value getStridedElementPtr(Location loc, Type elementTypePtr,
+                             Value descriptor, ArrayRef<Value> indices,
+                             ArrayRef<int64_t> strides, int64_t offset,
+                             ConversionPatternRewriter &rewriter) const;
+
+  Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
+                   ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
+                   llvm::Module &module) const;
+
 protected:
   /// Reference to the type converter, with potential extensions.
   LLVMTypeConverter &typeConverter;

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 1ff2d6ae28bb..b33b38971249 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -73,6 +73,7 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
 
   /// Vector type utilities.
   LLVMType getVectorElementType();
+  unsigned getVectorNumElements();
   bool isVectorTy();
 
   /// Function type utilities.

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 75f49e86d10a..bceba806024d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -111,6 +111,7 @@ class Builder {
   IntegerAttr getI16IntegerAttr(int16_t value);
   IntegerAttr getI32IntegerAttr(int32_t value);
   IntegerAttr getI64IntegerAttr(int64_t value);
+  IntegerAttr getIndexAttr(int64_t value);
 
   /// Signed and unsigned integer attribute getters.
   IntegerAttr getSI32IntegerAttr(int32_t value);

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index a746af7cce61..57663a39e132 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -735,6 +735,61 @@ Value ConvertToLLVMPattern::createIndexConstant(
   return createIndexAttrConstant(builder, loc, getIndexType(), value);
 }
 
+Value ConvertToLLVMPattern::linearizeSubscripts(
+    ConversionPatternRewriter &builder, Location loc, ArrayRef<Value> indices,
+    ArrayRef<Value> allocSizes) const {
+  assert(indices.size() == allocSizes.size() &&
+         "mismatching number of indices and allocation sizes");
+  assert(!indices.empty() && "cannot linearize a 0-dimensional access");
+
+  Value linearized = indices.front();
+  for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
+    linearized = builder.create<LLVM::MulOp>(
+        loc, this->getIndexType(), ArrayRef<Value>{linearized, allocSizes[i]});
+    linearized = builder.create<LLVM::AddOp>(
+        loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
+  }
+  return linearized;
+}
+
+Value ConvertToLLVMPattern::getStridedElementPtr(
+    Location loc, Type elementTypePtr, Value descriptor,
+    ArrayRef<Value> indices, ArrayRef<int64_t> strides, int64_t offset,
+    ConversionPatternRewriter &rewriter) const {
+  MemRefDescriptor memRefDescriptor(descriptor);
+
+  Value base = memRefDescriptor.alignedPtr(rewriter, loc);
+  Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
+                          ? memRefDescriptor.offset(rewriter, loc)
+                          : this->createIndexConstant(rewriter, loc, offset);
+
+  for (int i = 0, e = indices.size(); i < e; ++i) {
+    Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
+                       ? memRefDescriptor.stride(rewriter, loc, i)
+                       : this->createIndexConstant(rewriter, loc, strides[i]);
+    Value additionalOffset =
+        rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
+    offsetValue =
+        rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
+  }
+  return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
+}
+
+Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
+                                       Value memRefDesc,
+                                       ArrayRef<Value> indices,
+                                       ConversionPatternRewriter &rewriter,
+                                       llvm::Module &module) const {
+  LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  auto successStrides = getStridesAndOffset(type, strides, offset);
+  assert(succeeded(successStrides) && "unexpected non-strided memref");
+  (void)successStrides;
+  return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
+                              offset, rewriter);
+}
+
 /// Only retain those attributes that are not constructed by
 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
 /// attributes.
@@ -1913,70 +1968,6 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
     MemRefType type = cast<Derived>(op).getMemRefType();
     return isSupportedMemRefType(type) ? success() : failure();
   }
-
-  // Given subscript indices and array sizes in row-major order,
-  //   i_n, i_{n-1}, ..., i_1
-  //   s_n, s_{n-1}, ..., s_1
-  // obtain a value that corresponds to the linearized subscript
-  //   \sum_k i_k * \prod_{j=1}^{k-1} s_j
-  // by accumulating the running linearized value.
-  // Note that `indices` and `allocSizes` are passed in the same order as they
-  // appear in load/store operations and memref type declarations.
-  Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
-                            ArrayRef<Value> indices,
-                            ArrayRef<Value> allocSizes) const {
-    assert(indices.size() == allocSizes.size() &&
-           "mismatching number of indices and allocation sizes");
-    assert(!indices.empty() && "cannot linearize a 0-dimensional access");
-
-    Value linearized = indices.front();
-    for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
-      linearized = builder.create<LLVM::MulOp>(
-          loc, this->getIndexType(),
-          ArrayRef<Value>{linearized, allocSizes[i]});
-      linearized = builder.create<LLVM::AddOp>(
-          loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
-    }
-    return linearized;
-  }
-
-  // This is a strided getElementPtr variant that linearizes subscripts as:
-  //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
-  Value getStridedElementPtr(Location loc, Type elementTypePtr,
-                             Value descriptor, ArrayRef<Value> indices,
-                             ArrayRef<int64_t> strides, int64_t offset,
-                             ConversionPatternRewriter &rewriter) const {
-    MemRefDescriptor memRefDescriptor(descriptor);
-
-    Value base = memRefDescriptor.alignedPtr(rewriter, loc);
-    Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
-                            ? memRefDescriptor.offset(rewriter, loc)
-                            : this->createIndexConstant(rewriter, loc, offset);
-
-    for (int i = 0, e = indices.size(); i < e; ++i) {
-      Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
-                         ? memRefDescriptor.stride(rewriter, loc, i)
-                         : this->createIndexConstant(rewriter, loc, strides[i]);
-      Value additionalOffset =
-          rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
-      offsetValue =
-          rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
-    }
-    return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
-  }
-
-  Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
-                   ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
-                   llvm::Module &module) const {
-    LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
-    int64_t offset;
-    SmallVector<int64_t, 4> strides;
-    auto successStrides = getStridesAndOffset(type, strides, offset);
-    assert(succeeded(successStrides) && "unexpected non-strided memref");
-    (void)successStrides;
-    return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
-                                offset, rewriter);
-  }
 };
 
 // Load operation is lowered to obtaining a pointer to the indexed element

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index eb4bf3b6d0ef..38822fa79458 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
@@ -894,6 +895,129 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
   }
 };
 
+template <typename ConcreteOp>
+void replaceTransferOp(ConversionPatternRewriter &rewriter,
+                       LLVMTypeConverter &typeConverter, Location loc,
+                       Operation *op, ArrayRef<Value> operands, Value dataPtr,
+                       Value mask);
+
+template <>
+void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
+                                       LLVMTypeConverter &typeConverter,
+                                       Location loc, Operation *op,
+                                       ArrayRef<Value> operands, Value dataPtr,
+                                       Value mask) {
+  auto xferOp = cast<TransferReadOp>(op);
+  auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+  VectorType fillType = xferOp.getVectorType();
+  Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
+  fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
+
+  auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
+  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
+      op, vecTy, dataPtr, mask, ValueRange{fill},
+      rewriter.getI32IntegerAttr(1));
+}
+
+template <>
+void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
+                                        LLVMTypeConverter &typeConverter,
+                                        Location loc, Operation *op,
+                                        ArrayRef<Value> operands, Value dataPtr,
+                                        Value mask) {
+  auto adaptor = TransferWriteOpOperandAdaptor(operands);
+  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
+      op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
+}
+
+static TransferReadOpOperandAdaptor
+getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
+  return TransferReadOpOperandAdaptor(operands);
+}
+
+static TransferWriteOpOperandAdaptor
+getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
+  return TransferWriteOpOperandAdaptor(operands);
+}
+
+/// Conversion pattern that converts a 1-D vector transfer read/write op in a
+/// sequence of:
+/// 1. Bitcast to vector form.
+/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+/// 3. Create a mask where offsetVector is compared against memref upper bound.
+/// 4. Rewrite op as a masked read or write.
+template <typename ConcreteOp>
+class VectorTransferConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorTransferConversion(MLIRContext *context,
+                                    LLVMTypeConverter &typeConv)
+      : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
+                             typeConv) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto xferOp = cast<ConcreteOp>(op);
+    auto adaptor = getTransferOpAdapter(xferOp, operands);
+    if (xferOp.getMemRefType().getRank() != 1)
+      return failure();
+    if (!xferOp.permutation_map().isIdentity())
+      return failure();
+
+    auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+
+    Location loc = op->getLoc();
+    Type i64Type = rewriter.getIntegerType(64);
+    MemRefType memRefType = xferOp.getMemRefType();
+
+    // 1. Get the source/dst address as an LLVM vector pointer.
+    // TODO: support alignment when possible.
+    Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
+                               adaptor.indices(), rewriter, getModule());
+    auto vecTy =
+        toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
+    auto vectorDataPtr =
+        rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
+
+    // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+    unsigned vecWidth = vecTy.getVectorNumElements();
+    VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
+    SmallVector<int64_t, 8> indices;
+    indices.reserve(vecWidth);
+    for (unsigned i = 0; i < vecWidth; ++i)
+      indices.push_back(i);
+    Value linearIndices = rewriter.create<ConstantOp>(
+        loc, vectorCmpType,
+        DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
+    linearIndices = rewriter.create<LLVM::DialectCastOp>(
+        loc, toLLVMTy(vectorCmpType), linearIndices);
+
+    // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+    Value offsetIndex = *(xferOp.indices().begin());
+    offsetIndex = rewriter.create<IndexCastOp>(
+        loc, vectorCmpType.getElementType(), offsetIndex);
+    Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
+    Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
+
+    // 4. Let dim the memref dimension, compute the vector comparison mask:
+    //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+    Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), 0);
+    dim =
+        rewriter.create<IndexCastOp>(loc, vectorCmpType.getElementType(), dim);
+    dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
+    Value mask =
+        rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
+    mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
+                                                mask);
+
+    // 5. Rewrite as a masked read / write.
+    replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op, operands,
+                                  vectorDataPtr, mask);
+
+    return success();
+  }
+};
+
 class VectorPrintOpConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorPrintOpConversion(MLIRContext *context,
@@ -1079,16 +1203,25 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
 void mlir::populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   MLIRContext *ctx = converter.getDialect()->getContext();
+  // clang-format off
   patterns.insert<VectorFMAOpNDRewritePattern,
                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
                   VectorInsertStridedSliceOpSameRankRewritePattern,
                   VectorStridedSliceOpConversion>(ctx);
-  patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
-                  VectorShuffleOpConversion, VectorExtractElementOpConversion,
-                  VectorExtractOpConversion, VectorFMAOp1DConversion,
-                  VectorInsertElementOpConversion, VectorInsertOpConversion,
-                  VectorTypeCastOpConversion, VectorPrintOpConversion>(
-      ctx, converter);
+  patterns
+      .insert<VectorBroadcastOpConversion,
+              VectorReductionOpConversion,
+              VectorShuffleOpConversion,
+              VectorExtractElementOpConversion,
+              VectorExtractOpConversion,
+              VectorFMAOp1DConversion,
+              VectorInsertElementOpConversion,
+              VectorInsertOpConversion,
+              VectorPrintOpConversion,
+              VectorTransferConversion<TransferReadOp>,
+              VectorTransferConversion<TransferWriteOp>,
+              VectorTypeCastOpConversion>(ctx, converter);
+  // clang-format on
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c69530b28e29..82bbe18dd01e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1774,6 +1774,9 @@ bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); }
 LLVMType LLVMType::getVectorElementType() {
   return get(getContext(), getUnderlyingType()->getVectorElementType());
 }
+unsigned LLVMType::getVectorNumElements() {
+  return getUnderlyingType()->getVectorNumElements();
+}
 bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
 
 /// Function type utilities.

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index c8d5ea6b6ca9..84a883c64d24 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -93,6 +93,10 @@ DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
   return DictionaryAttr::get(value, context);
 }
 
+IntegerAttr Builder::getIndexAttr(int64_t value) {
+  return IntegerAttr::get(getIndexType(), APInt(64, value));
+}
+
 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 0cc6789a0619..6a65b219b632 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -738,3 +738,95 @@ func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
 //       CHECK:   llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
 //  CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
 //  CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>">
+
+func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
+  %f7 = constant 7.0: f32
+  %f = vector.transfer_read %A[%base], %f7
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    memref<?xf32>, vector<17xf32>
+  vector.transfer_write %f, %A[%base]
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    vector<17xf32>, memref<?xf32>
+  return %f: vector<17xf32>
+}
+// CHECK-LABEL: func @transfer_read_1d
+//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>">
+//
+// 1. Bitcast to vector form.
+//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+//  CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+//       CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
+//  CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
+//
+// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+//       CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(
+//  CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+//  CHECK-SAME: vector<17xi64>) : !llvm<"<17 x i64>">
+//
+// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+//       CHECK: %[[offsetVec:.*]] = llvm.mlir.undef : !llvm<"<17 x i64>">
+//       CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+//       CHECK: %[[offsetVec2:.*]] = llvm.insertelement %[[BASE]], %[[offsetVec]][%[[c0]] :
+//  CHECK-SAME: !llvm.i32] : !llvm<"<17 x i64>">
+//       CHECK: %[[offsetVec3:.*]] = llvm.shufflevector %[[offsetVec2]], %{{.*}} [
+//  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+//  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+//  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32] :
+//  CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
+//       CHECK: %[[offsetVec4:.*]] = llvm.add %[[offsetVec3]], %[[linearIndex]] :
+//  CHECK-SAME: !llvm<"<17 x i64>">
+//
+// 4. Let dim the memref dimension, compute the vector comparison mask:
+//    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+//       CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
+//  CHECK-SAME: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+//       CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm<"<17 x i64>">
+//       CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+//       CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :
+//  CHECK-SAME:  !llvm.i32] : !llvm<"<17 x i64>">
+//       CHECK: %[[dimVec3:.*]] = llvm.shufflevector %[[dimVec2]], %{{.*}} [
+//  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+//  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+//  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32] :
+//  CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
+//       CHECK: %[[mask:.*]] = llvm.icmp "slt" %[[offsetVec4]], %[[dimVec3]] :
+//  CHECK-SAME: !llvm<"<17 x i64>">
+//
+// 5. Rewrite as a masked read.
+//       CHECK: %[[PASS_THROUGH:.*]] =  llvm.mlir.constant(dense<7.000000e+00> :
+//  CHECK-SAME:  vector<17xf32>) : !llvm<"<17 x float>">
+//       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
+//  CHECK-SAME: %[[PASS_THROUGH]] {alignment = 1 : i32} :
+//  CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>">
+
+//
+// 1. Bitcast to vector form.
+//       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
+//  CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+//       CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
+//  CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
+//
+// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+//       CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(
+//  CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+//  CHECK-SAME:  vector<17xi64>) : !llvm<"<17 x i64>">
+//
+// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+//       CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32,
+//  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+//  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] :
+//  CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
+//       CHECK: llvm.add
+//
+// 4. Let dim the memref dimension, compute the vector comparison mask:
+//    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+//       CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32,
+//  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+//  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] :
+//  CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
+//       CHECK: %[[mask_b:.*]] = llvm.icmp "slt" {{.*}} : !llvm<"<17 x i64>">
+//
+// 5. Rewrite as a masked write.
+//       CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
+//  CHECK-SAME: {alignment = 1 : i32} :
+//  CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*">


        


More information about the Mlir-commits mailing list