From 3e4373fb55649f06472cddd4c32bd34cb0f64a81 Mon Sep 17 00:00:00 2001 From: kerlomz Date: Fri, 13 Nov 2020 14:26:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=BC=E5=AE=B9tf2,=E5=BB=BA=E8=AE=AERTX3090?= =?UTF-8?q?=E4=BD=BF=E7=94=A8CUDA11.1=E7=89=88=E6=9C=AC=E7=9A=84ptxas?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- graph_session.py | 11 ++++++----- package.py | 3 ++- resource/VERSION | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/graph_session.py b/graph_session.py index 8b3a7f0..323dc63 100644 --- a/graph_session.py +++ b/graph_session.py @@ -3,6 +3,7 @@ # Author: kerlomz import os import tensorflow as tf +tf.compat.v1.disable_v2_behavior() from tensorflow.python.framework.errors_impl import NotFoundError from config import ModelConfig @@ -18,14 +19,14 @@ def __init__(self, model_conf: ModelConfig): self.model_name = self.model_conf.model_name self.graph_name = self.model_conf.model_name self.version = self.model_conf.model_version - self.graph = tf.Graph() - self.sess = tf.Session( + self.graph = tf.compat.v1.Graph() + self.sess = tf.compat.v1.Session( graph=self.graph, - config=tf.ConfigProto( + config=tf.compat.v1.ConfigProto( # allow_soft_placement=True, # log_device_placement=True, - gpu_options=tf.GPUOptions( + gpu_options=tf.compat.v1.GPUOptions( # allocator_type='BFC', allow_growth=True, # it will cause fragmentation. # per_process_gpu_memory_fraction=self.model_conf.device_usage @@ -49,7 +50,7 @@ def load_model(self): graph_def_file = f.read() self.graph_def.ParseFromString(graph_def_file) with self.graph.as_default(): - self.sess.run(tf.global_variables_initializer()) + self.sess.run(tf.compat.v1.global_variables_initializer()) _ = tf.import_graph_def(self.graph_def, name="") self.logger.info('TensorFlow Session {} Loaded.'.format(self.model_conf.model_name)) diff --git a/package.py b/package.py index f2c7cc8..7259f4c 100644 --- a/package.py +++ b/package.py @@ -9,6 +9,7 @@ import platform import distutils import tensorflow as tf +tf.compat.v1.disable_v2_behavior() from enum import Enum, unique from utils import SystemUtils from config import resource_path @@ -33,7 +34,7 @@ class Version(Enum): if __name__ == '__main__': - ver = Version.GPU if tf.test.gpu_device_name() else Version.CPU + ver = Version.CPU upload = False server_ip = "" diff --git a/resource/VERSION b/resource/VERSION index 0f81a28..479e4f3 100644 --- a/resource/VERSION +++ b/resource/VERSION @@ -1 +1 @@ -20201101 \ No newline at end of file +20201109 \ No newline at end of file