Commit db331d8a99a1ff623a601d4ef5b40fb12019b890

Authored by David Mayerich
1 parent 8d96a467

added adjacency weight calculations and scaling functions to Network, updated sp…

…harmonics functions to take tuples
Showing 2 changed files with 153 additions and 65 deletions   Show diff stats
@@ -11,6 +11,8 @@ import scipy as sp @@ -11,6 +11,8 @@ import scipy as sp
11 import networkx as nx 11 import networkx as nx
12 import matplotlib.pyplot as plt 12 import matplotlib.pyplot as plt
13 import math 13 import math
  14 +import time
  15 +import spharmonics
14 16
15 ''' 17 '''
16 Definition of the Node class 18 Definition of the Node class
@@ -33,18 +35,12 @@ class Node: @@ -33,18 +35,12 @@ class Node:
33 ''' 35 '''
34 class Fiber: 36 class Fiber:
35 37
36 - def __init__ (self):  
37 - self.v0 = 0  
38 - self.v1 = 0  
39 - self.points = []  
40 - self.radii = []  
41 -  
42 - #NOTE: there is no function overloading in Python  
43 -# def __init__ (self, p1, p2, pois, rads):  
44 -# self.v0 = p1  
45 -# self.v1 = p2  
46 -# self.points = pois  
47 -# self.radii = rads 38 +
  39 + def __init__ (self, p1, p2, pois, rads):
  40 + self.v0 = p1
  41 + self.v1 = p2
  42 + self.points = pois
  43 + self.radii = rads
48 44
49 ''' 45 '''
50 return the length of the fiber. 46 return the length of the fiber.
@@ -83,7 +79,7 @@ class NWT: @@ -83,7 +79,7 @@ class NWT:
83 ''' 79 '''
84 Writes the header given and open file descripion, number of verticies and number of edges. 80 Writes the header given and open file descripion, number of verticies and number of edges.
85 ''' 81 '''
86 - def writeHeader(self, open_file, numVerts, numEdges): 82 + def writeHeader(open_file, numVerts, numEdges):
87 txt = "nwtFileFormat fileid(14B), desc(58B), #vertices(4B), #edges(4B): bindata" 83 txt = "nwtFileFormat fileid(14B), desc(58B), #vertices(4B), #edges(4B): bindata"
88 b = bytearray() 84 b = bytearray()
89 b.extend(txt.encode()) 85 b.extend(txt.encode())
@@ -95,7 +91,7 @@ class NWT: @@ -95,7 +91,7 @@ class NWT:
95 ''' 91 '''
96 Writes a single vertex to a file. 92 Writes a single vertex to a file.
97 ''' 93 '''
98 - def writeVertex(self, open_file, vertex): 94 + def writeVertex(open_file, vertex):
99 open_file.write(struct.pack('<f',vertex.p[0])) 95 open_file.write(struct.pack('<f',vertex.p[0]))
100 open_file.write(struct.pack('<f',vertex.p[1])) 96 open_file.write(struct.pack('<f',vertex.p[1]))
101 open_file.write(struct.pack('<f',vertex.p[2])) 97 open_file.write(struct.pack('<f',vertex.p[2]))
@@ -112,7 +108,7 @@ class NWT: @@ -112,7 +108,7 @@ class NWT:
112 ''' 108 '''
113 Writes a single fiber to a file. 109 Writes a single fiber to a file.
114 ''' 110 '''
115 - def writeFiber(self, open_file, edge): 111 + def writeFiber(open_file, edge):
116 open_file.write(struct.pack('i',edge.v0)) 112 open_file.write(struct.pack('i',edge.v0))
117 open_file.write(struct.pack('i',edge.v1)) 113 open_file.write(struct.pack('i',edge.v1))
118 open_file.write(struct.pack('i',len(edge.points))) 114 open_file.write(struct.pack('i',len(edge.points)))
@@ -127,14 +123,14 @@ class NWT: @@ -127,14 +123,14 @@ class NWT:
127 ''' 123 '''
128 Writes the entire network to a file in str given the vertices array and the edges array. 124 Writes the entire network to a file in str given the vertices array and the edges array.
129 ''' 125 '''
130 - def exportNWT(self, str, vertices, edges): 126 + def exportNWT(str, vertices, edges):
131 with open(str, "wb") as file: 127 with open(str, "wb") as file:
132 - self.writeHeader(file, len(vertices), len(edges)) 128 + NWT.writeHeader(file, len(vertices), len(edges))
133 for i in range(len(vertices)): 129 for i in range(len(vertices)):
134 - self.writeVertex(file, vertices[i]) 130 + NWT.writeVertex(file, vertices[i])
135 131
136 for i in range(len(edges)): 132 for i in range(len(edges)):
137 - self.writeFiber(file, edges[i]) 133 + NWT.writeFiber(file, edges[i])
138 134
139 return 135 return
140 136
@@ -142,7 +138,7 @@ class NWT: @@ -142,7 +138,7 @@ class NWT:
142 ''' 138 '''
143 Reads a single vertex from an open file and returns a node Object. 139 Reads a single vertex from an open file and returns a node Object.
144 ''' 140 '''
145 - def readVertex(self, open_file): 141 + def readVertex(open_file):
146 points = np.tile(0., 3) 142 points = np.tile(0., 3)
147 bytes = open_file.read(4) 143 bytes = open_file.read(4)
148 points[0] = struct.unpack('f', bytes)[0] 144 points[0] = struct.unpack('f', bytes)[0]
@@ -172,7 +168,7 @@ class NWT: @@ -172,7 +168,7 @@ class NWT:
172 ''' 168 '''
173 Reads a single fiber from an open file and returns a Fiber object . 169 Reads a single fiber from an open file and returns a Fiber object .
174 ''' 170 '''
175 - def readFiber(self, open_file): 171 + def readFiber(open_file):
176 bytes = open_file.read(4) 172 bytes = open_file.read(4)
177 vtx0 = int.from_bytes(bytes, byteorder = 'little') 173 vtx0 = int.from_bytes(bytes, byteorder = 'little')
178 bytes = open_file.read(4) 174 bytes = open_file.read(4)
@@ -206,8 +202,11 @@ class NWT: @@ -206,8 +202,11 @@ class NWT:
206 202
207 class Network: 203 class Network:
208 204
209 - def __init__(self, str):  
210 - with open(str, "rb") as file: 205 + def __init__(self, filename, clock=False):
  206 + if clock:
  207 + start_time = time.time()
  208 +
  209 + with open(filename, "rb") as file:
211 header = file.read(72) 210 header = file.read(72)
212 bytes = file.read(4) 211 bytes = file.read(4)
213 numVertex = int.from_bytes(bytes, byteorder='little') 212 numVertex = int.from_bytes(bytes, byteorder='little')
@@ -223,6 +222,8 @@ class Network: @@ -223,6 +222,8 @@ class Network:
223 for i in range(numEdges): 222 for i in range(numEdges):
224 edge = NWT.readFiber(file) 223 edge = NWT.readFiber(file)
225 self.F.append(edge) 224 self.F.append(edge)
  225 + if clock:
  226 + print("Network initialization: " + str(time.time() - start_time) + "s")
226 227
227 ''' 228 '''
228 Creates a graph from a list of nodes and a list of edges. 229 Creates a graph from a list of nodes and a list of edges.
@@ -287,7 +288,7 @@ class Network: @@ -287,7 +288,7 @@ class Network:
287 288
288 lower = self.N[0].p.copy() 289 lower = self.N[0].p.copy()
289 upper = lower.copy() 290 upper = lower.copy()
290 - for i in self.nodeList: 291 + for i in self.N:
291 for c in range(len(lower)): 292 for c in range(len(lower)):
292 if lower[c] > i.p[c]: 293 if lower[c] > i.p[c]:
293 lower[c] = i.p[c] 294 lower[c] = i.p[c]
@@ -317,7 +318,7 @@ class Network: @@ -317,7 +318,7 @@ class Network:
317 R = (200, 200, 200) 318 R = (200, 200, 200)
318 319
319 #generate a meshgrid of the appropriate size and resolution to surround the network 320 #generate a meshgrid of the appropriate size and resolution to surround the network
320 - lower, upper = self.aabb(self.N, self.F) #get the space occupied by the network 321 + lower, upper = self.aabb(self.N, self.F) #get the space occupied by the network
321 x = np.linspace(lower[0], upper[0], R[0]) #get the grid points for uniform sampling of this space 322 x = np.linspace(lower[0], upper[0], R[0]) #get the grid points for uniform sampling of this space
322 y = np.linspace(lower[1], upper[1], R[1]) 323 y = np.linspace(lower[1], upper[1], R[1])
323 z = np.linspace(lower[2], upper[2], R[2]) 324 z = np.linspace(lower[2], upper[2], R[2])
@@ -333,31 +334,106 @@ class Network: @@ -333,31 +334,106 @@ class Network:
333 return D 334 return D
334 335
335 #returns the number of points in the network 336 #returns the number of points in the network
336 - def npoints(self):  
337 - n = 0  
338 - for f in self.F:  
339 - n = n + len(f.points) - 2  
340 - n = n + len(self.N)  
341 - return n 337 + def npoints(self):
  338 + n = 0 #initialize the counter to zero
  339 + for f in self.F: #for each fiber
  340 + n = n + len(f.points) - 2 #count the number of points in the fiber - ignoring the end points
  341 + n = n + len(self.N) #add the number of nodes (shared points) to the node count
  342 + return n #return the number of nodes
  343 +
  344 + #returns all of the points in the network
  345 + def points(self):
  346 + k = self.npoints()
  347 + P = np.zeros((3, k)) #allocate space for the point list
  348 +
  349 + idx = 0
  350 + for f in self.F: #for each fiber in the network
  351 + for ip in range(1, len(f.points)-1): #for each point in the network
  352 + P[:, idx] = f.points[ip] #store the point in the raw point list
  353 + idx = idx + 1
  354 + return P #return the point array
342 355
343 #returns the number of linear segments in the network 356 #returns the number of linear segments in the network
344 def nsegments(self): 357 def nsegments(self):
345 - n = 0  
346 - for f in self.F:  
347 - n = n + len(f.points) - 1  
348 - return n 358 + n = 0 #initialize the segment counter to 0
  359 + for f in self.F: #for each fiber
  360 + n = n + len(f.points) - 1 #calculate the number of line segments in the fiber (points - 1)
  361 + return n #return the number of line segments
349 362
350 - def vectorize(self):  
351 - #initialize three coordinate vectors ()  
352 - n = self.nsegments(self.N, self.F)  
353 - X = np.zeros((n))  
354 - Y = np.zeros((n))  
355 - Z = np.zeros((n)) 363 + #return a list of line segments representing the network
  364 + def segments(self, dtype=np.float32):
  365 + k = self.nsegments() #get the number of line segments
  366 + start = np.zeros((k, 3),dtype=dtype) #start points for the line segments
  367 + end = np.zeros((k, 3), dtype=dtype) #end points for the line segments
  368 +
  369 + idx = 0 #initialize the index counter to zero
  370 + for f in self.F: #for each fiber in the network
  371 + for ip in range(0, len(f.points)-1): #for each point in the network
  372 + start[idx, :] = f.points[ip] #store the point in the raw point list
  373 + idx = idx + 1
  374 +
356 idx = 0 375 idx = 0
357 - for i in range(0, len(self.F)):  
358 - for j in range(0, len(self.F[i].points)-1):  
359 - X[idx] = self.F[i].points[j][0]-self.F[i].points[j+1][0]  
360 - Y[idx] = self.F[i].points[j][1]-self.F[i].points[j+1][1]  
361 - Z[idx] = self.F[i].points[j][2]-self.F[i].points[j+1][2] 376 + for f in self.F: #for each fiber in the network
  377 + for ip in range(1, len(f.points)): #for each point in the network
  378 + end[idx, :] = f.points[ip] #store the point in the raw point list
362 idx = idx + 1 379 idx = idx + 1
363 - return X, Y, Z 380 +
  381 + return start, end
  382 +
  383 + #function returns the fiber associated with a given 1D line segment index
  384 + def segment2fiber(self, idx):
  385 + i = 0
  386 + for f in range(len(self.F)): #for each fiber in the network
  387 + i = i + len(self.F[f].points)-1 #add the number of points in the fiber to i
  388 + if i > idx: #if we encounter idx in this fiber
  389 + return self.F[f].points, f #return the fiber associated with idx and the index into the fiber array
  390 +
  391 + def vectors(self, clock=False, dtype=np.float32):
  392 + if clock:
  393 + start_time = time.time()
  394 + start, end = self.segments(dtype) #retrieve all of the line segments
  395 + v = end - start #calculate the resulting vectors
  396 + l = np.sqrt(v[:, 0]**2 + v[:,1]**2 + v[:,2]**2) #calculate the fiber lengths
  397 + z = l==0 #look for any zero values
  398 + nz = z.sum()
  399 + if nz > 0:
  400 + print("WARNING: " + str(nz) + " line segment(s) of length zero were found in the network and will be removed" )
  401 +
  402 + if clock:
  403 + print("Network::vectors: " + str(time.time() - start_time) + "s")
  404 +
  405 + return np.delete(v, np.where(z), 0)
  406 +
  407 + #scale all values in the network by tuple S = (sx, sy, sz)
  408 + def scale(self, S):
  409 + for f in self.F:
  410 + for p in f.points:
  411 + p[0] = p[0] * S[0]
  412 + p[1] = p[1] * S[1]
  413 + p[2] = p[2] * S[2]
  414 +
  415 + for n in self.N:
  416 + n.p[0] = n.p[0] * S[0]
  417 + n.p[1] = n.p[1] * S[1]
  418 + n.p[2] = n.p[2] * S[2]
  419 +
  420 +
  421 + #calculate the adjacency weighting function for the network given a set of vectors X = (x, y, z) and weight exponent k
  422 + def adjacencyweight(self, P, k=200, length_threshold = 25, dtype=np.float32):
  423 + V = self.vectors(dtype) #get the vectors representing each segment
  424 + #V = V[0:n_vectors, :]
  425 + L = np.expand_dims(np.sqrt((V**2).sum(1)), 1) #calculate the length of each vector
  426 +
  427 + outliers = L > length_threshold #remove outliers based on the length_threshold
  428 + V = np.delete(V, np.where(outliers), 0)
  429 + L = np.delete(L, np.where(outliers))
  430 + V = V/L[:,None] #normalize the vectors
  431 +
  432 + P = np.stack(spharmonics.sph2cart(1, P[0], P[1]), P[0].ndim)
  433 + PV = P[...,None,:] * V
  434 + cos_alpha = PV.sum(PV.ndim-1)
  435 + W = np.abs(cos_alpha) ** k
  436 +
  437 + return W, L
  438 +
  439 +
python/spharmonics.py
@@ -11,16 +11,17 @@ import matplotlib.pyplot as plt @@ -11,16 +11,17 @@ import matplotlib.pyplot as plt
11 from matplotlib import cm, colors 11 from matplotlib import cm, colors
12 from mpl_toolkits.mplot3d import Axes3D 12 from mpl_toolkits.mplot3d import Axes3D
13 import math 13 import math
  14 +import time
14 15
15 16
16 -def sph2cart(theta, phi, r): 17 +def sph2cart(r, theta, phi):
17 x = r * numpy.cos(theta) * numpy.sin(phi) 18 x = r * numpy.cos(theta) * numpy.sin(phi)
18 y = r * numpy.sin(theta) * numpy.sin(phi) 19 y = r * numpy.sin(theta) * numpy.sin(phi)
19 z = r * numpy.cos(phi) 20 z = r * numpy.cos(phi)
20 21
21 return x, y, z 22 return x, y, z
22 23
23 -def cart2sph(x,y,z): 24 +def cart2sph(x, y, z):
24 r = numpy.sqrt(x**2+y**2+z**2) 25 r = numpy.sqrt(x**2+y**2+z**2)
25 theta = numpy.arctan2(y,x) 26 theta = numpy.arctan2(y,x)
26 phi = numpy.arccos(z/r) 27 phi = numpy.arccos(z/r)
@@ -41,15 +42,15 @@ def sh(theta, phi, l, m): @@ -41,15 +42,15 @@ def sh(theta, phi, l, m):
41 else: 42 else:
42 return scipy.special.sph_harm(0, l, theta, phi).real 43 return scipy.special.sph_harm(0, l, theta, phi).real
43 44
44 -#calculate a spherical harmonic value from a set of coefficients and a spherical coordinate  
45 -def sh_coeff(theta, phi, C): 45 +#calculate a spherical harmonic value from a set of coefficients and coordinates P = (theta, phi)
  46 +def sh_coeff(P, C):
46 47
47 - s = numpy.zeros(theta.shape) 48 + s = numpy.zeros(P[0].shape)
48 c = range(len(C)) 49 c = range(len(C))
49 50
50 for c in range(len(C)): 51 for c in range(len(C)):
51 l, m = i2lm(c) #get the 2D indices 52 l, m = i2lm(c) #get the 2D indices
52 - s = s + C[c] * sh(theta, phi, l, m) 53 + s = s + C[c] * sh(P[0], P[1], l, m)
53 54
54 return s 55 return s
55 56
@@ -86,34 +87,45 @@ def lm2i(l, m): @@ -86,34 +87,45 @@ def lm2i(l, m):
86 return l * (l+1) + m 87 return l * (l+1) + m
87 88
88 #generates a set of spherical harmonic coefficients from samples using linear least squares 89 #generates a set of spherical harmonic coefficients from samples using linear least squares
89 -def linfit(theta, phi, s, nc):  
90 - #allocate space for the matrix 90 +def linfit(P, s, nc, clock=False):
  91 + if clock:
  92 + start_time = time.time()
  93 +
  94 + #allocate space for the matrix and RHS values
91 A = numpy.zeros((nc, nc)) 95 A = numpy.zeros((nc, nc))
  96 + b = numpy.zeros(nc)
92 97
93 #calculate each of the matrix coefficients 98 #calculate each of the matrix coefficients
94 #(see SH technical report in the vascular_viz repository) 99 #(see SH technical report in the vascular_viz repository)
95 for i in range(nc): 100 for i in range(nc):
96 li, mi = i2lm(i) 101 li, mi = i2lm(i)
97 - yi = sh(theta, phi, li, mi) 102 + yi = sh(P[0], P[1], li, mi)
98 for j in range(nc): 103 for j in range(nc):
99 lj, mj = i2lm(j) 104 lj, mj = i2lm(j)
100 - yj = sh(theta, phi, lj, mj) 105 + yj = sh(P[0], P[1], lj, mj)
101 A[i, j] = numpy.sum(yi * yj) 106 A[i, j] = numpy.sum(yi * yj)
  107 + b[i] = numpy.sum(yi * s) #calculate the RHS value
102 108
103 #calculate the RHS values 109 #calculate the RHS values
104 - b = numpy.zeros(nc)  
105 - for j in range(nc):  
106 - lj, mj = i2lm(j)  
107 - yj = sh(theta, phi, lj, mj)  
108 - b[j] = numpy.sum(yj * s) 110 + #for j in range(nc):
  111 + # lj, mj = i2lm(j)
  112 + # yj = sh(theta, phi, lj, mj)
  113 + # b[j] = numpy.sum(yj * s)
109 114
  115 + if clock:
  116 + print("SH::linfit:matrix "+str(time.time() - start_time)+"s")
110 #solve the system of linear equations 117 #solve the system of linear equations
111 - return numpy.linalg.solve(A, b) 118 + R = numpy.linalg.solve(A, b)
  119 +
  120 + if clock:
  121 + print("SH::linfit:solution "+str(time.time() - start_time)+"s")
  122 + return R
112 123
113 #generate a scatter plot in 3D using spherical coordinates 124 #generate a scatter plot in 3D using spherical coordinates
114 -def scatterplot3d(theta, phi, r): 125 +def scatterplot3d(P):
  126 + r, theta, phi = P
115 #convert all of the samples to cartesian coordinates 127 #convert all of the samples to cartesian coordinates
116 - X, Y, Z = sph2cart(theta, phi, r) 128 + X, Y, Z = sph2cart(r, theta, phi)
117 129
118 fig = plt.figure() 130 fig = plt.figure()
119 ax = fig.add_subplot(111, projection='3d') 131 ax = fig.add_subplot(111, projection='3d')