diff --git a/essos/coils.py b/essos/coils.py index 7782fe8e..c329640f 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -1,36 +1,35 @@ import jax jax.config.update("jax_enable_x64", True) import jax.numpy as jnp -from jax.lax import fori_loop from jax import tree_util, jit, vmap from functools import partial from .plot import fix_matplotlib_3d -def compute_curvature(gammadash, gammadashdash): - return jnp.linalg.norm(jnp.cross(gammadash, gammadashdash, axis=1), axis=1) / jnp.linalg.norm(gammadash, axis=1)**3 - class Curves: - """ - Class to store the curves + """ Class to store the curves - ----------- Attributes: - dofs (jnp.ndarray - shape (n_indcurves, 3, 2*order+1)): Fourier Coefficients of the independent curves + dofs (jnp.ndarray - shape (n_base_curves, 3, 2*order+1)): Fourier Coefficients of the base curves n_segments (int): Number of segments to discretize the curves + quadpoints (jnp.ndarray - shape (n_segments,)): Quadrature points used to discretize the curves nfp (int): Number of field periods stellsym (bool): Stellarator symmetry order (int): Order of the Fourier series - curves jnp.ndarray - shape (n_indcurves*nfp*(1+stellsym), 3, 2*order+1)): Curves obtained by applying rotations and flipping corresponding to nfp fold rotational symmetry and optionally stellarator symmetry - gamma (jnp.array - shape (n_coils, n_segments, 3)): Discretized curves - gamma_dash (jnp.array - shape (n_coils, n_segments, 3)): Discretized curves derivatives - + n_base_curves (int): Number of base curves before applying symmetries + curves (jnp.ndarray - shape (n_base_curves*nfp*(1+stellsym), 3, 2*order+1)): Curves obtained by applying rotations and flipping corresponding to nfp fold rotational symmetry and optionally stellarator symmetry + gamma (jnp.ndarray - shape (n_curves, n_segments, 3)): Discretized curves + gamma_dash (jnp.ndarray - shape (n_curves, n_segments, 3)): Discretized curves derivatives + gamma_dashdash (jnp.ndarray - shape (n_curves, n_segments, 3)): Discretized curves second derivatives """ - def __init__(self, dofs: jnp.ndarray, n_segments: int = 100, nfp: int = 1, stellsym: bool = True): - dofs = jnp.array(dofs) - # assert isinstance(dofs, jnp.ndarray), "dofs must be a jnp.ndarray" - assert dofs.ndim == 3, "dofs must be a 3D array with shape (n_curves, 3, 2*order+1)" - assert dofs.shape[1] == 3, "dofs must have shape (n_curves, 3, 2*order+1)" - assert dofs.shape[2] % 2 == 1, "dofs must have shape (n_curves, 3, 2*order+1)" + def __init__(self, + dofs: jnp.ndarray, + n_segments: int = 100, + nfp: int = 1, + stellsym: bool = True): + if hasattr(dofs, 'shape'): + assert len(dofs.shape) == 3, "dofs must be a 3D array with shape (n_curves, 3, 2*order+1)" + assert dofs.shape[1] == 3, "dofs must have shape (n_curves, 3, 2*order+1)" + assert dofs.shape[2] % 2 == 1, "dofs must have shape (n_curves, 3, 2*order+1)" assert isinstance(n_segments, int), "n_segments must be an integer" assert n_segments > 2, "n_segments must be greater than 2" assert isinstance(nfp, int), "nfp must be a positive integer" @@ -41,150 +40,165 @@ def __init__(self, dofs: jnp.ndarray, n_segments: int = 100, nfp: int = 1, stell self._n_segments = n_segments self._nfp = nfp self._stellsym = stellsym - self._order = dofs.shape[2]//2 - self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) - self.quadpoints = jnp.linspace(0, 1, self.n_segments, endpoint=False) - self._set_gamma() - self.n_base_curves=dofs.shape[0] - - def __str__(self): - return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ - + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" - - def __repr__(self): - return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ - + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" - - def _tree_flatten(self): - children = (self._dofs,) # arrays / dynamic values - aux_data = {"n_segments": self._n_segments, "nfp": self._nfp, "stellsym": self._stellsym} # static values - return (children, aux_data) - @classmethod - def _tree_unflatten(cls, aux_data, children): - return cls(*children, **aux_data) - - partial(jit, static_argnames=['self']) - def _set_gamma(self): - def fori_createdata(order_index: int, data: jnp.ndarray) -> jnp.ndarray: - return data[0] + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index - 1], jnp.sin(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index], jnp.cos(2 * jnp.pi * order_index * self.quadpoints)), \ - data[1] + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index - 1], 2*jnp.pi *order_index *jnp.cos(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index], -2*jnp.pi *order_index *jnp.sin(2 * jnp.pi * order_index * self.quadpoints)), \ - data[2] + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index - 1], -4*jnp.pi**2*order_index**2*jnp.sin(2 * jnp.pi * order_index * self.quadpoints)) + jnp.einsum("ij,k->ikj", self._curves[:, :, 2 * order_index], -4*jnp.pi**2*order_index**2*jnp.cos(2 * jnp.pi * order_index * self.quadpoints)) - gamma = jnp.einsum("ij,k->ikj", self._curves[:, :, 0], jnp.ones(self.n_segments)) - gamma_dash = jnp.zeros((jnp.size(self._curves, 0), self.n_segments, 3)) - gamma_dashdash = jnp.zeros((jnp.size(self._curves, 0), self.n_segments, 3)) - gamma, gamma_dash, gamma_dashdash = fori_loop(1, self._order+1, fori_createdata, (gamma, gamma_dash, gamma_dashdash)) - length = jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in gamma_dash]) - curvature = vmap(compute_curvature)(gamma_dash, gamma_dashdash) - self._gamma = gamma - self._gamma_dash = gamma_dash - self._gamma_dashdash = gamma_dashdash - self._curvature = curvature - self._length = length - + self.quadpoints = jnp.linspace(0, 1, self._n_segments, endpoint=False) + self._curves = None + self._gamma = None + self._gamma_dash = None + self._gamma_dashdash = None + self._length = None + self._curvature = None + + # reset_cache method + def reset_cache(self): + self._curves = None + self._gamma = None + self._gamma_dash = None + self._gamma_dashdash = None + self._curvature = None + self._length = None + + # dofs property and setter @property def dofs(self): - return self._dofs + return jnp.array(self._dofs) @dofs.setter def dofs(self, new_dofs): - assert isinstance(new_dofs, jnp.ndarray) - assert new_dofs.ndim == 3 - assert jnp.size(new_dofs, 1) == 3 - assert jnp.size(new_dofs, 2) % 2 == 1 + self.reset_cache() self._dofs = new_dofs - self._order = jnp.size(new_dofs, 2)//2 - self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) - self._set_gamma() - - @property - def curves(self): - return self._curves - - @property - def order(self): - return self._order - @order.setter - def order(self, new_order): - assert isinstance(new_order, int) - assert new_order > 0 - self._dofs = jnp.pad(self.dofs, ((0, 0), (0, 0), (0, 2*(new_order-self._order)))) if new_order > self._order else self.dofs[:, :, :2*(new_order)+1] - self._order = new_order - self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) - self._set_gamma() - + # n_segments property and setter @property def n_segments(self): return self._n_segments @n_segments.setter def n_segments(self, new_n_segments): - assert isinstance(new_n_segments, int) - assert new_n_segments > 2 + self.reset_cache() self._n_segments = new_n_segments self.quadpoints = jnp.linspace(0, 1, self._n_segments, endpoint=False) - self._set_gamma() - + + # nfp property and setter @property def nfp(self): return self._nfp @nfp.setter def nfp(self, new_nfp): - assert isinstance(new_nfp, int) - assert new_nfp > 0 + self.reset_cache() self._nfp = new_nfp - self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) - self._set_gamma() - + + # stellsym property and setter @property def stellsym(self): return self._stellsym @stellsym.setter def stellsym(self, new_stellsym): - assert isinstance(new_stellsym, bool) + self.reset_cache() self._stellsym = new_stellsym - self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) - self._set_gamma() + # order property and setter @property - def gamma(self): - return self._gamma + def order(self): + return self.dofs.shape[2]//2 - @gamma.setter - def gamma(self, new_gamma): - self._gamma = new_gamma - + @order.setter + def order(self, new_order): + self.reset_cache() + self._dofs = jnp.pad(self.dofs, ((0,0), (0,0), (0, max(0, 2*(new_order-self.order)))))[:, :, :2*(new_order)+1] + + # n_base_curves property @property - def gamma_dash(self): - return self._gamma_dash - - @gamma_dash.setter - def gamma_dash(self, new_gamma_dash): - self._gamma_dash = new_gamma_dash + def n_base_curves(self): + return self.dofs.shape[0] + # curves property + @property + def curves(self): + if self._curves is None: + self._curves = apply_symmetries_to_curves(self.dofs, self.nfp, self.stellsym) + return self._curves - + # _compute_gamma method + @jit + def _compute_gamma(self): + def create_data(order: int) -> jnp.ndarray: + return jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order - 1], jnp.sin(2 * jnp.pi * order * self.quadpoints)) \ + + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order], jnp.cos(2 * jnp.pi * order * self.quadpoints)) + gamma_0 = jnp.einsum("ij,k->ikj", self.curves[:, :, 0], jnp.ones(self.n_segments)) + gamma_n = vmap(create_data)(jnp.arange(1, self.order+1)) + return gamma_0 + jnp.sum(gamma_n, axis=0) + + # TODO change gamma from a property to a method + # gamma property + @property + def gamma(self): + return self._compute_gamma() + + # _compute_gamma_dash method + @jit + def _compute_gamma_dash(self): + def create_data(order: int) -> jnp.ndarray: + return jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order - 1], 2*jnp.pi * order * jnp.cos(2 * jnp.pi * order * self.quadpoints)) \ + + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order], -2 * jnp.pi * order * jnp.sin(2 * jnp.pi * order * self.quadpoints)) + gamma_dash_n = vmap(create_data)(jnp.arange(1, self.order+1)) + return jnp.sum(gamma_dash_n, axis=0) + + # gamma_dash property + @property + def gamma_dash(self): + return self._compute_gamma_dash() + + # _compute_gamma_dashdash method + @jit + def _compute_gamma_dashdash(self): + def create_data(order: int) -> jnp.ndarray: + return jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order - 1], -4*jnp.pi**2 * order**2 * jnp.sin(2 * jnp.pi * order * self.quadpoints)) \ + + jnp.einsum("ij,k->ikj", self.curves[:, :, 2 * order], -4*jnp.pi**2 * order**2 * jnp.cos(2 * jnp.pi * order * self.quadpoints)) + gamma_dashdash_n = vmap(create_data)(jnp.arange(1, self.order+1)) + return jnp.sum(gamma_dashdash_n, axis=0) + + # gamma_dashdash property @property def gamma_dashdash(self): - return self._gamma_dashdash + return self._compute_gamma_dashdash() - @gamma_dashdash.setter - def gamma_dashdash(self, new_gamma_dashdash): - self._gamma_dashdash = new_gamma_dashdash - + # length property @property def length(self): + if self._length is None: + self._length = jnp.mean(jnp.linalg.norm(self.gamma_dash, axis=2), axis=1) return self._length + # compute_curvature static method + @staticmethod + @jit + def compute_curvature(gammadash, gammadashdash): + return jnp.linalg.norm(jnp.cross(gammadash, gammadashdash, axis=1), axis=1) / jnp.linalg.norm(gammadash, axis=1)**3 + + # curvature property @property def curvature(self): - return self._curvature + return vmap(self.compute_curvature)(self.gamma_dash, self.gamma_dashdash) + + # copy method + def copy(self): + deep_copy = tree_util.tree_map(lambda x: x.copy(), self) + return deep_copy + + # magic methods + def __str__(self): + return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ + + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" + + def __repr__(self): + return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ + + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" def __len__(self): - return jnp.size(self.curves, 0) + return self.curves.shape[0] def __getitem__(self, key): if isinstance(key, int): @@ -308,9 +322,12 @@ def wrap(data): pointData = {**pointData, **extra_data} polyLinesToVTK(str(filename), np.array(x), np.array(y), np.array(z), pointsPerLine=np.array(ppl), pointData=pointData) -class Curves_from_simsopt(Curves): - # This assumes curves have all nfp and stellsym symmetries - def __init__(self, simsopt_curves, nfp=1, stellsym=True): + @classmethod + def from_simsopt(cls, simsopt_curves, nfp=1, stellsym=True): + """ + Create a Curves object from a list of simsopt curves. + This assumes curves have all nfp and stellsym symmetries. + """ if isinstance(simsopt_curves, str): from simsopt import load bs = load(simsopt_curves) @@ -321,79 +338,208 @@ def __init__(self, simsopt_curves, nfp=1, stellsym=True): [curve.x for curve in simsopt_curves] ), (len(simsopt_curves), 3, 2*simsopt_curves[0].order+1)) n_segments = len(simsopt_curves[0].quadpoints) - super().__init__(dofs, n_segments, nfp, stellsym) + return cls(dofs, n_segments, nfp, stellsym) + + def _tree_flatten(self): + children = (self._dofs,) # arrays / dynamic values + aux_data = {"n_segments": self._n_segments, + "nfp": self._nfp, + "stellsym": self._stellsym} # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) tree_util.register_pytree_node(Curves, Curves._tree_flatten, Curves._tree_unflatten) -class Coils(Curves): +# TODO: change currents logic: save dofs_currents as dynamic -> alter main +class Coils: + """ Class to store the coils + + Attributes: + curves (Curves): Curves object storing the coil geometry + dofs_currents_raw (jnp.ndarray - shape (n_base_curves,)): Non-normalized currents of the base curves + currents_scale (float): Normalization factor for the currents + dofs_currents (jnp.ndarray - shape (n_base_curves,)): Normalized currents of the base curves + currents (jnp.ndarray - shape (n_base_curves * nfp * (1 + stellsym),)): Currents obtained by applying symmetries to the base currents + dofs_curves (jnp.ndarray - shape (n_base_curves, 3, 2*order+1)): Degrees of freedom of the curves + dofs (jnp.ndarray - shape (n_base_curves * 3 * (2 * order + 1) + n_base_curves,)): Degrees of freedom of the coils (curves and normalized currents) + + """ def __init__(self, curves: Curves, currents: jnp.ndarray): - assert isinstance(curves, Curves) - currents = jnp.array(currents) - assert jnp.size(currents) == jnp.size(curves.dofs, 0) - super().__init__(curves.dofs, curves.n_segments, curves.nfp, curves.stellsym) - self._currents_scale = jnp.mean(jnp.abs(currents)) - self._dofs_currents = currents/self._currents_scale - self._currents = apply_symmetries_to_currents(self._dofs_currents*self._currents_scale, self.nfp, self.stellsym) + # if hasattr(curves, 'n_base_curves') and hasattr(currents, 'size'): + # assert curves.n_base_curves == currents.size, "Number of base curves and number of currents must be the same" - def __str__(self): - return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ - + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" \ - + f"Currents degrees of freedom\n{repr(self.dofs_currents.tolist())}\n" \ - + f"Currents scaling factor\n{self.currents_scale}\n" - - def __repr__(self): - return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ - + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" \ - + f"Currents degrees of freedom\n{repr(self.dofs_currents.tolist())}\n" \ - + f"Currents scaling factor\n{self.currents_scale}\n" + self.curves = curves + self._dofs_currents_raw = currents # Non-normalized base currents + self._currents_scale = None + self._dofs_currents = None + self._currents = None + + # reset_cache method + def reset_cache(self): + self._dofs_currents = None + self._currents_scale = None + self._currents = None + + # dofs_curves property and setter @property def dofs_curves(self): - return self._dofs + return self.curves.dofs @dofs_curves.setter def dofs_curves(self, new_dofs_curves): - self.dofs = new_dofs_curves + self.curves.dofs = new_dofs_curves + + # dofs_currents_raw property and setter + @property + def dofs_currents_raw(self): + return jnp.array(self._dofs_currents_raw) + + @dofs_currents_raw.setter + def dofs_currents_raw(self, new_dofs_currents_raw): + self.reset_cache() + self._dofs_currents_raw = new_dofs_currents_raw + + # currents_scale property and setter + @property + def currents_scale(self): + if self._currents_scale is None: + self._currents_scale = jnp.mean(jnp.abs(self.dofs_currents_raw)) + return self._currents_scale + @currents_scale.setter + def currents_scale(self, new_currents_scale): + self._dofs_currents_raw = self.dofs_currents * new_currents_scale + self._currents_scale = new_currents_scale + self._currents = None + + # dofs_currents property and setter @property def dofs_currents(self): + if self._dofs_currents is None: + self._dofs_currents = self.dofs_currents_raw / self.currents_scale return self._dofs_currents @dofs_currents.setter def dofs_currents(self, new_dofs_currents): - self._dofs_currents = new_dofs_currents - self._currents = apply_symmetries_to_currents(self._dofs_currents*self.currents_scale, self.nfp, self.stellsym) - + self.dofs_currents_raw = new_dofs_currents * self.currents_scale + + # currents property @property - def currents_scale(self): - return self._currents_scale + def currents(self): + if self._currents is None: + self._currents = apply_symmetries_to_currents(self.dofs_currents_raw, self.nfp, self.stellsym) + return self._currents + + # dofs property and setter + @property + def dofs(self): + return jnp.hstack([self.dofs_curves.ravel(), self.dofs_currents]) - @currents_scale.setter - def currents_scale(self, new_currents_scale): - self._currents_scale = new_currents_scale - self._currents = apply_symmetries_to_currents(self.dofs_currents*new_currents_scale, self.nfp, self.stellsym) + @dofs.setter + def dofs(self, new_dofs): + n_curve_dofs = jnp.size(self.dofs_curves) + self.dofs_curves = jnp.reshape(new_dofs[:n_curve_dofs], self.dofs_curves.shape) + self.dofs_currents = new_dofs[n_curve_dofs:] + # TODO: remove x property. This is a placeholder for compatibility with the examples that need to be updated. + # x property and setter @property def x(self): - dofs_curves = jnp.ravel(self.dofs_curves) - dofs_currents = jnp.ravel(self.dofs_currents) - return jnp.concatenate((dofs_curves, dofs_currents)) + return self.dofs @x.setter def x(self, new_dofs): - old_dofs_curves = jnp.ravel(self.dofs) - old_dofs_currents = jnp.ravel(self.dofs_currents) - new_dofs_curves = new_dofs[:old_dofs_curves.shape[0]] - new_dofs_currents = new_dofs[old_dofs_curves.shape[0]:] - self.dofs_curves = jnp.reshape(new_dofs_curves, (self.dofs_curves.shape)) - self.dofs_currents = new_dofs_currents - + self.dofs = new_dofs + + # currents property @property def currents(self): + if self._currents is None: + self._currents = apply_symmetries_to_currents(self.dofs_currents*self.currents_scale, self.nfp, self.stellsym) return self._currents + # gamma property + @property + def gamma(self): + return self.curves.gamma + + # gamma_dash property + @property + def gamma_dash(self): + return self.curves.gamma_dash + + # gamma_dashdash property + @property + def gamma_dashdash(self): + return self.curves.gamma_dashdash + + # length property + @property + def length(self): + return self.curves.length + + # curvature property + @property + def curvature(self): + return self.curves.curvature + + # nfp property + @property + def nfp(self): + return self.curves.nfp + + # stellsym property + @property + def stellsym(self): + return self.curves.stellsym + + # order property + @property + def order(self): + return self.curves.order + + # n_segments property and setter + @property + def n_segments(self): + return self.curves.n_segments + + @n_segments.setter + def n_segments(self, new_n_segments): + self.curves.n_segments = new_n_segments + + # copy method + def copy(self): + coils = Coils(self.curves.copy(), self.dofs_currents_raw.copy()) + + # Initialize caches + coils._dofs_currents = self.dofs_currents + coils._currents_scale = self.currents_scale + coils._currents = self._currents + + return coils + + # magic methods + def __str__(self): + return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ + + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" \ + + f"Currents degrees of freedom\n{repr(self.dofs_currents.tolist())}\n" \ + + f"Currents scaling factor\n{self.currents_scale}\n" + + def __repr__(self): + return f"nfp stellsym order\n{self.nfp} {self.stellsym} {self.order}\n"\ + + f"Degrees of freedom\n{repr(self.dofs.tolist())}\n" \ + + f"Currents degrees of freedom\n{repr(self.dofs_currents.tolist())}\n" \ + + f"Currents scaling factor\n{self.currents_scale}\n" + + def __len__(self): + return len(self.curves) + def __getitem__(self, key): if isinstance(key, int): return Coils(Curves(jnp.expand_dims(self.curves[key], 0), self.n_segments, 1, False), jnp.expand_dims(self.currents[key], 0)) @@ -425,15 +571,6 @@ def __eq__(self, other): return jnp.all(self.dofs == other.dofs) and jnp.all(self.dofs_currents == other.dofs_currents) else: raise TypeError(f"Invalid argument type. Got {type(other)}, expected Coils.") - - def __ne__(self, other): - return not self.__eq__(other) - - - def _tree_flatten(self): - children = (Curves(self.dofs, self.n_segments, self.nfp, self.stellsym), self._dofs_currents) # arrays / dynamic values - aux_data = {} # static values - return (children, aux_data) def save_coils(self, filename: str, text=""): """ @@ -472,34 +609,59 @@ def to_json(self, filename: str): "dofs_currents": self.dofs_currents.tolist(), } import json - with open(filename, "w") as file: + with open(filename, 'w') as file: json.dump(data, file) + + def plot(self, *args, **kwargs): + self.curves.plot(*args, **kwargs) + + def to_vtk(self, *args, **kwargs): + self.curves.to_vtk(*args, **kwargs) -class Coils_from_json(Coils): - def __init__(self, filename: str): - import json - with open(filename , "r") as file: - data = json.load(file) - super().__init__(Curves(jnp.array(data["dofs_curves"]), data["n_segments"], data["nfp"], data["stellsym"]), data["dofs_currents"]) - -class Coils_from_simsopt(Coils): - # This assumes coils have all nfp and stellsym symmetries - def __init__(self, simsopt_coils, nfp=1, stellsym=True): + @classmethod + def from_simsopt(cls, simsopt_coils, nfp=1, stellsym=True): + """ This assumes coils have all nfp and stellsym symmetries""" if isinstance(simsopt_coils, str): from simsopt import load bs = load(simsopt_coils) simsopt_coils = bs.coils curves = [c.curve for c in simsopt_coils] currents = jnp.array([c.current.get_value() for c in simsopt_coils[0:int(len(simsopt_coils)/nfp/(1+stellsym))]]) - super().__init__(Curves_from_simsopt(curves, nfp, stellsym), currents) + return cls(Curves.from_simsopt(curves, nfp, stellsym), currents) + + @classmethod + def from_json(cls, filename: str): + """ Creates a Coils object from a json file""" + import json + with open(filename, "r") as file: + data = json.load(file) + curves = Curves(jnp.array(data["dofs_curves"]), data["n_segments"], data["nfp"], data["stellsym"]) + currents = jnp.array(data["dofs_currents"]) + return cls(curves, currents) + + def _tree_flatten(self): + children = (self.curves, self._dofs_currents_raw) # arrays / dynamic values + aux_data = {} # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) tree_util.register_pytree_node(Coils, Coils._tree_flatten, Coils._tree_unflatten) -def CreateEquallySpacedCurves(n_curves: int, order: int, R: float, r: float, n_segments: int = 100, - nfp: int = 1, stellsym: bool = False) -> jnp.ndarray: +def CreateEquallySpacedCurves(n_curves: int, + order: int, + R: float, + r: float, + n_segments: int = 100, + nfp: int = 1, + stellsym: bool = False) -> Curves: + """ Creates n_curves equally spaced on a torus of major radius R and minor radius r using Fourier + representation up to the specified order.""" angles = (jnp.arange(n_curves) + 0.5) * (2 * jnp.pi) / ((1 + int(stellsym)) * nfp * n_curves) curves = jnp.zeros((n_curves, 3, 1 + 2 * order)) @@ -510,17 +672,15 @@ def CreateEquallySpacedCurves(n_curves: int, order: int, R: float, r: float, n_s curves = curves.at[:, 2, 1].set(-r) # z[1] (constant for all) return Curves(curves, n_segments=n_segments, nfp=nfp, stellsym=stellsym) +@partial(jit, static_argnames=["flip"]) def RotatedCurve(curve, phi, flip): - rotmat = jnp.array( - [[jnp.cos(phi), -jnp.sin(phi), 0], - [jnp.sin(phi), jnp.cos(phi), 0], - [0, 0, 1]]).T + rotmat_T = jnp.array( + [[ jnp.cos(phi), jnp.sin(phi), 0], + [-jnp.sin(phi), jnp.cos(phi), 0], + [ 0, 0, 1]]) if flip: - rotmat = rotmat @ jnp.array( - [[1, 0, 0], - [0, -1, 0], - [0, 0, -1]]) - return curve @ rotmat + rotmat_T = rotmat_T @ jnp.diag(jnp.array([1, -1, -1])) + return curve @ rotmat_T @partial(jit, static_argnames=['nfp', 'stellsym']) def apply_symmetries_to_curves(base_curves, nfp, stellsym): @@ -529,11 +689,8 @@ def apply_symmetries_to_curves(base_curves, nfp, stellsym): for k in range(0, nfp): for flip in flip_list: for i in range(len(base_curves)): - if k == 0 and not flip: - curves.append(base_curves[i]) - else: - rotcurve = RotatedCurve(base_curves[i].T, 2*jnp.pi*k/nfp, flip) - curves.append(rotcurve.T) + rotcurve = RotatedCurve(base_curves[i].T, 2*jnp.pi*k/nfp, flip) + curves.append(rotcurve.T) return jnp.array(curves) @partial(jit, static_argnames=['nfp', 'stellsym']) diff --git a/essos/dynamics.py b/essos/dynamics.py index 2de1e98c..d3b7089d 100644 --- a/essos/dynamics.py +++ b/essos/dynamics.py @@ -1,3 +1,4 @@ +from pyexpat import model import jax jax.config.update("jax_enable_x64", True) import jax.numpy as jnp @@ -53,7 +54,7 @@ def compute_orbit_params(xyz, vpar): class Particles(): def __init__(self, initial_xyz=None, initial_vparallel_over_v=None, charge=ALPHA_PARTICLE_CHARGE, mass=ALPHA_PARTICLE_MASS, energy=FUSION_ALPHA_PARTICLE_ENERGY, min_vparallel_over_v=-1, - max_vparallel_over_v=1, field=None, initial_vxvyvz=None, initial_xyz_fullorbit=None): + max_vparallel_over_v=1, field=None, initial_vxvyvz=None, initial_xyz_fullorbit=None, phase_angle_full_orbit = 0): self.charge = charge self.mass = mass self.energy = energy @@ -85,6 +86,21 @@ def to_full_orbit(self, field): self.initial_xyz_fullorbit, self.initial_vxvyvz = gc_to_fullorbit(field=field, initial_xyz=self.initial_xyz, initial_vparallel=self.initial_vparallel, total_speed=self.total_speed, mass=self.mass, charge=self.charge, phase_angle_full_orbit=self.phase_angle_full_orbit) + + def join(self, other, field=None): + assert isinstance(other, Particles), "Cannot join with non-Particles object" + assert self.charge == other.charge, "Cannot join particles with different charges" + assert self.mass == other.mass, "Cannot join particles with different masses" + assert self.energy == other.energy, "Cannot join particles with different energies" + + charge = self.charge + mass = self.mass + energy = self.energy + initial_xyz = jnp.concatenate((self.initial_xyz, other.initial_xyz), axis=0) + initial_vparallel_over_v = jnp.concatenate((self.initial_vparallel_over_v, other.initial_vparallel_over_v), axis=0) + + return Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, charge=charge, mass=mass, energy=energy, field=field) + @partial(jit, static_argnums=(2)) def GuidingCenterCollisionsDiffusionMu(t, initial_condition, @@ -584,50 +600,6 @@ def condition_BioSavart(t, y, args, **kwargs): self._trajectories = self.trace() - if self.particles is not None: - self.energy = jnp.zeros((self.particles.nparticles, self.times_to_trace)) - - if model == 'GuidingCenter' or model == 'GuidingCenterAdaptative' : - @jit - def compute_energy_gc(trajectory): - xyz = trajectory[:, :3] - vpar = trajectory[:, 3] - AbsB = vmap(self.field.AbsB)(xyz) - mu = (self.particles.energy - self.particles.mass * vpar[0]**2 / 2) / AbsB[0] - return self.particles.mass * vpar**2 / 2 + mu * AbsB - self.energy = vmap(compute_energy_gc)(self._trajectories) - elif model == 'GuidingCenterCollisions': - @jit - def compute_energy_gc(trajectory): - return 0.5*self.particles.mass* trajectory[:, 3]**2 - self.energy = vmap(compute_energy_gc)(self._trajectories) - elif model == 'GuidingCenterCollisionsMuIto' or model == 'GuidingCenterCollisionsMuFixed' or model == 'GuidingCenterCollisionsMuAdaptative' : - @jit - def compute_energy_gc(trajectory): - xyz = trajectory[:, :3] - vpar = trajectory[:, 3]*SPEED_OF_LIGHT - mu = trajectory[:, 4]*self.particles.mass*SPEED_OF_LIGHT**2 - AbsB = vmap(self.field.AbsB)(xyz) - return self.particles.mass * vpar**2 / 2 + mu*AbsB - self.energy = vmap(compute_energy_gc)(self._trajectories) - @jit - def compute_vperp_gc(trajectory): - xyz = trajectory[:, :3] - mu = trajectory[:, 4]*self.particles.mass*SPEED_OF_LIGHT**2 - AbsB = vmap(self.field.AbsB)(xyz) - return jnp.sqrt(2.*mu*AbsB/self.particles.mass) - self.vperp_final = vmap(compute_vperp_gc)(self._trajectories) - elif model == 'FullOrbit' or model == 'FullOrbit_Boris' or model == 'FullOrbitCollisions': - @jit - def compute_energy_fo(trajectory): - vxvyvz = trajectory[:, 3:] - return self.particles.mass / 2 * (vxvyvz[:, 0]**2 + vxvyvz[:, 1]**2 + vxvyvz[:, 2]**2) - self.energy = vmap(compute_energy_fo)(self._trajectories) - elif model == 'FieldLine' or model== 'FieldLineAdaptative': - self.energy = jnp.ones((len(initial_conditions), self.times_to_trace)) - - - self.trajectories_xyz = vmap(lambda xyz: vmap(lambda point: self.field.to_xyz(point[:3]))(xyz))(self.trajectories) if isinstance(field, Vmec): @@ -641,11 +613,8 @@ def compute_energy_fo(trajectory): else: self.loss_fractions, self.total_particles_lost, self.lost_times = self.loss_fraction_BioSavart(boundary) else: - self.loss_fractions = None - self.total_particles_lost = None - self.loss_times = None + self.trajectories_xyz = self.trajectories - @partial(jit, static_argnums=(0)) def trace(self): @jit def compute_trajectory(initial_condition, particle_key) -> jnp.ndarray: @@ -848,7 +817,7 @@ def update_state(state, _): solver=diffrax.Dopri8(), args=self.args, saveat=SaveAt(ts=self.times), - throw=False, + throw=True, # adjoint=DirectAdjoint(), progress_meter=self.progress_meter, max_steps=10000000000, @@ -871,15 +840,41 @@ def trajectories(self): def trajectories(self, value): self._trajectories = value - def _tree_flatten(self): - children = (self.trajectories,) # arrays / dynamic values - aux_data = {'field': self.field, 'model': self.model} # static values - return (children, aux_data) - - @classmethod - def _tree_unflatten(cls, aux_data, children): - return cls(*children, **aux_data) + def energy(self): + assert 'GuidingCenter' in self.model or 'FullOrbit' in self.model, "Energy calculation is only available for GuidingCenter and FullOrbit models" + mass = self.particles.mass + + if self.model == 'GuidingCenter' or self.model == 'GuidingCenterAdaptative' or \ + self.model == 'GuidingCenterCollisionsMuIto' or self.model == 'GuidingCenterCollisionsMuFixed' or self.model == 'GuidingCenterCollisionsMuAdaptative': + initial_xyz = self.initial_conditions[:, :3] + initial_vparallel = self.initial_conditions[:, 3] + initial_B = vmap(self.field.AbsB)(initial_xyz) + mu_array = (self.particles.energy - 0.5 * mass * jnp.square(initial_vparallel)) / initial_B + def compute_energy(trajectory, mu): + xyz = trajectory[:, :3] + vpar = trajectory[:, 3] + AbsB = vmap(self.field.AbsB)(xyz) + return 0.5 * mass * jnp.square(vpar) + mu * AbsB + + energy = vmap(compute_energy)(self.trajectories, mu_array) + elif self.model == 'GuidingCenterCollisions': + def compute_energy(trajectory): + return 0.5 * mass * trajectory[:, 3]**2 + energy = vmap(compute_energy)(self.trajectories) + + elif self.model == 'FullOrbit': + def compute_energy(trajectory): + vxvyvz = trajectory[:, 3:] + v_squared = jnp.sum(jnp.square(vxvyvz), axis=1) + return 0.5 * mass * v_squared + energy = vmap(compute_energy)(self.trajectories) + + elif self.model == 'FieldLine' or self.model == 'FieldLineAdaptative': + energy = jnp.ones((len(self.initial_conditions), self.times_to_trace)) + + return energy + def to_vtk(self, filename): try: import numpy as np except ImportError: raise ImportError("The 'numpy' library is required. Please install it using 'pip install numpy'.") @@ -899,7 +894,7 @@ def plot(self, ax=None, show=True, axis_equal=True, n_trajectories_plot=5, **kwa trajectories_xyz = jnp.array(self.trajectories_xyz) n_trajectories_plot = jnp.min(jnp.array([n_trajectories_plot, trajectories_xyz.shape[0]])) for i in random.choice(random.PRNGKey(0), trajectories_xyz.shape[0], (n_trajectories_plot,), replace=False): - ax.plot(trajectories_xyz[i, :, 0], trajectories_xyz[i, :, 1], trajectories_xyz[i, :, 2], linewidth=0.5, **kwargs) + ax.plot(trajectories_xyz[i, :, 0], trajectories_xyz[i, :, 1], trajectories_xyz[i, :, 2], **kwargs) ax.grid(False) if axis_equal: fix_matplotlib_3d(ax) @@ -966,7 +961,7 @@ def poincare_plot(self, shifts = [jnp.pi/2], orientation = 'toroidal', length = """ Plot Poincare plots using scipy to find the roots of an interpolation. Can take particle trace or field lines. Args: - shifts (list, optional): Apply a linear shift to dependent data. Default is [0]. + shifts (list, optional): Apply a linear shift to dependent data. Default is [pi/2]. orientation (str, optional): 'toroidal' - find time values when toroidal angle = shift [0, 2pi]. 'z' - find time values where z coordinate = shift. Default is 'toroidal'. @@ -1053,7 +1048,18 @@ def process_trajectory(X_i, Y_i, T_i): plt.show() return plotting_data - + + def _tree_flatten(self): + children = (self.trajectories, self.initial_conditions, self.times) # arrays / dynamic values + aux_data = {'field': self.field, 'electric_field': self.electric_field, 'model': self.model, 'maxtime': self.maxtime, 'timestep': self.timestep, + 'rtol': self.rtol, 'atol': self.atol, 'particles': self.particles, 'condition': self.condition, 'tag_gc': self.tag_gc} # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) + + tree_util.register_pytree_node(Tracing, Tracing._tree_flatten, Tracing._tree_unflatten) diff --git a/essos/fields.py b/essos/fields.py index 4689e76f..b8c324f6 100644 --- a/essos/fields.py +++ b/essos/fields.py @@ -1,82 +1,110 @@ import jax jax.config.update("jax_enable_x64", True) from jax import vmap -from essos.coils import compute_curvature +from essos.coils import Curves import jax.numpy as jnp from functools import partial from jax import jit, jacfwd, grad, vmap, tree_util, lax -from essos.surfaces import SurfaceRZFourier, BdotN_over_B,SurfaceClassifier +from essos.surfaces import SurfaceRZFourier, BdotN_over_B, SurfaceClassifier from essos.plot import fix_matplotlib_3d from essos.util import newton -class BiotSavart(): - def __init__(self, coils): - self.coils = coils - self.currents = coils.currents - self.gamma = coils.gamma - self.gamma_dash = coils.gamma_dash - self.gamma_dashdash = coils.gamma_dashdash - self.coils_length=jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in self.gamma_dash]) - self.coils_curvature= vmap(compute_curvature)(self.gamma_dash, coils.gamma_dashdash) - self.r_axis=jnp.mean(jnp.sqrt(vmap(lambda dofs: dofs[0, 0]**2 + dofs[1, 0]**2)(self.coils.dofs_curves))) - self.z_axis=jnp.mean(vmap(lambda dofs: dofs[2, 0])(self.coils.dofs_curves)) - +class MagneticField(): + def __init__(self): + pass - @partial(jit, static_argnames=['self']) + @jit def sqrtg(self, points): - return 1. + raise NotImplementedError("sqrtg method not implemented") - @partial(jit, static_argnames=['self']) + @jit def B(self, points): - dif_R = (jnp.array(points)-self.gamma).T - dB = jnp.cross(self.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0)/jnp.linalg.norm(dif_R, axis=0)**3 - dB_sum = jnp.einsum("i,bai", self.currents*1e-7, dB, optimize="greedy") - return jnp.mean(dB_sum, axis=0) - - @partial(jit, static_argnames=['self']) + raise NotImplementedError("B method not implemented") + + @jit def B_covariant(self, points): return self.B(points) - - @partial(jit, static_argnames=['self']) + + @jit def B_contravariant(self, points): return self.B(points) - @partial(jit, static_argnames=['self']) + @jit def AbsB(self, points): return jnp.linalg.norm(self.B(points)) - @partial(jit, static_argnames=['self']) + @jit def dB_by_dX(self, points): return jacfwd(self.B)(points) - - @partial(jit, static_argnames=['self']) + @jit def dAbsB_by_dX(self, points): return grad(self.AbsB)(points) - @partial(jit, static_argnames=['self']) + @jit def grad_B_covariant(self, points): - return jacfwd(self.B_covariant)(points) - - @partial(jit, static_argnames=['self']) + return jacfwd(self.B_covariant)(points) + + @jit def curl_B(self, points): grad_B_cov=self.grad_B_covariant(points) - return jnp.array([grad_B_cov[2][1] -grad_B_cov[1][2], - grad_B_cov[0][2] -grad_B_cov[2][0], - grad_B_cov[1][0] -grad_B_cov[0][1]])/self.sqrtg(points) - - @partial(jit, static_argnames=['self']) - def curl_b(self, points): - return self.curl_B(points)/self.AbsB(points)+jnp.cross(self.B_covariant(points),jnp.array(self.dAbsB_by_dX(points)))/self.AbsB(points)**2/self.sqrtg(points) + return jnp.array([grad_B_cov[2][1] - grad_B_cov[1][2], + grad_B_cov[0][2] - grad_B_cov[2][0], + grad_B_cov[1][0] - grad_B_cov[0][1]])/self.sqrtg(points) - @partial(jit, static_argnames=['self']) + @jit + def curl_b(self, points): + return self.curl_B(points) / self.AbsB(points) + jnp.cross(self.B_covariant(points), jnp.array(self.dAbsB_by_dX(points))) / self.AbsB(points)**2 / self.sqrtg(points) + + @jit def kappa(self, points): - return -jnp.cross(self.B_contravariant(points),self.curl_b(points))*self.sqrtg(points)/self.AbsB(points) + return -jnp.cross(self.B_contravariant(points), self.curl_b(points)) * self.sqrtg(points) / self.AbsB(points) - @partial(jit, static_argnames=['self']) + @jit + def to_xyz(self, points): + raise NotImplementedError("to_xyz method not implemented") + +class BiotSavart(MagneticField): + def __init__(self, coils): + self.coils = coils + # self.r_axis=jnp.mean(jnp.sqrt(vmap(lambda dofs: dofs[0, 0]**2 + dofs[1, 0]**2)(self.coils.dofs_curves))) + # self.z_axis=jnp.mean(vmap(lambda dofs: dofs[2, 0])(self.coils.dofs_curves)) + + @property + def dofs(self): + return self.coils.dofs + @dofs.setter + def dofs(self, new_dofs): + self.coils.dofs = new_dofs + + @jit + def sqrtg(self, points): + return 1. + + @jit + def B(self, points): + dif_R = (jnp.array(points) - self.coils.gamma).T + dB = jnp.cross(self.coils.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0) / jnp.linalg.norm(dif_R, axis=0)**3 + dB_sum = jnp.einsum("i,bai", self.coils.currents*1e-7, dB, optimize="greedy") + return jnp.mean(dB_sum, axis=0) + + @jit def to_xyz(self, points): return points + + def _tree_flatten(self): + children = (self.coils,) + aux_data = {} + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) +tree_util.register_pytree_node(BiotSavart, + BiotSavart._tree_flatten, + BiotSavart._tree_unflatten) + @jit def d_dtheta_fft(f_theta): ntheta = f_theta.shape[-1] @@ -109,84 +137,69 @@ def gamma_dashdash_from_gamma(gamma): d2_dtheta2_fft(gamma[..., 2]), ], axis=-1) -class BiotSavart_from_gamma(): - def __init__(self, gamma,gamma_dash=None,gamma_dashdash=None, currents=None): - if currents is None: - currents = jnp.ones(len(gamma)) - else: - currents = currents +class BiotSavart_from_gamma(MagneticField): + def __init__(self, gamma, gamma_dash=None, gamma_dashdash=None, currents=None): self.currents = currents self.gamma = gamma - self.r_axis=jnp.average(jnp.linalg.norm(jnp.average(gamma,axis=1)[:,0:2],axis=1)) - self.z_axis=jnp.average(jnp.average(gamma,axis=1)[:,2]) - if gamma_dash is not None: - self.gamma_dash = gamma_dash - else: - self.gamma_dash = gamma_dash_from_gamma(gamma) - self.coils_length=jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in self.gamma_dash]) - if gamma_dashdash is not None: - self.gamma_dashdash = gamma_dashdash - else: - self.gamma_dashdash = gamma_dashdash_from_gamma(gamma) - self.coils_curvature= vmap(compute_curvature)(self.gamma_dash, self.gamma_dashdash) + self._gamma_dash = gamma_dash + self._gamma_dashdash = gamma_dashdash + + self.coils_length = None + self.coils_curvature = None + self.r_axis = None + self.z_axis = None + + @property + def gamma_dash(self): + if self._gamma_dash is None: + self._gamma_dash = gamma_dash_from_gamma(self.gamma) + return self._gamma_dash + + @property + def gamma_dashdash(self): + if self._gamma_dashdash is None: + self._gamma_dashdash = gamma_dashdash_from_gamma(self.gamma) + return self._gamma_dashdash + + @property + def coils_length(self): + if self.coils_length is None: + self.coils_length = jnp.array([jnp.mean(jnp.linalg.norm(d1gamma, axis=1)) for d1gamma in self.gamma_dash]) + return self.coils_length + @property + def coils_curvature(self): + if self._coils_curvature is None: + self._coils_curvature = vmap(Curves.compute_curvature)(self.gamma_dash, self.gamma_dashdash) + return self._coils_curvature + + @property + def r_axis(self): + if self._r_axis is None: + self._r_axis = jnp.average(jnp.linalg.norm(jnp.average(self.gamma, axis=1)[:, 0:2], axis=1)) + return self._r_axis + + @property + def z_axis(self): + if self._z_axis is None: + self._z_axis = jnp.average(jnp.average(self.gamma, axis=1)[:, 2]) + return self._z_axis + @partial(jit, static_argnames=['self']) def sqrtg(self, points): return 1. @partial(jit, static_argnames=['self']) def B(self, points): - dif_R = (jnp.array(points)-self.gamma).T - dB = jnp.cross(self.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0)/jnp.linalg.norm(dif_R, axis=0)**3 + dif_R = (jnp.array(points) - self.gamma).T + dB = jnp.cross(self.gamma_dash.T, dif_R, axisa=0, axisb=0, axisc=0) / jnp.linalg.norm(dif_R, axis=0)**3 dB_sum = jnp.einsum("i,bai", self.currents*1e-7, dB, optimize="greedy") return jnp.mean(dB_sum, axis=0) - @partial(jit, static_argnames=['self']) - def B_covariant(self, points): - return self.B(points) - - @partial(jit, static_argnames=['self']) - def B_contravariant(self, points): - return self.B(points) - - @partial(jit, static_argnames=['self']) - def AbsB(self, points): - return jnp.linalg.norm(self.B(points)) - - @partial(jit, static_argnames=['self']) - def dB_by_dX(self, points): - return jacfwd(self.B)(points) - - - @partial(jit, static_argnames=['self']) - def dAbsB_by_dX(self, points): - return grad(self.AbsB)(points) - - @partial(jit, static_argnames=['self']) - def grad_B_covariant(self, points): - return jacfwd(self.B_covariant)(points) - - @partial(jit, static_argnames=['self']) - def curl_B(self, points): - grad_B_cov=self.grad_B_covariant(points) - return jnp.array([grad_B_cov[2][1] -grad_B_cov[1][2], - grad_B_cov[0][2] -grad_B_cov[2][0], - grad_B_cov[1][0] -grad_B_cov[0][1]])/self.sqrtg(points) - - @partial(jit, static_argnames=['self']) - def curl_b(self, points): - return self.curl_B(points)/self.AbsB(points)+jnp.cross(self.B_covariant(points),jnp.array(self.dAbsB_by_dX(points)))/self.AbsB(points)**2/self.sqrtg(points) - - @partial(jit, static_argnames=['self']) - def kappa(self, points): - return -jnp.cross(self.B_contravariant(points),self.curl_b(points))*self.sqrtg(points)/self.AbsB(points) - @partial(jit, static_argnames=['self']) def to_xyz(self, points): return points - - class Vmec(): def __init__(self, wout_filename, ntheta=50, nphi=50, close=True, range_torus='full torus'): self.wout_filename = wout_filename @@ -213,7 +226,7 @@ def __init__(self, wout_filename, ntheta=50, nphi=50, close=True, range_torus='f self.s_half_grid = self.s_full_grid[1:] - 0.5 * self.ds self.r_axis = self.rmnc[0, 0] self.z_axis=self.zmns[0,0] - self.mpol = int(jnp.max(self.xm)+1) + self.mpol = int(jnp.max(self.xm)) self.ntor = int(jnp.max(jnp.abs(self.xn)) / self.nfp) self.range_torus = range_torus self._surface = SurfaceRZFourier(self, ntheta=ntheta, nphi=nphi, close=close, range_torus=range_torus) diff --git a/essos/losses.py b/essos/losses.py new file mode 100644 index 00000000..a1f965ef --- /dev/null +++ b/essos/losses.py @@ -0,0 +1,174 @@ +from functools import partial +import jax.numpy as jnp +from jax import tree_util, jit, grad as jax_grad +from jax.flatten_util import ravel_pytree + +class base_loss: + def __init__(self): + self.losses = [self] + self._dependencies = {} + self._dependencies_buffer = None + self._starting_dofs = None + self._dofs_to_pytree = None + + def clear_cache(self): + self._dependencies_buffer = None + self._starting_dofs = None + self._dofs_to_pytree = None + + @property + def dependencies(self): + return self._dependencies + + @dependencies.setter + def dependencies(self, value): + assert isinstance(value, dict), "dependencies must be a dictionary mapping dependency names to their corresponding objects." + self.clear_cache() + self._dependencies = value + + @property + def dependencies_buffer(self): + if self._dependencies_buffer is None: + self._dependencies_buffer = tree_util.tree_map(lambda x: jnp.zeros_like(x), self.dependencies) + return self._dependencies_buffer + + def __add__(self, other): + if not isinstance(other, base_loss): + raise TypeError("Addition is only defined between base_loss objects.") + + losses_list = [*self.losses, *other.losses] # Flatten the losses + out_loss = composite_loss(losses_list) + out_loss.dependencies = self.dependencies | other.dependencies + return out_loss + + def __iter__(self): + return iter(self.losses) + + def __mul__(self, other): + raise NotImplementedError("Multiplication is only defined in subclasses of base_loss.") + + def __rmul__(self, other): + return self.__mul__(other) + + +class custom_loss(base_loss): + def __init__(self, fun, *args_names, **kwargs): + """ A custom loss function that can take multiple arguments and compute gradients with respect to specified arguments. + + Args: + fun (callable): + The loss function to be optimized. It may take multiple arguments. + All dynamic arguments (i.e., those that require gradients) should be passed as positional arguments, while static arguments (i.e., those that do not require gradients) should be passed as keyword arguments. + args_names (tuple): + A tuple of strings indicating the names of the dynamic arguments. This is used for gradient computation. + *args: Dynamic (differentiable) arguments to be passed to the loss function. + **kwargs: Static (non-differentiable) keyword arguments to be passed to the loss function. + + Returns: + custom_loss: An instance of the custom_loss class. + """ + super().__init__() + self.fun = fun + self.args_names = args_names + self.kwargs = kwargs + + # The dofs of a custom loss are the dofs of its arguments + @property + def starting_dofs(self): + if self._starting_dofs is None: + self._starting_dofs, self._dofs_to_pytree = ravel_pytree(tuple(self.dependencies[arg] for arg in self.args_names)) + return self._starting_dofs + + @property + def dofs_to_pytree(self): + if self._dofs_to_pytree is None: + self._starting_dofs, self._dofs_to_pytree = ravel_pytree(tuple(self.dependencies[arg] for arg in self.args_names)) + return self._dofs_to_pytree + + @partial(jit, static_argnames=['self']) + def __call__(self, dofs: jnp.ndarray) -> float: + args = self.dofs_to_pytree(dofs) + return self.fun(*args, **self.kwargs) + + @partial(jit, static_argnames=['self']) + def call_pytree(self, dofs_pytree) -> float: + return self.fun(*dofs_pytree, **self.kwargs) + + @partial(jit, static_argnames=['self']) + def grad(self, dofs: jnp.ndarray) -> jnp.ndarray: + args = self.dofs_to_pytree(dofs) + gradient = jax_grad(self.fun, argnums=tuple(range(len(args))))(*args, **self.kwargs) + return ravel_pytree(gradient)[0] + + @partial(jit, static_argnames=['self']) + def grad_pytree(self, dofs_pytree) -> dict: + gradient = jax_grad(self.fun, argnums=tuple(range(len(dofs_pytree))))(*dofs_pytree, **self.kwargs) + buffer = self.dependencies_buffer.copy() + for dep, g in zip(self.args_names, gradient): + buffer[dep] = g + return buffer + + def __mul__(self, other): + if not isinstance(other, (int, float)): + raise TypeError("Multiplication is only defined between base_loss and a scalar.") + + new_fun = lambda *args, **kwargs: other * self.fun(*args, **kwargs) + out_loss = custom_loss(new_fun, *self.args_names, **self.kwargs) + return out_loss + + +class composite_loss(base_loss): + def __init__(self, losses: list): + """ A composite loss function that combines multiple loss functions. + + Args: + losses (list): + A list of loss functions to be combined. Each loss function should be an instance of base_loss or its subclasses. + Returns: + composite_loss: An instance of the composite_loss class. + """ + super().__init__() + self.losses = losses + + @property + def dependencies(self): + return self._dependencies + + @dependencies.setter + def dependencies(self, value): + assert isinstance(value, dict), "dependencies must be a dictionary mapping dependency names to their corresponding objects." + self.clear_cache() + self._dependencies = value + for loss in self.losses: + loss.dependencies = self._dependencies + + # The dofs of a composite loss are all the dofs of its dependencies + @property + def starting_dofs(self): + if self._starting_dofs is None: + self._starting_dofs, self._dofs_to_pytree = ravel_pytree(self.dependencies) + return self._starting_dofs + + @property + def dofs_to_pytree(self): + if self._dofs_to_pytree is None: + self._starting_dofs, self._dofs_to_pytree = ravel_pytree(self.dependencies) + return self._dofs_to_pytree + + @partial(jit, static_argnames=['self']) + def __call__(self, dofs: jnp.ndarray) -> float: + dependencies = self.dofs_to_pytree(dofs) + each_loss = [loss.call_pytree(tuple(dependencies[arg] for arg in loss.args_names))\ + for loss in self.losses] + return sum(each_loss) + + @partial(jit, static_argnames=['self']) + def grad(self, dofs: jnp.ndarray) -> jnp.ndarray: + dependencies = self.dofs_to_pytree(dofs) + + grads_each_loss = [loss.grad_pytree(tuple(dependencies[arg] for arg in loss.args_names))\ + for loss in self.losses] + + grad = tree_util.tree_map(lambda *dofs: jnp.sum(jnp.stack(dofs), axis=0), *grads_each_loss) + dofs_grad = ravel_pytree(grad)[0] + return dofs_grad diff --git a/essos/objective_functions.py b/essos/objective_functions.py index a9d040fd..b4366e7c 100644 --- a/essos/objective_functions.py +++ b/essos/objective_functions.py @@ -1,12 +1,14 @@ import jax +# from build.lib.essos import coils jax.config.update("jax_enable_x64", True) import jax.numpy as jnp from jax import jit, vmap +from jax.lax import fori_loop from functools import partial from essos.dynamics import Tracing from essos.fields import BiotSavart,BiotSavart_from_gamma from essos.surfaces import BdotN_over_B, BdotN -from essos.coils import Curves, Coils,compute_curvature +from essos.coils import Curves, Coils from essos.optimization import new_nearaxis_from_x_and_old_nearaxis from essos.constants import mu_0 from essos.coil_perturbation import perturb_curves_systematic, perturb_curves_statistic @@ -70,16 +72,13 @@ def loss_coils_for_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, gradB_nearaxis = field_nearaxis.grad_B_axis.T gradB_coils = vmap(field.dB_by_dX)(points.T) - - coil_length = loss_coil_length(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_length=max_coil_length) - coil_curvature = loss_coil_curvature(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_curvature=max_coil_curvature) - + B_difference_loss = jnp.sum(jnp.abs(jnp.array(B_coils)-jnp.array(B_nearaxis))) gradB_difference_loss = jnp.sum(jnp.abs(jnp.array(gradB_coils)-jnp.array(gradB_nearaxis))) - coil_length_loss = 1e3*jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) - coil_curvature_loss = 1e3*jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) - + coil_length_loss = jnp.maximum(0, jnp.max(field.coils.length-max_coil_length)) + coil_curvature_loss = jnp.maximum(0, jnp.mean(field.coils.curvature, axis=1)-max_coil_curvature) + return B_difference_loss+gradB_difference_loss+coil_length_loss+coil_curvature_loss # @partial(jit, static_argnums=(0, 1)) @@ -105,10 +104,7 @@ def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, len_dofs_nearaxis = len(field_nearaxis.x) field=field_from_dofs(x[:-len_dofs_nearaxis],dofs_curves=dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym) new_field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(x[-len_dofs_nearaxis:], field_nearaxis) - - coil_length = loss_coil_length(x[:-len_dofs_nearaxis],dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_length=max_coil_length) - coil_curvature = loss_coil_curvature(x[:-len_dofs_nearaxis],dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_curvature=max_coil_curvature) - + elongation = new_field_nearaxis.elongation iota = new_field_nearaxis.iota @@ -116,14 +112,13 @@ def loss_coils_and_nearaxis(x, field_nearaxis, dofs_curves, currents_scale, nfp, B_difference_loss = 3*jnp.sum(jnp.abs(B_difference)) gradB_difference_loss = jnp.sum(jnp.abs(gradB_difference)) - coil_length_loss = 1e3*jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) - coil_curvature_loss = 1e3*jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) + coil_length_loss = jnp.maximum(0, jnp.max(field.coils.length-max_coil_length)) + coil_curvature_loss = jnp.maximum(0, jnp.mean(field.coils.curvature, axis=1)-max_coil_curvature) elongation_loss = jnp.sum(jnp.abs(elongation)) iota_loss = 30/jnp.abs(iota) return B_difference_loss+gradB_difference_loss+coil_length_loss+coil_curvature_loss+elongation_loss+iota_loss - def loss_particle_radial_drift(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True, maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) particles.to_full_orbit(field) @@ -198,7 +193,7 @@ def loss_particle_r_cross_final(x,particles,dofs_curves, currents_scale, nfp,n_s r_cross=jnp.sqrt(jnp.square(jnp.sqrt(jnp.square(xyz[:,:,0])+jnp.square(xyz[:,:,1]))-R_axis+1.e-12)+jnp.square(xyz[:,:,2]-Z_axis+1.e-12)) return jnp.linalg.norm((jnp.average(r_cross,axis=1))) -def loss_particle_r_cross_max_constraint(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,target_r=0.4,maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): +def loss_particle_r_cross_max(x,particles,dofs_curves, currents_scale, nfp,n_segments=60, stellsym=True,target_r=0.4,maxtime=1e-5, num_steps=300, trace_tolerance=1e-5, model='GuidingCenterAdaptative',boundary=None): field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) #particles.to_full_orbit(field) tracing = Tracing(field=field, model=model, particles=particles, maxtime=maxtime, @@ -271,29 +266,6 @@ def normB_axis(field, npoints=15,target_B_on_axis=5.7): B_axis = vmap(lambda phi: field.AbsB(jnp.array([R_axis * jnp.cos(phi), R_axis * jnp.sin(phi), 0])))(phi_array) return B_axis - -# @partial(jit, static_argnums=(0)) -#def loss_coil_length(field,max_coil_length=31): -# coil_length=jnp.ravel(field.coils_length) -# return jnp.array([jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])]))]) - -# @partial(jit, static_argnums=(0)) -#def loss_coil_curvature(field,max_coil_curvature=0.4): -# coil_curvature=jnp.mean(field.coils_curvature, axis=1) -# return jnp.array([jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])]))]) - -# @partial(jit, static_argnums=(0)) -def loss_coil_length(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,max_coil_length=31): - field=field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments,stellsym) - coil_length=jnp.ravel(field.coils_length) - return jnp.ravel(jnp.array([jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])]))])) - -# @partial(jit, static_argnums=(0)) -def loss_coil_curvature(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,max_coil_curvature=0.4): - field=field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments,stellsym) - coil_curvature=jnp.mean(field.coils_curvature, axis=1) - return jnp.ravel(jnp.array([jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])]))])) - # @partial(jit, static_argnums=(0, 1)) def loss_normB_axis(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True, npoints=15,target_B_on_axis=5.7): field=field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments,stellsym) @@ -310,22 +282,43 @@ def loss_normB_axis_average(x,dofs_curves,currents_scale,nfp,n_segments=60,stell B_axis = vmap(lambda phi: field.AbsB(jnp.array([R_axis * jnp.cos(phi), R_axis * jnp.sin(phi), 0])))(phi_array) return jnp.array([jnp.absolute(jnp.average(B_axis)-target_B_on_axis)]) +@partial(jit, static_argnames=['max_coil_length']) +def loss_coil_length(coils, max_coil_length=0): + return jnp.square(coils.length/max_coil_length - 1) +@partial(jit, static_argnames=['max_coil_curvature']) +def loss_coil_curvature(coils, max_coil_curvature=0): + pointwise_curvature_loss = jnp.square(jnp.maximum(coils.curvature-max_coil_curvature, 0)) + return jnp.mean(pointwise_curvature_loss*jnp.linalg.norm(coils.gamma_dash, axis=-1), axis=1) -# @partial(jit, static_argnums=(0)) -def loss_coil_curvature_new(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,max_coil_curvature=0.4): - field=field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments,stellsym) - coil_curvature=jnp.mean(field.coils_curvature, axis=1) - return jnp.maximum(coil_curvature-max_coil_curvature,0.0) +def compute_candidates(coils, min_separation): + centers = coils.curves.curves[:, :, 0] + a_n = coils.curves.curves[:, :, 2 : 2*coils.order+1 : 2] + b_n = coils.curves.curves[:, :, 1 : 2*coils.order : 2] + radii = jnp.sum(jnp.linalg.norm(a_n, axis=1)+jnp.linalg.norm(b_n, axis=1), axis=1) + + i_vals, j_vals = jnp.triu_indices(len(coils), k=1) + centers_dists = jnp.linalg.norm(centers[i_vals] - centers[j_vals], axis=1) + mask = centers_dists <= min_separation + radii[i_vals] + radii[j_vals] -# @partial(jit, static_argnums=(0)) -def loss_coil_length_new(x,dofs_curves,currents_scale,nfp,n_segments=60,stellsym=True,max_coil_length=31): - field=field_from_dofs(x,dofs_curves,currents_scale,nfp,n_segments,stellsym) - coil_length=jnp.ravel(field.coils_length) - return jnp.maximum(coil_length-max_coil_length,0.0) + return i_vals[mask], j_vals[mask] +@partial(jit, static_argnames=['min_separation']) +def loss_coil_separation(coils, min_separation, candidates=None): + if candidates is None: + candidates = jnp.triu_indices(len(coils), k=1) + def pair_loss(i, j): + gamma_i = coils.gamma[i] + gamma_dash_i = jnp.linalg.norm(coils.gamma_dash[i], axis=-1) + gamma_j = coils.gamma[j] + gamma_dash_j = jnp.linalg.norm(coils.gamma_dash[j], axis=-1) + dists = jnp.linalg.norm(gamma_i[:, None, :] - gamma_j[None, :, :], axis=2) + penalty = jnp.maximum(0, min_separation - dists) + return jnp.mean(jnp.square(penalty)*gamma_dash_i*gamma_dash_j) + losses = jax.vmap(pair_loss)(*candidates) + return jnp.sum(losses) @@ -338,8 +331,8 @@ def loss_optimize_coils_for_particle_confinement(x, particles, dofs_curves, curr particles_drift_loss = loss_particle_radial_drift(x,dofs_curves=dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym, particles=particles, maxtime=maxtime, num_steps=num_steps, trace_tolerance=trace_tolerance, model=model,boundary=boundary) normB_axis_loss = loss_normB_axis(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,npoints=15,target_B_on_axis=target_B_on_axis) - coil_length_loss = loss_coil_length(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_length=max_coil_length) - coil_curvature_loss = loss_coil_curvature(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_curvature=max_coil_curvature) + coil_length_loss = jnp.maximum(0, jnp.max(field.coils.length-max_coil_length)) + coil_curvature_loss = jnp.maximum(0, jnp.mean(field.coils.curvature, axis=1)-max_coil_curvature) loss = jnp.concatenate((normB_axis_loss, coil_length_loss, coil_curvature_loss,particles_drift_loss)) return jnp.sum(loss) @@ -362,14 +355,11 @@ def loss_BdotN(x, vmec, dofs_curves, currents_scale, nfp, max_coil_length=42, field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) bdotn_over_b = BdotN_over_B(vmec.surface, field) - coil_length = loss_coil_length(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_length=max_coil_length) - coil_curvature = loss_coil_curvature(x,dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_curvature=max_coil_curvature) - - bdotn_over_b_loss = jnp.sum(jnp.abs(bdotn_over_b)) - coil_length_loss = jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) - coil_curvature_loss = jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) - + + coil_length_loss = jnp.maximum(0, jnp.max(field.coils.length-max_coil_length)) + coil_curvature_loss = jnp.maximum(0, jnp.max(jnp.mean(field.coils.curvature, axis=1)-max_coil_curvature)) + return bdotn_over_b_loss+coil_length_loss+coil_curvature_loss @partial(jit, static_argnums=(1, 4, 5, 6)) @@ -377,7 +367,6 @@ def loss_BdotN_only(x, vmec, dofs_curves, currents_scale, nfp,n_segments=60, ste field=field_from_dofs(x,dofs_curves, currents_scale, nfp,n_segments, stellsym) bdotn_over_b = BdotN_over_B(vmec.surface, field) - bdotn_over_b_loss = jnp.sum(jnp.abs(bdotn_over_b)) return bdotn_over_b_loss @@ -581,7 +570,7 @@ def loss_lorentz_force_coils(x,dofs_curves,currents_scale,nfp,n_segments=60,stel def lp_force_pure(index,gamma, gamma_dash,gamma_dashdash,currents,quadpoints,p, threshold): """Pure function for minimizing the Lorentz force on a coil. """ - regularization = regularization_circ(1./jnp.average(compute_curvature( gamma_dash.at[index].get(), gamma_dashdash.at[index].get()))) + regularization = regularization_circ(1./jnp.average(Curves.compute_curvature( gamma_dash.at[index].get(), gamma_dashdash.at[index].get()))) B_mutual=jax.vmap(BiotSavart_from_gamma(jnp.roll(gamma, -index, axis=0)[1:], jnp.roll(gamma_dash, -index, axis=0)[1:], jnp.roll(gamma_dashdash, -index, axis=0)[1:], diff --git a/essos/optimization.py b/essos/optimization.py index fb1a24bb..6291fec4 100644 --- a/essos/optimization.py +++ b/essos/optimization.py @@ -64,7 +64,7 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e- dofs_currents = result.x[len_dofs_curves:-len(surface_all.x)] curves = Curves(dofs_curves, n_segments, nfp, stellsym) new_coils = Coils(curves=curves, currents=dofs_currents * coils.currents_scale) - new_surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta) + new_surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta,mpol=surface_all.mpol,ntor=surface_all.ntor) new_surface.dofs = result.x[-len(surface_all.x):] return new_coils, new_surface elif 'surface_all' in kwargs and 'field_nearaxis' in kwargs and len(initial_dofs) == len(coils.x) + len(kwargs['surface_all'].x) + len(kwargs['field_nearaxis'].x): @@ -73,7 +73,7 @@ def optimize_loss_function(func, initial_dofs, coils, tolerance_optimization=1e- dofs_currents = result.x[len_dofs_curves:-len(surface_all.x)-len(field_nearaxis.x)] curves = Curves(dofs_curves, n_segments, nfp, stellsym) new_coils = Coils(curves=curves, currents=dofs_currents * coils.currents_scale) - new_surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta) + new_surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta,mpol=surface_all.mpol,ntor=surface_all.ntor) new_surface.dofs = result.x[-len(surface_all.x)-len(field_nearaxis.x):-len(field_nearaxis.x)] new_field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(result.x[-len(field_nearaxis.x):], field_nearaxis) return new_coils, new_surface, new_field_nearaxis diff --git a/essos/surfaces.py b/essos/surfaces.py index 0048e3cf..2ed0a124 100644 --- a/essos/surfaces.py +++ b/essos/surfaces.py @@ -1,36 +1,92 @@ from functools import partial +import jax import jax.numpy as jnp from jax.scipy.interpolate import RegularGridInterpolator -from jax import jit, vmap, devices, device_put +from jax import tree_util, jit, vmap, devices, device_put from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental.pjit import pjit from essos.plot import fix_matplotlib_3d import jaxkd mesh = Mesh(devices(), ("dev",)) -sharding = NamedSharding(mesh, PartitionSpec("dev", None)) +sharding = NamedSharding(mesh, PartitionSpec("dev")) -@partial(jit, static_argnames=['surface','field']) + +@jit +def toroidal_flux(surface, field, idx=0) -> jnp.ndarray: + curve = surface.gamma[idx] + dl = jnp.roll(curve, -1, axis=0) - curve + A_vals = vmap(field.A)(curve) + Adl = jnp.sum(A_vals * dl, axis=1) + tf = jnp.sum(Adl) + #curve = surface.gamma[idx] + #dl = surface.gammadash_theta[idx] + #A_vals = vmap(field.A)(curve) + #Adl = jnp.sum(A_vals * dl, axis=1)/surface.ntheta + #tf = jnp.sum(Adl) + return tf + +@jit +def poloidal_flux(surface, field, idx=0) -> jnp.ndarray: + curve = surface.gamma[:,idx,:] + dl = jnp.roll(curve, -1, axis=0) - curve + A_vals = vmap(field.A)(curve) + Adl = jnp.sum(A_vals * dl, axis=1) + tf = jnp.sum(Adl) + #curve = surface.gamma[:,idx,:] + #dl = surface.gammadash_phi[:,idx,:] + #A_vals = vmap(field.A)(curve) + #Adl = jnp.sum(A_vals * dl, axis=1)/surface.nphi + #tf = jnp.sum(Adl) + return tf + +# @jit +@partial(pjit, in_shardings=(sharding, None), out_shardings=sharding) def B_on_surface(surface, field): ntheta = surface.ntheta nphi = surface.nphi gamma = surface.gamma gamma_reshaped = gamma.reshape(nphi * ntheta, 3) - gamma_sharded = device_put(gamma_reshaped, sharding) - B_on_surface = jit(vmap(field.B), in_shardings=sharding, out_shardings=sharding)(gamma_sharded) - B_on_surface = B_on_surface.reshape(nphi, ntheta, 3) - return B_on_surface -@partial(jit, static_argnames=['surface','field']) + # Map field.B over all positions + B_on_surface = vmap(field.B)(gamma_reshaped) + + return B_on_surface.reshape(nphi, ntheta, 3) + + +@jit def BdotN(surface, field): B_surface = B_on_surface(surface, field) B_dot_n = jnp.sum(B_surface * surface.unitnormal, axis=2) return B_dot_n -@partial(jit, static_argnames=['surface','field']) -def BdotN_over_B(surface, field): - B_surface = B_on_surface(surface, field) - B_dot_n = jnp.sum(B_surface * surface.unitnormal, axis=2) - return B_dot_n / jnp.linalg.norm(B_surface, axis=2) +@jit +def BdotN_over_B(surface, field, **kwargs): + return BdotN(surface, field) / jnp.linalg.norm(B_on_surface(surface, field), axis=2) + +@jit +def _squared_flux_local(surface, field): + return 0.5 * jnp.mean(BdotN(surface, field)**2 / jnp.sum(B_on_surface(surface, field)**2, axis=2) + * surface.area_element) + +@jit +def _squared_flux_global(surface, field): + return 0.5 * jnp.mean(BdotN(surface, field)**2 * surface.area_element) + +@jit +def _squared_flux_normalized(surface, field): + return 0.5 * jnp.mean(BdotN(surface, field)**2 * surface.area_element) / \ + jnp.mean(jnp.sum(B_on_surface(surface, field)**2, axis=2) * surface.area_element) + +def SquaredFlux(surface, field, definition='local'): + if definition == 'local': + return _squared_flux_local(surface, field) + elif definition == 'quadratic flux': + return _squared_flux_global(surface, field) + elif definition == 'normalized': + return _squared_flux_normalized(surface, field) + else: + raise ValueError(f"Unknown definition: {definition}") def nested_lists_to_array(ll): """ @@ -47,187 +103,368 @@ def nested_lists_to_array(ll): for jm, l in enumerate(ll): arr = arr.at[jm, :len(l)].set(jnp.array([x if x is not None else 0 for x in l])) return arr + class SurfaceRZFourier: - def __init__(self, vmec=None, s=1, ntheta=30, nphi=30, close=True, range_torus='full torus', - rc=None, zs=None, nfp=None): - if rc is not None: - self.rc = rc - self.zs = zs - self.nfp = nfp - self.mpol = rc.shape[0] - self.ntor = (rc.shape[1] - 1) // 2 - m1d = jnp.arange(self.mpol) - n1d = jnp.arange(-self.ntor, self.ntor + 1) - n2d, m2d = jnp.meshgrid(n1d, m1d) - self.xm = m2d.flatten()[self.ntor:] - self.xn = self.nfp*n2d.flatten()[self.ntor:] - indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] - self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] - elif isinstance(vmec, str): - self.input_filename = vmec - import f90nml - all_namelists = f90nml.read(vmec) - nml = all_namelists['indata'] - if 'nfp' in nml: - self.nfp = nml['nfp'] - else: - self.nfp = 1 - rc = nested_lists_to_array(nml['rbc']) - zs = nested_lists_to_array(nml['zbs']) - rbc_first_n = nml.start_index['rbc'][0] - rbc_last_n = rbc_first_n + rc.shape[1] - 1 - zbs_first_n = nml.start_index['zbs'][0] - zbs_last_n = zbs_first_n + zs.shape[1] - 1 - self.ntor = jnp.max(jnp.abs(jnp.array([rbc_first_n, rbc_last_n, zbs_first_n, zbs_last_n], dtype='i'))) - rbc_first_m = nml.start_index['rbc'][1] - rbc_last_m = rbc_first_m + rc.shape[0] - 1 - zbs_first_m = nml.start_index['zbs'][1] - zbs_last_m = zbs_first_m + zs.shape[0] - 1 - self.mpol = max(rbc_last_m, zbs_last_m) - self.rc = jnp.zeros((self.mpol, 2 * self.ntor + 1)) - self.zs = jnp.zeros((self.mpol, 2 * self.ntor + 1)) - m_indices_rc = jnp.arange(rc.shape[0]) + nml.start_index['rbc'][1] - n_indices_rc = jnp.arange(rc.shape[1]) + nml.start_index['rbc'][0] + self.ntor - self.rc = self.rc.at[m_indices_rc[:, None], n_indices_rc].set(rc) - m_indices_zs = jnp.arange(zs.shape[0]) + nml.start_index['zbs'][1] - n_indices_zs = jnp.arange(zs.shape[1]) + nml.start_index['zbs'][0] + self.ntor - self.zs = self.zs.at[m_indices_zs[:, None], n_indices_zs].set(zs) - m1d = jnp.arange(self.mpol) - n1d = jnp.arange(-self.ntor, self.ntor + 1) - n2d, m2d = jnp.meshgrid(n1d, m1d) - self.xm = m2d.flatten()[self.ntor:] - self.xn = self.nfp*n2d.flatten()[self.ntor:] - indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] - self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] - else: - try: - self.nfp = vmec.nfp - self.bmnc = vmec.bmnc - self.xm = vmec.xm - self.xn = vmec.xn - self.rmnc = vmec.rmnc - self.zmns = vmec.zmns - self.xm_nyq = vmec.xm_nyq - self.xn_nyq = vmec.xn_nyq - self.len_xm_nyq = len(self.xm_nyq) - self.ns = vmec.ns - self.s_full_grid = vmec.s_full_grid - self.ds = vmec.ds - self.s_half_grid = vmec.s_half_grid - self.r_axis = vmec.r_axis - self.rmnc_interp = vmap(lambda row: jnp.interp(s, self.s_full_grid, row, left='extrapolate'), in_axes=1)(self.rmnc) - self.zmns_interp = vmap(lambda row: jnp.interp(s, self.s_full_grid, row, left='extrapolate'), in_axes=1)(self.zmns) - self.bmnc_interp = vmap(lambda row: jnp.interp(s, self.s_half_grid, row, left='extrapolate'), in_axes=1)(self.bmnc[1:, :]) - self.mpol = vmec.mpol - self.ntor = vmec.ntor - self.num_dofs = 2 * (self.mpol + 1) * (2 * self.ntor + 1) - self.ntor - (self.ntor + 1) - shape = (int(jnp.max(self.xm)) + 1, int(jnp.max(self.xn)) + 1) - self.rc = jnp.zeros(shape) - self.zs = jnp.zeros(shape) - indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rc = self.rc.at[indices[:, 0], indices[:, 1]].set(self.rmnc_interp) - self.zs = self.zs.at[indices[:, 0], indices[:, 1]].set(self.zmns_interp) - except: - raise ValueError("vmec must be a Vmec object or a string pointing to a VMEC input file.") - self.ntheta = ntheta - self.nphi = nphi - self.range_torus = range_torus - if range_torus == 'full torus': div = 1 - else: div = self.nfp - if range_torus == 'half period': end_val = 0.5 - else: end_val = 1.0 - self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=self.ntheta, endpoint=True if close else False) - self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=self.nphi, endpoint=True if close else False) - self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) - self.num_dofs_rc = len(jnp.ravel(self.rc)[self.ntor:]) - self.num_dofs_zs = len(jnp.ravel(self.zs)[self.ntor:]) - self._dofs = jnp.concatenate((jnp.ravel(self.rc)[self.ntor:], jnp.ravel(self.zs)[self.ntor:])) + def __init__(self, rc, zs, nfp, mpol, ntor, ntheta=30, nphi=30, close=True, range_torus='full torus', + scaling_type=2, scaling_factor=0): + """ rc, zs: dynamic arrays + nfp, mpol, ntor: static """ + + assert isinstance(nfp, int) and nfp > 0, "nfp must be a positive integer." + assert isinstance(mpol, int) and mpol >= 0, "mpol must be a non-negative integer." + assert isinstance(ntor, int) and ntor >= 0, "ntor must be a non-negative integer." + assert isinstance(ntheta, int) and ntheta > 0, "ntheta must be a positive integer." + assert isinstance(nphi, int) and nphi > 0, "nphi must be a positive integer." + assert isinstance(close, bool), "close must be a boolean." + assert range_torus in ['full torus', 'half period'], f"Unknown range_torus: {range_torus}. Choose 'full torus' or 'half period'." + + self._rc = rc + self._zs = zs + self._nfp = nfp + self._mpol = mpol + self._ntor = ntor + + self._gamma = None + self._gammadash_theta = None + self._gammadash_phi = None + self._normal = None + self._unitnormal = None + self._area_element = None + self._xm = None + self._xn = None + + self._ntheta = ntheta + self._nphi = nphi + self._close = close + self._range_torus = range_torus + + self._quadpoints_theta = None + self._quadpoints_phi = None + self._theta2d = None + self._phi2d = None + self._angles = None + + self._scaling_type = scaling_type # 1 for L-1 norm, 2 for L-2 norm, jnp.inf for L-infinity norm + self._scaling_factor = scaling_factor + self._scaling = None + + + @classmethod + def from_input_file(cls, file, ntheta=30, nphi=30, close=True, range_torus='full torus'): + from f90nml import Parser + nml = Parser().read(file)['indata'] + + nfp = nml["nfp"] if "nfp" in nml else 1 + mpol = nml['mpol'] + ntor = nml['ntor'] - self.angles = jnp.einsum('i,jk->ijk', self.xm, self.theta_2d) - jnp.einsum('i,jk->ijk', self.xn, self.phi_2d) + rc = jnp.ravel(nested_lists_to_array(nml['rbc']))[2:] + zs = jnp.ravel(nested_lists_to_array(nml['zbs']))[2:] + + surface = cls(rc, zs, nfp, mpol, ntor, ntheta=ntheta, nphi=nphi, close=close, range_torus=range_torus) + return surface - (self._gamma, self._gammadash_theta, self._gammadash_phi, - self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + @classmethod + def from_vmec(cls, vmec, s=1, ntheta=30, nphi=30, close=True, range_torus='full torus'): + nfp = vmec.nfp + mpol = vmec.mpol + ntor = vmec.ntor + + s_full_grid = vmec.s_full_grid + rc = vmap(lambda row: jnp.interp(s, s_full_grid, row, left='extrapolate'), in_axes=1)(vmec.rmnc) + zs = vmap(lambda row: jnp.interp(s, s_full_grid, row, left='extrapolate'), in_axes=1)(vmec.zmns) + + surface = cls(rc, zs, nfp, mpol, ntor, ntheta=ntheta, nphi=nphi, close=close, range_torus=range_torus) + surface._xm = vmec.xm + surface._xn = vmec.xn + + return surface + + @classmethod + def from_wout_file(cls, file, s=1, ntheta=30, nphi=30, close=True, range_torus='full torus'): + from netCDF4 import Dataset + nc = Dataset(file) + + nfp = int(nc.variables["nfp"][0]) + xm = jnp.array(nc.variables["xm"][:]) + xn = jnp.array(nc.variables["xn"][:]) + mpol = int(jnp.max(xm)) + ntor = int(jnp.max(jnp.abs(xn)) / nfp) - if hasattr(self, 'bmnc'): - self._AbsB = self._set_AbsB() + ns = nc.variables["ns"][0] + s_full_grid = jnp.linspace(0, 1, ns) + rc = vmap(lambda row: jnp.interp(s, s_full_grid, row, left='extrapolate'), in_axes=1)(jnp.array(nc.variables["rmnc"][:])) + zs = vmap(lambda row: jnp.interp(s, s_full_grid, row, left='extrapolate'), in_axes=1)(jnp.array(nc.variables["zmns"][:])) + + surface = cls(rc, zs, nfp, mpol, ntor, ntheta=ntheta, nphi=nphi, close=close, range_torus=range_torus) + surface._xm = xm + surface._xn = xn + + return surface + + # reset_cache method + def reset_cache(self): + self._gamma = None + self._gammadash_theta = None + self._gammadash_phi = None + self._normal = None + self._unitnormal = None + self._area_element = None + self._xm = None + self._xn = None + self._angles = None + + # reset_mesh method + def reset_mesh(self): + self._quadpoints_theta = None + self._quadpoints_phi = None + self._theta2d = None + self._phi2d = None + self._angles = None + + # rc property and setter + @property + def rc(self): + return self._rc + + @rc.setter + def rc(self, new_rc): + self._rc = new_rc + self.reset_cache() + + # zs property and setter + @property + def zs(self): + return self._zs + + @zs.setter + def zs(self, new_zs): + self._zs = new_zs + self.reset_cache() + + # nfp property + @property + def nfp(self): + return self._nfp + + # mpol property + @property + def mpol(self): + return self._mpol + + # ntor property + @property + def ntor(self): + return self._ntor + + # xm property + @property + def xm(self): + if self._xm is None: + self._xm = jnp.repeat(jnp.arange(self.mpol + 1), 2 * self.ntor + 1)[self.ntor:] + return self._xm + + # xn property + @property + def xn(self): + if self._xn is None: + self._xn = self.nfp * jnp.tile(jnp.arange(-self.ntor, self.ntor + 1), self.mpol + 1)[self.ntor:] + return self._xn + + # _ntheta property and setter + @property + def ntheta(self): + return self._ntheta + + @ntheta.setter + def ntheta(self, new_ntheta): + self._ntheta = new_ntheta + self.reset_mesh() + + # n_phi property and setter + @property + def nphi(self): + return self._nphi + + @nphi.setter + def nphi(self, new_nphi): + self._nphi = new_nphi + self.reset_mesh() + + # close property and setter + @property + def close(self): + return self._close + + @close.setter + def close(self, new_close): + self._close = new_close + self.reset_mesh() + # range_torus property and setter + @property + def range_torus(self): + return self._range_torus + + @range_torus.setter + def range_torus(self, new_range): + self._range_torus = new_range + self.reset_mesh() + + # _compute_meshgrid method + @jit + def _compute_meshgrid(self): + if self.range_torus == "full torus": + div, end_val = 1., 1. + elif self.range_torus == "half period": + div, end_val = self.nfp, 0.5 + quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=self.ntheta, endpoint=self.close) + quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=self.nphi, endpoint=self.close) + theta2d, phi2d = jnp.meshgrid(quadpoints_theta, quadpoints_phi) + return quadpoints_theta, quadpoints_phi, theta2d, phi2d + + # theta2d property + @property + def theta2d(self): + if self._theta2d is None: + self._quadpoints_theta, self._quadpoints_phi, self._theta2d, self._phi2d = self._compute_meshgrid() + return self._theta2d + + # phi2d property + @property + def phi2d(self): + if self._phi2d is None: + self._quadpoints_theta, self._quadpoints_phi, self._theta2d, self._phi2d = self._compute_meshgrid() + return self._phi2d + + # angles property + @property + def angles(self): + if self._angles is None: + self._angles = jnp.einsum('i,jk->ijk', self.xm, self.theta2d) - jnp.einsum('i,jk->ijk', self.xn, self.phi2d) + return self._angles + + # scaling_type property and setter + @property + def scaling_type(self): + return self._scaling_type + + @scaling_type.setter + def scaling_type(self, new_type): + self._scaling_type = new_type + self._scaling = None + + # scaling_factor property and setter + @property + def scaling_factor(self): + return self._scaling_factor + + @scaling_factor.setter + def scaling_factor(self, new_factor): + self._scaling_factor = new_factor + self._scaling = None + + # scaling property + @property + def scaling(self): + if self._scaling is None: + self._scaling = jnp.exp(self.scaling_factor * jnp.linalg.norm(jnp.vstack([self.xm, self.xn]), ord=self.scaling_type, axis=0)) + return self._scaling + + # dofs property and setter @property def dofs(self): - return self._dofs + return jnp.hstack([self.rc * self.scaling, self.zs * self.scaling]) @dofs.setter def dofs(self, new_dofs): - self._dofs = new_dofs - self.rc = jnp.concatenate((jnp.zeros(self.ntor),new_dofs[:self.num_dofs_rc])).reshape(self.rc.shape) - self.zs = jnp.concatenate((jnp.zeros(self.ntor),new_dofs[self.num_dofs_rc:])).reshape(self.zs.shape) - indices = jnp.array([self.xm, self.xn / self.nfp + self.ntor], dtype=int).T - self.rmnc_interp = self.rc[indices[:, 0], indices[:, 1]] - self.zmns_interp = self.zs[indices[:, 0], indices[:, 1]] - (self._gamma, self._gammadash_theta, self._gammadash_phi, - self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) - # if hasattr(self, 'bmnc'): - # self._AbsB = self._set_AbsB() + self._rc = new_dofs[:self.rc.size] / self.scaling + self._zs = new_dofs[self.rc.size:] / self.scaling + self.reset_cache() - @partial(jit, static_argnames=['self']) - def _set_gamma(self, rmnc_interp, zmns_interp): - phi_2d = self.phi_2d + # _compute_gamma method + @jit + def _compute_gamma(self): angles = self.angles - + print(angles.shape) sin_angles = jnp.sin(angles) cos_angles = jnp.cos(angles) - r_coordinate = jnp.einsum('i,ijk->jk', rmnc_interp, cos_angles) - z_coordinate = jnp.einsum('i,ijk->jk', zmns_interp, sin_angles) - gamma = jnp.transpose(jnp.array([r_coordinate * jnp.cos(phi_2d), r_coordinate * jnp.sin(phi_2d), z_coordinate]), (1, 2, 0)) + phi2d = self.phi2d + sin_phi2d = jnp.sin(phi2d) + cos_phi2d = jnp.cos(phi2d) + rc = self.rc; zs = self.zs; xm = self.xm; xn = self.xn - dX_dtheta = jnp.einsum('i,ijk,i->jk', -self.xm, sin_angles, rmnc_interp) * jnp.cos(phi_2d) - dY_dtheta = jnp.einsum('i,ijk,i->jk', -self.xm, sin_angles, rmnc_interp) * jnp.sin(phi_2d) - dZ_dtheta = jnp.einsum('i,ijk,i->jk', self.xm, cos_angles, zmns_interp) - gammadash_theta = 2*jnp.pi*jnp.transpose(jnp.array([dX_dtheta, dY_dtheta, dZ_dtheta]), (1, 2, 0)) + print(rc.shape, cos_angles.shape) + R = jnp.einsum('i,ijk->jk', rc, cos_angles) + Z = jnp.einsum('i,ijk->jk', zs, sin_angles) + X = R * cos_phi2d + Y = R * sin_phi2d + gamma = jnp.stack([X, Y, Z], axis=-1) - dX_dphi = jnp.einsum('i,ijk,i->jk', self.xn, sin_angles, rmnc_interp) * jnp.cos(phi_2d) - r_coordinate * jnp.sin(phi_2d) - dY_dphi = jnp.einsum('i,ijk,i->jk', self.xn, sin_angles, rmnc_interp) * jnp.sin(phi_2d) + r_coordinate * jnp.cos(phi_2d) - dZ_dphi = jnp.einsum('i,ijk,i->jk', -self.xn, cos_angles, zmns_interp) - gammadash_phi = 2*jnp.pi*jnp.transpose(jnp.array([dX_dphi, dY_dphi, dZ_dphi]), (1, 2, 0)) + dR_dtheta = -jnp.einsum('i,ijk->jk', xm * rc, sin_angles) + dZ_dtheta = jnp.einsum('i,ijk->jk', xm * zs, cos_angles) + dX_dtheta = dR_dtheta * cos_phi2d + dY_dtheta = dR_dtheta * sin_phi2d + gammadash_theta = jnp.stack([dX_dtheta, dY_dtheta, dZ_dtheta], axis=-1) - normal = jnp.cross(gammadash_phi, gammadash_theta, axis=2) - unitnormal = normal / jnp.linalg.norm(normal, axis=2, keepdims=True) + dR_dphi = jnp.einsum('i,ijk->jk', xn*rc, sin_angles) + dZ_dphi = -jnp.einsum('i,ijk->jk', xn*zs, cos_angles) + dX_dphi = dR_dphi * cos_phi2d - R * sin_phi2d + dY_dphi = dR_dphi * sin_phi2d + R * cos_phi2d + gammadash_phi = jnp.stack([dX_dphi, dY_dphi, dZ_dphi], axis=-1) - return (gamma, gammadash_theta, gammadash_phi, normal, unitnormal) - - @partial(jit, static_argnames=['self']) - def _set_AbsB(self): - angles_nyq = jnp.einsum('i,jk->ijk', self.xm_nyq, self.theta_2d) - jnp.einsum('i,jk->ijk', self.xn_nyq, self.phi_2d) - AbsB = jnp.einsum('i,ijk->jk', self.bmnc_interp, jnp.cos(angles_nyq)) - return AbsB + return gamma, gammadash_theta, gammadash_phi + # gamma, gammadash_theta, gammadash_phi properties @property def gamma(self): + if self._gamma is None: + self._gamma, self._gammadash_theta, self._gammadash_phi = self._compute_gamma() return self._gamma @property def gammadash_theta(self): + if self._gammadash_theta is None: + self._gamma, self._gammadash_theta, self._gammadash_phi = self._compute_gamma() return self._gammadash_theta @property def gammadash_phi(self): + if self._gammadash_phi is None: + self._gamma, self._gammadash_theta, self._gammadash_phi = self._compute_gamma() return self._gammadash_phi + + # _compute_properties method + @jit + def _compute_properties(self): + normal = jnp.cross(self.gammadash_theta, self.gammadash_phi, axis=2) + unitnormal = normal / jnp.linalg.norm(normal, axis=2, keepdims=True) + area_element = jnp.linalg.norm(normal, axis=2) + return normal, unitnormal, area_element + # normal, unitnormal, area_element properties @property def normal(self): + if self._normal is None: + self._normal, self._unitnormal, self._area_element = self._compute_properties() return self._normal @property def unitnormal(self): + if self._unitnormal is None: + self._normal, self._unitnormal, self._area_element = self._compute_properties() return self._unitnormal @property - def AbsB(self): - return self._AbsB - + def area_element(self): + if self._area_element is None: + self._normal, self._unitnormal, self._area_element = self._compute_properties() + return self._area_element + + # TODO: remove x property. This is a placeholder for compatibility with the examples that need to be updated. + # x property and setter @property def x(self): return self.dofs @@ -235,7 +472,101 @@ def x(self): @x.setter def x(self, new_dofs): self.dofs = new_dofs - + + @property + def volume(self): + + xyz = self.gamma # shape: (nphi, ntheta, 3) + n = self.normal # shape: (nphi, ntheta, 3) + + integrand = jnp.sum(xyz * n, axis=2) # dot(x, n), shape: (nphi, ntheta) + volume = jnp.mean(integrand) / 3.0 + return volume + + @property + def area(self): + #n = self.normal # (nphi, ntheta, 3) + #norm_n = jnp.linalg.norm(n, axis=2) # shape: (nphi, ntheta) + #avg_area = jnp.mean(norm_n) + #return avg_area + n = self.normal # shape: (nphi, ntheta, 3) + norm_n = jnp.linalg.norm(n, axis=2) + + dphi = 2 * jnp.pi / self.nphi + dtheta = 2 * jnp.pi / self.ntheta + + area = jnp.sum(norm_n) * dphi * dtheta + return area + + # def change_resolution(self, mpol: int, ntor: int, ntheta=None, nphi=None,close=True): + # """ + # Change the values of `mpol` and `ntor`. + # New Fourier coefficients are zero by default. + # Old coefficients outside the new range are discarded. + # """ + # rc_old, zs_old = self.rc, self.zs + # mpol_old, ntor_old = self.mpol, self.ntor + # if ntheta is not None: + # self.ntheta = ntheta + # else: + # ntheta = self.ntheta + + # if nphi is not None: + # self.nphi = nphi + # else: + # nphi = self.nphi + + # #rc_new = jnp.zeros((mpol, 2 * ntor + 1)) + # #zs_new = jnp.zeros((mpol, 2 * ntor + 1)) + # rc_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) + # zs_new = jnp.zeros(((mpol+1)*( 2 * ntor + 1)-ntor)) + # m_keep = min(mpol_old, mpol) + # n_keep = min(ntor_old, ntor) + + # xm_old=self.xm + # xn_old=self.xn + # self.xm = jnp.repeat(jnp.arange(mpol+1), 2*ntor+1)[ntor:] + # self.xn = self.nfp*jnp.tile(jnp.arange(-ntor, ntor + 1), mpol+1)[ntor:] + # # Copy overlapping region + # for l in range(len(self.xm)): + # if self.xm[l]<=m_keep and jnp.abs(self.xn[l]/self.nfp)<=n_keep: + # index=self.xm[l]*(ntor_old*2+1)-self.xn[l]//self.nfp + # rc_new=rc_new.at[l].set(self.rc[index]) + # zs_new=zs_new.at[l].set(self.zs[index]) + + + # # Update attributes + # self.mpol, self.ntor = mpol, ntor + # self.rc, self.zs = rc_new, zs_new + + # self.rmnc_interp = self.rc + # self.zmns_interp = self.zs + + # # Update degrees of freedom + # self.num_dofs_rc = len(jnp.ravel(self.rc)) + # self.num_dofs_zs = len(jnp.ravel(self.zs)) + # self._dofs = jnp.concatenate((self.rescaling_function(jnp.ravel(self.rc)), self.rescaling_function(jnp.ravel(self.zs)))) + + # # Recompute angles and geometry + # if self.range_torus == 'full torus': div = 1 + # else: div = self.nfp + # if self.range_torus == 'half period': end_val = 0.5 + # else: end_val = 1.0 + # self.quadpoints_theta = jnp.linspace(0, 2 * jnp.pi, num=ntheta, endpoint=True if close else False) + # self.quadpoints_phi = jnp.linspace(0, 2 * jnp.pi * end_val / div, num=nphi, endpoint=True if close else False) + # self.theta_2d, self.phi_2d = jnp.meshgrid(self.quadpoints_theta, self.quadpoints_phi) + + # self.angles = (jnp.einsum('i,jk->ijk', self.xm, self.theta_2d)- jnp.einsum('i,jk->ijk', self.xn, self.phi_2d)) + # (self._gamma, self._gammadash_theta, self._gammadash_phi, + # self._normal, self._unitnormal) = self._set_gamma(self.rmnc_interp, self.zmns_interp) + + + # # Recompute AbsB if available + # if hasattr(self, 'bmnc'): + # self._AbsB = self._set_AbsB() + + # return self + def plot(self, ax=None, show=True, close=False, axis_equal=True, **kwargs): if close: raise NotImplementedError("Call close=True when instantiating the VMEC/SurfaceRZFourier object.") @@ -246,7 +577,7 @@ def plot(self, ax=None, show=True, close=False, axis_equal=True, **kwargs): if ax is None or ax.name != "3d": fig = plt.figure() ax = fig.add_subplot(projection='3d') - + boundary = self.gamma if hasattr(self, 'bmnc'): @@ -299,15 +630,11 @@ def to_vmec(self, filename): nml += 'LASYM = .FALSE.\n' nml += f'NFP = {self.nfp}\n' - for m in range(self.mpol + 1): - nmin = -self.ntor - if m == 0: - nmin = 0 - for n in range(nmin, self.ntor + 1): - rc = self.rc[m, n + self.ntor] - zs = self.zs[m, n + self.ntor] - if jnp.abs(rc) > 0 or jnp.abs(zs) > 0: - nml += f"RBC({n:4d},{m:4d}) ={rc:23.15e}, ZBS({n:4d},{m:4d}) ={zs:23.15e}\n" + # Copy overlapping region + for l in range(len(self.xm)): + rc = self.rc[l] + zs = self.zs[l] + nml += f"RBC({self.xn[l]:4d},{self.xm[l]:4d}) ={rc:23.15e}, ZBS({self.xn[l]:4d},{self.xm[l]:4d}) ={zs:23.15e}\n" nml += '/\n' with open(filename, 'w') as f: @@ -329,6 +656,27 @@ def mean_cross_sectional_area(self): mean_cross_sectional_area = jnp.abs(jnp.mean(jnp.sqrt(x2y2) * dZ_dtheta * detJ))/(2 * jnp.pi) return mean_cross_sectional_area + def _tree_flatten(self): + children = (self._rc, self._zs) # arrays / dynamic values + aux_data = {"nfp": self._nfp, + "mpol": self._mpol, + "ntor": self._ntor, + "ntheta": self._ntheta, + "nphi": self._nphi, + "close": self._close, + "range_torus": self._range_torus, + "scaling_type": self._scaling_type, + "scaling_factor": self._scaling_factor} # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) + +tree_util.register_pytree_node(SurfaceRZFourier, + SurfaceRZFourier._tree_flatten, + SurfaceRZFourier._tree_unflatten) + #This class is based on simsopt classifier but translated to fit jax class SurfaceClassifier(): """ @@ -454,3 +802,11 @@ def signed_distance_from_surface_extras(xyz, surface): +def plot_scalar_on_flux_surface(surface, scalar_map): + ''' + surface: the surface object in which to plot the scalar_map + scalar_map: a scalar_map as function of theta and phi + ''' + + + diff --git a/examples/optimize_coils_and_nearaxis.py b/examples/coil_optimization/optimize_coils_and_nearaxis.py similarity index 100% rename from examples/optimize_coils_and_nearaxis.py rename to examples/coil_optimization/optimize_coils_and_nearaxis.py diff --git a/examples/optimize_coils_and_surface.py b/examples/coil_optimization/optimize_coils_and_surface.py similarity index 94% rename from examples/optimize_coils_and_surface.py rename to examples/coil_optimization/optimize_coils_and_surface.py index 1bef5b83..2946903f 100644 --- a/examples/optimize_coils_and_surface.py +++ b/examples/coil_optimization/optimize_coils_and_surface.py @@ -20,8 +20,10 @@ ntheta=30 nphi=30 -input = os.path.join('input_files','input.rotating_ellipse') -surface_initial = SurfaceRZFourier(input, ntheta=ntheta, nphi=nphi, range_torus='half period') +mpol=2 +ntor=2 +input = os.path.join(os.path.dirname(__file__), 'input_files','input.rotating_ellipse') +surface_initial = SurfaceRZFourier(input, ntheta=ntheta, nphi=nphi, range_torus='half period', mpol=mpol, ntor=ntor) # Optimization parameters max_coil_length = 38 @@ -122,18 +124,18 @@ def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves, currents n_segments=60, stellsym=True, max_coil_curvature=0.5, target_B_on_surface=5.7): field=field_from_dofs(x[:-len(surface_all.x)-len(field_nearaxis.x)] ,dofs_curves=dofs_curves, currents_scale=currents_scale, nfp=nfp,n_segments=n_segments, stellsym=stellsym) - surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta) + surface = SurfaceRZFourier(rc=surface_all.rc, zs=surface_all.zs, nfp=nfp, range_torus=surface_all.range_torus, nphi=surface_all.nphi, ntheta=surface_all.ntheta,mpol=surface_all.mpol,ntor=surface_all.ntor) surface.dofs = x[-len(surface_all.x)-len(field_nearaxis.x):-len(field_nearaxis.x)] field_nearaxis = new_nearaxis_from_x_and_old_nearaxis(x[-len(field_nearaxis.x):], field_nearaxis) + + coil_length = field.coils.length + coil_curvature = field.coils.curvature + - coil_length = loss_coil_length(x[:-len(surface_all.x)-len(field_nearaxis.x)],dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_length=max_coil_length) - coil_curvature = loss_coil_curvature(x[:-len(surface_all.x)-len(field_nearaxis.x)],dofs_curves=dofs_curves,currents_scale=currents_scale,nfp=nfp,n_segments=n_segments,stellsym=stellsym,max_coil_curvature=max_coil_curvature) - - - coil_length_loss = 1e3*jnp.max(jnp.concatenate([coil_length-max_coil_length,jnp.array([0])])) - coil_curvature_loss = 1e3*jnp.max(jnp.concatenate([coil_curvature-max_coil_curvature,jnp.array([0])])) - + coil_length_loss = 1e3*jnp.max(jnp.maximum(0, coil_length - max_coil_length)) + coil_curvature_loss = 1e3*jnp.max(jnp.maximum(0, coil_curvature - max_coil_curvature)) + normal_cross_GradB_dot_grad_B_dot_GradB_surface = jnp.sum(jnp.abs(loss_normal_cross_GradB_dot_grad_B_dot_GradB_surface(surface, field))) bdotn_over_b = BdotN_over_B(surface, field) @@ -233,7 +235,6 @@ def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves, currents # tracing_optimized.plot(ax=ax2, show=False) plt.tight_layout() plt.show() - # Save the surface to a VMEC file surface_optimized.to_vmec('input.optimized') @@ -244,6 +245,7 @@ def loss_coils_and_surface(x, surface_all, field_nearaxis, dofs_curves, currents surface_optimized.to_vtk('optimized_surface', field=BiotSavart(coils_optimized)) coils_optimized.to_vtk('optimized_coils') field_nearaxis_optimized.to_vtk('optimized_field_nearaxis', r=major_radius_coils/12, field=BiotSavart(coils_optimized)) + # tracing_initial.to_vtk('initial_tracing') # tracing_optimized.to_vtk('optimized_tracing') diff --git a/examples/optimize_coils_for_nearaxis.py b/examples/coil_optimization/optimize_coils_for_nearaxis.py similarity index 100% rename from examples/optimize_coils_for_nearaxis.py rename to examples/coil_optimization/optimize_coils_for_nearaxis.py diff --git a/examples/optimize_coils_particle_confinement_fullorbit.py b/examples/coil_optimization/optimize_coils_particle_confinement_fullorbit.py similarity index 78% rename from examples/optimize_coils_particle_confinement_fullorbit.py rename to examples/coil_optimization/optimize_coils_particle_confinement_fullorbit.py index e49fe27f..c7a7c756 100644 --- a/examples/optimize_coils_particle_confinement_fullorbit.py +++ b/examples/coil_optimization/optimize_coils_particle_confinement_fullorbit.py @@ -6,7 +6,6 @@ import jax.numpy as jnp import matplotlib.pyplot as plt from essos.dynamics import Particles, Tracing -from essos.fields import BiotSavart from essos.coils import Coils, CreateEquallySpacedCurves from essos.optimization import optimize_loss_function from essos.objective_functions import loss_optimize_coils_for_particle_confinement @@ -51,11 +50,12 @@ # Optimize coils print(f'Optimizing coils with {maximum_function_evaluations} function evaluations and maxtime_tracing={maxtime_tracing}') time0 = time() -coils_optimized = optimize_loss_function(loss_optimize_coils_for_particle_confinement, initial_dofs=coils_initial.x, - coils=coils_initial, tolerance_optimization=1e-4, particles=particles, - maximum_function_evaluations=maximum_function_evaluations, max_coil_curvature=max_coil_curvature, - target_B_on_axis=target_B_on_axis, max_coil_length=max_coil_length, model=model, - maxtime=maxtime_tracing, num_steps=timesteps) +coils_optimized = optimize_loss_function(loss_optimize_coils_for_particle_confinement, initial_dofs=coils_initial.x, coils=coils_initial, + tolerance_optimization=1e-4, particles=particles, maximum_function_evaluations=maximum_function_evaluations, + max_coil_curvature=max_coil_curvature, target_B_on_axis=target_B_on_axis, max_coil_length=max_coil_length, + model=model, maxtime=maxtime_tracing, num_steps=500, trace_tolerance=1e-5) +# coils_optimized = optimize_coils_for_particle_confinement(coils_initial, particles, target_B_on_axis=target_B_on_axis, maxtime=maxtime_tracing, model=model, +# max_coil_length=max_coil_length, maximum_function_evaluations=maximum_function_evaluations, max_coil_curvature=max_coil_curvature) print(f" Optimization took {time()-time0:.2f} seconds") particles.to_full_orbit(BiotSavart(coils_optimized)) @@ -75,14 +75,13 @@ coils_initial.plot(ax=ax1, show=False) tracing_initial.plot(ax=ax1, show=False) for i, trajectory in enumerate(tracing_initial.trajectories): - ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}', linewidth=0.2) + ax3.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') ax3.set_xlabel('R (m)');ax3.set_ylabel('Z (m)');#ax3.legend() coils_optimized.plot(ax=ax2, show=False) tracing_optimized.plot(ax=ax2, show=False) -# for i, trajectory in enumerate(tracing_optimized.trajectories): -# ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}', linewidth=0.2) -# ax4.set_xlabel('R (m)');ax4.set_ylabel('Z (m)');#ax4.legend() -plotting_data = tracing_optimized.poincare_plot(ax=ax4, shifts = [jnp.pi/4, jnp.pi/2, 3*jnp.pi/4], show=False) +for i, trajectory in enumerate(tracing_optimized.trajectories): + ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') +ax4.set_xlabel('R (m)');ax4.set_ylabel('Z (m)');#ax4.legend() plt.tight_layout() plt.show() diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_adam.py b/examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_adam.py similarity index 100% rename from examples/optimize_coils_particle_confinement_guidingcenter_adam.py rename to examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_adam.py diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py b/examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py similarity index 100% rename from examples/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py rename to examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_augmented_lagrangian.py diff --git a/examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py b/examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_lbfgs.py similarity index 100% rename from examples/optimize_coils_particle_confinement_guidingcenter_lbfgs.py rename to examples/coil_optimization/optimize_coils_particle_confinement_guidingcenter_lbfgs.py diff --git a/examples/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py b/examples/coil_optimization/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py similarity index 100% rename from examples/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py rename to examples/coil_optimization/optimize_coils_particle_confinement_loss_fraction_augmented_lagrangian.py diff --git a/examples/coil_optimization/optimize_coils_vmec_surface.py b/examples/coil_optimization/optimize_coils_vmec_surface.py new file mode 100644 index 00000000..b10aab68 --- /dev/null +++ b/examples/coil_optimization/optimize_coils_vmec_surface.py @@ -0,0 +1,86 @@ +import os +from time import time +import jax.numpy as jnp +import matplotlib.pyplot as plt + +from essos.coils import Coils, CreateEquallySpacedCurves +from essos.fields import BiotSavart +from essos.surfaces import SurfaceRZFourier, BdotN_over_B +from essos.losses import custom_loss + +# In this exmple, `scipy.optimize.least_squares` is used, but any other optimizer, e.g. from +# `scipy.optimize.minimize` or `jaxopt`, can be used as well and may even be preferable. +from scipy.optimize import least_squares + +input_filepath = os.path.join(os.path.dirname(__file__), "input_files") +vmec_input = os.path.join(input_filepath, 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc') + +""" Creating starting coils and surface """ +N_COILS = 3; FOURIER_ORDER = 3; LARGE_R = 10; SMALL_R = 5.6; NFP = 2; N_SEGMENTS = 45; STELLSYM = True # Curve parameters +COIL_CURRENT = 1. # Amperes (optimization does not depend on current magnitude) + +init_curves = CreateEquallySpacedCurves(N_COILS, FOURIER_ORDER, LARGE_R, SMALL_R, n_segments=N_SEGMENTS, nfp=NFP, stellsym=STELLSYM) +init_coils = Coils(curves=init_curves, currents=[COIL_CURRENT]*N_COILS) +init_field = BiotSavart(init_coils) +surface = SurfaceRZFourier.from_wout_file(vmec_input, s=1, ntheta=30, nphi=30, range_torus='half period') + +""" Setting the losses weights and targets """ +LENGTH_WEIGHT = 1.; LENGTH_TARGET = 32. +CURVATURE_WEIGHT = 1.; CURVATURE_TARGET = 0.1 +NORMAL_FIELD_WEIGHT = 1. + +""" Creating the loss functions """ +def loss(field, surface): + return jnp.sum(jnp.abs(BdotN_over_B(surface, field))) + +def loss_length(field): + return jnp.mean(jnp.maximum(0, field.coils.length - LENGTH_TARGET)) + +def loss_curvature(field): + return jnp.mean(jnp.maximum(0, field.coils.curvature - CURVATURE_TARGET)) + +""" Defining custom losses """ +L_normal_field = custom_loss(loss, "field", surface=surface) +L_length = custom_loss(loss_length, "field") +L_curvature = custom_loss(loss_curvature, "field") + +""" Defining total loss + setting dependencies """ +L_total = NORMAL_FIELD_WEIGHT*L_normal_field + LENGTH_WEIGHT*L_length + CURVATURE_WEIGHT*L_curvature +L_total.dependencies = {"field": init_field} + +""" Optimizing the total loss """ +t_start = time() +res = least_squares(L_total, L_total.starting_dofs, L_total.grad, verbose=2, ftol=1e-5, gtol=1e-5, xtol=1e-14, max_nfev=200) +t_end = time() + +print(f"\nOptimization took {t_end - t_start:.2f} seconds") +print("Initial loss:", L_total(L_total.starting_dofs)) +print("Loss after optimization:", L_total(res.x)) + +opt_field = L_total.dofs_to_pytree(res.x)["field"] +opt_coils = opt_field.coils + +fig = plt.figure(figsize=(8, 4)) + +ax1 = fig.add_subplot(121, projection='3d') +init_coils.plot(ax=ax1, show=False) +surface.plot(ax=ax1, show=False) +ax2 = fig.add_subplot(122, projection='3d') +opt_coils.plot(ax=ax2, show=False) +surface.plot(ax=ax2, show=False) +plt.tight_layout() +plt.show() + +EXPORT = False +if EXPORT: + output_filepath = os.path.join(os.path.dirname(__file__), "output") + + """ Save the coils to a json file """ + init_coils.to_json(os.path.join(output_filepath, "init_coils_vmec_surface.json")) + opt_coils.to_json(os.path.join(output_filepath, "opt_coils_vmec_surface.json")) + + """ Save results in vtk format to analyze in Paraview """ + surface.to_vtk(os.path.join(output_filepath, "init_surface_vmec_surface.json"), field=init_field) + surface.to_vtk(os.path.join(output_filepath, "final_surface_vmec_surface.json"), field=opt_field) + init_coils.to_vtk(os.path.join(output_filepath, "init_coils_vmec_surface.json")) + opt_coils.to_vtk(os.path.join(output_filepath, "opt_coils_vmec_surface.json")) \ No newline at end of file diff --git a/examples/optimize_coils_vmec_surface_augmented_lagrangian.py b/examples/coil_optimization/optimize_coils_vmec_surface_augmented_lagrangian.py similarity index 100% rename from examples/optimize_coils_vmec_surface_augmented_lagrangian.py rename to examples/coil_optimization/optimize_coils_vmec_surface_augmented_lagrangian.py diff --git a/examples/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py b/examples/coil_optimization/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py similarity index 100% rename from examples/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py rename to examples/coil_optimization/optimize_coils_vmec_surface_augmented_lagrangian_stochastic.py diff --git a/examples/optimize_multiple_objectives.py b/examples/coil_optimization/optimize_multiple_objectives.py similarity index 100% rename from examples/optimize_multiple_objectives.py rename to examples/coil_optimization/optimize_multiple_objectives.py diff --git a/examples/compare_guidingcenter_fullorbit.py b/examples/compare_guidingcenter_fullorbit.py deleted file mode 100644 index 27a3b518..00000000 --- a/examples/compare_guidingcenter_fullorbit.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -number_of_processors_to_use = 1 # Parallelization, this should divide nparticles -os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' -from jax import vmap -from time import time -import jax.numpy as jnp -import matplotlib.pyplot as plt -from essos.fields import BiotSavart -from essos.coils import Coils_from_json -from essos.constants import PROTON_MASS, ONE_EV -from essos.dynamics import Tracing, Particles -from jax import block_until_ready - -# Input parameters -tmax = 1.e-4 -dt_fo=1.e-9 -nparticles_per_core=2 -nparticles = number_of_processors_to_use*nparticles_per_core -R0 = jnp.linspace(1.23, 1.27, nparticles) -trace_tolerance = 1e-5 -num_steps_gc = 5000 -num_steps_fo = int(tmax/dt_fo) -mass=PROTON_MASS -energy=5000*ONE_EV - -# Load coils and field -json_file = os.path.join(os.path.dirname(__file__), 'input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) -field = BiotSavart(coils) - -# Initialize particles -Z0 = jnp.zeros(nparticles) -phi0 = jnp.zeros(nparticles) -initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T -initial_vparallel_over_v = jnp.linspace(-0.1, 0.1, nparticles) -particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, field=field, initial_vparallel_over_v=initial_vparallel_over_v) - -# Trace in ESSOS -time0 = time() -tracing_guidingcenter = Tracing(field=field, model='GuidingCenterAdaptative', particles=particles, - maxtime=tmax,times_to_trace=num_steps_gc, atol=trace_tolerance,rtol=trace_tolerance) -trajectories_guidingcenter = block_until_ready(tracing_guidingcenter.trajectories) -print(f"ESSOS guiding center tracing took {time()-time0:.2f} seconds") - -time0 = time() -tracing_fullorbit = Tracing(field=field, model='FullOrbit_Boris', particles=particles, - maxtime=tmax, times_to_trace=num_steps_fo,timestep=dt_fo) -trajectories_fullorbit = block_until_ready(tracing_fullorbit.trajectories) -print(f"ESSOS full orbit tracing took {time()-time0:.2f} seconds") - -# Plot trajectories, velocity parallel to the magnetic field, and energy error -fig = plt.figure(figsize=(9, 8)) -ax1 = fig.add_subplot(221, projection='3d') -ax2 = fig.add_subplot(222) -ax3 = fig.add_subplot(223) -ax4 = fig.add_subplot(224) - -coils.plot(ax=ax1, show=False) -tracing_guidingcenter.plot(ax=ax1, show=False) -tracing_fullorbit.plot(ax=ax1, show=False) - -for i, (trajectory_gc, trajectory_fo) in enumerate(zip(trajectories_guidingcenter, trajectories_fullorbit)): - ax2.plot(tracing_guidingcenter.times, jnp.abs(tracing_guidingcenter.energy[i]-particles.energy)/particles.energy, '-', label=f'Particle {i+1} GC', linewidth=1.0, alpha=0.7) - ax2.plot(tracing_fullorbit.times, jnp.abs(tracing_fullorbit.energy[i]-particles.energy)/particles.energy, '--', label=f'Particle {i+1} FO', linewidth=1.0, markersize=0.5, alpha=0.7) - def compute_v_parallel(trajectory_t): - magnetic_field_unit_vector = field.B(trajectory_t[:3]) / field.AbsB(trajectory_t[:3]) - return jnp.dot(trajectory_t[3:], magnetic_field_unit_vector) - v_parallel_fo = vmap(compute_v_parallel)(trajectory_fo) - ax3.plot(tracing_guidingcenter.times, trajectory_gc[:, 3] / particles.total_speed, '-', label=f'Particle {i+1} GC', linewidth=1.1, alpha=0.95) - ax3.plot(tracing_fullorbit.times, v_parallel_fo / particles.total_speed, '--', label=f'Particle {i+1} FO', linewidth=0.5, markersize=0.5, alpha=0.2) - # ax4.plot(jnp.sqrt(trajectory_gc[:,0]**2+trajectory_gc[:,1]**2), trajectory_gc[:, 2], '-', label=f'Particle {i+1} GC', linewidth=1.5, alpha=0.3) - # ax4.plot(jnp.sqrt(trajectory_fo[:,0]**2+trajectory_fo[:,1]**2), trajectory_fo[:, 2], '--', label=f'Particle {i+1} FO', linewidth=1.5, markersize=0.5, alpha=0.2) -tracing_guidingcenter.poincare_plot(ax=ax4, show=False, color='k', label=f'GC', shifts=[jnp.pi/2])#, 0]) -tracing_fullorbit.poincare_plot( ax=ax4, show=False, color='r', label=f'FO', shifts=[jnp.pi/2])#, 0]) - -ax2.set_xlabel('Time (s)') -ax2.set_ylabel('Relative Energy Error') -ax3.set_ylabel(r'$v_{\parallel}/v$') -ax2.legend(loc='upper right') -ax3.set_xlabel('Time (s)') -ax3.legend(loc='upper right') -ax4.set_xlabel('R (m)') -ax4.set_ylabel('Z (m)') -ax4.legend(loc='upper right') -plt.tight_layout() -plt.show() - - -## Save results in vtk format to analyze in Paraview -# tracing.to_vtk('trajectories') -# coils.to_vtk('coils') \ No newline at end of file diff --git a/examples/comparisons_SIMSOPT/coils_biotsavart_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/coils_biotsavart_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index c89fc99f..00000000 --- a/examples/comparisons_SIMSOPT/coils_biotsavart_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,260 +0,0 @@ -import os -from time import time -import jax.numpy as jnp -import matplotlib.pyplot as plt -from jax import block_until_ready -from essos.fields import BiotSavart as BiotSavart_essos -from essos.coils import Coils_from_simsopt, Curves_from_simsopt -from simsopt import load -from simsopt.geo import CurveXYZFourier, curves_to_vtk -from simsopt.field import BiotSavart as BiotSavart_simsopt, coils_via_symmetries -from simsopt.configs import get_ncsx_data, get_w7x_data, get_hsx_data, get_giuliani_data - -output_dir = os.path.join(os.path.dirname(__file__), 'output') -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -list_segments = [30, 100, 300, 1000, 3000] - -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '..', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') -nfp_array = [3, 2, 5, 4, 2] -curves_array = [get_ncsx_data()[0], LandremanPaulQA_json_file, get_w7x_data()[0], get_hsx_data()[0], get_giuliani_data()[0]] -currents_array = [get_ncsx_data()[1], None, get_w7x_data()[1], get_hsx_data()[1], get_giuliani_data()[1]] -name_array = ["NCSX", "QA(json)", "W7-X", "HSX", "Giuliani"] - -print(f'Output being saved to {output_dir}') -print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') -for nfp, curves_stel, currents_stel, name in zip(nfp_array, curves_array, currents_array, name_array): - print(f' Running {name} and saving to output directory...') - if currents_stel is None: - json_file_stel = curves_stel - field_simsopt = load(json_file_stel) - coils_simsopt = field_simsopt.coils - curves_simsopt = [coil.curve for coil in coils_simsopt] - currents_simsopt = [coil.current for coil in coils_simsopt] - coils_essos = Coils_from_simsopt(json_file_stel, nfp) - curves_essos = Curves_from_simsopt(json_file_stel, nfp) - else: - coils_simsopt = coils_via_symmetries(curves_stel, currents_stel, nfp, True) - curves_simsopt = [c.curve for c in coils_simsopt] - currents_simsopt = [c.current for c in coils_simsopt] - field_simsopt = BiotSavart_simsopt(coils_simsopt) - - coils_essos = Coils_from_simsopt(coils_simsopt, nfp) - curves_essos = Curves_from_simsopt(curves_simsopt, nfp) - - field_essos = BiotSavart_essos(coils_essos) - - coils_essos_to_simsopt = coils_essos.to_simsopt() - curves_essos_to_simsopt = curves_essos.to_simsopt() - field_essos_to_simsopt = BiotSavart_simsopt(coils_essos_to_simsopt) - - curves_to_vtk(curves_simsopt, os.path.join(output_dir,f"curves_simsopt_{name}")) - curves_essos.to_vtk(os.path.join(output_dir,f"curves_essos_{name}")) - curves_to_vtk(curves_essos_to_simsopt, os.path.join(output_dir,f"curves_essos_to_simsopt_{name}")) - - base_coils_simsopt = coils_simsopt[:int(len(coils_simsopt)/2/nfp)] - R = jnp.mean(jnp.array([jnp.sqrt(coil.curve.x[coil.curve.local_dof_names.index('xc(0)')]**2 - +coil.curve.x[coil.curve.local_dof_names.index('yc(0)')]**2) - for coil in base_coils_simsopt])) - x = jnp.array([R+0.01,R,R]) - y = jnp.array([R,R+0.01,R-0.01]) - z = jnp.array([0.05,0.06,0.07]) - - positions = jnp.array((x,y,z)) - - len_list_segments = len(list_segments) - t_gamma_avg_essos = jnp.zeros(len_list_segments) - t_gamma_avg_simsopt = jnp.zeros(len_list_segments) - gamma_error_avg = jnp.zeros(len_list_segments) - t_gammadash_avg_essos = jnp.zeros(len_list_segments) - t_gammadash_avg_simsopt = jnp.zeros(len_list_segments) - gammadash_error_avg = jnp.zeros(len_list_segments) - t_gammadashdash_avg_essos = jnp.zeros(len_list_segments) - t_gammadashdash_avg_simsopt = jnp.zeros(len_list_segments) - gammadashdash_error_avg = jnp.zeros(len_list_segments) - t_curvature_avg_essos = jnp.zeros(len_list_segments) - t_curvature_avg_simsopt = jnp.zeros(len_list_segments) - curvature_error_avg = jnp.zeros(len_list_segments) - t_B_avg_essos = jnp.zeros(len_list_segments) - t_B_avg_simsopt = jnp.zeros(len_list_segments) - B_error_avg = jnp.zeros(len_list_segments) - t_dB_by_dX_avg_essos = jnp.zeros(len_list_segments) - t_dB_by_dX_avg_simsopt = jnp.zeros(len_list_segments) - dB_by_dX_error_avg = jnp.zeros(len_list_segments) - - gamma_error_simsopt_to_essos = 0 - gamma_error_essos_to_simsopt = 0 - - for i, (coil_simsopt, coil_essos_gamma, coil_essos_to_simsopt) in enumerate(zip(coils_simsopt, coils_essos.gamma, coils_essos_to_simsopt)): - gamma_error_simsopt_to_essos += jnp.linalg.norm(coil_simsopt.curve.gamma()-coil_essos_gamma) - gamma_error_essos_to_simsopt += jnp.linalg.norm(coil_simsopt.curve.gamma()-coil_essos_to_simsopt.curve.gamma()) - - B_error_avg_simsopt_to_essos = 0 - B_error_avg_essos_to_simsopt = 0 - for j, position in enumerate(positions): - field_simsopt.set_points([position]) - field_essos_to_simsopt.set_points([position]) - B_simsopt = field_simsopt.B() - B_essos_to_simsopt = field_essos_to_simsopt.B() - B_simsopt_to_essos = field_essos.B(position) - B_error_avg_simsopt_to_essos += jnp.abs(jnp.linalg.norm(B_simsopt) - jnp.linalg.norm(B_simsopt_to_essos)) - B_error_avg_essos_to_simsopt += jnp.abs(jnp.linalg.norm(B_simsopt) - jnp.linalg.norm(B_essos_to_simsopt)) - B_error_avg_simsopt_to_essos = B_error_avg_simsopt_to_essos/len(positions) - B_error_avg_essos_to_simsopt = B_error_avg_essos_to_simsopt/len(positions) - - fig = plt.figure(figsize = (8, 6)) - X_axis = jnp.arange(2) - plt.bar(X_axis[0] - 0.2, gamma_error_simsopt_to_essos+1e-19, 0.3, label='SIMSOPT to ESSOS coils', color='blue', edgecolor='black', hatch='/') - plt.bar(X_axis[0] + 0.2, gamma_error_essos_to_simsopt+1e-19, 0.3, label='ESSOS to SIMSOPT coils', color='red', edgecolor='black', hatch='-') - plt.bar(X_axis[1] - 0.2, B_error_avg_simsopt_to_essos+1e-19, 0.3, label=r'SIMSOPT to ESSOS $B$', color='blue', edgecolor='black', hatch='||') - plt.bar(X_axis[1] + 0.2, B_error_avg_essos_to_simsopt+1e-19, 0.3, label=r'ESSOS to SIMSOPT $B$', color='red', edgecolor='black', hatch='*') - plt.xticks(X_axis, ['Coil Error', 'B Error']) - plt.xlabel('Parameter', fontsize=14) - plt.ylabel('Error Magnitude', fontsize=14) - plt.yscale('log') - plt.ylim(1e-20, 1e-11) - plt.legend(fontsize=14) - plt.grid(axis='y') - plt.title(f"{name}", fontsize=14) - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f"error_gamma_B_SIMSOPT_vs_ESSOS_{name}.pdf"), transparent=True) - plt.close() - - def update_nsegments_simsopt(curve_simsopt, n_segments): - new_curve = CurveXYZFourier(n_segments, curve_simsopt.order) - new_curve.x = curve_simsopt.x - return new_curve - - for index, n_segments in enumerate(list_segments): - coils_essos.n_segments = n_segments - - base_curves_simsopt = [update_nsegments_simsopt(coil_simsopt.curve, n_segments) for coil_simsopt in base_coils_simsopt] - coils_simsopt = coils_via_symmetries(base_curves_simsopt, currents_simsopt[0:len(base_coils_simsopt)], nfp, True) - curves_simsopt = [c.curve for c in coils_simsopt] - - [curve.gamma() for curve in curves_simsopt] - coils_essos.gamma - - start_time = time() - gamma_curves_simsopt = block_until_ready(jnp.array([curve.gamma() for curve in curves_simsopt])) - t_gamma_avg_simsopt = t_gamma_avg_simsopt.at[index].set(t_gamma_avg_simsopt[index] + time() - start_time) - - start_time = time() - gamma_curves_essos = block_until_ready(jnp.array(coils_essos.gamma)) - t_gamma_avg_essos = t_gamma_avg_essos.at[index].set(t_gamma_avg_essos[index] + time() - start_time) - - start_time = time() - gammadash_curves_simsopt = block_until_ready(jnp.array([curve.gammadash() for curve in curves_simsopt])) - t_gammadash_avg_simsopt = t_gammadash_avg_simsopt.at[index].set(t_gammadash_avg_simsopt[index] + time() - start_time) - - start_time = time() - gammadash_curves_essos = block_until_ready(jnp.array(coils_essos.gamma_dash)) - t_gammadash_avg_essos = t_gammadash_avg_essos.at[index].set(t_gammadash_avg_essos[index] + time() - start_time) - - start_time = time() - gammadashdash_curves_simsopt = block_until_ready(jnp.array([curve.gammadashdash() for curve in curves_simsopt])) - t_gammadashdash_avg_simsopt = t_gammadashdash_avg_simsopt.at[index].set(t_gammadashdash_avg_simsopt[index] + time() - start_time) - - start_time = time() - gammadashdash_curves_essos = block_until_ready(jnp.array(coils_essos.gamma_dashdash)) - t_gammadashdash_avg_essos = t_gammadashdash_avg_essos.at[index].set(t_gammadashdash_avg_essos[index] + time() - start_time) - - start_time = time() - curvature_curves_simsopt = block_until_ready(jnp.array([curve.kappa() for curve in curves_simsopt])) - t_curvature_avg_simsopt = t_curvature_avg_simsopt.at[index].set(t_curvature_avg_simsopt[index] + time() - start_time) - - start_time = time() - curvature_curves_essos = block_until_ready(jnp.array(coils_essos.curvature)) - t_curvature_avg_essos = t_curvature_avg_essos.at[index].set(t_curvature_avg_essos[index] + time() - start_time) - - gamma_error_avg = gamma_error_avg. at[index].set(gamma_error_avg[index] + jnp.linalg.norm(gamma_curves_essos - gamma_curves_simsopt)) - gammadash_error_avg = gammadash_error_avg. at[index].set(gammadash_error_avg[index] + jnp.linalg.norm(gammadash_curves_essos - gammadash_curves_simsopt)) - gammadashdash_error_avg = gammadashdash_error_avg.at[index].set(gammadashdash_error_avg[index] + jnp.linalg.norm(gammadashdash_curves_essos - gammadashdash_curves_simsopt)) - curvature_error_avg = curvature_error_avg.at[index].set(curvature_error_avg[index] + jnp.linalg.norm(curvature_curves_essos - curvature_curves_simsopt)) - - field_essos = BiotSavart_essos(coils_essos) - field_simsopt = BiotSavart_simsopt(coils_simsopt) - - for j, position in enumerate(positions): - field_essos.B(position) - time1 = time() - result_B_essos = field_essos.B(position) - t_B_avg_essos = t_B_avg_essos.at[index].set(t_B_avg_essos[index] + time() - time1) - normB_essos = jnp.linalg.norm(result_B_essos) - - field_simsopt.set_points(jnp.array([position])) - field_simsopt.B() - time3 = time() - field_simsopt.set_points(jnp.array([position])) - result_simsopt = field_simsopt.B() - t_B_avg_simsopt = t_B_avg_simsopt.at[index].set(t_B_avg_simsopt[index] + time() - time3) - normB_simsopt = jnp.linalg.norm(jnp.array(result_simsopt)) - - B_error_avg = B_error_avg.at[index].set(B_error_avg[index] + jnp.abs(normB_essos - normB_simsopt)) - - field_essos.dB_by_dX(position) - time1 = time() - field_simsopt.set_points(jnp.array([position])) - result_dB_by_dX_essos = field_essos.dB_by_dX(position) - t_dB_by_dX_avg_essos = t_dB_by_dX_avg_essos.at[index].set(t_dB_by_dX_avg_essos[index] + time() - time1) - norm_dB_by_dX_essos = jnp.linalg.norm(result_dB_by_dX_essos) - - field_simsopt.dB_by_dX() - time3 = time() - field_simsopt.set_points(jnp.array([position])) - result_dB_by_dX_simsopt = field_simsopt.dB_by_dX() - t_dB_by_dX_avg_simsopt = t_dB_by_dX_avg_simsopt.at[index].set(t_dB_by_dX_avg_simsopt[index] + time() - time3) - norm_dB_by_dX_simsopt = jnp.linalg.norm(jnp.array(result_dB_by_dX_simsopt)) - - dB_by_dX_error_avg = dB_by_dX_error_avg.at[index].set(dB_by_dX_error_avg[index] + jnp.abs(norm_dB_by_dX_essos - norm_dB_by_dX_simsopt)) - - X_axis = jnp.arange(len_list_segments) - - fig = plt.figure(figsize = (8, 6)) - plt.bar(X_axis-0.2, B_error_avg, 0.1, label = r"$B_{\text{essos}} - B_{\text{simsopt}}$", color="green", edgecolor="black", hatch="/") - plt.bar(X_axis-0.1, dB_by_dX_error_avg, 0.1, label = r"${B'}_{\text{essos}} - {B'}_{\text{simsopt}}$", color="purple", edgecolor="black", hatch="x") - plt.bar(X_axis+0.0, gamma_error_avg, 0.1, label = r"$\Gamma_{\text{essos}} - \Gamma_{\text{simsopt}}$", color="orange", edgecolor="black", hatch="|") - plt.bar(X_axis+0.1, gammadash_error_avg, 0.1, label = r"${\Gamma'}_{\text{essos}} - {\Gamma'}_{\text{simsopt}}$", color="gray", edgecolor="black", hatch="-") - plt.bar(X_axis+0.2, gammadashdash_error_avg, 0.1, label = r"${\Gamma''}_{\text{essos}} - {\Gamma''}_{\text{simsopt}}$", color="black", edgecolor="black", hatch="*") - plt.bar(X_axis+0.3, curvature_error_avg, 0.1, label = r"$\kappa_{\text{essos}} - \kappa_{\text{simsopt}}$", color="brown", edgecolor="black", hatch="\\") - plt.xticks(X_axis, list_segments) - plt.xlabel("Number of segments of each coil", fontsize=14) - plt.ylabel(f"Difference SIMSOPT vs ESSOS", fontsize=14) - plt.tick_params(axis='both', which='major', labelsize=14) - plt.tick_params(axis='both', which='minor', labelsize=14) - plt.legend(fontsize=14) - plt.yscale("log") - plt.grid(axis='y') - plt.ylim(1e-18, 1e-10) - plt.title(f"{name}", fontsize=14) - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f"error_BiotSavart_SIMSOPT_vs_ESSOS_{name}.pdf"), transparent=True) - plt.close() - - fig = plt.figure(figsize = (8, 6)) - plt.bar(X_axis - 0.30, t_B_avg_essos, 0.05, label = r'B ESSOS', color="red", edgecolor="black") - plt.bar(X_axis - 0.25, t_B_avg_simsopt, 0.05, label = r'B SIMSOPT', color="blue", edgecolor="black") - plt.bar(X_axis - 0.20, t_dB_by_dX_avg_essos, 0.05, label = r"$B'$ ESSOS", color="red", edgecolor="black") - plt.bar(X_axis - 0.15, t_dB_by_dX_avg_simsopt, 00.05, label = r"$B'$ SIMSOPT", color="blue", edgecolor="black") - plt.bar(X_axis - 0.10, t_gamma_avg_essos, 0.05, label = r'$\Gamma$ ESSOS', color="red", edgecolor="black", hatch="//") - plt.bar(X_axis - 0.05, t_gamma_avg_simsopt, 0.05, label = r'$\Gamma$ SIMSOPT', color="blue", edgecolor="black", hatch="-") - plt.bar(X_axis + 0.0, t_gammadash_avg_essos, 0.05, label = r"${\Gamma'}$ ESSOS", color="red", edgecolor="black", hatch="\\") - plt.bar(X_axis + 0.05, t_gammadash_avg_simsopt, 0.05, label = r"${\Gamma'}$ SIMSOPT", color="blue", edgecolor="black", hatch="||") - plt.bar(X_axis + 0.10, t_gammadashdash_avg_essos, 0.05, label = r"${\Gamma''}$ ESSOS", color="red", edgecolor="black", hatch="*") - plt.bar(X_axis + 0.15, t_gammadashdash_avg_simsopt, 0.05, label = r"${\Gamma''}$ SIMSOPT", color="blue", edgecolor="black", hatch="|") - plt.bar(X_axis + 0.20, t_curvature_avg_essos, 0.05, label = r"$\kappa$ ESSOS", color="red", edgecolor="black", hatch="x") - plt.bar(X_axis + 0.25, t_curvature_avg_simsopt, 0.05, label = r"$\kappa$ SIMSOPT", color="blue", edgecolor="black", hatch="+") - plt.tick_params(axis='both', which='major', labelsize=14) - plt.tick_params(axis='both', which='minor', labelsize=14) - plt.xticks(X_axis, list_segments) - plt.xlabel("Number of segments of each coil", fontsize=14) - plt.ylabel("Time to evaluate SIMSOPT vs ESSOS (s)", fontsize=14) - plt.grid(axis='y') - # plt.gca().set_ylim((None,0.03)) - plt.yscale("log") - plt.legend(fontsize=14) - plt.title(f"{name}", fontsize=14) - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f"time_BiotSavart_SIMSOPT_vs_ESSOS_{name}.pdf"), transparent=True) - plt.close() diff --git a/examples/comparisons_SIMSOPT/fieldlines_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/fieldlines_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index fd0473e9..00000000 --- a/examples/comparisons_SIMSOPT/fieldlines_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,195 +0,0 @@ -import os -import time -import jax.numpy as jnp -from jax import block_until_ready -from simsopt import load -from simsopt.field import (particles_to_vtk, compute_fieldlines, plot_poincare_data) -from essos.coils import Coils_from_simsopt -from essos.dynamics import Tracing -from essos.fields import BiotSavart as BiotSavart_essos -import matplotlib.pyplot as plt - -tmax_fl = 150 -nfieldlines = 3 -axis_shft=0.02 -R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nfieldlines) -nfp = 2 -trace_tolerance_SIMSOPT_array = [1e-5, 1e-7, 1e-9, 1e-11, 1e-13] -trace_tolerance_ESSOS = 1e-7 - -Z0 = jnp.zeros(nfieldlines) -phi0 = jnp.zeros(nfieldlines) - -phis_poincare = [(i/4)*(2*jnp.pi/nfp) for i in range(4)] - -output_dir = os.path.join(os.path.dirname(__file__), 'output') -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '..', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') -field_simsopt = load(LandremanPaulQA_json_file) -field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) - -fieldlines_SIMSOPT_array = [] -time_SIMSOPT_array = [] -avg_steps_SIMSOPT = 0 - -print(f'Output being saved to {output_dir}') -print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') -for trace_tolerance_SIMSOPT in trace_tolerance_SIMSOPT_array: - print(f' Tracing SIMSOPT fieldlines with tolerance={trace_tolerance_SIMSOPT}') - t1 = time.time() - fieldlines_SIMSOPT_this_tolerance, fieldlines_SIMSOPT_phi_hits = block_until_ready(compute_fieldlines(field_simsopt, R0, Z0, tmax=tmax_fl, tol=trace_tolerance_SIMSOPT, phis=phis_poincare)) - time_SIMSOPT_array.append(time.time()-t1) - avg_steps_SIMSOPT += sum([len(l) for l in fieldlines_SIMSOPT_this_tolerance])//nfieldlines - print(f" Time for SIMSOPT tracing={time.time()-t1:.3f}s. Avg num steps={avg_steps_SIMSOPT}") - fieldlines_SIMSOPT_array.append(fieldlines_SIMSOPT_this_tolerance) - -particles_to_vtk(fieldlines_SIMSOPT_this_tolerance, os.path.join(output_dir,f'fieldlines_SIMSOPT')) -# plot_poincare_data(fieldlines_phi_hits, phis_poincare, os.path.join(output_dir,f'poincare_fieldline_SIMSOPT.pdf'), dpi=150) - -# Trace in ESSOS -num_steps_essos = int(jnp.mean(jnp.array([len(fieldlines_SIMSOPT[0]) for fieldlines_SIMSOPT in fieldlines_SIMSOPT_array]))) -time_essos = jnp.linspace(0, tmax_fl, num_steps_essos) - -print(f'Tracing ESSOS fieldlines with tolerance={trace_tolerance_ESSOS}') -t1 = time.time() -tracing = block_until_ready(Tracing(field=field_essos, model='FieldLine', initial_conditions=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T, - maxtime=tmax_fl, timesteps=num_steps_essos, tol_step_size=trace_tolerance_ESSOS)) -fieldlines_ESSOS = tracing.trajectories -time_ESSOS = time.time()-t1 -print(f" Time for ESSOS tracing={time.time()-t1:.3f}s. Num steps={len(fieldlines_ESSOS[0])}") - -tracing.to_vtk(os.path.join(output_dir,f'fieldlines_ESSOS')) -# tracing.poincare_plot(phis_poincare, show=False) - -print('Plotting the results to output directory...') -# Plot time comparison in a bar chart -labels = [f'SIMSOPT\nTol={tol}' for tol in trace_tolerance_SIMSOPT_array] + [f'ESSOS\nTol={trace_tolerance_ESSOS}'] -times = time_SIMSOPT_array + [time_ESSOS] -plt.figure() -bars = plt.bar(labels, times, color=['blue']*len(trace_tolerance_SIMSOPT_array) + ['red'], edgecolor=['black']*len(trace_tolerance_SIMSOPT_array) + ['black'], hatch=['//']*len(trace_tolerance_SIMSOPT_array) + ['|']) -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Time (s)') -plt.xticks(rotation=45) -plt.tight_layout() -blue_patch = plt.Line2D([0], [0], color='blue', lw=4, label='SIMSOPT', linestyle='--') -orange_patch = plt.Line2D([0], [0], color='red', lw=4, label=f'ESSOS', linestyle='-') -plt.legend(handles=[blue_patch, orange_patch]) -plt.savefig(os.path.join(output_dir, 'times_fieldlines_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -def interpolate_ESSOS_to_SIMSOPT(fieldine_SIMSOPT, fieldline_ESSOS): - time_SIMSOPT = jnp.array(fieldine_SIMSOPT)[:, 0] # Time values from fieldlines_SIMSOPT - # coords_SIMSOPT = jnp.array(fieldine_SIMSOPT)[:, 1:] # Coordinates (x, y, z) from fieldlines_SIMSOPT - coords_ESSOS = jnp.array(fieldline_ESSOS) - - interp_x = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 0]) - interp_y = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 1]) - interp_z = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 2]) - - coords_ESSOS_interp = jnp.column_stack([ interp_x, interp_y, interp_z]) - - return coords_ESSOS_interp - -relative_error_array = [] -for i, fieldlines_SIMSOPT in enumerate(fieldlines_SIMSOPT_array): - fieldlines_ESSOS_interp = [interpolate_ESSOS_to_SIMSOPT(fieldlines_SIMSOPT[i], fieldlines_ESSOS[i]) for i in range(nfieldlines)] - tracing.trajectories = fieldlines_ESSOS_interp - if i==len(trace_tolerance_SIMSOPT_array)-1: tracing.to_vtk(os.path.join(output_dir,f'fieldlines_ESSOS_interp')) - - relative_error_fieldlines_SIMSOPT_vs_ESSOS = [] - plt.figure() - for j in range(nfieldlines): - this_fieldline_SIMSOPT = jnp.array(fieldlines_SIMSOPT[j])[:,1:] - this_fieldlines_ESSOS = fieldlines_ESSOS_interp[j] - average_relative_error = [] - for fieldline_SIMSOPT_t, fieldline_ESSOS_t in zip(this_fieldline_SIMSOPT, this_fieldlines_ESSOS): - relative_error_x = jnp.abs(fieldline_SIMSOPT_t[0] - fieldline_ESSOS_t[0])/(jnp.abs(fieldline_SIMSOPT_t[0])+1e-12) - relative_error_y = jnp.abs(fieldline_SIMSOPT_t[1] - fieldline_ESSOS_t[1])/(jnp.abs(fieldline_SIMSOPT_t[1])+1e-12) - relative_error_z = jnp.abs(fieldline_SIMSOPT_t[2] - fieldline_ESSOS_t[2])/(jnp.abs(fieldline_SIMSOPT_t[2])+1e-12) - average_relative_error.append((relative_error_x + relative_error_y + relative_error_z)/3) - average_relative_error = jnp.array(average_relative_error) - relative_error_fieldlines_SIMSOPT_vs_ESSOS.append(average_relative_error) - plt.plot(jnp.linspace(0, tmax_fl, len(average_relative_error))[1:], average_relative_error[1:], label=f'Fieldline {j}') - plt.legend() - plt.xlabel('Time') - plt.ylabel('Relative Error') - plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir, f'relative_error_fieldlines_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - # relative_error_fieldlines_SIMSOPT_vs_ESSOS = jnp.array(relative_error_fieldlines_SIMSOPT_vs_ESSOS) - # print(f"Relative difference between SIMSOPT and ESSOS fieldlines={relative_error_fieldlines_SIMSOPT_vs_ESSOS}") - relative_error_array.append(relative_error_fieldlines_SIMSOPT_vs_ESSOS) - - plt.figure() - for j in range(nfieldlines): - R_SIMSOPT = jnp.sqrt(fieldlines_SIMSOPT[j][:,1]**2+fieldlines_SIMSOPT[j][:,2]**2) - phi_SIMSOPT = jnp.arctan2(fieldlines_SIMSOPT[j][:,2], fieldlines_SIMSOPT[j][:,1]) - Z_SIMSOPT = fieldlines_SIMSOPT[j][:,3] - - R_ESSOS = jnp.sqrt(fieldlines_ESSOS_interp[j][:,0]**2+fieldlines_ESSOS_interp[j][:,1]**2) - phi_ESSOS = jnp.arctan2(fieldlines_ESSOS_interp[j][:,1], fieldlines_ESSOS_interp[j][:,0]) - Z_ESSOS = fieldlines_ESSOS_interp[j][:,2] - - plt.plot(R_SIMSOPT, Z_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {j}') - plt.plot(R_ESSOS, Z_ESSOS, '--', linewidth=2.5, label=f'ESSOS {j}') - plt.legend() - plt.xlabel('R') - plt.ylabel('Z') - plt.savefig(os.path.join(output_dir,f'fieldlines_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - -# Calculate RMS error for each tolerance -rms_error_array = jnp.array([[jnp.sqrt(jnp.mean(jnp.square(jnp.array(error)))) for error in relative_error] for relative_error in relative_error_array]) - -# Plot RMS error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(rms_error_array.shape[1]): - plt.bar(x + i * bar_width, rms_error_array[:, i], bar_width, label=f'Fieldline {i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('RMS Error') -plt.yscale('log') -plt.xticks(x + bar_width * (rms_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'rms_error_fieldlines_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Calculate maximum error for each tolerance -max_error_array = jnp.array([[jnp.max(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) -# Plot maximum error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(max_error_array.shape[1]): - plt.bar(x + i * bar_width, max_error_array[:, i], bar_width, label=f'Fieldline {i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Maximum Error') -plt.yscale('log') -plt.xticks(x + bar_width * (max_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'max_error_fieldlines_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Calculate mean error for each tolerance -mean_error_array = jnp.array([[jnp.mean(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) -# Plot mean error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(mean_error_array.shape[1]): - plt.bar(x + i * bar_width, mean_error_array[:, i], bar_width, label=f'Fieldline {i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Mean Error') -plt.yscale('log') -plt.xticks(x + bar_width * (mean_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'mean_error_fieldlines_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() diff --git a/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index c1b2aba1..00000000 --- a/examples/comparisons_SIMSOPT/fullorbit_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,259 +0,0 @@ -import os -import time -import jax.numpy as jnp -from jax import block_until_ready, random -from simsopt import load -from simsopt.field import (particles_to_vtk, trace_particles, plot_poincare_data) -from essos.coils import Coils_from_simsopt -from essos.constants import PROTON_MASS, ONE_EV -from essos.dynamics import Tracing, Particles -from essos.fields import BiotSavart as BiotSavart_essos -import matplotlib.pyplot as plt - -tmax_full = 1e-5 -nparticles = 3 -axis_shft=0.02 -R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nparticles) -trace_tolerance_SIMSOPT_array = [1e-3, 1e-5, 1e-7, 1e-9]#, 1e-11] -trace_tolerance_ESSOS = 1e-5 -mass=PROTON_MASS -energy=5000*ONE_EV -model_ESSOS_array = ['FullOrbit', 'FullOrbit_Boris'] - -output_dir = os.path.join(os.path.dirname(__file__), 'output') -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -nfp=2 -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '..', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') -field_simsopt = load(LandremanPaulQA_json_file) -field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) - -Z0 = jnp.zeros(nparticles) -phi0 = jnp.zeros(nparticles) -initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T -initial_vparallel_over_v = random.uniform(random.PRNGKey(42), (nparticles,), minval=-1, maxval=1) - - -phis_poincare = [(i/4)*(2*jnp.pi/nfp) for i in range(4)] - -particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy, field=field_essos) - -# Trace in SIMSOPT -time_SIMSOPT_array = [] -trajectories_SIMSOPT_array = [] -avg_steps_SIMSOPT = 0 -relative_energy_error_SIMSOPT_array = [] -print(f'Output being saved to {output_dir}') -print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') -for trace_tolerance_SIMSOPT in trace_tolerance_SIMSOPT_array: - print(f' Tracing SIMSOPT full orbit with tolerance={trace_tolerance_SIMSOPT}') - t1 = time.time() - trajectories_SIMSOPT_this_tolerance, trajectories_SIMSOPT_phi_hits = block_until_ready(trace_particles( - field=field_simsopt, xyz_inits=particles.initial_xyz, mass=particles.mass, - parallel_speeds=particles.initial_vparallel, tmax=tmax_full, mode='full', - charge=particles.charge, Ekin=particles.energy, tol=trace_tolerance_SIMSOPT)) - time_SIMSOPT_array.append(time.time()-t1) - avg_steps_SIMSOPT += sum([len(l) for l in trajectories_SIMSOPT_this_tolerance])//nparticles - print(f" Time for SIMSOPT tracing={time.time()-t1:.3f}s. Avg num steps={avg_steps_SIMSOPT}") - trajectories_SIMSOPT_array.append(trajectories_SIMSOPT_this_tolerance) - - relative_energy_error_SIMSOPT_array.append([jnp.abs(mass*(trajectory[:,4]**2+trajectory[:,5]**2+trajectory[:,6]**2)/2-particles.energy)/particles.energy - for trajectory in trajectories_SIMSOPT_this_tolerance]) - -particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'full_orbit_SIMSOPT')) - - -# Trace in ESSOS -num_steps_essos = int(jnp.max(jnp.array([len(trajectories_SIMSOPT[0]) for trajectories_SIMSOPT in trajectories_SIMSOPT_array]))) -time_essos = jnp.linspace(0, tmax_full, num_steps_essos) - - -tracing_array = [] -trajectories_ESSOS_array = [] -time_ESSOS_array = [] -for model_ESSOS in model_ESSOS_array: - print(f'Tracing ESSOS full orbit '+('Boris' if model_ESSOS=='FullOrbit_Boris' else f'with tolerance={trace_tolerance_ESSOS}')+f' and plotting the result.') - t1 = time.time() - tracing = block_until_ready(Tracing(field=field_essos, model=model_ESSOS, particles=particles, - maxtime=tmax_full, timesteps=num_steps_essos, tol_step_size=trace_tolerance_ESSOS)) - trajectories_ESSOS = tracing.trajectories - time_ESSOS = time.time()-t1 - print(f" Time for ESSOS tracing={time.time()-t1:.3f}s "+('Boris' if model_ESSOS=='FullOrbit_Boris' else f'')+f". Num steps={len(trajectories_ESSOS[0])}") - tracing.to_vtk(os.path.join(output_dir,f'full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_ESSOS')) - tracing_array.append(tracing) - trajectories_ESSOS_array.append(trajectories_ESSOS) - time_ESSOS_array.append(time_ESSOS) - -print('Plotting the results to output directory...') -plt.figure() -SIMSOPT_energy_interp_this_particle = jnp.zeros((len(trace_tolerance_SIMSOPT_array), nparticles, len(trajectories_SIMSOPT_array[-1][-1][:,0]))) -for j in range(nparticles): - for i, relative_energy_error_SIMSOPT in enumerate(relative_energy_error_SIMSOPT_array): - SIMSOPT_energy_interp_this_particle = SIMSOPT_energy_interp_this_particle.at[i,j].set(jnp.interp(trajectories_SIMSOPT_array[-1][-1][:,0], trajectories_SIMSOPT_array[i][j][:,0], relative_energy_error_SIMSOPT[j][:])) -for i, SIMSOPT_energy_interp in enumerate(SIMSOPT_energy_interp_this_particle): - plt.plot(trajectories_SIMSOPT_array[-1][-1][4:,0], jnp.mean(SIMSOPT_energy_interp, axis=0)[4:], '--', label=f'SIMSOPT Tol={trace_tolerance_SIMSOPT_array[i]}') -for model_ESSOS, tracing, trajectories_ESSOS in zip(model_ESSOS_array, tracing_array, trajectories_ESSOS_array): - relative_energy_error_ESSOS = jnp.abs(tracing.energy-particles.energy)/particles.energy - plt.plot(time_essos[2:], jnp.mean(relative_energy_error_ESSOS, axis=0)[2:], '-', label=f'ESSOS'+(' Boris' if model_ESSOS=='FullOrbit_Boris' else f' Tol={trace_tolerance_ESSOS}')) -plt.legend() -plt.yscale('log') -plt.xlabel('Time (s)') -plt.ylabel('Average Relative Energy Error') -plt.tight_layout() -plt.savefig(os.path.join(output_dir, f'relative_energy_error_full_orbit_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -labels = [f'SIMSOPT Tol={tol}' for tol in trace_tolerance_SIMSOPT_array] -times = time_SIMSOPT_array -plt.figure() -for model_ESSOS, tracing, trajectories_ESSOS, time_ESSOS in zip(model_ESSOS_array, tracing_array, trajectories_ESSOS_array, time_ESSOS_array): - # Plot time comparison in a bar chart - labels += ([f'ESSOS Boris Algorithm'] if model_ESSOS=='FullOrbit_Boris' else [f'ESSOS Tol={trace_tolerance_ESSOS}']) - times += [time_ESSOS] -bars = plt.bar(labels, times, color=['blue']*len(trace_tolerance_SIMSOPT_array) + ['red', 'orange'], edgecolor=['black']*len(trace_tolerance_SIMSOPT_array) + ['black']*2, hatch=['//']*len(trace_tolerance_SIMSOPT_array) + ['|']*2) -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Time (s)') -plt.xticks(rotation=45) -plt.tight_layout() -blue_patch = plt.Line2D([0], [0], color='blue', lw=4, label='SIMSOPT', linestyle='--') -red_patch = plt.Line2D([0], [0], color='red', lw=4, label=f'ESSOS', linestyle='-') -orange_patch = plt.Line2D([0], [0], color='orange', lw=4, label=f'ESSOS\nBoris Algorithm') -plt.legend(handles=[blue_patch, red_patch, orange_patch]) -plt.savefig(os.path.join(output_dir, 'times_full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): - time_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 0] # Time values from full orbit SIMSOPT - # coords_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 1:] # Coordinates (x, y, z) from full orbit SIMSOPT - coords_ESSOS = jnp.array(trajectory_ESSOS) - interp_x = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 0]) - interp_y = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 1]) - interp_z = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 2]) - interp_vx = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 3]) - interp_vy = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 4]) - interp_vz = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 5]) - coords_ESSOS_interp = jnp.column_stack([ interp_x, interp_y, interp_z, interp_vx, interp_vy, interp_vz]) - return coords_ESSOS_interp - -for model_ESSOS, tracing, trajectories_ESSOS, time_ESSOS in zip(model_ESSOS_array, tracing_array, trajectories_ESSOS_array, time_ESSOS_array): - - relative_error_array = [] - for i, trajectories_SIMSOPT in enumerate(trajectories_SIMSOPT_array): - trajectories_ESSOS_interp = [interpolate_ESSOS_to_SIMSOPT(trajectories_SIMSOPT[i], trajectories_ESSOS[i]) for i in range(nparticles)] - tracing.trajectories = trajectories_ESSOS_interp - if i==len(trace_tolerance_SIMSOPT_array)-1: tracing.to_vtk(os.path.join(output_dir,f'full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_ESSOS_interp')) - - relative_error_trajectories_SIMSOPT_vs_ESSOS = [] - plt.figure() - for j in range(nparticles): - this_trajectory_SIMSOPT = jnp.array(trajectories_SIMSOPT[j])[:,1:] - this_trajectory_ESSOS = trajectories_ESSOS_interp[j] - average_relative_error = [] - for trajectory_SIMSOPT_t, trajectory_ESSOS_t in zip(this_trajectory_SIMSOPT, this_trajectory_ESSOS): - relative_error_x = jnp.abs(trajectory_SIMSOPT_t[0] - trajectory_ESSOS_t[0])/(jnp.abs(trajectory_SIMSOPT_t[0])+1e-12) - relative_error_y = jnp.abs(trajectory_SIMSOPT_t[1] - trajectory_ESSOS_t[1])/(jnp.abs(trajectory_SIMSOPT_t[1])+1e-12) - relative_error_z = jnp.abs(trajectory_SIMSOPT_t[2] - trajectory_ESSOS_t[2])/(jnp.abs(trajectory_SIMSOPT_t[2])+1e-12) - relative_error_vx = jnp.abs(trajectory_SIMSOPT_t[3] - trajectory_ESSOS_t[3])/(jnp.abs(trajectory_SIMSOPT_t[3])+1e-12) - relative_error_vy = jnp.abs(trajectory_SIMSOPT_t[3] - trajectory_ESSOS_t[3])/(jnp.abs(trajectory_SIMSOPT_t[4])+1e-12) - relative_error_vz = jnp.abs(trajectory_SIMSOPT_t[3] - trajectory_ESSOS_t[3])/(jnp.abs(trajectory_SIMSOPT_t[5])+1e-12) - average_relative_error.append((relative_error_x + relative_error_y + relative_error_z + relative_error_vx + relative_error_vy + relative_error_vz)/6) - average_relative_error = jnp.array(average_relative_error) - relative_error_trajectories_SIMSOPT_vs_ESSOS.append(average_relative_error) - plt.plot(jnp.linspace(0, tmax_full, len(average_relative_error))[1:], average_relative_error[1:], label=f'Particle {1+j}') - plt.legend() - plt.xlabel('Time') - plt.ylabel('Relative Error') - plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir, f'relative_error_full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+f'_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - relative_error_array.append(relative_error_trajectories_SIMSOPT_vs_ESSOS) - - plt.figure() - for j in range(nparticles): - R_SIMSOPT = jnp.sqrt(trajectories_SIMSOPT[j][:,1]**2+trajectories_SIMSOPT[j][:,2]**2) - phi_SIMSOPT = jnp.arctan2(trajectories_SIMSOPT[j][:,2], trajectories_SIMSOPT[j][:,1]) - Z_SIMSOPT = trajectories_SIMSOPT[j][:,3] - - R_ESSOS = jnp.sqrt(trajectories_ESSOS_interp[j][:,0]**2+trajectories_ESSOS_interp[j][:,1]**2) - phi_ESSOS = jnp.arctan2(trajectories_ESSOS_interp[j][:,1], trajectories_ESSOS_interp[j][:,0]) - Z_ESSOS = trajectories_ESSOS_interp[j][:,2] - - plt.plot(R_SIMSOPT, Z_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') - plt.plot(R_ESSOS, Z_ESSOS, '--', linewidth=2.5, label=f'ESSOS {1+j}') - plt.legend() - plt.xlabel('R') - plt.ylabel('Z') - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+f'_RZ_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - plt.figure() - for j in range(nparticles): - time_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,0]) - vx_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,4]) - vx_ESSOS = jnp.array(trajectories_ESSOS_interp[j][:,3]) - # plt.plot(time_SIMSOPT, jnp.abs((vx_SIMSOPT-vx_ESSOS)/vx_SIMSOPT), '-', linewidth=2.5, label=f'Particle {1+j}') - plt.plot(time_SIMSOPT, vx_SIMSOPT/particles.total_speed, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') - plt.plot(time_SIMSOPT, vx_ESSOS/particles.total_speed, '--', linewidth=2.5, label=f'ESSOS {1+j}') - plt.legend() - plt.xlabel('Time (s)') - plt.ylabel(r'$v_x/v$') - # plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+f'_vx_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - # Calculate RMS error for each tolerance - rms_error_array = jnp.array([[jnp.sqrt(jnp.mean(jnp.square(jnp.array(error)))) for error in relative_error] for relative_error in relative_error_array]) - - # Plot RMS error in a bar chart - plt.figure() - bar_width = 0.15 - x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) - for i in range(rms_error_array.shape[1]): - plt.bar(x + i * bar_width, rms_error_array[:, i], bar_width, label=f'Particle {1+i}') - plt.xlabel('Tracing Tolerance of SIMSOPT') - plt.ylabel('RMS Error') - plt.yscale('log') - plt.xticks(x + bar_width * (rms_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) - plt.legend() - plt.tight_layout() - plt.savefig(os.path.join(output_dir, 'rms_error_full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) - plt.close() - - # Calculate maximum error for each tolerance - max_error_array = jnp.array([[jnp.max(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) - # Plot maximum error in a bar chart - plt.figure() - bar_width = 0.15 - x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) - for i in range(max_error_array.shape[1]): - plt.bar(x + i * bar_width, max_error_array[:, i], bar_width, label=f'Particle {1+i}') - plt.xlabel('Tracing Tolerance of SIMSOPT') - plt.ylabel('Maximum Error') - plt.yscale('log') - plt.xticks(x + bar_width * (max_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) - plt.legend() - plt.tight_layout() - plt.savefig(os.path.join(output_dir, 'max_error_full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) - plt.close() - - # Calculate mean error for each tolerance - mean_error_array = jnp.array([[jnp.mean(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) - # Plot mean error in a bar chart - plt.figure() - bar_width = 0.15 - x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) - for i in range(mean_error_array.shape[1]): - plt.bar(x + i * bar_width, mean_error_array[:, i], bar_width, label=f'Particle {1+i}') - plt.xlabel('Tracing Tolerance of SIMSOPT') - plt.ylabel('Mean Error') - plt.yscale('log') - plt.xticks(x + bar_width * (mean_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) - plt.legend() - plt.tight_layout() - plt.savefig(os.path.join(output_dir, 'mean_error_full_orbit'+('_boris' if model_ESSOS=='FullOrbit_Boris' else '')+'_SIMSOPT_vs_ESSOS.pdf'), dpi=150) - plt.close() diff --git a/examples/comparisons_SIMSOPT/guiding_center_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/guiding_center_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index eb102a77..00000000 --- a/examples/comparisons_SIMSOPT/guiding_center_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,248 +0,0 @@ -import os -import time -import jax.numpy as jnp -from jax import block_until_ready, random -from simsopt import load -from simsopt.field import (particles_to_vtk, trace_particles, plot_poincare_data) -from essos.coils import Coils_from_simsopt -from essos.constants import PROTON_MASS, ONE_EV -from essos.dynamics import Tracing, Particles -from essos.fields import BiotSavart as BiotSavart_essos -import matplotlib.pyplot as plt - -tmax_gc = 1e-4 -nparticles = 5 -axis_shft=0.02 -R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nparticles) -trace_tolerance_SIMSOPT_array = [1e-5, 1e-7, 1e-9, 1e-11] -trace_tolerance_ESSOS = 1e-7 -mass=PROTON_MASS -energy=5000*ONE_EV - -output_dir = os.path.join(os.path.dirname(__file__), 'output') -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -nfp=2 -LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '..', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') -field_simsopt = load(LandremanPaulQA_json_file) -field_essos = BiotSavart_essos(Coils_from_simsopt(LandremanPaulQA_json_file, nfp)) - -Z0 = jnp.zeros(nparticles) -phi0 = jnp.zeros(nparticles) -initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T -initial_vparallel_over_v = random.uniform(random.PRNGKey(42), (nparticles,), minval=-1, maxval=1) - -phis_poincare = [(i/4)*(2*jnp.pi/nfp) for i in range(4)] - -particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy) - -# Trace in SIMSOPT -time_SIMSOPT_array = [] -trajectories_SIMSOPT_array = [] -avg_steps_SIMSOPT = 0 -relative_energy_error_SIMSOPT_array = [] -print(f'Output being saved to {output_dir}') -print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') -for trace_tolerance_SIMSOPT in trace_tolerance_SIMSOPT_array: - print(f'Tracing SIMSOPT guiding center with tolerance={trace_tolerance_SIMSOPT}') - t1 = time.time() - trajectories_SIMSOPT_this_tolerance, trajectories_SIMSOPT_phi_hits = block_until_ready(trace_particles( - field=field_simsopt, xyz_inits=particles.initial_xyz, mass=particles.mass, - parallel_speeds=particles.initial_vparallel, tmax=tmax_gc, mode='gc_vac', - charge=particles.charge, Ekin=particles.energy, tol=trace_tolerance_SIMSOPT)) - time_SIMSOPT_array.append(time.time()-t1) - avg_steps_SIMSOPT += sum([len(l) for l in trajectories_SIMSOPT_this_tolerance])//nparticles - print(f" Time for SIMSOPT tracing={time.time()-t1:.3f}s. Avg num steps={avg_steps_SIMSOPT}") - trajectories_SIMSOPT_array.append(trajectories_SIMSOPT_this_tolerance) - - relative_energy_SIMSOPT = [] - for i, trajectory in enumerate(trajectories_SIMSOPT_this_tolerance): - xyz = jnp.asarray(trajectory[:, 1:4]) - vpar = trajectory[:, 4] - field_simsopt.set_points(xyz) - AbsB = field_simsopt.AbsB()[:,0] - mu = (particles.energy - particles.mass*vpar[0]**2/2)/AbsB[0] - relative_energy_SIMSOPT.append(jnp.abs(particles.mass*vpar**2/2+mu*AbsB-particles.energy)/particles.energy) - relative_energy_error_SIMSOPT_array.append(relative_energy_SIMSOPT) - -particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'guiding_center_SIMSOPT')) - -# Trace in ESSOS -num_steps_essos = int(jnp.mean(jnp.array([len(trajectories_SIMSOPT[0]) for trajectories_SIMSOPT in trajectories_SIMSOPT_array]))) -time_essos = jnp.linspace(0, tmax_gc, num_steps_essos) - -print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') -t1 = time.time() -tracing = block_until_ready(Tracing(field=field_essos, model='GuidingCenter', particles=particles, - maxtime=tmax_gc, timesteps=num_steps_essos, tol_step_size=trace_tolerance_ESSOS)) -trajectories_ESSOS = tracing.trajectories -time_ESSOS = time.time()-t1 -print(f" Time for ESSOS tracing={time.time()-t1:.3f}s. Num steps={len(trajectories_ESSOS[0])}") -tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS')) - -relative_energy_error_ESSOS = jnp.abs(tracing.energy-particles.energy)/particles.energy - -print('Plotting the results to output directory...') -plt.figure() -SIMSOPT_energy_interp_this_particle = jnp.zeros((len(trace_tolerance_SIMSOPT_array), nparticles, len(trajectories_SIMSOPT_array[-1][-1][:,0]))) -for j in range(nparticles): - for i, relative_energy_error_SIMSOPT in enumerate(relative_energy_error_SIMSOPT_array): - SIMSOPT_energy_interp_this_particle = SIMSOPT_energy_interp_this_particle.at[i,j].set(jnp.interp(trajectories_SIMSOPT_array[-1][-1][:,0], trajectories_SIMSOPT_array[i][j][:,0], relative_energy_error_SIMSOPT[j][:])) -plt.plot(time_essos[2:], jnp.mean(relative_energy_error_ESSOS, axis=0)[2:], '-', label=f'ESSOS Tol={trace_tolerance_ESSOS}') -for i, SIMSOPT_energy_interp in enumerate(SIMSOPT_energy_interp_this_particle): - plt.plot(trajectories_SIMSOPT_array[-1][-1][4:,0], jnp.mean(SIMSOPT_energy_interp, axis=0)[4:], '--', label=f'SIMSOPT Tol={trace_tolerance_SIMSOPT_array[i]}') -plt.legend() -plt.yscale('log') -plt.xlabel('Time (s)') -plt.ylabel('Average Relative Energy Error') -plt.tight_layout() -plt.savefig(os.path.join(output_dir, f'relative_energy_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Plot time comparison in a bar chart -labels = [f'SIMSOPT\nTol={tol}' for tol in trace_tolerance_SIMSOPT_array] + [f'ESSOS\nTol={trace_tolerance_ESSOS}'] -times = time_SIMSOPT_array + [time_ESSOS] -plt.figure() -bars = plt.bar(labels, times, color=['blue']*len(trace_tolerance_SIMSOPT_array) + ['red'], edgecolor=['black']*len(trace_tolerance_SIMSOPT_array) + ['black'], hatch=['//']*len(trace_tolerance_SIMSOPT_array) + ['|']) -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Time (s)') -plt.xticks(rotation=45) -plt.tight_layout() -blue_patch = plt.Line2D([0], [0], color='blue', lw=4, label='SIMSOPT', linestyle='--') -orange_patch = plt.Line2D([0], [0], color='red', lw=4, label=f'ESSOS', linestyle='-') -plt.legend(handles=[blue_patch, orange_patch]) -plt.savefig(os.path.join(output_dir, 'times_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -def interpolate_ESSOS_to_SIMSOPT(trajectory_SIMSOPT, trajectory_ESSOS): - time_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 0] # Time values from guiding center SIMSOPT - # coords_SIMSOPT = jnp.array(trajectory_SIMSOPT)[:, 1:] # Coordinates (x, y, z) from guiding center SIMSOPT - coords_ESSOS = jnp.array(trajectory_ESSOS) - - interp_x = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 0]) - interp_y = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 1]) - interp_z = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 2]) - interp_v = jnp.interp(time_SIMSOPT, time_essos, coords_ESSOS[:, 3]) - - coords_ESSOS_interp = jnp.column_stack([ interp_x, interp_y, interp_z, interp_v]) - - return coords_ESSOS_interp - -relative_error_array = [] -for i, trajectories_SIMSOPT in enumerate(trajectories_SIMSOPT_array): - trajectories_ESSOS_interp = [interpolate_ESSOS_to_SIMSOPT(trajectories_SIMSOPT[i], trajectories_ESSOS[i]) for i in range(nparticles)] - tracing.trajectories = trajectories_ESSOS_interp - if i==len(trace_tolerance_SIMSOPT_array)-1: tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS_interp')) - - relative_error_trajectories_SIMSOPT_vs_ESSOS = [] - plt.figure() - for j in range(nparticles): - this_trajectory_SIMSOPT = jnp.array(trajectories_SIMSOPT[j])[:,1:] - this_trajectory_ESSOS = trajectories_ESSOS_interp[j] - average_relative_error = [] - for trajectory_SIMSOPT_t, trajectory_ESSOS_t in zip(this_trajectory_SIMSOPT, this_trajectory_ESSOS): - relative_error_x = jnp.abs(trajectory_SIMSOPT_t[0] - trajectory_ESSOS_t[0])/(jnp.abs(trajectory_SIMSOPT_t[0])+1e-12) - relative_error_y = jnp.abs(trajectory_SIMSOPT_t[1] - trajectory_ESSOS_t[1])/(jnp.abs(trajectory_SIMSOPT_t[1])+1e-12) - relative_error_z = jnp.abs(trajectory_SIMSOPT_t[2] - trajectory_ESSOS_t[2])/(jnp.abs(trajectory_SIMSOPT_t[2])+1e-12) - relative_error_v = jnp.abs(trajectory_SIMSOPT_t[3] - trajectory_ESSOS_t[3])/(jnp.abs(trajectory_SIMSOPT_t[3])+1e-12) - average_relative_error.append((relative_error_x + relative_error_y + relative_error_z + relative_error_v)/4) - average_relative_error = jnp.array(average_relative_error) - relative_error_trajectories_SIMSOPT_vs_ESSOS.append(average_relative_error) - plt.plot(jnp.linspace(0, tmax_gc, len(average_relative_error))[1:], average_relative_error[1:], label=f'Particle {1+j}') - plt.legend() - plt.xlabel('Time') - plt.ylabel('Relative Error') - plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir, f'relative_error_guiding_center_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - relative_error_array.append(relative_error_trajectories_SIMSOPT_vs_ESSOS) - - plt.figure() - for j in range(nparticles): - R_SIMSOPT = jnp.sqrt(trajectories_SIMSOPT[j][:,1]**2+trajectories_SIMSOPT[j][:,2]**2) - phi_SIMSOPT = jnp.arctan2(trajectories_SIMSOPT[j][:,2], trajectories_SIMSOPT[j][:,1]) - Z_SIMSOPT = trajectories_SIMSOPT[j][:,3] - - R_ESSOS = jnp.sqrt(trajectories_ESSOS_interp[j][:,0]**2+trajectories_ESSOS_interp[j][:,1]**2) - phi_ESSOS = jnp.arctan2(trajectories_ESSOS_interp[j][:,1], trajectories_ESSOS_interp[j][:,0]) - Z_ESSOS = trajectories_ESSOS_interp[j][:,2] - - plt.plot(R_SIMSOPT, Z_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') - plt.plot(R_ESSOS, Z_ESSOS, '--', linewidth=2.5, label=f'ESSOS {1+j}') - plt.legend() - plt.xlabel('R') - plt.ylabel('Z') - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'guiding_center_RZ_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - - plt.figure() - for j in range(nparticles): - time_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,0]) - vpar_SIMSOPT = jnp.array(trajectories_SIMSOPT[j][:,4]) - vpar_ESSOS = jnp.array(trajectories_ESSOS_interp[j][:,3]) - # plt.plot(time_SIMSOPT, jnp.abs((vpar_SIMSOPT-vpar_ESSOS)/vpar_SIMSOPT), '-', linewidth=2.5, label=f'Particle {1+j}') - plt.plot(time_SIMSOPT, vpar_SIMSOPT, '-', linewidth=2.5, label=f'SIMSOPT {1+j}') - plt.plot(time_SIMSOPT, vpar_ESSOS, '--', linewidth=2.5, label=f'ESSOS {1+j}') - plt.legend() - plt.xlabel('Time (s)') - plt.ylabel(r'$v_{\parallel}/v$') - # plt.yscale('log') - plt.tight_layout() - plt.savefig(os.path.join(output_dir,f'guiding_center_vpar_SIMSOPT_vs_ESSOS_tolerance{trace_tolerance_SIMSOPT_array[i]}.pdf'), dpi=150) - plt.close() - -# Calculate RMS error for each tolerance -rms_error_array = jnp.array([[jnp.sqrt(jnp.mean(jnp.square(jnp.array(error)))) for error in relative_error] for relative_error in relative_error_array]) - -# Plot RMS error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(rms_error_array.shape[1]): - plt.bar(x + i * bar_width, rms_error_array[:, i], bar_width, label=f'Particle {1+i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('RMS Error') -plt.yscale('log') -plt.xticks(x + bar_width * (rms_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'rms_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Calculate maximum error for each tolerance -max_error_array = jnp.array([[jnp.max(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) -# Plot maximum error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(max_error_array.shape[1]): - plt.bar(x + i * bar_width, max_error_array[:, i], bar_width, label=f'Particle {1+i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Maximum Error') -plt.yscale('log') -plt.xticks(x + bar_width * (max_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'max_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() - -# Calculate mean error for each tolerance -mean_error_array = jnp.array([[jnp.mean(jnp.array(error)) for error in relative_error] for relative_error in relative_error_array]) -# Plot mean error in a bar chart -plt.figure() -bar_width = 0.15 -x = jnp.arange(len(trace_tolerance_SIMSOPT_array)) -for i in range(mean_error_array.shape[1]): - plt.bar(x + i * bar_width, mean_error_array[:, i], bar_width, label=f'Particle {1+i}') -plt.xlabel('Tracing Tolerance of SIMSOPT') -plt.ylabel('Mean Error') -plt.yscale('log') -plt.xticks(x + bar_width * (mean_error_array.shape[1] - 1) / 2, [f'Tol={tol}' for tol in trace_tolerance_SIMSOPT_array], rotation=45) -plt.legend() -plt.tight_layout() -plt.savefig(os.path.join(output_dir, 'mean_error_guiding_center_SIMSOPT_vs_ESSOS.pdf'), dpi=150) -plt.close() \ No newline at end of file diff --git a/examples/comparisons_SIMSOPT/surfaces_SIMSOPT_vs_ESSOS.py b/examples/comparisons_SIMSOPT/surfaces_SIMSOPT_vs_ESSOS.py deleted file mode 100644 index 7e1780b3..00000000 --- a/examples/comparisons_SIMSOPT/surfaces_SIMSOPT_vs_ESSOS.py +++ /dev/null @@ -1,72 +0,0 @@ -import os -from time import time -import matplotlib.pyplot as plt -from jax import vmap -import jax.numpy as jnp -from essos.coils import Coils, CreateEquallySpacedCurves -from essos.fields import Vmec, BiotSavart -from essos.surfaces import B_on_surface, BdotN_over_B, SurfaceRZFourier as SurfaceRZFourier_ESSOS -from simsopt.field import BiotSavart as BiotSavart_simsopt -from simsopt.geo import SurfaceRZFourier as SurfaceRZFourier_SIMSOPT -from simsopt.objectives import SquaredFlux - -# Optimization parameters -max_coil_length = 42 -order_Fourier_series_coils = 4 -number_coil_points = 50 -function_evaluations_array = [30]*1 -diff_step_array = [1e-2]*1 -number_coils_per_half_field_period = 3 - -ntheta = 36 -nphi = 32 - -# Initialize VMEC field -vmec_file = os.path.join(os.path.dirname(__file__), '..', 'input_files', - 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc') -vmec = Vmec(vmec_file, ntheta=ntheta, nphi=nphi, close=False) - -# Initialize coils -current_on_each_coil = 1 -number_of_field_periods = vmec.nfp -major_radius_coils = vmec.r_axis -minor_radius_coils = vmec.r_axis/1.5 -curves_essos = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, - order=order_Fourier_series_coils, - R=major_radius_coils, r=minor_radius_coils, - n_segments=number_coil_points, - nfp=number_of_field_periods, stellsym=True) -coils_essos = Coils(curves=curves_essos, currents=[current_on_each_coil]*number_coils_per_half_field_period) -field_essos = BiotSavart(coils_essos) -surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False) -# surface_essos.to_vtk("essos_surface") - -coils_simsopt = coils_essos.to_simsopt() -curves_simsopt = curves_essos.to_simsopt() -field_simsopt = BiotSavart_simsopt(coils_simsopt) -surface_simsopt = SurfaceRZFourier_SIMSOPT.from_wout(vmec_file, range="full torus", nphi=nphi, ntheta=ntheta) -field_simsopt.set_points(surface_simsopt.gamma().reshape((-1, 3))) -# surface_simsopt.to_vtk("simsopt_surface") - -print("Gamma") -print(jnp.sum(jnp.abs(surface_simsopt.gamma()-surface_essos.gamma))) - -print('Gamma dash theta') -print(jnp.sum(jnp.abs(surface_simsopt.gammadash2()-surface_essos.gammadash_theta))) - -print('Gamma dash phi') -print(jnp.sum(jnp.abs(surface_simsopt.gammadash1()-surface_essos.gammadash_phi))) - -print('Normal') -print(jnp.sum(jnp.abs(surface_simsopt.normal()-surface_essos.normal))) - -print('Unit normal') -print(jnp.sum(jnp.abs(surface_simsopt.unitnormal()-surface_essos.unitnormal))) - -BdotN_over_B_SIMSOPT = SquaredFlux(surface_simsopt, field_simsopt, definition="normalized").J() -BdotN_over_B_ESSOS = BdotN_over_B(surface_essos, field_essos) - -B_on_surface_simsopt = field_simsopt.B().reshape(surface_simsopt.normal().shape) -B_on_surface_ESSOS = B_on_surface(surface_essos, field_essos) -# print("ESSOS: ", BdotN_over_B_ESSOS) -# print("SIMSOPT: ", BdotN_over_B_SIMSOPT) diff --git a/examples/comparisons_simsopt/coils.py b/examples/comparisons_simsopt/coils.py new file mode 100644 index 00000000..efb7017f --- /dev/null +++ b/examples/comparisons_simsopt/coils.py @@ -0,0 +1,232 @@ +import os +from time import time +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from jax import block_until_ready +from essos.fields import BiotSavart as BiotSavart_essos +from essos.coils import Coils, Curves +from simsopt import load +from simsopt.geo import CurveXYZFourier, curves_to_vtk +from simsopt.field import BiotSavart as BiotSavart_simsopt, coils_via_symmetries +from simsopt.configs import get_ncsx_data, get_w7x_data, get_hsx_data, get_giuliani_data + +output_dir = os.path.join(os.path.dirname(__file__), '../output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +n_segments = 100 + +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../../examples/', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +nfp_array = [3, 2, 5, 4, 2] +curves_array = [get_ncsx_data()[0], LandremanPaulQA_json_file, get_w7x_data()[0], get_hsx_data()[0], get_giuliani_data()[0]] +currents_array = [get_ncsx_data()[1], None, get_w7x_data()[1], get_hsx_data()[1], get_giuliani_data()[1]] +name_array = ["NCSX", "QA(json)", "W7-X", "HSX", "Giuliani"] + +print(f'Output being saved to {output_dir}') +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') +for nfp, curves_stel, currents_stel, name in zip(nfp_array, curves_array, currents_array, name_array): + print(f' Running {name} and saving to output directory...') + if currents_stel is None: + json_file_stel = curves_stel + field_simsopt = load(json_file_stel) + coils_simsopt = field_simsopt.coils + curves_simsopt = [coil.curve for coil in coils_simsopt] + currents_simsopt = [coil.current for coil in coils_simsopt] + coils_essos = Coils.from_simsopt(json_file_stel, nfp) + curves_essos = Curves.from_simsopt(json_file_stel, nfp) + else: + coils_simsopt = coils_via_symmetries(curves_stel, currents_stel, nfp, True) + curves_simsopt = [c.curve for c in coils_simsopt] + currents_simsopt = [c.current for c in coils_simsopt] + field_simsopt = BiotSavart_simsopt(coils_simsopt) + + coils_essos = Coils.from_simsopt(coils_simsopt, nfp) + curves_essos = Curves.from_simsopt(curves_simsopt, nfp) + + field_essos = BiotSavart_essos(coils_essos) + + coils_essos_to_simsopt = coils_essos.to_simsopt() + curves_essos_to_simsopt = curves_essos.to_simsopt() + field_essos_to_simsopt = BiotSavart_simsopt(coils_essos_to_simsopt) + + # curves_to_vtk(curves_simsopt, os.path.join(output_dir,f"curves_simsopt_{name}")) + # curves_essos.to_vtk(os.path.join(output_dir,f"curves_essos_{name}")) + # curves_to_vtk(curves_essos_to_simsopt, os.path.join(output_dir,f"curves_essos_to_simsopt_{name}")) + + base_coils_simsopt = coils_simsopt[:int(len(coils_simsopt)/2/nfp)] + R = jnp.mean(jnp.array([jnp.sqrt(coil.curve.x[coil.curve.local_dof_names.index('xc(0)')]**2 + +coil.curve.x[coil.curve.local_dof_names.index('yc(0)')]**2) + for coil in base_coils_simsopt])) + x = jnp.array([R+0.01,R,R]) + y = jnp.array([R,R+0.01,R-0.01]) + z = jnp.array([0.05,0.06,0.07]) + + positions = jnp.array((x,y,z)) + + def update_nsegments_simsopt(curve_simsopt, n_segments): + new_curve = CurveXYZFourier(n_segments, curve_simsopt.order) + new_curve.x = curve_simsopt.x + return new_curve + + coils_essos.n_segments = n_segments + + base_curves_simsopt = [update_nsegments_simsopt(coil_simsopt.curve, n_segments) for coil_simsopt in base_coils_simsopt] + coils_simsopt = coils_via_symmetries(base_curves_simsopt, currents_simsopt[0:len(base_coils_simsopt)], nfp, True) + curves_simsopt = [c.curve for c in coils_simsopt] + + # Running the first time for compilation + [curve.gamma() for curve in curves_simsopt] + [curve.gammadash() for curve in curves_simsopt] + [curve.gammadashdash() for curve in curves_simsopt] + coils_essos.gamma + coils_essos.gamma_dash + coils_essos.gamma_dashdash + coils_essos.curvature + coils_essos.reset_cache() + + # Running the second time for coils characteristics comparison + start_time = time() + gamma_curves_simsopt = block_until_ready(jnp.array([curve.gamma() for curve in curves_simsopt])) + t_gamma_avg_simsopt = time() - start_time + + start_time = time() + gamma_curves_essos = block_until_ready(jnp.array(coils_essos.gamma)) + t_gamma_avg_essos = time() - start_time + + start_time = time() + gammadash_curves_simsopt = block_until_ready(jnp.array([curve.gammadash() for curve in curves_simsopt])) + t_gammadash_avg_simsopt = time() - start_time + + start_time = time() + gammadash_curves_essos = block_until_ready(jnp.array(coils_essos.gamma_dash)) + t_gammadash_avg_essos = time() - start_time + + start_time = time() + gammadashdash_curves_simsopt = block_until_ready(jnp.array([curve.gammadashdash() for curve in curves_simsopt])) + t_gammadashdash_avg_simsopt = time() - start_time + + start_time = time() + gammadashdash_curves_essos = block_until_ready(jnp.array(coils_essos.gamma_dashdash)) + t_gammadashdash_avg_essos = time() - start_time + + start_time = time() + curvature_curves_simsopt = block_until_ready(jnp.array([curve.kappa() for curve in curves_simsopt])) + t_curvature_avg_simsopt = time() - start_time + + start_time = time() + curvature_curves_essos = block_until_ready(jnp.array(coils_essos.curvature)) + t_curvature_avg_essos = time() - start_time + + gamma_error_avg = jnp.linalg.norm(gamma_curves_essos - gamma_curves_simsopt) + gammadash_error_avg = jnp.linalg.norm(gammadash_curves_essos - gammadash_curves_simsopt) + gammadashdash_error_avg = jnp.linalg.norm(gammadashdash_curves_essos - gammadashdash_curves_simsopt) + curvature_error_avg = jnp.linalg.norm(curvature_curves_essos - curvature_curves_simsopt) + + # Magnetic field comparison + + field_essos = BiotSavart_essos(coils_essos) + field_simsopt = BiotSavart_simsopt(coils_simsopt) + + t_B_avg_essos = 0 + t_B_avg_simsopt = 0 + B_error_avg = 0 + t_dB_by_dX_avg_essos = 0 + t_dB_by_dX_avg_simsopt = 0 + dB_by_dX_error_avg = 0 + + for position in positions: + field_essos.B(position) + time1 = time() + result_B_essos = field_essos.B(position) + t_B_avg_essos = t_B_avg_essos + time() - time1 + normB_essos = jnp.linalg.norm(result_B_essos) + + field_simsopt.set_points(jnp.array([position])) + field_simsopt.B() + time3 = time() + field_simsopt.set_points(jnp.array([position])) + result_simsopt = field_simsopt.B() + t_B_avg_simsopt = t_B_avg_simsopt + time() - time3 + normB_simsopt = jnp.linalg.norm(jnp.array(result_simsopt)) + + B_error_avg = B_error_avg + jnp.abs(normB_essos - normB_simsopt) + + field_essos.dB_by_dX(position) + time1 = time() + field_simsopt.set_points(jnp.array([position])) + result_dB_by_dX_essos = field_essos.dB_by_dX(position) + t_dB_by_dX_avg_essos = t_dB_by_dX_avg_essos + time() - time1 + norm_dB_by_dX_essos = jnp.linalg.norm(result_dB_by_dX_essos) + + field_simsopt.dB_by_dX() + time3 = time() + field_simsopt.set_points(jnp.array([position])) + result_dB_by_dX_simsopt = field_simsopt.dB_by_dX() + t_dB_by_dX_avg_simsopt = t_dB_by_dX_avg_simsopt + time() - time3 + norm_dB_by_dX_simsopt = jnp.linalg.norm(jnp.array(result_dB_by_dX_simsopt)) + + dB_by_dX_error_avg = dB_by_dX_error_avg + jnp.abs(norm_dB_by_dX_essos - norm_dB_by_dX_simsopt) + + # Labels and corresponding absolute errors (ESSOS - SIMSOPT) + quantities_errors = [ + (r"$B$", jnp.abs(B_error_avg)), + (r"$B'$", jnp.abs(dB_by_dX_error_avg)), + (r"$\Gamma$", jnp.abs(gamma_error_avg)), + (r"$\Gamma'$", jnp.abs(gammadash_error_avg)), + (r"$\Gamma''$", jnp.abs(gammadashdash_error_avg)), + (r"$\kappa$", jnp.abs(curvature_error_avg)), + ] + + labels = [q[0] for q in quantities_errors] + error_vals = [q[1] for q in quantities_errors] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.6 + + fig, ax = plt.subplots(figsize=(9, 6)) + ax.bar(X_axis, error_vals, bar_width, color="darkorange", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Absolute error") + ax.set_yscale("log") + ax.set_ylim(1e-17, 1e-12) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"comparisons_coils_error_{name}.pdf"), transparent=True) + plt.close() + + + # Labels and corresponding timings + quantities = [ + (r"$B$", t_B_avg_essos, t_B_avg_simsopt), + (r"$B'$", t_dB_by_dX_avg_essos, t_dB_by_dX_avg_simsopt), + (r"$\Gamma$", t_gamma_avg_essos, t_gamma_avg_simsopt), + (r"$\Gamma'$", t_gammadash_avg_essos, t_gammadash_avg_simsopt), + (r"$\Gamma''$", t_gammadashdash_avg_essos, t_gammadashdash_avg_simsopt), + (r"$\kappa$", t_curvature_avg_essos, t_curvature_avg_simsopt), + ] + + labels = [q[0] for q in quantities] + essos_vals = [q[1] for q in quantities] + simsopt_vals = [q[2] for q in quantities] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.35 + + fig, ax = plt.subplots(figsize=(9, 6)) + ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") + ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Computation time (s)") + ax.set_yscale("log") + ax.set_ylim(1e-5, 1e-1) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + ax.legend(fontsize=12) + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"comparisons_coils_time_{name}.pdf"), transparent=True) + plt.close() diff --git a/examples/comparisons_simsopt/field_lines.py b/examples/comparisons_simsopt/field_lines.py new file mode 100644 index 00000000..445d36e6 --- /dev/null +++ b/examples/comparisons_simsopt/field_lines.py @@ -0,0 +1,179 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax.numpy as jnp +from jax import block_until_ready, random +from simsopt import load +from simsopt.field import (particles_to_vtk, compute_fieldlines, plot_poincare_data) +from essos.coils import Coils +from essos.constants import PROTON_MASS, ONE_EV +from essos.dynamics import Tracing, Particles +from essos.fields import BiotSavart as BiotSavart_essos +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +plt.rcParams.update({'font.size': 18}) + +tmax_fl = 2000 +nfieldlines = 5 +axis_shift=0.02 +R0 = jnp.linspace(1.2125346+axis_shift, 1.295-axis_shift, nfieldlines) +trace_tolerance_array = [1e-5, 1e-7, 1e-9, 1e-11, 1e-13] +trace_tolerance_ESSOS = 1e-9 +mass=PROTON_MASS +energy=5000*ONE_EV + +output_dir = os.path.join(os.path.dirname(__file__), '../output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +nfp=2 +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +field_simsopt = load(LandremanPaulQA_json_file) +field_essos = BiotSavart_essos(Coils.from_simsopt(LandremanPaulQA_json_file, nfp)) + +Z0 = jnp.zeros(nfieldlines) +phi0 = jnp.zeros(nfieldlines) +initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T +initial_vparallel_over_v = random.uniform(random.PRNGKey(42), (nfieldlines,), minval=-1, maxval=1) + +phis_poincare = [(i/4)*(2*jnp.pi/nfp) for i in range(4)] + +particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy) + +# Trace in SIMSOPT +runtime_SIMSOPT_array = [] +trajectories_SIMSOPT_array = [] +avg_steps_SIMSOPT_array = [] + +print(f'Output being saved to {output_dir}') +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}\n') +for trace_tolerance_SIMSOPT in trace_tolerance_array: + print(f'Tracing SIMSOPT field lines with tolerance={trace_tolerance_SIMSOPT}') + t1 = time() + trajectories_SIMSOPT, trajectories_SIMSOPT_phi_hits = block_until_ready(compute_fieldlines( + field_simsopt, R0, Z0, tmax=tmax_fl, tol=trace_tolerance_SIMSOPT, phis=phis_poincare)) + runtime_SIMSOPT = time() - t1 + runtime_SIMSOPT_array.append(runtime_SIMSOPT) + avg_steps_SIMSOPT = sum([len(l) for l in trajectories_SIMSOPT]) // nfieldlines + avg_steps_SIMSOPT_array.append(avg_steps_SIMSOPT) + # print(trajectories_SIMSOPT_this_tolerance[0].shape) + print(f"Time for SIMSOPT tracing={runtime_SIMSOPT:.3f}s. Avg num steps={avg_steps_SIMSOPT}\n") + trajectories_SIMSOPT_array.append(trajectories_SIMSOPT) + + # particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'guiding_center_SIMSOPT')) + +# Trace in ESSOS +runtime_ESSOS_array = [] +times_essos_array = [] +trajectories_ESSOS_array = [] +relative_energy_error_ESSOS_array = [] + +# Creating a tracing object for compilation +compile_tracing = Tracing('FieldLine', field_essos, tmax_fl, initial_conditions=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T, + timesteps=100, method='Dopri5', stepsize='adaptive', tol_step_size=trace_tolerance_array[0]) +block_until_ready(compile_tracing.trajectories) + +for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): + num_steps_essos = avg_steps_SIMSOPT_array[index] + print(f'Tracing ESSOS field lines with tolerance={trace_tolerance_ESSOS}') + start_time = time() + tracing = Tracing('FieldLine', field_essos, tmax_fl, initial_conditions=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T, + timesteps=num_steps_essos, method='Dopri5', stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS) + block_until_ready(tracing.trajectories) + runtime_ESSOS = time() - start_time + runtime_ESSOS_array.append(runtime_ESSOS) + times_essos_array.append(tracing.times) + trajectories_ESSOS_array.append(tracing.trajectories) + # print(tracing.trajectories.shape) + + trajectories_ESSOS = tracing.trajectories + print(f"Time for ESSOS tracing={runtime_ESSOS:.3f}s. Num steps={len(trajectories_ESSOS[0])}\n") + + # tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS')) + +print('Plotting the results to output directory...') + +# Plot time comparison in a bar chart +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", runtime_ESSOS_array[tolerance_idx], runtime_SIMSOPT_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +essos_vals = [q[1] for q in quantities] +simsopt_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Computation time (s)") +ax.set_yscale('log') +ax.set_ylim(1e0, 1e2) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +plt.savefig(os.path.join(output_dir, 'comparisons_fl_times.pdf'), dpi=150) + +################################## + +def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): + time_simsopt = trajectory_SIMSOPT[:, 0] # Time values from SIMSOPT trajectory + + interp_x = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 1]) + interp_y = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 2]) + interp_z = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 3]) + + coords_SIMSOPT_interp = jnp.column_stack([interp_x, interp_y, interp_z]) + + return coords_SIMSOPT_interp + +plt.figure(figsize=(9, 6)) + +avg_relative_xyz_error_array = [] +for tolerance_idx in range(len(trace_tolerance_array)): + this_trajectory_SIMSOPT = jnp.stack([interpolate_SIMSOPT_to_ESSOS( + trajectories_SIMSOPT_array[tolerance_idx][particle_idx], times_essos_array[tolerance_idx] + ) for particle_idx in range(nfieldlines)]) + + this_trajectory_ESSOS = trajectories_ESSOS_array[tolerance_idx] + + relative_xyz_errors = jnp.linalg.norm(this_trajectory_ESSOS[:, :, :3] - this_trajectory_SIMSOPT[:, :, :3], axis=2) / (jnp.linalg.norm(this_trajectory_SIMSOPT[:, :, :3], axis=2) + 1e-12) + + avg_relative_xyz_errors = jnp.mean(relative_xyz_errors, axis=0) + avg_relative_xyz_error_array.append(jnp.mean(avg_relative_xyz_errors)) + + plt.plot(times_essos_array[tolerance_idx], avg_relative_xyz_errors, label=rf'tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$') + +plt.legend() +plt.xlabel('Time (a.u.)') +plt.yscale('log') + +plt.ylabel(r'Relative $x,y,z$ Error') +plt.savefig(os.path.join(output_dir, f'comparisons_fl_error_xyz.pdf'), dpi=150) + +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", avg_relative_xyz_error_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +xyz_vals = [q[1] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.4 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis, xyz_vals, bar_width, label=r"x,y,z", color="darkorange", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Time-averaged relative error") +ax.set_yscale('log') +ax.set_ylim(1e-6, 1e-1) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +plt.savefig(os.path.join(output_dir, 'comparisons_fl_error.pdf'), dpi=150) + +plt.show() \ No newline at end of file diff --git a/examples/comparisons_simsopt/full_orbit.py b/examples/comparisons_simsopt/full_orbit.py new file mode 100644 index 00000000..ad7c8b65 --- /dev/null +++ b/examples/comparisons_simsopt/full_orbit.py @@ -0,0 +1,262 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax.numpy as jnp +from jax import block_until_ready, random +from simsopt import load +from simsopt.field import (particles_to_vtk, trace_particles, plot_poincare_data) +from essos.coils import Coils +from essos.constants import PROTON_MASS, ONE_EV +from essos.dynamics import Tracing, Particles +from essos.fields import BiotSavart as BiotSavart_essos +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +plt.rcParams.update({'font.size': 18}) + +######################################################################################## +method = 'Boris' # 'Boris' or 'Dopri5' +######################################################################################## + + +tmax = 5e-5 +nparticles = 5 +axis_shft=0.02 +R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nparticles) +trace_tolerance_array = [1e-5, 1e-7, 1e-9, 1e-11, 1e-13] +trace_tolerance_ESSOS = 1e-9 +mass=PROTON_MASS +energy=5000*ONE_EV + +output_dir = os.path.join(os.path.dirname(__file__), '../output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +nfp=2 +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +field_simsopt = load(LandremanPaulQA_json_file) +field_essos = BiotSavart_essos(Coils.from_simsopt(LandremanPaulQA_json_file, nfp)) + +Z0 = jnp.zeros(nparticles) +phi0 = jnp.zeros(nparticles) +initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T +initial_vparallel_over_v = random.uniform(random.PRNGKey(42), (nparticles,), minval=-1, maxval=1) + +particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy, field=field_essos) + +# Trace in SIMSOPT +runtime_SIMSOPT_array = [] +trajectories_SIMSOPT_array = [] +avg_steps_SIMSOPT_array = [] +relative_energy_error_SIMSOPT_array = [] +print(f'Output being saved to {output_dir}') +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}\n') +for trace_tolerance_SIMSOPT in trace_tolerance_array: + print(f'Tracing SIMSOPT full orbit with tolerance={trace_tolerance_SIMSOPT}') + t1 = time() + trajectories_SIMSOPT, trajectories_SIMSOPT_phi_hits = block_until_ready(trace_particles( + field=field_simsopt, xyz_inits=particles.initial_xyz, mass=particles.mass, + parallel_speeds=particles.initial_vparallel, tmax=tmax, mode='full', + charge=particles.charge, Ekin=particles.energy, tol=trace_tolerance_SIMSOPT)) + runtime_SIMSOPT = time() - t1 + runtime_SIMSOPT_array.append(runtime_SIMSOPT) + avg_steps_SIMSOPT = sum([len(l) for l in trajectories_SIMSOPT]) // nparticles + avg_steps_SIMSOPT_array.append(avg_steps_SIMSOPT) + + print(f"Time for SIMSOPT tracing={runtime_SIMSOPT:.3f}s. Avg num steps={avg_steps_SIMSOPT}\n") + trajectories_SIMSOPT_array.append(trajectories_SIMSOPT) + + relative_energy_SIMSOPT = [jnp.abs(0.5 * mass * jnp.sum(jnp.square(trajectory[:, 4:]), axis=1) - particles.energy) / particles.energy + for trajectory in trajectories_SIMSOPT] + + relative_energy_error_SIMSOPT_array.append(relative_energy_SIMSOPT) + + # particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'guiding_center_SIMSOPT')) + +# Trace in ESSOS +runtime_ESSOS_array = [] +times_essos_array = [] +trajectories_ESSOS_array = [] +relative_energy_error_ESSOS_array = [] + +# Creating a tracing object for compilation +if method == 'Dopri5': + compile_tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=100, method='Dopri5', + stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) +else: + compile_tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=100, method='Boris', + stepsize='constant', particles=particles) + +block_until_ready(compile_tracing.trajectories) + +for tolerance_idx, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): + print(f'Tracing ESSOS full orbit with tolerance={trace_tolerance_ESSOS}') + start_time = time() + if method == 'Dopri5': + num_steps_essos = 10000 + tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=num_steps_essos, method='Dopri5', + stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) + else: + num_steps_essos = avg_steps_SIMSOPT_array[tolerance_idx]*3 + tracing = Tracing('FullOrbit', field_essos, tmax, timesteps=num_steps_essos, method='Boris', + stepsize='constant', particles=particles) + + block_until_ready(tracing.trajectories) + runtime_ESSOS = time() - start_time + runtime_ESSOS_array.append(runtime_ESSOS) + times_essos_array.append(tracing.times) + trajectories_ESSOS_array.append(tracing.trajectories) + # print(tracing.trajectories.shape) + + trajectories_ESSOS = tracing.trajectories + print(f"Time for ESSOS tracing={runtime_ESSOS:.3f}s. Num steps={len(trajectories_ESSOS[0])}\n") + + relative_energy_error_ESSOS = jnp.abs(tracing.energy()-particles.energy)/particles.energy + relative_energy_error_ESSOS_array.append(relative_energy_error_ESSOS) + # tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS')) + +print('Plotting the results to output directory...') +plt.figure(figsize=(9, 6)) +colors = ['blue', 'orange', 'green', 'red', 'purple'] + +SIMSOPT_energy_interp = [] + +for tolerance_idx in range(len(trace_tolerance_array)): + interpolation = jnp.stack([ + jnp.interp(times_essos_array[tolerance_idx], trajectories_SIMSOPT_array[tolerance_idx][particle_idx][:, 0], relative_energy_error_SIMSOPT_array[tolerance_idx][particle_idx]) + for particle_idx in range(nparticles) + ]) # This will have shape (nparticles, len(times_essos_array[tolerance_idx])) + SIMSOPT_energy_interp.append(interpolation) + + plt.plot(times_essos_array[tolerance_idx]*1000, jnp.mean(interpolation, axis=0), '--', color=colors[tolerance_idx]) + plt.plot(times_essos_array[tolerance_idx]*1000, jnp.mean(relative_energy_error_ESSOS_array[tolerance_idx], axis=0), '-', color=colors[tolerance_idx]) + +legend_elements = [Line2D([0], [0], color=colors[tolerance_idx], linestyle='-', label=fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$") + for tolerance_idx in range(len(trace_tolerance_array))] + +plt.legend(handles=legend_elements, loc='lower right', title='ESSOS (─), SIMSOPT (--)', fontsize=14, title_fontsize=14) +plt.yscale('log') +plt.xlabel('Time (ms)') +plt.ylabel('Average relative energy error') +plt.tight_layout() +if method == 'Dopri5': + plt.savefig(os.path.join(output_dir, f'comparisons_fo_error_energy.pdf'), dpi=150) +else: + plt.savefig(os.path.join(output_dir, f'comparisons_fo_boris_error_energy.pdf'), dpi=150) + +# Plot time comparison in a bar chart + +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", runtime_ESSOS_array[tolerance_idx], runtime_SIMSOPT_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +essos_vals = [q[1] for q in quantities] +simsopt_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Computation time (s)") +ax.set_yscale('log') +if method == 'Dopri5': + ax.set_ylim(1e0, 1e3) +else: + ax.set_ylim(1e-1, 1e3) + +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +if method == 'Dopri5': + plt.savefig(os.path.join(output_dir, 'comparisons_fo_times.pdf'), dpi=150) +else: + plt.savefig(os.path.join(output_dir, 'comparisons_fo_boris_times.pdf'), dpi=150) + +################################## + +def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): + time_simsopt = trajectory_SIMSOPT[:, 0] # Time values from SIMSOPT trajectory + + interp_x = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 1]) + interp_y = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 2]) + interp_z = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 3]) + interp_vx = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 4]) + interp_vy = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 5]) + interp_vz = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 6]) + + coords_SIMSOPT_interp = jnp.column_stack([interp_x, interp_y, interp_z, interp_vx, interp_vy, interp_vz]) + + return coords_SIMSOPT_interp + +xyz_error_fig, xyz_error_ax = plt.subplots(figsize=(9, 6)) +v_error_fig, v_error_ax = plt.subplots(figsize=(9, 6)) + +avg_relative_xyz_error_array = [] +avg_relative_v_error_array = [] +for tolerance_idx in range(len(trace_tolerance_array)): + this_trajectory_SIMSOPT = jnp.stack([interpolate_SIMSOPT_to_ESSOS( + trajectories_SIMSOPT_array[tolerance_idx][particle_idx], times_essos_array[tolerance_idx] + ) for particle_idx in range(nparticles)]) + + this_trajectory_ESSOS = trajectories_ESSOS_array[tolerance_idx] + + relative_xyz_errors = jnp.linalg.norm(this_trajectory_ESSOS[:, :, :3] - this_trajectory_SIMSOPT[:, :, :3], axis=2) / (jnp.linalg.norm(this_trajectory_SIMSOPT[:, :, :3], axis=2) + 1e-12) + relative_v_errors = jnp.linalg.norm(this_trajectory_ESSOS[:, :, 3:] - this_trajectory_SIMSOPT[:, :, 3:], axis=2) / (jnp.linalg.norm(this_trajectory_SIMSOPT[:, :, 3:], axis=2) + 1e-12) + + avg_relative_xyz_errors = jnp.mean(relative_xyz_errors, axis=0) + avg_relative_v_errors = jnp.mean(relative_v_errors, axis=0) + avg_relative_xyz_error_array.append(jnp.mean(avg_relative_xyz_errors)) + avg_relative_v_error_array.append(jnp.mean(avg_relative_v_errors)) + + xyz_error_ax.plot(times_essos_array[tolerance_idx]*1000, avg_relative_xyz_errors, label=rf'tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$') + v_error_ax.plot(times_essos_array[tolerance_idx]*1000, avg_relative_v_errors, label=rf'tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$') + +for ax, fig in zip([xyz_error_ax, v_error_ax], [xyz_error_fig, v_error_fig]): + ax.legend() + ax.set_xlabel('Time (ms)') + ax.set_yscale('log') + +xyz_error_ax.set_ylabel(r'Relative $x,y,z$ Error') +v_error_ax.set_ylabel(r'Relative $v_x,v_y,v_z$ Error') +if method == 'Dopri5': + xyz_error_fig.savefig(os.path.join(output_dir, f'comparisons_fo_error_xyz.pdf'), dpi=150) + v_error_fig.savefig(os.path.join(output_dir, f'comparisons_fo_error_v.pdf'), dpi=150) +else: + xyz_error_fig.savefig(os.path.join(output_dir, f'comparisons_fo_boris_error_xyz.pdf'), dpi=150) + v_error_fig.savefig(os.path.join(output_dir, f'comparisons_fo_boris_error_v.pdf'), dpi=150) + +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", avg_relative_xyz_error_array[tolerance_idx], avg_relative_v_error_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +xyz_vals = [q[1] for q in quantities] +v_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, xyz_vals, bar_width, label=r"x,y,z", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, v_vals, bar_width, label=r"$v_x,v_y,v_z$", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Time-averaged relative error") +ax.set_yscale('log') +ax.set_ylim(1e-6, 1e1) +if method == 'Dopri5': + ax.set_ylim(1e-8, 1e-1) +else: + ax.set_ylim(1e-4, 1e0) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +if method == 'Dopri5': + plt.savefig(os.path.join(output_dir, 'comparisons_fo_errors.pdf'), dpi=150) +else: + plt.savefig(os.path.join(output_dir, 'comparisons_fo_boris_errors.pdf'), dpi=150) + +plt.show() \ No newline at end of file diff --git a/examples/comparisons_simsopt/guiding_center.py b/examples/comparisons_simsopt/guiding_center.py new file mode 100644 index 00000000..8798d854 --- /dev/null +++ b/examples/comparisons_simsopt/guiding_center.py @@ -0,0 +1,230 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +import jax.numpy as jnp +from jax import block_until_ready, random +from simsopt import load +from simsopt.field import (particles_to_vtk, trace_particles, plot_poincare_data) +from essos.coils import Coils +from essos.constants import PROTON_MASS, ONE_EV +from essos.dynamics import Tracing, Particles +from essos.fields import BiotSavart as BiotSavart_essos +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +plt.rcParams.update({'font.size': 18}) + +tmax_gc = 5e-4 +nparticles = 5 +axis_shft=0.02 +R0 = jnp.linspace(1.2125346+axis_shft, 1.295-axis_shft, nparticles) +trace_tolerance_array = [1e-5, 1e-7, 1e-9, 1e-11, 1e-13] +trace_tolerance_ESSOS = 1e-9 +mass=PROTON_MASS +energy=5000*ONE_EV + +output_dir = os.path.join(os.path.dirname(__file__), '../output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +nfp=2 +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../../examples', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +field_simsopt = load(LandremanPaulQA_json_file) +field_essos = BiotSavart_essos(Coils.from_simsopt(LandremanPaulQA_json_file, nfp)) + +Z0 = jnp.zeros(nparticles) +phi0 = jnp.zeros(nparticles) +initial_xyz=jnp.array([R0*jnp.cos(phi0), R0*jnp.sin(phi0), Z0]).T +initial_vparallel_over_v = random.uniform(random.PRNGKey(42), (nparticles,), minval=-1, maxval=1) + +phis_poincare = [(i/4)*(2*jnp.pi/nfp) for i in range(4)] + +particles = Particles(initial_xyz=initial_xyz, initial_vparallel_over_v=initial_vparallel_over_v, mass=mass, energy=energy) + +# Trace in SIMSOPT +runtime_SIMSOPT_array = [] +trajectories_SIMSOPT_array = [] +avg_steps_SIMSOPT_array = [] +relative_energy_error_SIMSOPT_array = [] +print(f'Output being saved to {output_dir}') +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}\n') +for trace_tolerance_SIMSOPT in trace_tolerance_array: + print(f'Tracing SIMSOPT guiding center with tolerance={trace_tolerance_SIMSOPT}') + t1 = time() + trajectories_SIMSOPT, trajectories_SIMSOPT_phi_hits = block_until_ready(trace_particles( + field=field_simsopt, xyz_inits=particles.initial_xyz, mass=particles.mass, + parallel_speeds=particles.initial_vparallel, tmax=tmax_gc, mode='gc_vac', + charge=particles.charge, Ekin=particles.energy, tol=trace_tolerance_SIMSOPT)) + runtime_SIMSOPT = time() - t1 + runtime_SIMSOPT_array.append(runtime_SIMSOPT) + avg_steps_SIMSOPT = sum([len(l) for l in trajectories_SIMSOPT]) // nparticles + avg_steps_SIMSOPT_array.append(avg_steps_SIMSOPT) + # print(trajectories_SIMSOPT_this_tolerance[0].shape) + print(f"Time for SIMSOPT tracing={runtime_SIMSOPT:.3f}s. Avg num steps={avg_steps_SIMSOPT}\n") + trajectories_SIMSOPT_array.append(trajectories_SIMSOPT) + + relative_energy_SIMSOPT = [] + for i, trajectory in enumerate(trajectories_SIMSOPT): + xyz = jnp.asarray(trajectory[:, 1:4]) + vpar = trajectory[:, 4] + field_simsopt.set_points(xyz) + AbsB = field_simsopt.AbsB()[:,0] + mu = (particles.energy - particles.mass*vpar[0]**2/2)/AbsB[0] + relative_energy_SIMSOPT.append(jnp.abs(particles.mass*vpar**2/2+mu*AbsB-particles.energy)/particles.energy) + relative_energy_error_SIMSOPT_array.append(relative_energy_SIMSOPT) + + # particles_to_vtk(trajectories_SIMSOPT_this_tolerance, os.path.join(output_dir,f'guiding_center_SIMSOPT')) + +# Trace in ESSOS +runtime_ESSOS_array = [] +times_essos_array = [] +trajectories_ESSOS_array = [] +relative_energy_error_ESSOS_array = [] + +# Creating a tracing object for compilation +compile_tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=100, method='Dopri5', + stepsize='adaptive', tol_step_size=trace_tolerance_array[0], particles=particles) +block_until_ready(compile_tracing.trajectories) + +for index, trace_tolerance_ESSOS in enumerate(trace_tolerance_array): + num_steps_essos = avg_steps_SIMSOPT_array[index] + print(f'Tracing ESSOS guiding center with tolerance={trace_tolerance_ESSOS}') + start_time = time() + tracing = Tracing('GuidingCenter', field_essos, tmax_gc, timesteps=num_steps_essos, method='Dopri5', + stepsize='adaptive', tol_step_size=trace_tolerance_ESSOS, particles=particles) + block_until_ready(tracing.trajectories) + runtime_ESSOS = time() - start_time + runtime_ESSOS_array.append(runtime_ESSOS) + times_essos_array.append(tracing.times) + trajectories_ESSOS_array.append(tracing.trajectories) + # print(tracing.trajectories.shape) + + trajectories_ESSOS = tracing.trajectories + print(f"Time for ESSOS tracing={runtime_ESSOS:.3f}s. Num steps={len(trajectories_ESSOS[0])}\n") + + relative_energy_error_ESSOS = jnp.abs(tracing.energy()-particles.energy)/particles.energy + relative_energy_error_ESSOS_array.append(relative_energy_error_ESSOS) + # tracing.to_vtk(os.path.join(output_dir,f'guiding_center_ESSOS')) + +print('Plotting the results to output directory...') +plt.figure(figsize=(9, 6)) +colors = ['blue', 'orange', 'green', 'red', 'purple'] + +SIMSOPT_energy_interp = [] + +for tolerance_idx in range(len(trace_tolerance_array)): + interpolation = jnp.stack([ + jnp.interp(times_essos_array[tolerance_idx], trajectories_SIMSOPT_array[tolerance_idx][particle_idx][:, 0], relative_energy_error_SIMSOPT_array[tolerance_idx][particle_idx]) + for particle_idx in range(nparticles) + ]) # This will have shape (nparticles, len(times_essos_array[tolerance_idx])) + SIMSOPT_energy_interp.append(interpolation) + + plt.plot(times_essos_array[tolerance_idx]*1000, jnp.mean(interpolation, axis=0), '--', color=colors[tolerance_idx]) + plt.plot(times_essos_array[tolerance_idx]*1000, jnp.mean(relative_energy_error_ESSOS_array[tolerance_idx], axis=0), '-', color=colors[tolerance_idx]) + +legend_elements = [Line2D([0], [0], color=colors[tolerance_idx], linestyle='-', label=fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$") + for tolerance_idx in range(len(trace_tolerance_array))] + +plt.legend(handles=legend_elements, loc='lower right', title='ESSOS (─), SIMSOPT (--)', fontsize=14, title_fontsize=14) +plt.yscale('log') +plt.xlabel('Time (ms)') +plt.ylabel('Average relative energy error') +plt.tight_layout() +plt.savefig(os.path.join(output_dir, f'comparisons_gc_error_energy.pdf'), dpi=150) + +# Plot time comparison in a bar chart + +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", runtime_ESSOS_array[tolerance_idx], runtime_SIMSOPT_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +essos_vals = [q[1] for q in quantities] +simsopt_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Computation time (s)") +ax.set_yscale('log') +ax.set_ylim(1e0, 1e2) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +plt.savefig(os.path.join(output_dir, 'comparisons_gc_times.pdf'), dpi=150) + +################################## + +def interpolate_SIMSOPT_to_ESSOS(trajectory_SIMSOPT, time_ESSOS): + time_simsopt = trajectory_SIMSOPT[:, 0] # Time values from SIMSOPT trajectory + + interp_x = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 1]) + interp_y = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 2]) + interp_z = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 3]) + interp_v = jnp.interp(time_ESSOS, time_simsopt, trajectory_SIMSOPT[:, 4]) + + coords_SIMSOPT_interp = jnp.column_stack([interp_x, interp_y, interp_z, interp_v]) + + return coords_SIMSOPT_interp + +xyz_error_fig, xyz_error_ax = plt.subplots(figsize=(9, 6)) +vpar_error_fig, vpar_error_ax = plt.subplots(figsize=(9, 6)) + +avg_relative_xyz_error_array = [] +avg_relative_v_error_array = [] +for tolerance_idx in range(len(trace_tolerance_array)): + this_trajectory_SIMSOPT = jnp.stack([interpolate_SIMSOPT_to_ESSOS( + trajectories_SIMSOPT_array[tolerance_idx][particle_idx], times_essos_array[tolerance_idx] + ) for particle_idx in range(nparticles)]) + + this_trajectory_ESSOS = trajectories_ESSOS_array[tolerance_idx] + + relative_xyz_errors = jnp.linalg.norm(this_trajectory_ESSOS[:, :, :3] - this_trajectory_SIMSOPT[:, :, :3], axis=2) / (jnp.linalg.norm(this_trajectory_SIMSOPT[:, :, :3], axis=2) + 1e-12) + relative_v_errors = jnp.abs(this_trajectory_SIMSOPT[:, :, 3] - this_trajectory_ESSOS[:, :, 3]) / (jnp.abs(this_trajectory_SIMSOPT[:, :, 3]) + 1e-12) + + avg_relative_xyz_errors = jnp.mean(relative_xyz_errors, axis=0) + avg_relative_v_errors = jnp.mean(relative_v_errors, axis=0) + avg_relative_xyz_error_array.append(jnp.mean(avg_relative_xyz_errors)) + avg_relative_v_error_array.append(jnp.mean(avg_relative_v_errors)) + + xyz_error_ax.plot(times_essos_array[tolerance_idx]*1000, avg_relative_xyz_errors, label=rf'tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$') + vpar_error_ax.plot(times_essos_array[tolerance_idx]*1000, avg_relative_v_errors, label=rf'tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$') + +for ax, fig in zip([xyz_error_ax, vpar_error_ax], [xyz_error_fig, vpar_error_fig]): + ax.legend() + ax.set_xlabel('Time (ms)') + ax.set_yscale('log') + +xyz_error_ax.set_ylabel(r'Relative $x,y,z$ Error') +vpar_error_ax.set_ylabel(r'Relative $v_\parallel$ Error') +xyz_error_fig.savefig(os.path.join(output_dir, f'comparisons_gc_error_xyz.pdf'), dpi=150) +vpar_error_fig.savefig(os.path.join(output_dir, f'comparisons_gc_error_vpar.pdf'), dpi=150) + +quantities = [(fr"tol=$10^{{{int(jnp.log10(trace_tolerance_array[tolerance_idx])-1e-3)}}}$", avg_relative_xyz_error_array[tolerance_idx], avg_relative_v_error_array[tolerance_idx]) + for tolerance_idx in range(len(trace_tolerance_array))] + +labels = [q[0] for q in quantities] +xyz_vals = [q[1] for q in quantities] +vpar_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, xyz_vals, bar_width, label=r"x,y,z", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, vpar_vals, bar_width, label=r"$v_\parallel$", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Time-averaged relative error") +ax.set_yscale('log') +ax.set_ylim(1e-6, 1e-1) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=14) +plt.savefig(os.path.join(output_dir, 'comparisons_gc_error.pdf'), dpi=150) + +plt.show() \ No newline at end of file diff --git a/examples/comparisons_simsopt/losses.py b/examples/comparisons_simsopt/losses.py new file mode 100644 index 00000000..837e3ca4 --- /dev/null +++ b/examples/comparisons_simsopt/losses.py @@ -0,0 +1,197 @@ +import os +from time import perf_counter as time +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from jax import block_until_ready +from essos.fields import BiotSavart as BiotSavart_essos +from essos.coils import Coils, Curves +from essos.objective_functions import loss_coil_curvature, loss_coil_separation, compute_candidates, loss_coil_length +from simsopt import load +from simsopt.geo import CurveXYZFourier, curves_to_vtk, CurveCurveDistance, LpCurveCurvature, CurveLength +from simsopt.field import BiotSavart as BiotSavart_simsopt, coils_via_symmetries +from simsopt.configs import get_ncsx_data, get_w7x_data, get_hsx_data, get_giuliani_data + +output_dir = os.path.join(os.path.dirname(__file__), '../output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +n_segments = 100 + +LandremanPaulQA_json_file = os.path.join(os.path.dirname(__file__), '../../examples/', 'input_files', 'SIMSOPT_biot_savart_LandremanPaulQA.json') +nfp_array = [3, 2, 5, 4, 2] +curves_array = [get_ncsx_data()[0], LandremanPaulQA_json_file, get_w7x_data()[0], get_hsx_data()[0], get_giuliani_data()[0]] +currents_array = [get_ncsx_data()[1], None, get_w7x_data()[1], get_hsx_data()[1], get_giuliani_data()[1]] +name_array = ["NCSX", "QA(json)", "W7-X", "HSX", "Giuliani"] + +print(f'Output being saved to {output_dir}') +print(f'SIMSOPT LandremanPaulQA json file location: {LandremanPaulQA_json_file}') +for nfp, curves_stel, currents_stel, name in zip(nfp_array, curves_array, currents_array, name_array): + print(f' Running {name} and saving to output directory...') + if currents_stel is None: + json_file_stel = curves_stel + field_simsopt = load(json_file_stel) + coils_simsopt = field_simsopt.coils + curves_simsopt = [coil.curve for coil in coils_simsopt] + currents_simsopt = [coil.current for coil in coils_simsopt] + coils_essos = Coils.from_simsopt(json_file_stel, nfp) + curves_essos = Curves.from_simsopt(json_file_stel, nfp) + else: + coils_simsopt = coils_via_symmetries(curves_stel, currents_stel, nfp, True) + curves_simsopt = [c.curve for c in coils_simsopt] + currents_simsopt = [c.current for c in coils_simsopt] + field_simsopt = BiotSavart_simsopt(coils_simsopt) + + coils_essos = Coils.from_simsopt(coils_simsopt, nfp) + curves_essos = Curves.from_simsopt(curves_simsopt, nfp) + + field_essos = BiotSavart_essos(coils_essos) + + coils_essos_to_simsopt = coils_essos.to_simsopt() + curves_essos_to_simsopt = curves_essos.to_simsopt() + field_essos_to_simsopt = BiotSavart_simsopt(coils_essos_to_simsopt) + + # curves_to_vtk(curves_simsopt, os.path.join(output_dir,f"curves_simsopt_{name}")) + # curves_essos.to_vtk(os.path.join(output_dir,f"curves_essos_{name}")) + # curves_to_vtk(curves_essos_to_simsopt, os.path.join(output_dir,f"curves_essos_to_simsopt_{name}")) + + base_coils_simsopt = coils_simsopt[:int(len(coils_simsopt)/2/nfp)] + R = jnp.mean(jnp.array([jnp.sqrt(coil.curve.x[coil.curve.local_dof_names.index('xc(0)')]**2 + +coil.curve.x[coil.curve.local_dof_names.index('yc(0)')]**2) + for coil in base_coils_simsopt])) + x = jnp.array([R+0.01,R,R]) + y = jnp.array([R,R+0.01,R-0.01]) + z = jnp.array([0.05,0.06,0.07]) + + positions = jnp.array((x,y,z)) + + def update_nsegments_simsopt(curve_simsopt, n_segments): + new_curve = CurveXYZFourier(n_segments, curve_simsopt.order) + new_curve.x = curve_simsopt.x + return new_curve + + coils_essos.n_segments = n_segments + + base_curves_simsopt = [update_nsegments_simsopt(coil_simsopt.curve, n_segments) for coil_simsopt in base_coils_simsopt] + coils_simsopt = coils_via_symmetries(base_curves_simsopt, currents_simsopt[0:len(base_coils_simsopt)], nfp, True) + curves_simsopt = [c.curve for c in coils_simsopt] + + # Running the first time for compilation + [LpCurveCurvature(curve, p=2, threshold=0).J() for curve in curves_simsopt] + loss_coil_curvature(coils_essos, 0) + [CurveLength(curve).J() for curve in curves_simsopt] + loss_coil_length(coils_essos, 10) + CurveCurveDistance(curves_simsopt, 0.5).J() + loss_coil_separation(coils_essos, 0.5) + + # Running the second time for losses comparison + + start_time = time() + curvature_loss_simsopt = block_until_ready(2*sum([LpCurveCurvature(curve, p=2, threshold=0).J() for curve in curves_simsopt])) + t_curvature_avg_simsopt = time() - start_time + + start_time = time() + curvature_loss_essos = block_until_ready(jnp.sum(loss_coil_curvature(coils_essos, 0))) + t_curvature_avg_essos = time() - start_time + + start_time = time() + length_loss_simsopt = block_until_ready(sum([(CurveLength(curve).J()/10 - 1)**2 for curve in curves_simsopt])) + t_length_avg_simsopt = time() - start_time + print(f"Length loss SIMSOPT: {length_loss_simsopt}") + + start_time = time() + length_loss_essos = block_until_ready(jnp.sum(loss_coil_length(coils_essos, 10))) + t_length_avg_essos = time() - start_time + print(f"Length loss ESSOS: {length_loss_essos}") + + start_time = time() + separation_loss_simsopt = block_until_ready(CurveCurveDistance(curves_simsopt, 0.5).J()) + t_separation_avg_simsopt = time() - start_time + print(f"Separation loss SIMSOPT: {separation_loss_simsopt}") + + start_time = time() + separation_loss_essos = block_until_ready(loss_coil_separation(coils_essos, 0.5)) + t_separation_avg_essos = time() - start_time + print(f"Separation loss ESSOS: {separation_loss_essos}") + + start_time = time() + ind_separation_loss_simsopt = block_until_ready(CurveCurveDistance(curves_simsopt, 0.5).J()) + t_ind_separation_avg_simsopt = time() - start_time + print(f"Independence separation loss SIMSOPT: {ind_separation_loss_simsopt}") + + start_time = time() + ind_separation_loss_essos = block_until_ready(loss_coil_separation(coils_essos, 0.5, candidates=compute_candidates(coils_essos, 0.5))) + t_ind_separation_avg_essos = time() - start_time + print(f"Independence separation loss ESSOS: {ind_separation_loss_essos}") + + length_error_avg = jnp.linalg.norm(length_loss_essos - length_loss_simsopt) / jnp.linalg.norm(length_loss_simsopt) + if length_error_avg == 0: + length_error_avg = jnp.finfo(jnp.float64).eps + curvature_error_avg = jnp.linalg.norm(curvature_loss_essos - curvature_loss_simsopt) / jnp.linalg.norm(curvature_loss_simsopt) + if curvature_error_avg == 0: + curvature_error_avg = jnp.finfo(jnp.float64).eps + separation_error_avg = jnp.linalg.norm(separation_loss_essos - separation_loss_simsopt) / jnp.linalg.norm(separation_loss_simsopt) + ind_separation_error_avg = jnp.linalg.norm(ind_separation_loss_essos - ind_separation_loss_simsopt) / jnp.linalg.norm(ind_separation_loss_simsopt) + print(f"length_error_avg: {length_error_avg:.2e}") + print(f"curvature_error_avg: {curvature_error_avg:.2e}") + print(f"separation_error_avg: {separation_error_avg:.2e}") + # print(f"ind_separation_error_avg: {ind_separation_error_avg:.2e}") + + # Labels and corresponding absolute errors (ESSOS - SIMSOPT) + quantities_errors = [ + (r"$L_\ell$", jnp.abs(length_error_avg)), + (r"$L_\kappa$", jnp.abs(curvature_error_avg)), + (r"$L_\text{sep}$", jnp.abs(separation_error_avg)), + # (r"$L_\text{sep,ind}$", jnp.abs(ind_separation_error_avg)), + ] + + labels = [q[0] for q in quantities_errors] + error_vals = [q[1] for q in quantities_errors] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.6 + + fig, ax = plt.subplots(figsize=(9, 6)) + ax.bar(X_axis, error_vals, bar_width, color="darkorange", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Relative error") + ax.set_yscale("log") + ax.set_ylim(1e-16, 1e-1) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"comparisons_losses_error_{name}.pdf"), transparent=True) + plt.close() + + + # Labels and corresponding timings + quantities = [ + (r"$L_\ell$", t_length_avg_essos, t_length_avg_simsopt), + (r"$L_\kappa$", t_curvature_avg_essos, t_curvature_avg_simsopt), + (r"$L_\text{sep}$", t_separation_avg_essos, t_separation_avg_simsopt), + # (r"$L_\text{sep,ind}$", t_ind_separation_avg_essos, t_ind_separation_avg_simsopt), + ] + + labels = [q[0] for q in quantities] + essos_vals = [q[1] for q in quantities] + simsopt_vals = [q[2] for q in quantities] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.35 + + fig, ax = plt.subplots(figsize=(9, 6)) + ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") + ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Computation time (s)") + ax.set_yscale("log") + ax.set_ylim(1e-4, 1e0) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + ax.legend(fontsize=12) + plt.tight_layout() + plt.savefig(os.path.join(output_dir, f"comparisons_losses_time_{name}.pdf"), transparent=True) + plt.close() diff --git a/examples/comparisons_simsopt/surfaces.py b/examples/comparisons_simsopt/surfaces.py new file mode 100644 index 00000000..5a6d70d8 --- /dev/null +++ b/examples/comparisons_simsopt/surfaces.py @@ -0,0 +1,205 @@ +import os +from time import perf_counter as time +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +import jax.numpy as jnp +from jax import block_until_ready +from essos.coils import Coils, CreateEquallySpacedCurves +from essos.fields import Vmec, BiotSavart +from essos.surfaces import B_on_surface, BdotN_over_B, SurfaceRZFourier as SurfaceRZFourier_ESSOS, SquaredFlux as SquaredFlux_ESSOS +from simsopt.field import BiotSavart as BiotSavart_simsopt +from simsopt.geo import SurfaceRZFourier as SurfaceRZFourier_SIMSOPT +from simsopt.objectives import SquaredFlux as SquaredFlux_SIMSOPT + +output_dir = os.path.join(os.path.dirname(__file__), '../output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +# Optimization parameters +max_coil_length = 42 +order_Fourier_series_coils = 4 +number_coil_points = 50 +function_evaluations_array = [30]*1 +diff_step_array = [1e-2]*1 +number_coils_per_half_field_period = 3 + +ntheta = 36 +nphi = 32 + +# Initialize VMEC field +vmec_file = os.path.join(os.path.dirname(__file__), '../../examples', 'input_files', + 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc') +vmec = Vmec(vmec_file, ntheta=ntheta, nphi=nphi, close=False) + +# Initialize coils +current_on_each_coil = 1 +number_of_field_periods = vmec.nfp +major_radius_coils = vmec.r_axis +minor_radius_coils = vmec.r_axis/1.5 +curves_essos = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) +coils_essos = Coils(curves=curves_essos, currents=[current_on_each_coil]*number_coils_per_half_field_period) +field_essos = BiotSavart(coils_essos) +surface_essos = SurfaceRZFourier_ESSOS(vmec, ntheta=ntheta, nphi=nphi, close=False) +# surface_essos.to_vtk("essos_surface") + +coils_simsopt = coils_essos.to_simsopt() +curves_simsopt = curves_essos.to_simsopt() +field_simsopt = BiotSavart_simsopt(coils_simsopt) +surface_simsopt = SurfaceRZFourier_SIMSOPT.from_wout(vmec_file, range="full torus", nphi=nphi, ntheta=ntheta) +field_simsopt.set_points(surface_simsopt.gamma().reshape((-1, 3))) +# surface_simsopt.to_vtk("simsopt_surface") + +# Running the first time for compilation +surface_simsopt.gamma() +surface_simsopt.gammadash1() +surface_simsopt.gammadash2() +surface_simsopt.unitnormal() +field_simsopt.B() +SquaredFlux_SIMSOPT(surface_simsopt, field_simsopt).J() +block_until_ready(surface_essos.gamma) + +# Running the second time for surface characteristics comparison + +print("Gamma") +start_time = time() +gamma_essos = block_until_ready(surface_essos.gamma) +t_gamma_essos = time() - start_time + +gamma_simsopt = block_until_ready(surface_simsopt.gamma()) +start_time = time() +t_gamma_simsopt = time() - start_time + +gamma_error = jnp.sum(jnp.abs(gamma_simsopt - gamma_essos)) +print(gamma_error) + + +print('Gamma dash theta') +start_time = time() +gamma_dash_theta_essos = block_until_ready(surface_essos.gammadash_theta) +t_gamma_dash_theta_essos = time() - start_time + +start_time = time() +gamma_dash_theta_simsopt = block_until_ready(surface_simsopt.gammadash2()) +t_gamma_dash_theta_simsopt = time() - start_time + +gamma_dash_theta_error = jnp.sum(jnp.abs(gamma_dash_theta_simsopt - gamma_dash_theta_essos)) +print(gamma_dash_theta_error) + + +print('Gamma dash phi') +start_time = time() +gamma_dash_phi_essos = block_until_ready(surface_essos.gammadash_phi) +t_gamma_dash_phi_essos = time() - start_time + +start_time = time() +gamma_dash_phi_simsopt = block_until_ready(surface_simsopt.gammadash1()) +t_gamma_dash_phi_simsopt = time() - start_time + +gamma_dash_phi_error = jnp.sum(jnp.abs(gamma_dash_phi_simsopt - gamma_dash_phi_essos)) +print(gamma_dash_phi_error) + + +print('Unit normal') +start_time = time() +unit_normal_essos = block_until_ready(surface_essos.unitnormal) +t_unit_normal_essos = time() - start_time + +start_time = time() +unit_normal_simsopt = block_until_ready(surface_simsopt.unitnormal()) +t_unit_normal_simsopt = time() - start_time + +unit_normal_error = jnp.sum(jnp.abs(unit_normal_simsopt - unit_normal_essos)) +print(unit_normal_error) + + +print('B on surface') +start_time = time() +B_on_surface_essos = block_until_ready(B_on_surface(surface_essos, field_essos)) +t_B_on_surface_essos = time() - start_time + +start_time = time() +B_on_surface_simsopt = block_until_ready(field_simsopt.B()) +t_B_on_surface_simsopt = time() - start_time + +B_on_surface_error = jnp.sum(jnp.abs(B_on_surface_simsopt.reshape((nphi, ntheta, 3)) - B_on_surface_essos)) +print(B_on_surface_error) + + +definition = "local" +print("Squared flux", definition) +start_time = time() +sf_essos = block_until_ready(SquaredFlux_ESSOS(surface_essos, field_essos, definition=definition)) +t_squared_flux_essos = time() - start_time + +start_time = time() +sf_simsopt = block_until_ready(SquaredFlux_SIMSOPT(surface_simsopt, field_simsopt, definition=definition).J()) +t_squared_flux_simsopt = time() - start_time + +squared_flux_error = jnp.abs(sf_simsopt - sf_essos) +print(squared_flux_error) + +# Labels and corresponding absolute errors (ESSOS - SIMSOPT) +quantities_errors = [ + (r"$\Gamma$", gamma_error), + (r"$\Gamma'_\theta$", gamma_dash_theta_error), + (r"$\Gamma'_\phi$", gamma_dash_phi_error), + (r"$\mathbf{n}$", unit_normal_error), + # (r"$\mathbf{B}$", B_on_surface_error), + (r"$L_\text{flux}$", squared_flux_error), +] + +labels = [q[0] for q in quantities_errors] +error_vals = [q[1] for q in quantities_errors] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.6 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis, error_vals, bar_width, color="darkorange", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Absolute error") +ax.set_yscale("log") +ax.set_ylim(1e-14, 1e-10) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + +plt.tight_layout() +plt.savefig(os.path.join(output_dir, f"comparisons_surfaces_error.pdf"), transparent=True) + +# Labels and corresponding timings +quantities = [ + (r"$\Gamma$", t_gamma_essos, t_gamma_simsopt), + (r"$\Gamma'_\theta$", t_gamma_dash_theta_essos, t_gamma_dash_theta_simsopt), + (r"$\Gamma'_\phi$", t_gamma_dash_phi_essos, t_gamma_dash_phi_simsopt), + (r"$\mathbf{n}$", t_unit_normal_essos, t_unit_normal_simsopt), + # (r"$\mathbf{B}$", t_B_on_surface_essos, t_B_on_surface_simsopt), + (r"$L_\text{flux}$", t_squared_flux_essos, t_squared_flux_simsopt), +] + +labels = [q[0] for q in quantities] +essos_vals = [q[1] for q in quantities] +simsopt_vals = [q[2] for q in quantities] + +X_axis = jnp.arange(len(labels)) +bar_width = 0.35 + +fig, ax = plt.subplots(figsize=(9, 6)) +ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") +ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + +ax.set_xticks(X_axis) +ax.set_xticklabels(labels) +ax.set_ylabel("Computation time (s)") +ax.set_yscale("log") +ax.set_ylim(1e-7, 1e-1) +ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) +ax.legend(fontsize=12) +plt.tight_layout() +plt.savefig(os.path.join(output_dir, f"comparisons_surfaces_time.pdf"), transparent=True) + +plt.show() diff --git a/examples/comparisons_SIMSOPT/vmec_SIMSOPT_vs_ESSOS.py b/examples/comparisons_simsopt/vmec_import.py similarity index 50% rename from examples/comparisons_SIMSOPT/vmec_SIMSOPT_vs_ESSOS.py rename to examples/comparisons_simsopt/vmec_import.py index 8c0c1d27..adffffbc 100644 --- a/examples/comparisons_SIMSOPT/vmec_SIMSOPT_vs_ESSOS.py +++ b/examples/comparisons_simsopt/vmec_import.py @@ -2,18 +2,21 @@ from time import time import jax.numpy as jnp import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) from jax import block_until_ready, random from essos.fields import Vmec as Vmec_essos from simsopt.mhd import Vmec as Vmec_simsopt, vmec_compute_geometry -output_dir = os.path.join(os.path.dirname(__file__), 'output') + +output_dir = os.path.join(os.path.dirname(__file__), '../output') if not os.path.exists(output_dir): os.makedirs(output_dir) -wout_array = [os.path.join(os.path.dirname(__file__), '..', 'input_files', "wout_LandremanPaul2021_QA_reactorScale_lowres.nc"), - os.path.join(os.path.dirname(__file__), '..', 'input_files', "wout_n3are_R7.75B5.7.nc")] +wout_array = [os.path.join(os.path.dirname(__file__), '../../examples/', 'input_files', "wout_LandremanPaul2021_QA_reactorScale_lowres.nc"), + os.path.join(os.path.dirname(__file__), '../../examples/', 'input_files', "wout_n3are_R7.75B5.7.nc")] name_array = ["LandremanPaulQA", 'NCSX'] + print(f'Output being saved to {output_dir}') for name, wout in zip(name_array, wout_array): print(f' Running comparison with VMEC file located at: {wout}') @@ -71,41 +74,66 @@ def timed_B(s, function): average_time_modB_essos /= len(s_array) average_time_B_essos /= len(s_array) average_time_B_simsopt /= len(s_array) - - fig = plt.figure(figsize = (8, 6)) - X_axis = jnp.arange(4) - Y_axis = [average_time_modB_simsopt, average_time_B_simsopt, average_time_modB_essos, average_time_B_essos] - colors = ['blue', 'blue', 'red', 'red'] - hatches = ['/', '\\', '/', '\\'] - bars = plt.bar(X_axis, Y_axis, width=0.4, color=colors) - for bar, hatch in zip(bars, hatches): bar.set_hatch(hatch) - plt.xticks(X_axis, [r"$|\boldsymbol{B}|$ SIMSOPT", r"$\boldsymbol{B}$ SIMSOPT", r"$|\boldsymbol{B}|$ ESSOS", r"$\boldsymbol{B}$ ESSOS"], fontsize=16) - plt.tick_params(axis='both', which='major', labelsize=14) - plt.tick_params(axis='both', which='minor', labelsize=14) - plt.ylabel("Time to evaluate VMEC field (s)", fontsize=14) - plt.grid(axis='y') - plt.yscale("log") - plt.ylim(1e-6, 1) - plt.title(name, fontsize=14) + error_modB /= len(s_array) + error_B /= len(s_array) + + # Labels and corresponding absolute errors (ESSOS - SIMSOPT) + quantities_errors = [ + (r"$B$", jnp.mean(error_modB)), + (r"$\mathbf{B}$", jnp.mean(error_B)), + ] + + labels = [q[0] for q in quantities_errors] + error_vals = [q[1] for q in quantities_errors] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.4 + + fig, ax = plt.subplots(figsize=(9, 6)) + ax.bar(X_axis, error_vals, bar_width, color="darkorange", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Relative error") + ax.set_yscale("log") + ax.set_ylim(1e-6, 1e-2) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + plt.tight_layout() - plt.savefig(os.path.join(output_dir,f"time_VMEC_SIMSOPT_vs_ESSOS_{name}.pdf"), transparent=True) - plt.close() - - fig = plt.figure(figsize = (8, 6)) - X_axis = jnp.arange(2) - Y_axis = [jnp.mean(error_modB), jnp.mean(error_B)] - colors = ['purple', 'orange'] - hatches = ['/', '//'] - bars = plt.bar(X_axis, Y_axis, width=0.4, color=colors) - for bar, hatch in zip(bars, hatches): bar.set_hatch(hatch) - plt.xticks(X_axis, [r"$|\boldsymbol{B}|$", r"$\boldsymbol{B}$"], fontsize=16) - plt.tick_params(axis='both', which='major', labelsize=14) - plt.tick_params(axis='both', which='minor', labelsize=14) - plt.ylabel("Relative error SIMSOPT vs ESSOS (%)", fontsize=14) - plt.grid(axis='y') - plt.yscale("log") - plt.ylim(1e-6, 1e-1) - plt.title(name, fontsize=14) + plt.savefig(os.path.join(output_dir, f"comparisons_VMEC_error_{name}.pdf"), transparent=True) + + # Labels and corresponding timings + print(f"Average time to compute |B| in SIMSOPT: {average_time_modB_simsopt:.6f} s") + print(f"Average time to compute B in SIMSOPT: {average_time_B_simsopt:.6f} s") + print(f"Average time to compute |B| in ESSOS: {average_time_modB_essos:.6f} s") + print(f"Average time to compute B in ESSOS: {average_time_B_essos:.6f} s") + print(f"Relative error in |B|: {jnp.mean(error_modB):.6f}") + print(f"Relative error in B: {jnp.mean(error_B):.6f}") + + quantities = [ + (r"$B$", average_time_modB_essos, average_time_modB_simsopt), + (r"$\mathbf{B}$", average_time_B_essos, average_time_B_simsopt), + ] + + labels = [q[0] for q in quantities] + essos_vals = [q[1] for q in quantities] + simsopt_vals = [q[2] for q in quantities] + + X_axis = jnp.arange(len(labels)) + bar_width = 0.4 + + fig, ax = plt.subplots(figsize=(9, 6)) + ax.bar(X_axis - bar_width/2, essos_vals, bar_width, label="ESSOS", color="red", edgecolor="black") + ax.bar(X_axis + bar_width/2, simsopt_vals, bar_width, label="SIMSOPT", color="blue", edgecolor="black") + + ax.set_xticks(X_axis) + ax.set_xticklabels(labels) + ax.set_ylabel("Computation time (s)") + ax.set_yscale("log") + ax.set_ylim(1e-5, 1e-1) + ax.grid(axis='y', which='both', linestyle='--', linewidth=0.6) + ax.legend(fontsize=12) plt.tight_layout() - plt.savefig(os.path.join(output_dir,f"error_VMEC_SIMSOPT_vs_ESSOS_{name}.pdf"), transparent=True) - plt.close() \ No newline at end of file + plt.savefig(os.path.join(output_dir, f"comparisons_VMEC_time_{name}.pdf"), transparent=True) + + plt.show() \ No newline at end of file diff --git a/examples/poincare_guiding_center_coils.py b/examples/fieldline_tracing/poincare_guiding_center_coils.py similarity index 100% rename from examples/poincare_guiding_center_coils.py rename to examples/fieldline_tracing/poincare_guiding_center_coils.py diff --git a/examples/trace_fieldlines_coils.py b/examples/fieldline_tracing/trace_fieldlines_coils.py similarity index 95% rename from examples/trace_fieldlines_coils.py rename to examples/fieldline_tracing/trace_fieldlines_coils.py index 2ea23059..c92148aa 100644 --- a/examples/trace_fieldlines_coils.py +++ b/examples/fieldline_tracing/trace_fieldlines_coils.py @@ -6,7 +6,7 @@ from jax import block_until_ready import matplotlib.pyplot as plt from essos.fields import BiotSavart -from essos.coils import Coils_from_json +from essos.coils import Coils from essos.dynamics import Tracing # Input parameters @@ -19,7 +19,7 @@ # Load coils and field json_file = os.path.join(os.path.dirname(__file__), 'input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) +coils = Coils.from_json(json_file) field = BiotSavart(coils) # Initialize particles diff --git a/examples/trace_fieldlines_vmec.py b/examples/fieldline_tracing/trace_fieldlines_vmec.py similarity index 100% rename from examples/trace_fieldlines_vmec.py rename to examples/fieldline_tracing/trace_fieldlines_vmec.py diff --git a/examples/input_files/input.rotating_ellipse b/examples/input_files/input.rotating_ellipse index a35f3af0..bce19ba5 100644 --- a/examples/input_files/input.rotating_ellipse +++ b/examples/input_files/input.rotating_ellipse @@ -5,10 +5,17 @@ MPOL = 002 NTOR = 002 !----- Boundary Parameters (n,m) ----- - RBC( 000,000) = 10 ZBS( 000,000) = 0 - RBC( 001,000) = 1 ZBS( 001,000) = -1 + RBC( 000,000) = 10. ZBS( 000,000) = 0. + RBC( 001,000) = 1. ZBS( 001,000) = -1. + RBC( 002,000) = 0. ZBS( 002,000) = 0. + RBC( -002,001) = 0. ZBS( -002,001) = 0. RBC(-001,001) = 0.1 ZBS(-001,001) = 0.1 RBC( 000,001) = 2.5 ZBS( 000,001) = 2.5 RBC( 001,001) = -1 ZBS( 001,001) = 1 + RBC( 002,001) = 0 ZBS( 002,001) = 0 RBC(-002,002) = 1E-4 ZBS(-002,002) = 1E-4 + RBC(-001,002) = 0. ZBS(-001,002) = 0. + RBC( 000,002) = 0. ZBS( 000,002) = 0. + RBC( 001,002) = 0. ZBS( 001,002) = 0. + RBC( 002,002) = 0. ZBS( 002,002) = 0. / diff --git a/examples/input_files/input.toroidal_surface b/examples/input_files/input.toroidal_surface index 3a133b24..533ae617 100644 --- a/examples/input_files/input.toroidal_surface +++ b/examples/input_files/input.toroidal_surface @@ -1,14 +1,21 @@ !----- Runtime Parameters ----- &INDATA LASYM = F - NFP = 0001 + NFP = 0002 MPOL = 002 NTOR = 002 !----- Boundary Parameters (n,m) ----- - RBC( 000,000) = 7.75 ZBS( 000,000) = 0 + RBC( 000,000) = 10.0 ZBS( 000,000) = 0 RBC( 001,000) = 0.000001 ZBS( 001,000) = -0.000001 + RBC( 002,000) = 0. ZBS( 002,000) = 0. + RBC( -002,001) = 0. ZBS( -002,001) = 0. RBC(-001,001) = 0.000001 ZBS(-001,001) = 0.000001 - RBC( 000,001) = 2.5 ZBS( 000,001) = 2.5 + RBC( 000,001) = 0.5 ZBS( 000,001) = 0.5 RBC( 001,001) = 0.000001 ZBS( 001,001) = 0.000001 + RBC( 002,001) = 0 ZBS( 002,001) = 0 RBC(-002,002) = 1E-7 ZBS(-002,002) = 1E-7 + RBC(-001,002) = 0. ZBS(-001,002) = 0. + RBC( 000,002) = 0. ZBS( 000,002) = 0. + RBC( 001,002) = 0. ZBS( 001,002) = 0. + RBC( 002,002) = 0. ZBS( 002,002) = 0. / diff --git a/examples/optimize_coils_vmec_surface.py b/examples/optimize_coils_vmec_surface.py deleted file mode 100644 index 57324b25..00000000 --- a/examples/optimize_coils_vmec_surface.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -number_of_processors_to_use = 5 # Parallelization, this should divide ntheta*nphi -os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' -from time import time -import jax.numpy as jnp -import matplotlib.pyplot as plt -from essos.surfaces import BdotN_over_B -from essos.coils import Coils, CreateEquallySpacedCurves -from essos.fields import Vmec, BiotSavart -from essos.objective_functions import loss_BdotN -from essos.optimization import optimize_loss_function - -# Optimization parameters -max_coil_length = 10 -max_coil_curvature = 1.0 -order_Fourier_series_coils = 3 -number_coil_points = order_Fourier_series_coils*15 -maximum_function_evaluations = 50 -number_coils_per_half_field_period = 3 -tolerance_optimization = 1e-5 -ntheta=35 -nphi=35 - -# Initialize VMEC field -vmec = Vmec(os.path.join(os.path.dirname(__file__), 'input_files', - 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc'), - ntheta=ntheta, nphi=nphi, range_torus='half period') - -# Initialize coils -current_on_each_coil = 1 -number_of_field_periods = vmec.nfp -major_radius_coils = vmec.r_axis -minor_radius_coils = vmec.r_axis/1.8 -curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, - order=order_Fourier_series_coils, - R=major_radius_coils, r=minor_radius_coils, - n_segments=number_coil_points, - nfp=number_of_field_periods, stellsym=True) -coils_initial = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) - -# Optimize coils -print(f'Optimizing coils with {maximum_function_evaluations} function evaluations.') -time0 = time() -coils_optimized = optimize_loss_function(loss_BdotN, initial_dofs=coils_initial.x, coils=coils_initial, tolerance_optimization=tolerance_optimization, - maximum_function_evaluations=maximum_function_evaluations, vmec=vmec, - max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature,) -print(f"Optimization took {time()-time0:.2f} seconds") - - -BdotN_over_B_initial = BdotN_over_B(vmec.surface, BiotSavart(coils_initial)) -BdotN_over_B_optimized = BdotN_over_B(vmec.surface, BiotSavart(coils_optimized)) -curvature=jnp.mean(BiotSavart(coils_optimized).coils.curvature, axis=1) -length=jnp.max(jnp.ravel(BiotSavart(coils_optimized).coils.length)) -print(f"Mean curvature: ",curvature) -print(f"Length:", length) -print(f"Maximum BdotN/B before optimization: {jnp.max(BdotN_over_B_initial):.2e}") -print(f"Maximum BdotN/B after optimization: {jnp.max(BdotN_over_B_optimized):.2e}") - -# Plot coils, before and after optimization -fig = plt.figure(figsize=(8, 4)) -ax1 = fig.add_subplot(121, projection='3d') -ax2 = fig.add_subplot(122, projection='3d') -coils_initial.plot(ax=ax1, show=False) -vmec.surface.plot(ax=ax1, show=False) -coils_optimized.plot(ax=ax2, show=False) -vmec.surface.plot(ax=ax2, show=False) -plt.tight_layout() -plt.show() - -# # Save the coils to a json file -# coils_optimized.to_json("stellarator_coils.json") -# # Load the coils from a json file -# from essos.coils import Coils_from_json -# coils = Coils_from_json("stellarator_coils.json") - -# # Save results in vtk format to analyze in Paraview -# from essos.fields import BiotSavart -# vmec.surface.to_vtk('surface_initial', field=BiotSavart(coils_initial)) -# vmec.surface.to_vtk('surface_final', field=BiotSavart(coils_optimized)) -# coils_initial.to_vtk('coils_initial') -# coils_optimized.to_vtk('coils_optimized') \ No newline at end of file diff --git a/examples/paper/fo_integrators.py b/examples/paper/fo_integrators.py new file mode 100644 index 00000000..1a015711 --- /dev/null +++ b/examples/paper/fo_integrators.py @@ -0,0 +1,92 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +from jax import block_until_ready +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from essos.fields import BiotSavart +from essos.coils import Coils +from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE +from essos.dynamics import Tracing, Particles +import diffrax + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +# Load coils and field +json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils.from_json(json_file) +field = BiotSavart(coils) + +# Particle parameters +nparticles = number_of_processors_to_use +mass=PROTON_MASS +energy=5000*ONE_EV +cyclotron_frequency = ELEMENTARY_CHARGE*0.3/mass +print("cyclotron period:", 1/cyclotron_frequency) + +# Particles initialization +initial_xyz=jnp.array([[1.23, 0, 0]]) +particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.8], field=field) + +# Tracing parameters +tmax = 1e-4 +dt = 1e-9 +num_steps = int(tmax/dt) + +fig, ax = plt.subplots(figsize=(9, 6)) + +method_names = ['Tsit5', 'Dopri5', 'Dopri8', 'Boris'] +methods = [getattr(diffrax, method) for method in method_names[:-1]] + ['Boris'] +for method_name, method in zip(method_names, methods): + if method_name != 'Boris': + energies = [] + tracing_times = [] + for trace_tolerance in [1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15]: + time0 = time() + tracing = Tracing('FullOrbit', field, tmax, method=method, timesteps=num_steps, + stepsize='adaptive', tol_step_size=trace_tolerance, particles=particles) + block_until_ready(tracing.trajectories) + tracing_times += [time() - time0] + + print(f"Tracing with adaptive {method_name} and tolerance {trace_tolerance:.0e} took {tracing_times[-1]:.2f} seconds") + + energies += [jnp.mean(jnp.abs(tracing.energy()-particles.energy)/particles.energy)] + ax.plot(tracing_times, energies, label=f'{method_name} adapt', marker='o', markersize=3, linestyle='-') + + energies = [] + tracing_times = [] + for n_points_in_gyration in [10, 20, 50, 75, 100, 150, 200]: + dt = 1/(n_points_in_gyration*cyclotron_frequency) + num_steps = int(tmax/dt) + time0 = time() + tracing = Tracing('FullOrbit', field, tmax, method=method, timesteps=num_steps, + stepsize="constant", particles=particles) + block_until_ready(tracing.trajectories) + tracing_times += [time() - time0] + + print(f"Tracing with {method_name} and step {dt:.2e} took {tracing_times[-1]:.2f} seconds") + + energies += [jnp.mean(jnp.abs(tracing.energy()-particles.energy)/particles.energy)] + ax.plot(tracing_times, energies, label=f'{method_name}', marker='o', markersize=4, linestyle='-') + + +ax.legend(fontsize=15, loc='upper left') +ax.set_xlabel('Computation time (s)') +ax.set_ylabel('Relative energy error') +ax.set_xscale('log') +ax.set_yscale('log') +ax.set_xlim(1e-1, 1e2) +ax.set_ylim(1e-16, 1e-4) +plt.grid(axis='x', which='both', linestyle='--', linewidth=0.6) +plt.grid(axis='y', which='major', linestyle='--', linewidth=0.6) +plt.tight_layout() +plt.savefig(os.path.join(output_dir, 'fo_integration.pdf')) +plt.show() + +## Save results in vtk format to analyze in Paraview +# tracing.to_vtk('trajectories') +# coils.to_vtk('coils') \ No newline at end of file diff --git a/examples/paper/gc_integrators.py b/examples/paper/gc_integrators.py new file mode 100644 index 00000000..f18e3c4a --- /dev/null +++ b/examples/paper/gc_integrators.py @@ -0,0 +1,103 @@ +import os +import gc +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +from jax import block_until_ready +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from essos.fields import BiotSavart +from essos.coils import Coils +from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE +from essos.dynamics import Tracing, Particles + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +# Load coils and field +json_file = os.path.join(os.path.dirname(__file__), '../examples/input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils.from_json(json_file) +field = BiotSavart(coils) + +# Particle parameters +nparticles = number_of_processors_to_use +mass=PROTON_MASS +energy=5000*ONE_EV +cyclotron_frequency = ELEMENTARY_CHARGE*0.3/mass +print("cyclotron period:", 1/cyclotron_frequency) + +# Particles initialization +initial_xyz=jnp.array([[1.23, 0, 0]]) +particles = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.8]) + +# Tracing parameters +tmax = 1e-4 + +fig, ax = plt.subplots(figsize=(9, 6)) +fig_tol, ax_tol = plt.subplots(figsize=(9, 6)) +markers = ["o-", "^-", "*-", "s-"] +for method, marker in zip(['Tsit5', 'Dopri5', 'Dopri8', 'Kvaerno5'], markers): + dt = 1e-7 + num_steps = int(tmax/dt) + energies = [] + tracing_times = [] + tolerances = [1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, 1e-13, 1e-14, 1e-15, 1e-16] + for tolerance in tolerances: + time0 = time() + tracing = Tracing('GuidingCenter', field, tmax, method=method, timesteps=num_steps, + stepsize='adaptive', tol_step_size=tolerance, particles=particles) + block_until_ready(tracing.trajectories) + tracing_times += [time() - time0] + + print(f"Tracing with adaptive {method} and {tolerance=:.0e} took {tracing_times[-1]:.2f} seconds") + + energies += [jnp.max(jnp.abs(tracing.energy()-particles.energy)/particles.energy)] + ax.plot(tracing_times, energies, label=f'{method} adapt', marker='o', markersize=3) + ax_tol.plot(tolerances, energies, marker, label=f'{method} adapt', clip_on=False, linewidth=2.5) + + if method == 'Kvaerno5': continue + + energies = [] + tracing_times = [] + for dt in [4e-7, 2e-7, 1e-7, 8e-8, 6e-8, 4e-8, 2e-8, 1e-8]: + num_steps = int(tmax/dt) + time0 = time() + tracing = Tracing('GuidingCenter', field, tmax, method=method, + timesteps=num_steps, stepsize="constant", particles=particles) + block_until_ready(tracing.trajectories) + tracing_times += [time() - time0] + + print(f"Tracing with {method} and {dt=:.2e} took {tracing_times[-1]:.2f} seconds") + + energies += [jnp.max(jnp.abs(tracing.energy()-particles.energy)/particles.energy)] + ax.plot(tracing_times, energies, label=f'{method}', marker='o', markersize=4, linestyle='-') + gc.collect() + +ax.set_xlabel('Computation time (s)') +ax_tol.set_xlabel('Tracing tolerance') +ax.set_xlim(1e-1, 1e2) +ax_tol.set_xlim(tolerances[-1], tolerances[0]) + +for axis in [ax, ax_tol]: + axis.legend(fontsize=15) + axis.set_ylabel('Relative energy error') + axis.set_xscale('log') + axis.set_yscale('log') + axis.set_ylim(1e-16, 1e-4) + axis.grid(axis='x', which='both', linestyle='--', linewidth=0.6) + axis.grid(axis='y', which='major', linestyle='--', linewidth=0.6) +for figure in [fig, fig_tol]: + figure.tight_layout() + +for spine in ax_tol.spines.values(): + spine.set_zorder(0) + +fig.savefig(os.path.join(output_dir, 'gc_integration.pdf')) +fig_tol.savefig(os.path.join(output_dir, 'energy_vs_tol.pdf')) +plt.show() + +## Save results in vtk format to analyze in Paraview +# tracing.to_vtk('trajectories') +# coils.to_vtk('coils') \ No newline at end of file diff --git a/examples/paper/gc_vs_fo.py b/examples/paper/gc_vs_fo.py new file mode 100644 index 00000000..216a9b2e --- /dev/null +++ b/examples/paper/gc_vs_fo.py @@ -0,0 +1,87 @@ +import os +number_of_processors_to_use = 1 # Parallelization, this should divide nparticles +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from jax import vmap +from time import time +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from essos.fields import BiotSavart +from essos.coils import Coils +from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE +from essos.dynamics import Tracing, Particles +from jax import block_until_ready + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +# Load coils and field +json_file = os.path.join(os.path.dirname(__file__), '../input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils.from_json(json_file) +field = BiotSavart(coils) + +# Particle parameters +nparticles = number_of_processors_to_use +mass=PROTON_MASS +energy=5000*ONE_EV +cyclotron_frequency = ELEMENTARY_CHARGE*0.3/mass +print("cyclotron period:", 1/cyclotron_frequency) + +# Particles initialization +initial_xyz=jnp.array([[1.23, 0, 0]]) + +particles_passing = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.1], phase_angle_full_orbit=0) +particles_traped = Particles(initial_xyz=initial_xyz, mass=mass, energy=energy, initial_vparallel_over_v=[0.9], phase_angle_full_orbit=0) +particles = particles_passing.join(particles_traped, field=field) + +# Tracing parameters +tmax = 1e-5 +trace_tolerance = 1e-14 +dt_gc = 1e-7 +dt_fo = 1e-9 +num_steps_gc = int(tmax/dt_gc) +num_steps_fo = int(tmax/dt_fo) + +# Trace in ESSOS +time0 = time() +tracing_gc = Tracing(field=field, model='GuidingCenter', particles=particles, + maxtime=tmax, timestep=dt_gc, atol=trace_tolerance, rtol=trace_tolerance, + times_to_trace=200) +trajectories_guidingcenter = block_until_ready(tracing_gc.trajectories) +print(f"ESSOS guiding center tracing took {time()-time0:.2f} seconds") + +time0 = time() +tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax, + timestep=dt_fo, atol=trace_tolerance, rtol=trace_tolerance, + times_to_trace=600) + +block_until_ready(tracing_fo.trajectories) +print(f"ESSOS full orbit tracing took {time()-time0:.2f} seconds") + +# Plot trajectories, velocity parallel to the magnetic field, and energy error +fig = plt.figure(figsize=(9, 8)) +ax = fig.add_subplot(projection='3d') +coils.plot(ax=ax, show=False) +tracing_gc.plot(ax=ax, show=False, color='black', linewidth=2) +tracing_fo.plot(ax=ax, show=False) +plt.tight_layout() + +plt.figure(figsize=(9, 6)) +plt.plot(tracing_gc.times[1:]*1000, jnp.abs(tracing_gc.energy()[0][1:]/particles.energy-1)+1e-17, label='Guiding Center', color='red') +plt.plot(tracing_fo.times[1:]*1000, jnp.abs(tracing_fo.energy()[0][1:]/particles.energy-1)+1e-17, label='Full Orbit', color='blue') +plt.xlabel('Time (ms)') +plt.ylabel('Relative energy error') +plt.xlim(0, tmax*1000) +# plt.ylim(bottom=0) +plt.yscale('log') +plt.legend() +plt.tight_layout() +plt.savefig(os.path.join(output_dir, 'energies.png'), dpi=300) + +plt.show() + +## Save results in vtk format to analyze in Paraview +# tracing_gc.to_vtk(os.path.join(output_dir, 'trajectories_gc')) +# tracing_fo.to_vtk(os.path.join(output_dir, 'trajectories_fo')) +# coils.to_vtk(os.path.join(output_dir, 'coils')) \ No newline at end of file diff --git a/examples/paper/gradients.py b/examples/paper/gradients.py new file mode 100644 index 00000000..f0eab357 --- /dev/null +++ b/examples/paper/gradients.py @@ -0,0 +1,124 @@ +import os +from functools import partial +number_of_processors_to_use = 8 # Parallelization, this should divide ntheta*nphi +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +from jax import jit, grad, block_until_ready +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from essos.coils import Coils, CreateEquallySpacedCurves +from essos.fields import Vmec +from essos.objective_functions import loss_BdotN + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +# Optimization parameters +max_coil_length = 40 +max_coil_curvature = 0.5 +order_Fourier_series_coils = 6 +number_coil_points = order_Fourier_series_coils*10 +maximum_function_evaluations = 300 +number_coils_per_half_field_period = 4 +tolerance_optimization = 1e-5 +ntheta=32 +nphi=32 + +# Initialize VMEC field +vmec = Vmec(os.path.join(os.path.dirname(__file__), '../examples/input_files', + 'wout_LandremanPaul2021_QA_reactorScale_lowres.nc'), + ntheta=ntheta, nphi=nphi, range_torus='half period') + +# Initialize coils +current_on_each_coil = 1 +number_of_field_periods = vmec.nfp +major_radius_coils = vmec.r_axis +minor_radius_coils = vmec.r_axis/1.5 +curves = CreateEquallySpacedCurves(n_curves=number_coils_per_half_field_period, + order=order_Fourier_series_coils, + R=major_radius_coils, r=minor_radius_coils, + n_segments=number_coil_points, + nfp=number_of_field_periods, stellsym=True) + +coils = Coils(curves=curves, currents=[current_on_each_coil]*number_coils_per_half_field_period) + + +loss_partial = partial(loss_BdotN, dofs_curves=coils.dofs_curves, currents_scale=coils.currents_scale, + nfp=coils.nfp, n_segments=coils.n_segments, stellsym=coils.stellsym, + vmec=vmec, max_coil_length=max_coil_length, max_coil_curvature=max_coil_curvature) +print(loss_partial(coils.x)) +grad_loss_partial = jit(grad(loss_partial)) + +time0 = time() +loss = loss_partial(coils.x) +block_until_ready(loss) +print(f"Loss took {time()-time0:.4f} seconds. Gradient would take {(time()-time0)*(coils.x.size +1):.4f} seconds") + +time0 = time() +loss_comp = loss_partial(coils.x) +block_until_ready(loss_comp) +print(f"Compiled loss took {time()-time0:.4f} seconds. Gradient would take {(time()-time0)*(coils.x.size +1):.4f} seconds") + +time0 = time() +grad_loss = grad_loss_partial(coils.x) +block_until_ready(grad_loss) +print(f"Gradient took {time()-time0:.4f} seconds") + +time0 = time() +grad_loss_comp = grad_loss_partial(coils.x) +block_until_ready(grad_loss_comp) +print(f"Compiled gradient took {time()-time0:.4f} seconds") + +# Parameter to perturb +param = 42 + +# Set the possible perturbations +h_list = jnp.arange(-9, -0.9, 1/3) +h_list = 10.**h_list + +# Number of orders for finite differences +fd_loss = jnp.zeros(4) + +# Array to store the relative difference +fd_diff = jnp.zeros((fd_loss.size, h_list.size)) + +# Compute finite differences +for index, h in enumerate(h_list): + delta = jnp.zeros(coils.x.shape) + delta = delta.at[param].set(h) + + # 1st order finite differences + fd_loss = fd_loss.at[0].set((loss_partial(coils.x+delta)-loss_partial(coils.x))/h) + # 2nd order finite differences + fd_loss = fd_loss.at[1].set((loss_partial(coils.x+delta)-loss_partial(coils.x-delta))/(2*h)) + # 4th order finite differences + fd_loss = fd_loss.at[2].set((loss_partial(coils.x-2*delta)-8*loss_partial(coils.x-delta)+8*loss_partial(coils.x+delta)-loss_partial(coils.x+2*delta))/(12*h)) + # 6th order finite differences + fd_loss = fd_loss.at[3].set((loss_partial(coils.x+3*delta)-9*loss_partial(coils.x+2*delta)+45*loss_partial(coils.x+delta)-45*loss_partial(coils.x-delta)+9*loss_partial(coils.x-2*delta)-loss_partial(coils.x-3*delta))/(60*h)) + + fd_diff_h = jnp.abs((grad_loss[param]-fd_loss)/grad_loss[param]) + fd_diff = fd_diff.at[:, index].set(fd_diff_h) + + +# plot relative difference +plt.figure(figsize=(9, 6)) +plt.plot(h_list, fd_diff[0], "o-", label=f'1st order', clip_on=False, linewidth=2.5) +plt.plot(h_list, fd_diff[1], "^-", label=f'2nd order', clip_on=False, linewidth=2.5) +plt.plot(h_list, fd_diff[2], "*-", label=f'4th order', clip_on=False, linewidth=2.5) +plt.plot(h_list, fd_diff[3], "s-", label=f'6th order', clip_on=False, linewidth=2.5) +plt.legend(fontsize=15) +plt.xlabel('Finite differences stepsize h') +plt.ylabel('Relative error') +plt.xscale('log') +plt.yscale('log') +plt.ylim(1e-13, 1e-1) +plt.xlim(jnp.min(h_list), jnp.max(h_list)) +plt.grid(which='both', axis='x', linestyle='--', linewidth=0.6) +plt.grid(which='major', axis='y', linestyle='--', linewidth=0.6) +for spine in plt.gca().spines.values(): + spine.set_zorder(0) +plt.tight_layout() +plt.savefig(os.path.join(output_dir, 'gradients.pdf')) +plt.show() \ No newline at end of file diff --git a/examples/paper/poincare_plots.py b/examples/paper/poincare_plots.py new file mode 100644 index 00000000..95d27244 --- /dev/null +++ b/examples/paper/poincare_plots.py @@ -0,0 +1,134 @@ +import os +from functools import partial +number_of_processors_to_use = 1 # Parallelization, this should divide ntheta*nphi +os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={number_of_processors_to_use}' +from time import time +from jax import jit, grad, block_until_ready +import jax.numpy as jnp +import matplotlib.pyplot as plt +plt.rcParams.update({'font.size': 18}) +from essos.coils import Coils +from essos.constants import PROTON_MASS, ONE_EV, ELEMENTARY_CHARGE +from essos.fields import BiotSavart +from essos.dynamics import Tracing, Particles + +output_dir = os.path.join(os.path.dirname(__file__), 'output') +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +# Input parameters +tmax_fl = 50000 +tmax_gc = 1e-3 +tmax_fo = 1e-3 + +nparticles = number_of_processors_to_use*1 +nfieldlines = number_of_processors_to_use*8 +s = 0.25 # s-coordinate: flux surface label +trace_tolerance = 1e-15 +dt_fo = 1e-9 +dt_gc = 1e-7 +timesteps_gc = int(tmax_gc/dt_gc) +timesteps_fo = int(tmax_fo/dt_fo) +mass = PROTON_MASS +energy = 4000*ONE_EV +print("cyclotron period:", 1/(ELEMENTARY_CHARGE*0.3/mass)) + +# Load coils and field +json_file = os.path.join(os.path.dirname(__file__), '../input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils.from_json(json_file) +field = BiotSavart(coils) + +R0_fieldlines = jnp.linspace(1.21, 1.41, nfieldlines) +R0_particles= jnp.linspace(1.21, 1.41, nparticles) +Z0_fieldlines = jnp.zeros(nfieldlines) +Z0_particles = jnp.zeros(nparticles) +phi0_fieldlines = jnp.zeros(nfieldlines) +phi0_particles = jnp.zeros(nparticles) + +initial_xyz_fieldlines=jnp.array([R0_fieldlines*jnp.cos(phi0_fieldlines), R0_fieldlines*jnp.sin(phi0_fieldlines), Z0_fieldlines]).T +initial_xyz_particles=jnp.array([R0_particles*jnp.cos(phi0_particles), R0_particles*jnp.sin(phi0_particles), Z0_particles]).T + +particles = Particles(initial_xyz=initial_xyz_particles, mass=mass, energy=energy, field=field, min_vparallel_over_v=0.8) + +# Trace in ESSOS +# time0 = time() +# tracing_fl = Tracing(field=field, model='FieldLine', initial_conditions=initial_xyz_fieldlines, +# maxtime=tmax_fl, timesteps=tmax_fl*10, tol_step_size=trace_tolerance) +# block_until_ready(tracing_fl) +# print(f"ESSOS tracing of {nfieldlines} field lines took {time()-time0:.2f} seconds") + +time0 = time() +tracing_fo = Tracing(field=field, model='FullOrbit', particles=particles, maxtime=tmax_fo, + timesteps=timesteps_fo, tol_step_size=trace_tolerance) +# tracing_fo.trajectories = tracing_fo.trajectories[:, 0::100, :] +# tracing_fo.times = tracing_fo.times[0::100] +# tracing_fo.energy = tracing_fo.energy[:, 0::100] +block_until_ready(tracing_fo) +print(f"ESSOS tracing of {nparticles} particles with FO for {tmax_fo:.1e}s took {time()-time0:.2f} seconds") + +time0 = time() +tracing_gc = Tracing(field=field, model='GuidingCenter', particles=particles, maxtime=tmax_gc, + timesteps=timesteps_gc, tol_step_size=trace_tolerance) +block_until_ready(tracing_gc) +print(f"ESSOS tracing of {nparticles} particles with GC for {tmax_gc:.1e}s took {time()-time0:.2f} seconds") + +# fig = plt.figure(figsize=(9, 6)) +# ax = fig.add_subplot(projection='3d') +# coils.plot(ax=ax, show=False) +# tracing_fl.plot(ax=ax, show=False) +# plt.tight_layout() + +# fig = plt.figure(figsize=(9, 6)) +# ax = fig.add_subplot(projection='3d') +# coils.plot(ax=ax, show=False) +# tracing_fo.plot(ax=ax, show=False) +# plt.tight_layout() + +# fig = plt.figure(figsize=(9, 6)) +# ax = fig.add_subplot(projection='3d') +# coils.plot(ax=ax, show=False) +# tracing_gc.plot(ax=ax, show=False) +# plt.tight_layout() + +# fig, ax = plt.subplots(figsize=(9, 6)) +# time0 = time() +# tracing_fl.poincare_plot(ax=ax, shifts=[jnp.pi/2], show=False, s=0.5) +# print(f"ESSOS Poincare plot of {nfieldlines} field lines took {time()-time0:.2f} seconds") +# plt.xlabel('R (m)') +# plt.ylabel('Z (m)') +# ax.set_xlim(0.3, 1.3) +# ax.set_ylim(-0.3, 0.3) +# plt.grid(visible=False) +# plt.tight_layout() +# plt.savefig(os.path.join(output_dir, 'poincare_plot_fl.png'), dpi=300) +# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_fl.png'), dpi=300) + + +# fig, ax = plt.subplots(figsize=(9, 6)) +# time0 = time() +# tracing_fo.poincare_plot(ax=ax, shifts=[jnp.pi/2], show=False) +# print(f"ESSOS Poincare plot of {nparticles} particles took {time()-time0:.2f} seconds") +# plt.xlabel('R (m)') +# plt.ylabel('Z (m)') +# plt.xlim(0.3, 1.3) +# plt.ylim(-0.3, 0.3) +# plt.grid(visible=False) +# plt.tight_layout() +# plt.savefig(os.path.join(output_dir 'poincare_plot_fo.png'), dpi=300) +# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_fo.png'), dpi=300) + + +# fig, ax = plt.subplots(figsize=(9, 6)) +# time0 = time() +# tracing_gc.poincare_plot(ax=ax, shifts=[jnp.pi/2], show=False) +# print(f"ESSOS Poincare plot of {nparticles} particles took {time()-time0:.2f} seconds") +# plt.xlabel('R (m)') +# plt.ylabel('Z (m)') +# ax.set_xlim(0.3, 1.3) +# ax.set_ylim(-0.3, 0.3) +# plt.grid(visible=False) +# plt.tight_layout() +# plt.savefig(os.path.join(output_dir, 'poincare_plot_gc.png'), dpi=300) +# plt.savefig(os.path.join(os.path.dirname(__file__), 'poincare_plot_gc.png'), dpi=300) + +# plt.show() \ No newline at end of file diff --git a/examples/trace_particles_coils_fullorbit.py b/examples/particle_tracing/trace_particles_coils_fullorbit.py similarity index 100% rename from examples/trace_particles_coils_fullorbit.py rename to examples/particle_tracing/trace_particles_coils_fullorbit.py diff --git a/examples/trace_particles_coils_guidingcenter.py b/examples/particle_tracing/trace_particles_coils_guidingcenter.py similarity index 89% rename from examples/trace_particles_coils_guidingcenter.py rename to examples/particle_tracing/trace_particles_coils_guidingcenter.py index 018317c3..6634674e 100644 --- a/examples/trace_particles_coils_guidingcenter.py +++ b/examples/particle_tracing/trace_particles_coils_guidingcenter.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import matplotlib.pyplot as plt from essos.fields import BiotSavart -from essos.coils import Coils_from_json +from essos.coils import Coils from essos.constants import ALPHA_PARTICLE_MASS, ALPHA_PARTICLE_CHARGE, ONE_EV from essos.dynamics import Tracing, Particles @@ -21,8 +21,8 @@ energy=4000*ONE_EV # Load coils and field -json_file = os.path.join(os.path.dirname(__name__), 'input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') -coils = Coils_from_json(json_file) +json_file = os.path.join(os.path.dirname(__file__), 'input_files', 'ESSOS_biot_savart_LandremanPaulQA.json') +coils = Coils.from_json(json_file) field = BiotSavart(coils) # Initialize particles @@ -49,7 +49,7 @@ tracing.plot(ax=ax1, show=False) for i, trajectory in enumerate(trajectories): - ax2.plot(tracing.times, jnp.abs(tracing.energy[i]-particles.energy)/particles.energy, label=f'Particle {i+1}') + ax2.plot(tracing.times, jnp.abs(tracing.energy()[i]-particles.energy)/particles.energy, label=f'Particle {i+1}') ax3.plot(tracing.times, trajectory[:, 3]/particles.total_speed, label=f'Particle {i+1}') ax4.plot(jnp.sqrt(trajectory[:,0]**2+trajectory[:,1]**2), trajectory[:, 2], label=f'Particle {i+1}') ax2.set_xlabel('Time (s)') diff --git a/examples/trace_particles_coils_guidingcenter_with_classifier.py b/examples/particle_tracing/trace_particles_coils_guidingcenter_with_classifier.py similarity index 100% rename from examples/trace_particles_coils_guidingcenter_with_classifier.py rename to examples/particle_tracing/trace_particles_coils_guidingcenter_with_classifier.py diff --git a/examples/trace_particles_coils_guidingcenter_with_classifier_scaled_currents.py b/examples/particle_tracing/trace_particles_coils_guidingcenter_with_classifier_scaled_currents.py similarity index 100% rename from examples/trace_particles_coils_guidingcenter_with_classifier_scaled_currents.py rename to examples/particle_tracing/trace_particles_coils_guidingcenter_with_classifier_scaled_currents.py diff --git a/examples/trace_particles_vmec.py b/examples/particle_tracing/trace_particles_vmec.py similarity index 100% rename from examples/trace_particles_vmec.py rename to examples/particle_tracing/trace_particles_vmec.py diff --git a/examples/trace_particles_vmec_Electric_field.py b/examples/particle_tracing/trace_particles_vmec_Electric_field.py similarity index 100% rename from examples/trace_particles_vmec_Electric_field.py rename to examples/particle_tracing/trace_particles_vmec_Electric_field.py diff --git a/examples/testing_collisions_velocity_distributions_mu_Adaptative.py b/examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_Adaptative.py similarity index 100% rename from examples/testing_collisions_velocity_distributions_mu_Adaptative.py rename to examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_Adaptative.py diff --git a/examples/testing_collisions_velocity_distributions_mu_Fixed.py b/examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_Fixed.py similarity index 100% rename from examples/testing_collisions_velocity_distributions_mu_Fixed.py rename to examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_Fixed.py diff --git a/examples/testing_collisions_velocity_distributions_mu_time.py b/examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_time.py similarity index 100% rename from examples/testing_collisions_velocity_distributions_mu_time.py rename to examples/particle_tracing_collisions/statistics_collisions_velocity_distributions_mu_time.py diff --git a/examples/trace_particles_coils_guidingcenter_with_classifier_with_collisionsMu.py b/examples/particle_tracing_collisions/trace_particles_coils_guidingcenter_with_classifier_with_collisionsMu.py similarity index 100% rename from examples/trace_particles_coils_guidingcenter_with_classifier_with_collisionsMu.py rename to examples/particle_tracing_collisions/trace_particles_coils_guidingcenter_with_classifier_with_collisionsMu.py diff --git a/examples/trace_particles_vmec_collisionsMu.py b/examples/particle_tracing_collisions/trace_particles_vmec_collisionsMu.py similarity index 100% rename from examples/trace_particles_vmec_collisionsMu.py rename to examples/particle_tracing_collisions/trace_particles_vmec_collisionsMu.py diff --git a/examples/create_perturbed_coils.py b/examples/simple_examples/create_perturbed_coils.py similarity index 100% rename from examples/create_perturbed_coils.py rename to examples/simple_examples/create_perturbed_coils.py diff --git a/examples/create_stellarator_coils.py b/examples/simple_examples/create_stellarator_coils.py similarity index 100% rename from examples/create_stellarator_coils.py rename to examples/simple_examples/create_stellarator_coils.py