Skip to content

Commit

Permalink
Add a "lite" factory variant
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Carroll <[email protected]>
  • Loading branch information
mjcarroll committed Mar 4, 2024
1 parent 35d48cd commit 6338d74
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 54 deletions.
4 changes: 2 additions & 2 deletions tools/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ load(

gz_py_binary(
name = "gz_msgs_generate_factory_py",
srcs = ["gz_msgs_generate_factory.py"],
main = "gz_msgs_generate_factory.py",
srcs = ["gz_msgs_generate_factory_lite.py"],
main = "gz_msgs_generate_factory_lite.py",
visibility = GZ_VISIBILITY,
)
90 changes: 38 additions & 52 deletions tools/gz_msgs_generate_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import argparse
import os
import re
import pathlib
import sys

# Create <gz/msgs/MessageTypes.hh>
Expand Down Expand Up @@ -74,7 +74,7 @@
#include "gz/msgs/Factory.hh"
#include "gz/msgs/MessageFactory.hh"
#include "{include_path}/MessageTypes.hh"
#include "{package_path}/MessageTypes.hh"
#include <array>
Expand Down Expand Up @@ -110,21 +110,17 @@ def main(argv=sys.argv[1:]):
description='Generate protobuf factory file',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--cc-output',
'--output-cpp-path',
required=True,
help='The path to the generated cpp file')
help='The basepath of the generated C++ files')
parser.add_argument(
'--hh-output',
'--proto-package',
required=True,
help='The path to the generated hh file')
help='The basepath of the generated C++ files')
parser.add_argument(
'--proto-path',
required=True,
help='The location of the protos')
parser.add_argument(
'--proto-include-path',
required=True,
help='The location of the protos')
parser.add_argument(
'--protos',
type=str,
Expand All @@ -135,55 +131,45 @@ def main(argv=sys.argv[1:]):

args = parser.parse_args(argv)

package_re = re.compile('^package (.*);$')
message_re = re.compile(r'message (\w*)\s?{?$')
headers = []
registrations = []

registrations = dict()
gz_msgs_headers = []
package = None
messages = []
package = [p for p in args.proto_package.split('.') if len(p)]
namespace = '::'.join(package)
package_str = '.'.join(package)
package_path = '/'.join(package)

for proto in args.protos:
try:
with open(proto, 'r') as f:
content = f.readlines()
for line in content:
package_found = package_re.match(line)
if package_found:
package = package_found.group(1).split('.')

message_found = message_re.match(line)
if message_found:
messages.append(message_found.group(1))
except:
pass

if package and messages:
for message in messages:
registrations['_'.join([*package, message])] = register_fn.format(
package_str='.'.join(package),
message_str=message,
message_cpp_type='::'.join([*package, message])
)

split = proto.replace(args.proto_include_path, '')
split = [s for s in split.split("/") if s]
split[-1] = split[-1].replace(".proto", ".pb.h")
gz_msgs_headers.append("#include <" + "/".join(split) + ">")
proto_file = os.path.splitext(os.path.relpath(proto, args.proto_path))[0]
header = proto_file + ".pb.h"
headers.append(f"#include <{header}>")

namespace = '::'.join(package)
include_path = '/'.join(package)
proto_file = '_'.join(pathlib.Path(proto_file).parts)

with open(os.path.join(args.cc_output), 'w') as f:
f.write((cc_source.format(registrations='\n'.join(registrations.values()),
nRegistrations=len(registrations.values()),
# The gazebo extensions to the gazebo compiler write out a series of index files
# which capture the message types
index = os.path.join(args.output_cpp_path, proto_file + ".pb_index")
with open(index, "r") as index_f:
for line in index_f.readlines():
line = line.strip()

message_str = line
message_cpp_type = '::'.join(package) + '::' + message_str

registrations.append(register_fn.format(
package_str=package_str,
message_str=message_str,
message_cpp_type=message_cpp_type))

with open(os.path.join(args.output_cpp_path, *package, 'MessageTypes.hh'), 'w') as f:
f.write(cc_header.format(gz_msgs_headers='\n'.join(headers), namespace=namespace))

with open(os.path.join(args.output_cpp_path, *package, 'register.cc'), 'w') as f:
f.write((cc_source.format(registrations='\n'.join(registrations),
nRegistrations=len(registrations),
namespace=namespace,
include_path=include_path) +
package_path=package_path) +
cc_factory.format(namespace=namespace)))

with open(os.path.join(args.hh_output), 'w') as f:
f.write(cc_header.format(namespace=namespace,
gz_msgs_headers='\n'.join(gz_msgs_headers)))

if __name__ == '__main__':
sys.exit(main())
189 changes: 189 additions & 0 deletions tools/gz_msgs_generate_factory_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
#!/usr/bin/env python3
#
# Copyright (C) 2023 Open Source Robotics Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import re
import sys

# Create <gz/msgs/MessageTypes.hh>
cc_header = """/*
* Copyright (C) 2023 Open Source Robotics Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/* This file was automatically generated.
* Do not edit this directly
*/
#ifndef GZ_MSGS_MESSAGE_TYPES_HH_
#define GZ_MSGS_MESSAGE_TYPES_HH_
{gz_msgs_headers}
namespace {namespace} {{
int RegisterAll();
}}
#endif"""

# Create factory registration bits
cc_source = """/*
* Copyright (C) 2023 Open Source Robotics Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
/* This file was automatically generated.
* Do not edit this directly
*/
#include "gz/msgs/Factory.hh"
#include "gz/msgs/MessageFactory.hh"
#include "{include_path}/MessageTypes.hh"
#include <array>
namespace {{
using NamedFactoryFn = std::pair<std::string, gz::msgs::MessageFactory::FactoryFn>;
std::array<NamedFactoryFn, {nRegistrations}> kFactoryFunctions = {{{{
{registrations}
}}}};
}} // namespace
"""

cc_factory = """
namespace {namespace} {{
int RegisterAll() {{
size_t registered = 0;
for (const auto &entry: kFactoryFunctions) {{
gz::msgs::Factory::Register(entry.first, entry.second);
registered++;
}}
return registered;
}}
static int kMessagesRegistered = RegisterAll();
}} // namespace {namespace}
"""

register_fn = """ {{"{package_str}.{message_str}",
[]()->std::unique_ptr<google::protobuf::Message>{{return std::make_unique<{message_cpp_type}>();}}}},"""

def main(argv=sys.argv[1:]):
parser = argparse.ArgumentParser(
description='Generate protobuf factory file',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--cc-output',
required=True,
help='The path to the generated cpp file')
parser.add_argument(
'--hh-output',
required=True,
help='The path to the generated hh file')
parser.add_argument(
'--proto-path',
required=True,
help='The location of the protos')
parser.add_argument(
'--proto-include-path',
required=True,
help='The location of the protos')
parser.add_argument(
'--protos',
type=str,
nargs='*',
required=True,
help='The list of protos to include'
)

args = parser.parse_args(argv)

package_re = re.compile('^package (.*);$')
message_re = re.compile(r'message (\w*)\s?{?$')

registrations = dict()
gz_msgs_headers = []
package = []
messages = []

for proto in args.protos:
try:
with open(proto, 'r') as f:
content = f.readlines()
for line in content:
package_found = package_re.match(line)
if package_found:
package = package_found.group(1).split('.')

message_found = message_re.match(line)
if message_found:
messages.append(message_found.group(1))
except:
pass

if package and messages:
for message in messages:
registrations['_'.join([*package, message])] = register_fn.format(
package_str='.'.join(package),
message_str=message,
message_cpp_type='::'.join([*package, message])
)

split = proto.replace(args.proto_include_path, '')
split = [s for s in split.split("/") if s]
split[-1] = split[-1].replace(".proto", ".pb.h")
gz_msgs_headers.append("#include <" + "/".join(split) + ">")

namespace = '::'.join(package)
include_path = '/'.join(package)

with open(os.path.join(args.cc_output), 'w') as f:
f.write((cc_source.format(registrations='\n'.join(registrations.values()),
nRegistrations=len(registrations.values()),
namespace=namespace,
include_path=include_path) +
cc_factory.format(namespace=namespace)))

with open(os.path.join(args.hh_output), 'w') as f:
f.write(cc_header.format(namespace=namespace,
gz_msgs_headers='\n'.join(gz_msgs_headers)))

if __name__ == '__main__':
sys.exit(main())

0 comments on commit 6338d74

Please sign in to comment.