[Mlir-commits] [mlir] 2101590 - NFC: add indexing operator for ArrayAttr

River Riddle llvmlistbot at llvm.org
Fri Feb 14 22:59:17 PST 2020


Author: Uday Bondhugula
Date: 2020-02-14T22:54:37-08:00
New Revision: 2101590a78b7189f89aa06513eeea2dee6a3c45a

URL: https://github.com/llvm/llvm-project/commit/2101590a78b7189f89aa06513eeea2dee6a3c45a
DIFF: https://github.com/llvm/llvm-project/commit/2101590a78b7189f89aa06513eeea2dee6a3c45a.diff

LOG: NFC: add indexing operator for ArrayAttr

Summary: - add ArrayAttr::operator[](unsigned idx)

Differential Revision: https://reviews.llvm.org/D74663

Added: 
    

Modified: 
    mlir/docs/DeclarativeRewrites.md
    mlir/docs/Tutorials/Toy/Ch-7.md
    mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Dialect/VectorOps/VectorOps.cpp
    mlir/lib/IR/Attributes.cpp
    mlir/test/lib/TestDialect/TestDialect.cpp
    mlir/test/lib/TestDialect/TestOps.td
    mlir/test/mlir-tblgen/predicate.td

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md
index 5201eaafe168..0faae5a726f8 100644
--- a/mlir/docs/DeclarativeRewrites.md
+++ b/mlir/docs/DeclarativeRewrites.md
@@ -400,7 +400,7 @@ handy methods on `mlir::Builder`.
 example and decompose the array attribute into two attributes:
 
 ```tablegen
-class getNthAttr<int n> : NativeCodeCall<"$_self.getValue()[" # n # "]">;
+class getNthAttr<int n> : NativeCodeCall<"$_self[" # n # "]">;
 
 def : Pat<(OneAttrOp $attr),
           (TwoAttrOp (getNthAttr<0>:$attr), (getNthAttr<1>:$attr)>;

diff  --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md
index b1872d5accfb..a14d65409982 100644
--- a/mlir/docs/Tutorials/Toy/Ch-7.md
+++ b/mlir/docs/Tutorials/Toy/Ch-7.md
@@ -495,7 +495,7 @@ OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
     return nullptr;
 
   size_t elementIndex = index().getZExtValue();
-  return structAttr.getValue()[elementIndex];
+  return structAttr[elementIndex];
 }
 ```
 

diff  --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
index 019f1dadb14e..2817778dc90e 100644
--- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
@@ -43,7 +43,7 @@ OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
     return nullptr;
 
   size_t elementIndex = index().getZExtValue();
-  return structAttr.getValue()[elementIndex];
+  return structAttr[elementIndex];
 }
 
 /// This is an example of a c++ rewrite pattern for the TransposeOp. It

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 4f677d8b35f5..34e0152e17d0 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -204,6 +204,7 @@ class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
   static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
 
   ArrayRef<Attribute> getValue() const;
+  Attribute operator[](unsigned idx) const;
 
   /// Support range iteration.
   using iterator = llvm::ArrayRef<Attribute>::iterator;

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index e4165914cde5..1ccf9aee6a72 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1292,7 +1292,7 @@ class ArrayCount<int n> : AttrConstraint<
 class IntArrayNthElemEq<int index, int value> : AttrConstraint<
     And<[
       CPred<"$_self.cast<ArrayAttr>().size() > " # index>,
-      CPred<"$_self.cast<ArrayAttr>().getValue()[" # index # "]"
+      CPred<"$_self.cast<ArrayAttr>()[" # index # "]"
         ".cast<IntegerAttr>().getInt() == " # value>
        ]>,
     "whose " # index # "-th element must be " # value>;
@@ -1300,7 +1300,7 @@ class IntArrayNthElemEq<int index, int value> : AttrConstraint<
 class IntArrayNthElemMinValue<int index, int min> : AttrConstraint<
     And<[
       CPred<"$_self.cast<ArrayAttr>().size() > " # index>,
-      CPred<"$_self.cast<ArrayAttr>().getValue()[" # index # "]"
+      CPred<"$_self.cast<ArrayAttr>()[" # index # "]"
         ".cast<IntegerAttr>().getInt() >= " # min>
         ]>,
     "whose " # index # "-th element must be at least " # min>;

diff  --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index b4d7aee70b17..174efb66ccd4 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -1143,7 +1143,7 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
   shape.reserve(vectorType.getRank());
   unsigned idx = 0;
   for (unsigned e = offsets.size(); idx < e; ++idx)
-    shape.push_back(sizes.getValue()[idx].cast<IntegerAttr>().getInt());
+    shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
     shape.push_back(vectorType.getShape()[idx]);
 

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 224ef0a21796..bf4d7330a7b6 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -68,6 +68,11 @@ ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
 
 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
 
+Attribute ArrayAttr::operator[](unsigned idx) const {
+  assert(idx < size() && "index out of bounds");
+  return getValue()[idx];
+}
+
 //===----------------------------------------------------------------------===//
 // BoolAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index af1af0dec927..d1f53a2bcdb5 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -42,7 +42,7 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
     auto args = block->getArguments();
     auto e = std::min(arrayAttr.size(), args.size());
     for (unsigned i = 0; i < e; ++i) {
-      if (auto strAttr = arrayAttr.getValue()[i].dyn_cast<StringAttr>())
+      if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
         setNameFn(args[i], strAttr.getValue());
     }
   }

diff  --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index 48adf2dd1585..743cddd2f55f 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -609,12 +609,11 @@ def OpAllAttrConstraint2 : TEST_Op<"all_attr_constraint_of2"> {
   let results = (outs I32);
 }
 def Constraint0 : AttrConstraint<
-    CPred<"$_self.cast<ArrayAttr>().getValue()[0]."
+    CPred<"$_self.cast<ArrayAttr>()[0]."
           "cast<IntegerAttr>().getInt() == 0">,
     "[0] == 0">;
 def Constraint1 : AttrConstraint<
-    CPred<"$_self.cast<ArrayAttr>().getValue()[1]."
-          "cast<IntegerAttr>().getInt() == 1">,
+    CPred<"$_self.cast<ArrayAttr>()[1].cast<IntegerAttr>().getInt() == 1">,
     "[1] == 1">;
 def : Pat<(OpAllAttrConstraint1
             AllAttrConstraintsOf<[Constraint0, Constraint1]>:$attr),

diff  --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index d02d64526a7c..0de9edce7209 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -56,7 +56,7 @@ def OpH : NS_Op<"op_for_arr_value_at_index", []> {
 }
 
 // CHECK-LABEL: OpH::verify()
-// CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>().getValue()[0].cast<IntegerAttr>().getInt() == 8)))))
+// CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>()[0].cast<IntegerAttr>().getInt() == 8)))))
 // CHECK-SAME:    return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8");
 
 def OpI: NS_Op<"op_for_arr_min_value_at_index", []> {
@@ -64,7 +64,7 @@ def OpI: NS_Op<"op_for_arr_min_value_at_index", []> {
 }
 
 // CHECK-LABEL: OpI::verify()
-// CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>().getValue()[0].cast<IntegerAttr>().getInt() >= 8)))))
+// CHECK: (((tblgen_attr.cast<ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<ArrayAttr>()[0].cast<IntegerAttr>().getInt() >= 8)))))
 // CHECK-SAME:    return emitOpError("attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8");
 
 def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [


        


More information about the Mlir-commits mailing list