[Mlir-commits] [mlir] 7c5ecc8 - [mlir][vector] Insert/extract element can accept index

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 18 14:40:34 PST 2021


Author: Mogball
Date: 2021-11-18T22:40:29Z
New Revision: 7c5ecc8b7e1bcd1b02eafeba9bbf3d5bc50d72c5

URL: https://github.com/llvm/llvm-project/commit/7c5ecc8b7e1bcd1b02eafeba9bbf3d5bc50d72c5
DIFF: https://github.com/llvm/llvm-project/commit/7c5ecc8b7e1bcd1b02eafeba9bbf3d5bc50d72c5.diff

LOG: [mlir][vector] Insert/extract element can accept index

`vector::InsertElementOp` and `vector::ExtractElementOp` have had their `position`
operand changed to accept `AnySignlessIntegerOrIndex` for better operability with
operations that use `index`, such as affine loops.

LLVM's `extractelement` and `insertelement` can also accept `i64`, so lowering
directly to these operations without explicitly inserting casts is allowed. SPIRV's
equivalent ops can also accept `i64`.

Reviewed By: nicolasvasilache, jpienaar

Differential Revision: https://reviews.llvm.org/D114139

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
    mlir/test/Conversion/VectorToSPIRV/simple.mlir
    mlir/test/Dialect/SparseTensor/sparse_vector.mlir
    mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
    mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 06bb0a52e9d5f..bbd45b78ecaf2 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -482,7 +482,7 @@ def Vector_ExtractElementOp :
      TypesMatchWith<"result type matches element type of vector operand",
                     "vector", "result",
                     "$_self.cast<ShapedType>().getElementType()">]>,
-    Arguments<(ins AnyVector:$vector, AnySignlessInteger:$position)>,
+    Arguments<(ins AnyVector:$vector, AnySignlessIntegerOrIndex:$position)>,
     Results<(outs AnyType:$result)> {
   let summary = "extractelement operation";
   let description = [{
@@ -504,7 +504,6 @@ def Vector_ExtractElementOp :
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$source, "int64_t":$position)>,
     OpBuilder<(ins "Value":$source, "Value":$position)>
   ];
   let extraClassDeclaration = [{
@@ -658,7 +657,7 @@ def Vector_InsertElementOp :
                     "$_self.cast<ShapedType>().getElementType()">,
      AllTypesMatch<["dest", "result"]>]>,
      Arguments<(ins AnyType:$source, AnyVector:$dest,
-                    AnySignlessInteger:$position)>,
+                    AnySignlessIntegerOrIndex:$position)>,
      Results<(outs AnyVector:$result)> {
   let summary = "insertelement operation";
   let description = [{
@@ -683,7 +682,6 @@ def Vector_InsertElementOp :
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
     OpBuilder<(ins "Value":$source, "Value":$dest, "Value":$position)>
   ];
   let extraClassDeclaration = [{

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 1398bd7aac3c6..e088d037f950e 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -502,6 +502,10 @@ def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index",
                  "::mlir::IndexType">,
             BuildableType<"$_builder.getIndexType()">;
 
+// Any signless integer type or index type.
+def AnySignlessIntegerOrIndex : Type<CPred<"$_self.isSignlessIntOrIndex()">,
+                                     "signless integer or index">;
+
 // Floating point types.
 
 // Any float type irrespective of its width.
@@ -823,9 +827,9 @@ def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
 // Type constraint for signless-integer-like types: signless integers, indices,
 // vectors of signless integers or indices, tensors of signless integers.
 def SignlessIntegerLike : TypeConstraint<Or<[
-        AnySignlessInteger.predicate, Index.predicate,
-        VectorOf<[AnySignlessInteger, Index]>.predicate,
-        TensorOf<[AnySignlessInteger, Index]>.predicate]>,
+        AnySignlessIntegerOrIndex.predicate,
+        VectorOf<[AnySignlessIntegerOrIndex]>.predicate,
+        TensorOf<[AnySignlessIntegerOrIndex]>.predicate]>,
     "signless-integer-like">;
 
 // Type constraint for float-like types: floats, vectors or tensors thereof.

diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2fd4959884981..6d2c91f19bb68 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -123,9 +123,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
     return Value();
 
   Location loc = xferOp.getLoc();
-  Value ivI32 = b.create<arith::IndexCastOp>(
-      loc, IntegerType::get(b.getContext(), 32), iv);
-  return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
+  return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), iv);
 }
 
 /// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -1115,8 +1113,6 @@ struct Strategy1d<TransferReadOp> {
                                   ValueRange loopState) {
     SmallVector<Value, 8> indices;
     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
-    Value ivI32 = b.create<arith::IndexCastOp>(
-        loc, IntegerType::get(b.getContext(), 32), iv);
     auto vec = loopState[0];
 
     // In case of out-of-bounds access, leave `vec` as is (was initialized with
@@ -1126,7 +1122,7 @@ struct Strategy1d<TransferReadOp> {
         /*inBoundsCase=*/
         [&](OpBuilder &b, Location loc) {
           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
-          return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
+          return b.create<vector::InsertElementOp>(loc, val, vec, iv);
         },
         /*outOfBoundsCase=*/
         [&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1148,15 +1144,13 @@ struct Strategy1d<TransferWriteOp> {
                                   ValueRange /*loopState*/) {
     SmallVector<Value, 8> indices;
     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
-    Value ivI32 = b.create<arith::IndexCastOp>(
-        loc, IntegerType::get(b.getContext(), 32), iv);
 
     // Nothing to do in case of out-of-bounds access.
     generateInBoundsCheck(
         b, xferOp, iv, dim,
         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
           auto val =
-              b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
+              b.create<vector::ExtractElementOp>(loc, xferOp.vector(), iv);
           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
         });
     b.create<scf::YieldOp>(loc);

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 99e7a6684a909..7c63fce6919be 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -880,6 +880,11 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
   valueTypes.append(op->operand_type_begin(), op->operand_type_end());
   valueTypes.append(op->result_type_begin(), op->result_type_end());
 
+  // Ensure that all types have been converted to SPIRV types.
+  if (llvm::any_of(valueTypes,
+                   [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
+    return false;
+
   // Special treatment for global variables, whose type requirements are
   // conveyed by type attributes.
   if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 31d3ee520fbd7..676c86ec2f711 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -391,7 +391,8 @@ static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter,
     // Initialize reduction vector to: | 0 | .. | 0 | r |
     Attribute zero = rewriter.getZeroAttr(vtp);
     Value vec = rewriter.create<arith::ConstantOp>(loc, vtp, zero);
-    return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0);
+    return rewriter.create<vector::InsertElementOp>(
+        loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0));
   }
   case kProduct: {
     // Initialize reduction vector to: | 1 | .. | 1 | r |
@@ -403,7 +404,8 @@ static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter,
       one = rewriter.getIntegerAttr(etp, 1);
     Value vec = rewriter.create<arith::ConstantOp>(
         loc, vtp, DenseElementsAttr::get(vtp, one));
-    return rewriter.create<vector::InsertElementOp>(loc, r, vec, 0);
+    return rewriter.create<vector::InsertElementOp>(
+        loc, r, vec, rewriter.create<arith::ConstantIndexOp>(loc, 0));
   }
   case kAnd:
   case kOr:

diff  --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
index 84102f0fe2a50..0f1c4011b231c 100644
--- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
@@ -347,8 +347,9 @@ struct TwoDimMultiReductionToReduction
           loc, getElementTypeOrSelf(multiReductionOp.getDestType()),
           rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
           ValueRange{});
-      result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
-                                                        result, i);
+      result = rewriter.create<vector::InsertElementOp>(
+          loc, reducedValue, result,
+          rewriter.create<arith::ConstantIndexOp>(loc, i));
     }
     rewriter.replaceOp(multiReductionOp, result);
     return success();

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index b03c4ecf867ae..3ca8fa0dcf0be 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -838,13 +838,6 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
   result.addTypes(source.getType().cast<VectorType>().getElementType());
 }
 
-void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
-                                     Value source, int64_t position) {
-  Value pos =
-      builder.create<arith::ConstantIntOp>(result.location, position, 32);
-  build(builder, result, source, pos);
-}
-
 static LogicalResult verify(vector::ExtractElementOp op) {
   VectorType vectorType = op.getVectorType();
   if (vectorType.getRank() != 1)
@@ -1505,13 +1498,6 @@ void InsertElementOp::build(OpBuilder &builder, OperationState &result,
   result.addTypes(dest.getType());
 }
 
-void InsertElementOp::build(OpBuilder &builder, OperationState &result,
-                            Value source, Value dest, int64_t position) {
-  Value pos =
-      builder.create<arith::ConstantIntOp>(result.location, position, 32);
-  build(builder, result, source, dest, pos);
-}
-
 static LogicalResult verify(InsertElementOp op) {
   auto dstVectorType = op.getDestVectorType();
   if (dstVectorType.getRank() != 1)

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7f775161545d3..d5d8509cfa61f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -431,6 +431,20 @@ func @extract_element(%arg0: vector<16xf32>) -> f32 {
 
 // -----
 
+func @extract_element_index(%arg0: vector<16xf32>) -> f32 {
+  %0 = arith.constant 15 : index
+  %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32>
+  return %1 : f32
+}
+// CHECK-LABEL: @extract_element_index(
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>)
+//       CHECK:   %[[c:.*]] = arith.constant 15 : index
+//       CHECK:   %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
+//       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[i]] : i64] : vector<16xf32>
+//       CHECK:   return %[[x]] : f32
+
+// -----
+
 func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 {
   %0 = vector.extract %arg0[15]: vector<16xf32>
   return %0 : f32
@@ -502,6 +516,21 @@ func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
 
 // -----
 
+func @insert_element_index(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
+  %0 = arith.constant 3 : index
+  %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<4xf32>
+  return %1 : vector<4xf32>
+}
+// CHECK-LABEL: @insert_element_index(
+//  CHECK-SAME: %[[A:.*]]: f32,
+//  CHECK-SAME: %[[B:.*]]: vector<4xf32>)
+//       CHECK:   %[[c:.*]] = arith.constant 3 : index
+//       CHECK:   %[[i:.*]] = builtin.unrealized_conversion_cast %[[c]] : index to i64
+//       CHECK:   %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[i]] : i64] : vector<4xf32>
+//       CHECK:   return %[[x]] : vector<4xf32>
+
+// -----
+
 func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32>
   return %0 : vector<4xf32>

diff  --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 9279c4791085a..08b3ffbdb688e 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -8,16 +8,14 @@ func @vector_transfer_ops_0d(%M: memref<f32>) {
 
 //  CHECK: %[[V0:.*]] = arith.constant dense<0{{.*}}> : vector<1xf32>
 //  CHECK: %[[R0:.*]] = scf.for %[[I:.*]] = {{.*}} iter_args(%[[V0_ITER:.*]] = %[[V0]]) -> (vector<1xf32>) {
-//  CHECK:   %[[IDX:.*]] = arith.index_cast %[[I]] : index to i32
 //  CHECK:   %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
-//  CHECK:   %[[R_ITER:.*]] = vector.insertelement %[[S]], %[[V0_ITER]][%[[IDX]] : i32] : vector<1xf32>
+//  CHECK:   %[[R_ITER:.*]] = vector.insertelement %[[S]], %[[V0_ITER]][%[[I]] : index] : vector<1xf32>
 //  CHECK:   scf.yield %[[R_ITER]] : vector<1xf32>
     %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
       memref<f32>, vector<1xf32>
 
 //  CHECK: scf.for %[[J:.*]] = %{{.*}}
-//  CHECK:   %[[JDX:.*]] = arith.index_cast %[[J]] : index to i32
-//  CHECK:   %[[SS:.*]] = vector.extractelement %[[R0]][%[[JDX]] : i32] : vector<1xf32>
+//  CHECK:   %[[SS:.*]] = vector.extractelement %[[R0]][%[[J]] : index] : vector<1xf32>
 //  CHECK:   memref.store %[[SS]], %[[MEM]][] : memref<f32>
     vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} :
       vector<1xf32>, memref<f32>
@@ -107,10 +105,9 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
   // CHECK:                   scf.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
   // CHECK:                     %[[VEC:.*]] = scf.for %[[I6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {{.*}} -> (vector<3xf32>) {
   // CHECK:                       %[[L0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]])
-  // CHECK:                       %[[VIDX:.*]] = arith.index_cast %[[I6]]
   // CHECK:                       scf.if {{.*}} -> (vector<3xf32>) {
   // CHECK-NEXT:                    %[[SCAL:.*]] = memref.load %{{.*}}[%[[L0]], %[[I1]], %[[I2]], %[[L3]]] : memref<?x?x?x?xf32>
-  // CHECK-NEXT:                    %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[VIDX]] : i32] : vector<3xf32>
+  // CHECK-NEXT:                    %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[I6]] : index] : vector<3xf32>
   // CHECK-NEXT:                    scf.yield
   // CHECK-NEXT:                  } else {
   // CHECK-NEXT:                    scf.yield
@@ -181,9 +178,8 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
   // CHECK:                      %[[VEC:.*]] = memref.load %[[VECTOR_VIEW2]][%[[I4]], %[[I5]]] : memref<5x4xvector<3xf32>>
   // CHECK:                      scf.for %[[I6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
   // CHECK:                        %[[S0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]])
-  // CHECK:                        %[[VIDX:.*]] = arith.index_cast %[[I6]]
   // CHECK:                        scf.if
-  // CHECK:                          %[[SCAL:.*]] = vector.extractelement %[[VEC]][%[[VIDX]] : i32] : vector<3xf32>
+  // CHECK:                          %[[SCAL:.*]] = vector.extractelement %[[VEC]][%[[I6]] : index] : vector<3xf32>
   // CHECK:                          memref.store %[[SCAL]], {{.*}}[%[[S0]], %[[S1]], %[[I2]], %[[S3]]] : memref<?x?x?x?xf32>
   // CHECK:                        }
   // CHECK:                      }

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 4a471a4108531..a253fc7fbbcbe 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -80,6 +80,14 @@ func @extract_element(%arg0 : vector<4xf32>, %id : i32) {
 
 // -----
 
+func @extract_element_index(%arg0 : vector<4xf32>, %id : index) {
+// expected-error @+1 {{failed to legalize operation 'vector.extractelement'}}
+  %0 = vector.extractelement %arg0[%id : index] : vector<4xf32>
+  spv.ReturnValue %0: f32
+}
+
+// -----
+
 func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
 // expected-error @+1 {{failed to legalize operation 'vector.extractelement'}}
   %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32>
@@ -110,6 +118,14 @@ func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) {
 
 // -----
 
+func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) {
+// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}}
+  %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32>
+  spv.ReturnValue %0: vector<4xf32>
+}
+
+// -----
+
 func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
 // expected-error @+1 {{failed to legalize operation 'vector.insertelement'}}
   %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
index 3752cdb8ed5b7..7f27922ecdfe4 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir
@@ -210,12 +210,11 @@ func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %ar
 //
 // CHECK-VEC1-LABEL: func @reduction_d
 // CHECK-VEC1-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-VEC1-DAG:   %[[i0:.*]] = arith.constant 0 : i32
 // CHECK-VEC1-DAG:   %[[c16:.*]] = arith.constant 16 : index
 // CHECK-VEC1-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC1-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC1:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC1:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[i0]] : i32] : vector<16xf32>
+// CHECK-VEC1:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
 // CHECK-VEC1:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC1:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC1:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -228,12 +227,11 @@ func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %ar
 //
 // CHECK-VEC2-LABEL: func @reduction_d
 // CHECK-VEC2-DAG:   %[[c0:.*]] = arith.constant 0 : index
-// CHECK-VEC2-DAG:   %[[i0:.*]] = arith.constant 0 : i32
 // CHECK-VEC2-DAG:   %[[c16:.*]] = arith.constant 16 : index
 // CHECK-VEC2-DAG:   %[[c1024:.*]] = arith.constant 1024 : index
 // CHECK-VEC2-DAG:   %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
 // CHECK-VEC2:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
-// CHECK-VEC2:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[i0]] : i32] : vector<16xf32>
+// CHECK-VEC2:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
 // CHECK-VEC2:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
 // CHECK-VEC2:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
 // CHECK-VEC2:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index ab694417a38c3..030eff2bf9d2d 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -23,7 +23,6 @@
 // CHECK-SAME:      %[[VAL_2:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<f64> {
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<8xf64>
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 8 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 64 : index
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
@@ -89,7 +88,7 @@
 // CHECK:               %[[VAL_63:.*]] = select %[[VAL_61]], %[[VAL_62]], %[[VAL_34]] : index
 // CHECK:               scf.yield %[[VAL_60]], %[[VAL_63]], %[[VAL_64:.*]] : index, index, f64
 // CHECK:             }
-// CHECK:             %[[VAL_65:.*]] = vector.insertelement %[[VAL_66:.*]]#2, %[[VAL_3]]{{\[}}%[[VAL_5]] : i32] : vector<8xf64>
+// CHECK:             %[[VAL_65:.*]] = vector.insertelement %[[VAL_66:.*]]#2, %[[VAL_3]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
 // CHECK:             %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_66]]#0 to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_69:.*]] = %[[VAL_65]]) -> (vector<8xf64>) {
 // CHECK:               %[[VAL_70:.*]] = affine.min #map(%[[VAL_22]], %[[VAL_68]])
 // CHECK:               %[[VAL_71:.*]] = vector.create_mask %[[VAL_70]] : vector<8xi1>

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 42a90f96e9dbd..7e6c1713d455e 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -7,14 +7,14 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
 // CHECK-LABEL: func @vector_multi_reduction
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>
 //       CHECK:       %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<2xf32>
-//       CHECK:       %[[C0:.+]] = arith.constant 0 : i32
-//       CHECK:       %[[C1:.+]] = arith.constant 1 : i32
+//       CHECK:       %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:       %[[C1:.+]] = arith.constant 1 : index
 //       CHECK:       %[[V0:.+]] = vector.extract %[[INPUT]][0]
 //       CHECK:       %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<4xf32> into f32
-//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<2xf32>
+//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<2xf32>
 //       CHECK:       %[[V1:.+]] = vector.extract %[[INPUT]][1]
 //       CHECK:       %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<4xf32> into f32
-//       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<2xf32>
+//       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
 //       CHECK:       return %[[RESULT_VEC]]
 
 func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>) -> f32 {
@@ -36,31 +36,31 @@ func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> {
 // CHECK-LABEL: func @vector_reduction_inner
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<2x3x4x5xi32>
 //       CHECK:       %[[FLAT_RESULT_VEC_0:.+]] = arith.constant dense<0> : vector<6xi32>
-//   CHECK-DAG:       %[[C0:.+]] = arith.constant 0 : i32
-//   CHECK-DAG:       %[[C1:.+]] = arith.constant 1 : i32
-//   CHECK-DAG:       %[[C2:.+]] = arith.constant 2 : i32
-//   CHECK-DAG:       %[[C3:.+]] = arith.constant 3 : i32
-//   CHECK-DAG:       %[[C4:.+]] = arith.constant 4 : i32
-//   CHECK-DAG:       %[[C5:.+]] = arith.constant 5 : i32
+//   CHECK-DAG:       %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:       %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:       %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-DAG:       %[[C3:.+]] = arith.constant 3 : index
+//   CHECK-DAG:       %[[C4:.+]] = arith.constant 4 : index
+//   CHECK-DAG:       %[[C5:.+]] = arith.constant 5 : index
 //       CHECK:       %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32>
 //       CHECK:       %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32>
 //       CHECK:       %[[V0R:.+]] = vector.reduction "add", %[[V0]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : i32] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : index] : vector<6xi32>
 //       CHECK:       %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32>
 //       CHECK:       %[[V1R:.+]] = vector.reduction "add", %[[V1]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : i32] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : index] : vector<6xi32>
 //       CHECK:       %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32>
 //       CHECK:       %[[V2R:.+]] = vector.reduction "add", %[[V2]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : i32] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : index] : vector<6xi32>
 //       CHECK:       %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32>
 //       CHECK:       %[[V3R:.+]] = vector.reduction "add", %[[V3]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : i32] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : index] : vector<6xi32>
 //       CHECK:       %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32>
 //       CHECK:       %[[V4R:.+]] = vector.reduction "add", %[[V4]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : i32] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : index] : vector<6xi32>
 ///       CHECK:      %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32>
 //       CHECK:       %[[V5R:.+]] = vector.reduction "add", %[[V5]] : vector<20xi32> into i32
-//       CHECK:       %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : i32] : vector<6xi32>
+//       CHECK:       %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : index] : vector<6xi32>
 //       CHECK:       %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
 //       CHECK:       return %[[RESULT]]
 
@@ -84,38 +84,38 @@ func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf3
 // CHECK-LABEL: func @vector_multi_reduction_ordering
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<3x2x4xf32>
 //       CHECK:       %[[RESULT_VEC_0:.+]] = arith.constant dense<{{.*}}> : vector<8xf32>
-//       CHECK:       %[[C0:.+]] = arith.constant 0 : i32
-//       CHECK:       %[[C1:.+]] = arith.constant 1 : i32
-//       CHECK:       %[[C2:.+]] = arith.constant 2 : i32
-//       CHECK:       %[[C3:.+]] = arith.constant 3 : i32
-//       CHECK:       %[[C4:.+]] = arith.constant 4 : i32
-//       CHECK:       %[[C5:.+]] = arith.constant 5 : i32
-//       CHECK:       %[[C6:.+]] = arith.constant 6 : i32
-//       CHECK:       %[[C7:.+]] = arith.constant 7 : i32
+//       CHECK:       %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:       %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:       %[[C2:.+]] = arith.constant 2 : index
+//       CHECK:       %[[C3:.+]] = arith.constant 3 : index
+//       CHECK:       %[[C4:.+]] = arith.constant 4 : index
+//       CHECK:       %[[C5:.+]] = arith.constant 5 : index
+//       CHECK:       %[[C6:.+]] = arith.constant 6 : index
+//       CHECK:       %[[C7:.+]] = arith.constant 7 : index
 //       CHECK:       %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32>
 //       CHECK:       %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0]
 //       CHECK:       %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : index] : vector<8xf32>
 //       CHECK:       %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1]
 //       CHECK:       %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<8xf32>
 //       CHECK:       %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2]
 //       CHECK:       %[[RV2:.+]] = vector.reduction "mul", %[[V2]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : i32] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : index] : vector<8xf32>
 //       CHECK:       %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3]
 //       CHECK:       %[[RV3:.+]] = vector.reduction "mul", %[[V3]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : i32] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : index] : vector<8xf32>
 //       CHECK:       %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0]
 //       CHECK:       %[[RV4:.+]] = vector.reduction "mul", %[[V4]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : i32] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : index] : vector<8xf32>
 //       CHECK:       %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1]
 //       CHECK:       %[[RV5:.+]] = vector.reduction "mul", %[[V5]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : i32] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : index] : vector<8xf32>
 //       CHECK:       %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2]
 //       CHECK:       %[[RV6:.+]] = vector.reduction "mul", %[[V6]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : i32] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : index] : vector<8xf32>
 //       CHECK:       %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3]
 //       CHECK:       %[[RV7:.+]] = vector.reduction "mul", %[[V7]] : vector<3xf32> into f32
-//       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : i32] : vector<8xf32>
+//       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32>
 //       CHECK:       %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
 //       CHECK:       return %[[RESHAPED_VEC]]


        


More information about the Mlir-commits mailing list