[Mlir-commits] [mlir] [mlir][sparse] fold explicit value during sparsification (PR #90530)

Aart Bik llvmlistbot at llvm.org
Mon Apr 29 15:59:39 PDT 2024


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/90530

>From d0d80ed841a43d02b04e2e07d1c21fe1fe4feef3 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Mon, 29 Apr 2024 15:33:03 -0700
Subject: [PATCH 1/2] [mlir][sparse] fold explicit value during sparsification

This ensures the explicit value is generated (and not a load
into the values array). Note that actually not storing values
array at all is still TBD, this is just the very first step.
---
 .../Transforms/Sparsification.cpp             | 12 ++-
 .../Transforms/Utils/CodegenUtils.h           | 10 +++
 .../SparseTensor/sparse_matmul_one.mlir       | 75 +++++++++++++++++++
 3 files changed, 95 insertions(+), 2 deletions(-)
 create mode 100755 mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0a9bb40b458d68..b04ca11f714ba1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -498,9 +498,17 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
   Value val = env.exp(exp).val;
   if (val)
     return val;
-  // Load during insertion.
+  // Get tensor operand.
   linalg::GenericOp op = env.op();
+  Location loc = op.getLoc();
   OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
+  // Fold binary-valued tensor into explicit value.
+  const auto stt = getSparseTensorType(t->get());
+  if (stt.hasEncoding()) {
+    if (auto explVal = stt.getExplicitVal())
+      return genValFromAttr(builder, loc, explVal);
+  }
+  // Load during insertion.
   if (env.isSparseOutput(t)) {
     if (env.isCustomReduc())
       return genInsertionLoadReduce(env, builder, t);
@@ -509,7 +517,7 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
   // Actual load.
   SmallVector<Value> args;
   Value ptr = genSubscript(env, builder, t, args);
-  return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
+  return builder.create<memref::LoadOp>(loc, ptr, args);
 }
 
 /// Generates a store on a dense or sparse tensor.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index ce5831d999e9a4..cf3c35f5fa4c78 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -399,6 +399,16 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
   return constantI64(builder, loc, static_cast<uint64_t>(lt));
 }
 
+// Generates a constant from a validated value carrying attribute.
+inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    Type tp = cast<TypedAttr>(arrayAttr[0]).getType();
+    return builder.create<complex::ConstantOp>(loc, tp, arrayAttr);
+  }
+  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
+}
+
+// TODO: is this at the right place?
 inline bool isZeroRankedTensorOrScalar(Type type) {
   auto rtp = dyn_cast<RankedTensorType>(type);
   return !rtp || rtp.getRank() == 0;
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
new file mode 100755
index 00000000000000..09ec43b393d52d
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN:             --sparsification-and-bufferization | FileCheck %s
+
+#CSR_ones_complex = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed)
+// explicitVal = (1.0, 0.0) : complex<f32>,
+// implicitVal = (1.0, 0.0) : complex<f32>
+}>
+
+#CSR_ones_fp = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  explicitVal = 1.0 : f32,
+  implicitVal = 0.0 : f32
+}>
+
+#CSR_ones_int = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  explicitVal = 1 : i32,
+  implicitVal = 0 : i32
+}>
+
+// CHECK-LABEL:   func.func @matmul_complex
+//
+// TODO: make this work
+//
+func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>,
+                          %b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>,
+                          %c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> {
+  %0 = linalg.matmul
+    ins(%a, %b: tensor<10x20xcomplex<f32>>, tensor<20x30xcomplex<f32>,#CSR_ones_complex>)
+    outs(%c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>>
+  return %0 : tensor<10x30xcomplex<f32>>
+}
+
+// CHECK-LABEL:   func.func @matmul_fp
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             %[[X:.*]] = memref.load
+// CHECK:             scf.for
+// CHECK:               %[[I:.*]] = memref.load
+// CHECK:               %[[Y:.*]] = memref.load
+// CHECK:               %[[M:.*]] = arith.addf %[[Y]], %[[X]] : f32
+// CHECK:               memref.store %[[M]]
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+func.func @matmul_fp(%a: tensor<10x20xf32>,
+                     %b: tensor<20x30xf32, #CSR_ones_fp>,
+                     %c: tensor<10x30xf32>) -> tensor<10x30xf32> {
+  %0 = linalg.matmul
+    ins(%a, %b: tensor<10x20xf32>, tensor<20x30xf32,#CSR_ones_fp>)
+    outs(%c: tensor<10x30xf32>) -> tensor<10x30xf32>
+  return %0 : tensor<10x30xf32>
+}
+
+// CHECK-LABEL:   func.func @matmul_int
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             %[[X:.*]] = memref.load
+// CHECK:             scf.for
+// CHECK:               %[[I:.*]] = memref.load
+// CHECK:               %[[Y:.*]] = memref.load
+// CHECK:               %[[M:.*]] = arith.addi %[[Y]], %[[X]] : i32
+// CHECK:               memref.store %[[M]]
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+func.func @matmul_int(%a: tensor<10x20xi32>,
+                      %b: tensor<20x30xi32, #CSR_ones_int>,
+                      %c: tensor<10x30xi32>) -> tensor<10x30xi32> {
+  %0 = linalg.matmul
+    ins(%a, %b: tensor<10x20xi32>, tensor<20x30xi32,#CSR_ones_int>)
+    outs(%c: tensor<10x30xi32>) -> tensor<10x30xi32>
+  return %0 : tensor<10x30xi32>
+}

>From 63d3c90546d7c1253eb9a3a5a85e0f7bb745e452 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Mon, 29 Apr 2024 15:59:17 -0700
Subject: [PATCH 2/2] reviewer feedback

---
 mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp | 6 ++----
 mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir       | 2 +-
 2 files changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index b04ca11f714ba1..0c8e431d8c9964 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -504,10 +504,8 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
   OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
   // Fold binary-valued tensor into explicit value.
   const auto stt = getSparseTensorType(t->get());
-  if (stt.hasEncoding()) {
-    if (auto explVal = stt.getExplicitVal())
-      return genValFromAttr(builder, loc, explVal);
-  }
+  if (auto explVal = stt.getExplicitVal())
+    return genValFromAttr(builder, loc, explVal);
   // Load during insertion.
   if (env.isSparseOutput(t)) {
     if (env.isCustomReduc())
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
index 09ec43b393d52d..82f3147d3206bd 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
@@ -4,7 +4,7 @@
 #CSR_ones_complex = #sparse_tensor.encoding<{
   map = (d0, d1) -> (d0 : dense, d1 : compressed)
 // explicitVal = (1.0, 0.0) : complex<f32>,
-// implicitVal = (1.0, 0.0) : complex<f32>
+// implicitVal = (0.0, 0.0) : complex<f32>
 }>
 
 #CSR_ones_fp = #sparse_tensor.encoding<{



More information about the Mlir-commits mailing list