[Mlir-commits] [mlir] 060c9dd - [mlir] [VectorOps] Improve SIMD compares with narrower indices

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 3 21:43:55 PDT 2020


Author: aartbik
Date: 2020-09-03T21:43:38-07:00
New Revision: 060c9dd1cc467cbeb6cf1c29dd44d07f562606b4

URL: https://github.com/llvm/llvm-project/commit/060c9dd1cc467cbeb6cf1c29dd44d07f562606b4
DIFF: https://github.com/llvm/llvm-project/commit/060c9dd1cc467cbeb6cf1c29dd44d07f562606b4.diff

LOG: [mlir] [VectorOps] Improve SIMD compares with narrower indices

When allowed, use 32-bit indices rather than 64-bit indices in the
SIMD computation of masks. This runs up to 2x and 4x faster on
a number of AVX2 and AVX512 microbenchmarks.

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D87116

Added: 
    mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6686e2865813..1b27a7308c7a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -358,7 +358,10 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
   let options = [
     Option<"reassociateFPReductions", "reassociate-fp-reductions",
            "bool", /*default=*/"false",
-           "Allows llvm to reassociate floating-point reductions for speed">
+           "Allows llvm to reassociate floating-point reductions for speed">,
+    Option<"enableIndexOptimizations", "enable-index-optimizations",
+           "bool", /*default=*/"false",
+           "Allows compiler to assume indices fit in 32-bit if that yields faster code">
   ];
 }
 

diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 82aa8287d90f..81ffa6328135 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -22,8 +22,13 @@ class OperationPass;
 /// ConvertVectorToLLVM pass in include/mlir/Conversion/Passes.td
 struct LowerVectorToLLVMOptions {
   bool reassociateFPReductions = false;
-  LowerVectorToLLVMOptions &setReassociateFPReductions(bool r) {
-    reassociateFPReductions = r;
+  bool enableIndexOptimizations = false;
+  LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
+    reassociateFPReductions = b;
+    return *this;
+  }
+  LowerVectorToLLVMOptions &setEnableIndexOptimizations(bool b) {
+    enableIndexOptimizations = b;
     return *this;
   }
 };
@@ -37,7 +42,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
-    bool reassociateFPReductions = false);
+    bool reassociateFPReductions = false,
+    bool enableIndexOptimizations = false);
 
 /// Create a pass to convert vector operations to the LLVMIR dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index ecb047a1ad14..dfa204d17389 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -117,6 +117,49 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
   return res;
 }
 
+// Helper that returns a vector comparison that constructs a mask:
+//     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
+//
+// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
+//       much more compact, IR for this operation, but LLVM eventually
+//       generates more elaborate instructions for this intrinsic since it
+//       is very conservative on the boundary conditions.
+static Value buildVectorComparison(ConversionPatternRewriter &rewriter,
+                                   Operation *op, bool enableIndexOptimizations,
+                                   int64_t dim, Value b, Value *off = nullptr) {
+  auto loc = op->getLoc();
+  // If we can assume all indices fit in 32-bit, we perform the vector
+  // comparison in 32-bit to get a higher degree of SIMD parallelism.
+  // Otherwise we perform the vector comparison using 64-bit indices.
+  Value indices;
+  Type idxType;
+  if (enableIndexOptimizations) {
+    SmallVector<int32_t, 4> values(dim);
+    for (int64_t d = 0; d < dim; d++)
+      values[d] = d;
+    indices =
+        rewriter.create<ConstantOp>(loc, rewriter.getI32VectorAttr(values));
+    idxType = rewriter.getI32Type();
+  } else {
+    SmallVector<int64_t, 4> values(dim);
+    for (int64_t d = 0; d < dim; d++)
+      values[d] = d;
+    indices =
+        rewriter.create<ConstantOp>(loc, rewriter.getI64VectorAttr(values));
+    idxType = rewriter.getI64Type();
+  }
+  // Add in an offset if requested.
+  if (off) {
+    Value o = rewriter.create<IndexCastOp>(loc, idxType, *off);
+    Value ov = rewriter.create<SplatOp>(loc, indices.getType(), o);
+    indices = rewriter.create<AddIOp>(loc, ov, indices);
+  }
+  // Construct the vector comparison.
+  Value bound = rewriter.create<IndexCastOp>(loc, idxType, b);
+  Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+  return rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, indices, bounds);
+}
+
 // Helper that returns data layout alignment of an operation with memref.
 template <typename T>
 LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
@@ -512,10 +555,10 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorReductionOpConversion(MLIRContext *context,
                                        LLVMTypeConverter &typeConverter,
-                                       bool reassociateFP)
+                                       bool reassociateFPRed)
       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
                              typeConverter),
-        reassociateFPReductions(reassociateFP) {}
+        reassociateFPReductions(reassociateFPRed) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -589,6 +632,34 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
   const bool reassociateFPReductions;
 };
 
+/// Conversion pattern for a vector.create_mask (1-D only).
+class VectorCreateMaskOpConversion : public ConvertToLLVMPattern {
+public:
+  explicit VectorCreateMaskOpConversion(MLIRContext *context,
+                                        LLVMTypeConverter &typeConverter,
+                                        bool enableIndexOpt)
+      : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context,
+                             typeConverter),
+        enableIndexOptimizations(enableIndexOpt) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstType = op->getResult(0).getType().cast<VectorType>();
+    int64_t rank = dstType.getRank();
+    if (rank == 1) {
+      rewriter.replaceOp(
+          op, buildVectorComparison(rewriter, op, enableIndexOptimizations,
+                                    dstType.getDimSize(0), operands[0]));
+      return success();
+    }
+    return failure();
+  }
+
+private:
+  const bool enableIndexOptimizations;
+};
+
 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorShuffleOpConversion(MLIRContext *context,
@@ -1121,17 +1192,19 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
 
 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
 /// sequence of:
-/// 1. Bitcast or addrspacecast to vector form.
-/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-/// 3. Create a mask where offsetVector is compared against memref upper bound.
-/// 4. Rewrite op as a masked read or write.
+/// 1. Get the source/dst address as an LLVM vector pointer.
+/// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+/// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
+/// 4. Create a mask where offsetVector is compared against memref upper bound.
+/// 5. Rewrite op as a masked read or write.
 template <typename ConcreteOp>
 class VectorTransferConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorTransferConversion(MLIRContext *context,
-                                    LLVMTypeConverter &typeConv)
-      : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
-                             typeConv) {}
+                                    LLVMTypeConverter &typeConv,
+                                    bool enableIndexOpt)
+      : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv),
+        enableIndexOptimizations(enableIndexOpt) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -1155,7 +1228,6 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
 
     Location loc = op->getLoc();
-    Type i64Type = rewriter.getIntegerType(64);
     MemRefType memRefType = xferOp.getMemRefType();
 
     if (auto memrefVectorElementType =
@@ -1202,41 +1274,26 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
                                               xferOp, operands, vectorDataPtr);
 
     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-    unsigned vecWidth = vecTy.getVectorNumElements();
-    VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
-    SmallVector<int64_t, 8> indices;
-    indices.reserve(vecWidth);
-    for (unsigned i = 0; i < vecWidth; ++i)
-      indices.push_back(i);
-    Value linearIndices = rewriter.create<ConstantOp>(
-        loc, vectorCmpType,
-        DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
-    linearIndices = rewriter.create<LLVM::DialectCastOp>(
-        loc, toLLVMTy(vectorCmpType), linearIndices);
-
     // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-    // TODO: when the leaf transfer rank is k > 1 we need the last
-    // `k` dimensions here.
-    unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
-    Value offsetIndex = *(xferOp.indices().begin() + lastIndex);
-    offsetIndex = rewriter.create<IndexCastOp>(loc, i64Type, offsetIndex);
-    Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
-    Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
-
     // 4. Let dim the memref dimension, compute the vector comparison mask:
     //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+    //
+    // TODO: when the leaf transfer rank is k > 1, we need the last `k`
+    //       dimensions here.
+    unsigned vecWidth = vecTy.getVectorNumElements();
+    unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
+    Value off = *(xferOp.indices().begin() + lastIndex);
     Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), lastIndex);
-    dim = rewriter.create<IndexCastOp>(loc, i64Type, dim);
-    dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
-    Value mask =
-        rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
-    mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
-                                                mask);
+    Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations,
+                                       vecWidth, dim, &off);
 
     // 5. Rewrite as a masked read / write.
     return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp,
                                        operands, vectorDataPtr, mask);
   }
+
+private:
+  const bool enableIndexOptimizations;
 };
 
 class VectorPrintOpConversion : public ConvertToLLVMPattern {
@@ -1444,7 +1501,7 @@ class VectorExtractStridedSliceOpConversion
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
-    bool reassociateFPReductions) {
+    bool reassociateFPReductions, bool enableIndexOptimizations) {
   MLIRContext *ctx = converter.getDialect()->getContext();
   // clang-format off
   patterns.insert<VectorFMAOpNDRewritePattern,
@@ -1453,6 +1510,10 @@ void mlir::populateVectorToLLVMConversionPatterns(
                   VectorExtractStridedSliceOpConversion>(ctx);
   patterns.insert<VectorReductionOpConversion>(
       ctx, converter, reassociateFPReductions);
+  patterns.insert<VectorCreateMaskOpConversion,
+                  VectorTransferConversion<TransferReadOp>,
+                  VectorTransferConversion<TransferWriteOp>>(
+      ctx, converter, enableIndexOptimizations);
   patterns
       .insert<VectorShuffleOpConversion,
               VectorExtractElementOpConversion,
@@ -1461,8 +1522,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
               VectorInsertElementOpConversion,
               VectorInsertOpConversion,
               VectorPrintOpConversion,
-              VectorTransferConversion<TransferReadOp>,
-              VectorTransferConversion<TransferWriteOp>,
               VectorTypeCastOpConversion,
               VectorMaskedLoadOpConversion,
               VectorMaskedStoreOpConversion,
@@ -1485,6 +1544,7 @@ struct LowerVectorToLLVMPass
     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
     this->reassociateFPReductions = options.reassociateFPReductions;
+    this->enableIndexOptimizations = options.enableIndexOptimizations;
   }
   void runOnOperation() override;
 };
@@ -1505,15 +1565,14 @@ void LowerVectorToLLVMPass::runOnOperation() {
   LLVMTypeConverter converter(&getContext());
   OwningRewritePatternList patterns;
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
-  populateVectorToLLVMConversionPatterns(converter, patterns,
-                                         reassociateFPReductions);
+  populateVectorToLLVMConversionPatterns(
+      converter, patterns, reassociateFPReductions, enableIndexOptimizations);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateStdToLLVMConversionPatterns(converter, patterns);
 
   LLVMConversionTarget target(getContext());
-  if (failed(applyPartialConversion(getOperation(), target, patterns))) {
+  if (failed(applyPartialConversion(getOperation(), target, patterns)))
     signalPassFailure();
-  }
 }
 
 std::unique_ptr<OperationPass<ModuleOp>>

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 16d10e558b5e..332bfbe2f457 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1347,7 +1347,8 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
     auto eltType = dstType.getElementType();
     auto dimSizes = op.mask_dim_sizes();
     int64_t rank = dimSizes.size();
-    int64_t trueDim = dimSizes[0].cast<IntegerAttr>().getInt();
+    int64_t trueDim = std::min(dstType.getDimSize(0),
+                               dimSizes[0].cast<IntegerAttr>().getInt());
 
     if (rank == 1) {
       // Express constant 1-D case in explicit vector form:
@@ -1402,21 +1403,8 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
     int64_t rank = dstType.getRank();
     Value idx = op.getOperand(0);
 
-    if (rank == 1) {
-      // Express dynamic 1-D case in explicit vector form:
-      //   mask = [0,1,..,n-1] < [a,a,..,a]
-      SmallVector<int64_t, 4> values(dim);
-      for (int64_t d = 0; d < dim; d++)
-        values[d] = d;
-      Value indices =
-          rewriter.create<ConstantOp>(loc, rewriter.getI64VectorAttr(values));
-      Value bound =
-          rewriter.create<IndexCastOp>(loc, rewriter.getI64Type(), idx);
-      Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
-      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, indices,
-                                          bounds);
-      return success();
-    }
+    if (rank == 1)
+      return failure(); // leave for lowering
 
     VectorType lowType =
         VectorType::get(dstType.getShape().drop_front(), eltType);

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
new file mode 100644
index 000000000000..ec05e349897a
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=1' | FileCheck %s --check-prefix=CMP32
+// RUN: mlir-opt %s --convert-vector-to-llvm='enable-index-optimizations=0' | FileCheck %s --check-prefix=CMP64
+
+// CMP32-LABEL: llvm.func @genbool_var_1d(
+// CMP32-SAME: %[[A:.*]]: !llvm.i64)
+// CMP32: %[[T0:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32>) : !llvm.vec<11 x i32>
+// CMP32: %[[T1:.*]] = llvm.trunc %[[A]] : !llvm.i64 to !llvm.i32
+// CMP32: %[[T2:.*]] = llvm.mlir.undef : !llvm.vec<11 x i32>
+// CMP32: %[[T3:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CMP32: %[[T4:.*]] = llvm.insertelement %[[T1]], %[[T2]][%[[T3]] : !llvm.i32] : !llvm.vec<11 x i32>
+// CMP32: %[[T5:.*]] = llvm.shufflevector %[[T4]], %[[T2]] [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm.vec<11 x i32>, !llvm.vec<11 x i32>
+// CMP32: %[[T6:.*]] = llvm.icmp "slt" %[[T0]], %[[T5]] : !llvm.vec<11 x i32>
+// CMP32: llvm.return %[[T6]] : !llvm.vec<11 x i1>
+
+// CMP64-LABEL: llvm.func @genbool_var_1d(
+// CMP64-SAME: %[[A:.*]]: !llvm.i64)
+// CMP64: %[[T0:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64>) : !llvm.vec<11 x i64>
+// CMP64: %[[T1:.*]] = llvm.mlir.undef : !llvm.vec<11 x i64>
+// CMP64: %[[T2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+// CMP64: %[[T3:.*]] = llvm.insertelement %[[A]], %[[T1]][%[[T2]] : !llvm.i32] : !llvm.vec<11 x i64>
+// CMP64: %[[T4:.*]] = llvm.shufflevector %[[T3]], %[[T1]] [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm.vec<11 x i64>, !llvm.vec<11 x i64>
+// CMP64: %[[T5:.*]] = llvm.icmp "slt" %[[T0]], %[[T4]] : !llvm.vec<11 x i64>
+// CMP64: llvm.return %[[T5]] : !llvm.vec<11 x i1>
+
+func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
+  %0 = vector.create_mask %arg0 : vector<11xi1>
+  return %0 : vector<11xi1>
+}
+
+// CMP32-LABEL: llvm.func @transfer_read_1d
+// CMP32: %[[C:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>) : !llvm.vec<16 x i32>
+// CMP32: %[[A:.*]] = llvm.add %{{.*}}, %[[C]] : !llvm.vec<16 x i32>
+// CMP32: %[[M:.*]] = llvm.icmp "slt" %[[A]], %{{.*}} : !llvm.vec<16 x i32>
+// CMP32: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}}
+// CMP32: llvm.return %[[L]] : !llvm.vec<16 x float>
+
+// CMP64-LABEL: llvm.func @transfer_read_1d
+// CMP64: %[[C:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi64>) : !llvm.vec<16 x i64>
+// CMP64: %[[A:.*]] = llvm.add %{{.*}}, %[[C]] : !llvm.vec<16 x i64>
+// CMP64: %[[M:.*]] = llvm.icmp "slt" %[[A]], %{{.*}} : !llvm.vec<16 x i64>
+// CMP64: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}}
+// CMP64: llvm.return %[[L]] : !llvm.vec<16 x float>
+
+func @transfer_read_1d(%A : memref<?xf32>, %i: index) -> vector<16xf32> {
+  %d = constant -1.0: f32
+  %f = vector.transfer_read %A[%i], %d {permutation_map = affine_map<(d0) -> (d0)>} : memref<?xf32>, vector<16xf32>
+  return %f : vector<16xf32>
+}

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d35c7fa645b7..e0800c2fd227 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -749,10 +749,12 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //  CHECK-SAME: (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
 //       CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
 //  CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<vec<17 x float>>
+//       CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
+//  CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
 //
 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(
-//  CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+//       CHECK: %[[linearIndex:.*]] = llvm.mlir.constant(dense
+//  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
 //  CHECK-SAME: vector<17xi64>) : !llvm.vec<17 x i64>
 //
 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
@@ -770,8 +772,6 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //
 // 4. Let dim the memref dimension, compute the vector comparison mask:
 //    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
-//       CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
-//  CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
 //       CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64>
 //       CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
 //       CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :
@@ -799,9 +799,9 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //  CHECK-SAME: !llvm.ptr<float> to !llvm.ptr<vec<17 x float>>
 //
 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(
-//  CHECK-SAME: dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
-//  CHECK-SAME:  vector<17xi64>) : !llvm.vec<17 x i64>
+//       CHECK: %[[linearIndex_b:.*]] = llvm.mlir.constant(dense
+//  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+//  CHECK-SAME: vector<17xi64>) : !llvm.vec<17 x i64>
 //
 // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
 //       CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32,
@@ -832,6 +832,8 @@ func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index)
 }
 // CHECK-LABEL: func @transfer_read_2d_to_1d
 //  CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: !llvm.i64, %[[BASE_1:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm.vec<17 x float>
+//       CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 1] :
+//  CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 //
 // Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
 //       CHECK: %[[offsetVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64>
@@ -847,8 +849,6 @@ func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index)
 // Let dim the memref dimension, compute the vector comparison mask:
 //    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
 // Here we check we properly use %DIM[1]
-//       CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 1] :
-//  CHECK-SAME: !llvm.struct<(ptr<float>, ptr<float>, i64, array<2 x i64>, array<2 x i64>)>
 //       CHECK: %[[dimVec:.*]] = llvm.mlir.undef : !llvm.vec<17 x i64>
 //       CHECK: %[[c01:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
 //       CHECK: %[[dimVec2:.*]] = llvm.insertelement %[[DIM]], %[[dimVec]][%[[c01]] :

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index e34e3428c185..aaaa7adf6472 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -785,43 +785,63 @@ func @genbool_3d() -> vector<2x3x4xi1> {
   return %v: vector<2x3x4xi1>
 }
 
-// CHECK-LABEL: func @genbool_var_1d
-// CHECK-SAME: %[[A:.*]]: index
-// CHECK:      %[[C1:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
-// CHECK:      %[[T0:.*]] = index_cast %[[A]] : index to i64
-// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
-// CHECK:      %[[T2:.*]] = cmpi "slt", %[[C1]], %[[T1]] : vector<3xi64>
-// CHECK:      return %[[T2]] : vector<3xi1>
+// CHECK-LABEL: func @genbool_var_1d(
+// CHECK-SAME: %[[A:.*]]: index)
+// CHECK:      %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1>
+// CHECK:      return %[[T0]] : vector<3xi1>
 
 func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
   %0 = vector.create_mask %arg0 : vector<3xi1>
   return %0 : vector<3xi1>
 }
 
-// CHECK-LABEL: func @genbool_var_2d
-// CHECK-SAME: %[[A:.*0]]: index
-// CHECK-SAME: %[[B:.*1]]: index
-// CHECK:      %[[CI:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
-// CHECK:      %[[CF:.*]] = constant dense<false> : vector<3xi1>
+// CHECK-LABEL: func @genbool_var_2d(
+// CHECK-SAME: %[[A:.*0]]: index,
+// CHECK-SAME: %[[B:.*1]]: index)
+// CHECK:      %[[C1:.*]] = constant dense<false> : vector<3xi1>
 // CHECK:      %[[C2:.*]] = constant dense<false> : vector<2x3xi1>
 // CHECK:      %[[c0:.*]] = constant 0 : index
 // CHECK:      %[[c1:.*]] = constant 1 : index
-// CHECK:      %[[T0:.*]] = index_cast %[[B]] : index to i64
-// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
-// CHECK:      %[[T2:.*]] = cmpi "slt", %[[CI]], %[[T1]] : vector<3xi64>
-// CHECK:      %[[T3:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
-// CHECK:      %[[T4:.*]] = select %[[T3]], %[[T2]], %[[CF]] : vector<3xi1>
-// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
-// CHECK:      %[[T6:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
-// CHECK:      %[[T7:.*]] = select %[[T6]], %[[T2]], %[[CF]] : vector<3xi1>
-// CHECK:      %[[T8:.*]] = vector.insert %[[T7]], %[[T5]] [1] : vector<3xi1> into vector<2x3xi1>
-// CHECK:      return %[[T8]] : vector<2x3xi1>
+// CHECK:      %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1>
+// CHECK:      %[[T1:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
+// CHECK:      %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1>
+// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
+// CHECK:      %[[T4:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
+// CHECK:      %[[T5:.*]] = select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1>
+// CHECK:      %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1>
+// CHECK:      return %[[T6]] : vector<2x3xi1>
 
 func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
   %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
   return %0 : vector<2x3xi1>
 }
 
+// CHECK-LABEL: func @genbool_var_3d(
+// CHECK-SAME: %[[A:.*0]]: index,
+// CHECK-SAME: %[[B:.*1]]: index,
+// CHECK-SAME: %[[C:.*2]]: index)
+// CHECK:      %[[C1:.*]] = constant dense<false> : vector<7xi1>
+// CHECK:      %[[C2:.*]] = constant dense<false> : vector<1x7xi1>
+// CHECK:      %[[C3:.*]] = constant dense<false> : vector<2x1x7xi1>
+// CHECK:      %[[c0:.*]] = constant 0 : index
+// CHECK:      %[[c1:.*]] = constant 1 : index
+// CHECK:      %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1>
+// CHECK:      %[[T1:.*]] = cmpi "slt", %[[c0]], %[[B]] : index
+// CHECK:      %[[T2:.*]] = select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1>
+// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1>
+// CHECK:      %[[T4:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
+// CHECK:      %[[T5:.*]] = select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1>
+// CHECK:      %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1>
+// CHECK:      %[[T7:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
+// CHECK:      %[[T8:.*]] = select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1>
+// CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1>
+// CHECK:      return %[[T9]] : vector<2x1x7xi1>
+
+func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> {
+  %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1>
+  return %0 : vector<2x1x7xi1>
+}
+
 #matmat_accesses_0 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,


        


More information about the Mlir-commits mailing list