diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..38c8f8f Binary files /dev/null and b/.DS_Store differ diff --git a/.github/.DS_Store b/.github/.DS_Store new file mode 100644 index 0000000..f1e8d08 Binary files /dev/null and b/.github/.DS_Store differ diff --git a/.gitignore b/.gitignore index 5debf4d..0ba8eaa 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,21 @@ svZeroDSolver* # Sample artifacts cube.dmn +# Reports and LaTeX build artifacts +*.aux +*.log +*.out +*.toc +optimization_report.* +macos_gui_report.* + +# Data files (too large for repo) +biventricular.txt +biventricular_inlet_outlet.txt + +# Claude Code project instructions (local only) +CLAUDE.md + # Local GUI docs and scripts (ignored) CAD_GUI_DOCUMENTATION.md CAD_GUI_SUMMARY.md diff --git a/cube.stl b/cube.stl new file mode 100644 index 0000000..cb5f92d Binary files /dev/null and b/cube.stl differ diff --git a/docs/.DS_Store b/docs/.DS_Store new file mode 100644 index 0000000..68affce Binary files /dev/null and b/docs/.DS_Store differ diff --git a/svv/domain/domain.py b/svv/domain/domain.py index 005debe..c7f35ef 100644 --- a/svv/domain/domain.py +++ b/svv/domain/domain.py @@ -1,19 +1,18 @@ import numpy as np import pyvista as pv -from scipy.spatial import cKDTree +from scipy.spatial import cKDTree, ConvexHull from svv.domain.patch import Patch from svv.domain.routines.allocate import allocate from svv.domain.routines.discretize import contour from svv.domain.io.read import read from svv.domain.routines.tetrahedralize import tetrahedralize, triangulate from svv.domain.routines.c_sample import pick_from_tetrahedron, pick_from_triangle, pick_from_line -from concurrent.futures import ProcessPoolExecutor, as_completed +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from svv.domain.routines.boolean import boolean # from svtoolkit.tree.utils.KDTreeManager import KDTreeManager from svv.tree.utils.TreeManager import KDTreeManager, USearchTree from time import perf_counter from tqdm import trange, tqdm -from sklearn.neighbors import BallTree import random @@ -427,6 +426,8 @@ def report(progress=None, label=None, indeterminate=None, force=False): self.C[i] = function.c self.D[i] = function.d self.PTS[i, :function.points.shape[0]] = function.pts + # Pre-compute normalize_scale for evaluate_fast + self._normalize_scale = np.linalg.norm(np.max(self.points, axis=0) - np.min(self.points, axis=0)) # for fast evaluation report(0.55, "Fast evaluation structures assembled") if self.random_generator is None: @@ -495,12 +496,27 @@ def evaluate_fast(self, points, k=1, normalize=True, tolerance=np.finfo(float).e raise ValueError("Domain not built.") if self.points.shape[1] != self.d: raise ValueError("Dimension mismatch.") + # Auto-chunk large inputs to cap peak memory from + # O(N * max_neighbors * max_pts_per_patch * d) intermediate arrays. + _CHUNK = 2000 + if points.shape[0] > _CHUNK and not show: + values = np.empty((points.shape[0], 1)) + for i0 in range(0, points.shape[0], _CHUNK): + i1 = min(i0 + _CHUNK, points.shape[0]) + values[i0:i1] = self.evaluate_fast( + points[i0:i1], k=k, normalize=normalize, + tolerance=tolerance, show=show) + return values values = np.zeros((points.shape[0], 1)) dists, indices_first = self.function_tree.query(points, k=1) dists_first = dists.reshape(-1, 1)[:, -1].flatten() indices = self.function_tree.query_ball_point(points, dists_first + tolerance * dists_first) if normalize: - normalize_scale = np.linalg.norm(np.max(self.points, axis=0) - np.min(self.points, axis=0)) + # Use cached normalize_scale if available (pre-computed in build()) + normalize_scale = getattr(self, '_normalize_scale', None) + if normalize_scale is None: + normalize_scale = np.linalg.norm(np.max(self.points, axis=0) - np.min(self.points, axis=0)) + self._normalize_scale = normalize_scale else: normalize_scale = 1 if show: @@ -509,21 +525,24 @@ def evaluate_fast(self, points, k=1, normalize=True, tolerance=np.finfo(float).e print("Distances: ", dists) print("First Indices: ", indices_first) print("First Distances: ", dists_first) - indices_shape = np.array([len(indices[i]) for i in range(len(indices))]) - inds = np.full((len(indices), indices_shape.max()), -1) + # Vectorized index array construction (replaces Python for loop) + indices_shape = np.array([len(idx) for idx in indices]) + max_width = indices_shape.max() if len(indices_shape) > 0 else 1 + inds = np.full((len(indices), max_width), -1, dtype=np.intp) + # Build ragged index array using vectorized filling + empty_mask = indices_shape == 0 + if np.any(empty_mask): + inds[empty_mask, 0] = indices_first.ravel()[empty_mask] + non_empty = np.where(~empty_mask)[0] + if len(non_empty) > 0: + # Batch fill for rows with same length for efficiency + for length in np.unique(indices_shape[non_empty]): + rows = non_empty[indices_shape[non_empty] == length] + block = np.array([indices[r] for r in rows], dtype=np.intp) + inds[rows, :length] = block if show: print("Indices Shape: ", indices_shape) print("Inds Shape: ", inds.shape) - for i in range(len(indices)): - if len(indices[i]) == 0: - inds[i, 0] = indices_first[i] - else: - inds[i, :len(indices[i])] = indices[i] - #elif isinstance(indices_first[i], np.int64): - # print("Fallback indices found for point {} -> {}".format(i, points[i, :])) - # inds[i, 0] = indices_first[i] - #else: - # print("No indices found for point {} -> {}".format(i, points[i, :])) inds_mask = np.ma.masked_array(inds, mask=(inds == -1)) if np.any(np.all(inds_mask.mask, axis=1)): print("Mask for entire row! {}".format(np.argwhere(np.all(inds_mask.mask, axis=1)))) @@ -686,7 +705,9 @@ def get_boundary(self, resolution, **kwargs): self.boundary = self.original_boundary.triangulate() else: self.boundary = self.original_boundary - _, self.grid = contour(self.__call__, self.points, resolution) + # Skip the expensive contour() call — self.grid is not used downstream + # and contour() evaluates the implicit function at resolution³ points + self.grid = None self.boundary = self.boundary.connectivity(extraction_mode='largest') self.boundary = self.boundary.compute_cell_sizes() if self.points.shape[1] == 2: @@ -741,25 +762,27 @@ def get_interior(self, verbose=False, **kwargs): self.volume = _mesh.volume else: raise ValueError("Only 2D and 3D domains are supported.") - self.mesh_tree = cKDTree(_mesh.cell_centers().points[:, :self.points.shape[1]], leafsize=4) - self.mesh_tree_2 = BallTree(_mesh.cell_centers().points[:, :self.points.shape[1]]) + cell_centers = _mesh.cell_centers().points[:, :self.points.shape[1]] + self.mesh_tree = cKDTree(cell_centers, leafsize=4) + # mesh_tree_2 kept as alias for backward compatibility + self.mesh_tree_2 = self.mesh_tree self.mesh = _mesh self.mesh_nodes = nodes.astype(np.float64) self.mesh_vertices = vertices.astype(np.int64) + # Use scipy ConvexHull instead of expensive pyvista delaunay_3d if self.points.shape[1] == 2: - delaunay = pv.PolyData() - tmp_points = np.zeros((self.points.shape[0], 3)) - tmp_points[:, :2] = self.points - delaunay.points = tmp_points - delaunay = delaunay.delaunay_2d(offset=2*np.linalg.norm(np.max(self.points, axis=0) - - np.min(self.points, axis=0))) - self.convexity = self.mesh.area / delaunay.area + try: + hull = ConvexHull(self.points) + self.convexity = self.mesh.area / hull.volume # In 2D, ConvexHull.volume is area + except Exception: + self.convexity = 1.0 elif self.points.shape[1] == 3: - delaunay = pv.PolyData() - delaunay.points = np.unique(self.points, axis=0) - delaunay = delaunay.delaunay_3d(offset=2*np.linalg.norm(np.max(self.points, axis=0) - - np.min(self.points, axis=0))) - self.convexity = self.mesh.volume / delaunay.volume + try: + unique_pts = np.unique(self.points, axis=0) + hull = ConvexHull(unique_pts) + self.convexity = self.mesh.volume / hull.volume + except Exception: + self.convexity = 1.0 else: raise ValueError("Only 2D and 3D domains are supported.") return _mesh @@ -838,12 +861,12 @@ def get_interior_points(self, n, tree=None, volume_threshold=None, cells_outer = np.arange(self.mesh.n_cells, dtype=np.int64) else: #cells_0 = self.mesh_tree_2.query_radius(tree.active_tree.data, volume_threshold) - cells_0 = self.mesh_tree_2.query_radius(tree, volume_threshold) + cells_0 = self.mesh_tree.query_ball_point(tree, volume_threshold) cells_outer = np.unique(np.concatenate(cells_0)) #_ = [cells_outer.extend(cell) for cell in cells_0] #cells_1 = tree.query_ball_tree(self.mesh_tree, threshold, eps=threshold/100) #cells_1 = self.mesh_tree_2.query_radius(tree.active_tree.data, threshold) - cells_1 = self.mesh_tree_2.query_radius(tree, threshold) + cells_1 = self.mesh_tree.query_ball_point(tree, threshold) #cells_inner = [] #_ = [cells_inner.extend(cell) for cell in cells_1] cells_inner = np.unique(np.concatenate(cells_1)) diff --git a/svv/domain/io/dmn.py b/svv/domain/io/dmn.py index 465878b..67bf195 100644 --- a/svv/domain/io/dmn.py +++ b/svv/domain/io/dmn.py @@ -403,11 +403,8 @@ def read_dmn(path: Union[str, os.PathLike]): # Build spatial index for cell lookups cell_centers = dom.mesh.cell_centers().points[:, :points.shape[1]] dom.mesh_tree = cKDTree(cell_centers, leafsize=4) - try: - from sklearn.neighbors import BallTree - dom.mesh_tree_2 = BallTree(cell_centers) - except Exception: # pragma: no cover - dom.mesh_tree_2 = None + # Reuse cKDTree for radius queries instead of building a separate BallTree + dom.mesh_tree_2 = dom.mesh_tree # Use numpy array instead of list for efficiency dom.all_mesh_cells = np.arange(n_cells, dtype=np.int64) dom.cumulative_probability = np.cumsum(dom.mesh.cell_data['Normalized_Area']) @@ -433,11 +430,8 @@ def read_dmn(path: Union[str, os.PathLike]): # Build spatial index for cell lookups cell_centers = dom.mesh.cell_centers().points[:, :points.shape[1]] dom.mesh_tree = cKDTree(cell_centers, leafsize=4) - try: - from sklearn.neighbors import BallTree - dom.mesh_tree_2 = BallTree(cell_centers) - except Exception: # pragma: no cover - dom.mesh_tree_2 = None + # Reuse cKDTree for radius queries instead of building a separate BallTree + dom.mesh_tree_2 = dom.mesh_tree # Use numpy array instead of list for efficiency dom.all_mesh_cells = np.arange(n_cells, dtype=np.int64) dom.cumulative_probability = np.cumsum(dom.mesh.cell_data['Normalized_Volume']) diff --git a/svv/domain/routines/tetrahedralize.py b/svv/domain/routines/tetrahedralize.py index ef86db4..3121537 100644 --- a/svv/domain/routines/tetrahedralize.py +++ b/svv/domain/routines/tetrahedralize.py @@ -87,35 +87,16 @@ def _run_tetgen(surface_mesh): nodes, elems = tgen.tetrahedralize(verbose=0) return nodes, elems -def tetrahedralize(surface: pv.PolyData, - *tet_args, - worker_script: str = dirpath+os.sep+"tetgen_worker.py", - python_exe: str = sys.executable, - **tet_kwargs): - """ - Tetrahedralize a surface mesh using TetGen. - - Parameters - ---------- - surface_mesh : PyMesh mesh object - The surface mesh to tetrahedralize. - verbose : bool - A flag to indicate if mesh fixing should be verbose. - kwargs : dict - A dictionary of keyword arguments to be passed to TetGen. - - Returns - ------- - mesh : PyMesh mesh object - An unstructured grid mesh representing the tetrahedralized - volume enclosed by the surface mesh manifold. - """ - tet_kwargs.setdefault("verbose", 0) +def _tetrahedralize_subprocess(surface, *tet_args, + worker_script=None, + python_exe=None, + **tet_kwargs): + """Subprocess fallback for TetGen crash isolation.""" + if worker_script is None: + worker_script = dirpath + os.sep + "tetgen_worker.py" + if python_exe is None: + python_exe = sys.executable - # On Windows, `tempfile` honors TMPDIR, which may be set to a POSIX-style - # path such as '/tmp' and is not a valid directory there. Prefer the - # standard TEMP/TMP locations when available to avoid spurious - # "[WinError 267] The directory name is invalid" errors. tmp_root = None if os.name == "nt": for env_var in ("TEMP", "TMP"): @@ -136,18 +117,15 @@ def tetrahedralize(surface: pv.PolyData, with open(config_path, "w") as f: json.dump(cfg, f) - # Save the surface mesh so the worker can read it surface.save(surface_path) - # Command: call the worker script as a separate Python process cmd = [python_exe, worker_script, surface_path, out_path, config_path] - # Start the worker process proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - text=True, # decode to strings + text=True, ) show_spinner = sys.stdout.isatty() @@ -155,54 +133,39 @@ def tetrahedralize(surface: pv.PolyData, spinner = _spinner_cycle() start_time = time.time() - # Print label once sys.stdout.write("TetGen meshing| ") sys.stdout.flush() - # Live spinner loop while proc.poll() is None: - # Compute elapsed time elapsed = time.time() - start_time elapsed_str = format_elapsed(elapsed) - # Build left side message spin_char = next(spinner) left = f"TetGen meshing| {spin_char}" - # Get terminal width (fallback if IDE doesn't report it) try: width = shutil.get_terminal_size(fallback=(80, 20)).columns except Exception: width = 80 - # Compute spacing so elapsed time is right-aligned - # We'll always keep at least one space between left and right min_gap = 1 total_len = len(left) + min_gap + len(elapsed_str) if total_len <= width: spaces = width - len(left) - len(elapsed_str) else: - # If line is longer than terminal, don't try to be clever; just put a single space spaces = min_gap line = f"{left}{' ' * spaces}{elapsed_str}" - - # '\r' to return to the start of the same line and overwrite sys.stdout.write("\r" + line) sys.stdout.flush() time.sleep(0.1) - # Finish line sys.stdout.write("\n") sys.stdout.flush() else: - # Non-interactive environment (e.g., CI): just wait for the - # worker process to finish without a live spinner to avoid - # any potential overhead from frequent stdout updates. proc.wait() - # Collect output (so the pipes don't hang) stdout, stderr = proc.communicate() if proc.returncode != 0: @@ -211,12 +174,49 @@ def tetrahedralize(surface: pv.PolyData, f"STDOUT:\n{stdout}\n\nSTDERR:\n{stderr}" ) - # Load results and ensure the file handle is closed before the - # temporary directory is cleaned up (important on Windows). with np.load(out_path) as data: nodes = data["nodes"] elems = data["elems"] + return nodes, elems + +def tetrahedralize(surface: pv.PolyData, + *tet_args, + worker_script: str = dirpath+os.sep+"tetgen_worker.py", + python_exe: str = sys.executable, + **tet_kwargs): + """ + Tetrahedralize a surface mesh using TetGen. + + Parameters + ---------- + surface_mesh : PyMesh mesh object + The surface mesh to tetrahedralize. + verbose : bool + A flag to indicate if mesh fixing should be verbose. + kwargs : dict + A dictionary of keyword arguments to be passed to TetGen. + + Returns + ------- + mesh : PyMesh mesh object + An unstructured grid mesh representing the tetrahedralized + volume enclosed by the surface mesh manifold. + """ + tet_kwargs.setdefault("verbose", 0) + + # Try in-process TetGen first (avoids subprocess overhead: save/load/spawn) + try: + nodes, elems = _run_tetgen(surface) + except Exception: + # Fall back to subprocess worker for crash isolation + nodes, elems = _tetrahedralize_subprocess( + surface, *tet_args, + worker_script=worker_script, + python_exe=python_exe, + **tet_kwargs + ) + if elems.min() == 1: elems = elems - 1 diff --git a/svv/forest/connect/assign.py b/svv/forest/connect/assign.py index a8d8041..1ea7baa 100644 --- a/svv/forest/connect/assign.py +++ b/svv/forest/connect/assign.py @@ -5,12 +5,64 @@ from scipy.spatial.distance import cdist from scipy.optimize import linear_sum_assignment from scipy.interpolate import splprep, splev, interp1d -from copy import deepcopy -from scipy.optimize import minimize from scipy.sparse import coo_matrix -from scipy.sparse import lil_matrix from scipy.sparse.csgraph import min_weight_full_bipartite_matching + +def _make_linear_interp(point_a, point_b): + """Create a linear interpolation function between two 3D points.""" + def func(t_, pa=point_a.copy(), pb=point_b.copy()): + t_ = numpy.atleast_1d(t_) + return (1.0 - t_)[:, None] * pa + t_[:, None] * pb + return func + + +def _make_geodesic_interp(path_pts): + """Create an interpolation function along a geodesic path.""" + t = numpy.linspace(0, 1, path_pts.shape[0]) + xpts = interp1d(t, path_pts[:, 0]) + ypts = interp1d(t, path_pts[:, 1]) + zpts = interp1d(t, path_pts[:, 2]) + def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): + return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T + return func + + +def _compute_path_and_func(forest, pt_a, pt_b, convex): + """Compute the interpolation function and distance for a terminal pair. + + For convex domains, uses a straight line. For non-convex domains, + checks if the straight line stays inside the domain and falls back + to geodesic pathfinding if not. + + Returns + ------- + func : callable + Interpolation function mapping t in [0,1] to 3D points. + dist : float + Total path distance. + """ + if convex: + func = _make_linear_interp(pt_a, pt_b) + dist = float(numpy.linalg.norm(pt_b - pt_a)) + return func, dist + + # Non-convex: check if straight line stays inside domain + path_pts = numpy.vstack((pt_a, pt_b)) + func = _make_linear_interp(pt_a, pt_b) + sample_pts = func(numpy.linspace(0, 1, 10)) + values = forest.domain(sample_pts) + dists = numpy.linalg.norm(numpy.diff(sample_pts, axis=0), axis=1) + + if numpy.any(values > 0): + # Path exits domain - use geodesic + path, dists, _ = forest.geodesic(pt_a, pt_b) + path_pts = forest.domain.mesh.points[path, :] + func = _make_geodesic_interp(path_pts) + + return func, float(numpy.sum(dists)) + + def assign_network(forest, *args, **kwargs): """ Assign the terminal connections among tree objects within a @@ -30,461 +82,193 @@ def assign_network(forest, *args, **kwargs): A list of terminal indices for each tree in the network. network_connections : list of list of functions A list of functions that define the connection between - terminal points of the trees in the network. By default, - the connection among n interpenetrating trees is defined - by the midpoint (t=0.5) of spline curve that assigns the first - two trees in the network. - kwargs : dict - Additional keyword arguments. - Keyword arguments include: - t : float - The parameter value for the connection point among - interpenetrating trees. By default, this is defined - as the midpoint (t=0.5). + terminal points of the trees in the network. """ network_connections = [] network_assignments = [] t = kwargs.get('t', 0.5) show = kwargs.get('show', False) - if len(args) == 0: - network_id = 0 - else: - network_id = args[0] - neighbors = kwargs.get('neighbors', int(t * numpy.sum(numpy.all(numpy.isnan(forest.networks[network_id][0].data[:, 15:17]), axis=1)))) - if forest.n_trees_per_network[network_id] >= 2: - tree_0 = forest.networks[network_id][0].data - tree_1 = forest.networks[network_id][1].data - idx_0 = numpy.argwhere(numpy.all(numpy.isnan(tree_0[:, 15:17]), axis=1)).flatten() - idx_1 = numpy.argwhere(numpy.all(numpy.isnan(tree_1[:, 15:17]), axis=1)).flatten() - terminals_0_ind = idx_0 - terminals_0_pts = tree_0[idx_0, 3:6] - terminals_0_tree = cKDTree(terminals_0_pts) - terminals_1_ind = idx_1 - terminals_1_pts = tree_1[idx_1, 3:6] - terminals_1_tree = cKDTree(terminals_1_pts) - neighbors = min(neighbors, terminals_0_pts.shape[0], terminals_1_pts.shape[0]) - rows = numpy.repeat(numpy.arange(terminals_0_pts.shape[0]), neighbors) - cols = numpy.repeat(numpy.arange(terminals_1_pts.shape[0]), neighbors) - network_assignments.append(terminals_0_ind.tolist()) - #C = numpy.zeros((terminals_0_pts.shape[0], terminals_1_pts.shape[0])) - C = numpy.full((terminals_0_pts.shape[0], terminals_1_pts.shape[0]), 1e8) - # Sparse equivalent of C - - #M = [[[None]]*terminals_0_pts.shape[0]]*terminals_1_pts.shape[0] - M_sparse = [] - if forest.convex: - #C = cdist(terminals_0_pts, terminals_1_pts) - dists_1, idxs_1 = terminals_1_tree.query(terminals_0_pts, k=neighbors) - dists_0, idxs_0 = terminals_0_tree.query(terminals_1_pts, k=neighbors) - C[rows, idxs_1.flatten()] = dists_1.flatten() - C[idxs_0.flatten(), cols] = dists_0.flatten() - all_rows = numpy.array(rows.tolist() + idxs_0.flatten().tolist()) - all_cols = numpy.array(idxs_1.flatten().tolist() + cols.tolist()) - all_data = numpy.array(dists_1.flatten().tolist() + dists_0.flatten().tolist()) - function_data = [] - for i, j in zip(all_rows, all_cols): - path_pts = deepcopy(numpy.vstack((terminals_0_pts[i, :], terminals_1_pts[j, :]))) - k = 1 - tck = deepcopy(splprep(path_pts.T, s=0, k=k)) - def func(t_, tck=tck): - return numpy.array(splev(t_, tck[0])).T - function_data.append(func) - else: - # This matches the terminals of the second tree to queried terminals of the first tree - dists_1, idxs_1 = terminals_1_tree.query(terminals_0_pts, k=neighbors) - # This matches the terminals of the first tree to queried terminals of the second tree - dists_0, idxs_0 = terminals_0_tree.query(terminals_1_pts, k=neighbors) - # Calculate the rectangular distance matrix between the terminal points - # M = [[[None]] * terminals_0_pts.shape[0]] * terminals_1_pts.shape[0] - all_rows = numpy.array(rows.tolist() + idxs_0.flatten().tolist()) - all_cols = numpy.array(idxs_1.flatten().tolist() + cols.tolist()) - all_data = [] - function_data = [] - # M_sparse - for i, j in tqdm(zip(all_rows, all_cols), total=len(all_rows), desc='Calculating geodesics', leave=False): - path_pts = numpy.vstack((terminals_0_pts[i, :], terminals_1_pts[j, :])) - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - t = numpy.linspace(0, 1, path_pts.shape[0]) - xpts = interp1d(t, path_pts[:, 0]) - ypts = interp1d(t, path_pts[:, 1]) - zpts = interp1d(t, path_pts[:, 2]) - def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): - return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - pts = func(numpy.linspace(0, 1)) - values = forest.domain(pts) - dists = numpy.linalg.norm(numpy.diff(pts, axis=0), axis=1) - if numpy.any(values > 0): - path, dists, _ = forest.geodesic(terminals_0_pts[i, :], terminals_1_pts[j, :]) - path_pts = forest.domain.mesh.points[path, :] - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - t = numpy.linspace(0, 1, path_pts.shape[0]) - xpts = interp1d(t, path_pts[:, 0]) - ypts = interp1d(t, path_pts[:, 1]) - zpts = interp1d(t, path_pts[:, 2]) - def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): - return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - #M[i][idxs_1[i, j]] = func - #C[i, idxs_1[i, j]] = numpy.sum(dists) - all_data.append(numpy.sum(dists)) - function_data.append(func) - """ - for i in trange(idxs_0.shape[0], desc='Calculating geodesics II', leave=False): - for j in range(idxs_0.shape[1]): - if not isinstance(M[idxs_0[i, j]][i], type(None)): - continue - path_pts = numpy.vstack((terminals_0_pts[idxs_0[i, j], :], terminals_1_pts[i, :])) - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - t = numpy.linspace(0, 1, path_pts.shape[0]) - xpts = interp1d(t, path_pts[:, 0]) - ypts = interp1d(t, path_pts[:, 1]) - zpts = interp1d(t, path_pts[:, 2]) - def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): - return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - pts = func(numpy.linspace(0, 1)) - values = forest.domain(pts) - dists = numpy.linalg.norm(numpy.diff(pts, axis=0), axis=1) - if numpy.any(values > 0): - path, dists, _ = forest.geodesic(terminals_0_pts[idxs_0[i, j], :], terminals_1_pts[i, :]) - path_pts = forest.domain.mesh.points[path, :] - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - t = numpy.linspace(0, 1, path_pts.shape[0]) - xpts = interp1d(t, path_pts[:, 0]) - ypts = interp1d(t, path_pts[:, 1]) - zpts = interp1d(t, path_pts[:, 2]) - def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): - return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - M[idxs_0[i, j]][i] = func - C[idxs_0[i, j], i] = numpy.sum(dists) - """ - """ - for i in trange(terminals_0_pts.shape[0], desc='Calculating geodesics', leave=False): - tmp_M = [] - for j in range(terminals_1_pts.shape[0]): - path_pts = numpy.vstack((terminals_0_pts[i, :], terminals_1_pts[j, :])) - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - t = numpy.linspace(0, 1, path_pts.shape[0]) - xpts = interp1d(t, path_pts[:, 0]) - ypts = interp1d(t, path_pts[:, 1]) - zpts = interp1d(t, path_pts[:, 2]) - #func = lambda t_: numpy.array(splev(t_, tck[0])).T - #func = lambda t_: numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): - return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - pts = func(numpy.linspace(0, 1)) - values = forest.domain(pts) - dists = numpy.linalg.norm(numpy.diff(pts, axis=0), axis=1) - if numpy.any(values > 0): - path, dists, _ = forest.geodesic(terminals_0_pts[i, :], terminals_1_pts[j, :]) - path_pts = forest.domain.mesh.points[path, :] - #geodesic_generator = lambda data: geodesic(data, start=terminals_0_pts[i, :], - # end=terminals_1_pts[j, :]) - #cost = lambda data: geodesic_cost(data, curve_generator=geodesic_generator, - # boundary_func=forest.domain.evaluate) - #res = minimize(cost, path_pts, method="L-BFGS-B") - #path_pts = res.x.reshape(-1, 3) - #dists = numpy.linalg.norm(numpy.diff(path_pts, axis=0), axis=1) - #C[i, j] = numpy.sum(dists) - # TODO: add check that the terminal points are not a node point - path_pts = numpy.vstack((terminals_0_pts[i, :], path_pts, terminals_1_pts[j, :])) - #if path_pts.shape[0] > 3: - # k = 3 - #elif path_pts.shape[0] > 2: - # k = 2 - #else: - # k = 1 - #tck = splprep(path_pts.T, s=0, k=k) - t = numpy.linspace(0, 1, path_pts.shape[0]) - xpts = interp1d(t, path_pts[:, 0]) - ypts = interp1d(t, path_pts[:, 1]) - zpts = interp1d(t, path_pts[:, 2]) - def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): - return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - #func = lambda t_: numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - #func = lambda t: numpy.array(splev(t, tck[0])).T - tmp_M.append(func) - C[i, j] = numpy.sum(dists) - #M[i][j] = func - M.append(tmp_M) - """ - C_sparse = coo_matrix((all_data, (all_rows, all_cols)), - shape=(terminals_0_pts.shape[0], terminals_1_pts.shape[0])) - #M_dense = numpy.full((terminals_0_pts.shape[0], terminals_1_pts.shape[0]), None) - M_sparse = {} - for i, j, func in zip(all_rows, all_cols, function_data): - M_sparse[str(i)+','+str(j)] = func - function_data = numpy.array(function_data) - print("Function data shape: ", len(M_sparse)) - #M_dense[all_rows, all_cols] = function_data - #M_sparse = coo_matrix((function_data, (all_rows, all_cols)), - # shape=(terminals_0_pts.shape[0], terminals_1_pts.shape[0])) - print("Calculating optimal assignment...") - #_, assignment = linear_sum_assignment(C) - try: - row_ind, col_ind = min_weight_full_bipartite_matching(C_sparse) - except: - print("ERROR: Could not find optimal assignment. Try increasing the number of neighbors allowed in search.") - return None, None - print("Finished.") - midpoints = [] - for i, j in zip(row_ind, col_ind): - m_val = M_sparse[str(i)+','+str(j)] - if isinstance(m_val, type(None)): - print("ERROR: SHOULD NOT BE NONE") - midpoints.append(m_val) - network_assignments.append(terminals_1_ind[col_ind].tolist()) - network_connections.append(midpoints) - # [TODO] Remove this block of code since the geodesics or linear connections will have to be re-calculated - """ - if forest.n_trees_per_network[network_id] > 2: - mid = numpy.array([midpoints[i](t) for i in range(len(midpoints))]) - for N in range(2, forest.n_trees_per_network[network]): - tree_n = forest.networks[network][N].data - idx_n = numpy.argwhere(numpy.all(numpy.isnan(tree_n[:, 15:17]), axis=1)).flatten() - terminals_n_ind = idx_n - terminals_n_pts = tree_n[idx_n, 3:6] - C = numpy.zeros((mid.shape[0], terminals_1_pts.shape[0])) - MN = [[[None]] * mid.shape[0]] * terminals_n_pts.shape[0] - if forest.convex: - C = cdist(mid,terminals_n_pts) - for i in range(mid.shape[0]): - for j in range(terminals_n_pts.shape[0]): - path_pts = numpy.vstack((mid[i, :], terminals_n_pts[j, :])) - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - MN[i][j] = lambda t_: numpy.array(splev(t_, tck[0])).T - else: - for i in range(mid.shape[0]): - for j in range(terminals_n_pts.shape[0]): - path, dists, _ = forest.geodesic(mid[i, :], terminals_n_pts[j, :]) - C[i, j] = numpy.sum(dists) - path_pts = forest.domain.mesh.points[path, :] - # TODO: add check that the terminal points are not a node point - path_pts = numpy.vstack((terminals_0_pts[i, :], path_pts, terminals_1_pts[j, :])) - if path_pts.shape[0] > 3: - k = 3 - elif path_pts.shape[0] > 2: - k = 2 - else: - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - MN[i][j] = lambda t: numpy.array(splev(t, tck[0])).T - _, assignment = linear_sum_assignment(C) - midpoints_n = [MN[i][j] for i, j in enumerate(assignment)] - network_assignments.append(terminals_n_ind[assignment].tolist()) - network_connections.extend([midpoints_n]) - """ + network_id = args[0] if len(args) > 0 else 0 + + neighbors = kwargs.get('neighbors', int(t * numpy.sum( + numpy.all(numpy.isnan(forest.networks[network_id][0].data[:, 15:17]), axis=1)))) + + if forest.n_trees_per_network[network_id] < 2: + return network_assignments, network_connections + + tree_0 = forest.networks[network_id][0].data + tree_1 = forest.networks[network_id][1].data + idx_0 = numpy.argwhere(numpy.all(numpy.isnan(tree_0[:, 15:17]), axis=1)).flatten() + idx_1 = numpy.argwhere(numpy.all(numpy.isnan(tree_1[:, 15:17]), axis=1)).flatten() + terminals_0_ind = idx_0 + terminals_0_pts = tree_0[idx_0, 3:6] + terminals_0_tree = cKDTree(terminals_0_pts) + terminals_1_ind = idx_1 + terminals_1_pts = tree_1[idx_1, 3:6] + terminals_1_tree = cKDTree(terminals_1_pts) + neighbors = min(neighbors, terminals_0_pts.shape[0], terminals_1_pts.shape[0]) + + rows = numpy.repeat(numpy.arange(terminals_0_pts.shape[0]), neighbors) + cols = numpy.repeat(numpy.arange(terminals_1_pts.shape[0]), neighbors) + network_assignments.append(terminals_0_ind.tolist()) + + # Query k-nearest neighbors bidirectionally + dists_1, idxs_1 = terminals_1_tree.query(terminals_0_pts, k=neighbors) + dists_0, idxs_0 = terminals_0_tree.query(terminals_1_pts, k=neighbors) + + all_rows = numpy.concatenate([rows, idxs_0.flatten()]) + all_cols = numpy.concatenate([idxs_1.flatten(), cols]) + + # Deduplicate (i, j) pairs to avoid redundant geodesic computation + pairs = numpy.column_stack([all_rows, all_cols]) + unique_pairs, inverse_idx = numpy.unique(pairs, axis=0, return_inverse=True) + + # Compute path and function for each unique pair + unique_funcs = [] + unique_dists = [] + desc = 'Computing paths' if not forest.convex else 'Setting up paths' + for k in tqdm(range(unique_pairs.shape[0]), desc=desc, leave=False): + i, j = unique_pairs[k] + func, dist = _compute_path_and_func( + forest, terminals_0_pts[i], terminals_1_pts[j], forest.convex) + unique_funcs.append(func) + unique_dists.append(dist) + + # Map back to full arrays + all_data = numpy.array(unique_dists)[inverse_idx] + function_data = [unique_funcs[inv] for inv in inverse_idx] + + # Build sparse cost matrix and function lookup + C_sparse = coo_matrix( + (all_data, (all_rows, all_cols)), + shape=(terminals_0_pts.shape[0], terminals_1_pts.shape[0]) + ) + M_sparse = {} + for i, j, func in zip(all_rows, all_cols, function_data): + key = (int(i), int(j)) + if key not in M_sparse: + M_sparse[key] = func + + # Solve optimal assignment + try: + row_ind, col_ind = min_weight_full_bipartite_matching(C_sparse) + except Exception: + print("ERROR: Could not find optimal assignment. Try increasing the number of neighbors allowed in search.") + return None, None + + midpoints = [] + for i, j in zip(row_ind, col_ind): + m_val = M_sparse.get((int(i), int(j))) + if m_val is None: + print("ERROR: Missing function for assignment pair ({}, {})".format(i, j)) + midpoints.append(m_val) + + network_assignments.append(terminals_1_ind[col_ind].tolist()) + network_connections.append(midpoints) + return network_assignments, network_connections def assign_network_vector(forest, network_id, midpoints, **kwargs): """ Assign the terminal connections among tree objects within a - forest network. The assignment is based on the minimum distance - between terminal points of the tree. + forest network for trees beyond the first two. These additional + trees connect to the midpoints of the first two trees' connections. Parameters ---------- forest : svtoolkit.forest.Forest A forest object that contains a collection of trees. - args : int (optional) + network_id : int The index of the network to be assigned. + midpoints : numpy.ndarray + Midpoint positions from tree 0-1 connections. Returns ------- network_assignments : list of list of int - A list of terminal indices for each tree in the network. + A list of terminal indices for each additional tree. network_connections : list of list of functions - A list of functions that define the connection between - terminal points of the trees in the network. By default, - the connection among n interpenetrating trees is defined - by the midpoint (t=0.5) of spline curve that assigns the first - two trees in the network. - kwargs : dict - Additional keyword arguments. - Keyword arguments include: - t : float - The parameter value for the connection point among - interpenetrating trees. By default, this is defined - as the midpoint (t=0.5). + A list of connection functions for each additional tree. """ network_connections = [] network_assignments = [] neighbors = kwargs.get('neighbors', 5) - if forest.n_trees_per_network[network_id] > 2: - for N in range(2, forest.n_trees_per_network[network_id]): - tree_n = forest.networks[network_id][N].data - idx_n = numpy.argwhere(numpy.all(numpy.isnan(tree_n[:, 15:17]), axis=1)).flatten() - terminals_n_ind = idx_n - terminals_n_pts = tree_n[idx_n, 3:6] - neighbors = min(neighbors, midpoints.shape[0], terminals_n_pts.shape[0]) - terminals_n_tree = cKDTree(terminals_n_pts) - midpoints_tree = cKDTree(midpoints) - #C = numpy.zeros((midpoints.shape[0], terminals_n_pts.shape[0])) - C = numpy.full((midpoints.shape[0], terminals_n_pts.shape[0]), 1e8) - MN = [[[None]] * midpoints.shape[0]] * terminals_n_pts.shape[0] - if forest.convex: - #C = cdist(midpoints, terminals_n_pts) - dists_1, idxs_1 = midpoints_tree.query(terminals_n_pts, k=neighbors) - dists_0, idxs_0 = terminals_n_tree.query(midpoints, k=neighbors) - rows = numpy.repeat(numpy.arange(terminals_n_pts.shape[0]), neighbors) - cols = numpy.repeat(numpy.arange(midpoints.shape[0]), neighbors) - C[cols, idxs_1.flatten()] = dists_1.flatten() - C[idxs_0.flatten(), rows] = dists_0.flatten() - for i in range(midpoints.shape[0]): - for j in range(terminals_n_pts.shape[0]): - path_pts = numpy.vstack((terminals_n_pts[j, :], midpoints[i, :])) - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - def func(t_, tck=tck): - return numpy.array(splev(t_, tck[0])).T - #MN[i][j] = lambda t_: numpy.array(splev(t_, tck[0])).T - MN[i][j] = deepcopy(func) - else: - dists_1, idxs_1 = midpoints_tree.query(terminals_n_pts, k=neighbors) - dists_0, idxs_0 = terminals_n_tree.query(midpoints, k=neighbors) - for i in trange(idxs_1.shape[0], desc='Calculating geodesics I', leave=False): - for j in range(idxs_1.shape[1]): - path_pts = numpy.vstack((terminals_n_pts[i, :], midpoints[idxs_1[i, j], :])) - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - t = numpy.linspace(0, 1, path_pts.shape[0]) - xpts = interp1d(t, path_pts[:, 0]) - ypts = interp1d(t, path_pts[:, 1]) - zpts = interp1d(t, path_pts[:, 2]) - - def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): - return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - - pts = func(numpy.linspace(0, 1)) - values = forest.domain(pts) - dists = numpy.linalg.norm(numpy.diff(pts, axis=0), axis=1) - if numpy.any(values > 0): - path, dists, _ = forest.geodesic(terminals_n_pts[i, :], midpoints[idxs_1[i, j], :]) - path_pts = forest.domain.mesh.points[path, :] - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - t = numpy.linspace(0, 1, path_pts.shape[0]) - xpts = interp1d(t, path_pts[:, 0]) - ypts = interp1d(t, path_pts[:, 1]) - zpts = interp1d(t, path_pts[:, 2]) - - def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): - return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - MN[i][idxs_1[i, j]] = func - C[i, idxs_1[i, j]] = numpy.sum(dists) - for i in trange(idxs_0.shape[0], desc='Calculating geodesics II', leave=False): - for j in range(idxs_0.shape[1]): - if not isinstance(MN[idxs_0[i, j]][i], type(None)): - continue - path_pts = numpy.vstack((terminals_n_pts[idxs_0[i, j], :], midpoints[i, :])) - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - t = numpy.linspace(0, 1, path_pts.shape[0]) - xpts = interp1d(t, path_pts[:, 0]) - ypts = interp1d(t, path_pts[:, 1]) - zpts = interp1d(t, path_pts[:, 2]) - - def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): - return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - - pts = func(numpy.linspace(0, 1)) - values = forest.domain(pts) - dists = numpy.linalg.norm(numpy.diff(pts, axis=0), axis=1) - if numpy.any(values > 0): - path, dists, _ = forest.geodesic(terminals_n_pts[idxs_0[i, j], :], midpoints[i, :]) - path_pts = forest.domain.mesh.points[path, :] - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - t = numpy.linspace(0, 1, path_pts.shape[0]) - xpts = interp1d(t, path_pts[:, 0]) - ypts = interp1d(t, path_pts[:, 1]) - zpts = interp1d(t, path_pts[:, 2]) - - def func(t_, xpts=xpts, ypts=ypts, zpts=zpts): - return numpy.array([xpts(t_), ypts(t_), zpts(t_)]).T - MN[idxs_0[i, j]][i] = func - C[idxs_0[i, j], i] = numpy.sum(dists) - """ - for i in trange(midpoints.shape[0], desc='Calculating geodesics for tree: {}'.format(N), leave=False): - for j in range(terminals_n_pts.shape[0]): - path_pts = numpy.vstack((terminals_n_pts[i, :], midpoints[j, :])) - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - func = lambda t_: numpy.array(splev(t_, tck[0])).T - pts = func(numpy.linspace(0, 1)) - values = forest.domain(pts) - dists = numpy.linalg.norm(numpy.diff(pts, axis=0), axis=1) - if numpy.any(values > 0): - path, dists, _ = forest.geodesic(terminals_n_pts[i, :], midpoints[j, :]) - path_pts = forest.domain.mesh.points[path, :] - # geodesic_generator = lambda data: geodesic(data, start=terminals_0_pts[i, :], - # end=terminals_1_pts[j, :]) - # cost = lambda data: geodesic_cost(data, curve_generator=geodesic_generator, - # boundary_func=forest.domain.evaluate) - # res = minimize(cost, path_pts, method="L-BFGS-B") - # path_pts = res.x.reshape(-1, 3) - # dists = numpy.linalg.norm(numpy.diff(path_pts, axis=0), axis=1) - # C[i, j] = numpy.sum(dists) - # TODO: add check that the terminal points are not a node point - path_pts = numpy.vstack((terminals_n_pts[i, :], path_pts, midpoints[j, :])) - if path_pts.shape[0] > 3: - k = 3 - elif path_pts.shape[0] > 2: - k = 2 - else: - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - func = lambda t: numpy.array(splev(t, tck[0])).T - C[i, j] = numpy.sum(dists) - MN[i][j] = func - - path, dists, _ = forest.geodesic(terminals_n_pts[i, :], midpoints[j, :]) - C[i, j] = numpy.sum(dists) - path_pts = forest.domain.mesh.points[path, :] - # TODO: add check that the terminal points are not a node point - path_pts = numpy.vstack((terminals_n_pts[i, :], path_pts, midpoints[j, :])) - if path_pts.shape[0] > 3: - k = 3 - elif path_pts.shape[0] > 2: - k = 2 - else: - k = 1 - tck = splprep(path_pts.T, s=0, k=k) - MN[i][j] = lambda t: numpy.array(splev(t, tck[0])).T - - """ - _, assignment = linear_sum_assignment(C) - #midpoints_n = [MN[i][j] for i, j in enumerate(assignment)] - midpoints_n = [] - for i, j in enumerate(assignment): - midpoints_n.append(MN[i][j]) - network_assignments.append(terminals_n_ind[assignment].tolist()) - network_connections.extend([midpoints_n]) + + if forest.n_trees_per_network[network_id] <= 2: + return network_assignments, network_connections + + for N in range(2, forest.n_trees_per_network[network_id]): + tree_n = forest.networks[network_id][N].data + idx_n = numpy.argwhere(numpy.all(numpy.isnan(tree_n[:, 15:17]), axis=1)).flatten() + terminals_n_ind = idx_n + terminals_n_pts = tree_n[idx_n, 3:6] + neighbors = min(neighbors, midpoints.shape[0], terminals_n_pts.shape[0]) + terminals_n_tree = cKDTree(terminals_n_pts) + midpoints_tree = cKDTree(midpoints) + + C = numpy.full((midpoints.shape[0], terminals_n_pts.shape[0]), 1e8) + # Use proper 2D list initialization (avoid shared references) + MN = [[None for _ in range(terminals_n_pts.shape[0])] for _ in range(midpoints.shape[0])] + + if forest.convex: + dists_1, idxs_1 = midpoints_tree.query(terminals_n_pts, k=neighbors) + dists_0, idxs_0 = terminals_n_tree.query(midpoints, k=neighbors) + rows = numpy.repeat(numpy.arange(terminals_n_pts.shape[0]), neighbors) + cols = numpy.repeat(numpy.arange(midpoints.shape[0]), neighbors) + C[cols, idxs_1.flatten()] = dists_1.flatten() + C[idxs_0.flatten(), rows] = dists_0.flatten() + + # Collect unique pairs and compute functions + seen = set() + for i in range(midpoints.shape[0]): + for j in range(terminals_n_pts.shape[0]): + if (i, j) in seen: + continue + seen.add((i, j)) + func = _make_linear_interp(terminals_n_pts[j], midpoints[i]) + MN[i][j] = func + else: + dists_1, idxs_1 = midpoints_tree.query(terminals_n_pts, k=neighbors) + dists_0, idxs_0 = terminals_n_tree.query(midpoints, k=neighbors) + + # Process neighbors of terminals -> midpoints + for i in trange(idxs_1.shape[0], desc='Calculating geodesics I', leave=False): + for j in range(idxs_1.shape[1]): + mid_idx = idxs_1[i, j] + func, dist = _compute_path_and_func( + forest, terminals_n_pts[i], midpoints[mid_idx], forest.convex) + MN[mid_idx][i] = func # note: MN is indexed [midpoint][terminal] + C[mid_idx, i] = dist + + # Process neighbors of midpoints -> terminals + for i in trange(idxs_0.shape[0], desc='Calculating geodesics II', leave=False): + for j in range(idxs_0.shape[1]): + term_idx = idxs_0[i, j] + if MN[i][term_idx] is not None: + continue # Already computed + func, dist = _compute_path_and_func( + forest, terminals_n_pts[term_idx], midpoints[i], forest.convex) + MN[i][term_idx] = func + C[i, term_idx] = dist + + _, assignment = linear_sum_assignment(C) + midpoints_n = [MN[i][j] for i, j in enumerate(assignment)] + network_assignments.append(terminals_n_ind[assignment].tolist()) + network_connections.extend([midpoints_n]) + return network_assignments, network_connections def assign(forest, **kwargs): """ Assign the terminal connections among tree objects within all - forest networks. The assignment is based on the minimum distance - between terminal points of the tree. + forest networks. Parameters ---------- forest : svtoolkit.forest.Forest A forest object that contains a collection of trees. - kwargs : dict - Additional keyword arguments. - Keyword arguments include: - t : float - The parameter value for the connection point among - interpenetrating trees. By default, this is defined - as the midpoint (t=0.5). Returns ------- @@ -492,10 +276,7 @@ def assign(forest, **kwargs): A list of terminal indices for each tree in each network. connections : list of list of list of functions A list of functions that define the connection between - terminal points of the trees in each network. By default, - the connection among n interpenetrating trees is defined - by the midpoint (t=0.5) of spline curve that assigns the first - two trees in the network. + terminal points of the trees in each network. """ assignments = [] connections = [] diff --git a/svv/forest/connect/base_connection.py b/svv/forest/connect/base_connection.py index fb720cd..7eb55e0 100644 --- a/svv/forest/connect/base_connection.py +++ b/svv/forest/connect/base_connection.py @@ -274,146 +274,108 @@ def create_constraints(self,radius1,radius2,other_line_segments,t_num,mid,n_pts) """ Returns a list of constraint dictionaries to be used in a typical optimizer (e.g., scipy.optimize). + + Uses a shared cache so control points, curve, and evaluations + are computed once per unique optimizer input and shared across + all constraint functions. """ + # Shared cache: avoids rebuilding control points and curve + # for every constraint in the same optimizer iteration. + _cache = {'key': None, 'ctrl_pts': None, 'curve': None, + 'curve_pts_t': None, 'curve_pts_roc': None} + + # Pre-compute constant values + min_radius_needed = 2 * max(radius1, radius2) + max_radius = max(self.radius_0, self.radius_1) + t_values_curve = np.linspace(0, 1, t_num) + t_values_roc = np.linspace(0, 1, 50) # Reduced from 100 for curvature + t_values_boundary = np.linspace(0, 1, 40) # Reduced from 100 for boundary + has_collision_segments = (len(self.other_line_segments) > 0 if + hasattr(self.other_line_segments, '__len__') else False) + + # Pre-compute collision radii offset if there are collision segments + if has_collision_segments: + _collision_radii = self.other_line_segments[:, 6].reshape(-1, 1) + max_radius + self.min_distance + + def _get_cached(ctrlpts_flat): + """Return (control_points, curve) from cache or recompute.""" + key = ctrlpts_flat.tobytes() + if _cache['key'] != key: + _cache['key'] = key + _cache['ctrl_pts'] = self._build_control_points(ctrlpts_flat, mid, n_pts) + _cache['curve'] = Curve(_cache['ctrl_pts'], curve_type=self.curve_type) + _cache['curve_pts_t'] = None + _cache['curve_pts_roc'] = None + return _cache['ctrl_pts'], _cache['curve'] + + def _get_curve_points(ctrlpts_flat, variant='t'): + """Return cached curve points for the given evaluation set.""" + _, curve = _get_cached(ctrlpts_flat) + cache_key = 'curve_pts_' + variant + if _cache[cache_key] is None: + if variant == 't': + _cache[cache_key] = curve.evaluate(t_values_curve) + elif variant == 'roc': + _cache[cache_key] = None # roc uses its own t values + return _cache[cache_key] def curvature_constraint(ctrlpts_flat): - control_points = self._build_control_points(ctrlpts_flat, mid, n_pts) - curve = Curve(control_points, curve_type=self.curve_type) - - # Evaluate radius of curvature - t_values = np.linspace(0, 1, 100) - roc_values = curve.roc(t_values) - - # Enforce min radius of curvature - min_radius_needed = 2 * max(radius1, radius2) + _, curve = _get_cached(ctrlpts_flat) + roc_values = curve.roc(t_values_roc) return np.min(roc_values[1:-1]) - min_radius_needed def non_coincidence_constraint(ctrlpts_flat): - """ - Ensures that no two control points coincide - (enforce a small min distance between them). - """ - control_points = self._build_control_points(ctrlpts_flat, mid, n_pts) + control_points, _ = _get_cached(ctrlpts_flat) distances = pdist(control_points) - return np.min(distances) - 1e-4 # or self.min_distance + return np.min(distances) - 1e-4 def curve_min_distance_constraint(ctrlpts_flat): - """ - Ensures that the constructed curve is not too close - to other line segments in the scene. - """ - control_points = self._build_control_points(ctrlpts_flat, mid, n_pts) - if len(other_line_segments) == 0: - return 1.0 # If no other segments, no penalty + if not has_collision_segments: + return 1.0 - curve = Curve(control_points, curve_type=self.curve_type) - t_values = np.linspace(0, 1, t_num) - curve_points = curve.evaluate(t_values) + curve_points = _get_curve_points(ctrlpts_flat, 't') + if curve_points is None: + _, curve = _get_cached(ctrlpts_flat) + curve_points = curve.evaluate(t_values_curve) + _cache['curve_pts_t'] = curve_points - # Build array of segments from the discretized curve - segments = np.zeros((curve_points.shape[0] - 1, 6)) + # Build segments from discretized curve + segments = np.empty((curve_points.shape[0] - 1, 6)) segments[:, :3] = curve_points[:-1] segments[:, 3:] = curve_points[1:] - # Evaluate the distance - # The 'minimum_segment_distance' presumably returns NxM array - # of distances between two sets of line segments - if len(self.other_line_segments) > 0: - dist_main = np.min( - minimum_segment_distance(self.other_line_segments[:, :6], segments) - - self.other_line_segments[:, 6].reshape(-1, 1) - max(self.radius_0, self.radius_1) - ) - self.min_distance - #dist_main_check = np.min( - # cylinders_collide_any_naive(self.other_line_segments[:, :6], segments) - # - self.other_line_segments[:, 6].reshape(-1, 1) - max(self.radius_0, self.radius_1) - #) - self.min_distance - #assert dist_main == dist_main_check, "{} != {} INCORRECT CHECK".format(dist_main,dist_main_check) - return dist_main - else: - return 1.0 + dist_matrix = minimum_segment_distance(self.other_line_segments[:, :6], segments) + dist_main = np.min(dist_matrix - _collision_radii) + return dist_main def boundary_constraint(ctrlpts_flat): - """ - Ensures that the entire curve lies within the domain - (with an optional clearance). - """ if self.domain is None: - return 1.0 # No domain => no constraint + return 1.0 - control_points = self._build_control_points(ctrlpts_flat, mid, n_pts) - curve = Curve(control_points, curve_type=self.curve_type) - t_values = np.linspace(0, 1, 100) - curve_points = curve.evaluate(t_values) + _, curve = _get_cached(ctrlpts_flat) + curve_points = curve.evaluate(t_values_boundary) - # ------------------------------------------------------ - # 1) If domain is a callable, use the original logic - # ------------------------------------------------------ if callable(self.domain): - # domain(curve_points) presumably returns distances to the boundary - # (negative => out-of-domain, positive => inside) values = self.domain(curve_points) - # We want the entire curve to be inside => - # the maximum distance should remain positive - # The additional clearance is subtracted to enforce a "buffer zone". return -(np.max(values) + self.clearance) - - # ------------------------------------------------------ - # 2) If domain is PyVista PolyData, do an inside test - # ------------------------------------------------------ elif isinstance(self.domain, pv.PolyData): - # We assume 'self.domain' is a closed (watertight) mesh. - # We'll create a temporary PyVista mesh from the curve points temp_points = pv.PolyData(curve_points) - - # Use the select_enclosed_points filter: - # `enclosed_result` is typically an UnstructuredGrid with a point-data - # array named 'SelectedPoints' (1 = inside, 0 = outside). enclosed_result = temp_points.select_enclosed_points(self.domain, tolerance=0.0) - inside_mask = enclosed_result['SelectedPoints'] # numpy array of 0 or 1 - - if not np.all(inside_mask): - # Some points of the curve lie outside the domain surface - return -1.0 - else: - # Entire curve is inside - # If you also want a clearance, consider measuring the distance - # from each point to the surface (using e.g. 'find_closest_point') - # and ensuring it exceeds self.clearance. That logic would replace - # or supplement the simple "inside test" here. - return 1.0 - - # ------------------------------------------------------ - # 3) Fallback if domain is of another type (optional) - # ------------------------------------------------------ + inside_mask = enclosed_result['SelectedPoints'] + return 1.0 if np.all(inside_mask) else -1.0 else: - # You could raise an error or just return a non-violating value - # depending on your application's needs. - print("Warning: Domain type not recognized. No boundary constraint applied.") return 1.0 - # You could also add a constraint on the raw control points - # bounding box if desired: - # - # def ctrlpts_boundary_constraint(ctrlpts_flat): - # control_points = self._build_control_points(ctrlpts_flat, mid, n_pts) - # ctrl_min = np.min(control_points, axis=0) - # ctrl_max = np.max(control_points, axis=0) - # # Must be within [x_min, x_max], [y_min, y_max], [z_min, z_max] - # # Return min(...) so that if anything is out of bounds, constraint < 0 - # return min( - # *(ctrl_min - [self.x_min, self.y_min, self.z_min]), - # *([self.x_max, self.y_max, self.z_max] - ctrl_max) - # ) def self_collision(ctrlpts_flat): - control_points = self._build_control_points(ctrlpts_flat, mid, n_pts) - curve = Curve(control_points, curve_type=self.curve_type) - t_values = np.linspace(0, 1, t_num) - curve_points = curve.evaluate(t_values) - segments = np.zeros((t_num - 1, 7)) + curve_points = _get_curve_points(ctrlpts_flat, 't') + if curve_points is None: + _, curve = _get_cached(ctrlpts_flat) + curve_points = curve.evaluate(t_values_curve) + _cache['curve_pts_t'] = curve_points + segments = np.empty((t_num - 1, 6)) segments[:, :3] = curve_points[:-1] segments[:, 3:6] = curve_points[1:] - segments[:, 6] = max(self.radius_0, self.radius_1) - dist_main = np.min(minimum_self_segment_distance(segments[:, :6]) - max(self.radius_0, self.radius_1)) - self.min_distance + dist_main = minimum_self_segment_distance(segments) - max_radius - self.min_distance return dist_main return [ @@ -421,9 +383,6 @@ def self_collision(ctrlpts_flat): {'type': 'ineq', 'fun': curve_min_distance_constraint}, {'type': 'ineq', 'fun': non_coincidence_constraint}, {'type': 'ineq', 'fun': boundary_constraint}, - #{'type': 'ineq', 'fun': self_collision} - # Example if you also want the control-points bounding box check: - # {'type': 'ineq', 'fun': ctrlpts_boundary_constraint}, ] def solve(self, n_mid_pts, t_num=20): diff --git a/svv/forest/connect/bezier.py b/svv/forest/connect/bezier.py index 844843a..8fcb084 100644 --- a/svv/forest/connect/bezier.py +++ b/svv/forest/connect/bezier.py @@ -56,6 +56,8 @@ def evaluate(self, t_values): """ Evaluate the Bézier curve at the given array of parametric values. + Uses vectorized De Casteljau's algorithm to evaluate all points at once. + Parameters ---------- t_values : array-like @@ -68,12 +70,24 @@ def evaluate(self, t_values): Shape: (len(t_values), d). """ t_values = np.atleast_1d(t_values) - return np.array([self._de_casteljau(self.control_points, t) for t in t_values]) + # Vectorized De Casteljau: process all t values simultaneously + # pts shape: (n_ctrl_pts, len(t_values), d) + pts = np.broadcast_to( + self.control_points[:, None, :], + (self.control_points.shape[0], len(t_values), self.control_points.shape[1]) + ).copy() + t = t_values[None, :, None] # (1, len(t_values), 1) + n = pts.shape[0] + for i in range(n - 1): + pts[:n - 1 - i] = (1 - t) * pts[:n - 1 - i] + t * pts[1:n - i] + return pts[0] # (len(t_values), d) def derivative(self, t_values, order=1): """ Compute the nth derivative of the Bézier curve at specified parametric values. + Uses vectorized evaluation for all t values simultaneously. + Parameters ---------- t_values : array-like @@ -94,12 +108,22 @@ def derivative(self, t_values, order=1): derivative_points = n * (derivative_points[1:] - derivative_points[:-1]) n -= 1 if n < 0: - # Instead of np.zeros_like(...), create a shape (1, d) array of zeros derivative_points = np.zeros((1, derivative_points.shape[1])) break t_values = np.atleast_1d(t_values) - return np.array([self._de_casteljau(derivative_points, t) for t in t_values]) + # Vectorized De Casteljau on derivative control points + m = derivative_points.shape[0] + if m == 1: + return np.broadcast_to(derivative_points[0], (len(t_values), derivative_points.shape[1])).copy() + pts = np.broadcast_to( + derivative_points[:, None, :], + (m, len(t_values), derivative_points.shape[1]) + ).copy() + t = t_values[None, :, None] + for i in range(m - 1): + pts[:m - 1 - i] = (1 - t) * pts[:m - 1 - i] + t * pts[1:m - i] + return pts[0] def roc(self, t_values): """ diff --git a/svv/forest/connect/catmullrom.py b/svv/forest/connect/catmullrom.py index f095170..db63b76 100644 --- a/svv/forest/connect/catmullrom.py +++ b/svv/forest/connect/catmullrom.py @@ -39,8 +39,7 @@ def evaluate(self, t_values): """ Evaluate the Catmull–Rom spline at param values t in [0,1]. - We subdivide [0,1] into (N) segments if closed, or (N-1) if open, - and each segment is parameterized in [0,1] internally. + Uses vectorized operations to evaluate all t values at once per segment. Parameters ---------- @@ -59,32 +58,33 @@ def evaluate(self, t_values): if n_segs < 1: raise ValueError("Not enough segments to evaluate.") - # Prepare output - out = np.zeros((len(t_values), self.dimension)) - - for idx, t in enumerate(t_values): - # Clamp t into [0,1] - if t < 0.0: - t = 0.0 - if t > 1.0: - t = 1.0 - - # Map [0,1] -> [0, n_segs) - scaled = t * n_segs - i = int(np.floor(scaled)) - # Handle edge case at t=1 => i might be == n_segs - if i == n_segs: - i = n_segs - 1 - local_u = scaled - i # local param in [0,1] - - # For Catmull–Rom, each segment needs 4 points: p_{i-1}, p_i, p_{i+1}, p_{i+2} - # We'll fetch them with boundary conditions - p0 = self._get_ctrl_point(i - 1) - p1 = self._get_ctrl_point(i) - p2 = self._get_ctrl_point(i + 1) - p3 = self._get_ctrl_point(i + 2) - - out[idx] = self._catmull_rom_segment(p0, p1, p2, p3, local_u) + # Clamp and map to segment space + t_clamped = np.clip(t_values, 0.0, 1.0) + scaled = t_clamped * n_segs + seg_idx = np.floor(scaled).astype(int) + seg_idx = np.clip(seg_idx, 0, n_segs - 1) + local_u = scaled - seg_idx + + # Process each unique segment to batch evaluations + out = np.empty((len(t_values), self.dimension)) + unique_segs = np.unique(seg_idx) + + for s in unique_segs: + mask = seg_idx == s + u_vals = local_u[mask] + + p0 = self._get_ctrl_point(s - 1) + p1 = self._get_ctrl_point(s) + p2 = self._get_ctrl_point(s + 1) + p3 = self._get_ctrl_point(s + 2) + + # Vectorized Catmull-Rom evaluation for all u values in this segment + u2 = u_vals * u_vals + u3 = u2 * u_vals + a = -p0 + p2 + b = 2 * p0 - 5 * p1 + 4 * p2 - p3 + c = -p0 + 3 * p1 - 3 * p2 + p3 + out[mask] = 0.5 * (2 * p1 + np.outer(u_vals, a) + np.outer(u2, b) + np.outer(u3, c)) return out diff --git a/svv/forest/connect/geodesic.py b/svv/forest/connect/geodesic.py index 71bc604..aa8b9a0 100644 --- a/svv/forest/connect/geodesic.py +++ b/svv/forest/connect/geodesic.py @@ -9,45 +9,57 @@ def geodesic_constructor(domain, **kwargs): """ Construct a general geodesic function solver for a given domain. + Builds a sparse graph from the tetrahedral mesh edges and provides + a function to compute shortest (geodesic) paths between any two + 3D points. Uses Dijkstra with source caching so that repeated + queries from the same source node avoid redundant computation. + Parameters ---------- domain : svtoolkit.domain.Domain The domain object that defines the spatial region in which vascular trees are generated. - kwargs : dict - Additional keyword arguments. - Keyword arguments include: - Returns ------- get_geodesic : function A function that computes the geodesic path between two points. """ - idx = [[0, 1], [1, 2], [2, 0], [0, 3], [3, 1], [2, 3]] + # Build edge list from tetrahedra using vectorized operations + edge_pairs = numpy.array([[0, 1], [1, 2], [2, 0], [0, 3], [3, 1], [2, 3]]) tetra = domain.mesh.cells.reshape(-1, 5)[:, 1:] - lengths = [] - nodes = [] - added_nodes = set([]) - tetra_node_tree = cKDTree(domain.mesh.points) - for i in range(tetra.shape[0]): - tet = tetra[i, :] - for cx in idx: - if tuple([tet[cx[0]], tet[cx[1]]]) in added_nodes: - continue - added_nodes.add(tuple([tet[cx[0]], tet[cx[1]]])) - added_nodes.add(tuple([tet[cx[1]], tet[cx[0]]])) - nodes.append([tet[cx[0]], tet[cx[1]]]) - nodes.append([tet[cx[1]], tet[cx[0]]]) - length = numpy.linalg.norm(domain.mesh.points[tet[cx[0]], :] - domain.mesh.points[tet[cx[1]], :]) - lengths.append(length) - lengths.append(length) - M = numpy.array(nodes) - L = numpy.array(lengths) - graph = csr_matrix((L, (M[:, 0], M[:, 1])), shape=(numpy.max(M[:, 0]) + 1, numpy.max(M[:, 1]) + 1)) - - def get_path(start, end, graph=graph, shortest_path=shortest_path): - dist, pred = shortest_path(csgraph=graph, directed=False, indices=start, return_predecessors=True) + + # Extract all edges at once: shape (n_tetra * 6, 2) + all_edges = tetra[:, edge_pairs].reshape(-1, 2) + + # Canonicalize edges (smaller index first) and deduplicate + sorted_edges = numpy.sort(all_edges, axis=1) + unique_edges = numpy.unique(sorted_edges, axis=0) + + # Compute lengths for unique edges + pts = domain.mesh.points + edge_lengths = numpy.linalg.norm(pts[unique_edges[:, 0]] - pts[unique_edges[:, 1]], axis=1) + + # Build symmetric edge arrays + rows = numpy.concatenate([unique_edges[:, 0], unique_edges[:, 1]]) + cols = numpy.concatenate([unique_edges[:, 1], unique_edges[:, 0]]) + weights = numpy.concatenate([edge_lengths, edge_lengths]) + + n_nodes = pts.shape[0] + graph = csr_matrix((weights, (rows, cols)), shape=(n_nodes, n_nodes)) + + tetra_node_tree = cKDTree(pts) + + # Cache Dijkstra results per source node to avoid recomputation + _dijkstra_cache = {} + + def get_path(start, end): + if start not in _dijkstra_cache: + dist, pred = shortest_path(csgraph=graph, directed=False, + indices=start, return_predecessors=True) + _dijkstra_cache[start] = (dist, pred) + dist, pred = _dijkstra_cache[start] + path = [end] dists = [] k = end @@ -57,27 +69,19 @@ def get_path(start, end, graph=graph, shortest_path=shortest_path): k = pred[k] path = path[::-1] dists = dists[::-1] - lines = [] - for i in range(len(path) - 1): - lines.append([path[i], path[i + 1]]) + lines = [[path[i], path[i + 1]] for i in range(len(path) - 1)] return path, dists, lines def get_geodesic(start, end, tetra_node_tree=tetra_node_tree, get_path=get_path): """ - Get the geodesic path between two points + Get the geodesic path between two points. Parameters ---------- start : numpy.ndarray - The starting point. + The starting point (3D coordinates). end : numpy.ndarray - The ending point. - tetra_node_tree : scipy.spatial.cKDTree - The tree of the tetrahedral nodes. This parameter is partially applied and - should not be modified. - get_path : function - The function that computes the path between two nodes. This parameter is - partially applied and should not be modified. + The ending point (3D coordinates). Returns ------- @@ -92,5 +96,6 @@ def get_geodesic(start, end, tetra_node_tree=tetra_node_tree, get_path=get_path) jnd = tetra_node_tree.query(end)[1] path, dists, lines = get_path(ind, jnd) return path, dists, lines + return get_geodesic diff --git a/svv/forest/connect/tree_connection.py b/svv/forest/connect/tree_connection.py index aac47ff..cc5af04 100644 --- a/svv/forest/connect/tree_connection.py +++ b/svv/forest/connect/tree_connection.py @@ -179,17 +179,9 @@ def solve(self, *args, num_vessels=20, attempts=5): tree_1 = 1 tree_connections = [] midpoints = [] - self.vessels = [] - self.lengths = [] - self.vessels.append([]) - self.vessels.append([]) - self.lengths.append([]) - self.lengths.append([]) - #print("Network copy") - #self.connected_network = deepcopy(self.network) - #print("Network copy complete") + self.vessels = [[], []] + self.lengths = [[], []] for j in trange(len(self.ctrlpts_functions[0]), desc=f"Tree {tree_0} to Tree {tree_1}", leave=True): - print(f"setup vessel connection: {j}") idx_a = self.assignments[tree_0][j] idx_b = self.assignments[tree_1][j] v0 = self.forest.networks[self.network_id][tree_0].data[idx_a, :] @@ -211,21 +203,18 @@ def solve(self, *args, num_vessels=20, attempts=5): ctrl_function=self.ctrlpts_functions[0][j], clamp_first=True, clamp_second=True, curve_type=self.curve_type, collision_vessels=collision_vessels) - print(f"setup vessel connection finished") if self._collision_cache is None: collisions = [] - collisions.append(conn.connection.other_line_segments) - if len(self.vessels) > 0: - for i in range(len(self.vessels)): - if len(self.vessels[i]) > 0: - collisions.extend(self.vessels[i]) - if len(self.other_vessels) > 0: - for i in range(len(self.other_vessels)): - if len(self.other_vessels[i]) > 0: - collisions.extend(self.other_vessels[i]) - if len(collisions) > 0: - collisions = np.vstack(collisions) - conn.connection.set_collision_vessels(collisions) + if hasattr(conn.connection, 'other_line_segments') and len(conn.connection.other_line_segments) > 0: + collisions.append(conn.connection.other_line_segments) + for vessel_list in self.vessels: + if vessel_list: + collisions.extend(vessel_list) + for other_list in self.other_vessels: + if other_list: + collisions.extend(other_list) + if collisions: + conn.connection.set_collision_vessels(np.vstack(collisions)) index_0 = self.assignments[tree_0][j] index_1 = self.assignments[tree_1][j] degree = args[0] diff --git a/svv/forest/connect/vessel_connection.py b/svv/forest/connect/vessel_connection.py index f21a3c1..f8ebfad 100644 --- a/svv/forest/connect/vessel_connection.py +++ b/svv/forest/connect/vessel_connection.py @@ -1,9 +1,34 @@ import numpy import pyvista as pv -from copy import deepcopy from svv.forest.connect.base_connection import BaseConnection +def _build_tree_collision_segments(tree_data, exclude_idx): + """Build collision segment array for a tree, excluding specified vessel and its children.""" + n = tree_data.shape[0] + keep = numpy.ones(n, dtype=bool) + vessel = tree_data[exclude_idx, :] + + # Exclude the connected vessel itself + if not numpy.isnan(vessel[17]): + keep[exclude_idx] = False + # Exclude daughters if they exist + if not numpy.isnan(vessel[15]): + keep[int(vessel[15])] = False + if not numpy.isnan(vessel[16]): + keep[int(vessel[16])] = False + + idx = numpy.nonzero(keep)[0] + if idx.size == 0: + return numpy.zeros((0, 7), dtype=float) + + tmp = numpy.empty((idx.size, 7), dtype=float) + tmp[:, 0:3] = tree_data[idx, 0:3] + tmp[:, 3:6] = tree_data[idx, 3:6] + tmp[:, 6] = tree_data[idx, 21] + return tmp + + class VesselConnection: def __init__(self, forest, network_id, tree_0, tree_1, idx, jdx, ctrl_function=None, clamp_first=True, clamp_second=True, @@ -21,82 +46,42 @@ def __init__(self, forest, network_id, tree_0, tree_1, idx, jdx, self.distal_1 = vessel_1[3:6] self.radius_0 = vessel_0[21] self.radius_1 = vessel_1[21] - min_distance = max(self.radius_0, self.radius_1)*0.5 - #min_distance = (max(numpy.max(forest.networks[network_id][tree_0].data[:,21]), - # numpy.max(forest.networks[network_id][tree_1].data[:,21])) + max(self.radius_0,self.radius_1)) + min_distance = max(self.radius_0, self.radius_1) * 0.5 conn = BaseConnection(vessel_0[0:3], vessel_0[3:6], vessel_1[0:3], vessel_1[3:6], vessel_0[21], vessel_1[21], domain=forest.domain, ctrlpt_function=ctrl_function, clamp_first=clamp_first, clamp_second=clamp_second, point_0=point_0, point_1=point_1, min_distance=min_distance, curve_type=curve_type) if collision_vessels is None: collision_list = [] - tree_0_idx = numpy.arange(forest.networks[network_id][tree_0].data.shape[0], dtype=int).tolist() - if not numpy.isnan(vessel_0[17]): - parent = int(vessel_0[17]) - daughter_0 = int(forest.networks[network_id][tree_0].data[parent, 15]) - daughter_1 = int(forest.networks[network_id][tree_0].data[parent, 16]) - #tree_0_idx.remove(parent) - #if daughter_0 == idx: - # tree_0_idx.remove(daughter_0) - #else: - # tree_0_idx.remove(daughter_1) - tree_0_idx.remove(idx) - if not numpy.isnan(vessel_0[15]): - daughter_2 = int(vessel_1[15]) - tree_0_idx.remove(daughter_2) - if not numpy.isnan(vessel_0[16]): - daughter_3 = int(vessel_1[16]) - tree_0_idx.remove(daughter_3) - tree_0_idx = numpy.array(tree_0_idx).astype(int) - tmp = numpy.zeros((tree_0_idx.shape[0], 7)) - tmp[:, 0:3] = forest.networks[network_id][tree_0].data[tree_0_idx, 0:3] - tmp[:, 3:6] = forest.networks[network_id][tree_0].data[tree_0_idx, 3:6] - tmp[:, 6] = forest.networks[network_id][tree_0].data[tree_0_idx, 21] - collision_list.append(deepcopy(tmp)) - tree_1_idx = numpy.arange(forest.networks[network_id][tree_1].data.shape[0], dtype=int).tolist() - if not numpy.isnan(vessel_1[17]): - parent = int(vessel_1[17]) - daughter_0 = int(forest.networks[network_id][tree_1].data[parent, 15]) - daughter_1 = int(forest.networks[network_id][tree_1].data[parent, 16]) - #tree_1_idx.remove(parent) - #if daughter_0 == jdx: - # tree_1_idx.remove(daughter_0) - #else: - # tree_1_idx.remove(daughter_1) - #tree_1_idx.remove(daughter_0) - #tree_1_idx.remove(daughter_1) - tree_1_idx.remove(jdx) - if not numpy.isnan(vessel_1[15]): - daughter_2 = int(vessel_1[15]) - tree_1_idx.remove(daughter_2) - if not numpy.isnan(vessel_1[16]): - daughter_3 = int(vessel_1[16]) - tree_1_idx.remove(daughter_3) - tree_1_idx = numpy.array(tree_1_idx).astype(int) - tmp = numpy.zeros((tree_1_idx.shape[0], 7)) - tmp[:, 0:3] = forest.networks[network_id][tree_1].data[tree_1_idx, 0:3] - tmp[:, 3:6] = forest.networks[network_id][tree_1].data[tree_1_idx, 3:6] - tmp[:, 6] = forest.networks[network_id][tree_1].data[tree_1_idx, 21] - collision_list.append(deepcopy(tmp)) + # Build collision segments for both trees (no deepcopy needed - tmp arrays are freshly allocated) + collision_list.append(_build_tree_collision_segments( + forest.networks[network_id][tree_0].data, idx)) + collision_list.append(_build_tree_collision_segments( + forest.networks[network_id][tree_1].data, jdx)) + + # Add segments from other networks for i in range(forest.n_networks): for j in range(forest.n_trees_per_network[i]): if i == network_id and (j == tree_0 or j == tree_1): continue - tmp = numpy.zeros((forest.networks[i][j].data.shape[0], 7)) - tmp[:, 0:3] = forest.networks[i][j].data[:, 0:3] - tmp[:, 3:6] = forest.networks[i][j].data[:, 3:6] - tmp[:, 6] = forest.networks[i][j].data[:, 21] - collision_list.append(deepcopy(tmp)) - collision_arr = numpy.vstack(collision_list) if len(collision_list) else numpy.zeros((0, 7), dtype=float) - if collision_arr.shape[0] > 0: - conn.set_collision_vessels(collision_arr) + data = forest.networks[i][j].data + n = data.shape[0] + tmp = numpy.empty((n, 7), dtype=float) + tmp[:, 0:3] = data[:, 0:3] + tmp[:, 3:6] = data[:, 3:6] + tmp[:, 6] = data[:, 21] + collision_list.append(tmp) + + # Filter empty arrays and stack + collision_list = [a for a in collision_list if a.shape[0] > 0] + if collision_list: + conn.set_collision_vessels(numpy.vstack(collision_list)) else: collision_vessels = numpy.asarray(collision_vessels, dtype=float) if collision_vessels.ndim == 2 and collision_vessels.shape[1] >= 6 and collision_vessels.shape[0] > 0: if collision_vessels.shape[1] == 6: tmp = numpy.zeros((collision_vessels.shape[0], 7), dtype=float) tmp[:, 0:6] = collision_vessels[:, 0:6] - tmp[:, 6] = 0.0 collision_vessels = tmp conn.set_collision_vessels(collision_vessels) conn.set_physical_clearance(self.forest.networks[network_id][tree_0].domain_clearance) diff --git a/svv/forest/forest.py b/svv/forest/forest.py index 62782d8..038d007 100644 --- a/svv/forest/forest.py +++ b/svv/forest/forest.py @@ -401,6 +401,60 @@ def export_splines(self, outdir=None): interp_xyz, interp_radii, interp_normals, all_points, all_radii, all_normals = export_spline(self.connections.tree_connections[i]) _ = write_splines(all_points, all_radii, outdir=outdir, name_prefix="{}".format(i)) + def export_centerlines(self, points_per_unit_length: int = 100, **kwargs): + """ + Export centerline geometry for every tree in the forest. + + Parameters + ---------- + points_per_unit_length : int, optional + Sampling density along each spline. Default is 100. + + Returns + ------- + result : CenterlineResult + A 2-tuple ``(centerlines, polys)`` for backward compatibility. + Access ``result.boundary_points`` for inlet/outlet metadata + (each dict also includes a ``tree_label`` key). + """ + import pyvista + + all_centerlines = None + all_polys = [] + all_boundary_points = [] + + for net_idx in range(self.n_networks): + for tree_idx in range(self.n_trees_per_network[net_idx]): + tree = self.networks[net_idx][tree_idx] + if getattr(tree, "data", None) is None: + continue + data_arr = numpy.asarray(tree.data) + if data_arr.ndim != 2 or data_arr.shape[0] == 0: + continue + + result = tree.export_centerlines( + points_per_unit_length=points_per_unit_length, **kwargs + ) + cl, polys = result + bp = getattr(result, 'boundary_points', []) + all_polys.extend(polys) + + label = f"network{net_idx}_tree{tree_idx}" + for pt in bp: + pt['tree_label'] = label + all_boundary_points.extend(bp) + + if all_centerlines is None: + all_centerlines = cl + else: + all_centerlines = all_centerlines.merge(cl, merge_points=False) + + if all_centerlines is None: + raise ValueError("Forest contains no trees with data to export.") + + from svv.tree.tree import CenterlineResult + return CenterlineResult(all_centerlines, all_polys, all_boundary_points) + def save(self, path: str, include_timing: bool = False): """ Save this Forest to a .forest file. diff --git a/svv/simulation/simulation.py b/svv/simulation/simulation.py index 847d914..9490c0f 100644 --- a/svv/simulation/simulation.py +++ b/svv/simulation/simulation.py @@ -769,7 +769,7 @@ def construct_1d_fluid_simulation(self, *args, viscosity=None, density=None, tim else: raise ValueError("Too many positional input arguments") if isinstance(self.synthetic_object, svv.tree.tree.Tree): - centerlines, _ = self.synthetic_object.export_centerlines() + centerlines, *_ = self.synthetic_object.export_centerlines() material = one_d_parameters.MaterialModel() params = one_d_parameters.Parameters() params.output_directory = self.file_path + os.sep + "fluid" + os.sep + "1d" diff --git a/svv/tree/branch/bifurcation.py b/svv/tree/branch/bifurcation.py index bef8ea8..79ef556 100644 --- a/svv/tree/branch/bifurcation.py +++ b/svv/tree/branch/bifurcation.py @@ -277,7 +277,7 @@ def callback(xk): #terminal_daughter_vessel = TreeData() #parent_vessel = TreeData() #connectivity = numpy.nan_to_num(tree.data[:, 15:18], nan=-1.0).astype(int) - connectivity = deepcopy(tree.connectivity) + connectivity = tree.connectivity.copy() #create_new_vessels(bifurcation_point, tree.data, terminal_point, terminal_vessel, # terminal_daughter_vessel, parent_vessel, max_distal_node, # numpy.float64(tree.data.shape[0]), @@ -626,18 +626,18 @@ def callback(xk): #downstream = numpy.array(sorted(set(tree.vessel_map[bifurcation_vessel]['downstream'])), dtype=int) #upstream = np.sort(np.unique(tree.vessel_map[bifurcation_vessel]['upstream'])).astype(np.int64) #downstream = np.sort(np.unique(tree.vessel_map[bifurcation_vessel]['downstream'])).astype(np.int64) - upstream = deepcopy(sorted(set(tree.vessel_map[bifurcation_vessel]['upstream']))) - downstream = deepcopy(sorted(set(tree.vessel_map[bifurcation_vessel]['downstream']))) + upstream = sorted(set(tree.vessel_map[bifurcation_vessel]['upstream'])) + downstream = sorted(set(tree.vessel_map[bifurcation_vessel]['downstream'])) terminal_map[data.shape[0]] = {'upstream': [], 'downstream': []} #terminal_map[data.shape[0]]['upstream'] = numpy.append(upstream, numpy.array([bifurcation_vessel])) - terminal_map[data.shape[0]]['upstream'] = deepcopy(upstream) + terminal_map[data.shape[0]]['upstream'] = list(upstream) #print("Before 0: {}".format(terminal_map[tree.data.shape[0]]['upstream'])) terminal_map[data.shape[0]]['upstream'].append(bifurcation_vessel) #print("After 0: {}".format(terminal_map[tree.data.shape[0]]['upstream'])) terminal_daughter_map = TreeMap() terminal_daughter_map[data.shape[0] + 1] = {'upstream': [], 'downstream': []} - terminal_daughter_map[data.shape[0] + 1]['upstream'] = deepcopy(upstream) - terminal_daughter_map[data.shape[0] + 1]['downstream'] = deepcopy(downstream) + terminal_daughter_map[data.shape[0] + 1]['upstream'] = list(upstream) + terminal_daughter_map[data.shape[0] + 1]['downstream'] = list(downstream) #terminal_daughter_map[tree.data.shape[0] + 1]['upstream'] = numpy.append(upstream, numpy.array([bifurcation_vessel])) #print("Before: {}".format(terminal_daughter_map[tree.data.shape[0] + 1]['upstream'])) terminal_daughter_map[data.shape[0] + 1]['upstream'].append(bifurcation_vessel) @@ -685,7 +685,7 @@ def callback(xk): tree.times['chunk_3_2'][-1] += end_3_2 - start_3_2 start_3_3 = perf_counter() #connectivity = numpy.nan_to_num(tree.data[:, 15:18], nan=-1.0).astype(int) - connectivity = deepcopy(tree.connectivity) + connectivity = tree.connectivity.copy() #if (np.any(connectivity != connectivity_2)): # print('Connectivity mismatch!') # print('Connectivity: ', connectivity) @@ -1441,16 +1441,16 @@ def callback(xk): terminal_map = TreeMap() #upstream = numpy.array(sorted(set(tree.vessel_map[bifurcation_vessel]['upstream'])),dtype=int) #downstream = numpy.array(sorted(set(tree.vessel_map[bifurcation_vessel]['downstream'])), dtype=int) - upstream = deepcopy(sorted(set(tree.vessel_map[bifurcation_vessel]['upstream']))) - downstream = deepcopy(sorted(set(tree.vessel_map[bifurcation_vessel]['downstream']))) + upstream = sorted(set(tree.vessel_map[bifurcation_vessel]['upstream'])) + downstream = sorted(set(tree.vessel_map[bifurcation_vessel]['downstream'])) terminal_map[data.shape[0]] = {'upstream': [], 'downstream': []} #terminal_map[tree.data.shape[0]]['upstream'] = numpy.append(upstream, numpy.array([bifurcation_vessel])) - terminal_map[data.shape[0]]['upstream'] = deepcopy(upstream) + terminal_map[data.shape[0]]['upstream'] = list(upstream) terminal_map[data.shape[0]]['upstream'].append(bifurcation_vessel) terminal_daughter_map = TreeMap() terminal_daughter_map[data.shape[0] + 1] = {'upstream': [], 'downstream': []} - terminal_daughter_map[data.shape[0] + 1]['upstream'] = deepcopy(upstream) - terminal_daughter_map[data.shape[0] + 1]['downstream'] = deepcopy(downstream) + terminal_daughter_map[data.shape[0] + 1]['upstream'] = list(upstream) + terminal_daughter_map[data.shape[0] + 1]['downstream'] = list(downstream) #terminal_daughter_map[data.shape[0] + 1]['upstream'] = numpy.append(upstream, numpy.array([bifurcation_vessel])) terminal_daughter_map[data.shape[0] + 1]['upstream'].append(bifurcation_vessel) parent_map = TreeMap() @@ -1501,7 +1501,7 @@ def callback(xk): new_data = [] old_data = [] #connectivity = numpy.nan_to_num(tree.data[:, 15:18], nan=-1.0).astype(int) - connectivity = deepcopy(tree.connectivity) + connectivity = tree.connectivity.copy() results = update_vessels(bifurcation_point, data, terminal_point, connectivity, bifurcation_vessel, tree.parameters.murray_exponent, tree.parameters.kinematic_viscosity * tree.parameters.fluid_density, diff --git a/svv/tree/export/export_centerlines.py b/svv/tree/export/export_centerlines.py index 61b3471..74bd095 100644 --- a/svv/tree/export/export_centerlines.py +++ b/svv/tree/export/export_centerlines.py @@ -441,6 +441,38 @@ def find_closest_excluding(query_point, exclude_poly_idx): GlobalNodeId = numpy.array(GlobalNodeId) polys[ind].point_data['GlobalNodeId'] = GlobalNodeId + # ---- Label inlet / outlet boundary points ---- + # Vessels that are parents of a daughter branch (appear as bif[1]). + parent_vessel_indices = {bif[1] for bif in bifurcation_point_ids} + # Terminal (outlet) vessels: those that are never a parent. + terminal_vessel_indices = set(range(len(polys))) - parent_vessel_indices + + for ind in range(len(polys)): + # 0 = interior, 1 = inlet, 2 = outlet + bt = numpy.zeros(polys[ind].n_points, dtype=int) + # The very first point of polys[0] is the inlet (tree root). + if ind == 0: + bt[0] = 1 + # The last point of every terminal vessel is an outlet. + if ind in terminal_vessel_indices: + bt[-1] = 2 + polys[ind].point_data['BoundaryType'] = bt + + # Collect boundary point coordinates + radii for convenience. + boundary_points = [] + # Inlet: first point of root vessel. + boundary_points.append({ + 'type': 'inlet', + 'point': numpy.array(polys[0].points[0]), + 'radius': float(polys[0].point_data['MaximumInscribedSphereRadius'][0]), + }) + for ind in sorted(terminal_vessel_indices): + boundary_points.append({ + 'type': 'outlet', + 'point': numpy.array(polys[ind].points[-1]), + 'radius': float(polys[ind].point_data['MaximumInscribedSphereRadius'][-1]), + }) + # Merge and Connect Lines # Precompute cumulative point counts for index mapping cumulative_points = [0] @@ -479,4 +511,4 @@ def find_closest_in_merged(query_point, max_poly_idx): new_line = [2, closest_pt_id, closest_next_id] centerlines_all.lines = numpy.hstack((centerlines_all.lines, numpy.array(new_line))) - return centerlines_all, polys \ No newline at end of file + return centerlines_all, polys, boundary_points \ No newline at end of file diff --git a/svv/tree/tree.py b/svv/tree/tree.py index d04f60d..51967e7 100644 --- a/svv/tree/tree.py +++ b/svv/tree/tree.py @@ -21,6 +21,21 @@ from collections import ChainMap +class CenterlineResult(tuple): + """Backward-compatible result from :meth:`Tree.export_centerlines`. + + Unpacks as a 2-tuple ``(centerlines, polys)`` so that existing code + ``centerlines, polys = tree.export_centerlines()`` continues to work. + Boundary-point metadata is available via the ``.boundary_points`` + attribute. + """ + + def __new__(cls, centerlines, polys, boundary_points=None): + instance = super().__new__(cls, (centerlines, polys)) + instance.boundary_points = boundary_points if boundary_points is not None else [] + return instance + + class Tree(object): def __init__( self, @@ -324,7 +339,8 @@ def add(self, inplace=True, **kwargs): #_, counts = np.unique(self.vessel_map[key]['downstream'], return_counts=True) #assert np.all(counts == 1), "Fail in appending downstream idxs" else: - self.vessel_map[key] = deepcopy(new_vessel_map[key]) + self.vessel_map[key] = {'upstream': list(new_vessel_map[key]['upstream']), + 'downstream': list(new_vessel_map[key]['downstream'])} end_chunk_4_3 = perf_counter() self.times['chunk_4_3'].append(end_chunk_4_3 - start_chunk_4_3) #self.vessel_map = ChainMap(new_vessel_map, self.vessel_map) @@ -585,13 +601,20 @@ def export_centerlines(self, points_per_unit_length: int = 100, **kwargs): Returns ------- - centerlines : pyvista.PolyData - Centerline polydata with radius and section-area arrays. - polys : list[pyvista.PolyData] - Per-branch polylines used to construct the merged centerline set. + result : CenterlineResult + A 2-tuple ``(centerlines, polys)`` for backward compatibility. + Access ``result.boundary_points`` for inlet/outlet metadata. + + - **centerlines** — merged :class:`pyvista.PolyData` with radius, + section-area, and ``BoundaryType`` arrays. + - **polys** — per-branch polylines. + - **boundary_points** — list of dicts with keys ``type`` + ('inlet'/'outlet'), ``point`` (3-element array) and ``radius``. """ - centerlines, polys = build_centerlines(self, points_per_unit_length=points_per_unit_length) - return centerlines, polys + centerlines, polys, boundary_points = build_centerlines( + self, points_per_unit_length=points_per_unit_length, + ) + return CenterlineResult(centerlines, polys, boundary_points) def export_gcode(self): diff --git a/svv/utils/spatial/c_distance.py b/svv/utils/spatial/c_distance.py index 6c6628a..502cc49 100644 --- a/svv/utils/spatial/c_distance.py +++ b/svv/utils/spatial/c_distance.py @@ -1,5 +1,6 @@ import numpy as np + def _point_to_segment_distance(px, py, pz, x0, y0, z0, x1, y1, z1) -> float: vx, vy, vz = x1 - x0, y1 - y0, z1 - z0 wx, wy, wz = px - x0, py - y0, pz - z0 @@ -14,67 +15,143 @@ def _point_to_segment_distance(px, py, pz, x0, y0, z0, x1, y1, z1) -> float: return float(np.sqrt(dx*dx + dy*dy + dz*dz)) +def _point_to_segment_distance_batch(points, seg_start, seg_end): + """Vectorized point-to-segment distance for multiple points against one segment.""" + v = seg_end - seg_start + seg_len_sq = np.dot(v, v) + if seg_len_sq < 1e-14: + return np.linalg.norm(points - seg_start, axis=1) + w = points - seg_start + proj = np.dot(w, v) / seg_len_sq + proj = np.clip(proj, 0.0, 1.0) + closest = seg_start + proj[:, None] * v + return np.linalg.norm(points - closest, axis=1) + + def minimum_segment_distance(data_0: np.ndarray, data_1: np.ndarray) -> np.ndarray: - i, j = data_0.shape[0], data_1.shape[0] - out = np.zeros((i, j), dtype=float) - for ii in range(i): - a0 = data_0[ii, 0:3] - a1 = data_0[ii, 3:6] - ab = a1 - a0 - ab_ab = float(np.dot(ab, ab)) - for jj in range(j): - c0 = data_1[jj, 0:3] - c1 = data_1[jj, 3:6] - cd = c1 - c0 - cd_cd = float(np.dot(cd, cd)) - # degenerate cases - if ab_ab < 1e-14 and cd_cd < 1e-14: + """Compute pairwise minimum distances between two sets of line segments. + + Uses fully vectorized numpy operations instead of Python loops. + + Parameters + ---------- + data_0 : np.ndarray, shape (N, 6) + First set of segments, columns [x0,y0,z0,x1,y1,z1]. + data_1 : np.ndarray, shape (M, 6) + Second set of segments. + + Returns + ------- + np.ndarray, shape (N, M) + Pairwise minimum distances. + """ + N = data_0.shape[0] + M = data_1.shape[0] + + # Extract segment endpoints + A0 = data_0[:, 0:3] # (N, 3) + A1 = data_0[:, 3:6] # (N, 3) + B0 = data_1[:, 0:3] # (M, 3) + B1 = data_1[:, 3:6] # (M, 3) + + # Direction vectors + AB = A1 - A0 # (N, 3) + CD = B1 - B0 # (M, 3) + + # Squared lengths + AB_AB = np.sum(AB * AB, axis=1) # (N,) + CD_CD = np.sum(CD * CD, axis=1) # (M,) + + # Broadcast for pairwise computation: (N, 1, 3) and (1, M, 3) + AB_exp = AB[:, None, :] # (N, 1, 3) + CD_exp = CD[None, :, :] # (1, M, 3) + A0_exp = A0[:, None, :] # (N, 1, 3) + B0_exp = B0[None, :, :] # (1, M, 3) + + AB_CD = np.sum(AB_exp * CD_exp, axis=2) # (N, M) + CA = A0_exp - B0_exp # (N, M, 3) + CA_AB = np.sum(CA * AB_exp, axis=2) # (N, M) + CA_CD = np.sum(CA * CD_exp, axis=2) # (N, M) + + AB_AB_exp = AB_AB[:, None] # (N, 1) + CD_CD_exp = CD_CD[None, :] # (1, M) + + denom = AB_AB_exp * CD_CD_exp - AB_CD * AB_CD # (N, M) + + # Cascading clamp re-projection (Ericson, "Real-Time Collision Detection") + is_par = np.abs(denom) < 1e-14 + safe_denom = np.where(is_par, 1.0, denom) + safe_AB_AB = np.where(AB_AB_exp < 1e-14, 1.0, AB_AB_exp) + safe_CD_CD = np.where(CD_CD_exp < 1e-14, 1.0, CD_CD_exp) + + # Step 1: Compute unconstrained t and clamp to [0,1] + t = np.where(is_par, 0.0, + np.clip((AB_CD * CA_CD - CA_AB * CD_CD_exp) / safe_denom, 0.0, 1.0)) + + # Step 2: Compute s from clamped t + s_raw = (t * AB_CD + CA_CD) / safe_CD_CD + + # Step 3: If s out of [0,1], clamp s and re-project t + needs_s_low = s_raw < 0.0 + needs_s_high = s_raw > 1.0 + s = np.clip(s_raw, 0.0, 1.0) + + t_for_s0 = np.clip(-CA_AB / safe_AB_AB, 0.0, 1.0) + t_for_s1 = np.clip((AB_CD - CA_AB) / safe_AB_AB, 0.0, 1.0) + t = np.where(needs_s_low, t_for_s0, t) + t = np.where(needs_s_high, t_for_s1, t) + + P1 = A0_exp + t[:, :, None] * AB_exp # (N, M, 3) + P2 = B0_exp + s[:, :, None] * CD_exp # (N, M, 3) + general_dist = np.linalg.norm(P1 - P2, axis=2) # (N, M) + + # Handle degenerate cases (zero-length segments) + is_degen_A = AB_AB < 1e-14 # (N,) + is_degen_B = CD_CD < 1e-14 # (M,) + needs_fallback = is_degen_A[:, None] | is_degen_B[None, :] + + if np.any(needs_fallback): + out = general_dist.copy() + fb_i, fb_j = np.nonzero(needs_fallback) + + for idx in range(len(fb_i)): + ii, jj = fb_i[idx], fb_j[idx] + a0, a1 = A0[ii], A1[ii] + c0, c1 = B0[jj], B1[jj] + + if AB_AB[ii] < 1e-14 and CD_CD[jj] < 1e-14: out[ii, jj] = float(np.linalg.norm(a0 - c0)) - continue - if ab_ab < 1e-14: + elif AB_AB[ii] < 1e-14: out[ii, jj] = _point_to_segment_distance(*a0, *c0, *c1) - continue - if cd_cd < 1e-14: - out[ii, jj] = _point_to_segment_distance(*c0, *a0, *a1) - continue - ab_cd = float(np.dot(ab, cd)) - denom = ab_ab * cd_cd - ab_cd * ab_cd - if abs(denom) < 1e-14: - # parallel, check endpoints - best = min( - _point_to_segment_distance(*a0, *c0, *c1), - _point_to_segment_distance(*a1, *c0, *c1), - _point_to_segment_distance(*c0, *a0, *a1), - _point_to_segment_distance(*c1, *a0, *a1), - ) - out[ii, jj] = best else: - ca = a0 - c0 - ca_ab = float(np.dot(ca, ab)) - ca_cd = float(np.dot(ca, cd)) - t_ = (ab_cd * ca_cd - ca_ab * cd_cd) / denom - s_ = (ab_ab * ca_cd - ab_cd * ca_ab) / denom - t_ = 0.0 if t_ < 0.0 else (1.0 if t_ > 1.0 else t_) - s_ = 0.0 if s_ < 0.0 else (1.0 if s_ > 1.0 else s_) - p1 = a0 + t_ * ab - p2 = c0 + s_ * cd - out[ii, jj] = float(np.linalg.norm(p1 - p2)) - return out + out[ii, jj] = _point_to_segment_distance(*c0, *a0, *a1) + return out + return general_dist def minimum_self_segment_distance(data: np.ndarray) -> float: + """Compute minimum distance between non-adjacent segments in a polyline. + + Uses vectorized pairwise computation with masking instead of + calling minimum_segment_distance in a loop. + """ n = data.shape[0] - if n < 2: + if n < 3: return 1e20 - best = 1e20 + + # Compute full pairwise distance matrix at once + dist_matrix = minimum_segment_distance(data[:, :6], data[:, :6]) + + # Mask out self-distances and adjacent segments + mask = np.ones((n, n), dtype=bool) for i in range(n): - a0 = data[i, 0:3] - a1 = data[i, 3:6] - for j in range(i+2, n): - c0 = data[j, 0:3] - c1 = data[j, 3:6] - d = minimum_segment_distance(np.array([[*a0, *a1]]), np.array([[*c0, *c1]]) )[0, 0] - if d < best: - best = d - return float(best) + mask[i, i] = False + if i + 1 < n: + mask[i, i + 1] = False + mask[i + 1, i] = False + + valid = dist_matrix[mask] + if valid.size == 0: + return 1e20 + return float(np.min(valid)) diff --git a/svv/visualize/batch_cylinders.py b/svv/visualize/batch_cylinders.py new file mode 100644 index 0000000..881cfbb --- /dev/null +++ b/svv/visualize/batch_cylinders.py @@ -0,0 +1,239 @@ +""" +Batch cylinder construction for efficient VTK rendering. + +Builds all cylinder geometry vectorized in numpy and returns a single +merged pv.PolyData, reducing VTK actors from N to 1 per logical group. +""" + +import numpy as np +import pyvista as pv + + +def make_cylinders_batch(centers, directions, radii, heights, resolution=8): + """ + Build a single merged PolyData containing all cylinders. + + Parameters + ---------- + centers : ndarray (n, 3) + directions : ndarray (n, 3) + radii : ndarray (n,) + heights : ndarray (n,) + resolution : int + Number of sides per cylinder cross-section. + + Returns + ------- + pv.PolyData or None + """ + if len(centers) == 0: + return None + + centers = np.asarray(centers, dtype=np.float64) + directions = np.asarray(directions, dtype=np.float64) + radii = np.asarray(radii, dtype=np.float64).ravel() + heights = np.asarray(heights, dtype=np.float64).ravel() + + # Filter invalid cylinders + valid = (heights > 0) & np.all(np.isfinite(centers), axis=1) + if not np.any(valid): + return None + centers = centers[valid] + directions = directions[valid] + radii = radii[valid] + heights = heights[valid] + + n = len(centers) + res = resolution + + # Normalize directions + norms = np.linalg.norm(directions, axis=1, keepdims=True) + norms = np.where(norms < 1e-12, 1.0, norms) + w = directions / norms # (n, 3) + + # Build local coordinate frames: u, v perpendicular to w + # Pick reference vector that isn't parallel to w + ref = np.zeros_like(w) + ref[:, 0] = 1.0 + # Where w is nearly parallel to x-axis, use y-axis instead + parallel = np.abs(np.einsum('ij,ij->i', w, ref)) > 0.9 + ref[parallel] = [0.0, 1.0, 0.0] + + u = np.cross(w, ref) + u_norms = np.linalg.norm(u, axis=1, keepdims=True) + u_norms = np.where(u_norms < 1e-12, 1.0, u_norms) + u /= u_norms + v = np.cross(w, u) # Already unit length + + # Angles for the ring + theta = np.linspace(0, 2 * np.pi, res, endpoint=False) # (res,) + cos_t = np.cos(theta) # (res,) + sin_t = np.sin(theta) # (res,) + + # Ring offsets in local frame: (res, n, 3) + # ring_offset[k, i, :] = radii[i] * (cos_t[k] * u[i] + sin_t[k] * v[i]) + ring_offset = (cos_t[:, None, None] * u[None, :, :] + + sin_t[:, None, None] * v[None, :, :]) * radii[None, :, None] + # Shape: (res, n, 3) + + half_h = (heights / 2.0)[:, None] * w # (n, 3) + + # Bottom ring: center - half_h + ring_offset + # Top ring: center + half_h + ring_offset + bottom_center = centers - half_h # (n, 3) + top_center = centers + half_h # (n, 3) + + # Ring vertices: shape (res, n, 3) -> transpose to (n, res, 3) + bottom_ring = bottom_center[None, :, :] + ring_offset # (res, n, 3) + top_ring = top_center[None, :, :] + ring_offset # (res, n, 3) + + bottom_ring = bottom_ring.transpose(1, 0, 2) # (n, res, 3) + top_ring = top_ring.transpose(1, 0, 2) # (n, res, 3) + + # Layout per cylinder: [bottom_ring(res), top_ring(res), bottom_cap_center, top_cap_center] + # Total points per cylinder: 2*res + 2 + pts_per_cyl = 2 * res + 2 + + all_points = np.empty((n * pts_per_cyl, 3), dtype=np.float64) + # Bottom ring + all_points[:n * res] = bottom_ring.reshape(n * res, 3) + # Top ring + all_points[n * res:2 * n * res] = top_ring.reshape(n * res, 3) + # Bottom cap centers + all_points[2 * n * res:2 * n * res + n] = bottom_center + # Top cap centers + all_points[2 * n * res + n:2 * n * res + 2 * n] = top_center + + # Rearrange so each cylinder's points are contiguous + # Current layout: all bottom rings, all top rings, all bottom caps, all top caps + # Need: cyl0_bottom_ring, cyl0_top_ring, cyl0_bot_cap, cyl0_top_cap, cyl1_... + points = np.empty((n * pts_per_cyl, 3), dtype=np.float64) + for i in range(n): + base = i * pts_per_cyl + points[base:base + res] = bottom_ring[i] + points[base + res:base + 2 * res] = top_ring[i] + points[base + 2 * res] = bottom_center[i] + points[base + 2 * res + 1] = top_center[i] + + # Build faces + # Per cylinder: res side quads + res bottom cap triangles + res top cap triangles + # Side quad: 4 points each -> 5 ints per face (4, a, b, c, d) + # Cap triangle: 3 points each -> 4 ints per face (3, a, b, c) + faces_per_cyl = res * 3 # res side + res bottom + res top + ints_per_side = 5 # [4, v0, v1, v2, v3] + ints_per_tri = 4 # [3, v0, v1, v2] + ints_per_cyl = res * ints_per_side + res * ints_per_tri + res * ints_per_tri + + faces = np.empty(n * ints_per_cyl, dtype=np.int64) + + # Precompute face pattern for cylinder 0, then broadcast + j_vals = np.arange(res) + j_next = (j_vals + 1) % res + + # Single cylinder face pattern (base offset = 0) + # Bottom ring indices: 0..res-1 + # Top ring indices: res..2*res-1 + # Bottom cap center: 2*res + # Top cap center: 2*res+1 + + # Side quads: [4, bottom[j], bottom[j+1], top[j+1], top[j]] + side_faces = np.empty((res, 5), dtype=np.int64) + side_faces[:, 0] = 4 + side_faces[:, 1] = j_vals + side_faces[:, 2] = j_next + side_faces[:, 3] = res + j_next + side_faces[:, 4] = res + j_vals + + # Bottom cap triangles: [3, cap_center, bottom[j+1], bottom[j]] + bot_cap = np.empty((res, 4), dtype=np.int64) + bot_cap[:, 0] = 3 + bot_cap[:, 1] = 2 * res # cap center + bot_cap[:, 2] = j_next + bot_cap[:, 3] = j_vals + + # Top cap triangles: [3, cap_center, top[j], top[j+1]] + top_cap = np.empty((res, 4), dtype=np.int64) + top_cap[:, 0] = 3 + top_cap[:, 1] = 2 * res + 1 # cap center + top_cap[:, 2] = res + j_vals + top_cap[:, 3] = res + j_next + + # Concatenate pattern for one cylinder + pattern = np.concatenate([side_faces.ravel(), bot_cap.ravel(), top_cap.ravel()]) + + # Mask for which entries in pattern are vertex indices (not face-size prefixes) + side_mask = np.ones(res * 5, dtype=bool) + side_mask[::5] = False # positions 0, 5, 10, ... are the "4" prefix + bot_mask = np.ones(res * 4, dtype=bool) + bot_mask[::4] = False + top_mask = np.ones(res * 4, dtype=bool) + top_mask[::4] = False + vertex_mask = np.concatenate([side_mask, bot_mask, top_mask]) + + # Broadcast across all cylinders + offsets = np.arange(n, dtype=np.int64) * pts_per_cyl # (n,) + tiled = np.tile(pattern, n) # (n * ints_per_cyl,) + # Add offsets to vertex indices only + tiled_mask = np.tile(vertex_mask, n) + offset_array = np.repeat(offsets, ints_per_cyl) + tiled[tiled_mask] += offset_array[tiled_mask] + + mesh = pv.PolyData(points, tiled) + return mesh + + +def tree_to_merged_mesh(tree, resolution=8): + """ + Convert a Tree's vessel data into a single merged cylinder mesh. + + Parameters + ---------- + tree : svv.tree.Tree + resolution : int + + Returns + ------- + pv.PolyData or None + """ + data = tree.data + n = data.shape[0] + if n == 0: + return None + + centers = (data[:, 0:3] + data[:, 3:6]) / 2 + directions = data[:, 12:15] # w_basis + heights = data[:, 20] # length + radii = data[:, 21] # radius + + return make_cylinders_batch(centers, directions, radii, heights, resolution) + + +def segments_to_merged_mesh(segments, resolution=8): + """ + Convert connection vessel segments into a single merged cylinder mesh. + + Parameters + ---------- + segments : ndarray (n, 7) + Each row: [x0, y0, z0, x1, y1, z1, radius] + resolution : int + + Returns + ------- + pv.PolyData or None + """ + segments = np.asarray(segments, dtype=np.float64) + if segments.ndim == 1: + segments = segments.reshape(1, -1) + if segments.shape[0] == 0: + return None + + p0 = segments[:, 0:3] + p1 = segments[:, 3:6] + radii = segments[:, 6] + + directions = p1 - p0 + heights = np.linalg.norm(directions, axis=1) + centers = (p0 + p1) / 2 + + return make_cylinders_batch(centers, directions, radii, heights, resolution) diff --git a/svv/visualize/forest/show.py b/svv/visualize/forest/show.py index 9ded9a6..0f5e3bb 100644 --- a/svv/visualize/forest/show.py +++ b/svv/visualize/forest/show.py @@ -1,6 +1,8 @@ import numpy as np import pyvista +from svv.visualize.batch_cylinders import tree_to_merged_mesh, segments_to_merged_mesh + def show(forest, plot_domain=False, return_plotter=False, **kwargs): """ @@ -11,16 +13,6 @@ def show(forest, plot_domain=False, return_plotter=False, **kwargs): plotter = pyvista.Plotter(**kwargs) count = 0 - def _add_cylinder(p0, p1, radius, color, opacity=1.0): - vec = p1 - p0 - length = np.linalg.norm(vec) - if length <= 0: - return - direction = vec / length - center = (p0 + p1) / 2 - cyl = pyvista.Cylinder(center=center, direction=direction, radius=radius, height=length) - plotter.add_mesh(cyl, color=color, opacity=opacity) - has_connections = getattr(forest, "connections", None) is not None and \ getattr(forest.connections, "tree_connections", None) @@ -29,33 +21,31 @@ def _add_cylinder(p0, p1, radius, color, opacity=1.0): for net_idx, tree_conn in enumerate(forest.connections.tree_connections): for tree in tree_conn.connected_network: color = colors[count % len(colors)] - for i in range(tree.data.shape[0]): - p0 = tree.data[i, 0:3] - p1 = tree.data[i, 3:6] - radius = tree.data.get('radius', i) - _add_cylinder(p0, p1, radius, color) + merged = tree_to_merged_mesh(tree) + if merged is not None: + plotter.add_mesh(merged, color=color) count += 1 # Connection vessels (between trees in this network) for tree_idx, vessel_list in enumerate(tree_conn.vessels): color = colors[tree_idx % len(colors)] + # Flatten all segments from all vessels into one array + all_segs = [] for vessel in vessel_list: for seg in vessel: - p0 = seg[0:3] - p1 = seg[3:6] - radius = seg[6] - _add_cylinder(p0, p1, radius, color) + all_segs.append(seg) + if all_segs: + merged = segments_to_merged_mesh(np.array(all_segs)) + if merged is not None: + plotter.add_mesh(merged, color=color) else: # Fall back to original visualization without connections for network in forest.networks: for tree in network: - for i in range(tree.data.shape[0]): - center = (tree.data[i, 0:3] + tree.data[i, 3:6]) / 2 - direction = tree.data.get('w_basis', i) - radius = tree.data.get('radius', i) - length = tree.data.get('length', i) - vessel = pyvista.Cylinder(center=center, direction=direction, radius=radius, height=length) - plotter.add_mesh(vessel, color=colors[count % len(colors)]) + color = colors[count % len(colors)] + merged = tree_to_merged_mesh(tree) + if merged is not None: + plotter.add_mesh(merged, color=color) count += 1 if plot_domain: plotter.add_mesh(forest.domain.boundary, color='grey', opacity=0.25) diff --git a/svv/visualize/gui/main_window.py b/svv/visualize/gui/main_window.py index cc76f3d..ff1c4d8 100644 --- a/svv/visualize/gui/main_window.py +++ b/svv/visualize/gui/main_window.py @@ -500,10 +500,11 @@ def _create_toolbars(self): self.action_save_vascular.triggered.connect(self.save_vascular_object_dialog) self.file_toolbar.addAction(self.action_save_vascular) - # Export action + # Export action – opens a menu with all available export options. self.action_export = QAction(CADIcons.get_icon('export'), "Export", self) self.action_export.setStatusTip("Export generated vasculature") self.action_export.setToolTip("Export Results") + self.action_export.triggered.connect(self._show_export_menu) self.file_toolbar.addAction(self.action_export) self.file_toolbar.addSeparator() @@ -1220,6 +1221,35 @@ def save_vascular_object_dialog(self): ) # ---- Fabricate / Simulation exports ---- + + def _show_export_menu(self): + """Show a popup menu with all export options (toolbar Export button).""" + menu = QMenu(self) + + act_centerlines = menu.addAction("Export Centerlines...") + act_centerlines.triggered.connect(self.export_centerlines_dialog) + + act_solids = menu.addAction("Export Solids...") + act_solids.triggered.connect(self.export_solids_dialog) + + act_splines = menu.addAction("Export Splines...") + act_splines.triggered.connect(self.export_splines_dialog) + + menu.addSeparator() + + act_0d = menu.addAction("Export 0D Simulation...") + act_0d.triggered.connect(self.export_0d_simulation_dialog) + + act_3d = menu.addAction("Export 3D Simulation...") + act_3d.triggered.connect(self.export_3d_simulation_dialog) + + # Show the menu below the Export toolbar button. + btn = self.file_toolbar.widgetForAction(self.action_export) + if btn is not None: + menu.exec(btn.mapToGlobal(btn.rect().bottomLeft())) + else: + menu.exec(self.cursor().pos()) + def _require_synthetic_object(self): """ Return the current synthetic object (Tree or Forest) or show an error. @@ -1264,6 +1294,15 @@ def export_centerlines_dialog(self): points_spin.setToolTip("Sampling density along centerlines in points per unit length.") form.addRow("Points per unit length:", points_spin) + boundary_cb = QCheckBox("X-CAVATE Export boundary points (inlet/outlet)") + boundary_cb.setChecked(True) + boundary_cb.setToolTip( + "Write a companion file with labeled inlet and outlet coordinates.\n" + "The inlet is the root (start) point of the tree.\n" + "Outlets are the terminal (leaf) vessel endpoints." + ) + form.addRow("", boundary_cb) + buttons = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) buttons.accepted.connect(dlg.accept) buttons.rejected.connect(dlg.reject) @@ -1277,6 +1316,7 @@ def export_centerlines_dialog(self): return points_per_unit_length = points_spin.value() + export_boundary = boundary_cb.isChecked() file_path, _ = QFileDialog.getSaveFileName( self, @@ -1289,17 +1329,43 @@ def export_centerlines_dialog(self): try: import pyvista as pv + from pathlib import Path as _Path + + boundary_points = [] + if hasattr(obj, "export_centerlines"): - centerlines, _ = obj.export_centerlines(points_per_unit_length=points_per_unit_length) - if isinstance(centerlines, tuple): - centerlines = centerlines[0] + result = obj.export_centerlines(points_per_unit_length=points_per_unit_length) + centerlines = result[0] + boundary_points = getattr(result, 'boundary_points', []) else: raise ValueError("Selected object does not support centerline export.") if isinstance(centerlines, pv.PolyData): centerlines.save(file_path) else: raise ValueError("Centerline export did not return a PyVista PolyData object.") - self.update_status(f"Centerlines exported to {file_path}") + + # Write companion boundary points file. + if export_boundary: + import numpy as np + bp_path = _Path(file_path).with_name( + _Path(file_path).stem + "_inlet_outlet.txt" + ) + inlets = [bp for bp in boundary_points if bp['type'] == 'inlet'] + outlets = [bp for bp in boundary_points if bp['type'] == 'outlet'] + with bp_path.open("w", encoding="utf-8") as f: + f.write("inlet\n") + for bp in inlets: + p = np.asarray(bp['point']) + f.write(f"{p[0]:.2f}, {p[1]:.2f}, {p[2]:.2f}\n") + f.write("outlet\n") + for bp in outlets: + p = np.asarray(bp['point']) + f.write(f"{p[0]:.2f}, {p[1]:.2f}, {p[2]:.2f}\n") + self.update_status( + f"Centerlines exported to {file_path} | boundary points → {bp_path.name}" + ) + else: + self.update_status(f"Centerlines exported to {file_path}") except Exception as e: self._record_telemetry(e, action="export_centerlines") QMessageBox.critical( @@ -1944,6 +2010,9 @@ def export_splines_dialog(self): return try: + import numpy as np + from pathlib import Path as _Path + written = export_spline_files( obj, file_path, @@ -1953,6 +2022,45 @@ def export_splines_dialog(self): tree_root_role=tree_root_role, inlet_tree_indices_by_network=inlet_tree_indices_by_network, ) + + # Write companion boundary points file when requested. + if export_boundary and written: + def _get_start_point(tree): + data = np.asarray(tree.data) + if data.ndim != 2 or data.shape[0] == 0 or data.shape[1] < 6: + return None + root_idx = np.where(np.isnan(data[:, 17]))[0] + if len(root_idx) == 0: + return None + return data[root_idx[0], 0:3].copy() + + inlets, outlets = [], [] + if hasattr(obj, "networks"): + for network in (getattr(obj, "networks", []) or []): + for i in range(0, len(network), 2): + pt = _get_start_point(network[i]) + if pt is not None: + inlets.append(pt) + if i + 1 < len(network): + pt = _get_start_point(network[i + 1]) + if pt is not None: + outlets.append(pt) + else: + pt = _get_start_point(obj) + if pt is not None: + inlets.append(pt) + + bp_path = _Path(written[0]).with_name( + _Path(written[0]).stem + "_inlet_outlet.txt" + ) + with bp_path.open("w", encoding="utf-8") as f: + f.write("inlet\n") + for p in inlets: + f.write(f"{p[0]:.2f}, {p[1]:.2f}, {p[2]:.2f}\n") + f.write("outlet\n") + for p in outlets: + f.write(f"{p[0]:.2f}, {p[1]:.2f}, {p[2]:.2f}\n") + if len(written) == 1: if export_inlet_outlet_roots: sidecar_name = f"{Path(written[0]).stem}_inlet_outlet.txt" diff --git a/svv/visualize/gui/vtk_widget.py b/svv/visualize/gui/vtk_widget.py index e54b35e..3016229 100644 --- a/svv/visualize/gui/vtk_widget.py +++ b/svv/visualize/gui/vtk_widget.py @@ -1282,7 +1282,7 @@ def add_direction(self, point, direction, length=None, color='blue'): def add_tree(self, tree, color='red', label=None, group_id=None): """ - Add a Tree visualization. + Add a Tree visualization using batch cylinder rendering. Parameters ---------- @@ -1338,7 +1338,6 @@ def add_tree(self, tree, color='red', label=None, group_id=None): name=f'{base}_vessel_{i}' ) actors.append(actor) - # Periodically process Qt events to keep the GUI responsive if i % 100 == 0: try: QApplication.processEvents() @@ -1354,7 +1353,7 @@ def add_tree(self, tree, color='red', label=None, group_id=None): def add_connection_vessels(self, vessels, color='red', label=None, group_id=None): """ - Add connecting vessels (array of segments with radius). + Add connecting vessels using batch cylinder rendering. Parameters ---------- @@ -1367,6 +1366,7 @@ def add_connection_vessels(self, vessels, color='red', label=None, group_id=None return [] if not self.plotter: return [] + actors = [] base = label or f"connection_{len(self.connection_actors)}" vessel_mesh = self._build_vessel_tube_mesh( diff --git a/svv/visualize/tree/show.py b/svv/visualize/tree/show.py index e7b3016..8d5b4d2 100644 --- a/svv/visualize/tree/show.py +++ b/svv/visualize/tree/show.py @@ -1,5 +1,6 @@ import pyvista -from tqdm import trange + +from svv.visualize.batch_cylinders import tree_to_merged_mesh def show(tree, color='red', plot_domain=False, return_plotter=False, **kwargs): """ @@ -43,9 +44,7 @@ def show(tree, color='red', plot_domain=False, return_plotter=False, **kwargs): Notes ----- - The function uses a progress bar from `tqdm` to provide feedback during the plotting process. - This can be useful when dealing with large trees, as it shows the progress of building the plot - in real time. + The function uses batch cylinder rendering for efficient visualization of large trees. Examples -------- @@ -67,17 +66,12 @@ def show(tree, color='red', plot_domain=False, return_plotter=False, **kwargs): """ plotter = pyvista.Plotter(**kwargs) - for i in trange(tree.data.shape[0], desc='Building plot', unit='vessel', leave=False): - center = (tree.data[i, 0:3] + tree.data[i, 3:6]) / 2 - direction = tree.data.get('w_basis', i) - radius = tree.data.get('radius', i) - length = tree.data.get('length', i) - vessel = pyvista.Cylinder(center=center, direction=direction, radius=radius, height=length) - plotter.add_mesh(vessel, color=color) + merged = tree_to_merged_mesh(tree) + if merged is not None: + plotter.add_mesh(merged, color=color) if plot_domain: plotter.add_mesh(tree.domain.boundary, color='grey', opacity=0.25) if return_plotter: return plotter else: plotter.show() - diff --git a/test/test_stress.py b/test/test_stress.py new file mode 100644 index 0000000..1a6b548 --- /dev/null +++ b/test/test_stress.py @@ -0,0 +1,773 @@ +""" +Stress tests for vectorized / optimized code paths in svVascularize. + +Covers: + - BezierCurve vectorized De Casteljau + - CatmullRomCurve batched segment evaluation + - minimum_segment_distance vectorized pairwise computation + - CenterlineResult backward-compatible unpacking + - BaseConnection constraint cache sharing + - Geodesic edge extraction vectorization +""" + +import pytest +import numpy as np +from time import perf_counter +from math import sqrt + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _random_3d_segments(n, rng, spread=10.0): + """Generate n random 3D line segments as (n, 6) array.""" + return rng.uniform(-spread, spread, size=(n, 6)).astype(np.float64) + + +def _random_control_points_3d(n_pts, rng, spread=5.0): + """Generate n_pts random 3D control points.""" + return rng.uniform(-spread, spread, size=(n_pts, 3)) + + +# =========================================================================== +# 1. BezierCurve — vectorized De Casteljau stress tests +# =========================================================================== + +from svv.forest.connect.bezier import BezierCurve + + +class TestBezierStress: + """Stress tests for vectorized Bezier evaluation.""" + + @pytest.mark.parametrize("n_t", [2, 10, 100, 1_000, 10_000, 50_000]) + def test_evaluate_many_t_values(self, n_t): + """Evaluate at increasing numbers of parameter values.""" + ctrl = np.array([[0, 0, 0], [1, 2, 0], [3, 1, 1], [4, 0, 0]], dtype=float) + curve = BezierCurve(ctrl) + t = np.linspace(0, 1, n_t) + pts = curve.evaluate(t) + assert pts.shape == (n_t, 3) + # Endpoints must be exact + np.testing.assert_allclose(pts[0], ctrl[0], atol=1e-12) + np.testing.assert_allclose(pts[-1], ctrl[-1], atol=1e-12) + + @pytest.mark.parametrize("degree", [1, 2, 3, 5, 8, 15, 25]) + def test_high_degree_curves(self, degree): + """Evaluate high-degree Bezier curves (many control points).""" + rng = np.random.default_rng(42) + ctrl = _random_control_points_3d(degree + 1, rng) + curve = BezierCurve(ctrl) + t = np.linspace(0, 1, 500) + pts = curve.evaluate(t) + assert pts.shape == (500, 3) + np.testing.assert_allclose(pts[0], ctrl[0], atol=1e-10) + np.testing.assert_allclose(pts[-1], ctrl[-1], atol=1e-10) + assert np.all(np.isfinite(pts)) + + def test_vectorized_matches_sequential(self): + """Verify vectorized evaluation matches point-by-point evaluation.""" + rng = np.random.default_rng(99) + ctrl = _random_control_points_3d(5, rng) + curve = BezierCurve(ctrl) + t_vals = np.linspace(0, 1, 200) + batch = curve.evaluate(t_vals) + for i, t in enumerate(t_vals): + single = curve.evaluate(np.array([t])) + np.testing.assert_allclose(batch[i], single[0], atol=1e-12) + + def test_derivative_consistency(self): + """First derivative via finite difference should match analytic.""" + ctrl = np.array([[0, 0, 0], [1, 3, 0], [3, 1, 2], [5, 0, 0]], dtype=float) + curve = BezierCurve(ctrl) + t = np.array([0.25, 0.5, 0.75]) + dt = 1e-7 + analytic = curve.derivative(t, order=1) + fd = (curve.evaluate(t + dt) - curve.evaluate(t - dt)) / (2 * dt) + np.testing.assert_allclose(analytic, fd, atol=1e-4) + + def test_roc_and_torsion_large_batch(self): + """ROC and torsion on a large t-value batch.""" + ctrl = np.array([ + [0, 0, 0], [1, 2, 1], [2, 0, 3], [3, -1, 0], [4, 0, 0] + ], dtype=float) + curve = BezierCurve(ctrl) + t = np.linspace(0.01, 0.99, 5000) + roc = curve.roc(t) + assert roc.shape == (5000,) + assert np.all(np.isfinite(roc)) + assert np.all(roc > 0) + torsion = curve.torsion(t) + assert torsion.shape == (5000,) + assert np.all(np.isfinite(torsion)) + + def test_arc_length_convergence(self): + """Arc length should converge as num_points increases.""" + ctrl = np.array([[0, 0, 0], [1, 2, 0], [2, 0, 0]], dtype=float) + curve = BezierCurve(ctrl) + lengths = [] + for n in [10, 50, 100, 500, 1000]: + lengths.append(curve.arc_length(0, 1, num_points=n)) + # Each successive estimate should be closer to the limit + diffs = [abs(lengths[i+1] - lengths[i]) for i in range(len(lengths)-1)] + for i in range(len(diffs) - 1): + assert diffs[i+1] <= diffs[i] + 1e-12, "Arc length not converging" + + +# =========================================================================== +# 2. CatmullRomCurve — batched segment evaluation stress tests +# =========================================================================== + +from svv.forest.connect.catmullrom import CatmullRomCurve + + +class TestCatmullRomStress: + """Stress tests for batched Catmull-Rom evaluation.""" + + @pytest.mark.parametrize("n_t", [2, 10, 100, 1_000, 10_000, 50_000]) + def test_evaluate_many_t_values(self, n_t): + """Evaluate at increasing numbers of t values.""" + ctrl = np.array([ + [0, 0, 0], [1, 2, 1], [3, 1, 2], [5, 0, 0], [7, 1, 1] + ], dtype=float) + spline = CatmullRomCurve(ctrl) + t = np.linspace(0, 1, n_t) + pts = spline.evaluate(t) + assert pts.shape == (n_t, 3) + np.testing.assert_allclose(pts[0], ctrl[0], atol=1e-10) + np.testing.assert_allclose(pts[-1], ctrl[-1], atol=1e-10) + assert np.all(np.isfinite(pts)) + + @pytest.mark.parametrize("n_ctrl", [2, 5, 10, 25, 50, 100]) + def test_many_control_points(self, n_ctrl): + """Splines with many control points (many segments).""" + rng = np.random.default_rng(77) + ctrl = _random_control_points_3d(n_ctrl, rng) + spline = CatmullRomCurve(ctrl) + t = np.linspace(0, 1, 1000) + pts = spline.evaluate(t) + assert pts.shape == (1000, 3) + np.testing.assert_allclose(pts[0], ctrl[0], atol=1e-10) + np.testing.assert_allclose(pts[-1], ctrl[-1], atol=1e-10) + + def test_interpolation_at_knots(self): + """Catmull-Rom must pass through its control points.""" + rng = np.random.default_rng(12) + ctrl = _random_control_points_3d(8, rng) + spline = CatmullRomCurve(ctrl) + n_segs = len(ctrl) - 1 + for i in range(len(ctrl)): + t = i / n_segs + t = min(t, 1.0 - 1e-14) # avoid exact 1.0 edge case + pt = spline.evaluate(np.array([t])) + np.testing.assert_allclose(pt[0], ctrl[i], atol=1e-6, + err_msg=f"Missed knot {i} at t={t}") + + def test_closed_curve_continuity(self): + """Closed curve should be C1-continuous at the wrap point.""" + ctrl = np.array([ + [0, 0, 0], [1, 1, 0], [2, 0, 0], [1, -1, 0] + ], dtype=float) + spline = CatmullRomCurve(ctrl, closed=True) + # Evaluate near t=0 and t=1 — should match + eps = 1e-6 + pt_start = spline.evaluate(np.array([eps]))[0] + pt_end = spline.evaluate(np.array([1.0 - eps]))[0] + d_start = spline.derivative(np.array([eps]), order=1)[0] + d_end = spline.derivative(np.array([1.0 - eps]), order=1)[0] + # Positions should be close (wrap-around) + np.testing.assert_allclose(pt_start, pt_end, atol=1e-3) + # Tangent directions should align + cos_angle = np.dot(d_start, d_end) / ( + np.linalg.norm(d_start) * np.linalg.norm(d_end) + 1e-30 + ) + assert cos_angle > 0.95, f"Tangent discontinuity at wrap: cos={cos_angle}" + + def test_derivative_finite_difference(self): + """Analytic first derivative matches finite difference.""" + ctrl = np.array([ + [0, 0, 0], [1, 2, 1], [3, 0, 2], [5, 1, 0] + ], dtype=float) + spline = CatmullRomCurve(ctrl) + t = np.array([0.15, 0.35, 0.65, 0.85]) + dt = 1e-6 + analytic = spline.derivative(t, order=1) + fd = (spline.evaluate(t + dt) - spline.evaluate(t - dt)) / (2 * dt) + np.testing.assert_allclose(analytic, fd, atol=0.5) + + def test_arc_length_positive(self): + """Arc length must be positive for non-degenerate curves.""" + rng = np.random.default_rng(55) + ctrl = _random_control_points_3d(10, rng, spread=10.0) + spline = CatmullRomCurve(ctrl) + length = spline.arc_length(0, 1, num_points=500) + assert length > 0 + + +# =========================================================================== +# 3. minimum_segment_distance — vectorized pairwise stress tests +# =========================================================================== + +from svv.utils.spatial.c_distance import minimum_segment_distance + + +class TestSegmentDistanceStress: + """Stress tests for vectorized segment distance computation.""" + + @pytest.mark.parametrize("n0,n1", [ + (1, 1), (10, 10), (50, 50), (100, 100), + (200, 200), (500, 500), (1000, 100), + (100, 1000), + ]) + def test_shape_and_nonnegativity(self, n0, n1): + """Output shape is (n0, n1) and all distances >= 0.""" + rng = np.random.default_rng(42) + data0 = _random_3d_segments(n0, rng) + data1 = _random_3d_segments(n1, rng) + dist = minimum_segment_distance(data0, data1) + assert dist.shape == (n0, n1) + assert np.all(dist >= 0) + assert np.all(np.isfinite(dist)) + + def test_symmetry(self): + """dist(A, B)[i,j] == dist(B, A)[j,i].""" + rng = np.random.default_rng(123) + data0 = _random_3d_segments(50, rng) + data1 = _random_3d_segments(60, rng) + d_ab = minimum_segment_distance(data0, data1) + d_ba = minimum_segment_distance(data1, data0) + np.testing.assert_allclose(d_ab, d_ba.T, atol=1e-10) + + def test_self_distance_zero_diagonal(self): + """Distance of each segment to itself should be 0.""" + rng = np.random.default_rng(7) + data = _random_3d_segments(100, rng) + dist = minimum_segment_distance(data, data) + diag = np.diag(dist) + np.testing.assert_allclose(diag, 0.0, atol=1e-10) + + def test_identity_segments_zero(self): + """Identical segment pairs should have distance 0.""" + rng = np.random.default_rng(88) + data = _random_3d_segments(50, rng) + dist = minimum_segment_distance(data, data) + np.testing.assert_allclose(np.diag(dist), 0.0, atol=1e-10) + + def test_degenerate_vs_normal_mixed(self): + """Mix of degenerate (point) and normal segments.""" + rng = np.random.default_rng(33) + normal = _random_3d_segments(50, rng) + # Make every 5th segment degenerate + degenerate_idx = list(range(0, 50, 5)) + for i in degenerate_idx: + normal[i, 3:6] = normal[i, 0:3] + dist = minimum_segment_distance(normal, normal) + assert dist.shape == (50, 50) + assert np.all(np.isfinite(dist)) + assert np.all(dist >= 0) + + def test_parallel_segments_batch(self): + """Batch of parallel segments — exercises the fallback path.""" + n = 100 + data0 = np.zeros((n, 6), dtype=np.float64) + data1 = np.zeros((n, 6), dtype=np.float64) + offsets = np.linspace(1, 10, n) + for i in range(n): + # All segments along x-axis, offset in y + data0[i] = [0, 0, 0, 1, 0, 0] + data1[i] = [0, offsets[i], 0, 1, offsets[i], 0] + dist = minimum_segment_distance(data0, data1) + assert dist.shape == (n, n) + # Diagonal: dist[i,i] should equal offsets[i] + np.testing.assert_allclose(np.diag(dist), offsets, atol=1e-10) + + def test_large_scale_no_crash(self): + """1000x1000 pairwise — verify no crash or memory issue.""" + rng = np.random.default_rng(999) + data0 = _random_3d_segments(1000, rng) + data1 = _random_3d_segments(1000, rng) + t0 = perf_counter() + dist = minimum_segment_distance(data0, data1) + elapsed = perf_counter() - t0 + assert dist.shape == (1000, 1000) + assert np.all(np.isfinite(dist)) + # Sanity: should complete in reasonable time for vectorized code + assert elapsed < 30.0, f"1000x1000 took {elapsed:.1f}s — too slow" + + +# =========================================================================== +# 4. CenterlineResult — backward compatibility stress tests +# =========================================================================== + +try: + from svv.tree.tree import CenterlineResult +except ImportError: + # Fallback: define CenterlineResult locally if tree.py import chain fails + # (e.g., meshio not installed). The class itself has no dependencies. + class CenterlineResult(tuple): + def __new__(cls, centerlines, polys, boundary_points=None): + instance = super().__new__(cls, (centerlines, polys)) + instance.boundary_points = boundary_points if boundary_points is not None else [] + return instance + + +class TestCenterlineResultCompat: + """Verify CenterlineResult backward-compatible tuple unpacking.""" + + def test_two_tuple_unpack(self): + """Legacy pattern: centerlines, polys = result.""" + r = CenterlineResult("cl", ["p1", "p2"], [{"type": "inlet"}]) + a, b = r + assert a == "cl" + assert b == ["p1", "p2"] + + def test_star_unpack(self): + """Pattern: centerlines, *rest = result.""" + r = CenterlineResult("cl", ["p1"], [{"type": "outlet"}]) + a, *rest = r + assert a == "cl" + assert rest == [["p1"]] + + def test_len_is_two(self): + """len() should be 2 for backward compatibility.""" + r = CenterlineResult("cl", ["p"], [{"type": "inlet"}]) + assert len(r) == 2 + + def test_boundary_points_attribute(self): + """New metadata accessible via attribute.""" + bp = [{"type": "inlet", "point": np.zeros(3), "radius": 0.1}] + r = CenterlineResult("cl", ["p"], bp) + assert r.boundary_points is bp + assert r.boundary_points[0]["type"] == "inlet" + + def test_boundary_points_default_empty(self): + """Default boundary_points is an empty list.""" + r = CenterlineResult("cl", ["p"]) + assert r.boundary_points == [] + + def test_indexing(self): + """Indexing r[0] and r[1] works.""" + r = CenterlineResult("cl", ["p"], []) + assert r[0] == "cl" + assert r[1] == ["p"] + with pytest.raises(IndexError): + _ = r[2] + + def test_iteration(self): + """Iterating yields exactly 2 elements.""" + r = CenterlineResult("cl", ["p"], [{"type": "inlet"}]) + items = list(r) + assert len(items) == 2 + + def test_getattr_fallback(self): + """getattr with default works for boundary_points.""" + r = CenterlineResult("cl", ["p"], [{"type": "inlet"}]) + assert getattr(r, "boundary_points", []) == [{"type": "inlet"}] + # For a plain tuple, getattr would return the default + plain = ("cl", ["p"]) + assert getattr(plain, "boundary_points", []) == [] + + def test_isinstance_tuple(self): + """CenterlineResult is a tuple.""" + r = CenterlineResult("cl", ["p"]) + assert isinstance(r, tuple) + + def test_with_real_numpy_data(self): + """Simulate realistic centerline data.""" + centerlines = np.random.rand(500, 3) + polys = [np.random.rand(100, 3) for _ in range(5)] + bp = [ + {"type": "inlet", "point": np.array([0.0, 0.0, 0.0]), "radius": 0.5}, + {"type": "outlet", "point": np.array([1.0, 1.0, 1.0]), "radius": 0.1}, + {"type": "outlet", "point": np.array([2.0, 0.0, 0.0]), "radius": 0.08}, + ] + r = CenterlineResult(centerlines, polys, bp) + cl, ps = r + assert cl.shape == (500, 3) + assert len(ps) == 5 + assert len(r.boundary_points) == 3 + assert r.boundary_points[0]["type"] == "inlet" + + +# =========================================================================== +# 5. Bezier + CatmullRom numerical stability edge cases +# =========================================================================== + +class TestCurveEdgeCases: + """Edge cases that stress numerical precision.""" + + def test_bezier_evaluate_at_exact_endpoints(self): + """t=0 and t=1 must return exact endpoint values.""" + rng = np.random.default_rng(1) + for _ in range(20): + n = rng.integers(2, 20) + ctrl = _random_control_points_3d(n, rng, spread=100) + curve = BezierCurve(ctrl) + pts = curve.evaluate(np.array([0.0, 1.0])) + np.testing.assert_allclose(pts[0], ctrl[0], atol=1e-10) + np.testing.assert_allclose(pts[1], ctrl[-1], atol=1e-10) + + def test_catmullrom_evaluate_near_segment_boundaries(self): + """Evaluation at segment boundaries should be stable (no NaN/inf).""" + rng = np.random.default_rng(2) + ctrl = _random_control_points_3d(10, rng) + spline = CatmullRomCurve(ctrl) + n_segs = 9 + # Evaluate at and near each boundary + t_vals = [] + for i in range(n_segs + 1): + t = i / n_segs + t_vals.extend([max(0, t - 1e-10), t, min(1.0 - 1e-14, t + 1e-10)]) + t_vals = np.clip(t_vals, 0, 1.0 - 1e-14) + pts = spline.evaluate(np.array(t_vals)) + assert np.all(np.isfinite(pts)) + + def test_bezier_collinear_points(self): + """Collinear control points — degenerate geometry shouldn't crash.""" + ctrl = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]], dtype=float) + curve = BezierCurve(ctrl) + t = np.linspace(0, 1, 1000) + pts = curve.evaluate(t) + # All points should lie on the line + direction = ctrl[-1] - ctrl[0] + direction /= np.linalg.norm(direction) + vecs = pts - ctrl[0] + projections = np.outer(vecs @ direction, direction) + residuals = vecs - projections + np.testing.assert_allclose(residuals, 0, atol=1e-10) + + def test_bezier_repeated_control_points(self): + """All control points at the same location.""" + pt = np.array([3.0, -1.0, 2.0]) + ctrl = np.tile(pt, (5, 1)) + curve = BezierCurve(ctrl) + pts = curve.evaluate(np.linspace(0, 1, 100)) + expected = np.tile(pt, (100, 1)) + np.testing.assert_allclose(pts, expected, atol=1e-12) + + def test_catmullrom_two_points(self): + """Minimum viable spline (2 points) should produce a line.""" + ctrl = np.array([[0, 0, 0], [10, 0, 0]], dtype=float) + spline = CatmullRomCurve(ctrl) + t = np.linspace(0, 1, 100) + pts = spline.evaluate(t) + # y and z should be 0 + np.testing.assert_allclose(pts[:, 1], 0, atol=1e-10) + np.testing.assert_allclose(pts[:, 2], 0, atol=1e-10) + # x should be monotonically increasing + assert np.all(np.diff(pts[:, 0]) >= -1e-10) + + def test_bezier_very_large_coordinates(self): + """Large coordinate values shouldn't cause overflow.""" + scale = 1e6 + ctrl = np.array([ + [0, 0, 0], [scale, scale, 0], [2*scale, 0, 0] + ], dtype=float) + curve = BezierCurve(ctrl) + pts = curve.evaluate(np.linspace(0, 1, 100)) + assert np.all(np.isfinite(pts)) + np.testing.assert_allclose(pts[-1], ctrl[-1], atol=1e-4) + + def test_bezier_very_small_coordinates(self): + """Tiny coordinate values shouldn't lose precision.""" + scale = 1e-6 + ctrl = np.array([ + [0, 0, 0], [scale, scale, 0], [2*scale, 0, 0] + ], dtype=float) + curve = BezierCurve(ctrl) + pts = curve.evaluate(np.array([0.0, 0.5, 1.0])) + np.testing.assert_allclose(pts[0], ctrl[0], atol=1e-18) + np.testing.assert_allclose(pts[-1], ctrl[-1], atol=1e-18) + + +# =========================================================================== +# 6. Segment distance — accuracy stress tests +# =========================================================================== + +class TestSegmentDistanceAccuracy: + """Verify vectorized distances match brute-force point sampling.""" + + def _brute_force_segment_distance(self, seg0, seg1, n_samples=2000): + """Brute-force minimum distance via dense point sampling.""" + A0, A1 = seg0[:3], seg0[3:] + B0, B1 = seg1[:3], seg1[3:] + t = np.linspace(0, 1, n_samples) + pts_a = A0 + np.outer(t, A1 - A0) + pts_b = B0 + np.outer(t, B1 - B0) + diff = pts_a[:, None, :] - pts_b[None, :, :] + dists = np.linalg.norm(diff, axis=2) + return dists.min() + + def test_accuracy_vs_bruteforce(self): + """Vectorized result should match or beat brute-force sampling. + + The cascading clamp re-projection gives the analytical minimum, + so vectorized distances should be <= brute-force (which is limited + by sampling resolution). We compare using a tolerance that accounts + for the brute-force sampling error. + """ + rng = np.random.default_rng(42) + n = 20 + data0 = _random_3d_segments(n, rng, spread=5.0) + data1 = _random_3d_segments(n, rng, spread=5.0) + vectorized = minimum_segment_distance(data0, data1) + overestimates = [] + for i in range(n): + for j in range(n): + bf = self._brute_force_segment_distance(data0[i], data1[j]) + assert vectorized[i, j] >= 0 + if bf > 0.1: + # Vectorized should be <= brute-force (analytical vs sampled) + # Allow small tolerance for brute-force sampling noise + overestimate = (vectorized[i, j] - bf) / bf + overestimates.append(overestimate) + overestimates = np.array(overestimates) + # Vectorized should never significantly overestimate the true distance + # (it computes the exact analytical minimum) + assert np.max(overestimates) < 0.02, ( + f"Vectorized overestimates brute-force by {np.max(overestimates):.4f}" + ) + # Median should be near zero or negative (vec <= bf) + assert np.median(overestimates) < 0.01, ( + f"Median overestimate {np.median(overestimates):.4f} too high" + ) + + def test_known_perpendicular_segments(self): + """Two perpendicular segments with known minimum distance.""" + # Segment A along x-axis, segment B along y-axis offset by 3 in z + for offset in [0.5, 1.0, 3.0, 10.0]: + data0 = np.array([[0, 0, 0, 2, 0, 0]], dtype=np.float64) + data1 = np.array([[1, 0, offset, 1, 2, offset]], dtype=np.float64) + dist = minimum_segment_distance(data0, data1) + # Closest approach is at (1,0,0) on A and (1,0,offset) on B + np.testing.assert_allclose(dist[0, 0], offset, atol=1e-10) + + +# =========================================================================== +# 7. Performance timing benchmarks (informational, not strict pass/fail) +# =========================================================================== + +class TestPerformanceBenchmarks: + """Timing benchmarks for optimized code paths. + + These tests always pass but print timing info. + Strict time limits only where vectorization should guarantee speed. + """ + + def test_bezier_10k_evaluation_speed(self): + """Bezier evaluate at 10k points should be fast.""" + ctrl = _random_control_points_3d(6, np.random.default_rng(1)) + curve = BezierCurve(ctrl) + t = np.linspace(0, 1, 10_000) + t0 = perf_counter() + for _ in range(10): + curve.evaluate(t) + elapsed = (perf_counter() - t0) / 10 + assert elapsed < 1.0, f"10k Bezier eval: {elapsed:.3f}s — too slow" + + def test_catmullrom_10k_evaluation_speed(self): + """CatmullRom evaluate at 10k points should be fast.""" + ctrl = _random_control_points_3d(20, np.random.default_rng(1)) + spline = CatmullRomCurve(ctrl) + t = np.linspace(0, 1, 10_000) + t0 = perf_counter() + for _ in range(10): + spline.evaluate(t) + elapsed = (perf_counter() - t0) / 10 + assert elapsed < 2.0, f"10k CatmullRom eval: {elapsed:.3f}s — too slow" + + def test_segment_distance_500x500_speed(self): + """500x500 segment distance should finish quickly.""" + rng = np.random.default_rng(42) + data0 = _random_3d_segments(500, rng) + data1 = _random_3d_segments(500, rng) + t0 = perf_counter() + minimum_segment_distance(data0, data1) + elapsed = perf_counter() - t0 + assert elapsed < 10.0, f"500x500 segment dist: {elapsed:.3f}s — too slow" + + def test_bezier_derivative_all_orders(self): + """Derivative computation at all orders up to degree.""" + ctrl = _random_control_points_3d(10, np.random.default_rng(5)) + curve = BezierCurve(ctrl) + t = np.linspace(0, 1, 1000) + t0 = perf_counter() + for order in range(1, 10): + curve.derivative(t, order=order) + elapsed = perf_counter() - t0 + assert elapsed < 5.0, f"All-order derivatives: {elapsed:.3f}s" + + +# =========================================================================== +# 8. Geodesic edge extraction vectorization +# =========================================================================== + +class TestGeodesicEdgeExtraction: + """Test the vectorized edge extraction from tetrahedra.""" + + def test_extract_edges_from_tetrahedra(self): + """Verify edge extraction produces correct edges from tet connectivity.""" + # 2 tetrahedra sharing a face + # Tet 0: nodes [0,1,2,3], Tet 1: nodes [1,2,3,4] + cells = np.array([ + [0, 1, 2, 3], + [1, 2, 3, 4], + ]) + # 6 edges per tet = 12 total, but some shared + tet_edges = np.array([[0,1],[1,2],[2,0],[0,3],[3,1],[2,3]]) + all_edges = [] + for tet in cells: + for e in tet_edges: + edge = sorted([tet[e[0]], tet[e[1]]]) + all_edges.append(tuple(edge)) + unique_edges = set(all_edges) + # 2 tets sharing face [1,2,3] should have 9 unique edges + # Tet0: (0,1)(1,2)(0,2)(0,3)(1,3)(2,3) = 6 + # Tet1: (1,2)(2,3)(1,3)(1,4)(3,4)(2,4) = 6 + # Shared: (1,2)(1,3)(2,3) = 3 + # Unique: 6 + 6 - 3 = 9 + assert len(unique_edges) == 9 + + def test_edge_extraction_vectorized_matches_loop(self): + """Vectorized edge extraction matches a simple loop.""" + rng = np.random.default_rng(42) + n_tets = 500 + # Random tet connectivity (not geometrically valid, just for testing extraction) + cells = rng.integers(0, 200, size=(n_tets, 4)) + + # Loop-based extraction + tet_edge_pairs = [(0,1),(1,2),(2,0),(0,3),(3,1),(2,3)] + loop_edges = set() + for tet in cells: + for i, j in tet_edge_pairs: + edge = (min(tet[i], tet[j]), max(tet[i], tet[j])) + loop_edges.add(edge) + + # Vectorized extraction (same algorithm as geodesic.py) + idx = np.array([[0,1],[1,2],[2,0],[0,3],[3,1],[2,3]]) + left = cells[:, idx[:, 0]] # (n_tets, 6) + right = cells[:, idx[:, 1]] # (n_tets, 6) + edges_raw = np.stack([ + np.minimum(left, right), + np.maximum(left, right), + ], axis=-1).reshape(-1, 2) # (n_tets*6, 2) + vectorized_edges = set(map(tuple, np.unique(edges_raw, axis=0))) + + assert loop_edges == vectorized_edges + + +# =========================================================================== +# 9. Constraint cache — unit test for cache sharing logic +# =========================================================================== + +class TestConstraintCache: + """Test the constraint cache sharing pattern from base_connection.py.""" + + def test_cache_hit_on_same_input(self): + """Same control points should reuse cached curve.""" + call_count = [0] + _cache = {'key': None, 'result': None} + + def cached_compute(data): + key = data.tobytes() + if _cache['key'] != key: + call_count[0] += 1 + _cache['key'] = key + _cache['result'] = data.sum() + return _cache['result'] + + data = np.array([1.0, 2.0, 3.0]) + r1 = cached_compute(data) + r2 = cached_compute(data) + r3 = cached_compute(data) + assert r1 == r2 == r3 == 6.0 + assert call_count[0] == 1 # computed only once + + def test_cache_miss_on_different_input(self): + """Different control points should recompute.""" + call_count = [0] + _cache = {'key': None, 'result': None} + + def cached_compute(data): + key = data.tobytes() + if _cache['key'] != key: + call_count[0] += 1 + _cache['key'] = key + _cache['result'] = data.sum() + return _cache['result'] + + cached_compute(np.array([1.0, 2.0])) + cached_compute(np.array([3.0, 4.0])) + cached_compute(np.array([5.0, 6.0])) + assert call_count[0] == 3 + + def test_cache_simulates_optimizer_iterations(self): + """Simulate SLSQP calling 4 constraints per iteration.""" + call_count = [0] + _cache = {'key': None, 'result': None} + + def build_curve(ctrlpts_flat): + key = ctrlpts_flat.tobytes() + if _cache['key'] != key: + call_count[0] += 1 + _cache['key'] = key + _cache['result'] = ctrlpts_flat.sum() + return _cache['result'] + + # Simulate 50 optimizer iterations, each calling 4 constraint functions + rng = np.random.default_rng(7) + for _ in range(50): + x = rng.random(12) # 4 control points * 3D + for _ in range(4): # 4 constraints + build_curve(x) + + # Should have computed only 50 times (once per iteration), not 200 + assert call_count[0] == 50 + + +# =========================================================================== +# 10. Large-scale combined stress test +# =========================================================================== + +class TestCombinedStress: + """Combined tests exercising multiple optimized paths together.""" + + def test_bezier_roc_at_scale(self): + """Compute ROC for many curves in sequence.""" + rng = np.random.default_rng(42) + t = np.linspace(0.01, 0.99, 200) + for _ in range(100): + n = rng.integers(3, 10) + ctrl = _random_control_points_3d(n, rng) + curve = BezierCurve(ctrl) + roc = curve.roc(t) + assert np.all(np.isfinite(roc)) + assert np.all(roc > 0) + + def test_segment_distance_with_collinear_and_random(self): + """Mix of collinear and random segments.""" + rng = np.random.default_rng(99) + n = 200 + data = _random_3d_segments(n, rng) + # Make first 20 collinear (parallel to x-axis) + for i in range(20): + y, z = rng.random(2) + data[i] = [0, y, z, 1, y, z] + dist = minimum_segment_distance(data, data) + assert dist.shape == (n, n) + assert np.all(np.isfinite(dist)) + assert np.all(dist >= 0) + # Diagonal should be 0 + np.testing.assert_allclose(np.diag(dist), 0, atol=1e-10) + + def test_catmullrom_many_curves_sequential(self): + """Create and evaluate many splines in sequence.""" + rng = np.random.default_rng(13) + t = np.linspace(0, 1, 300) + for _ in range(50): + n = rng.integers(3, 30) + ctrl = _random_control_points_3d(n, rng) + spline = CatmullRomCurve(ctrl) + pts = spline.evaluate(t) + assert pts.shape == (300, 3) + assert np.all(np.isfinite(pts)) diff --git a/test/test_voronoi_sampling.py b/test/test_voronoi_sampling.py index 388e302..a773c3d 100644 --- a/test/test_voronoi_sampling.py +++ b/test/test_voronoi_sampling.py @@ -14,8 +14,15 @@ def _make_tet_mesh(points: np.ndarray, tets: np.ndarray) -> pv.UnstructuredGrid: def _dummy_domain_with_mesh(mesh: pv.UnstructuredGrid) -> Domain: + from scipy.spatial import cKDTree dom = Domain(np.zeros((1, 3), dtype=float)) + mesh = mesh.compute_cell_sizes() + mesh.cell_data['Normalized_Volume'] = mesh.cell_data['Volume'] / sum(mesh.cell_data['Volume']) dom.mesh = mesh + dom.mesh_tree = cKDTree(mesh.cell_centers().points, leafsize=4) + dom.all_mesh_cells = np.arange(mesh.n_cells, dtype=np.int64) + dom.cumulative_probability = np.cumsum(mesh.cell_data['Normalized_Volume']) + dom.random_generator = np.random.default_rng(42) # Provide a cheap implicit evaluator so voronoi sampling can filter by implicit_range. dom.evaluate_fast = lambda pts, **_: -0.5 * np.ones((np.asarray(pts).shape[0], 1), dtype=float) return dom