[Mlir-commits] [mlir] [mlir][sparse][gpu] add CSC to libgen GPU sparsification using cuSparse (PR #67713)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 28 10:36:04 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

Add CSC, but also adds BSR as a future format. Coming soon!

---

Patch is 26.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67713.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (+91-64) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir (+46-2) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-lib.mlir (+30-4) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 91b346c8a9b4c4d..ba3fd751ce1c39e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -33,6 +33,15 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
+// Sparse formats supported by cuSparse.
+enum class CuSparseFormat {
+  kNone,
+  kCOO,
+  kCSR,
+  kCSC,
+  kBSR,  // TODO: coming soon!
+};
+
 //===----------------------------------------------------------------------===//
 // Helper methods.
 //===----------------------------------------------------------------------===//
@@ -385,73 +394,92 @@ static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
   return false;
 }
 
-/// Determines if the given value is a dense tensor instead of a sparse one.
+/// Test for dense tensor.
 static bool isDenseTensor(Value v) {
-  return (sparse_tensor::getSparseTensorType(v).isAllDense());
+  auto sTp = getSparseTensorType(v);
+  return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
+}
+
+/// Test for suitable positions/coordinates width.
+static bool isAdmissibleMetaData(SparseTensorType &aTp) {
+  return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
+         (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
 }
 
-/// Test for sorted COO with suitable data and coordinates types.
+/// Test for sorted COO matrix with suitable metadata.
 static bool isAdmissibleCOO(SparseTensorType &aTp) {
-  return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
+  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
+         aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
          aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
-         (aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
-         (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
-          aTp.getCrdWidth() == 64);
+         isAdmissibleMetaData(aTp);
 }
 
-/// Test for CSR with suitable data and coordinates types.
+/// Test for CSR matrix with suitable metadata.
 static bool isAdmissibleCSR(SparseTensorType &aTp) {
-  return aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
-         aTp.isUniqueLvl(1) &&
-         (aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
-         (aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
-          aTp.getCrdWidth() == 64);
+  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
+         aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
+         aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
 }
 
-/// Test for admissible types on operands (with output parameter `isCOO`).
-static bool areAdmissibleTypes(SparseTensorType aTp, SparseTensorType bTp,
-                               SparseTensorType cTp, bool enableRT,
-                               bool isMatVec, bool &isCOO) {
+/// Test for CSC matrix with suitable metadata.
+static bool isAdmissibleCSC(SparseTensorType &aTp) {
+  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
+         aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
+         aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
+}
+
+/// Returns a suitable sparse format for the operation and given operand
+/// types with cuSparse, or kNone if none is available.
+static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
+                                        SparseTensorType bTp,
+                                        SparseTensorType cTp, bool enableRT,
+                                        bool isMatVec) {
+  // The other operands have a dense type.
   if (bTp.hasEncoding() || cTp.hasEncoding())
-    return false;
-  if (isAdmissibleCOO(aTp)) {
-    isCOO = true;
+    return CuSparseFormat::kNone;
+  // Now check for suitable operand type for the main operand.
+  if (isAdmissibleCOO(aTp))
 #ifdef CUSPARSE_COO_AOS
-    return isMatVec;
+    return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
 #else
-    return enableRT;
+    return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
 #endif
-  }
-  return isAdmissibleCSR(aTp);
+  if (isAdmissibleCSR(aTp))
+    return CuSparseFormat::kCSR;
+  if (isAdmissibleCSC(aTp))
+    return CuSparseFormat::kCSC;
+  return CuSparseFormat::kNone;
 }
 
 /// Generates the first positions/coordinates of a sparse matrix.
 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
-                               bool isCOO, bool enableRT) {
-  if (isCOO) {
+                               CuSparseFormat format, bool enableRT) {
+  if (format == CuSparseFormat::kCOO) {
     // Library uses SoA COO, direct IR uses AoS COO.
     if (enableRT)
       return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0);
     return genToCoordinatesBuffer(builder, loc, a);
   }
-  // CSR uses positions.
+  // Formats CSR/CSC and BSR use positions at 1.
   return genToPositions(builder, loc, a, 1);
 }
 
 /// Generates the second coordinates of a sparse matrix.
 static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
-                           bool isCOO, bool enableRT) {
+                           CuSparseFormat format, bool enableRT) {
+  bool isCOO = format == CuSparseFormat::kCOO;
   if (isCOO && !enableRT)
     return Value(); // nothing needed
+  // Formats CSR/CSC and BSR use coordinates at 1.
   return genToCoordinates(builder, loc, a, 1, /*cooStart=*/isCOO ? 0 : 2);
 }
 
-/// Generates the sparse matrix multiplication.
+/// Generates the sparse matrix handle.
 static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
                            Type tokenTp, Value token, Value sz1, Value sz2,
                            Value nseA, Value rowA, Value colA, Value valA,
-                           bool isCOO, bool enableRT) {
-  if (isCOO) {
+                           CuSparseFormat format, bool enableRT) {
+  if (format == CuSparseFormat::kCOO) {
     // Library uses SoA COO, direct IR uses AoS COO.
     if (enableRT) {
       assert(colA);
@@ -467,7 +495,11 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
 #endif
   }
   assert(colA);
-  return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
+  if (format == CuSparseFormat::kCSR)
+    return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
+                                            sz2, nseA, rowA, colA, valA);
+  assert(format == CuSparseFormat::kCSC);
+  return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
                                           sz2, nseA, rowA, colA, valA);
 }
 
@@ -484,12 +516,12 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   bool isZeroCopy =
       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
 
-  // Only admissible sparse matrix format and dense vectors.
-  bool isCOO = false;
+  // Only admissible sparse matrix format and dense vectors (no BSR).
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType xTp = getSparseTensorType(x);
   SparseTensorType yTp = getSparseTensorType(y);
-  if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, /*isMatVec=*/true, isCOO))
+  auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, /*isMatVec=*/true);
+  if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
     return failure();
 
   // Start sparse kernel and copy data from host to device.
@@ -499,8 +531,8 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
   Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
   Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
-  Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
-  Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
+  Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
   Value memV = genToValues(rewriter, loc, a);
   Value memX, memY;
   Value castR, castC, castV, castX, castY;
@@ -535,7 +567,7 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value token = genFirstWait(rewriter, loc);
   Operation *spGenA =
       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
-               rowA, colA, valA, isCOO, enableRT);
+               rowA, colA, valA, format, enableRT);
   Value spMatA = spGenA->getResult(0);
   token = spGenA->getResult(1);
   auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
@@ -546,7 +578,6 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
       loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
   Value dnY = dvecY.getResult(0);
   token = dvecY.getAsyncToken();
-
   auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
 
   // Precompute buffersize for SpMV.
@@ -610,12 +641,12 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   bool isZeroCopy =
       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
 
-  // Only admissible sparse matrix format and dense matrices.
-  bool isCOO = false;
+  // Only admissible sparse matrix format and dense matrices (no BSR).
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType bTp = getSparseTensorType(b);
   SparseTensorType cTp = getSparseTensorType(c);
-  if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, /*isMatVec=*/false, isCOO))
+  auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, /*isMatVec=*/false);
+  if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
     return failure();
 
   // Start sparse kernel and copy data from host to device.
@@ -626,8 +657,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
-  Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
-  Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
+  Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
   Value memV = genToValues(rewriter, loc, a);
   Value bufB, bufC;
   Value castR, castC, castV, castB, castBufC;
@@ -661,7 +692,7 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value token = genFirstWait(rewriter, loc);
   Operation *spGenA =
       genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
-               rowA, colA, valA, isCOO, enableRT);
+               rowA, colA, valA, format, enableRT);
   Value spMatA = spGenA->getResult(0);
   token = spGenA->getResult(1);
   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
@@ -674,7 +705,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
       SmallVector<Value>{szm, szn});
   Value dnC = dmatC.getResult(0);
   token = dmatC.getAsyncToken();
-
   auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
 
   // Precompute buffersize for SpMM.
@@ -686,7 +716,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
   Value buffer = buf.getResult(0);
   token = buf.getAsyncToken();
-
   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
 
   // Perform the SpMM.
@@ -738,7 +767,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   SmallVector<Value> tokens;
 
   // Only CSR <- CSR x CSR supported.
-  bool isCOO = false;
+  auto format = CuSparseFormat::kCSR;
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType bTp = getSparseTensorType(b);
   SparseTensorType cTp = getSparseTensorType(c);
@@ -755,11 +784,11 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
   Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
   Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
-  Value amemR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
-  Value amemC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
+  Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
+  Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT);
   Value amemV = genToValues(rewriter, loc, a);
-  Value bmemR = genFirstPosOrCrds(rewriter, loc, b, isCOO, enableRT);
-  Value bmemC = genSecondCrds(rewriter, loc, b, isCOO, enableRT);
+  Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
+  Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT);
   Value bmemV = genToValues(rewriter, loc, b);
   Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
   Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
@@ -778,12 +807,12 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value token = genFirstWait(rewriter, loc);
   Operation *spGenA =
       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
-               rowA, colA, valA, isCOO, enableRT);
+               rowA, colA, valA, format, enableRT);
   Value spMatA = spGenA->getResult(0);
   token = spGenA->getResult(1);
   Operation *spGenB =
       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
-               rowB, colB, valB, isCOO, enableRT);
+               rowB, colB, valB, format, enableRT);
   Value spMatB = spGenB->getResult(0);
   token = spGenB->getResult(1);
 
@@ -802,7 +831,7 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   token = e3.getAsyncToken();
   Operation *spGenC =
       genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
-               rowC, colC, valC, isCOO, enableRT);
+               rowC, colC, valC, format, enableRT);
   Value spMatC = spGenC->getResult(0);
   token = spGenC->getResult(1);
 
@@ -1045,14 +1074,13 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   bool isZeroCopy =
       gpuDataTransferStrategy == GPUDataTransferStrategy::kZeroCopy;
 
-  // Only admissible sparse matrix format and dense matrices, no COO.
-  bool isCOO = false;
+  // Only admissible sparse matrix format (no COO/CSC) and dense matrices.
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType bTp = getSparseTensorType(b);
   SparseTensorType cTp = getSparseTensorType(c);
-  if (!areAdmissibleTypes(cTp, bTp, aTp, enableRT, false, isCOO))
-    return failure();
-  if (isCOO)
+  auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, /*isMatVec=*/false);
+  if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
+      format == CuSparseFormat::kCSC)
     return failure();
 
   // The SDDMM does the in-place operation.
@@ -1071,8 +1099,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value bufB = genTensorToMemref(rewriter, loc, b);
   if (!isZeroCopy)
     matB = isZeroCopy ? bufB : genAllocCopy(rewriter, loc, bufB, tokens);
-  Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT);
-  Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT);
+  Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, c, format, enableRT);
   Value memV = genToValues(rewriter, loc, c);
   Value castB, castA, castR, castC, castV;
   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
@@ -1107,10 +1135,9 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
       loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
   Value dnB = dmatB.getResult(0);
   token = dmatB.getAsyncToken();
-
   Operation *spGenC =
       genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
-               rowC, colC, valC, isCOO, enableRT);
+               rowC, colC, valC, format, enableRT);
   Value spMatC = spGenC->getResult(0);
   token = spGenC->getResult(1);
   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir
old mode 100644
new mode 100755
index 6782f2d0e2014d7..9e397e0ad5b5dc2
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir
@@ -34,6 +34,12 @@
   crdWidth = 32
 }>
 
+#CSC = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d1 : dense, d0 : compressed),
+  posWidth = 64,
+  crdWidth = 64
+}>
+
 module {
   llvm.func @mgpuCreateSparseEnv()
   llvm.func @mgpuDestroySparseEnv()
@@ -58,6 +64,16 @@ module {
     return %D: tensor<8x8xf32>
   }
 
+  // Computes C = A x B with A sparse CSC.
+  func.func @matmulCSC(%A: tensor<8x8xf32, #CSC>,
+                       %B: tensor<8x8xf32>,
+                       %C: tensor<8x8xf32>) -> tensor<8x8xf32> {
+    %D = linalg.matmul
+      ins(%A, %B: tensor<8x8xf32, #CSC>, tensor<8x8xf32>)
+      outs(%C: tensor<8x8xf32>) -> tensor<8x8xf32>
+    return %D: tensor<8x8xf32>
+  }
+
   func.func @dump(%mat: tensor<8x8xf32>) {
     %f0 = arith.constant 0.0 : f32
     %c0 = arith.constant 0   : index
@@ -107,6 +123,7 @@ module {
     // Convert to a "sparse" matrix A.
     %Acoo = sparse_tensor.convert %DA : tensor<8x8xf32> to tensor<8x8xf32, #SortedCOO>
     %Acsr = sparse_tensor.convert %DA : tensor<8x8xf32> to tensor<8x8xf32, #CSR>
+    %Acsc = sparse_tensor.convert %DA : tensor<8x8xf32> to tensor<8x8xf32, #CSC>
 
     // Initial C matrices.
     %C0 = tensor.generate {
@@ -125,10 +142,16 @@ module {
     %1 = call @matmulCSR(%Acsr, %DA, %C0) : (tensor<8x8xf32, #CSR>,
                                              tensor<8x8xf32>,
 					     tensor<8x8xf32>) -> tensor<8x8xf32>
-    %2 = call @matmulCOO(%Acoo, %DA, %C1) : (tensor<8x8xf32, #SortedCOO>,
+    %2 = call @matmulCSC(%Acsc, %DA, %C0) : (tensor<8x8xf32, #CSC>,
+                                             tensor<8x8xf32>,
+					     tensor<8x8xf32>) -> tensor<8x8xf32>
+    %3 = call @matmulCOO(%Acoo, %DA, %C1) : (tensor<8x8xf32, #SortedCOO>,
+                                             tensor<8x8xf32>,
+					     tensor<8x8xf32>) -> tensor<8x8xf32>
+    %4 = call @matmulCSR(%Acsr, %DA, %C1) : (tensor<8x8xf32, #CSR>,
                                              tensor<8x8xf32>,
 					     tensor<8x8xf32>) -> tensor<8x8xf32>
-    %3 = call @matmulCSR(%Acsr, %DA, %C1) : (tensor<8x8xf32, #CSR>,
+    %5 = call @matmulCSC(%Acsc, %DA, %C1) : (tensor<8x8xf32, #CSC>,
                                              tensor<8x8xf32>,
 					     tensor<8x8xf32>) -> tensor<8x8xf32>
 
@@ -153,6 +176,24 @@ module {
     // CHECK-NEXT: ( 308, 384, 460, 536, 612, 688, 764, 840 )
     // CHECK-NEXT: ( 336, 420, 504, 588, 672, 756, 840, 924 )
     //
+    // CHECK:      ( 140, 168, 196, 224, 252, 280, 308, 336 )
+    // CHECK-NEXT: ( 168, 204, 240, 276, 312, 348, 384, 420 )
+    // CHECK-NEXT: ( 196, 240, 284, 328, 372, 416, 460, 504 )
+    // CHECK-NEXT: ( 224, 276, 328, 380, 432, 484, 536, 588 )
+    // CHECK-NEXT: ( 252, 312, 372, 432, 492, 552, 612, 672 )
+    // CHECK-NEXT: ( 280, 348, 416, 484, 552, 620, 688, 756 )
+    // CHECK-NEXT: ( 308, 384, 460, 536, 612, 688, 764, 840 )
+    // CHECK-NEXT: ( 336, 420, 504, 588, 672, 756, 840, 924 )
+    //
+    // CHECK:      ( 141, 169, 197, 225, 253, 281, 309, 337 )
+    // CHECK-NEXT: ( 169, 205, 241, 277, 313, 349, 385, 421 )
+    // CHECK-NEXT: ( 197, 241, 285, 329, 373, 417, 461, 505 )
+    // CHECK-NEXT: ( 225, 277, 329, 381, 433, 485, 537, 589 )
+    // CHECK-NEXT: ( 253, 313, 373, 433, 493, 553, 613, 673 )
+    // CHECK-NEXT: ( 281, 349, 417, 485, 553, 621, 689, 757 )
+    // CHECK-NEXT: ( 309, 385, 461, 537, 613, 689, 765, 841 )
+    // CHECK-NEXT: ( 337, 421, 505, 589, 673, 757, 841, 925 )
+    //
     // CHECK:      ( 141, 169, 197, 225, 253, 281, 309, 337 )
     // CHECK-NEXT: ( 169...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/67713


More information about the Mlir-commits mailing list