[Mlir-commits] [mlir] 652b39b - [mlir][sparse][linalg] add linalg rewriting specific to sparse tensors

Aart Bik llvmlistbot at llvm.org
Wed Feb 23 17:29:51 PST 2022


Author: Aart Bik
Date: 2022-02-23T17:29:41-08:00
New Revision: 652b39b46f85ad826a20d3e0cec5d0db91b43daf

URL: https://github.com/llvm/llvm-project/commit/652b39b46f85ad826a20d3e0cec5d0db91b43daf
DIFF: https://github.com/llvm/llvm-project/commit/652b39b46f85ad826a20d3e0cec5d0db91b43daf.diff

LOG: [mlir][sparse][linalg] add linalg rewriting specific to sparse tensors

Now that sparse tensor types are first-class citizens and the sparse compiler
is taking shape, it is time to make sure other compiler optimizations compose
well with sparse tensors. Mostly, this should be completely transparent (i.e.,
dense and sparse take the same path). However, in some cases, optimizations
only make sense in the context of sparse tensors. This is a first example of
such an optimization, where fusing a sampled elt-wise multiplication only makes
sense when the resulting kernel has a potential lower asymptotic complexity due
to the sparsity.

As an extreme example, running SDDMM with 1024x1024 matrices and a sparse
sampling matrix with only two elements runs in 463.55ms in the unfused
case but just 0.032ms in the fused case, with a speedup of 14485x that
is only possible in the exciting world of sparse computations!

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D120429

Added: 
    mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
    mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 80ec20ac617ac..a8cd374ab0b84 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -59,6 +59,9 @@ void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
 /// parallel loops.
 void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns that are only useful in the context of sparse tensors.
+void populateSparseTensorRewriting(RewritePatternSet &patterns);
+
 /// Function type which is used to control when to stop fusion. It is expected
 /// that OpOperand is not modified in the callback. The OpOperand is not marked
 /// as const to allow callers to use non-const methods.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index f758546bb9afc..ec8c8c438635e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   NamedOpConversions.cpp
   PadOpInterchange.cpp
   Promotion.cpp
+  SparseTensorRewriting.cpp
   Tiling.cpp
   Transforms.cpp
   Vectorization.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 7e0e857643eb6..3493f4e3c7598 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -49,7 +49,7 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
   AffineMap invProducerResultIndexMap =
       inversePermutation(producerResultIndexMap);
   assert(invProducerResultIndexMap &&
-         "expected producer result indexig map to be invertible");
+         "expected producer result indexing map to be invertible");
 
   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
   // argMap is a map from producer loop -> producer arg tensor index.
@@ -2264,6 +2264,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
                FoldConstantTranspose>(context,
                                       options.controlElementwiseOpsFusionFn);
   patterns.add<RemoveOutsDependency>(context);
+  populateSparseTensorRewriting(patterns);
   populateFoldReshapeOpsByExpansionPatterns(patterns,
                                             options.controlFoldingReshapesFn);
   AffineApplyOp::getCanonicalizationPatterns(patterns, context);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
new file mode 100644
index 0000000000000..3958ab3baf178
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp
@@ -0,0 +1,213 @@
+//===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements linalg dialect rewriting specific to sparse tensors.
+//
+// Sparsity should be mostly transparent to the linalg dialect optimizations
+// (i.e., the dense and sparse take the same path). However, in some cases,
+// optimizations only make sense in the context of sparse tensors. This file
+// implements such sparsity specific rewriting rules.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::sparse_tensor;
+
+//===---------------------------------------------------------------------===//
+// Helper methods for the actual rewriting rules.
+//===---------------------------------------------------------------------===//
+
+// Helper to detect a sparse tensor type operand.
+static bool isSparseTensor(OpOperand *op) {
+  if (auto enc = getSparseTensorEncoding(op->get().getType())) {
+    ArrayRef<SparseTensorEncodingAttr::DimLevelType> dimTypes =
+        enc.getDimLevelType();
+    for (unsigned i = 0, e = dimTypes.size(); i < e; i++)
+      if (dimTypes[i] == SparseTensorEncodingAttr::DimLevelType::Compressed)
+        return true; // at least one compressed
+  }
+  return false;
+}
+
+// Helper method to find zero or empty initialization.
+static bool isEmptyInit(OpOperand *op) {
+  Value val = op->get();
+  if (matchPattern(val, m_Zero()))
+    return true;
+  if (matchPattern(val, m_AnyZeroFloat()))
+    return true;
+  if (val.getDefiningOp<InitTensorOp>())
+    return true;
+  if (val.getDefiningOp<InitOp>())
+    return true;
+  return false;
+}
+
+// Helper to detect sampling operation.
+static bool isSampling(GenericOp op) {
+  auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
+  if (auto def = yieldOp.getOperand(0).getDefiningOp()) {
+    if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
+      // Both scalar input arguments used exactly once.
+      Value s1 = op.getBlock()->getArgument(0);
+      Value s2 = op.getBlock()->getArgument(1);
+      return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
+             (def->getOperand(1) == s1 && def->getOperand(0) == s2);
+    }
+  }
+  return false;
+}
+
+// Helper to detect chain of multiplications that do not involve x.
+static bool isMulChain(Value val, Value x) {
+  if (auto arg = val.dyn_cast<BlockArgument>())
+    return arg != x;
+  if (auto def = val.getDefiningOp()) {
+    if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
+      return isMulChain(def->getOperand(0), x) &&
+             isMulChain(def->getOperand(1), x);
+  }
+  return false;
+}
+
+// Helper to detect x = x + <multiplications>.
+static bool isSumOfMul(GenericOp op) {
+  auto yieldOp = cast<linalg::YieldOp>(op.region().front().getTerminator());
+  if (auto def = yieldOp.getOperand(0).getDefiningOp()) {
+    if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
+      Value x = op.getBlock()->getArguments().back();
+      return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
+             (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
+    }
+  }
+  return false;
+}
+
+//===---------------------------------------------------------------------===//
+// The actual sparse tensor rewriting rules.
+//===---------------------------------------------------------------------===//
+
+namespace {
+/// Rewriting rule that converts two kernels:
+///
+///      T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
+///      X(i,j) = S(i,j) * T(i,j)
+///
+/// into a single kernel, using distributive law:
+///
+///      X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
+///
+/// This kind of fusion (merging two ops into one but using arithmetic
+/// equalities that may not hold for floating-point computations) would
+/// be undesirable in the dense case, since we distribute the multiplication
+/// into the reduction loop. However, for sparse sampling tensor S, such
+/// a fusion may actually reduce the asymptotic complexity of the kernel,
+/// since intermediate results may be nullified.
+struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp op,
+                                PatternRewriter &rewriter) const override {
+    // Check consumer.
+    if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
+        op.getNumResults() != 1)
+      return failure();
+    if (op.getNumParallelLoops() != op.getNumLoops())
+      return failure();
+    if (!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
+        !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
+        !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
+      return failure();
+    // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
+    // operand can be sparse or dense, since the point of this rewriting rule
+    // is detecting a situation in which *more* sparsity is introduced into
+    // a computation, be it already sparse or still dense.
+    unsigned other = 0;
+    if (isSparseTensor(op.getInputOperand(0)))
+      other = 1;
+    else if (!isSparseTensor(op.getInputOperand(1)))
+      return failure();
+    // Check producer.
+    auto prod = dyn_cast_or_null<GenericOp>(
+        op.getInputOperand(other)->get().getDefiningOp());
+    if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1)
+      return failure();
+    if (!prod.getResult(0).hasOneUse())
+      return failure();
+    // Sampling consumer and sum of multiplication chain producer.
+    if (!isEmptyInit(op.getOutputOperand(0)) ||
+        !isEmptyInit(prod.getOutputOperand(0)))
+      return failure();
+    if (!isSampling(op) || !isSumOfMul(prod))
+      return failure();
+    // Modify operand structure of producer and consumer.
+    Location loc = prod.getLoc();
+    SmallVector<Value> inputOps = prod.getInputOperands();
+    SmallVector<Value> outputOps = op.getOutputOperands();
+    SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMaps();
+    inputOps.push_back(op.getInputOperand(1 - other)->get());
+    fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
+    // Fuse producer and consumer into a new generic op.
+    auto fusedOp = rewriter.create<GenericOp>(
+        loc, op.getResult(0).getType(), inputOps, outputOps,
+        rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(),
+        /*doc=*/nullptr, /*library_call=*/nullptr);
+    Block &prodBlock = prod.region().front();
+    Block &consBlock = op.region().front();
+    BlockAndValueMapping mapper;
+    Block *fusedBlock = new Block();
+    fusedOp.region().push_back(fusedBlock);
+    unsigned num = prodBlock.getNumArguments();
+    for (unsigned i = 0; i < num - 1; i++)
+      addArg(mapper, fusedBlock, prodBlock.getArgument(i));
+    addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
+    addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
+    // Clone bodies of the producer and consumer in new evaluation order.
+    auto acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
+    auto sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
+    rewriter.setInsertionPointToStart(fusedBlock);
+    Value last;
+    for (auto &op : prodBlock.without_terminator())
+      if (&op != acc) {
+        last = op.getResult(0);
+        rewriter.clone(op, mapper);
+      }
+    mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
+    mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
+    last = rewriter.clone(*acc, mapper)->getResult(0);
+    rewriter.create<linalg::YieldOp>(loc, last);
+    // Replace consumer with fused operation. Old producer
+    // and consumer ops will be removed by DCE.
+    rewriter.replaceOp(op, fusedOp->getResults());
+    return success();
+  }
+
+private:
+  // Helper to add argument and record the mapping.
+  static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) {
+    mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
+  }
+};
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Methods that add patterns described in this file to a pattern list.
+//===---------------------------------------------------------------------===//
+
+void mlir::linalg::populateSparseTensorRewriting(RewritePatternSet &patterns) {
+  auto *context = patterns.getContext();
+  patterns.add<FuseSparseMultiplyOverAdd>(context);
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
index 017950391e39f..0cbdd7a7f4d0e 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
@@ -5,7 +5,7 @@
 //
 // Do the same run, but now with SIMDization as well. This should not change the outcome.
 //
-// RUN: mlir-opt %s -sparse-compiler="vectorization-strategy=2 vl=8" | \
+// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=2 vl=8" | \
 // RUN: mlir-cpu-runner -e entry -entry-point-result=void \
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
@@ -46,7 +46,8 @@
 //
 module {
   //
-  // A kernel that computes a direct sampled matrix matrix multiplication.
+  // A kernel that computes a direct sampled matrix matrix multiplication
+  // (with dense result).
   //
   func @sampled_dd(%args: tensor<8x8xf64, #SM>,
                    %arga: tensor<8x8xf64>,
@@ -66,11 +67,13 @@ module {
   }
 
   //
-  // A kernel that computes an unfused sampled matrix matrix multiplication.
+  // A kernel that computes an unfused sampled matrix matrix multiplication
+  // (with dense result).
   //
   func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
                            %arga: tensor<8x8xf64>,
-                           %argb: tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8x8xf64>) {
+                           %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
+    // Perform dense-dense matrix matrix multiplication.
     %1 = arith.constant dense<0.0> : tensor<8x8xf64>
     %2 = linalg.generic #trait_matmul
       ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>)
@@ -80,17 +83,68 @@ module {
           %q = arith.addf %x, %p : f64
           linalg.yield %q : f64
     } -> tensor<8x8xf64>
-
-    %3 = arith.constant dense<0.0> : tensor<8x8xf64>
-    %4 = linalg.generic #trait_scale
+    // Sample the result with elements-wise multiplication with sparse matrix.
+    %3 = linalg.generic #trait_scale
       ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
-      outs(%3 : tensor<8x8xf64>) {
+      outs(%1 : tensor<8x8xf64>) {
         ^bb0(%t: f64, %s: f64, %x: f64):
           %r = arith.mulf %t, %s : f64
           linalg.yield %r : f64
     } -> tensor<8x8xf64>
+    return %3 : tensor<8x8xf64>
+  }
 
-    return %4, %2 : tensor<8x8xf64>, tensor<8x8xf64>
+  //
+  // A kernel that computes a direct sampled matrix matrix multiplication
+  // (with sparse result).
+  //
+  func @sparse_sampled_dd(%args: tensor<8x8xf64, #SM>,
+                          %arga: tensor<8x8xf64>,
+                          %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {
+    %c8 = arith.constant 8 : index
+    %1 = sparse_tensor.init [%c8, %c8] : tensor<8x8xf64, #SM>
+    %2 = linalg.generic #trait_sampled_dense_dense
+      ins(%args, %arga, %argb: tensor<8x8xf64, #SM>,
+                               tensor<8x8xf64>, tensor<8x8xf64>)
+      outs(%1: tensor<8x8xf64, #SM>) {
+        ^bb(%s: f64, %a: f64, %b: f64, %x: f64):
+          %p = arith.mulf %a, %b : f64
+          %q = arith.mulf %s, %p : f64
+          %r = arith.addf %x, %q : f64
+          linalg.yield %r : f64
+    } -> tensor<8x8xf64, #SM>
+    return %2 : tensor<8x8xf64, #SM>
+  }
+
+  //
+  // A kernel that computes an unfused sampled matrix matrix multiplication
+  // (with sparse result).
+  //
+  func @sparse_sampled_dd_unfused(
+        %args: tensor<8x8xf64, #SM>,
+        %arga: tensor<8x8xf64>,
+        %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {
+    // Perform dense-dense matrix matrix multiplication.
+    %1 = arith.constant dense<0.0> : tensor<8x8xf64>
+    %2 = linalg.generic #trait_matmul
+      ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>)
+      outs(%1 : tensor<8x8xf64>) {
+        ^bb0(%a: f64, %b: f64, %x: f64):
+          %p = arith.mulf %a, %b : f64
+          %q = arith.addf %x, %p : f64
+          linalg.yield %q : f64
+    } -> tensor<8x8xf64>
+    // Sample the result with elements-wise multiplication with sparse matrix.
+    %c8 = arith.constant 8 : index
+    %3 = sparse_tensor.init [%c8, %c8] : tensor<8x8xf64, #SM>
+    %4 = linalg.generic #trait_scale
+      ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
+      outs(%3 : tensor<8x8xf64, #SM>) {
+        ^bb0(%t: f64, %s: f64, %x: f64):
+          %r = arith.mulf %t, %s : f64
+          linalg.yield %r : f64
+    } -> tensor<8x8xf64, #SM>
+    return %4 : tensor<8x8xf64, #SM>
   }
 
   //
@@ -112,9 +166,15 @@ module {
     %0 = call @sampled_dd(%s, %a, %b)
       : (tensor<8x8xf64, #SM>,
          tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64>
-    %1, %2 = call @sampled_dd_unfused(%s, %a, %b)
+    %1 = call @sampled_dd_unfused(%s, %a, %b)
       : (tensor<8x8xf64, #SM>,
-         tensor<8x8xf64>, tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8x8xf64>)
+         tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64>
+    %2 = call @sparse_sampled_dd(%s, %a, %b)
+      : (tensor<8x8xf64, #SM>,
+         tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64, #SM>
+    %3 = call @sparse_sampled_dd_unfused(%s, %a, %b)
+      : (tensor<8x8xf64, #SM>,
+         tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64, #SM>
 
     // Verify the outputs.
     //
@@ -128,21 +188,31 @@ module {
     // CHECK-SAME: ( 0, 0, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ),
     // CHECK-SAME: ( 0, 0, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 192 ) )
     //
+    // CHECK-NEXT: ( 96, 192, 0, 0 )
+    //
+    // CHECK-NEXT: ( 96, 192, 0, 0 )
+    //
     %m0 = bufferization.to_memref %0 : memref<8x8xf64>
     %m1 = bufferization.to_memref %1 : memref<8x8xf64>
-    %m2 = bufferization.to_memref %2 : memref<8x8xf64>
+    %m2 = sparse_tensor.values %2 : tensor<8x8xf64, #SM> to memref<?xf64>
+    %m3 = sparse_tensor.values %3 : tensor<8x8xf64, #SM> to memref<?xf64>
     %v0 = vector.transfer_read %m0[%c0, %c0], %d0
         : memref<8x8xf64>, vector<8x8xf64>
     %v1 = vector.transfer_read %m1[%c0, %c0], %d0
         : memref<8x8xf64>, vector<8x8xf64>
+    %v2 = vector.transfer_read %m2[%c0], %d0 : memref<?xf64>, vector<4xf64>
+    %v3 = vector.transfer_read %m3[%c0], %d0 : memref<?xf64>, vector<4xf64>
     vector.print %v0 : vector<8x8xf64>
     vector.print %v1 : vector<8x8xf64>
+    vector.print %v2 : vector<4xf64>
+    vector.print %v3 : vector<4xf64>
 
     // Release the resources.
     sparse_tensor.release %s : tensor<8x8xf64, #SM>
     memref.dealloc %m0 : memref<8x8xf64>
     memref.dealloc %m1 : memref<8x8xf64>
-    memref.dealloc %m2 : memref<8x8xf64>
+    sparse_tensor.release %2 : tensor<8x8xf64, #SM>
+    sparse_tensor.release %3 : tensor<8x8xf64, #SM>
 
     return
   }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
index 876e6bd073a09..9f017ad157783 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py
@@ -33,8 +33,9 @@
 
 # Alternative way to define SDDMM kernel. Since this performs the reduction as
 #   sum(k, A[i, k] * B[k, j]) * S[i, j]
-# the MLIR lowering results in two separate tensor index expressions that
-# need to be fused properly to guarantee proper asymptotic complexity.
+# the MLIR lowering results in two separate tensor index expressions that are
+# fused prior to running the sparse compiler in order to guarantee proper
+# asymptotic complexity.
 Y[i, j] = A[i, k] * B[k, j] * S[i, j]
 
 expected = """; extended FROSTT format


        


More information about the Mlir-commits mailing list