Skip to content

Commit

Permalink
兼容tf2,建议RTX3090使用CUDA11.1版本的ptxas
Browse files Browse the repository at this point in the history
  • Loading branch information
kerlomz committed Nov 13, 2020
1 parent c1af547 commit 3e4373f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
11 changes: 6 additions & 5 deletions graph_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Author: kerlomz <[email protected]>
import os
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()
from tensorflow.python.framework.errors_impl import NotFoundError
from config import ModelConfig

Expand All @@ -18,14 +19,14 @@ def __init__(self, model_conf: ModelConfig):
self.model_name = self.model_conf.model_name
self.graph_name = self.model_conf.model_name
self.version = self.model_conf.model_version
self.graph = tf.Graph()
self.sess = tf.Session(
self.graph = tf.compat.v1.Graph()
self.sess = tf.compat.v1.Session(
graph=self.graph,
config=tf.ConfigProto(
config=tf.compat.v1.ConfigProto(

# allow_soft_placement=True,
# log_device_placement=True,
gpu_options=tf.GPUOptions(
gpu_options=tf.compat.v1.GPUOptions(
# allocator_type='BFC',
allow_growth=True, # it will cause fragmentation.
# per_process_gpu_memory_fraction=self.model_conf.device_usage
Expand All @@ -49,7 +50,7 @@ def load_model(self):
graph_def_file = f.read()
self.graph_def.ParseFromString(graph_def_file)
with self.graph.as_default():
self.sess.run(tf.global_variables_initializer())
self.sess.run(tf.compat.v1.global_variables_initializer())
_ = tf.import_graph_def(self.graph_def, name="")

self.logger.info('TensorFlow Session {} Loaded.'.format(self.model_conf.model_name))
Expand Down
3 changes: 2 additions & 1 deletion package.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import platform
import distutils
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()
from enum import Enum, unique
from utils import SystemUtils
from config import resource_path
Expand All @@ -33,7 +34,7 @@ class Version(Enum):

if __name__ == '__main__':

ver = Version.GPU if tf.test.gpu_device_name() else Version.CPU
ver = Version.CPU

upload = False
server_ip = ""
Expand Down
2 changes: 1 addition & 1 deletion resource/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
20201101
20201109

0 comments on commit 3e4373f

Please sign in to comment.