Commit db331d8a99a1ff623a601d4ef5b40fb12019b890

Authored by David Mayerich
1 parent 8d96a467

added adjacency weight calculations and scaling functions to Network, updated sp…

…harmonics functions to take tuples
Showing 2 changed files with 153 additions and 65 deletions   Show diff stats
python/network.py
... ... @@ -11,6 +11,8 @@ import scipy as sp
11 11 import networkx as nx
12 12 import matplotlib.pyplot as plt
13 13 import math
  14 +import time
  15 +import spharmonics
14 16  
15 17 '''
16 18 Definition of the Node class
... ... @@ -33,18 +35,12 @@ class Node:
33 35 '''
34 36 class Fiber:
35 37  
36   - def __init__ (self):
37   - self.v0 = 0
38   - self.v1 = 0
39   - self.points = []
40   - self.radii = []
41   -
42   - #NOTE: there is no function overloading in Python
43   -# def __init__ (self, p1, p2, pois, rads):
44   -# self.v0 = p1
45   -# self.v1 = p2
46   -# self.points = pois
47   -# self.radii = rads
  38 +
  39 + def __init__ (self, p1, p2, pois, rads):
  40 + self.v0 = p1
  41 + self.v1 = p2
  42 + self.points = pois
  43 + self.radii = rads
48 44  
49 45 '''
50 46 return the length of the fiber.
... ... @@ -83,7 +79,7 @@ class NWT:
83 79 '''
84 80 Writes the header given and open file descripion, number of verticies and number of edges.
85 81 '''
86   - def writeHeader(self, open_file, numVerts, numEdges):
  82 + def writeHeader(open_file, numVerts, numEdges):
87 83 txt = "nwtFileFormat fileid(14B), desc(58B), #vertices(4B), #edges(4B): bindata"
88 84 b = bytearray()
89 85 b.extend(txt.encode())
... ... @@ -95,7 +91,7 @@ class NWT:
95 91 '''
96 92 Writes a single vertex to a file.
97 93 '''
98   - def writeVertex(self, open_file, vertex):
  94 + def writeVertex(open_file, vertex):
99 95 open_file.write(struct.pack('<f',vertex.p[0]))
100 96 open_file.write(struct.pack('<f',vertex.p[1]))
101 97 open_file.write(struct.pack('<f',vertex.p[2]))
... ... @@ -112,7 +108,7 @@ class NWT:
112 108 '''
113 109 Writes a single fiber to a file.
114 110 '''
115   - def writeFiber(self, open_file, edge):
  111 + def writeFiber(open_file, edge):
116 112 open_file.write(struct.pack('i',edge.v0))
117 113 open_file.write(struct.pack('i',edge.v1))
118 114 open_file.write(struct.pack('i',len(edge.points)))
... ... @@ -127,14 +123,14 @@ class NWT:
127 123 '''
128 124 Writes the entire network to a file in str given the vertices array and the edges array.
129 125 '''
130   - def exportNWT(self, str, vertices, edges):
  126 + def exportNWT(str, vertices, edges):
131 127 with open(str, "wb") as file:
132   - self.writeHeader(file, len(vertices), len(edges))
  128 + NWT.writeHeader(file, len(vertices), len(edges))
133 129 for i in range(len(vertices)):
134   - self.writeVertex(file, vertices[i])
  130 + NWT.writeVertex(file, vertices[i])
135 131  
136 132 for i in range(len(edges)):
137   - self.writeFiber(file, edges[i])
  133 + NWT.writeFiber(file, edges[i])
138 134  
139 135 return
140 136  
... ... @@ -142,7 +138,7 @@ class NWT:
142 138 '''
143 139 Reads a single vertex from an open file and returns a node Object.
144 140 '''
145   - def readVertex(self, open_file):
  141 + def readVertex(open_file):
146 142 points = np.tile(0., 3)
147 143 bytes = open_file.read(4)
148 144 points[0] = struct.unpack('f', bytes)[0]
... ... @@ -172,7 +168,7 @@ class NWT:
172 168 '''
173 169 Reads a single fiber from an open file and returns a Fiber object .
174 170 '''
175   - def readFiber(self, open_file):
  171 + def readFiber(open_file):
176 172 bytes = open_file.read(4)
177 173 vtx0 = int.from_bytes(bytes, byteorder = 'little')
178 174 bytes = open_file.read(4)
... ... @@ -206,8 +202,11 @@ class NWT:
206 202  
207 203 class Network:
208 204  
209   - def __init__(self, str):
210   - with open(str, "rb") as file:
  205 + def __init__(self, filename, clock=False):
  206 + if clock:
  207 + start_time = time.time()
  208 +
  209 + with open(filename, "rb") as file:
211 210 header = file.read(72)
212 211 bytes = file.read(4)
213 212 numVertex = int.from_bytes(bytes, byteorder='little')
... ... @@ -223,6 +222,8 @@ class Network:
223 222 for i in range(numEdges):
224 223 edge = NWT.readFiber(file)
225 224 self.F.append(edge)
  225 + if clock:
  226 + print("Network initialization: " + str(time.time() - start_time) + "s")
226 227  
227 228 '''
228 229 Creates a graph from a list of nodes and a list of edges.
... ... @@ -287,7 +288,7 @@ class Network:
287 288  
288 289 lower = self.N[0].p.copy()
289 290 upper = lower.copy()
290   - for i in self.nodeList:
  291 + for i in self.N:
291 292 for c in range(len(lower)):
292 293 if lower[c] > i.p[c]:
293 294 lower[c] = i.p[c]
... ... @@ -317,7 +318,7 @@ class Network:
317 318 R = (200, 200, 200)
318 319  
319 320 #generate a meshgrid of the appropriate size and resolution to surround the network
320   - lower, upper = self.aabb(self.N, self.F) #get the space occupied by the network
  321 + lower, upper = self.aabb(self.N, self.F) #get the space occupied by the network
321 322 x = np.linspace(lower[0], upper[0], R[0]) #get the grid points for uniform sampling of this space
322 323 y = np.linspace(lower[1], upper[1], R[1])
323 324 z = np.linspace(lower[2], upper[2], R[2])
... ... @@ -333,31 +334,106 @@ class Network:
333 334 return D
334 335  
335 336 #returns the number of points in the network
336   - def npoints(self):
337   - n = 0
338   - for f in self.F:
339   - n = n + len(f.points) - 2
340   - n = n + len(self.N)
341   - return n
  337 + def npoints(self):
  338 + n = 0 #initialize the counter to zero
  339 + for f in self.F: #for each fiber
  340 + n = n + len(f.points) - 2 #count the number of points in the fiber - ignoring the end points
  341 + n = n + len(self.N) #add the number of nodes (shared points) to the node count
  342 + return n #return the number of nodes
  343 +
  344 + #returns all of the points in the network
  345 + def points(self):
  346 + k = self.npoints()
  347 + P = np.zeros((3, k)) #allocate space for the point list
  348 +
  349 + idx = 0
  350 + for f in self.F: #for each fiber in the network
  351 + for ip in range(1, len(f.points)-1): #for each point in the network
  352 + P[:, idx] = f.points[ip] #store the point in the raw point list
  353 + idx = idx + 1
  354 + return P #return the point array
342 355  
343 356 #returns the number of linear segments in the network
344 357 def nsegments(self):
345   - n = 0
346   - for f in self.F:
347   - n = n + len(f.points) - 1
348   - return n
  358 + n = 0 #initialize the segment counter to 0
  359 + for f in self.F: #for each fiber
  360 + n = n + len(f.points) - 1 #calculate the number of line segments in the fiber (points - 1)
  361 + return n #return the number of line segments
349 362  
350   - def vectorize(self):
351   - #initialize three coordinate vectors ()
352   - n = self.nsegments(self.N, self.F)
353   - X = np.zeros((n))
354   - Y = np.zeros((n))
355   - Z = np.zeros((n))
  363 + #return a list of line segments representing the network
  364 + def segments(self, dtype=np.float32):
  365 + k = self.nsegments() #get the number of line segments
  366 + start = np.zeros((k, 3),dtype=dtype) #start points for the line segments
  367 + end = np.zeros((k, 3), dtype=dtype) #end points for the line segments
  368 +
  369 + idx = 0 #initialize the index counter to zero
  370 + for f in self.F: #for each fiber in the network
  371 + for ip in range(0, len(f.points)-1): #for each point in the network
  372 + start[idx, :] = f.points[ip] #store the point in the raw point list
  373 + idx = idx + 1
  374 +
356 375 idx = 0
357   - for i in range(0, len(self.F)):
358   - for j in range(0, len(self.F[i].points)-1):
359   - X[idx] = self.F[i].points[j][0]-self.F[i].points[j+1][0]
360   - Y[idx] = self.F[i].points[j][1]-self.F[i].points[j+1][1]
361   - Z[idx] = self.F[i].points[j][2]-self.F[i].points[j+1][2]
  376 + for f in self.F: #for each fiber in the network
  377 + for ip in range(1, len(f.points)): #for each point in the network
  378 + end[idx, :] = f.points[ip] #store the point in the raw point list
362 379 idx = idx + 1
363   - return X, Y, Z
  380 +
  381 + return start, end
  382 +
  383 + #function returns the fiber associated with a given 1D line segment index
  384 + def segment2fiber(self, idx):
  385 + i = 0
  386 + for f in range(len(self.F)): #for each fiber in the network
  387 + i = i + len(self.F[f].points)-1 #add the number of points in the fiber to i
  388 + if i > idx: #if we encounter idx in this fiber
  389 + return self.F[f].points, f #return the fiber associated with idx and the index into the fiber array
  390 +
  391 + def vectors(self, clock=False, dtype=np.float32):
  392 + if clock:
  393 + start_time = time.time()
  394 + start, end = self.segments(dtype) #retrieve all of the line segments
  395 + v = end - start #calculate the resulting vectors
  396 + l = np.sqrt(v[:, 0]**2 + v[:,1]**2 + v[:,2]**2) #calculate the fiber lengths
  397 + z = l==0 #look for any zero values
  398 + nz = z.sum()
  399 + if nz > 0:
  400 + print("WARNING: " + str(nz) + " line segment(s) of length zero were found in the network and will be removed" )
  401 +
  402 + if clock:
  403 + print("Network::vectors: " + str(time.time() - start_time) + "s")
  404 +
  405 + return np.delete(v, np.where(z), 0)
  406 +
  407 + #scale all values in the network by tuple S = (sx, sy, sz)
  408 + def scale(self, S):
  409 + for f in self.F:
  410 + for p in f.points:
  411 + p[0] = p[0] * S[0]
  412 + p[1] = p[1] * S[1]
  413 + p[2] = p[2] * S[2]
  414 +
  415 + for n in self.N:
  416 + n.p[0] = n.p[0] * S[0]
  417 + n.p[1] = n.p[1] * S[1]
  418 + n.p[2] = n.p[2] * S[2]
  419 +
  420 +
  421 + #calculate the adjacency weighting function for the network given a set of vectors X = (x, y, z) and weight exponent k
  422 + def adjacencyweight(self, P, k=200, length_threshold = 25, dtype=np.float32):
  423 + V = self.vectors(dtype) #get the vectors representing each segment
  424 + #V = V[0:n_vectors, :]
  425 + L = np.expand_dims(np.sqrt((V**2).sum(1)), 1) #calculate the length of each vector
  426 +
  427 + outliers = L > length_threshold #remove outliers based on the length_threshold
  428 + V = np.delete(V, np.where(outliers), 0)
  429 + L = np.delete(L, np.where(outliers))
  430 + V = V/L[:,None] #normalize the vectors
  431 +
  432 + P = np.stack(spharmonics.sph2cart(1, P[0], P[1]), P[0].ndim)
  433 + PV = P[...,None,:] * V
  434 + cos_alpha = PV.sum(PV.ndim-1)
  435 + W = np.abs(cos_alpha) ** k
  436 +
  437 + return W, L
  438 +
  439 +
... ...
python/spharmonics.py
... ... @@ -11,16 +11,17 @@ import matplotlib.pyplot as plt
11 11 from matplotlib import cm, colors
12 12 from mpl_toolkits.mplot3d import Axes3D
13 13 import math
  14 +import time
14 15  
15 16  
16   -def sph2cart(theta, phi, r):
  17 +def sph2cart(r, theta, phi):
17 18 x = r * numpy.cos(theta) * numpy.sin(phi)
18 19 y = r * numpy.sin(theta) * numpy.sin(phi)
19 20 z = r * numpy.cos(phi)
20 21  
21 22 return x, y, z
22 23  
23   -def cart2sph(x,y,z):
  24 +def cart2sph(x, y, z):
24 25 r = numpy.sqrt(x**2+y**2+z**2)
25 26 theta = numpy.arctan2(y,x)
26 27 phi = numpy.arccos(z/r)
... ... @@ -41,15 +42,15 @@ def sh(theta, phi, l, m):
41 42 else:
42 43 return scipy.special.sph_harm(0, l, theta, phi).real
43 44  
44   -#calculate a spherical harmonic value from a set of coefficients and a spherical coordinate
45   -def sh_coeff(theta, phi, C):
  45 +#calculate a spherical harmonic value from a set of coefficients and coordinates P = (theta, phi)
  46 +def sh_coeff(P, C):
46 47  
47   - s = numpy.zeros(theta.shape)
  48 + s = numpy.zeros(P[0].shape)
48 49 c = range(len(C))
49 50  
50 51 for c in range(len(C)):
51 52 l, m = i2lm(c) #get the 2D indices
52   - s = s + C[c] * sh(theta, phi, l, m)
  53 + s = s + C[c] * sh(P[0], P[1], l, m)
53 54  
54 55 return s
55 56  
... ... @@ -86,34 +87,45 @@ def lm2i(l, m):
86 87 return l * (l+1) + m
87 88  
88 89 #generates a set of spherical harmonic coefficients from samples using linear least squares
89   -def linfit(theta, phi, s, nc):
90   - #allocate space for the matrix
  90 +def linfit(P, s, nc, clock=False):
  91 + if clock:
  92 + start_time = time.time()
  93 +
  94 + #allocate space for the matrix and RHS values
91 95 A = numpy.zeros((nc, nc))
  96 + b = numpy.zeros(nc)
92 97  
93 98 #calculate each of the matrix coefficients
94 99 #(see SH technical report in the vascular_viz repository)
95 100 for i in range(nc):
96 101 li, mi = i2lm(i)
97   - yi = sh(theta, phi, li, mi)
  102 + yi = sh(P[0], P[1], li, mi)
98 103 for j in range(nc):
99 104 lj, mj = i2lm(j)
100   - yj = sh(theta, phi, lj, mj)
  105 + yj = sh(P[0], P[1], lj, mj)
101 106 A[i, j] = numpy.sum(yi * yj)
  107 + b[i] = numpy.sum(yi * s) #calculate the RHS value
102 108  
103 109 #calculate the RHS values
104   - b = numpy.zeros(nc)
105   - for j in range(nc):
106   - lj, mj = i2lm(j)
107   - yj = sh(theta, phi, lj, mj)
108   - b[j] = numpy.sum(yj * s)
  110 + #for j in range(nc):
  111 + # lj, mj = i2lm(j)
  112 + # yj = sh(theta, phi, lj, mj)
  113 + # b[j] = numpy.sum(yj * s)
109 114  
  115 + if clock:
  116 + print("SH::linfit:matrix "+str(time.time() - start_time)+"s")
110 117 #solve the system of linear equations
111   - return numpy.linalg.solve(A, b)
  118 + R = numpy.linalg.solve(A, b)
  119 +
  120 + if clock:
  121 + print("SH::linfit:solution "+str(time.time() - start_time)+"s")
  122 + return R
112 123  
113 124 #generate a scatter plot in 3D using spherical coordinates
114   -def scatterplot3d(theta, phi, r):
  125 +def scatterplot3d(P):
  126 + r, theta, phi = P
115 127 #convert all of the samples to cartesian coordinates
116   - X, Y, Z = sph2cart(theta, phi, r)
  128 + X, Y, Z = sph2cart(r, theta, phi)
117 129  
118 130 fig = plt.figure()
119 131 ax = fig.add_subplot(111, projection='3d')
... ...