[Mlir-commits] [mlir] 8d615a2 - [mlir][sparse] fix crash on sparse_tensor.foreach operation on tensors with complex<T> elements.

Peiming Liu llvmlistbot at llvm.org
Thu Nov 17 11:36:21 PST 2022


Author: Peiming Liu
Date: 2022-11-17T19:36:15Z
New Revision: 8d615a23ef781c91da06e45337229c19f2c777d7

URL: https://github.com/llvm/llvm-project/commit/8d615a23ef781c91da06e45337229c19f2c777d7
DIFF: https://github.com/llvm/llvm-project/commit/8d615a23ef781c91da06e45337229c19f2c777d7.diff

LOG: [mlir][sparse] fix crash on sparse_tensor.foreach operation on tensors with complex<T> elements.

Reviewed By: aartbik, bixia

Differential Revision: https://reviews.llvm.org/D138223

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 0db86dd2e8c16..fa36d083e285e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -1024,3 +1024,36 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
     ;
   return op;
 }
+
+void sparse_tensor::foreachInSparseConstant(
+    Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
+    function_ref<void(ArrayRef<Value>, Value)> callback) {
+  int64_t rank = attr.getType().getRank();
+  // Foreach on constant.
+  DenseElementsAttr indicesAttr = attr.getIndices();
+  DenseElementsAttr valuesAttr = attr.getValues();
+
+  SmallVector<Value> coords;
+  for (int i = 0, e = valuesAttr.size(); i < e; i++) {
+    coords.clear();
+    for (int j = 0; j < rank; j++) {
+      auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
+      auto coord =
+          rewriter.create<arith::ConstantIndexOp>(loc, coordAttr.getInt());
+      // Remaps coordinates.
+      coords.push_back(coord);
+    }
+    Value val;
+    if (attr.getElementType().isa<ComplexType>()) {
+      auto valAttr = valuesAttr.getValues<ArrayAttr>()[i];
+      val = rewriter.create<complex::ConstantOp>(loc, attr.getElementType(),
+                                                 valAttr);
+    } else {
+      auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
+      // Remaps value.
+      val = rewriter.create<arith::ConstantOp>(loc, valAttr);
+    }
+    assert(val);
+    callback(coords, val);
+  }
+}

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index ecd4135be8e67..373817eb13a47 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -183,6 +183,26 @@ void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
 /// Scans to top of generated loop.
 Operation *getTop(Operation *op);
 
+/// Iterate over a sparse constant, generates constantOp for value and indices.
+/// E.g.,
+/// sparse<[ [0], [28], [31] ],
+///          [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] >
+/// =>
+/// %c1 = arith.constant 0
+/// %v1 = complex.constant (5.13, 2.0)
+/// callback({%c1}, %v1)
+///
+/// %c2 = arith.constant 28
+/// %v2 = complex.constant (3.0, 4.0)
+/// callback({%c2}, %v2)
+///
+/// %c3 = arith.constant 31
+/// %v3 = complex.constant (5.0, 6.0)
+/// callback({%c3}, %v3)
+void foreachInSparseConstant(
+    Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
+    function_ref<void(ArrayRef<Value>, Value)> callback);
+
 //===----------------------------------------------------------------------===//
 // Inlined constant generators.
 //
@@ -197,9 +217,9 @@ Operation *getTop(Operation *op);
 //===----------------------------------------------------------------------===//
 
 /// Generates a 0-valued constant of the given type.  In addition to
-/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, `IntegerType`),
-/// this also works for `RankedTensorType` and `VectorType` (for which it
-/// generates a constant `DenseElementsAttr` of zeros).
+/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`,
+/// `IntegerType`), this also works for `RankedTensorType` and `VectorType`
+/// (for which it generates a constant `DenseElementsAttr` of zeros).
 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
   if (auto ctp = tp.dyn_cast<ComplexType>()) {
     auto zeroe = builder.getZeroAttr(ctp.getElementType());

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 5fa2b4ebbfd2c..8a3cc0ae5fea3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -170,6 +170,35 @@ static void getDynamicSizes(RankedTensorType tp,
   }
 }
 
+static LogicalResult genForeachOnSparseConstant(ForeachOp op,
+                                                RewriterBase &rewriter,
+                                                SparseElementsAttr attr) {
+  auto loc = op.getLoc();
+  SmallVector<Value> reduc = op.getInitArgs();
+
+  // Foreach on constant.
+  foreachInSparseConstant(
+      loc, rewriter, attr,
+      [&reduc, &rewriter, op](ArrayRef<Value> coords, Value v) mutable {
+        SmallVector<Value> args;
+        args.append(coords.begin(), coords.end());
+        args.push_back(v);
+        args.append(reduc);
+        // Clones the foreach op to get a copy of the loop body.
+        auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
+        assert(args.size() == cloned.getBody()->getNumArguments());
+        Operation *yield = cloned.getBody()->getTerminator();
+        rewriter.mergeBlockBefore(cloned.getBody(), op, args);
+        // clean up
+        rewriter.eraseOp(cloned);
+        reduc = yield->getOperands();
+        rewriter.eraseOp(yield);
+      });
+
+  rewriter.replaceOp(op, reduc);
+  return success();
+}
+
 //===---------------------------------------------------------------------===//
 // The actual sparse tensor rewriting rules.
 //===---------------------------------------------------------------------===//
@@ -752,36 +781,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
     // rule.
     if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
       if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
-        // Foreach on constant.
-        DenseElementsAttr indicesAttr = attr.getIndices();
-        DenseElementsAttr valuesAttr = attr.getValues();
-
-        SmallVector<Value> args;
-        for (int i = 0, e = valuesAttr.size(); i < e; i++) {
-          auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
-          for (int j = 0; j < rank; j++) {
-            auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
-            auto coord = rewriter.create<arith::ConstantIndexOp>(
-                loc, coordAttr.getInt());
-            // Remaps coordinates.
-            args.push_back(coord);
-          }
-          // Remaps value.
-          auto val = rewriter.create<arith::ConstantOp>(loc, valAttr);
-          args.push_back(val);
-          // Remaps iteration args.
-          args.append(reduc);
-          auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
-          Operation *yield = cloned.getBody()->getTerminator();
-          rewriter.mergeBlockBefore(cloned.getBody(), op, args);
-          // clean up
-          args.clear();
-          rewriter.eraseOp(cloned);
-          reduc = yield->getOperands();
-          rewriter.eraseOp(yield);
-        }
-        rewriter.replaceOp(op, reduc);
-        return success();
+        return genForeachOnSparseConstant(op, rewriter, attr);
       }
     }
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
index aab48131605ed..97ecf509ea552 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir
@@ -1,8 +1,15 @@
-// RUN: mlir-opt %s --sparse-compiler | \
-// RUN: mlir-cpu-runner \
-// RUN:  -e entry -entry-point-result=void  \
-// RUN:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s
+// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = enable-runtime-library=false
+// RUN: %{command}
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 


        


More information about the Mlir-commits mailing list