[Mlir-commits] [mlir] [mlir][sparse] Add verification for explicit/implicit value (PR #90111)
Yinying Li
llvmlistbot at llvm.org
Thu Apr 25 16:49:22 PDT 2024
https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/90111
>From bedbe3e7a29baec4e54583520a03b457065f48f7 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 25 Apr 2024 00:39:00 +0000
Subject: [PATCH 1/3] add verification for explicit/implicit values
---
.../SparseTensor/IR/SparseTensorDialect.cpp | 35 +++++++++
.../SparseTensor/invalid_encoding.mlir | 72 +++++++++++++++++++
2 files changed, 107 insertions(+)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 028a69da10c1e1..b7567173341eed 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -907,6 +907,41 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
return emitError()
<< "dimension-rank mismatch between encoding and tensor shape: "
<< getDimRank() << " != " << dimRank;
+ Type expType, impType;
+ if (getExplicitVal()) {
+ auto fVal = llvm::dyn_cast<FloatAttr>(getExplicitVal());
+ auto intVal = llvm::dyn_cast<IntegerAttr>(getExplicitVal());
+ if (fVal && fVal.getType() != elementType) {
+ expType = fVal.getType();
+ } else if (intVal && intVal.getType() != elementType) {
+ expType = intVal.getType();
+ }
+ if (expType) {
+ return emitError() << "explicit value type mismatch between encoding and "
+ << "tensor element type: " << expType
+ << " != " << elementType;
+ }
+ }
+
+ if (getImplicitVal()) {
+ auto impFVal = llvm::dyn_cast<FloatAttr>(getImplicitVal());
+ auto impIntVal = llvm::dyn_cast<IntegerAttr>(getImplicitVal());
+ if (impFVal && impFVal.getType() != elementType) {
+ impType = impFVal.getType();
+ } else if (impIntVal && impIntVal.getType() != elementType) {
+ impType = impIntVal.getType();
+ }
+ if (impType) {
+ return emitError() << "implicit value type mismatch between encoding and "
+ << "tensor element type: " << impType
+ << " != " << elementType;
+ }
+ // Currently, we only support zero as the implicit value.
+ if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
+ (impIntVal && impIntVal.getInt() != 0)) {
+ return emitError() << "implicit value must be zero";
+ }
+ }
return success();
}
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 8096c010ac935a..19e8fc95e22813 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -443,3 +443,75 @@ func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
return
}
+
+// -----
+
+#CSR_ExpType = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ explicitVal = 1 : i32,
+ implicitVal = 0.0 : f32
+}>
+
+// expected-error at +1 {{explicit value type mismatch between encoding and tensor element type: 'i32' != 'f32'}}
+func.func private @sparse_csr(tensor<?x?xf32, #CSR_ExpType>)
+
+// -----
+
+#CSR_ImpType = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ explicitVal = 1 : i32,
+ implicitVal = 0.0 : f32
+}>
+
+// expected-error at +1 {{implicit value type mismatch between encoding and tensor element type: 'f32' != 'i32'}}
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
+
+// -----
+
+// expected-error at +1 {{expected a numeric value for explicitVal}}
+#CSR_ExpType = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ explicitVal = "str"
+}>
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ExpType>)
+
+// -----
+
+// expected-error at +1 {{expected a numeric value for implicitVal}}
+#CSR_ImpType = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ implicitVal = "str"
+}>
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpType>)
+
+// -----
+
+#CSR_ImpVal = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ implicitVal = 1 : i32
+}>
+
+// expected-error at +1 {{implicit value must be zero}}
+func.func private @sparse_csr(tensor<?x?xi32, #CSR_ImpVal>)
+
+// -----
+
+#CSR_ImpVal = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 32,
+ crdWidth = 32,
+ implicitVal = 1.0 : f32
+}>
+
+// expected-error at +1 {{implicit value must be zero}}
+func.func private @sparse_csr(tensor<?x?xf32, #CSR_ImpVal>)
>From f4f58e4d65b05b3e8d22f9083d7ca53a7e59d91f Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 25 Apr 2024 17:37:28 +0000
Subject: [PATCH 2/3] new function
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 5 +++
.../SparseTensor/IR/SparseTensorDialect.cpp | 41 ++++++++++---------
2 files changed, 26 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index eefa4c71bbd2ca..37fa4913aa6a60 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -512,6 +512,11 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
void printSymbols(AffineMap &map, AsmPrinter &printer) const;
void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
+
+ //
+ // Explicit/implicit value methods.
+ //
+ Type getMismatchedValueType(Type elementType, Attribute val) const;
}];
let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b7567173341eed..7c938ecaed5abe 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -888,6 +888,19 @@ LogicalResult SparseTensorEncodingAttr::verify(
return success();
}
+Type SparseTensorEncodingAttr::getMismatchedValueType(Type elementType,
+ Attribute val) const {
+ Type type;
+ auto fVal = llvm::dyn_cast<FloatAttr>(val);
+ auto intVal = llvm::dyn_cast<IntegerAttr>(val);
+ if (fVal && fVal.getType() != elementType) {
+ type = fVal.getType();
+ } else if (intVal && intVal.getType() != elementType) {
+ type = intVal.getType();
+ }
+ return type;
+}
+
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
ArrayRef<Size> dimShape, Type elementType,
function_ref<InFlightDiagnostic()> emitError) const {
@@ -907,36 +920,24 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
return emitError()
<< "dimension-rank mismatch between encoding and tensor shape: "
<< getDimRank() << " != " << dimRank;
- Type expType, impType;
+ Type type;
if (getExplicitVal()) {
- auto fVal = llvm::dyn_cast<FloatAttr>(getExplicitVal());
- auto intVal = llvm::dyn_cast<IntegerAttr>(getExplicitVal());
- if (fVal && fVal.getType() != elementType) {
- expType = fVal.getType();
- } else if (intVal && intVal.getType() != elementType) {
- expType = intVal.getType();
- }
- if (expType) {
+ if ((type = getMismatchedValueType(elementType, getExplicitVal()))) {
return emitError() << "explicit value type mismatch between encoding and "
- << "tensor element type: " << expType
+ << "tensor element type: " << type
<< " != " << elementType;
}
}
-
if (getImplicitVal()) {
- auto impFVal = llvm::dyn_cast<FloatAttr>(getImplicitVal());
- auto impIntVal = llvm::dyn_cast<IntegerAttr>(getImplicitVal());
- if (impFVal && impFVal.getType() != elementType) {
- impType = impFVal.getType();
- } else if (impIntVal && impIntVal.getType() != elementType) {
- impType = impIntVal.getType();
- }
- if (impType) {
+ auto impVal = getImplicitVal();
+ if ((type = getMismatchedValueType(elementType, impVal))) {
return emitError() << "implicit value type mismatch between encoding and "
- << "tensor element type: " << impType
+ << "tensor element type: " << type
<< " != " << elementType;
}
// Currently, we only support zero as the implicit value.
+ auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
+ auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal);
if ((impFVal && impFVal.getValueAsDouble() != 0.0) ||
(impIntVal && impIntVal.getInt() != 0)) {
return emitError() << "implicit value must be zero";
>From bb03496c566e7f62f9c7667bcdd62c7a3ae09be6 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 25 Apr 2024 23:49:03 +0000
Subject: [PATCH 3/3] use TypedAttr
---
.../SparseTensor/IR/SparseTensorAttrDefs.td | 5 ---
.../SparseTensor/IR/SparseTensorDialect.cpp | 40 +++++++++----------
2 files changed, 18 insertions(+), 27 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 37fa4913aa6a60..eefa4c71bbd2ca 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -512,11 +512,6 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
void printSymbols(AffineMap &map, AsmPrinter &printer) const;
void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
-
- //
- // Explicit/implicit value methods.
- //
- Type getMismatchedValueType(Type elementType, Attribute val) const;
}];
let genVerifyDecl = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 7c938ecaed5abe..cd3d697fef673d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -888,19 +888,6 @@ LogicalResult SparseTensorEncodingAttr::verify(
return success();
}
-Type SparseTensorEncodingAttr::getMismatchedValueType(Type elementType,
- Attribute val) const {
- Type type;
- auto fVal = llvm::dyn_cast<FloatAttr>(val);
- auto intVal = llvm::dyn_cast<IntegerAttr>(val);
- if (fVal && fVal.getType() != elementType) {
- type = fVal.getType();
- } else if (intVal && intVal.getType() != elementType) {
- type = intVal.getType();
- }
- return type;
-}
-
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
ArrayRef<Size> dimShape, Type elementType,
function_ref<InFlightDiagnostic()> emitError) const {
@@ -920,20 +907,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
return emitError()
<< "dimension-rank mismatch between encoding and tensor shape: "
<< getDimRank() << " != " << dimRank;
- Type type;
if (getExplicitVal()) {
- if ((type = getMismatchedValueType(elementType, getExplicitVal()))) {
- return emitError() << "explicit value type mismatch between encoding and "
- << "tensor element type: " << type
- << " != " << elementType;
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getExplicitVal())) {
+ Type attrType = typedAttr.getType();
+ if (attrType != elementType) {
+ return emitError()
+ << "explicit value type mismatch between encoding and "
+ << "tensor element type: " << attrType << " != " << elementType;
+ }
+ } else {
+ return emitError() << "expected typed explicit value";
}
}
if (getImplicitVal()) {
auto impVal = getImplicitVal();
- if ((type = getMismatchedValueType(elementType, impVal))) {
- return emitError() << "implicit value type mismatch between encoding and "
- << "tensor element type: " << type
- << " != " << elementType;
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getImplicitVal())) {
+ Type attrType = typedAttr.getType();
+ if (attrType != elementType) {
+ return emitError()
+ << "implicit value type mismatch between encoding and "
+ << "tensor element type: " << attrType << " != " << elementType;
+ }
+ } else {
+ return emitError() << "expected typed implicit value";
}
// Currently, we only support zero as the implicit value.
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
More information about the Mlir-commits
mailing list