diff --git a/structuretoolkit/analyse/symmetry.py b/structuretoolkit/analyse/symmetry.py index 773d05d5d..2e8be71db 100644 --- a/structuretoolkit/analyse/symmetry.py +++ b/structuretoolkit/analyse/symmetry.py @@ -384,26 +384,42 @@ def get_primitive_cell( >>> symmetry = Symmetry(structure) >>> len(symmetry.get_primitive_cell()) == len(basis) True + + .. warning:: + Custom arrays defined in the base structures + :attr:`ase.atoms.Atoms.arrays` are not copied to the new structure! """ + if not all(self._structure.pbc): + raise ValueError("Can only symmetrize periodic structures.") ret = spglib.standardize_cell( self._get_spglib_cell(use_elements=use_elements, use_magmoms=use_magmoms), to_primitive=not standardize, ) if ret is None: raise SymmetryError(spglib.spglib.spglib_error.message) - cell, positions, indices = ret - positions = (cell.T @ positions.T).T - new_structure = self._structure.copy() - new_structure.cell = cell - new_structure = new_structure[: len(indices)] + cell, scaled_positions, indices = ret indices_dict = { v: k for k, v in structuretoolkit.common.helper.get_species_indices_dict( structure=self._structure ).items() } - new_structure.symbols = [indices_dict[i] for i in indices] - new_structure.positions = positions + symbols = [indices_dict[i] for i in indices] + arrays = { + k: self._structure.arrays[k] + for k in self._structure.arrays + if k not in ("numbers", "positions") + } + new_structure = type(self._structure)( + symbols=symbols, + scaled_positions=scaled_positions, + cell=cell, + pbc=[True, True, True], + ) + keys = set(arrays) - {"numbers", "positions"} + if len(keys) > 0: + warning(f"Custom arrays {keys} do not carry over to new structure!") + return new_structure def get_ir_reciprocal_mesh( diff --git a/tests/test_symmetry.py b/tests/test_symmetry.py index 1e29f5b71..c7c30eda4 100644 --- a/tests/test_symmetry.py +++ b/tests/test_symmetry.py @@ -173,7 +173,9 @@ def test_get_ir_reciprocal_mesh(self): def test_get_primitive_cell(self): cell = 2.2 * np.identity(3) - basis = Atoms("AlFe", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell) + basis = Atoms( + "AlFe", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True + ) structure = basis.repeat([2, 2, 2]) sym = stk.analyse.get_symmetry(structure=structure) self.assertEqual(len(basis), len(sym.get_primitive_cell(standardize=True))) @@ -199,7 +201,7 @@ def test_get_primitive_cell_hex(self): [0.77, 1.57, 5.74], ] cell = [[2.519, 1.454, 4.590], [-2.519, 1.454, 4.590], [0.0, -2.909, 4.590]] - structure = Atoms(symbols=elements, positions=positions, cell=cell) + structure = Atoms(symbols=elements, positions=positions, cell=cell, pbc=True) structure_repeat = structure.repeat([2, 2, 2]) sym = stk.analyse.get_symmetry(structure=structure_repeat) structure_prim_base = sym.get_primitive_cell()