[Mlir-commits] [mlir] aef20f5 - [mlir][sparse] move from by-value to by-reference for data types

Aart Bik llvmlistbot at llvm.org
Fri Jun 17 08:39:40 PDT 2022


Author: Aart Bik
Date: 2022-06-17T08:39:25-07:00
New Revision: aef20f59a5210406a0b7aafd0a75eee708b8fcab

URL: https://github.com/llvm/llvm-project/commit/aef20f59a5210406a0b7aafd0a75eee708b8fcab
DIFF: https://github.com/llvm/llvm-project/commit/aef20f59a5210406a0b7aafd0a75eee708b8fcab.diff

LOG: [mlir][sparse] move from by-value to by-reference for data types

This fixes all sorts of ABI issues due to passing by-value
(using by-reference with memref's exclusively).

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D128018

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
    mlir/test/Dialect/SparseTensor/conversion.mlir
    mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
    mlir/test/Dialect/SparseTensor/sparse_index.mlir
    mlir/test/Dialect/SparseTensor/sparse_out.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index dddb25618f5c0..c3dc294e9b0b4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -192,7 +192,7 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
 
     ```mlir
     sparse_tensor.lex_insert %tensor, %indices, %val
-      : tensor<1024x1024xf64, #CSR>, memref<?xindex>, f64
+      : tensor<1024x1024xf64, #CSR>, memref<?xindex>, memref<f64>
     ```
   }];
   let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
index 7f2bbae0e71bb..fba5012860725 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
@@ -88,16 +88,8 @@ enum class PrimaryType : uint32_t {
   kC32 = 10
 };
 
-// This x-macro only specifies the non-complex `V` types, because the ABI
-// for complex types has compiler-/architecture-dependent details we need
-// to work around.  Namely, when a function takes a parameter of C/C++
-// type `complex32` (per se), then there is additional padding that causes
-// it not to match the LLVM type `!llvm.struct<(f32, f32)>`.  This only
-// happens with the `complex32` type itself, not with pointers/arrays
-// of complex values.  We also exclude `complex64` because it is in
-// principle susceptible to analogous ABI issues (even though we haven't
-// yet encountered them in practice).
-#define FOREVERY_SIMPLEX_V(DO)                                                 \
+// This x-macro includes all `V` types.
+#define FOREVERY_V(DO)                                                         \
   DO(F64, double)                                                              \
   DO(F32, float)                                                               \
   DO(F16, f16)                                                                 \
@@ -105,12 +97,7 @@ enum class PrimaryType : uint32_t {
   DO(I64, int64_t)                                                             \
   DO(I32, int32_t)                                                             \
   DO(I16, int16_t)                                                             \
-  DO(I8, int8_t)
-
-// This x-macro includes all `V` types, for when the aforementioned ABI
-// issues don't apply (e.g., because the functions take pointers/arrays).
-#define FOREVERY_V(DO)                                                         \
-  FOREVERY_SIMPLEX_V(DO)                                                       \
+  DO(I8, int8_t)                                                               \
   DO(C64, complex64)                                                           \
   DO(C32, complex32)
 
@@ -195,18 +182,11 @@ FOREVERY_O(DECL_SPARSEINDICES)
 /// Coordinate-scheme method for adding a new element.
 #define DECL_ADDELT(VNAME, V)                                                  \
   MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_addElt##VNAME(                   \
-      void *coo, V value, StridedMemRefType<index_type, 1> *iref,              \
+      void *coo,                                                               \
+      StridedMemRefType<V, 0> *vref, StridedMemRefType<index_type, 1> *iref,   \
       StridedMemRefType<index_type, 1> *pref);
-FOREVERY_SIMPLEX_V(DECL_ADDELT)
-DECL_ADDELT(C64, complex64)
+FOREVERY_V(DECL_ADDELT)
 #undef DECL_ADDELT
-// Explicitly unpack the `complex32` into a pair of `float` arguments,
-// to work around ABI issues.
-// TODO: cleaner way to avoid ABI padding problem?
-MLIR_CRUNNERUTILS_EXPORT void *
-_mlir_ciface_addEltC32(void *coo, float r, float i,
-                       StridedMemRefType<index_type, 1> *iref,
-                       StridedMemRefType<index_type, 1> *pref);
 
 /// Coordinate-scheme method for getting the next element while iterating.
 #define DECL_GETNEXT(VNAME, V)                                                 \
@@ -219,16 +199,10 @@ FOREVERY_V(DECL_GETNEXT)
 /// Tensor-storage method to insert elements in lexicographical index order.
 #define DECL_LEXINSERT(VNAME, V)                                               \
   MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_lexInsert##VNAME(                 \
-      void *tensor, StridedMemRefType<index_type, 1> *cref, V val);
-FOREVERY_SIMPLEX_V(DECL_LEXINSERT)
-DECL_LEXINSERT(C64, complex64)
+      void *tensor, StridedMemRefType<index_type, 1> *cref,                    \
+      StridedMemRefType<V, 0> *vref);
+FOREVERY_V(DECL_LEXINSERT)
 #undef DECL_LEXINSERT
-// Explicitly unpack the `complex32` into a pair of `float` arguments,
-// to work around ABI issues.
-// TODO: cleaner way to avoid ABI padding problem?
-MLIR_CRUNNERUTILS_EXPORT void
-_mlir_ciface_lexInsertC32(void *tensor, StridedMemRefType<index_type, 1> *cref,
-                          float r, float i);
 
 /// Tensor-storage method to insert using expansion.
 #define DECL_EXPINSERT(VNAME, V)                                               \

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 3ca7ff1c62cab..06168d5ef2c7f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -266,11 +266,11 @@ static void genDelCOOCall(OpBuilder &builder, Operation *op, Type elemTp,
 /// In particular, this generates code like the following:
 ///   val = a[i1,..,ik];
 ///   if val != 0
-///     t->add(val, [i1,..,ik], [p1,..,pk]);
+///     t->add(&val, [i1,..,ik], [p1,..,pk]);
 static void genAddEltCall(OpBuilder &builder, Operation *op, Type eltType,
-                          Value ptr, Value val, Value ind, Value perm) {
+                          Value ptr, Value valPtr, Value ind, Value perm) {
   SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
-  SmallVector<Value, 4> params{ptr, val, ind, perm};
+  SmallVector<Value, 4> params{ptr, valPtr, ind, perm};
   Type pTp = getOpaquePointerType(builder);
   createFuncCall(builder, op, name, pTp, params, EmitCInterface::On);
 }
@@ -674,6 +674,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       }
     }
     Type eltType = stp.getElementType();
+    Value elemPtr = genAllocaScalar(rewriter, loc, eltType);
     scf::buildLoopNest(
         rewriter, op.getLoc(), lo, hi, st, {},
         [&](OpBuilder &builder, Location loc, ValueRange ivs,
@@ -684,7 +685,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
                                             ivs, rank);
           else
             val = genIndexAndValueForDense(rewriter, loc, src, ind, ivs);
-          genAddEltCall(rewriter, op, eltType, coo, val, ind, perm);
+          builder.create<memref::StoreOp>(loc, val, elemPtr);
+          genAddEltCall(rewriter, op, eltType, coo, elemPtr, ind, perm);
           return {};
         });
     // Final call to construct sparse tensor storage.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index d0b6758d00ab8..590e925a02d01 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -56,8 +56,8 @@ struct CodeGen {
         highs(numTensors, std::vector<Value>(numLoops)),
         pidxs(numTensors, std::vector<Value>(numLoops)),
         idxs(numTensors, std::vector<Value>(numLoops)), redVal(), sparseOut(op),
-        outerParNest(nest), lexIdx(), expValues(), expFilled(), expAdded(),
-        expCount(), curVecMask() {}
+        outerParNest(nest), lexIdx(), lexVal(), expValues(), expFilled(),
+        expAdded(), expCount(), curVecMask() {}
   /// Sparsification options.
   SparsificationOptions options;
   /// Universal dense indices and upper bounds (by index). The loops array
@@ -89,6 +89,7 @@ struct CodeGen {
   OpOperand *sparseOut;
   unsigned outerParNest;
   Value lexIdx;
+  Value lexVal;
   Value expValues;
   Value expFilled;
   Value expAdded;
@@ -543,6 +544,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
       auto dynShape = {ShapedType::kDynamicSize};
       auto memTp = MemRefType::get(dynShape, builder.getIndexType());
       codegen.lexIdx = builder.create<memref::AllocaOp>(loc, memTp, rank);
+      codegen.lexVal = builder.create<memref::AllocaOp>(
+          loc, MemRefType::get({}, elementType));
     } else {
       // Annotated sparse tensors.
       auto dynShape = {ShapedType::kDynamicSize};
@@ -723,7 +726,8 @@ static void genInsertionStore(CodeGen &codegen, OpBuilder &builder,
   Location loc = op.getLoc();
   // Direct insertion in lexicographic index order.
   if (!codegen.expValues) {
-    builder.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs);
+    builder.create<memref::StoreOp>(loc, rhs, codegen.lexVal);
+    builder.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, codegen.lexVal);
     return;
   }
   // Generates insertion code along expanded access pattern.

diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 266ad5db74b68..b69bec2d5cc4b 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -1717,10 +1717,10 @@ FOREVERY_O(IMPL_SPARSEINDICES)
 #undef IMPL_GETOVERHEAD
 
 #define IMPL_ADDELT(VNAME, V)                                                  \
-  void *_mlir_ciface_addElt##VNAME(void *coo, V value,                         \
+  void *_mlir_ciface_addElt##VNAME(void *coo, StridedMemRefType<V, 0> *vref,   \
                                    StridedMemRefType<index_type, 1> *iref,     \
                                    StridedMemRefType<index_type, 1> *pref) {   \
-    assert(coo &&iref &&pref);                                                 \
+    assert(coo &&vref &&iref &&pref);                                          \
     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
     assert(iref->sizes[0] == pref->sizes[0]);                                  \
     const index_type *indx = iref->data + iref->offset;                        \
@@ -1729,25 +1729,12 @@ FOREVERY_O(IMPL_SPARSEINDICES)
     std::vector<index_type> indices(isize);                                    \
     for (uint64_t r = 0; r < isize; r++)                                       \
       indices[perm[r]] = indx[r];                                              \
-    static_cast<SparseTensorCOO<V> *>(coo)->add(indices, value);               \
+    V *value = vref->data + vref->offset;                                      \
+    static_cast<SparseTensorCOO<V> *>(coo)->add(indices, *value);              \
     return coo;                                                                \
   }
-FOREVERY_SIMPLEX_V(IMPL_ADDELT)
-IMPL_ADDELT(C64, complex64)
-// Marked static because it's not part of the public API.
-// NOTE: the `static` keyword confuses clang-format here, causing
-// the strange indentation of the `_mlir_ciface_addEltC32` prototype.
-// In C++11 we can add a semicolon after the call to `IMPL_ADDELT`
-// and that will correct clang-format.  Alas, this file is compiled
-// in C++98 mode where that semicolon is illegal (and there's no portable
-// macro magic to license a no-op semicolon at the top level).
-static IMPL_ADDELT(C32ABI, complex32)
+FOREVERY_V(IMPL_ADDELT)
 #undef IMPL_ADDELT
-    void *_mlir_ciface_addEltC32(void *coo, float r, float i,
-                                 StridedMemRefType<index_type, 1> *iref,
-                                 StridedMemRefType<index_type, 1> *pref) {
-  return _mlir_ciface_addEltC32ABI(coo, complex32(r, i), iref, pref);
-}
 
 #define IMPL_GETNEXT(VNAME, V)                                                 \
   bool _mlir_ciface_getNext##VNAME(void *coo,                                  \
@@ -1771,25 +1758,18 @@ FOREVERY_V(IMPL_GETNEXT)
 #undef IMPL_GETNEXT
 
 #define IMPL_LEXINSERT(VNAME, V)                                               \
-  void _mlir_ciface_lexInsert##VNAME(                                          \
-      void *tensor, StridedMemRefType<index_type, 1> *cref, V val) {           \
-    assert(tensor &&cref);                                                     \
+  void _mlir_ciface_lexInsert##VNAME(void *tensor,                             \
+                                     StridedMemRefType<index_type, 1> *cref,   \
+                                     StridedMemRefType<V, 0> *vref) {          \
+    assert(tensor &&cref &&vref);                                              \
     assert(cref->strides[0] == 1);                                             \
     index_type *cursor = cref->data + cref->offset;                            \
     assert(cursor);                                                            \
-    static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, val);    \
+    V *value = vref->data + vref->offset;                                      \
+    static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, *value); \
   }
-FOREVERY_SIMPLEX_V(IMPL_LEXINSERT)
-IMPL_LEXINSERT(C64, complex64)
-// Marked static because it's not part of the public API.
-// NOTE: see the note for `_mlir_ciface_addEltC32ABI`
-static IMPL_LEXINSERT(C32ABI, complex32)
+FOREVERY_V(IMPL_LEXINSERT)
 #undef IMPL_LEXINSERT
-    void _mlir_ciface_lexInsertC32(void *tensor,
-                                   StridedMemRefType<index_type, 1> *cref,
-                                   float r, float i) {
-  _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i));
-}
 
 #define IMPL_EXPINSERT(VNAME, V)                                               \
   void _mlir_ciface_expInsert##VNAME(                                          \

diff  --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 5bc04016a15a3..2a85d012b98e3 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -190,12 +190,14 @@ func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32
 //       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[EmptyCOO]], %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<1xindex>
 //       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
+//       CHECK: %[[BUF:.*]] = memref.alloca() : memref<i32>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] {
 //       CHECK:   %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xi32>
 //       CHECK:   %[[N:.*]] = arith.cmpi ne, %[[E]], %[[I0]] : i32
 //       CHECK:   scf.if %[[N]] {
 //       CHECK:     memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
-//       CHECK:     call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]])
+//       CHECK:     memref.store %[[E]], %[[BUF]][] : memref<i32>
+//       CHECK:     call @addEltI32(%[[C]], %[[BUF]], %[[T]], %[[Z]])
 //       CHECK:   }
 //       CHECK: }
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
@@ -274,12 +276,14 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
 //       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[EmptyCOO]], %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
 //       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
+//       CHECK: %[[BUF:.*]] = memref.alloca() : memref<f64>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %{{.*}} step %[[C1]] {
 //       CHECK:   scf.for %[[J:.*]] = %[[C0]] to %{{.*}} step %[[C1]] {
 //       CHECK:     %[[E:.*]] = tensor.extract %[[A]][%[[I]], %[[J]]] : tensor<2x4xf64>
 //       CHECK:     memref.store %[[I]], %[[M]][%[[C0]]] : memref<2xindex>
 //       CHECK:     memref.store %[[J]], %[[M]][%[[C1]]] : memref<2xindex>
-//       CHECK:     call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]])
+//       CHECK:     memref.store %[[E]], %[[BUF]][] : memref<f64>
+//       CHECK:     call @addEltF64(%[[C]], %[[BUF]], %[[T]], %[[Z]])
 //       CHECK:   }
 //       CHECK: }
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
@@ -306,11 +310,13 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseM
 //       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[EmptyCOO]], %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
 //       CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
+//       CHECK: %[[BUF:.*]] = memref.alloca() : memref<f32>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
 //       CHECK:   memref.store %{{.*}}, %[[M]][%[[C0]]] : memref<2xindex>
 //       CHECK:   memref.store %{{.*}}, %[[M]][%[[C1]]] : memref<2xindex>
 //       CHECK:   %[[V:.*]] = tensor.extract %{{.*}}[%[[I]]] : tensor<2xf32>
-//       CHECK:   call @addEltF32(%{{.*}}, %[[V]], %[[N]], %{{.*}})
+//       CHECK:   memref.store %[[V]], %[[BUF]][] : memref<f32>
+//       CHECK:   call @addEltF32(%{{.*}}, %[[BUF]], %[[N]], %{{.*}})
 //       CHECK: }
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]])
 //       CHECK: call @delSparseTensorCOOF32(%[[C]])
@@ -343,6 +349,7 @@ func.func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
 //       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[EmptyCOO]], %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<3xindex>
 //       CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<3xindex> to memref<?xindex>
+//       CHECK: %[[BUF:.*]] = memref.alloca() : memref<f64>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U1]] step %[[C1]] {
 //       CHECK:   scf.for %[[J:.*]] = %[[C0]] to %[[U2]] step %[[C1]] {
 //       CHECK:     scf.for %[[K:.*]] = %[[C0]] to %[[U3]] step %[[C1]] {
@@ -350,7 +357,8 @@ func.func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
 //       CHECK:       memref.store %[[I]], %[[M]][%[[C0]]] : memref<3xindex>
 //       CHECK:       memref.store %[[J]], %[[M]][%[[C1]]] : memref<3xindex>
 //       CHECK:       memref.store %[[K]], %[[M]][%[[C2]]] : memref<3xindex>
-//       CHECK:       call @addEltF64(%[[C]], %[[E]], %[[N]], %[[Z]])
+//       CHECK:       memref.store %[[E]], %[[BUF]][] : memref<f64>
+//       CHECK:       call @addEltF64(%[[C]], %[[BUF]], %[[N]], %[[Z]])
 //       CHECK:     }
 //       CHECK:   }
 //       CHECK: }
@@ -493,13 +501,13 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
 // CHECK-LABEL: func @sparse_insert(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
 //  CHECK-SAME: %[[B:.*]]: memref<?xindex>,
-//  CHECK-SAME: %[[C:.*]]: f32) {
-//       CHECK: call @lexInsertF32(%[[A]], %[[B]], %[[C]]) : (!llvm.ptr<i8>, memref<?xindex>, f32) -> ()
+//  CHECK-SAME: %[[C:.*]]: memref<f32>) {
+//       CHECK: call @lexInsertF32(%[[A]], %[[B]], %[[C]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f32>) -> ()
 //       CHECK: return
 func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
                     %arg1: memref<?xindex>,
-                    %arg2: f32) {
-  sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf32, #SparseVector>, memref<?xindex>, f32
+                    %arg2: memref<f32>) {
+  sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf32, #SparseVector>, memref<?xindex>, memref<f32>
   return
 }
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
index 9b7506c0a9064..1fe0905aad613 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir
@@ -360,6 +360,7 @@ func.func @divbyc(%arga: tensor<32xf64, #SV>,
 // CHECK:         %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_1]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:         %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
 // CHECK:         %[[VAL_8:.*]] = memref.alloca(%[[VAL_2]]) : memref<?xindex>
+// CHECK:         %[[BUF:.*]] = memref.alloca() : memref<f64>
 // CHECK:         %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
 // CHECK:         %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:         scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_2]] {
@@ -374,7 +375,8 @@ func.func @divbyc(%arga: tensor<32xf64, #SV>,
 // CHECK:           %[[VAL_19:.*]] = math.log1p %[[VAL_18]] : f64
 // CHECK:           %[[VAL_20:.*]] = math.sin %[[VAL_19]] : f64
 // CHECK:           %[[VAL_21:.*]] = math.tanh %[[VAL_20]] : f64
-// CHECK:           sparse_tensor.lex_insert %[[VAL_4]], %[[VAL_8]], %[[VAL_21]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref<?xindex>, f64
+// CHECK:           memref.store %[[VAL_21]], %[[BUF]][] : memref<f64>
+// CHECK:           sparse_tensor.lex_insert %[[VAL_4]], %[[VAL_8]], %[[BUF]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref<?xindex>, memref<f64>
 // CHECK:         }
 // CHECK:         %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_4]] hasInserts : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:         return %[[VAL_22]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
index bafeafaefbf0b..c90651f578c86 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_index.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir
@@ -82,6 +82,7 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
 // CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK:           %[[VAL_12:.*]] = memref.alloca(%[[VAL_3]]) : memref<?xindex>
+// CHECK:           %[[BUF:.*]] = memref.alloca() : memref<i64>
 // CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_1]]] : memref<?xindex>
 // CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_2]] {
@@ -98,7 +99,8 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
 // CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<?xi64>
 // CHECK:               %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_24]] : i64
 // CHECK:               %[[VAL_26:.*]] = arith.muli %[[VAL_22]], %[[VAL_25]] : i64
-// CHECK:               sparse_tensor.lex_insert %[[VAL_6]], %[[VAL_12]], %[[VAL_26]] : tensor<?x?xi64, #sparse_tensor.encoding
+// CHECK:               memref.store %[[VAL_26]], %[[BUF]][] : memref<i64>
+// CHECK:               sparse_tensor.lex_insert %[[VAL_6]], %[[VAL_12]], %[[BUF]] : tensor<?x?xi64, #sparse_tensor.encoding
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_27:.*]] = sparse_tensor.load %[[VAL_6]] hasInserts : tensor<?x?xi64, #sparse_tensor.encoding

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 167b778b17f8e..96409e1271a85 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -111,6 +111,7 @@ func.func @sparse_simply_dynamic2(%argx: tensor<32x16xf32, #DCSR> {linalg.inplac
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:           %[[VAL_11:.*]] = memref.alloca(%[[VAL_5]]) : memref<?xindex>
+// CHECK:           %[[BUF:.*]] = memref.alloca() : memref<f32>
 // CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_2]] step %[[VAL_4]] {
 // CHECK:             memref.store %[[VAL_12]], %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref<?xindex>
 // CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xindex>
@@ -121,7 +122,8 @@ func.func @sparse_simply_dynamic2(%argx: tensor<32x16xf32, #DCSR> {linalg.inplac
 // CHECK:               memref.store %[[VAL_17]], %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref<?xindex>
 // CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<?xf32>
 // CHECK:               %[[VAL_19:.*]] = arith.mulf %[[VAL_18]], %[[VAL_1]] : f32
-// CHECK:               sparse_tensor.lex_insert %[[VAL_7]], %[[VAL_11]], %[[VAL_19]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:               memref.store %[[VAL_19]], %[[BUF]][] : memref<f32>
+// CHECK:               sparse_tensor.lex_insert %[[VAL_7]], %[[VAL_11]], %[[BUF]] : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
 // CHECK:             }
 // CHECK:           }
 // CHECK:           %[[VAL_20:.*]] = sparse_tensor.load %[[VAL_7]] hasInserts : tensor<10x20xf32, #sparse_tensor.encoding<{{.*}}>>
@@ -175,6 +177,7 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK:           %[[VAL_21:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
 // CHECK:           %[[VAL_22:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xi32>
 // CHECK:           %[[VAL_23:.*]] = memref.alloca(%[[VAL_4]]) : memref<?xindex>
+// CHECK:           %[[BUF:.*]] = memref.alloca() : memref<i32>
 // CHECK:           %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
 // CHECK:           %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
 // CHECK:           %[[VAL_26:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_2]]] : memref<?xindex>
@@ -255,7 +258,8 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK:                     %[[VAL_97:.*]] = arith.select %[[VAL_95]], %[[VAL_96]], %[[VAL_78]] : index
 // CHECK:                     scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_98:.*]] : index, index, i32
 // CHECK:                   }
-// CHECK:                   sparse_tensor.lex_insert %[[VAL_8]], %[[VAL_23]], %[[VAL_99:.*]]#2 : tensor<?x?xi32, #{{.*}}>, memref<?xindex>, i32
+// CHECK:                   memref.store %[[VAL_70]]#2, %[[BUF]][] : memref<i32>
+// CHECK:                   sparse_tensor.lex_insert %[[VAL_8]], %[[VAL_23]], %[[BUF]] : tensor<?x?xi32, #{{.*}}>, memref<?xindex>, memref<i32>
 // CHECK:                 } else {
 // CHECK:                 }
 // CHECK:                 %[[VAL_100:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index


        


More information about the Mlir-commits mailing list