[Mlir-commits] [mlir] 8345b86 - [mlir][Vector] Add lowering of 1-D vector transfer_read/write to masked load/store
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Apr 9 13:20:35 PDT 2020
Author: Nicolas Vasilache
Date: 2020-04-09T16:17:05-04:00
New Revision: 8345b86d9ac20c112c6f66b1bfbcf9c5c4158996
URL: https://github.com/llvm/llvm-project/commit/8345b86d9ac20c112c6f66b1bfbcf9c5c4158996
DIFF: https://github.com/llvm/llvm-project/commit/8345b86d9ac20c112c6f66b1bfbcf9c5c4158996.diff
LOG: [mlir][Vector] Add lowering of 1-D vector transfer_read/write to masked load/store
Summary:
This revision adds support to lower 1-D vector transfers to LLVM.
A mask of the vector length is created that compares the base offset + linear index to the dim of the vector.
In each position where this does not overflow (i.e. offset + vector index < dim), the mask is set to 1.
A notable fact is that the lowering uses llvm.dialect_cast to allow writing code in the simplest form by targeting the simplest mix of vector and LLVM dialects and
letting other conversions kick in.
Differential Revision: https://reviews.llvm.org/D77703
Added:
Modified:
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/IR/Builders.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/IR/Builders.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index fffdf3947a2e..555666753360 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -398,6 +398,29 @@ class ConvertToLLVMPattern : public ConversionPattern {
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
uint64_t value) const;
+ // Given subscript indices and array sizes in row-major order,
+ // i_n, i_{n-1}, ..., i_1
+ // s_n, s_{n-1}, ..., s_1
+ // obtain a value that corresponds to the linearized subscript
+ // \sum_k i_k * \prod_{j=1}^{k-1} s_j
+ // by accumulating the running linearized value.
+ // Note that `indices` and `allocSizes` are passed in the same order as they
+ // appear in load/store operations and memref type declarations.
+ Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
+ ArrayRef<Value> indices,
+ ArrayRef<Value> allocSizes) const;
+
+ // This is a strided getElementPtr variant that linearizes subscripts as:
+ // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
+ Value getStridedElementPtr(Location loc, Type elementTypePtr,
+ Value descriptor, ArrayRef<Value> indices,
+ ArrayRef<int64_t> strides, int64_t offset,
+ ConversionPatternRewriter &rewriter) const;
+
+ Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
+ ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
+ llvm::Module &module) const;
+
protected:
/// Reference to the type converter, with potential extensions.
LLVMTypeConverter &typeConverter;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 1ff2d6ae28bb..b33b38971249 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -73,6 +73,7 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
/// Vector type utilities.
LLVMType getVectorElementType();
+ unsigned getVectorNumElements();
bool isVectorTy();
/// Function type utilities.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 75f49e86d10a..bceba806024d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -111,6 +111,7 @@ class Builder {
IntegerAttr getI16IntegerAttr(int16_t value);
IntegerAttr getI32IntegerAttr(int32_t value);
IntegerAttr getI64IntegerAttr(int64_t value);
+ IntegerAttr getIndexAttr(int64_t value);
/// Signed and unsigned integer attribute getters.
IntegerAttr getSI32IntegerAttr(int32_t value);
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index a746af7cce61..57663a39e132 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -735,6 +735,61 @@ Value ConvertToLLVMPattern::createIndexConstant(
return createIndexAttrConstant(builder, loc, getIndexType(), value);
}
+Value ConvertToLLVMPattern::linearizeSubscripts(
+ ConversionPatternRewriter &builder, Location loc, ArrayRef<Value> indices,
+ ArrayRef<Value> allocSizes) const {
+ assert(indices.size() == allocSizes.size() &&
+ "mismatching number of indices and allocation sizes");
+ assert(!indices.empty() && "cannot linearize a 0-dimensional access");
+
+ Value linearized = indices.front();
+ for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
+ linearized = builder.create<LLVM::MulOp>(
+ loc, this->getIndexType(), ArrayRef<Value>{linearized, allocSizes[i]});
+ linearized = builder.create<LLVM::AddOp>(
+ loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
+ }
+ return linearized;
+}
+
+Value ConvertToLLVMPattern::getStridedElementPtr(
+ Location loc, Type elementTypePtr, Value descriptor,
+ ArrayRef<Value> indices, ArrayRef<int64_t> strides, int64_t offset,
+ ConversionPatternRewriter &rewriter) const {
+ MemRefDescriptor memRefDescriptor(descriptor);
+
+ Value base = memRefDescriptor.alignedPtr(rewriter, loc);
+ Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
+ ? memRefDescriptor.offset(rewriter, loc)
+ : this->createIndexConstant(rewriter, loc, offset);
+
+ for (int i = 0, e = indices.size(); i < e; ++i) {
+ Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
+ ? memRefDescriptor.stride(rewriter, loc, i)
+ : this->createIndexConstant(rewriter, loc, strides[i]);
+ Value additionalOffset =
+ rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
+ offsetValue =
+ rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
+ }
+ return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
+}
+
+Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
+ Value memRefDesc,
+ ArrayRef<Value> indices,
+ ConversionPatternRewriter &rewriter,
+ llvm::Module &module) const {
+ LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(type, strides, offset);
+ assert(succeeded(successStrides) && "unexpected non-strided memref");
+ (void)successStrides;
+ return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
+ offset, rewriter);
+}
+
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
@@ -1913,70 +1968,6 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
MemRefType type = cast<Derived>(op).getMemRefType();
return isSupportedMemRefType(type) ? success() : failure();
}
-
- // Given subscript indices and array sizes in row-major order,
- // i_n, i_{n-1}, ..., i_1
- // s_n, s_{n-1}, ..., s_1
- // obtain a value that corresponds to the linearized subscript
- // \sum_k i_k * \prod_{j=1}^{k-1} s_j
- // by accumulating the running linearized value.
- // Note that `indices` and `allocSizes` are passed in the same order as they
- // appear in load/store operations and memref type declarations.
- Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
- ArrayRef<Value> indices,
- ArrayRef<Value> allocSizes) const {
- assert(indices.size() == allocSizes.size() &&
- "mismatching number of indices and allocation sizes");
- assert(!indices.empty() && "cannot linearize a 0-dimensional access");
-
- Value linearized = indices.front();
- for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
- linearized = builder.create<LLVM::MulOp>(
- loc, this->getIndexType(),
- ArrayRef<Value>{linearized, allocSizes[i]});
- linearized = builder.create<LLVM::AddOp>(
- loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
- }
- return linearized;
- }
-
- // This is a strided getElementPtr variant that linearizes subscripts as:
- // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
- Value getStridedElementPtr(Location loc, Type elementTypePtr,
- Value descriptor, ArrayRef<Value> indices,
- ArrayRef<int64_t> strides, int64_t offset,
- ConversionPatternRewriter &rewriter) const {
- MemRefDescriptor memRefDescriptor(descriptor);
-
- Value base = memRefDescriptor.alignedPtr(rewriter, loc);
- Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
- ? memRefDescriptor.offset(rewriter, loc)
- : this->createIndexConstant(rewriter, loc, offset);
-
- for (int i = 0, e = indices.size(); i < e; ++i) {
- Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
- ? memRefDescriptor.stride(rewriter, loc, i)
- : this->createIndexConstant(rewriter, loc, strides[i]);
- Value additionalOffset =
- rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
- offsetValue =
- rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
- }
- return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
- }
-
- Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
- ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
- llvm::Module &module) const {
- LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto successStrides = getStridesAndOffset(type, strides, offset);
- assert(succeeded(successStrides) && "unexpected non-strided memref");
- (void)successStrides;
- return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
- offset, rewriter);
- }
};
// Load operation is lowered to obtaining a pointer to the indexed element
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index eb4bf3b6d0ef..38822fa79458 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
@@ -894,6 +895,129 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
}
};
+template <typename ConcreteOp>
+void replaceTransferOp(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter, Location loc,
+ Operation *op, ArrayRef<Value> operands, Value dataPtr,
+ Value mask);
+
+template <>
+void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter,
+ Location loc, Operation *op,
+ ArrayRef<Value> operands, Value dataPtr,
+ Value mask) {
+ auto xferOp = cast<TransferReadOp>(op);
+ auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+ VectorType fillType = xferOp.getVectorType();
+ Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
+ fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
+
+ auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
+ rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
+ op, vecTy, dataPtr, mask, ValueRange{fill},
+ rewriter.getI32IntegerAttr(1));
+}
+
+template <>
+void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
+ LLVMTypeConverter &typeConverter,
+ Location loc, Operation *op,
+ ArrayRef<Value> operands, Value dataPtr,
+ Value mask) {
+ auto adaptor = TransferWriteOpOperandAdaptor(operands);
+ rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
+ op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
+}
+
+static TransferReadOpOperandAdaptor
+getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
+ return TransferReadOpOperandAdaptor(operands);
+}
+
+static TransferWriteOpOperandAdaptor
+getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
+ return TransferWriteOpOperandAdaptor(operands);
+}
+
+/// Conversion pattern that converts a 1-D vector transfer read/write op in a
+/// sequence of:
+/// 1. Bitcast to vector form.
+/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+/// 3. Create a mask where offsetVector is compared against memref upper bound.
+/// 4. Rewrite op as a masked read or write.
+template <typename ConcreteOp>
+class VectorTransferConversion : public ConvertToLLVMPattern {
+public:
+ explicit VectorTransferConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConv)
+ : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
+ typeConv) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto xferOp = cast<ConcreteOp>(op);
+ auto adaptor = getTransferOpAdapter(xferOp, operands);
+ if (xferOp.getMemRefType().getRank() != 1)
+ return failure();
+ if (!xferOp.permutation_map().isIdentity())
+ return failure();
+
+ auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+
+ Location loc = op->getLoc();
+ Type i64Type = rewriter.getIntegerType(64);
+ MemRefType memRefType = xferOp.getMemRefType();
+
+ // 1. Get the source/dst address as an LLVM vector pointer.
+ // TODO: support alignment when possible.
+ Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
+ adaptor.indices(), rewriter, getModule());
+ auto vecTy =
+ toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
+ auto vectorDataPtr =
+ rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
+
+ // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+ unsigned vecWidth = vecTy.getVectorNumElements();
+ VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
+ SmallVector<int64_t, 8> indices;
+ indices.reserve(vecWidth);
+ for (unsigned i = 0; i < vecWidth; ++i)
+ indices.push_back(i);
+ Value linearIndices = rewriter.create<ConstantOp>(
+ loc, vectorCmpType,
+ DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
+ linearIndices = rewriter.create<LLVM::DialectCastOp>(
+ loc, toLLVMTy(vectorCmpType), linearIndices);
+
+ // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+ Value offsetIndex = *(xferOp.indices().begin());
+ offsetIndex = rewriter.create<IndexCastOp>(
+ loc, vectorCmpType.getElementType(), offsetIndex);
+ Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
+ Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
+
+ // 4. Let dim the memref dimension, compute the vector comparison mask:
+ // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+ Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), 0);
+ dim =
+ rewriter.create<IndexCastOp>(loc, vectorCmpType.getElementType(), dim);
+ dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
+ Value mask =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
+ mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
+ mask);
+
+ // 5. Rewrite as a masked read / write.
+ replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op, operands,
+ vectorDataPtr, mask);
+
+ return success();
+ }
+};
+
class VectorPrintOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorPrintOpConversion(MLIRContext *context,
@@ -1079,16 +1203,25 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
+ // clang-format off
patterns.insert<VectorFMAOpNDRewritePattern,
VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorStridedSliceOpConversion>(ctx);
- patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
- VectorShuffleOpConversion, VectorExtractElementOpConversion,
- VectorExtractOpConversion, VectorFMAOp1DConversion,
- VectorInsertElementOpConversion, VectorInsertOpConversion,
- VectorTypeCastOpConversion, VectorPrintOpConversion>(
- ctx, converter);
+ patterns
+ .insert<VectorBroadcastOpConversion,
+ VectorReductionOpConversion,
+ VectorShuffleOpConversion,
+ VectorExtractElementOpConversion,
+ VectorExtractOpConversion,
+ VectorFMAOp1DConversion,
+ VectorInsertElementOpConversion,
+ VectorInsertOpConversion,
+ VectorPrintOpConversion,
+ VectorTransferConversion<TransferReadOp>,
+ VectorTransferConversion<TransferWriteOp>,
+ VectorTypeCastOpConversion>(ctx, converter);
+ // clang-format on
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c69530b28e29..82bbe18dd01e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1774,6 +1774,9 @@ bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); }
LLVMType LLVMType::getVectorElementType() {
return get(getContext(), getUnderlyingType()->getVectorElementType());
}
+unsigned LLVMType::getVectorNumElements() {
+ return getUnderlyingType()->getVectorNumElements();
+}
bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
/// Function type utilities.
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index c8d5ea6b6ca9..84a883c64d24 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -93,6 +93,10 @@ DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
return DictionaryAttr::get(value, context);
}
+IntegerAttr Builder::getIndexAttr(int64_t value) {
+ return IntegerAttr::get(getIndexType(), APInt(64, value));
+}
+
IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 0cc6789a0619..6a65b219b632 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -738,3 +738,95 @@ func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
// CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>">
+
+func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
+ %f7 = constant 7.0: f32
+ %f = vector.transfer_read %A[%base], %f7
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<17xf32>
+ vector.transfer_write %f, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<17xf32>, memref<?xf32>
+ return %f: vector<17xf32>
+}
+// CHECK-LABEL: func @transfer_read_1d
+// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>">
+//
+// 1. Bitcast to vector form.
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
+// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
+//
+// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(
+// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+// CHECK-SAME: vector<17xi64>) : !llvm<"<17 x i64>">
+//
+// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+// CHECK: %[[offsetVec:.*]] = llvm.mlir.undef : !llvm<"<17 x i64>">
+// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[offsetVec2:.*]] = llvm.insertelement %[[BASE]], %[[offsetVec]][%[[c0]] :
+// CHECK-SAME: !llvm.i32] : !llvm<"<17 x i64>">
+// CHECK: %[[offsetVec3:.*]] = llvm.shufflevector %[[offsetVec2]], %{{.*}} [
+// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32] :
+// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
+// CHECK: %[[offsetVec4:.*]] = llvm.add %[[offsetVec3]], %[[linearIndex]] :
+// CHECK-SAME: !llvm<"<17 x i64>">
+//
+// 4. Let dim the memref dimension, compute the vector comparison mask:
+// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
+// CHECK-SAME: !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+// CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm<"<17 x i64>">
+// CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :
+// CHECK-SAME: !llvm.i32] : !llvm<"<17 x i64>">
+// CHECK: %[[dimVec3:.*]] = llvm.shufflevector %[[dimVec2]], %{{.*}} [
+// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32] :
+// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
+// CHECK: %[[mask:.*]] = llvm.icmp "slt" %[[offsetVec4]], %[[dimVec3]] :
+// CHECK-SAME: !llvm<"<17 x i64>">
+//
+// 5. Rewrite as a masked read.
+// CHECK: %[[PASS_THROUGH:.*]] = llvm.mlir.constant(dense<7.000000e+00> :
+// CHECK-SAME: vector<17xf32>) : !llvm<"<17 x float>">
+// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
+// CHECK-SAME: %[[PASS_THROUGH]] {alignment = 1 : i32} :
+// CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>">
+
+//
+// 1. Bitcast to vector form.
+// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] :
+// CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
+//
+// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(
+// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+// CHECK-SAME: vector<17xi64>) : !llvm<"<17 x i64>">
+//
+// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32,
+// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] :
+// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
+// CHECK: llvm.add
+//
+// 4. Let dim the memref dimension, compute the vector comparison mask:
+// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32,
+// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
+// CHECK-SAME: 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] :
+// CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
+// CHECK: %[[mask_b:.*]] = llvm.icmp "slt" {{.*}} : !llvm<"<17 x i64>">
+//
+// 5. Rewrite as a masked write.
+// CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
+// CHECK-SAME: {alignment = 1 : i32} :
+// CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*">
More information about the Mlir-commits
mailing list