Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow subclasses to customize rules generation. #34

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
common --incompatible_disable_deprecated_attr_params=false
common --incompatible_no_support_tools_in_action_inputs=false
common --incompatible_new_actions_api=false
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ py_test(
"@com_google_protobuf//:protoc",
"@com_google_protobuf//:timestamp_proto",
],
python_version = "PY2",
)
12 changes: 7 additions & 5 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
protobuf_commit = "099d99759101c295244c24d8954ec85b8ac65ce3"
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

protobuf_sha256 = "c0ab1b088e220c1d56446f34001f0178e590270efdef1c46a77da4b9faa9d7b0"
protobuf_version = "3.6.1.3"

protobuf_sha256 = "73fdad358857e120fd0fa19e071a96e15c0f23bb25f85d3f7009abfd4f264a2a"


def protobuf_rules_gen_repositories():
if "com_google_protobuf" not in native.existing_rules():
native.http_archive(
http_archive(
name = "com_google_protobuf",
sha256 = protobuf_sha256,
strip_prefix = "protobuf-" + protobuf_commit,
url = "https://github.com/google/protobuf/archive/" + protobuf_commit + ".tar.gz",
strip_prefix = "protobuf-" + protobuf_version,
url = "https://github.com/google/protobuf/archive/v" + protobuf_version + ".tar.gz",
)
151 changes: 93 additions & 58 deletions firebase_rules_generator/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,43 +90,6 @@ std::string GetEnumName(const protobuf::EnumDescriptor *enumeration) {
}
}

std::vector<std::string> RequiredFields(const protobuf::Descriptor *message) {
std::vector<std::string> required;
for (int i = 0; i < message->field_count(); ++i) {
const auto *field = message->field(i);
if (field->is_required() && field->containing_oneof() == nullptr) {
required.push_back(field->json_name());
}
}
return required;
}

std::vector<std::string> OptionalFields(const protobuf::Descriptor *message) {
std::vector<std::string> optional;
for (int i = 0; i < message->field_count(); ++i) {
const auto *field = message->field(i);
if ((field->is_optional() || field->is_repeated()) &&
field->containing_oneof() == nullptr) {
optional.push_back(field->json_name());
}
}
return optional;
}

std::vector<std::vector<std::string>> OneOfFields(
const protobuf::Descriptor *message) {
std::vector<std::vector<std::string>> oneofs;
for (int i = 0; i < message->oneof_decl_count(); ++i) {
std::vector<std::string> oneof_names;
const auto *oneof_decl = message->oneof_decl(i);
for (int j = 0; j < oneof_decl->field_count(); ++j) {
oneof_names.push_back(oneof_decl->field(j)->json_name());
}
oneofs.push_back(oneof_names);
}
return oneofs;
}

std::string ToString(std::vector<std::string> vec) {
std::string result = "[";
for (const auto &elem : vec) {
Expand Down Expand Up @@ -213,17 +176,19 @@ bool IsLastIteration(S idx, S size) {

} // namespace

std::string RulesGenerator::RulesFilename(const protobuf::FileDescriptor *file, const std::string &parameter) const {
if (parameter == "bazel") {
return StrCat(StripSuffixString(file->name(), ".proto"), ".pb.rules");
} else {
return RULES_FILE;
}
}

bool RulesGenerator::Generate(const protobuf::FileDescriptor *file,
const std::string &parameter,
protobuf::compiler::GeneratorContext *context,
std::string *error) const {
std::string filename;
if (parameter == "bazel") {
filename = StrCat(StripSuffixString(file->name(), ".proto"),
".pb.rules");
} else {
filename = RULES_FILE;
}
std::string filename = RulesFilename(file, parameter);
protobuf::io::Printer printer(context->Open(filename), '$');

// Start by adding a comment
Expand Down Expand Up @@ -286,12 +251,13 @@ bool RulesGenerator::GenerateMessage(const protobuf::Descriptor *message,
}
}
// Validate inner types
if (message->field_count() > 0) printer.Print(" &&\n");
for (int i = 0; i < message->field_count(); ++i) {
if (!GenerateField(message->field(i), printer, error)) {
const auto &all_fields = AllFields(message);
if (all_fields.size() > 0) printer.Print(" &&\n");
for (size_t i = 0; i < all_fields.size(); ++i) {
if (!GenerateField(all_fields[i], printer, error)) {
return false;
}
if (!IsLastIteration(i, message->field_count())) {
if (!IsLastIteration(i, all_fields.size())) {
printer.Print(" &&\n");
}
}
Expand Down Expand Up @@ -345,6 +311,21 @@ bool RulesGenerator::GenerateEnum(const protobuf::EnumDescriptor *enumeration,
return true;
}

bool RulesGenerator::IsNullableField(
const protobuf::FieldDescriptor *field) const {
const auto &options = field->options().GetExtension(firebase_rules_field);
const auto &msg_options =
field->containing_type()->options().GetExtension(firebase_rules_message);
return (!options.has_nullable() && msg_options.nullable()) ||
options.nullable();
}

bool RulesGenerator::IsReferenceField(
const protobuf::FieldDescriptor *field) const {
const auto &options = field->options().GetExtension(firebase_rules_field);
return options.reference_type();
}

bool RulesGenerator::GenerateField(const protobuf::FieldDescriptor *field,
protobuf::io::Printer &printer,
std::string *error) const {
Expand All @@ -368,9 +349,10 @@ bool RulesGenerator::GenerateField(const protobuf::FieldDescriptor *field,
if (!field->is_repeated()) {
std::map<std::string, std::string> vars;
vars.insert({"name", field->json_name()});
if (options.reference_type() &&
field->type() != protobuf::FieldDescriptor::TYPE_STRING) {
*error = "references must be of type string";
if (IsReferenceField(field) &&
!(field->type() == protobuf::FieldDescriptor::TYPE_STRING||
field->type() == protobuf::FieldDescriptor::TYPE_MESSAGE)) {
*error = "references must be of type string or message";
return false;
}
switch (field->type()) {
Expand Down Expand Up @@ -435,7 +417,7 @@ bool RulesGenerator::GenerateField(const protobuf::FieldDescriptor *field,
printer.Print(vars, "resource.$name$ is $type$");
break;
case protobuf::FieldDescriptor::TYPE_STRING:
if (options.reference_type()) {
if (IsReferenceField(field)) {
vars.insert({"type", "path"});
} else {
vars.insert({"type", "string"});
Expand All @@ -453,6 +435,9 @@ bool RulesGenerator::GenerateField(const protobuf::FieldDescriptor *field,
if (field->message_type()->full_name() == "google.protobuf.Timestamp") {
vars.insert({"type", "timestamp"});
printer.Print(vars, "resource.$name$ is $type$");
} else if (IsReferenceField(field)) {
vars.insert({"type", "path"});
printer.Print(vars, "resource.$name$ is $type$");
} else {
vars.insert({"type", GetMessageName(field->message_type())});
printer.Print(vars, "is$type$Message(resource.$name$)");
Expand All @@ -468,16 +453,13 @@ bool RulesGenerator::GenerateField(const protobuf::FieldDescriptor *field,
printer.Print(" && ($validate$)", "validate", options.validate());
}
printer.Print(")");
const auto &msg_options =
field->containing_type()->options().GetExtension(firebase_rules_message);
bool nullable =
(!options.has_nullable() && msg_options.nullable()) || options.nullable();
if (nullable) {

if (IsNullableField(field)) {
printer.Print(" || resource.$name$ == null", "name", field->json_name());
}
printer.Print(")");
return true;
}
} // namespace experimental

bool RulesGenerator::GenerateMap(const protobuf::FieldDescriptor *map_field,
protobuf::io::Printer &printer,
Expand Down Expand Up @@ -507,6 +489,59 @@ bool RulesGenerator::GenerateMap(const protobuf::FieldDescriptor *map_field,
return true;
}

std::vector<std::string> RulesGenerator::RequiredFields(
const protobuf::Descriptor *message) const {
std::vector<std::string> required;
for (int i = 0; i < message->field_count(); ++i) {
const auto *field = message->field(i);
if (field->is_required() && field->containing_oneof() == nullptr &&
!IgnoreField(field)) {
required.push_back(field->json_name());
}
}
return required;
}

std::vector<std::string> RulesGenerator::OptionalFields(
const protobuf::Descriptor *message) const {
std::vector<std::string> optional;
for (int i = 0; i < message->field_count(); ++i) {
const auto *field = message->field(i);
if ((field->is_optional() || field->is_repeated()) &&
field->containing_oneof() == nullptr && !IgnoreField(field)) {
optional.push_back(field->json_name());
}
}
return optional;
}

std::vector<std::vector<std::string>> RulesGenerator::OneOfFields(
const protobuf::Descriptor *message) const {
std::vector<std::vector<std::string>> oneofs;
for (int i = 0; i < message->oneof_decl_count(); ++i) {
std::vector<std::string> oneof_names;
const auto *oneof_decl = message->oneof_decl(i);
for (int j = 0; j < oneof_decl->field_count(); ++j) {
if (!IgnoreField(oneof_decl->field(j))) {
oneof_names.push_back(oneof_decl->field(j)->json_name());
}
}
oneofs.push_back(oneof_names);
}
return oneofs;
}

std::vector<const protobuf::FieldDescriptor *> RulesGenerator::AllFields(
const protobuf::Descriptor *message) const {
std::vector<const protobuf::FieldDescriptor *> fields;
for (int i = 0; i < message->field_count(); ++i) {
const auto *field = message->field(i);
if (IgnoreField(field)) continue;
fields.push_back(field);
}
return fields;
}

} // namespace experimental
} // namespace rules
} // namespace firebase
Expand Down
22 changes: 22 additions & 0 deletions firebase_rules_generator/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ class RulesGenerator : public protobuf::compiler::CodeGenerator {
protobuf::compiler::GeneratorContext* context,
std::string* error) const override;

virtual std::string RulesFilename(const protobuf::FileDescriptor *file, const std::string &parameter) const;

virtual bool IgnoreField(const protobuf::FieldDescriptor* field) const {
return false;
}

virtual bool IsNullableField(const protobuf::FieldDescriptor* field) const;

virtual bool IsReferenceField(const protobuf::FieldDescriptor* field) const;

private:
bool GenerateMessage(const protobuf::Descriptor* message,
protobuf::io::Printer& printer,
Expand All @@ -54,6 +64,18 @@ class RulesGenerator : public protobuf::compiler::CodeGenerator {

bool GenerateMap(const protobuf::FieldDescriptor* map_field,
protobuf::io::Printer& printer, std::string* error) const;

std::vector<std::string> RequiredFields(
const protobuf::Descriptor* message) const;

std::vector<std::string> OptionalFields(
const protobuf::Descriptor* message) const;

std::vector<std::vector<std::string>> OneOfFields(
const protobuf::Descriptor* message) const;

std::vector<const protobuf::FieldDescriptor*> AllFields(
const protobuf::Descriptor* message) const;
};

} // namespace experimental
Expand Down