[Mlir-commits] [mlir] bb39ad4 - [mlir][spirv] Fix verification of nested array constants
Sergei Grechanik
llvmlistbot at llvm.org
Mon Feb 7 13:57:04 PST 2022
Author: Sergei Grechanik
Date: 2022-02-07T13:48:53-08:00
New Revision: bb39ad43ceeae0772e6902641954c32cde6b4da8
URL: https://github.com/llvm/llvm-project/commit/bb39ad43ceeae0772e6902641954c32cde6b4da8
DIFF: https://github.com/llvm/llvm-project/commit/bb39ad43ceeae0772e6902641954c32cde6b4da8.diff
LOG: [mlir][spirv] Fix verification of nested array constants
Fix the verification function of spirv::ConstantOp to allow nesting
array attributes.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D118939
Added:
Modified:
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Target/SPIRV/TranslateRegistration.cpp
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index b45b422c3c6f4..cb476dcb62307 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1737,17 +1737,13 @@ static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
printer << " : " << constOp.getType();
}
-LogicalResult spirv::ConstantOp::verify() {
- auto opType = getType();
- auto value = valueAttr();
+static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
+ Type opType) {
auto valueType = value.getType();
- // ODS already generates checks to make sure the result type is valid. We just
- // need to additionally check that the value's attribute type is consistent
- // with the result type.
if (value.isa<IntegerAttr, FloatAttr>()) {
if (valueType != opType)
- return emitOpError("result type (")
+ return op.emitOpError("result type (")
<< opType << ") does not match value type (" << valueType << ")";
return success();
}
@@ -1757,7 +1753,9 @@ LogicalResult spirv::ConstantOp::verify() {
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
auto shapedType = valueType.dyn_cast<ShapedType>();
if (!arrayType)
- return emitOpError("must have spv.array result type for array value");
+ return op.emitOpError("result or element type (")
+ << opType << ") does not match value type (" << valueType
+ << "), must be the same or spv.array";
int numElements = arrayType.getNumElements();
auto opElemType = arrayType.getElementType();
@@ -1766,37 +1764,42 @@ LogicalResult spirv::ConstantOp::verify() {
opElemType = t.getElementType();
}
if (!opElemType.isIntOrFloat())
- return emitOpError("only support nested array result type");
+ return op.emitOpError("only support nested array result type");
auto valueElemType = shapedType.getElementType();
if (valueElemType != opElemType) {
- return emitOpError("result element type (")
+ return op.emitOpError("result element type (")
<< opElemType << ") does not match value element type ("
<< valueElemType << ")";
}
if (numElements != shapedType.getNumElements()) {
- return emitOpError("result number of elements (")
+ return op.emitOpError("result number of elements (")
<< numElements << ") does not match value number of elements ("
<< shapedType.getNumElements() << ")";
}
return success();
}
- if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
+ if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
if (!arrayType)
- return emitOpError("must have spv.array result type for array value");
+ return op.emitOpError("must have spv.array result type for array value");
Type elemType = arrayType.getElementType();
- for (Attribute element : attayAttr.getValue()) {
- if (element.getType() != elemType)
- return emitOpError("has array element whose type (")
- << element.getType()
- << ") does not match the result element type (" << elemType
- << ')';
+ for (Attribute element : arrayAttr.getValue()) {
+ // Verify array elements recursively.
+ if (failed(verifyConstantType(op, element, elemType)))
+ return failure();
}
return success();
}
- return emitOpError("cannot have value of type ") << valueType;
+ return op.emitOpError("cannot have value of type ") << valueType;
+}
+
+LogicalResult spirv::ConstantOp::verify() {
+ // ODS already generates checks to make sure the result type is valid. We just
+ // need to additionally check that the value's attribute type is consistent
+ // with the result type.
+ return verifyConstantType(*this, valueAttr(), getType());
}
bool spirv::ConstantOp::isBuildableWith(Type type) {
diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
index aee4ad3bba994..e9d4f34b8f6bf 100644
--- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
+++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
@@ -16,6 +16,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Target/SPIRV/Deserialization.h"
@@ -151,6 +152,8 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
FileLineColLoc::get(&deserializationContext,
/*filename=*/"", /*line=*/0, /*column=*/0)));
dstModule->getBody()->push_front(spirvModule.release());
+ if (failed(verify(*dstModule)))
+ return failure();
dstModule->print(output);
return mlir::success();
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 798b843874a5d..ee09bee26cf6c 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -72,6 +72,7 @@ func @const() -> () {
%6 = spv.Constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>>
%7 = spv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>>
%8 = spv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>>
+ %9 = spv.Constant [[dense<3.0> : vector<2xf32>]] : !spv.array<1 x !spv.array<1xvector<2xf32>>>
return
}
@@ -86,7 +87,7 @@ func @unaccepted_std_attr() -> () {
// -----
func @array_constant() -> () {
- // expected-error @+1 {{has array element whose type ('vector<2xi32>') does not match the result element type ('vector<2xf32>')}}
+ // expected-error @+1 {{result or element type ('vector<2xf32>') does not match value type ('vector<2xi32>')}}
%0 = spv.Constant [dense<3.0> : vector<2xf32>, dense<4> : vector<2xi32>] : !spv.array<2xvector<2xf32>>
return
}
@@ -110,7 +111,7 @@ func @non_nested_array_constant() -> () {
// -----
func @value_result_type_mismatch() -> () {
- // expected-error @+1 {{must have spv.array result type for array value}}
+ // expected-error @+1 {{result or element type ('vector<4xi32>') does not match value type ('tensor<4xi32>')}}
%0 = "spv.Constant"() {value = dense<0> : tensor<4xi32>} : () -> (vector<4xi32>)
}
More information about the Mlir-commits
mailing list