Skip to content

Commit

Permalink
Merge pull request #246 from pyiron/symmetry_test
Browse files Browse the repository at this point in the history
Add simple test for symmerty
  • Loading branch information
jan-janssen authored Apr 23, 2024
2 parents d0225c2 + 4003b2e commit 7cc5821
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions tests/test_elastic_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import unittest

import numpy as np

from atomistics.workflows.elastic.symmetry import (
get_C_from_A2,
get_LAG_Strain_List,
get_symmetry_family_from_SGN,
)


class TestElasticHelper(unittest.TestCase):
def test_get_C_from_A2(self):
A2 = np.arange(21)
C = get_C_from_A2(A2=A2, LC="CI")
self.assertEqual(C.shape, (6, 6))
self.assertEqual(int(sum(np.diag(C))), 6)
self.assertEqual(int(np.sum(C)), 4)
C = get_C_from_A2(A2=A2, LC="HI")
self.assertEqual(C.shape, (6, 6))
self.assertEqual(int(sum(np.diag(C))), 25)
self.assertEqual(int(np.sum(C)), 17)
C = get_C_from_A2(A2=A2, LC="RI")
self.assertEqual(C.shape, (6, 6))
self.assertEqual(int(sum(np.diag(C))), 25)
self.assertEqual(int(np.sum(C)), 16)
C = get_C_from_A2(A2=A2, LC="RII")
self.assertEqual(C.shape, (6, 6))
self.assertEqual(int(sum(np.diag(C))), 25)
self.assertEqual(int(np.sum(C)), 17)
C = get_C_from_A2(A2=A2, LC="TI")
self.assertEqual(C.shape, (6, 6))
self.assertEqual(int(sum(np.diag(C))), 9)
self.assertEqual(int(np.sum(C)), 8)
C = get_C_from_A2(A2=A2, LC="TII")
self.assertEqual(C.shape, (6, 6))
self.assertEqual(int(sum(np.diag(C))), 11)
self.assertEqual(int(np.sum(C)), 9)
C = get_C_from_A2(A2=A2, LC="O")
self.assertEqual(C.shape, (6, 6))
self.assertEqual(int(sum(np.diag(C))), 14)
self.assertEqual(int(np.sum(C)), 12)
C = get_C_from_A2(A2=A2, LC="M")
self.assertEqual(C.shape, (6, 6))
self.assertEqual(int(sum(np.diag(C))), 14)
self.assertEqual(int(np.sum(C)), 58)
C = get_C_from_A2(A2=A2, LC="N")
self.assertEqual(C.shape, (6, 6))
self.assertEqual(int(sum(np.diag(C))), 12)
self.assertEqual(int(np.sum(C)), 72)

def test_get_LAG_Strain_List(self):
Lag_strain_lst = get_LAG_Strain_List(LC="CI")
self.assertEqual(len(Lag_strain_lst), 3)
Lag_strain_lst = get_LAG_Strain_List(LC="HI")
self.assertEqual(len(Lag_strain_lst), 5)
Lag_strain_lst = get_LAG_Strain_List(LC="RI")
self.assertEqual(len(Lag_strain_lst), 6)
Lag_strain_lst = get_LAG_Strain_List(LC="RII")
self.assertEqual(len(Lag_strain_lst), 7)
Lag_strain_lst = get_LAG_Strain_List(LC="TI")
self.assertEqual(len(Lag_strain_lst), 6)
Lag_strain_lst = get_LAG_Strain_List(LC="TII")
self.assertEqual(len(Lag_strain_lst), 7)
Lag_strain_lst = get_LAG_Strain_List(LC="O")
self.assertEqual(len(Lag_strain_lst), 9)
Lag_strain_lst = get_LAG_Strain_List(LC="M")
self.assertEqual(len(Lag_strain_lst), 13)
Lag_strain_lst = get_LAG_Strain_List(LC="N")
self.assertEqual(len(Lag_strain_lst), 21)

def test_get_symmetry_family_from_SGN(self):
LC = get_symmetry_family_from_SGN(SGN=1)
self.assertEqual(LC, "N")
LC = get_symmetry_family_from_SGN(SGN=3)
self.assertEqual(LC, "M")
LC = get_symmetry_family_from_SGN(SGN=16)
self.assertEqual(LC, "O")
LC = get_symmetry_family_from_SGN(SGN=75)
self.assertEqual(LC, "TII")
LC = get_symmetry_family_from_SGN(SGN=89)
self.assertEqual(LC, "TI")
LC = get_symmetry_family_from_SGN(SGN=143)
self.assertEqual(LC, "RII")
LC = get_symmetry_family_from_SGN(SGN=149)
self.assertEqual(LC, "RI")
LC = get_symmetry_family_from_SGN(SGN=168)
self.assertEqual(LC, "HII")
LC = get_symmetry_family_from_SGN(SGN=177)
self.assertEqual(LC, "HI")
LC = get_symmetry_family_from_SGN(SGN=195)
self.assertEqual(LC, "CII")
LC = get_symmetry_family_from_SGN(SGN=207)
self.assertEqual(LC, "CI")
with self.assertRaises(ValueError):
get_symmetry_family_from_SGN(SGN=231)

0 comments on commit 7cc5821

Please sign in to comment.