Coverage for wulffpack/core/geometry.py: 100%

72 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-02 08:51 +0000

1from typing import List 

2import numpy as np 

3from ase import Atoms 

4from ase.build import bulk 

5from spglib import get_symmetry_dataset, standardize_cell 

6 

7 

8def get_tetrahedral_volume(triangle, origin): 

9 """ 

10 Get the volume of the tetrahedron formed by the origin 

11 and three vertices defined by the input. 

12 

13 Parameters 

14 ---------- 

15 triangle : list of list of 3 floats 

16 Three coordinates forming a triangle 

17 origin : list of 3 floats 

18 The origin 

19 

20 Returns 

21 ------- 

22 float 

23 The volume 

24 """ 

25 if len(triangle) != 3: 

26 raise ValueError('triangle argument must contain three coordinates') 

27 M = np.vstack((triangle, [0, 0, 0])) 

28 M = np.vstack((M.transpose(), [1, 1, 1, 1])) 

29 return abs(np.linalg.det(M)) / 6 

30 

31 

32def get_angle(v1, v2, tol=1e-5): 

33 cos = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) 

34 if abs(abs(cos) - 1) < tol: 

35 if cos < -1.0: 

36 cos = -1.0 

37 elif cos > 1.0: 

38 cos = 1.0 

39 return np.arccos(cos) 

40 

41 

42def get_rotation_matrix(theta, u): 

43 u /= np.linalg.norm(u) 

44 R = np.zeros((3, 3)) 

45 R[0, 0] = np.cos(theta) + u[0] * u[0] * (1 - np.cos(theta)) 

46 R[1, 0] = u[0] * u[1] * (1 - np.cos(theta)) + u[2] * np.sin(theta) 

47 R[2, 0] = u[0] * u[2] * (1 - np.cos(theta)) - u[1] * np.sin(theta) 

48 R[0, 1] = u[0] * u[1] * (1 - np.cos(theta)) - u[2] * np.sin(theta) 

49 R[1, 1] = np.cos(theta) + u[1] * u[1] * (1 - np.cos(theta)) 

50 R[2, 1] = u[1] * u[2] * (1 - np.cos(theta)) + u[0] * np.sin(theta) 

51 R[0, 2] = u[0] * u[2] * (1 - np.cos(theta)) + u[1] * np.sin(theta) 

52 R[1, 2] = u[1] * u[2] * (1 - np.cos(theta)) - u[0] * np.sin(theta) 

53 R[2, 2] = np.cos(theta) + u[2] * u[2] * (1 - np.cos(theta)) 

54 return R 

55 

56 

57def get_standardized_structure(structure: Atoms = None, symprec: float = 1e-5) -> Atoms: 

58 """ 

59 Returns a standardized structure based on a primitive one, 

60 using the spglib routine `standardize_cell`. 

61 

62 Parameters 

63 ---------- 

64 structure 

65 Structure that should be standardized. If None, 

66 FCC Au with lattice parameter 4.0 is returned. 

67 symprec 

68 Numerical tolerance for symmetry analysis, forwarded to spglib. 

69 """ 

70 if structure is None: 

71 structure = bulk('Au', crystalstructure='fcc', a=4.0) 

72 

73 spg_structure = ( 

74 structure.get_cell(), 

75 structure.get_scaled_positions(), 

76 structure.get_atomic_numbers(), 

77 ) 

78 cell, scaled_positions, atomic_numbers = standardize_cell(spg_structure, symprec=symprec) 

79 return Atoms(atomic_numbers, 

80 scaled_positions=scaled_positions, 

81 cell=cell, 

82 pbc=True) 

83 

84 

85def break_symmetry(full_symmetry: List[np.ndarray], 

86 symmetry_axes: List[np.ndarray], 

87 inversion: List[bool] = None) -> List[np.ndarray]: 

88 """ 

89 Reduce symmetry elements to only those that do not affect 

90 one or more vectors. 

91 

92 Parameters 

93 ---------- 

94 full_symmetry 

95 List of candidate symmetries 

96 symmetry_axes 

97 Vectors that must remain unchanged by symmetry operation 

98 inversion 

99 One boolean for each `symmetry_axis`, if True, a symmetry 

100 element is kept if it only inverts any or all of the 

101 symmetry axes 

102 """ 

103 if inversion is None: 

104 inversion = [False] * len(symmetry_axes) 

105 else: 

106 if len(inversion) != len(symmetry_axes): 

107 raise ValueError('inversion must be a list of bools with ' 

108 'the same length as symmetry_axes ' 

109 '({} != {})'.format(len(inversion), 

110 len(symmetry_axes))) 

111 broken_symmetry = [] 

112 for R in full_symmetry: 

113 for symmetry_axis, inv in zip(symmetry_axes, inversion): 

114 v = np.dot(R, symmetry_axis) 

115 if not np.allclose(v, symmetry_axis): 

116 if inv: 

117 if not np.allclose(v, -symmetry_axis): 

118 break 

119 else: 

120 break 

121 else: 

122 broken_symmetry.append(R) 

123 return broken_symmetry 

124 

125 

126def get_symmetries(structure: Atoms, symprec: float = 1e-5) -> (Atoms, List[np.ndarray]): 

127 """ 

128 Get symmetry operations of the point group. 

129 

130 Parameters 

131 ---------- 

132 structure 

133 The structure for which the symmetry operations 

134 is to be extracted 

135 symprec 

136 Numerical tolerance for symmetry analysis, forwarded to spglib. 

137 

138 Returns 

139 ------- 

140 list of NumPy arrays 

141 The symmetry operations of the point group in matrix format. 

142 """ 

143 

144 spg_structure = ( 

145 structure.get_cell(), 

146 structure.get_scaled_positions(), 

147 structure.get_atomic_numbers(), 

148 ) 

149 symmetry_data = get_symmetry_dataset(spg_structure, symprec=symprec) 

150 rotations = [] 

151 for R in symmetry_data['rotations']: 

152 if not is_array_in_arrays(R, rotations): 

153 rotations.append(R) 

154 return rotations 

155 

156 

157def is_array_in_arrays(array: np.ndarray, arrays: List[np.ndarray]) -> bool: 

158 """ 

159 Checks whether an array exists (True) or not (False) in a list 

160 of arrays. 

161 

162 Parameters 

163 ---------- 

164 array 

165 Array to be searched for 

166 arrays 

167 List of arrays to search among 

168 """ 

169 for array_comp in arrays: 

170 if np.allclose(array, array_comp): 

171 return True 

172 return False 

173 

174 

175def where_is_array_in_arrays(array: np.ndarray, arrays: List[np.ndarray]) -> int: 

176 """ 

177 Returns the index of an array in a list of arrays. 

178 

179 array 

180 Array to search for 

181 arrays 

182 List of arrays to search within 

183 """ 

184 for i, array_comp in enumerate(arrays): 

185 if np.allclose(array, array_comp): 

186 return i 

187 return -1