-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy path__init__.py
116 lines (97 loc) · 4.98 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from .utils import *
from .sync_talk_nodes import *
class InstallStatus:
def __init__(self,
git_clone_sync_talk: bool,
install_dependencies: bool):
self.git_clone_sync_talk = git_clone_sync_talk
self.install_dependencies = install_dependencies
def to_dict(self):
return {
"git_clone_sync_talk": self.git_clone_sync_talk,
"install_dependencies": self.install_dependencies
}
# Prepare directories.
cur_dir = os.path.dirname(__file__)
install_status_path = os.path.join(cur_dir, "install_status.json")
repos_dir = os.path.join(cur_dir, "repos")
if not os.path.isdir(repos_dir):
os.makedirs(repos_dir)
sync_talk_dir = os.path.join(repos_dir, "SyncTalk")
pytorch3d_dir = os.path.join(repos_dir, "pytorch3d")
install_status = utils.read_json_as_class(install_status_path, InstallStatus)
# Clone forked SyncTalk repo.
if not install_status.git_clone_sync_talk:
utils.clone_repository('https://github.com/Ryuukeisyou/SyncTalk.git', sync_talk_dir)
if os.name == 'nt':
subprocess.check_call(["git", "restore", "--source=HEAD", ":/"], cwd=sync_talk_dir)
install_status.git_clone_sync_talk = True
utils.export_class_to_json(install_status, install_status_path)
else:
pass
# utils.pull_repository(sync_talk_dir)
# Install dependencies.
if install_status.git_clone_sync_talk and not install_status.install_dependencies:
# Install portaudio19-dev.
if os.name == 'posix':
if not utils.is_package_installed("portaudio19-dev"):
subprocess.check_call(["sudo", "apt-get", "install", "portaudio19-dev"])
# Install requirements.txt
sync_talk_requirements_path = os.path.join(sync_talk_dir, "requirements.txt")
other_requirements_path = os.path.join(cur_dir, "requirements.txt")
utils.install_requirements(sync_talk_requirements_path)
utils.install_requirements(other_requirements_path)
installed_packages = [pkg.key for pkg in pkg_resources.working_set]
# Install pytorch3d
if "pytorch3d" not in installed_packages:
# Check if wheel is built.
pytorch3d_wheel = utils.get_valid_wheel("pytorch3d", os.path.join(pytorch3d_dir, "dist"))
if pytorch3d_wheel is None:
# Build wheel.
utils.clone_repository('https://github.com/facebookresearch/pytorch3d.git', pytorch3d_dir)
pytorch3d_setup_path = os.path.join(pytorch3d_dir, "setup.py")
subprocess.check_call([sys.executable, pytorch3d_setup_path, "bdist_wheel"], cwd=pytorch3d_dir)
pytorch3d_wheel = utils.get_valid_wheel("pytorch3d", os.path.join(pytorch3d_dir, "dist"))
if pytorch3d_wheel is not None:
subprocess.check_call([sys.executable, "-m", "pip", "install", pytorch3d_wheel])
# Install extensions.
ext_list = ["shencoder", "freqencoder", "gridencoder", "raymarching-face"]
missing_exts = set(ext_list) - set(installed_packages)
if missing_exts:
for ext in missing_exts:
dirname = "raymarching" if ext == "raymarching-face" else ext
ext_name = "raymarching_face" if ext == "raymarching-face" else ext
ext_dir = os.path.join(sync_talk_dir, dirname)
ext_wheel = utils.get_valid_wheel(ext_name, os.path.join(ext_dir, "dist"))
if ext_wheel is None:
# Build wheel.
ext_setup_path = os.path.join(ext_dir, "setup.py")
subprocess.check_call([sys.executable, ext_setup_path, "bdist_wheel"], cwd=ext_dir)
ext_wheel = utils.get_valid_wheel(ext_name, os.path.join(ext_dir, "dist"))
if ext_wheel is not None:
subprocess.check_call([sys.executable, "-m", "pip", "install", ext_wheel])
# Check installations.
all_required_packages = ["pytorch3d"] + ext_list
with open(sync_talk_requirements_path) as f:
all_required_packages += f.read().splitlines()
with open(other_requirements_path) as f:
all_required_packages += f.read().splitlines()
all_required_packages = [s.lower().replace("_", "-") for s in all_required_packages]
installed_packages = [pkg.key for pkg in pkg_resources.working_set]
missing_packages = set(all_required_packages) - set(installed_packages)
if len(missing_packages) == 0:
install_status.install_dependencies = True
utils.export_class_to_json(install_status, install_status_path)
NODE_CLASS_MAPPINGS = {
f'{LoadAve.__name__}(SyncTalk)': LoadAve,
f'{AveProcess.__name__}(SyncTalk)': AveProcess,
f'{LoadHubert.__name__}(SyncTalk)': LoadHubert,
f'{HubertProcess.__name__}(SyncTalk)': HubertProcess,
f'{LoadDeepSpeech.__name__}(SyncTalk)': LoadDeepSpeech,
f'{DeepSpeechProcess.__name__}(SyncTalk)': DeepSpeechProcess,
f'{LoadInferenceData.__name__}(SyncTalk)': LoadInferenceData,
f'{LoadNeRFNetwork.__name__}(SyncTalk)': LoadNeRFNetwork,
f'{Inference.__name__}(SyncTalk)': Inference
}
__all__ = ['NODE_CLASS_MAPPINGS']