Skip to content

Commit

Permalink
Merge pull request #40 from Chaste/5-default-arg-templates
Browse files Browse the repository at this point in the history
Fix method default arg templates
  • Loading branch information
kwabenantim authored May 3, 2024
2 parents 79fddca + cfe5273 commit dfa8ee7
Show file tree
Hide file tree
Showing 30 changed files with 384 additions and 296 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Add a package description to `examples/shapes/wrapper/package_info.yaml`:
name: pyshapes
modules:
- name: math_funcs
free_functions: cppwg_ALL
free_functions: CPPWG_ALL
```
Generate the wrappers with:
Expand Down
2 changes: 1 addition & 1 deletion cppwg/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main() -> None:

logging.basicConfig(
format="%(levelname)s %(message)s",
handlers=[logging.FileHandler("cppwg.log"), logging.StreamHandler()],
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger()

Expand Down
16 changes: 10 additions & 6 deletions cppwg/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
r"castxml version \d+\.\d+\.\d+", castxml_version
).group(0)
logger.info(castxml_version)
logger.info(f"pygccxml version {pygccxml.version}")
logger.info(f"pygccxml version {pygccxml.__version__}")

# Sanitize castxml_cflags
self.castxml_cflags: str = ""
Expand Down Expand Up @@ -278,13 +278,17 @@ def add_class_decls(self) -> None:
for class_info in module_info.class_info_collection:
class_info.decls: List["class_t"] = [] # noqa: F821

for full_name in class_info.full_names:
decl_name = full_name.replace(" ", "") # e.g. Foo<2,2>
for class_cpp_name in class_info.cpp_names:
decl_name = class_cpp_name.replace(" ", "") # e.g. Foo<2,2>

try:
class_decl = self.source_ns.class_(decl_name)

except pygccxml.declarations.runtime_errors.declaration_not_found_t:
logging.warning(
f"Could not find declaration for {decl_name}: trying partial match."
)

if "=" in class_info.template_signature:
# Try to find the class without default template args
# e.g. for template <int A, int B=A> class Foo {};
Expand All @@ -300,10 +304,10 @@ def add_class_decls(self) -> None:
decl_name = ",".join(decl_name.split(",")[0:pos]) + " >"
class_decl = self.source_ns.class_(decl_name)

logging.info(f"Found {decl_name}")

else:
logging.error(
f"Could not find class declaration for {decl_name}"
)
logging.error(f"Could not find declaration for {decl_name}")

class_info.decls.append(class_decl)

Expand Down
36 changes: 18 additions & 18 deletions cppwg/input/class_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,39 @@ class CppClassInfo(CppTypeInfo):
Attributes
----------
full_names : List[str]
cpp_names : List[str]
The C++ names of the class e.g. ["Foo<2,2>", "Foo<3,3>"]
short_names : List[str]
py_names : List[str]
The Python names of the class e.g. ["Foo2_2", "Foo3_3"]
"""

def __init__(self, name: str, class_config: Optional[Dict[str, Any]] = None):

super(CppClassInfo, self).__init__(name, class_config)

self.full_names: List[str] = None
self.short_names: List[str] = None
self.cpp_names: List[str] = None
self.py_names: List[str] = None

def update_short_names(self) -> None:
def update_py_names(self) -> None:
"""
Set the Python names for the class, accounting for template args.
Set the name of the class as it will appear on the Python side. This
collapses template arguments, separating them by underscores and removes
special characters. The return type is a list, as a class can have
multiple names if it is templated. For example, a class "Foo" with
template arguments [[2, 2], [3, 3]] will have a short name list
template arguments [[2, 2], [3, 3]] will have a python name list
["Foo2_2", "Foo3_3"].
"""
# Handles untemplated classes
if self.template_arg_lists is None:
if self.name_override:
self.short_names = [self.name_override]
self.py_names = [self.name_override]
else:
self.short_names = [self.name]
self.py_names = [self.name]
return

self.short_names = []
self.py_names = []

# Table of special characters for removal
rm_chars = {"<": None, ">": None, ",": None, " ": None}
Expand Down Expand Up @@ -89,36 +89,36 @@ def update_short_names(self) -> None:
if idx < len(template_arg_list) - 1:
template_string += "_"

self.short_names.append(type_name + template_string)
self.py_names.append(type_name + template_string)

def update_full_names(self) -> None:
def update_cpp_names(self) -> None:
"""
Set the C++ names for the class, accounting for template args.
Set the name of the class as it should appear in C++.
The return type is a list, as a class can have multiple names
if it is templated. For example, a class "Foo" with
template arguments [[2, 2], [3, 3]] will have a full name list
template arguments [[2, 2], [3, 3]] will have a C++ name list
["Foo<2,2 >", "Foo<3,3 >"].
"""
# Handles untemplated classes
if self.template_arg_lists is None:
self.full_names = [self.name]
self.cpp_names = [self.name]
return

self.full_names = []
self.cpp_names = []
for template_arg_list in self.template_arg_lists:
# Create template string from arg list e.g. [2, 2] -> "<2,2 >"
template_string = ",".join([str(arg) for arg in template_arg_list])
template_string = "<" + template_string + " >"

# Join full name e.g. "Foo<2,2 >"
self.full_names.append(self.name + template_string)
self.cpp_names.append(self.name + template_string)

def update_names(self) -> None:
"""Update the full and short names for the class."""
self.update_full_names()
self.update_short_names()
"""Update the C++ and Python names for the class."""
self.update_cpp_names()
self.update_py_names()

@property
def parent(self) -> "ModuleInfo": # noqa: F821
Expand Down
28 changes: 4 additions & 24 deletions cppwg/input/cpp_type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ class CppTypeInfo(BaseInfo):
The name override specified in config e.g. "CustomFoo" -> "Foo"
template_signature : str
The template signature of the type e.g. "<unsigned DIM_A, unsigned DIM_B = DIM_A>"
template_params : List[str]
List of template parameters e.g. ["DIM_A", "DIM_B"]
template_arg_lists : List[List[Any]]
List of template replacement arguments for the type e.g. [[2, 2], [3, 3]]
List of template replacement arguments e.g. [[2, 2], [3, 3]]
decls : pygccxml.declarations.declaration_t
The pygccxml declarations associated with this type, one per template arg if templated
"""
Expand All @@ -36,32 +38,10 @@ def __init__(self, name: str, type_config: Optional[Dict[str, Any]] = None):
self.source_file: Optional[str] = None
self.name_override: Optional[str] = None
self.template_signature: Optional[str] = None
self.template_params: Optional[List[str]] = None
self.template_arg_lists: Optional[List[List[Any]]] = None
self.decls: Optional[List["declaration_t"]] = None # noqa: F821

if type_config:
for key, value in type_config.items():
setattr(self, key, value)

# TODO: This method is not used, remove it?
def needs_header_file_instantiation(self):
"""Check if this class needs to be instantiated in the header file."""
return (
(self.template_arg_lists is not None)
and (not self.include_file_only)
and (self.needs_instantiation)
)

# TODO: This method is not used, remove it?
def needs_header_file_typdef(self):
"""
Check if this type need to be typdef'd with a nicer name.
The typedefs are declared in the header file. All template classes need this.
"""
return (self.template_arg_lists is not None) and (not self.include_file_only)

# TODO: This method is not used, remove it?
def needs_auto_wrapper_generation(self):
"""Check if this class needs a wrapper to be autogenerated."""
return not self.include_file_only
12 changes: 12 additions & 0 deletions cppwg/input/info_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,16 @@ def extract_templates_from_source(self, feature_info: BaseInfo) -> None:
feature_info.template_signature = template_substitution[
"signature"
]

# Extract ["DIM_A", "DIM_B"] from "<unsigned A, unsigned DIM_B=DIM_A>"
template_params = []
for tp in template_substitution["signature"].split(","):
template_params.append(
tp.replace("<", "")
.replace(">", "")
.split(" ")[1]
.split("=")[0]
.strip()
)
feature_info.template_params = template_params
break
4 changes: 1 addition & 3 deletions cppwg/input/module_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import os
from typing import Any, Dict, List, Optional

from pygccxml.declarations import declaration_t

from cppwg.input.base_info import BaseInfo


Expand Down Expand Up @@ -54,7 +52,7 @@ def parent(self) -> "PackageInfo": # noqa: F821
"""Returns the parent package info object."""
return self.package_info

def is_decl_in_source_path(self, decl: declaration_t) -> bool:
def is_decl_in_source_path(self, decl: "declaration_t") -> bool: # noqa: F821
"""
Check if the declaration is associated with a file in the specified source paths.
Expand Down
3 changes: 3 additions & 0 deletions cppwg/input/package_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class PackageInfo(BaseInfo):
A list of source file names to include
common_include_file : bool
Use a common include file for all source files
exclude_default_args : bool
Exclude default arguments from method wrappers.
"""

def __init__(
Expand Down Expand Up @@ -57,6 +59,7 @@ def __init__(
self.source_hpp_patterns: List[str] = ["*.hpp"]
self.source_hpp_files: List[str] = []
self.common_include_file: bool = False
self.exclude_default_args: bool = False

if package_config:
for key, value in package_config.items():
Expand Down
4 changes: 4 additions & 0 deletions cppwg/parsers/package_info_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,18 @@ def parse(self) -> PackageInfo:
package_config: Dict[str, Any] = {
"name": "cppwg_package",
"common_include_file": True,
"exclude_default_args": False,
"source_hpp_patterns": ["*.hpp"],
}
package_config.update(global_config)

for key in package_config.keys():
if key in self.raw_package_info:
package_config[key] = self.raw_package_info[key]

# Replace boolean strings with booleans
utils.substitute_bool_for_string(package_config, "common_include_file")
utils.substitute_bool_for_string(package_config, "exclude_default_args")

# Create the PackageInfo object from the package config dict
self.package_info = PackageInfo(
Expand Down
28 changes: 14 additions & 14 deletions cppwg/templates/pybind11_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
{includes}
#include "{class_short_name}.%s.hpp"
#include "{class_py_name}.%s.hpp"
namespace py = pybind11;
typedef {class_full_name} {class_short_name};
typedef {class_cpp_name} {class_py_name};
{smart_ptr_handle};
""" % CPPWG_EXT

Expand All @@ -16,43 +16,43 @@
#include <pybind11/stl.h>
{includes}
//#include "PythonObjectConverters.hpp"
#include "{class_short_name}.%s.hpp"
#include "{class_py_name}.%s.hpp"
namespace py = pybind11;
//PYBIND11_CVECTOR_TYPECASTER2();
//PYBIND11_CVECTOR_TYPECASTER3();
typedef {class_full_name} {class_short_name};
typedef {class_cpp_name} {class_py_name};
{smart_ptr_handle};
""" % CPPWG_EXT

class_hpp_header = """\
#ifndef {class_short_name}_hpp__%s_wrapper
#define {class_short_name}_hpp__%s_wrapper
#ifndef {class_py_name}_hpp__%s_wrapper
#define {class_py_name}_hpp__%s_wrapper
#include <pybind11/pybind11.h>
void register_{class_short_name}_class(pybind11::module &m);
#endif // {class_short_name}_hpp__%s_wrapper
void register_{class_py_name}_class(pybind11::module &m);
#endif // {class_py_name}_hpp__%s_wrapper
""" % tuple([CPPWG_EXT]*3)

class_virtual_override_header = """\
class {class_short_name}%s : public {class_short_name}{{
class {class_py_name}%s : public {class_py_name}{{
public:
using {class_short_name}::{class_base_name};
using {class_py_name}::{class_base_name};
""" % CPPWG_CLASS_OVERRIDE_SUFFIX

class_virtual_override_footer = "}\n"

class_definition = """\
void register_{short_name}_class(py::module &m){{
py::class_<{short_name} {overrides_string} {ptr_support} {bases} >(m, "{short_name}")
void register_{class_py_name}_class(py::module &m){{
py::class_<{class_py_name} {overrides_string} {ptr_support} {bases} >(m, "{class_py_name}")
"""

method_virtual_override = """\
{return_type} {method_name}({arg_string}){const_adorn} override {{
PYBIND11_OVERRIDE{overload_adorn}(
{tidy_method_name},
{short_class_name},
{class_py_name},
{method_name},
{args_string});
}}
Expand All @@ -67,7 +67,7 @@ class {class_short_name}%s : public {class_short_name}{{
class_method = """\
.def{def_adorn}(
"{method_name}",
({return_type}({self_ptr})({arg_signature}){const_adorn}) &{class_short_name}::{method_name},
({return_type}({self_ptr})({arg_signature}){const_adorn}) &{class_py_name}::{method_name},
{method_docs} {default_args} {call_policy})
"""

Expand Down
12 changes: 0 additions & 12 deletions cppwg/writers/base_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,3 @@ def tidy_name(self, name: str) -> str:
name = name.replace(key, value)

return name

# TODO: This method is currently a placeholder. Consider implementing or removing.
def default_arg_exclusion_criteria(self) -> bool:
"""
Check if default arguments should be excluded from the wrapper code.
Returns
-------
bool
True if the default arguments should be excluded
"""
return False
Loading

0 comments on commit dfa8ee7

Please sign in to comment.