[Mlir-commits] [mlir] 060c9dd - [mlir] [VectorOps] Improve SIMD compares with narrower indices
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 3 21:43:55 PDT 2020
Author: aartbik
Date: 2020-09-03T21:43:38-07:00
New Revision: 060c9dd1cc467cbeb6cf1c29dd44d07f562606b4
URL: https://github.com/llvm/llvm-project/commit/060c9dd1cc467cbeb6cf1c29dd44d07f562606b4
DIFF: https://github.com/llvm/llvm-project/commit/060c9dd1cc467cbeb6cf1c29dd44d07f562606b4.diff
LOG: [mlir] [VectorOps] Improve SIMD compares with narrower indices
When allowed, use 32-bit indices rather than 64-bit indices in the
SIMD computation of masks. This runs up to 2x and 4x faster on
a number of AVX2 and AVX512 microbenchmarks.
Reviewed By: bkramer
Differential Revision: https://reviews.llvm.org/D87116
Added:
mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6686e2865813..1b27a7308c7a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -358,7 +358,10 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
let options = [
Option<"reassociateFPReductions", "reassociate-fp-reductions",
"bool", /*default=*/"false",
- "Allows llvm to reassociate floating-point reductions for speed">
+ "Allows llvm to reassociate floating-point reductions for speed">,
+ Option<"enableIndexOptimizations", "enable-index-optimizations",
+ "bool", /*default=*/"false",
+ "Allows compiler to assume indices fit in 32-bit if that yields faster code">
];
}
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 82aa8287d90f..81ffa6328135 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -22,8 +22,13 @@ class OperationPass;
/// ConvertVectorToLLVM pass in include/mlir/Conversion/Passes.td
struct LowerVectorToLLVMOptions {
bool reassociateFPReductions = false;
- LowerVectorToLLVMOptions &setReassociateFPReductions(bool r) {
- reassociateFPReductions = r;
+ bool enableIndexOptimizations = false;
+ LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
+ reassociateFPReductions = b;
+ return *this;
+ }
+ LowerVectorToLLVMOptions &setEnableIndexOptimizations(bool b) {
+ enableIndexOptimizations = b;
return *this;
}
};
@@ -37,7 +42,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- bool reassociateFPReductions = false);
+ bool reassociateFPReductions = false,
+ bool enableIndexOptimizations = false);
/// Create a pass to convert vector operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index ecb047a1ad14..dfa204d17389 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -117,6 +117,49 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
return res;
}
+// Helper that returns a vector comparison that constructs a mask:
+// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
+//
+// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
+// much more compact, IR for this operation, but LLVM eventually
+// generates more elaborate instructions for this intrinsic since it
+// is very conservative on the boundary conditions.
+static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
+ Operation *op, bool enableIndexOptimizations,
+ int64_t dim, Value b, Value *off = nullptr) {
+ auto loc = op->getLoc();
+ // If we can assume all indices fit in 32-bit, we perform the vector
+ // comparison in 32-bit to get a higher degree of SIMD parallelism.
+ // Otherwise we perform the vector comparison using 64-bit indices.
+ Value indices;
+ Type idxType;
+ if (enableIndexOptimizations) {
+ SmallVector<int32_t, 4> values(dim);
+ for (int64_t d = 0; d < dim; d++)
+ values[d] = d;
+ indices =
+ rewriter.create<ConstantOp>(loc, rewriter.getI32VectorAttr(values));
+ idxType = rewriter.getI32Type();
+ } else {
+ SmallVector<int64_t, 4> values(dim);
+ for (int64_t d = 0; d < dim; d++)
+ values[d] = d;
+ indices =
+ rewriter.create<ConstantOp>(loc, rewriter.getI64VectorAttr(values));
+ idxType = rewriter.getI64Type();
+ }
+ // Add in an offset if requested.
+ if (off) {
+ Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
+ Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
+ indices = rewriter.create<AddIOp>(loc, ov, indices);
+ }
+ // Construct the vector comparison.
+ Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
+ Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+ return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
+}
+
// Helper that returns data layout alignment of an operation with memref.
template <typename T>
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
@@ -512,10 +555,10 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorReductionOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter,
- bool reassociateFP)
+ bool reassociateFPRed)
: ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
typeConverter),
- reassociateFPReductions(reassociateFP) {}
+ reassociateFPReductions(reassociateFPRed) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -589,6 +632,34 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
const bool reassociateFPReductions;
};
+/// Conversion pattern for a vector.create_mask (1-D only).
+class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
+public:
+ explicit VectorCreateMaskOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter,
+ bool enableIndexOpt)
+ : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
+ typeConverter),
+ enableIndexOptimizations(enableIndexOpt) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType = op->getResult(0).getType().cast<VectorType>();
+ int64_t rank = dstType.getRank();
+ if (rank == 1) {
+ rewriter.replaceOp(
+ op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
+ dstType.getDimSize(0), operands[0]));
+ return success();
+ }
+ return failure();
+ }
+
+private:
+ const bool enableIndexOptimizations;
+};
+
class VectorShuffleOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorShuffleOpConversion(MLIRContext *context,
@@ -1121,17 +1192,19 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
/// Conversion pattern that converts a 1-D vector transfer read/write op in a
/// sequence of:
-/// 1. Bitcast or addrspacecast to vector form.
-/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-/// 3. Create a mask where offsetVector is compared against memref upper bound.
-/// 4. Rewrite op as a masked read or write.
+/// 1. Get the source/dst address as an LLVM vector pointer.
+/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+/// 4. Create a mask where offsetVector is compared against memref upper bound.
+/// 5. Rewrite op as a masked read or write.
template <typename ConcreteOp>
class VectorTransferConversion : public ConvertToLLVMPattern {
public:
explicit VectorTransferConversion(MLIRContext *context,
- LLVMTypeConverter &typeConv)
- : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
- typeConv) {}
+ LLVMTypeConverter &typeConv,
+ bool enableIndexOpt)
+ : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
+ enableIndexOptimizations(enableIndexOpt) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -1155,7 +1228,6 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
Location loc = op->getLoc();
- Type i64Type = rewriter.getIntegerType(64);
MemRefType memRefType = xferOp.getMemRefType();
if (auto memrefVectorElementType =
@@ -1202,41 +1274,26 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
xferOp, operands, vectorDataPtr);
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
- unsigned vecWidth = vecTy.getVectorNumElements();
- VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
- SmallVector<int64_t, 8> indices;
- indices.reserve(vecWidth);
- for (unsigned i = 0; i < vecWidth; ++i)
- indices.push_back(i);
- Value linearIndices = rewriter.create<ConstantOp>(
- loc, vectorCmpType,
- DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
- linearIndices = rewriter.create<LLVM::DialectCastOp>(
- loc, toLLVMTy(vectorCmpType), linearIndices);
-
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
- // TODO: when the leaf transfer rank is k > 1 we need the last
- // `k` dimensions here.
- unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
- Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
- offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
- Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
- Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
-
// 4. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+ //
+ // TODO: when the leaf transfer rank is k > 1, we need the last `k`
+ // dimensions here.
+ unsigned vecWidth = vecTy.getVectorNumElements();
+ unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
+ Value off = *(xferOp.indices().begin() + lastIndex);
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
- dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
- dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
- Value mask =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
- mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
- mask);
+ Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
+ vecWidth, dim, &off);
// 5. Rewrite as a masked read / write.
return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
operands, vectorDataPtr, mask);
}
+
+private:
+ const bool enableIndexOptimizations;
};
class VectorPrintOpConversion : public ConvertToLLVMPattern {
@@ -1444,7 +1501,7 @@ class VectorExtractStridedSliceOpConversion
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
- bool reassociateFPReductions) {
+ bool reassociateFPReductions, bool enableIndexOptimizations) {
MLIRContext *ctx = converter.getDialect()->getContext();
// clang-format off
patterns.insert<VectorFMAOpNDRewritePattern,
@@ -1453,6 +1510,10 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorExtractStridedSliceOpConversion>(ctx);
patterns.insert<VectorReductionOpConversion>(
ctx, converter, reassociateFPReductions);
+ patterns.insert<VectorCreateMaskOpConversion,
+ VectorTransferConversion<TransferReadOp>,
+ VectorTransferConversion<TransferWriteOp>>(
+ ctx, converter, enableIndexOptimizations);
patterns
.insert<VectorShuffleOpConversion,
VectorExtractElementOpConversion,
@@ -1461,8 +1522,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertElementOpConversion,
VectorInsertOpConversion,
VectorPrintOpConversion,
- VectorTransferConversion<TransferReadOp>,
- VectorTransferConversion<TransferWriteOp>,
VectorTypeCastOpConversion,
VectorMaskedLoadOpConversion,
VectorMaskedStoreOpConversion,
@@ -1485,6 +1544,7 @@ struct LowerVectorToLLVMPass
: public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
this->reassociateFPReductions = options.reassociateFPReductions;
+ this->enableIndexOptimizations = options.enableIndexOptimizations;
}
void runOnOperation() override;
};
@@ -1505,15 +1565,14 @@ void LowerVectorToLLVMPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
OwningRewritePatternList patterns;
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
- populateVectorToLLVMConversionPatterns(converter, patterns,
- reassociateFPReductions);
+ populateVectorToLLVMConversionPatterns(
+ converter, patterns, reassociateFPReductions, enableIndexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
- if (failed(applyPartialConversion(getOperation(), target, patterns))) {
+ if (failed(applyPartialConversion(getOperation(), target, patterns)))
signalPassFailure();
- }
}
std::unique_ptr<OperationPass<ModuleOp>>
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 16d10e558b5e..332bfbe2f457 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1347,7 +1347,8 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
auto eltType = dstType.getElementType();
auto dimSizes = op.mask_dim_sizes();
int64_t rank = dimSizes.size();
- int64_t trueDim = dimSizes[0].cast<IntegerAttr>().getInt();
+ int64_t trueDim = std::min(dstType.getDimSize(0),
+ dimSizes[0].cast<IntegerAttr>().getInt());
if (rank == 1) {
// Express constant 1-D case in explicit vector form:
@@ -1402,21 +1403,8 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
int64_t rank = dstType.getRank();
Value idx = op.getOperand(0);
- if (rank == 1) {
- // Express dynamic 1-D case in explicit vector form:
- // mask = [0,1,..,n-1] < [a,a,..,a]
- SmallVector<int64_t, 4> values(dim);
- for (int64_t d = 0; d < dim; d++)
- values[d] = d;
- Value indices =
- rewriter.create<ConstantOp>(loc, rewriter.getI64VectorAttr(values));
- Value bound =
- rewriter.create<IndexCastOp>(loc, rewriter.getI64Type(), idx);
- Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
- rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, indices,
- bounds);
- return success();
- }
+ if (rank == 1)
+ return failure(); // leave for lowering
VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
new file mode 100644
index 000000000000..ec05e349897a
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=1' | FileCheck %s --check-prefix=CMP32
+// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=0' | FileCheck %s --check-prefix=CMP64
+
+// CMP32-LABEL: llvm.func @genbool_var_1d(
+// CMP32-SAME: %[[A:.*]]: !llvm.i64)
+// CMP32: %[[T0:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32>) : !llvm.vec<11 x i32>
+// CMP32: %[[T1:.*]] = llvm.trunc %[[A]] : !llvm.i64 to !llvm.i32
+// CMP32: %[[T2:.*]] = llvm.mlir.undef : !llvm.vec<11 x i32>
+// CMP32: %[[T3:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CMP32: %[[T4:.*]] = llvm.insertelement %[[T1]], %[[T2]][%[[T3]] : !llvm.i32] : !llvm.vec<11 x i32>
+// CMP32: %[[T5:.*]] = llvm.shufflevector %[[T4]], %[[T2]] [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm.vec<11 x i32>, !llvm.vec<11 x i32>
+// CMP32: %[[T6:.*]] = llvm.icmp "slt" %[[T0]], %[[T5]] : !llvm.vec<11 x i32>
+// CMP32: llvm.return %[[T6]] : !llvm.vec<11 x i1>
+
+// CMP64-LABEL: llvm.func @genbool_var_1d(
+// CMP64-SAME: %[[A:.*]]: !llvm.i64)
+// CMP64: %[[T0:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64>) : !llvm.vec<11 x i64>
+// CMP64: %[[T1:.*]] = llvm.mlir.undef : !llvm.vec<11 x i64>
+// CMP64: %[[T2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CMP64: %[[T3:.*]] = llvm.insertelement %[[A]], %[[T1]][%[[T2]] : !llvm.i32] : !llvm.vec<11 x i64>
+// CMP64: %[[T4:.*]] = llvm.shufflevector %[[T3]], %[[T1]] [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm.vec<11 x i64>, !llvm.vec<11 x i64>
+// CMP64: %[[T5:.*]] = llvm.icmp "slt" %[[T0]], %[[T4]] : !llvm.vec<11 x i64>
+// CMP64: llvm.return %[[T5]] : !llvm.vec<11 x i1>
+
+func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
+ %0 = vector.create_mask %arg0 : vector<11xi1>
+ return %0 : vector<11xi1>
+}
+
+// CMP32-LABEL: llvm.func @transfer_read_1d
+// CMP32: %[[C:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>) : !llvm.vec<16 x i32>
+// CMP32: %[[A:.*]] = llvm.add %{{.*}}, %[[C]] : !llvm.vec<16 x i32>
+// CMP32: %[[M:.*]] = llvm.icmp "slt" %[[A]], %{{.*}} : !llvm.vec<16 x i32>
+// CMP32: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}}
+// CMP32: llvm.return %[[L]] : !llvm.vec<16 x float>
+
+// CMP64-LABEL: llvm.func @transfer_read_1d
+// CMP64: %[[C:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi64>) : !llvm.vec<16 x i64>
+// CMP64: %[[A:.*]] = llvm.add %{{.*}}, %[[C]] : !llvm.vec<16 x i64>
+// CMP64: %[[M:.*]] = llvm.icmp "slt" %[[A]], %{{.*}} : !llvm.vec<16 x i64>
+// CMP64: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}}
+// CMP64: llvm.return %[[L]] : !llvm.vec<16 x float>
+
+func @transfer_read_1d(%A : memref<?xf32>, %i: index) -> vector<16xf32> {
+ %d = constant -1.0: f32
+ %f = vector.transfer_read %A[%i], %d {permutation_map = affine_map<(d0) -> (d0)>} : memref<?xf32>, vector<16xf32>
+ return %f : vector<16xf32>
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d35c7fa645b7..e0800c2fd227 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -749,10 +749,12 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// CHECK-SAME: (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
// CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<vec<17 x float>>
+// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
+// CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-// CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(
-// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+// CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(dense
+// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17xi64>) : !llvm.vec<17 x i64>
//
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
@@ -770,8 +772,6 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
//
// 4. Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
-// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
-// CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64>
// CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :
@@ -799,9 +799,9 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
// CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<vec<17 x float>>
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-// CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(
-// CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
-// CHECK-SAME: vector<17xi64>) : !llvm.vec<17 x i64>
+// CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(dense
+// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+// CHECK-SAME: vector<17xi64>) : !llvm.vec<17 x i64>
//
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32,
@@ -832,6 +832,8 @@ func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index)
}
// CHECK-LABEL: func @transfer_read_2d_to_1d
// CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: !llvm.i64, %[[BASE_1:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm.vec<17 x float>
+// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 1] :
+// CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
//
// Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
// CHECK: %[[offsetVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64>
@@ -847,8 +849,6 @@ func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index)
// Let dim the memref dimension, compute the vector comparison mask:
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
// Here we check we properly use %DIM[1]
-// CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 1] :
-// CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64>
// CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index e34e3428c185..aaaa7adf6472 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -785,43 +785,63 @@ func @genbool_3d() -> vector<2x3x4xi1> {
return %v: vector<2x3x4xi1>
}
-// CHECK-LABEL: func @genbool_var_1d
-// CHECK-SAME: %[[A:.*]]: index
-// CHECK: %[[C1:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
-// CHECK: %[[T0:.*]] = index_cast %[[A]] : index to i64
-// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
-// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[T1]] : vector<3xi64>
-// CHECK: return %[[T2]] : vector<3xi1>
+// CHECK-LABEL: func @genbool_var_1d(
+// CHECK-SAME: %[[A:.*]]: index)
+// CHECK: %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1>
+// CHECK: return %[[T0]] : vector<3xi1>
func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
%0 = vector.create_mask %arg0 : vector<3xi1>
return %0 : vector<3xi1>
}
-// CHECK-LABEL: func @genbool_var_2d
-// CHECK-SAME: %[[A:.*0]]: index
-// CHECK-SAME: %[[B:.*1]]: index
-// CHECK: %[[CI:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
-// CHECK: %[[CF:.*]] = constant dense<false> : vector<3xi1>
+// CHECK-LABEL: func @genbool_var_2d(
+// CHECK-SAME: %[[A:.*0]]: index,
+// CHECK-SAME: %[[B:.*1]]: index)
+// CHECK: %[[C1:.*]] = constant dense<false> : vector<3xi1>
// CHECK: %[[C2:.*]] = constant dense<false> : vector<2x3xi1>
// CHECK: %[[c0:.*]] = constant 0 : index
// CHECK: %[[c1:.*]] = constant 1 : index
-// CHECK: %[[T0:.*]] = index_cast %[[B]] : index to i64
-// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
-// CHECK: %[[T2:.*]] = cmpi "slt", %[[CI]], %[[T1]] : vector<3xi64>
-// CHECK: %[[T3:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
-// CHECK: %[[T4:.*]] = select %[[T3]], %[[T2]], %[[CF]] : vector<3xi1>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
-// CHECK: %[[T6:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
-// CHECK: %[[T7:.*]] = select %[[T6]], %[[T2]], %[[CF]] : vector<3xi1>
-// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T5]] [1] : vector<3xi1> into vector<2x3xi1>
-// CHECK: return %[[T8]] : vector<2x3xi1>
+// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1>
+// CHECK: %[[T1:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
+// CHECK: %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
+// CHECK: %[[T4:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
+// CHECK: %[[T5:.*]] = select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1>
+// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1>
+// CHECK: return %[[T6]] : vector<2x3xi1>
func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
%0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
return %0 : vector<2x3xi1>
}
+// CHECK-LABEL: func @genbool_var_3d(
+// CHECK-SAME: %[[A:.*0]]: index,
+// CHECK-SAME: %[[B:.*1]]: index,
+// CHECK-SAME: %[[C:.*2]]: index)
+// CHECK: %[[C1:.*]] = constant dense<false> : vector<7xi1>
+// CHECK: %[[C2:.*]] = constant dense<false> : vector<1x7xi1>
+// CHECK: %[[C3:.*]] = constant dense<false> : vector<2x1x7xi1>
+// CHECK: %[[c0:.*]] = constant 0 : index
+// CHECK: %[[c1:.*]] = constant 1 : index
+// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1>
+// CHECK: %[[T1:.*]] = cmpi "slt", %[[c0]], %[[B]] : index
+// CHECK: %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1>
+// CHECK: %[[T4:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
+// CHECK: %[[T5:.*]] = select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1>
+// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1>
+// CHECK: %[[T7:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
+// CHECK: %[[T8:.*]] = select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1>
+// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1>
+// CHECK: return %[[T9]] : vector<2x1x7xi1>
+
+func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> {
+ %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1>
+ return %0 : vector<2x1x7xi1>
+}
+
#matmat_accesses_0 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
More information about the Mlir-commits
mailing list