[Mlir-commits] [mlir] [mlir][sparse] Add verification for explicit/implicit value (PR #90111)

Yinying Li llvmlistbot at llvm.org
Mon Apr 29 12:54:57 PDT 2024


https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/90111

>From d037f5c94e07bef46404ef058da0fa5d4019758d Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 25 Apr 2024 00:39:00 +0000
Subject: [PATCH 1/4] add verification for explicit/implicit values

---
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 35 +++++++++
 .../SparseTensor/invalid_encoding.mlir        | 72 +++++++++++++++++++
 2 files changed, 107 insertions(+)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 028a69da10c1e1..b7567173341eed 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -907,6 +907,41 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
     return emitError()
            << "dimension-rank mismatch between encoding and tensor shape: "
            << getDimRank() << " != " << dimRank;
+  Type expType, impType;
+  if (getExplicitVal()) {
+    auto fVal = llvm::dyn_cast<FloatAttr>(getExplicitVal());
+    auto intVal = llvm::dyn_cast<IntegerAttr>(getExplicitVal());
+    if (fVal && fVal.getType() != elementType) {
+      expType = fVal.getType();
+    } else if (intVal && intVal.getType() != elementType) {
+      expType = intVal.getType();
+    }
+    if (expType) {
+      return emitError() << "explicit value type mismatch between encoding and "
+                         << "tensor element type: " << expType
+                         << " != " << elementType;
+    }
+  }
+
+  if (getImplicitVal()) {
+    auto impFVal = llvm::dyn_cast<FloatAttr>(getImplicitVal());
+    auto impIntVal = llvm::dyn_cast<IntegerAttr>(getImplicitVal());
+    if (impFVal && impFVal.getType() != elementType) {
+      impType = impFVal.getType();
+    } else if (impIntVal && impIntVal.getType() != elementType) {
+      impType = impIntVal.getType();
+    }
+    if (impType) {
+      return emitError() << "implicit value type mismatch between encoding and "
+                         << "tensor element type: " << impType
+                         << " != " << elementType;
+    }
+    // Currently, we only support zero as the implicit value.
+    if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
+        (impIntVal && impIntVal.getInt() != 0)) {
+      return emitError() << "implicit value must be zero";
+    }
+  }
   return success();
 }
 
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 8096c010ac935a..19e8fc95e22813 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -443,3 +443,75 @@ func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
 func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
   return
 }
+
+// -----
+
+#CSR_ExpType = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  explicitVal = 1 : i32,
+  implicitVal = 0.0 : f32
+}>
+
+// expected-error at +1 {{explicit value type mismatch between encoding and tensor element type: 'i32' != 'f32'}}
+func.func private @sparse_csr(tensor<?x?xf32, #CSR_ExpType>)
+
+// -----
+
+#CSR_ImpType = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  explicitVal = 1 : i32,
+  implicitVal = 0.0 : f32
+}>
+
+// expected-error at +1 {{implicit value type mismatch between encoding and tensor element type: 'f32' != 'i32'}}
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
+
+// -----
+
+// expected-error at +1 {{expected a numeric value for explicitVal}}
+#CSR_ExpType = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  explicitVal = "str"
+}>
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ExpType>)
+
+// -----
+
+// expected-error at +1 {{expected a numeric value for implicitVal}}
+#CSR_ImpType = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  implicitVal = "str"
+}>
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
+
+// -----
+
+#CSR_ImpVal = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  implicitVal = 1 : i32
+}>
+
+// expected-error at +1 {{implicit value must be zero}}
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpVal>)
+
+// -----
+
+#CSR_ImpVal = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 32,
+  crdWidth = 32,
+  implicitVal = 1.0 : f32
+}>
+
+// expected-error at +1 {{implicit value must be zero}}
+func.func private @sparse_csr(tensor<?x?xf32, #CSR_ImpVal>)

>From 806f7f81834679b03689dcae7edcc252795761bc Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 25 Apr 2024 17:37:28 +0000
Subject: [PATCH 2/4] new function

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  5 +++
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 41 ++++++++++---------
 2 files changed, 26 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index eefa4c71bbd2ca..37fa4913aa6a60 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -512,6 +512,11 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     void printSymbols(AffineMap &map, AsmPrinter &printer) const;
     void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
     void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
+
+    //
+    // Explicit/implicit value methods.
+    //
+    Type getMismatchedValueType(Type elementType, Attribute val) const;
   }];
 
   let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b7567173341eed..7c938ecaed5abe 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -888,6 +888,19 @@ LogicalResult SparseTensorEncodingAttr::verify(
   return success();
 }
 
+Type SparseTensorEncodingAttr::getMismatchedValueType(Type elementType,
+                                                      Attribute val) const {
+  Type type;
+  auto fVal = llvm::dyn_cast<FloatAttr>(val);
+  auto intVal = llvm::dyn_cast<IntegerAttr>(val);
+  if (fVal && fVal.getType() != elementType) {
+    type = fVal.getType();
+  } else if (intVal && intVal.getType() != elementType) {
+    type = intVal.getType();
+  }
+  return type;
+}
+
 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
     ArrayRef<Size> dimShape, Type elementType,
     function_ref<InFlightDiagnostic()> emitError) const {
@@ -907,36 +920,24 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
     return emitError()
            << "dimension-rank mismatch between encoding and tensor shape: "
            << getDimRank() << " != " << dimRank;
-  Type expType, impType;
+  Type type;
   if (getExplicitVal()) {
-    auto fVal = llvm::dyn_cast<FloatAttr>(getExplicitVal());
-    auto intVal = llvm::dyn_cast<IntegerAttr>(getExplicitVal());
-    if (fVal && fVal.getType() != elementType) {
-      expType = fVal.getType();
-    } else if (intVal && intVal.getType() != elementType) {
-      expType = intVal.getType();
-    }
-    if (expType) {
+    if ((type = getMismatchedValueType(elementType, getExplicitVal()))) {
       return emitError() << "explicit value type mismatch between encoding and "
-                         << "tensor element type: " << expType
+                         << "tensor element type: " << type
                          << " != " << elementType;
     }
   }
-
   if (getImplicitVal()) {
-    auto impFVal = llvm::dyn_cast<FloatAttr>(getImplicitVal());
-    auto impIntVal = llvm::dyn_cast<IntegerAttr>(getImplicitVal());
-    if (impFVal && impFVal.getType() != elementType) {
-      impType = impFVal.getType();
-    } else if (impIntVal && impIntVal.getType() != elementType) {
-      impType = impIntVal.getType();
-    }
-    if (impType) {
+    auto impVal = getImplicitVal();
+    if ((type = getMismatchedValueType(elementType, impVal))) {
       return emitError() << "implicit value type mismatch between encoding and "
-                         << "tensor element type: " << impType
+                         << "tensor element type: " << type
                          << " != " << elementType;
     }
     // Currently, we only support zero as the implicit value.
+    auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
+    auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
     if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
         (impIntVal && impIntVal.getInt() != 0)) {
       return emitError() << "implicit value must be zero";

>From b9527a0f8253454c338d02bfdb411ebf02110d4d Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 25 Apr 2024 23:49:03 +0000
Subject: [PATCH 3/4] use TypedAttr

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  5 ---
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 40 +++++++++----------
 2 files changed, 18 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 37fa4913aa6a60..eefa4c71bbd2ca 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -512,11 +512,6 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     void printSymbols(AffineMap &map, AsmPrinter &printer) const;
     void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
     void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
-
-    //
-    // Explicit/implicit value methods.
-    //
-    Type getMismatchedValueType(Type elementType, Attribute val) const;
   }];
 
   let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 7c938ecaed5abe..cd3d697fef673d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -888,19 +888,6 @@ LogicalResult SparseTensorEncodingAttr::verify(
   return success();
 }
 
-Type SparseTensorEncodingAttr::getMismatchedValueType(Type elementType,
-                                                      Attribute val) const {
-  Type type;
-  auto fVal = llvm::dyn_cast<FloatAttr>(val);
-  auto intVal = llvm::dyn_cast<IntegerAttr>(val);
-  if (fVal && fVal.getType() != elementType) {
-    type = fVal.getType();
-  } else if (intVal && intVal.getType() != elementType) {
-    type = intVal.getType();
-  }
-  return type;
-}
-
 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
     ArrayRef<Size> dimShape, Type elementType,
     function_ref<InFlightDiagnostic()> emitError) const {
@@ -920,20 +907,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
     return emitError()
            << "dimension-rank mismatch between encoding and tensor shape: "
            << getDimRank() << " != " << dimRank;
-  Type type;
   if (getExplicitVal()) {
-    if ((type = getMismatchedValueType(elementType, getExplicitVal()))) {
-      return emitError() << "explicit value type mismatch between encoding and "
-                         << "tensor element type: " << type
-                         << " != " << elementType;
+    if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getExplicitVal())) {
+      Type attrType = typedAttr.getType();
+      if (attrType != elementType) {
+        return emitError()
+               << "explicit value type mismatch between encoding and "
+               << "tensor element type: " << attrType << " != " << elementType;
+      }
+    } else {
+      return emitError() << "expected typed explicit value";
     }
   }
   if (getImplicitVal()) {
     auto impVal = getImplicitVal();
-    if ((type = getMismatchedValueType(elementType, impVal))) {
-      return emitError() << "implicit value type mismatch between encoding and "
-                         << "tensor element type: " << type
-                         << " != " << elementType;
+    if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getImplicitVal())) {
+      Type attrType = typedAttr.getType();
+      if (attrType != elementType) {
+        return emitError()
+               << "implicit value type mismatch between encoding and "
+               << "tensor element type: " << attrType << " != " << elementType;
+      }
+    } else {
+      return emitError() << "expected typed implicit value";
     }
     // Currently, we only support zero as the implicit value.
     auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);

>From e19fc6b62703418f7a4ee729e2ccc534be1551c9 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Mon, 29 Apr 2024 19:37:43 +0000
Subject: [PATCH 4/4] remove redundant call

---
 mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index cd3d697fef673d..64b1f9ff5a28c5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -921,7 +921,7 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
   }
   if (getImplicitVal()) {
     auto impVal = getImplicitVal();
-    if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getImplicitVal())) {
+    if (auto typedAttr = llvm::dyn_cast<TypedAttr>(impVal)) {
       Type attrType = typedAttr.getType();
       if (attrType != elementType) {
         return emitError()



More information about the Mlir-commits mailing list