Skip to content

Commit

Permalink
Fix tests by changing 'main' to the absolute module paths (#6935)
Browse files Browse the repository at this point in the history
* Fixed xfail marked tests

* Removed ununsed import
  • Loading branch information
janasangeetha authored Oct 23, 2024
1 parent 2f62f98 commit 815bab3
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 43 deletions.
8 changes: 3 additions & 5 deletions tfx/dsl/component/experimental/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,6 @@ def testBeamExecutionNonNullableReturnError(self):
ValueError, 'Non-nullable output \'e\' received None return value'):
beam_dag_runner.BeamDagRunner().run(test_pipeline)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testComponentAnnotation(self):
"""Test component annotation parsed from decorator param."""
instance_1 = injector_1_with_annotation(foo=9, bar='secret')
Expand Down Expand Up @@ -654,18 +652,18 @@ def testComponentAnnotation(self):

# Verify base_type annotation parsed from component decorator is correct.
self.assertEqual(
test_pipeline.components[0].type, '__main__.injector_1_with_annotation'
test_pipeline.components[0].type, 'tfx.dsl.component.experimental.decorators_test.injector_1_with_annotation'
)
self.assertEqual(
test_pipeline.components[0].type_annotation.MLMD_SYSTEM_BASE_TYPE, 1)
self.assertEqual(
test_pipeline.components[1].type,
'__main__.simple_component_with_annotation',
'tfx.dsl.component.experimental.decorators_test.simple_component_with_annotation',
)
self.assertEqual(
test_pipeline.components[1].type_annotation.MLMD_SYSTEM_BASE_TYPE, 2)
self.assertEqual(
test_pipeline.components[2].type, '__main__.verify_with_annotation'
test_pipeline.components[2].type, 'tfx.dsl.component.experimental.decorators_test.verify_with_annotation'
)
self.assertEqual(
test_pipeline.components[2].type_annotation.MLMD_SYSTEM_BASE_TYPE, 3)
Expand Down
8 changes: 3 additions & 5 deletions tfx/dsl/component/experimental/decorators_typeddict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,6 @@ def testBeamExecutionNonNullableReturnError(self):
):
beam_dag_runner.BeamDagRunner().run(test_pipeline)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testComponentAnnotation(self):
"""Test component annotation parsed from decorator param."""
instance_1 = injector_1_with_annotation(foo=9, bar='secret')
Expand Down Expand Up @@ -675,20 +673,20 @@ def testComponentAnnotation(self):

# Verify base_type annotation parsed from component decorator is correct.
self.assertEqual(
test_pipeline.components[0].type, '__main__.injector_1_with_annotation'
test_pipeline.components[0].type, 'tfx.dsl.component.experimental.decorators_typeddict_test.injector_1_with_annotation'
)
self.assertEqual(
test_pipeline.components[0].type_annotation.MLMD_SYSTEM_BASE_TYPE, 1
)
self.assertEqual(
test_pipeline.components[1].type,
'__main__.simple_component_with_annotation',
'tfx.dsl.component.experimental.decorators_typeddict_test.simple_component_with_annotation',
)
self.assertEqual(
test_pipeline.components[1].type_annotation.MLMD_SYSTEM_BASE_TYPE, 2
)
self.assertEqual(
test_pipeline.components[2].type, '__main__.verify_with_annotation'
test_pipeline.components[2].type, 'tfx.dsl.component.experimental.decorators_typeddict_test.verify_with_annotation'
)
self.assertEqual(
test_pipeline.components[2].type_annotation.MLMD_SYSTEM_BASE_TYPE, 3
Expand Down
9 changes: 2 additions & 7 deletions tfx/dsl/components/base/executor_spec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.dsl.components.base.executor_spec."""


import pytest
import tensorflow as tf
from tfx.dsl.components.base import base_executor
from tfx.dsl.components.base import executor_spec
Expand All @@ -39,22 +38,18 @@ def testNotImplementedError(self):
'_TestSpecWithoutEncode does not support encoding into IR.'):
_TestSpecWithoutEncode().encode()

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testExecutorClassSpecCopy(self):
spec = executor_spec.ExecutorClassSpec(_DummyExecutor)
spec.add_extra_flags('a')
spec_copy = spec.copy()
del spec
self.assertProtoEquals(
"""
class_path: "__main__._DummyExecutor"
class_path: "tfx.dsl.components.base.executor_spec_test._DummyExecutor"
extra_flags: "a"
""",
spec_copy.encode())

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testBeamExecutorSpecCopy(self):
spec = executor_spec.BeamExecutorSpec(_DummyExecutor)
spec.add_extra_flags('a')
Expand All @@ -64,7 +59,7 @@ def testBeamExecutorSpecCopy(self):
self.assertProtoEquals(
"""
python_executor_spec: {
class_path: "__main__._DummyExecutor"
class_path: "tfx.dsl.components.base.executor_spec_test._DummyExecutor"
extra_flags: "a"
}
beam_pipeline_args: "b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Tests for tfx.orchestration.portable.input_resolution.input_graph_resolver."""

import pytest
from unittest import mock

from absl.testing import parameterized
Expand Down Expand Up @@ -466,8 +465,6 @@ def testBuildGraphFn_ComplexCase(self, raw_inputs, expected):
result = graph_fn(inputs)
self.assertEqual(result, [Integer(expected)])

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testResolverStrategy(self):
input_graph = self.parse_input_graph("""
nodes {
Expand All @@ -494,7 +491,7 @@ def testResolverStrategy(self):
key: "op_1"
value {
op_node {
op_type: "__main__.RenameStrategy"
op_type: "tfx.orchestration.portable.input_resolution.input_graph_resolver_test.RenameStrategy"
args {
node_id: "dict_1"
}
Expand Down
5 changes: 1 addition & 4 deletions tfx/types/standard_artifacts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for standard TFX Artifact types."""


import pytest
import math
from typing import Any, Dict
from unittest import mock
Expand Down Expand Up @@ -69,7 +68,7 @@ def __eq__(self, other):

_TEST_JSONVALUE_OBJ_RAW = (
'{\"__class__\": \"TfxTestJsonableCls\", \"__module__\":'
' \"__main__\", \"__tfx_object_type__\": '
' \"tfx.types.standard_artifacts_test\", \"__tfx_object_type__\": '
'\"jsonable\", \"x\": 42}')
_TEST_JSONVALUE_OBJ_DECODED = TfxTestJsonableCls(42)

Expand Down Expand Up @@ -120,8 +119,6 @@ def testJsonValueDict(self):
self.assertEqual(_TEST_JSONVALUE_DICT_DECODED,
instance.decode(_TEST_JSONVALUE_DICT_RAW))

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testJsonValueObj(self):
instance = standard_artifacts.JsonValue()
self.assertEqual(_TEST_JSONVALUE_OBJ_RAW,
Expand Down
25 changes: 7 additions & 18 deletions tfx/utils/json_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tests for tfx.utils.json_utils."""


import pytest
import tensorflow as tf
from tfx.proto import trainer_pb2
from tfx.utils import deprecation_utils
Expand All @@ -37,15 +36,13 @@ def __init__(self, a, b, c):

class JsonUtilsTest(tf.test.TestCase):

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testDumpsJsonableObjectRoundtrip(self):
obj = _DefaultJsonableObject(1, {'a': 'b'}, [True])

json_text = json_utils.dumps(obj)
self.assertEqual(
(
'{"__class__": "_DefaultJsonableObject", "__module__": "__main__",'
'{"__class__": "_DefaultJsonableObject", "__module__": "tfx.utils.json_utils_test",'
' "__tfx_object_type__": "jsonable", "a": 1, "b": {"a": "b"}, "c":'
' [true]}'
),
Expand All @@ -57,8 +54,6 @@ def testDumpsJsonableObjectRoundtrip(self):
self.assertDictEqual({'a': 'b'}, actual_obj.b)
self.assertCountEqual([True], actual_obj.c)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testDumpsNestedJsonableObject(self):
nested_obj = _DefaultJsonableObject(1, 2,
trainer_pb2.TrainArgs(num_steps=100))
Expand All @@ -67,9 +62,9 @@ def testDumpsNestedJsonableObject(self):
json_text = json_utils.dumps(obj)
self.assertEqual(
(
'{"__class__": "_DefaultJsonableObject", "__module__": "__main__",'
'{"__class__": "_DefaultJsonableObject", "__module__": "tfx.utils.json_utils_test",'
' "__tfx_object_type__": "jsonable", "a": {"__class__":'
' "_DefaultJsonableObject", "__module__": "__main__",'
' "_DefaultJsonableObject", "__module__": "tfx.utils.json_utils_test",'
' "__tfx_object_type__": "jsonable", "a": 1, "b": 2, "c":'
' {"__class__": "TrainArgs", "__module__": "tfx.proto.trainer_pb2",'
' "__proto_value__": "{\\n \\"num_steps\\": 100\\n}",'
Expand All @@ -85,17 +80,15 @@ def testDumpsNestedJsonableObject(self):
self.assertIsNone(actual_obj.b)
self.assertIsNone(actual_obj.c)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testDumpsNestedClass(self):
obj = _DefaultJsonableObject(_DefaultJsonableObject, None, None)

json_text = json_utils.dumps(obj)
self.assertEqual(
(
'{"__class__": "_DefaultJsonableObject", "__module__": "__main__",'
'{"__class__": "_DefaultJsonableObject", "__module__": "tfx.utils.json_utils_test",'
' "__tfx_object_type__": "jsonable", "a": {"__class__":'
' "_DefaultJsonableObject", "__module__": "__main__",'
' "_DefaultJsonableObject", "__module__": "tfx.utils.json_utils_test",'
' "__tfx_object_type__": "class"}, "b": null, "c": null}'
),
json_text,
Expand All @@ -106,13 +99,11 @@ def testDumpsNestedClass(self):
self.assertIsNone(actual_obj.b)
self.assertIsNone(actual_obj.c)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testDumpsClass(self):
json_text = json_utils.dumps(_DefaultJsonableObject)
self.assertEqual(
(
'{"__class__": "_DefaultJsonableObject", "__module__": "__main__",'
'{"__class__": "_DefaultJsonableObject", "__module__": "tfx.utils.json_utils_test",'
' "__tfx_object_type__": "class"}'
),
json_text,
Expand All @@ -121,13 +112,11 @@ def testDumpsClass(self):
actual_obj = json_utils.loads(json_text)
self.assertEqual(_DefaultJsonableObject, actual_obj)

@pytest.mark.xfail(run=False, reason="PR 6889 This test fails and needs to be fixed. "
"If this test passes, please remove this mark.", strict=True)
def testDumpsDeprecatedClass(self):
json_text = json_utils.dumps(_DeprecatedAlias)
self.assertEqual(
(
'{"__class__": "_DefaultJsonableObject", "__module__": "__main__",'
'{"__class__": "_DefaultJsonableObject", "__module__": "tfx.utils.json_utils_test",'
' "__tfx_object_type__": "class"}'
),
json_text,
Expand Down

0 comments on commit 815bab3

Please sign in to comment.