diff --git a/src/common/translators.py b/src/common/translators.py index 3042261e304440961de45a4f06b622c88f18bf81..d8b6dfe2daddc17a1161649b73dcca6ab0c1fb91 100644 --- a/src/common/translators.py +++ b/src/common/translators.py @@ -457,24 +457,26 @@ class DefaultNamespaceTranslator(itranslators.ITranslator): contents.append(" ".join(first_line_contents)) for children in namespace: + children_content = [] if children.is_namespace: - contents.extend(self.namespace.translate(children)) + children_content = self.namespace.translate(children) elif children.is_class: - contents.extend(self.struct.translate(children)) - contents.append("") + children_content = self.struct.translate(children) elif children.is_function: - contents.extend(self.function.translate(children)) - contents.append("") + children_content = self.function.translate(children) elif children.is_data: - contents.extend(self.variable.translate(children)) - contents.append("") + children_content = self.variable.translate(children) elif children.is_enum: - contents.extend(self.enum.translate(children)) - contents.append("") + children_content = self.enum.translate(children) + + for content in children_content: + if type(content) == str: + contents.append(content) + contents.append("") if namespace_name: contents.append("}") @@ -531,25 +533,26 @@ class DefaultFileTranslator(itranslators.ITranslator): contents.extend(self.header.translate(file)) for children in file: + children_content = [] if children.is_namespace: - contents.extend(self.namespace.translate(children)) - contents.append("") + children_content = self.namespace.translate(children) elif children.is_class: - contents.extend(self.struct.translate(children)) - contents.append("") + children_content = self.struct.translate(children) elif children.is_function: - contents.extend(self.function.translate(children)) - contents.append("") + children_content = self.function.translate(children) elif children.is_data: - contents.extend(self.variable.translate(children)) - contents.append("") + children_content = self.variable.translate(children) elif children.is_enum: - contents.extend(self.enum.translate(children)) - contents.append("") + children_content = self.enum.translate(children) + + for content in children_content: + if type(content) == str: + contents.append(content) + contents.append("") contents.extend(self.footer.translate(file)) diff --git a/src/common/utils.py b/src/common/utils.py index 0fc997ab865a0950cead046b31b2d33d53088cbc..42d37ae9378ad3d0683eb8ddda4ff93684d29cb5 100644 --- a/src/common/utils.py +++ b/src/common/utils.py @@ -285,23 +285,21 @@ def generate_macro_from_path(path: str) -> str: return macro -def merge_dicts(dict1: dict, dict2: dict) -> dict: +def merge_dicts(dict1: dict, dict2: dict): """ Merge two dictionaries, combining values from dict2 into dict1. If both dictionaries have a list for a same key, the values from dict2 are merged to the list in dict1. Otherwise, the value from dict2 updates or overwrites the value in dict1. """ - merged = dict1.copy() for key, value in dict2.items(): - if isinstance(merged.get(key), list) and isinstance(value, list): - merged[key].extend(value) - elif merged.get(key): + if isinstance(dict1.get(key), list) and isinstance(value, list): + dict1[key].extend(value) + elif dict1.get(key): warnings.warn(f"Overwriting key {key} in merged dictionary") - merged[key] = value + dict1[key] = value else: - merged[key] = value - return merged + dict1[key] = value def canonical_type_name(data_type: cursors.DataType) -> str: @@ -309,5 +307,27 @@ def canonical_type_name(data_type: cursors.DataType) -> str: for unit in units: if unit._type.kind == constants.UNEXPOSED: # æ— æ³•èŽ·å–æ£ç¡®çš„canonical_name + # TODO(https://git.aqrose.com/aidi/cpp2x/-/issues/366): need canonical return data_type.full_type_name return data_type.canonical_name + + +def get_all_template_instance(data_type: cursors.DataType) -> list[cursors.DataType]: + """ + Recursively retrieves all template instance types from a given DataType, + + Args: + data_type (cursors.DataType): The type to analyze. + + Returns: + list[cursors.DataType]: A list of all template instance types. + """ + result: list[cursors.DataType] = [] + num_template_args = data_type.get_num_template_arguments + if num_template_args > 0: + result.append(data_type) + for i in range(num_template_args): + arg_type = data_type.get_template_argument_type(i) + # Recursively gather template arguments from nested templates + result.extend(get_all_template_instance(arg_type)) + return result diff --git a/src/global_info/collectors.py b/src/global_info/collectors.py index 7253efdf9d048ae5c3220226e8f8ad98def94be2..656ffa229c776d58069eb1598799538938fe5f77 100644 --- a/src/global_info/collectors.py +++ b/src/global_info/collectors.py @@ -3,6 +3,7 @@ from typing import Any from common import utils from cursors import cursors +from extensions import utils as extensions_utils from global_info import icollector @@ -61,3 +62,20 @@ class NonMemberOperatorOverloadCollector(icollector.IGlobalInfoCollector): data = dict() self._search_operator_overload(file, data) return data + + +class TemplateClassDefineCollector(icollector.IGlobalInfoCollector): + + def _search_type_define(self, element: cursors.File | cursors.Namespace, data): + for child in element: + if child.is_class: + if child.is_definition and child.is_template: + data[child.full_namespace_with_name] = child.file_name + + if child.is_namespace: + self._search_type_define(child, data) + + def parse(self, file: cursors.File) -> dict[Hashable, Any]: + data = dict() + self._search_type_define(file, data) + return data diff --git a/src/pybind11/codegen.py b/src/pybind11/codegen.py index 1f6f891e6e2544616c54f1b5ffd8945da8520ab9..9bcd7784a5f540231ef7b9906fda27443e784418 100644 --- a/src/pybind11/codegen.py +++ b/src/pybind11/codegen.py @@ -5,6 +5,7 @@ import sys from common import list as common_list from common import translators as common_translators from cursors import cursors +from pybind11 import utils as pybind11_utils class _CodeGenFileOutputDirectoryTranslator( @@ -14,17 +15,11 @@ class _CodeGenFileOutputDirectoryTranslator( self._output_directory: str = output_directory def translate(self, file: cursors.File) -> list[str]: - # file.name为ç»å¯¹è·¯å¾„ - dirname: str = os.path.dirname(os.path.relpath(file.name)) - types_dirname: str = dirname - - # FIXME(cheny): Remove this. - if dirname.find("bazel-out") != -1: - start_index: int = dirname.find("bazel-out") - end_index: int = dirname.rfind("bin") - types_dirname = dirname[:start_index] + dirname[end_index + 4 :] - - return [os.path.join(self._output_directory, types_dirname)] + return [ + os.path.join( + self._output_directory, pybind11_utils.gen_output_dir_path(file.name) + ) + ] def init_translators(translator_module: str, output_directory: str): diff --git a/src/pybind11/default.py b/src/pybind11/default.py index 5202c73d277c7982d55918c2f96e19836e206a2c..6a3e30701ce83a1478495954dcc1a3cc4de362a3 100644 --- a/src/pybind11/default.py +++ b/src/pybind11/default.py @@ -18,6 +18,7 @@ from pybind11.translators import ( Pybind11DefaultCopyAssignOperatorTranslator, Pybind11DefaultEnumConstantTranslator, Pybind11DefaultEnumTypeTranslator, + Pybind11DefaultFileDirectoryTranslator, Pybind11DefaultFileFooterTranslator, Pybind11DefaultFileHeaderTranslator, Pybind11DefaultFileNameTranslator, @@ -170,12 +171,14 @@ template_class_translator = List( ) -template_define_translator = List( +template_define_translators = List( translator=Pybind11TemplateDefineTranslator( template_class=template_class_translator, ) ) +file_directory_translators = List(translator=Pybind11DefaultFileDirectoryTranslator()) + default_file_name_translators = List(translator=Pybind11DefaultFileNameTranslator()) default_file_footer_translators = List(translator=Pybind11DefaultFileFooterTranslator()) @@ -188,5 +191,6 @@ default_file_translator = Pybind11DefaultFileTranslator( struct=default_class_translators, function=default_function_translators, enum=default_enum_type_translators, - template_define_translator=template_define_translator, + template_define=template_define_translators, + file_directory=file_directory_translators, ) diff --git a/src/pybind11/translators.py b/src/pybind11/translators.py index f9403fa33f54fd36d53e2529081632fb1a539564..6b22a15cfdb8a24eac73e7c3ae2f5bd8f42aea4d 100644 --- a/src/pybind11/translators.py +++ b/src/pybind11/translators.py @@ -59,9 +59,13 @@ class Pybind11DefaultMemberFunctionTranslator( first_line_content: str = f'.def("{pybind11_member_function_name}",' contents: list[str] = [] + storage_info = {} contents.append(first_line_content) return_content: str = utils.canonical_type_name(member_function.return_type) + storage_info.setdefault("template_instance", []).extend( + utils.get_all_template_instance(member_function.return_type) + ) return_content = ( return_content @@ -73,6 +77,9 @@ class Pybind11DefaultMemberFunctionTranslator( parameter_type_contents: list[str] = [] for parameter in member_function.parameters: parameter_type_contents.append(utils.canonical_type_name(parameter.type)) + storage_info.setdefault("template_instance", []).extend( + utils.get_all_template_instance(parameter.type) + ) contents.append( constants.TAB @@ -98,7 +105,7 @@ class Pybind11DefaultMemberFunctionTranslator( constants.TAB + "pybind11::call_guard<pybind11::gil_scoped_release>())" ) - return ["\n".join(contents)] + return ["\n".join(contents), storage_info] class Pybind11DefaultStaticMemberFunctionTranslator( @@ -125,9 +132,13 @@ class Pybind11DefaultStaticMemberFunctionTranslator( first_line_content: str = f'.def_static("{pybind11_member_function_name}",' contents: list[str] = [] + storage_info = {} contents.append(first_line_content) return_content: str = utils.canonical_type_name(member_function.return_type) + storage_info.setdefault("template_instance", []).extend( + utils.get_all_template_instance(member_function.return_type) + ) return_content = ( return_content @@ -139,6 +150,9 @@ class Pybind11DefaultStaticMemberFunctionTranslator( parameter_type_contents: list[str] = [] for parameter in member_function.parameters: parameter_type_contents.append(utils.canonical_type_name(parameter.type)) + storage_info.setdefault("template_instance", []).extend( + utils.get_all_template_instance(parameter.type) + ) contents.append( constants.TAB @@ -159,7 +173,7 @@ class Pybind11DefaultStaticMemberFunctionTranslator( constants.TAB + "pybind11::call_guard<pybind11::gil_scoped_release>())" ) - return ["\n".join(contents)] + return ["\n".join(contents), storage_info] class Pybind11DefaultConstructorTranslator( @@ -183,8 +197,12 @@ class Pybind11DefaultConstructorTranslator( def translate(self, member_function: cursors.MemberFunction) -> list[str]: parameter_type_contents: list[str] = [] + storage_info = {} for parameter in member_function.parameters: parameter_type_contents.append(utils.canonical_type_name(parameter.type)) + storage_info.setdefault("template_instance", []).extend( + utils.get_all_template_instance(parameter.type) + ) parameter_value_contents: list[str | None] = [] for parameter in member_function.parameters: @@ -197,7 +215,10 @@ class Pybind11DefaultConstructorTranslator( parameter_value_contents.remove(None) return [ - f'.def(pybind11::init<{", ".join(parameter_type_contents)}>(){", ".join(parameter_value_contents)})' + ( + f'.def(pybind11::init<{", ".join(parameter_type_contents)}>(){", ".join(parameter_value_contents)})' + ), + storage_info, ] @@ -314,7 +335,10 @@ class Pybind11DefaultMemberVariableTranslator( def translate(self, member_variable: cursors.MemberVariable) -> list[str]: pybind11_member_variable_name: str = self.name.translate(member_variable)[0] - + storage_info = {} + storage_info.setdefault("template_instance", []).extend( + utils.get_all_template_instance(member_variable.type) + ) if member_variable.is_const: return [ f'.def_readonly("{pybind11_member_variable_name}",' @@ -322,8 +346,11 @@ class Pybind11DefaultMemberVariableTranslator( ] return [ - f'.def_readwrite("{pybind11_member_variable_name}",' - f" &{member_variable.belong_to_class.canonical_name}::{member_variable.name})" + ( + f'.def_readwrite("{pybind11_member_variable_name}",' + f" &{member_variable.belong_to_class.canonical_name}::{member_variable.name})" + ), + storage_info, ] @@ -341,7 +368,10 @@ class Pybind11DefaultStaticMemberVariableTranslator(itranslators.ITranslator): pybind11_member_variable_name: str = self.name.translate( static_member_variable )[0] - + storage_info = {} + storage_info.setdefault("template_instance", []).extend( + utils.get_all_template_instance(static_member_variable.type) + ) if static_member_variable.is_const: return [ f'.def_readonly_static("{pybind11_member_variable_name}",' @@ -349,8 +379,11 @@ class Pybind11DefaultStaticMemberVariableTranslator(itranslators.ITranslator): ] return [ - f'.def_readwrite_static("{pybind11_member_variable_name}",' - f" &{static_member_variable.full_namespace_with_name})" + ( + f'.def_readwrite_static("{pybind11_member_variable_name}",' + f" &{static_member_variable.full_namespace_with_name})" + ), + storage_info, ] @@ -386,6 +419,21 @@ class Pybind11DefaultClassTranslator(common_translators.DefaultClassTranslator): ) return contents + def _tanslate_and_merge_data( + self, + translator: common_list.List, + items: list[cursors._BaseCursor], + contents: list[str], + storage_info: dict, + ): + for item in items: + item_content = translator.translate(item) + for content in item_content: + if type(content) == str: + contents.append(content) + elif type(content) == dict: + utils.merge_dicts(storage_info, content) + @classmethod def collector_type_list(cls): return [collectors.NonMemberOperatorOverloadCollector] @@ -395,11 +443,7 @@ class Pybind11DefaultClassTranslator(common_translators.DefaultClassTranslator): def translate(self, struct: cursors.Class) -> list[str]: contents: list[str] = [] - data = {} - if struct.is_template: - # æ¤å¤„ä¸åšç¿»è¯‘,并å‘ä¸Šä¼ é€’æ¨¡ç‰ˆç±»çš„å®šä¹‰ - data["template_define"] = [struct] - return ["", data] + storage_info = {} for nested_enum in struct.all_enum_types: contents.extend(self.enum.translate(nested_enum)) @@ -439,24 +483,40 @@ class Pybind11DefaultClassTranslator(common_translators.DefaultClassTranslator): contents.append(first_line_content) - for constructor in struct.constructors: - contents.extend(self.constructor.translate(constructor)) + self._tanslate_and_merge_data( + self.constructor, struct.constructors, contents, storage_info + ) - for member_function in struct.all_public_functions: - contents.extend(self.function.translate(member_function)) + self._tanslate_and_merge_data( + self.function, struct.all_public_functions, contents, storage_info + ) contents.extend(self._non_member_operator_overload_content(struct)) - for member_variable in struct.all_member_variables: - contents.extend(self.variable.translate(member_variable)) + self._tanslate_and_merge_data( + self.variable, struct.all_member_variables, contents, storage_info + ) - for static_member_variable in struct.all_static_member_variables: - contents.extend(self.static_variable.translate(static_member_variable)) + self._tanslate_and_merge_data( + self.static_variable, + struct.all_static_member_variables, + contents, + storage_info, + ) while None in contents: contents.remove(None) - return ["\n".join(contents) + ";"] + # 获å–基类å¯èƒ½å˜åœ¨çš„模版实例化类型 + for base_class in struct.all_base_classes: + template_instances = utils.get_all_template_instance(base_class.type) + storage_info.setdefault("template_instance", []).extend(template_instances) + + if struct.is_template: + # æ¤å¤„ä¸åšç¿»è¯‘,并å‘ä¸Šä¼ é€’æ¨¡ç‰ˆç±»çš„å®šä¹‰ + storage_info.setdefault("template_define", []).append(struct) + return ["", storage_info] + return ["\n".join(contents) + ";", storage_info] class Pybind11TemplateDefineTranslator(itranslators.ITranslator): @@ -518,17 +578,18 @@ pybind11::class_<{struct.canonical_name}, std::shared_ptr<{struct.canonical_name contents.append(class_define_content) # 接å£ç»‘定 - for constructor in struct.constructors: - contents.extend(self.constructor.translate(constructor)) - - for member_function in struct.all_public_functions: - contents.extend(self.function.translate(member_function)) - - for member_variable in struct.all_member_variables: - contents.extend(self.variable.translate(member_variable)) - - for static_member_variable in struct.all_static_member_variables: - contents.extend(self.static_variable.translate(static_member_variable)) + self._tanslate_and_merge_data( + self.constructor, struct.constructors, contents, {} + ) + self._tanslate_and_merge_data( + self.function, struct.all_public_functions, contents, {} + ) + self._tanslate_and_merge_data( + self.variable, struct.all_member_variables, contents, {} + ) + self._tanslate_and_merge_data( + self.static_variable, struct.all_static_member_variables, contents, {} + ) while None in contents: contents.remove(None) @@ -602,8 +663,7 @@ class Pybind11DefaultEnumTypeTranslator(common_translators.DefaultEnumTypeTransl class Pybind11DefaultFileNameTranslator(common_translators.DefaultFileNameTranslator): def translate(self, file: cursors.File) -> list[str]: - basename: str = super().translate(file)[0] - return [f"pybind11_{basename.split('.')[0]}.cpp"] + return [pybind11_utils.gen_output_file_name(file.base_name)] class Pybind11DefaultFileHeaderTranslator( @@ -708,9 +768,13 @@ class Pybind11DefaultFunctionTranslator(common_translators.DefaultFunctionTransl ) contents: list[str] = [] + storage_info = {} contents.append(first_line_content) return_content: str = utils.canonical_type_name(function.return_type) + storage_info.setdefault("template_instance", []).extend( + utils.get_all_template_instance(function.return_type) + ) return_content = ( return_content @@ -721,6 +785,9 @@ class Pybind11DefaultFunctionTranslator(common_translators.DefaultFunctionTransl parameter_type_contents: list[str] = [] for parameter in function.parameters: parameter_type_contents.append(utils.canonical_type_name(parameter.type)) + storage_info.setdefault("template_instance", []).extend( + utils.get_all_template_instance(parameter.type) + ) contents.append( constants.TAB @@ -767,7 +834,7 @@ class Pybind11DefaultNamespaceTranslator(common_translators.DefaultNamespaceTran def translate(self, namespace: cursors.Namespace) -> list[str]: contents: list[str] = [] - data = {} + storage_info = {} if namespace.parent and namespace.parent.is_namespace: contents.append( @@ -792,15 +859,17 @@ class Pybind11DefaultNamespaceTranslator(common_translators.DefaultNamespaceTran elif children.is_enum: children_content = self.enum.translate(children) - if len(children_content): - contents.append(children_content[0]) - contents.append("") - if len(children_content) > 1 and isinstance(children_content[1], dict): - data = utils.merge_dicts(children_content[1], data) + for content in children_content: + if type(content) == str: + contents.append(content) + elif type(content) == dict: + utils.merge_dicts(storage_info, content) + contents.append("") + while None in contents: contents.remove(None) - return ["\n".join(contents), data] + return ["\n".join(contents), storage_info] def update(self, translators: common_list.List = common_list.List()) -> None: name_list = [] @@ -966,7 +1035,8 @@ class Pybind11DefaultFileTranslator(common_translators.DefaultFileTranslator): function: common_list.List = common_list.List(), variable: common_list.List = common_list.List(), enum: common_list.List = common_list.List(), - template_define_translator: common_list.List = common_list.List(), + template_define: common_list.List = common_list.List(), + file_directory: common_list.List = common_list.List(), ) -> None: self.name: common_list.List = name self.footer: common_list.List = footer @@ -977,22 +1047,47 @@ class Pybind11DefaultFileTranslator(common_translators.DefaultFileTranslator): self.function: common_list.List = function self.variable: common_list.List = variable self.enum: common_list.List = enum - self.template_define_translator: common_list.List = template_define_translator + self.template_define: common_list.List = template_define + self.file_directory: common_list.List = file_directory + + @classmethod + def collector_type_list(cls): + return [collectors.TemplateClassDefineCollector] - def _package_include_headers(self, template_instances: list[str]) -> list[str]: + def _package_include_headers( + self, template_instances: list[cursors.DataType], file: cursors.File + ) -> list[str]: """ 获å–"模版函数引用头文件",并使用(PYBIND11_TEMPLATE_INCLUDE_BY_OTHER)å®åŒ…装 Arguments: - template_instances {list[str]} -- 模版实例化类型 + template_instances {list[cursors.DataType]} -- 模版实例化类型 + file {cursors.File} -- 当å‰æ–‡ä»¶ Returns: list[str] """ - # TODO: 通过template_instances查表获å–"模版函数引用头文件" - include_template_headers_contents: list[str] = [ - "// this is include template headers contents" - ] + gen_path = lambda path: os.path.join( + self.file_directory.translate(file)[0], # type: ignore + pybind11_utils.gen_output_dir_path(path), + pybind11_utils.gen_output_file_name(os.path.basename(path)), + ).replace("\\", "/") + current_output_file_name = gen_path(file.name) + + # 获å–模版函数头文件 + include_template_headers_contents: list[str] = [] + for template_type in template_instances: + query_template = utils.canonical_type_name(template_type).split("<")[0] + result: str = process_global_info.info.query_data( + collectors.TemplateClassDefineCollector, query_template + ) + if result: + file_path: str = gen_path(result) + # éžå½“å‰æ–‡ä»¶ + if file_path != current_output_file_name: + include_template_headers_contents.append(f'#include "{file_path}"') + # åŽ»é‡ + include_template_headers_contents = list(set(include_template_headers_contents)) package_macro = [ "#ifdef PYBIND11_TEMPLATE_INCLUDE_BY_OTHER", @@ -1001,7 +1096,6 @@ class Pybind11DefaultFileTranslator(common_translators.DefaultFileTranslator): "#undef PYBIND11_TEMPLATE_INCLUDE_BY_OTHER", "#endif", ] - return ( package_macro[0:1] + include_template_headers_contents @@ -1012,7 +1106,7 @@ class Pybind11DefaultFileTranslator(common_translators.DefaultFileTranslator): def _origin_non_template_bind_contents( self, file: cursors.File - ) -> tuple[list[str], list[cursors.Class], list[str]]: + ) -> tuple[list[str], list[cursors.Class], list[cursors.DataType]]: """ 获å–原始"éžæ¨¡ç‰ˆç»‘定内容",åŒæ—¶æ”¶é›†"模版类的定义"å’Œ"模版实例化类型" @@ -1020,11 +1114,11 @@ class Pybind11DefaultFileTranslator(common_translators.DefaultFileTranslator): file {cursors.File} Returns: - tuple[list[str], list[cursors.Class], list[str]] -- éžæ¨¡ç‰ˆç»‘定内容, 模版类的定义, 模版实例化类型 + tuple[list[str], list[cursors.Class], list[cursors.DataType]] -- éžæ¨¡ç‰ˆç»‘定内容, 模版类的定义, 模版实例化类型 """ # TODO: 收集模版类的定义, 模版实例化类型 contents: list[str | None] = [] - data = { + storage_info = { "template_define": [], "template_instance": [], } @@ -1045,15 +1139,34 @@ class Pybind11DefaultFileTranslator(common_translators.DefaultFileTranslator): elif children.is_enum: children_content = self.enum.translate(children) - if len(children_content): - contents.append(children_content[0]) - contents.append("") - if len(children_content) > 1 and isinstance(children_content[1], dict): - data = utils.merge_dicts(children_content[1], data) + for content in children_content: + if type(content) == str: + contents.append(content) + elif type(content) == dict: + utils.merge_dicts(storage_info, content) + contents.append("") while None in contents: contents.remove(None) - return (contents, data["template_define"], data["template_instance"]) + + # åŽ»é‡ + template_define = list( + { + template_class.canonical_name: template_class + for template_class in storage_info["template_define"] + }.values() + ) + template_instances = list( + { + utils.canonical_type_name(template_type): template_type + for template_type in storage_info["template_instance"] + }.values() + ) + return ( + contents, + template_define, + template_instances, + ) def translate(self, file: cursors.File) -> list[str]: @@ -1063,24 +1176,24 @@ class Pybind11DefaultFileTranslator(common_translators.DefaultFileTranslator): func_name_content = origin_headers_and_func_name_contents[1] # [2] 获å–"原始的éžæ¨¡ç‰ˆå†…容绑定接å£å‡½æ•°å®žçŽ°", "模版类的定义", "模版实例化类型" - origin_non_template_bind_contents, template_cursors, template_instances = ( + origin_non_template_bind_contents, template_define, template_instances = ( self._origin_non_template_bind_contents(file) ) # [3] 获å–"模版函数引用头文件",并使用(PYBIND11_TEMPLATE_INCLUDE_BY_OTHER)å®åŒ…装 macro_include_template_headers_contents: list[str] = ( - self._package_include_headers(template_instances) + self._package_include_headers(template_instances, file) ) # [4] ä¾æ®template_cursors,翻译模版内容绑定函数 - template_bind_contents: list[str] = ["// this is template bind contents"] - for tempalte_class in template_cursors: + template_bind_contents: list[str] = [] + for tempalte_class in template_define: template_bind_contents.append( - self.template_define_translator.translate(object=tempalte_class)[0] + self.template_define.translate(object=tempalte_class)[0] ) # [5] TODO: ä¾æ®template_instances,翻译注册函数调用逻辑 - template_call_contents: list[str] = ["// this is template call contents"] + template_call_contents: list[str] = [] # [6] "}" footer = self.footer.translate(file) @@ -1125,3 +1238,11 @@ class Pybind11DefaultFileTranslator(common_translators.DefaultFileTranslator): self._write_file(file, content) return [content] + + +class Pybind11DefaultFileDirectoryTranslator(itranslators.ITranslator): + def condition(self, file: cursors.File) -> bool: + return file.is_file + + def translate(self, file: cursors.File) -> list[str]: + return ["pybind11"] diff --git a/src/pybind11/utils.py b/src/pybind11/utils.py index 04474c53d10e22162fd4eaf47976534038e83fde..504d9d965f648e018e670b8144cca576735b3475 100644 --- a/src/pybind11/utils.py +++ b/src/pybind11/utils.py @@ -1,3 +1,5 @@ +import os + from clang import cindex from extensions import utils as extensions_utils @@ -91,3 +93,21 @@ def get_template_func_name(cls: cindex.Cursor) -> str: return "pybind11_" + cls.spelling else: return "pybind11_" + namespace.replace("::", "_") + "_" + cls.spelling + + +def gen_output_dir_path(input_file_name: str) -> str: + # file.name为ç»å¯¹è·¯å¾„ + dirname: str = os.path.dirname(os.path.relpath(input_file_name)) + types_dirname: str = dirname + + # FIXME(cheny): Remove this. + if dirname.find("bazel-out") != -1: + start_index: int = dirname.find("bazel-out") + end_index: int = dirname.rfind("bin") + types_dirname = dirname[:start_index] + dirname[end_index + 4 :] + return types_dirname + + +def gen_output_file_name(input_file_name: str) -> str: + output_file_name = f"pybind11_{input_file_name.split('.')[0]}.cpp" + return output_file_name diff --git a/test/codegen/pybind11/pybind11_codegen_translator.py b/test/codegen/pybind11/pybind11_codegen_translator.py index 8678b8d2ea18c7214392b3a660de04e28ce88951..f971b37bb471d73241fa484a4fdb0dc35c25052f 100644 --- a/test/codegen/pybind11/pybind11_codegen_translator.py +++ b/test/codegen/pybind11/pybind11_codegen_translator.py @@ -3,7 +3,10 @@ import os from cursors.cursors import File from pybind11.default import default_file_translator -from pybind11.translators import Pybind11DefaultFileHeaderTranslator +from pybind11.translators import ( + Pybind11DefaultFileDirectoryTranslator, + Pybind11DefaultFileHeaderTranslator, +) class Pybind11TestFileHeaderTranslator(Pybind11DefaultFileHeaderTranslator): @@ -41,5 +44,11 @@ class Pybind11TestFileHeaderTranslator(Pybind11DefaultFileHeaderTranslator): return ["\n".join(contents), func_name] +class Pybind11TestFileDirectoryTranslator(Pybind11DefaultFileDirectoryTranslator): + def translate(self, file: File) -> list[str]: + return ["test/codegen/pybind11"] + + cpp_translator = copy.deepcopy(default_file_translator) cpp_translator.header.append_front(Pybind11TestFileHeaderTranslator()) +cpp_translator.file_directory.append_front(Pybind11TestFileDirectoryTranslator()) diff --git a/test/data/test_codegen_pybind11_template_class.hpp b/test/data/test_codegen_pybind11_template_class.hpp index 3d4e82d2878c5e8c77b25e563a673103bd1ada8b..4643ffcad104c4366b48052124cd465cec575de6 100644 --- a/test/data/test_codegen_pybind11_template_class.hpp +++ b/test/data/test_codegen_pybind11_template_class.hpp @@ -1,6 +1,9 @@ #pragma once #include <iostream> + +#include "test/data/test_codegen_pybind11_template_class1.hpp" + namespace test_codegen { namespace test_template { template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>> @@ -13,5 +16,7 @@ public: void setValue(T value) { value_ = value; } static T value(T v) { return v; } }; + +// class DerivedClass2 : public test_codegen::TemplateClass1<int> {}; } // namespace test_template } // namespace test_codegen diff --git a/test/data/test_codegen_pybind11_template_class1.hpp b/test/data/test_codegen_pybind11_template_class1.hpp index cef54c31a7b8559f79c00887dccef7dcad1fcfcd..6661e942e8258afaf6cbc6eb0d8ed7437390571e 100644 --- a/test/data/test_codegen_pybind11_template_class1.hpp +++ b/test/data/test_codegen_pybind11_template_class1.hpp @@ -2,11 +2,24 @@ #include <vector> namespace test_codegen { template <typename T> -class template_class1 { +class TemplateClass1 { public: - template_class1(std::vector<T>, T){}; + TemplateClass1() = default; + TemplateClass1(std::vector<T>, T){}; std::vector<T> temp_func(std::vector<T> v, T) { return v; } std::vector<T> temp_variable; }; +template <typename T> +class DerivedClass : public TemplateClass1<int>, + public TemplateClass1<TemplateClass1<int>>, + public TemplateClass1<T> { +public: + DerivedClass() = default; + TemplateClass1<double> test_variable; + TemplateClass1<float> test_function(TemplateClass1<char>) { + return TemplateClass1<float>(); + } +}; + } // namespace test_codegen diff --git a/test/pytest/cpp/translators/pytest_cpp_translators_namespace.py b/test/pytest/cpp/translators/pytest_cpp_translators_namespace.py index 1db0ef4b3bee1ec13cdb2cea1bdc8172fb811bdf..a9e17b6f6f384b60a929d3afa2eb11f6f1e2af28 100644 --- a/test/pytest/cpp/translators/pytest_cpp_translators_namespace.py +++ b/test/pytest/cpp/translators/pytest_cpp_translators_namespace.py @@ -272,5 +272,6 @@ inline void set_a(int value) { } } + }""" ) diff --git a/test/pytest/pybind11/translators/pytest_pybind11_translators_class.py b/test/pytest/pybind11/translators/pytest_pybind11_translators_class.py index a58a708a763ccdf4b78c630f05337a3b5933115d..a6e4c050c423f9f9dd87fa1bad767480c3796aeb 100644 --- a/test/pytest/pybind11/translators/pytest_pybind11_translators_class.py +++ b/test/pytest/pybind11/translators/pytest_pybind11_translators_class.py @@ -157,7 +157,7 @@ public: contents: list[str] = [] for struct in file.all_classes: if translator.condition(struct): - contents.extend(translator.translate(struct)) + contents.append(translator.translate(struct)[0]) assert len(contents) == 3 assert ( diff --git a/test/pytest/pybind11/translators/pytest_pybind11_translators_member_variable.py b/test/pytest/pybind11/translators/pytest_pybind11_translators_member_variable.py index 8a389f9dc80f6676a6f5c8c65fc5b85165b3bdf8..5af105477536e42c460d0303abb23f508e066dbb 100644 --- a/test/pytest/pybind11/translators/pytest_pybind11_translators_member_variable.py +++ b/test/pytest/pybind11/translators/pytest_pybind11_translators_member_variable.py @@ -45,7 +45,7 @@ private: if translator.condition(struct): contents.extend(translator.translate(struct)) - assert len(contents) == 1 + assert len(contents) >= 1 assert ( contents[0] == """pybind11::class_<TestClass, std::shared_ptr<TestClass>>(module, "TestClass") diff --git a/test/pytest/pybind11/translators/pytest_pybind11_translators_parameter.py b/test/pytest/pybind11/translators/pytest_pybind11_translators_parameter.py index 6f784ab0048fb6c6925ed4866e75d7d83894d06d..215b537fbba1a4c1e5ba17647bb3a3971bd1dfb6 100644 --- a/test/pytest/pybind11/translators/pytest_pybind11_translators_parameter.py +++ b/test/pytest/pybind11/translators/pytest_pybind11_translators_parameter.py @@ -63,7 +63,7 @@ public: if translator.condition(struct): contents.extend(translator.translate(struct)) - assert len(contents) == 1 + assert len(contents) >= 1 assert ( contents[0] == """pybind11::class_<TestClass, std::shared_ptr<TestClass>>(module, "TestClass") diff --git a/test/pytest/pybind11/translators/pytest_pybind11_translators_static_member_variable.py b/test/pytest/pybind11/translators/pytest_pybind11_translators_static_member_variable.py index b49c9176983ea496bd33dafad520f82562a1f9ca..96f8b3fb4719cf59ac226b9620cc642104bc1ddd 100644 --- a/test/pytest/pybind11/translators/pytest_pybind11_translators_static_member_variable.py +++ b/test/pytest/pybind11/translators/pytest_pybind11_translators_static_member_variable.py @@ -39,7 +39,7 @@ public: if translator.condition(struct): contents.extend(translator.translate(struct)) - assert len(contents) == 1 + assert len(contents) >= 1 assert ( contents[0] == """pybind11::class_<TestClass, std::shared_ptr<TestClass>>(module, "TestClass") diff --git a/test/pytest/translators/pytest_translators_namespace.py b/test/pytest/translators/pytest_translators_namespace.py index a21ef8b0e61ab52d903c1931d2e6d3fad5ff3587..5dd228448a986e83c3894af68f2ebc08f1a7403d 100644 --- a/test/pytest/translators/pytest_translators_namespace.py +++ b/test/pytest/translators/pytest_translators_namespace.py @@ -123,6 +123,7 @@ int test_variable; namespace Namespace2 { } + }""" )