diff --git a/fibernet.py b/fibernet.py new file mode 100644 index 0000000..f37377e --- /dev/null +++ b/fibernet.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Jan 19 2018 + +@author: Jiabing +""" + +import struct +import numpy as np +import scipy as sp +import networkx as nx +import matplotlib.pyplot as plt +import math +from mpl_toolkits.mplot3d import Axes3D +from itertools import chain + +class Node: + def __init__(self, point, outgoing, incoming): + self.p = point + self.o = outgoing + self.i = incoming + + +class Fiber: + def __init__ (self, indices): + self.indices = indices + +class Point: + def __init__(self, x, y, z, r): + self.x = x + self.y = y + self.z = z + self.r = r + +class Edge: + def __init__(self, indices_num, pois, radius): + self.indices_num = indices_num + self.points = pois + self.radius = radius + + +class NWT: + + def readVertex(open_file): + points = np.tile(0., 3) + bytes = open_file.read(4) + points[0] = struct.unpack('f', bytes)[0] + bytes = open_file.read(4) + points[1] = struct.unpack('f', bytes)[0] + bytes = open_file.read(4) + points[2] = struct.unpack('f', bytes)[0] + bytes = open_file.read(4) + + numO = int.from_bytes(bytes, byteorder='little') + outgoing = np.tile(0, numO) + bts = open_file.read(4) + numI = int.from_bytes(bts, byteorder='little') + incoming = np.tile(0, numI) + for j in range(numO): + bytes = open_file.read(4) + outgoing[j] = int.from_bytes(bytes, byteorder='little') + + for j in range(numI): + bytes = open_file.read(4) + incoming[j] = int.from_bytes(bytes, byteorder='little') + + node = Node(points, outgoing, incoming) + return node + + + ''' + Reads a single fiber from an open file and returns a Fiber object . + ''' + def readFiber(open_file): + bytes = open_file.read(4) + vtx0 = int.from_bytes(bytes, byteorder = 'little') + bytes = open_file.read(4) + vtx1 = int.from_bytes(bytes, byteorder = 'little') + bytes = open_file.read(4) + numVerts = int.from_bytes(bytes, byteorder = 'little') + pts = [] + rads = [] + + for j in range(numVerts): + point = np.tile(0., 3) + bytes = open_file.read(4) + point[0] = struct.unpack('f', bytes)[0] + bytes = open_file.read(4) + point[1] = struct.unpack('f', bytes)[0] + bytes = open_file.read(4) + point[2] = struct.unpack('f', bytes)[0] + bytes = open_file.read(4) + radius = struct.unpack('f', bytes)[0] + pts.append(point) + rads.append(radius) + + # F = Fiber(pts) + E = Edge(numVerts, pts, radius) + #E = Edge(pts) + return E + + +class fibernet: + def __init__(self, filename): + + with open(filename, "rb") as file: + header = file.read(72) + bytes = file.read(4) + numVertex = int.from_bytes(bytes, byteorder='little') + bytes = file.read(4) + numEdges = int.from_bytes(bytes, byteorder='little') + + #self.P = [] + self.F = [] + self.N = [] + + for i in range(numVertex): + node = NWT.readVertex(file) + self.N.append(node) + + for i in range( numEdges): + edge = NWT.readFiber(file) + #self.F.append(np.arange(num,num+edge.indices_num,1)) + #self.P= chain(self.P, edge.points) + self.F.append(edge.points) + #num += edge.indices_num + + def aabb(self): + + lower = self.N[0].p.copy() + upper = lower.copy() + for i in self.N: + for c in range(len(lower)): + if lower[c] > i.p[c]: + lower[c] = i.p[c] + if upper[c] < i.p[c]: + upper[c] = i.p[c] + return lower, upper + + + + def distancefield(self, R=(100, 100, 100)): + + #generate a meshgrid of the appropriate size and resolution to surround the network + lower, upper = self.aabb() #get the space occupied by the network + x = np.linspace(lower[0], upper[0], R[0]) #get the grid points for uniform sampling of this space + y = np.linspace(lower[1], upper[1], R[1]) + z = np.linspace(lower[2], upper[2], R[2]) + X, Y, Z = np.meshgrid(x, y, z) + #Z = 150 * numpy.ones(X.shape) + + Q = np.stack((X, Y, Z), 3) + d_x = abs(x[1]-x[0]); + d_y = abs(y[1]-y[0]); + d_z = abs(z[1]-z[0]); + dis1 = math.sqrt(pow(d_x,2)+pow(d_y,2)+pow(d_z ,2)) + #dx = abs(x[1]-x[0]) + + #dy = abs(y[1]-y[0]) + #dz = abs(z[1]-z[0]) + #get a list of all node positions in the network + P = [] + + for e in self.F[12:13]: #12-17 + for p in e: + P.append(p) + + for j in range(len(e)-1): + d_t = e[j+1]-e[j] + dis2 = math.sqrt(pow(d_t[0],2)+pow(d_t[1],2)+pow(d_t[2],2)) + ins = max(int(d_t[0]/d_x), int(d_t[1]/d_y), int(d_t[2]/d_z)) + if( ins>0 ): + ins = ins+1; + for k in range(ins): + p_ins =e[j]+(k+1)*(e[j+1]-e[j])/ins; + P.append(p_ins); + #turn that list into a Numpy array so that we can create a KD tree + P = np.array(P) + + #generate a KD-Tree out of the network point array + tree = sp.spatial.cKDTree(P) + + #specify the resolution of the ouput grid + # R = (200, 200, 200) + + D, I = tree.query(Q) + + + + return D, Q, dis1 + + +''' +##read NWT file +f= fibernet("full_seg.nwt") +#P = tuple(f.P) +plist = f.F + +fig = plt.figure() +ax = fig.add_subplot(111, projection='3d') + +xs = [] +ys = [] +zs = [] +for i in range(10): + for j in range(len(F[i])): + xs.append(P[F[i][j]][0]) + ys.append(P[F[i][j]][1]) + zs.append(P[F[i][j]][2]) + + +#ax.scatter(xs, ys, zs) +#ax = fig.gca(projection='3d') +ax.set_xlabel('X Label') +ax.set_ylabel('Y Label') +ax.set_zlabel('Z Label') + +ax.plot(xs, ys, zs, label='center lines') +ax.legend() +plt.savefig('p.png', dpi=100) +plt.show() +''' \ No newline at end of file -- libgit2 0.21.4