From 4003b2ed4d736d60abae8b25d37d8c84ca80de97 Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Mon, 22 Apr 2024 19:08:30 -0500 Subject: [PATCH] Add simple test for symmerty --- tests/test_elastic_helper.py | 96 ++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 tests/test_elastic_helper.py diff --git a/tests/test_elastic_helper.py b/tests/test_elastic_helper.py new file mode 100644 index 00000000..cf229904 --- /dev/null +++ b/tests/test_elastic_helper.py @@ -0,0 +1,96 @@ +import unittest + +import numpy as np + +from atomistics.workflows.elastic.symmetry import ( + get_C_from_A2, + get_LAG_Strain_List, + get_symmetry_family_from_SGN, +) + + +class TestElasticHelper(unittest.TestCase): + def test_get_C_from_A2(self): + A2 = np.arange(21) + C = get_C_from_A2(A2=A2, LC="CI") + self.assertEqual(C.shape, (6, 6)) + self.assertEqual(int(sum(np.diag(C))), 6) + self.assertEqual(int(np.sum(C)), 4) + C = get_C_from_A2(A2=A2, LC="HI") + self.assertEqual(C.shape, (6, 6)) + self.assertEqual(int(sum(np.diag(C))), 25) + self.assertEqual(int(np.sum(C)), 17) + C = get_C_from_A2(A2=A2, LC="RI") + self.assertEqual(C.shape, (6, 6)) + self.assertEqual(int(sum(np.diag(C))), 25) + self.assertEqual(int(np.sum(C)), 16) + C = get_C_from_A2(A2=A2, LC="RII") + self.assertEqual(C.shape, (6, 6)) + self.assertEqual(int(sum(np.diag(C))), 25) + self.assertEqual(int(np.sum(C)), 17) + C = get_C_from_A2(A2=A2, LC="TI") + self.assertEqual(C.shape, (6, 6)) + self.assertEqual(int(sum(np.diag(C))), 9) + self.assertEqual(int(np.sum(C)), 8) + C = get_C_from_A2(A2=A2, LC="TII") + self.assertEqual(C.shape, (6, 6)) + self.assertEqual(int(sum(np.diag(C))), 11) + self.assertEqual(int(np.sum(C)), 9) + C = get_C_from_A2(A2=A2, LC="O") + self.assertEqual(C.shape, (6, 6)) + self.assertEqual(int(sum(np.diag(C))), 14) + self.assertEqual(int(np.sum(C)), 12) + C = get_C_from_A2(A2=A2, LC="M") + self.assertEqual(C.shape, (6, 6)) + self.assertEqual(int(sum(np.diag(C))), 14) + self.assertEqual(int(np.sum(C)), 58) + C = get_C_from_A2(A2=A2, LC="N") + self.assertEqual(C.shape, (6, 6)) + self.assertEqual(int(sum(np.diag(C))), 12) + self.assertEqual(int(np.sum(C)), 72) + + def test_get_LAG_Strain_List(self): + Lag_strain_lst = get_LAG_Strain_List(LC="CI") + self.assertEqual(len(Lag_strain_lst), 3) + Lag_strain_lst = get_LAG_Strain_List(LC="HI") + self.assertEqual(len(Lag_strain_lst), 5) + Lag_strain_lst = get_LAG_Strain_List(LC="RI") + self.assertEqual(len(Lag_strain_lst), 6) + Lag_strain_lst = get_LAG_Strain_List(LC="RII") + self.assertEqual(len(Lag_strain_lst), 7) + Lag_strain_lst = get_LAG_Strain_List(LC="TI") + self.assertEqual(len(Lag_strain_lst), 6) + Lag_strain_lst = get_LAG_Strain_List(LC="TII") + self.assertEqual(len(Lag_strain_lst), 7) + Lag_strain_lst = get_LAG_Strain_List(LC="O") + self.assertEqual(len(Lag_strain_lst), 9) + Lag_strain_lst = get_LAG_Strain_List(LC="M") + self.assertEqual(len(Lag_strain_lst), 13) + Lag_strain_lst = get_LAG_Strain_List(LC="N") + self.assertEqual(len(Lag_strain_lst), 21) + + def test_get_symmetry_family_from_SGN(self): + LC = get_symmetry_family_from_SGN(SGN=1) + self.assertEqual(LC, "N") + LC = get_symmetry_family_from_SGN(SGN=3) + self.assertEqual(LC, "M") + LC = get_symmetry_family_from_SGN(SGN=16) + self.assertEqual(LC, "O") + LC = get_symmetry_family_from_SGN(SGN=75) + self.assertEqual(LC, "TII") + LC = get_symmetry_family_from_SGN(SGN=89) + self.assertEqual(LC, "TI") + LC = get_symmetry_family_from_SGN(SGN=143) + self.assertEqual(LC, "RII") + LC = get_symmetry_family_from_SGN(SGN=149) + self.assertEqual(LC, "RI") + LC = get_symmetry_family_from_SGN(SGN=168) + self.assertEqual(LC, "HII") + LC = get_symmetry_family_from_SGN(SGN=177) + self.assertEqual(LC, "HI") + LC = get_symmetry_family_from_SGN(SGN=195) + self.assertEqual(LC, "CII") + LC = get_symmetry_family_from_SGN(SGN=207) + self.assertEqual(LC, "CI") + with self.assertRaises(ValueError): + get_symmetry_family_from_SGN(SGN=231) \ No newline at end of file