fibernet.py 6.39 KB
# -*- coding: utf-8 -*-
"""
Created on Sat Jan 19 2018

@author: Jiabing
"""

import struct
import numpy as np
import scipy as sp
import networkx as nx
import matplotlib.pyplot as plt
import math
from mpl_toolkits.mplot3d import Axes3D
from itertools import chain

class Node:
    def __init__(self, point, outgoing, incoming):
        self.p = point
        self.o = outgoing
        self.i = incoming
       
            
class Fiber:
      def __init__ (self, indices):
        self.indices = indices
           
class Point:
      def __init__(self, x, y, z, r):
          self.x = x
          self.y = y
          self.z = z
          self.r = r
          
class Edge:
      def __init__(self, indices_num, pois, radius):
          self.indices_num = indices_num
          self.points = pois
          self.radius = radius
              
                       
class NWT:
    
    def readVertex(open_file):
        points = np.tile(0., 3)
        bytes = open_file.read(4)
        points[0] = struct.unpack('f', bytes)[0]
        bytes = open_file.read(4)
        points[1] = struct.unpack('f', bytes)[0]
        bytes = open_file.read(4)
        points[2] = struct.unpack('f', bytes)[0]
        bytes = open_file.read(4)
        
        numO = int.from_bytes(bytes, byteorder='little')
        outgoing = np.tile(0, numO)
        bts = open_file.read(4)
        numI = int.from_bytes(bts, byteorder='little')
        incoming = np.tile(0, numI)
        for j in range(numO):
            bytes = open_file.read(4)
            outgoing[j] = int.from_bytes(bytes, byteorder='little')
            
        for j in range(numI):
            bytes = open_file.read(4)
            incoming[j] = int.from_bytes(bytes, byteorder='little')
            
        node = Node(points, outgoing, incoming)    
        return node 


    '''
    Reads a single fiber from an open file and returns a Fiber object .   
    '''
    def readFiber(open_file):
        bytes = open_file.read(4)
        vtx0 = int.from_bytes(bytes, byteorder = 'little')
        bytes = open_file.read(4)
        vtx1 = int.from_bytes(bytes, byteorder = 'little')
        bytes = open_file.read(4)
        numVerts = int.from_bytes(bytes, byteorder = 'little')
        pts = []
        rads = []
        
        for j in range(numVerts):
            point = np.tile(0., 3)
            bytes = open_file.read(4)
            point[0] = struct.unpack('f', bytes)[0]
            bytes = open_file.read(4)
            point[1] = struct.unpack('f', bytes)[0]
            bytes = open_file.read(4)
            point[2] = struct.unpack('f', bytes)[0]
            bytes = open_file.read(4)
            radius = struct.unpack('f', bytes)[0]
            pts.append(point)
            rads.append(radius)
            
       # F = Fiber(pts)        
        E = Edge(numVerts, pts, radius)  
        #E = Edge(pts)  
        return E
    

class fibernet:
    def __init__(self, filename):
        
        with open(filename, "rb") as file:
            header = file.read(72)
            bytes = file.read(4)
            numVertex = int.from_bytes(bytes, byteorder='little')
            bytes = file.read(4)
            numEdges = int.from_bytes(bytes, byteorder='little')
            
            #self.P = []
            self.F = []
            self.N = []
                     
            for i in range(numVertex):
                node = NWT.readVertex(file)
                self.N.append(node)               
        
            for i in range( numEdges):
                edge = NWT.readFiber(file)                 
                #self.F.append(np.arange(num,num+edge.indices_num,1))
                #self.P= chain(self.P, edge.points)
                self.F.append(edge.points)
                #num += edge.indices_num
    
    def aabb(self):
    
        lower = self.N[0].p.copy()
        upper = lower.copy()
        for i in self.N:
            for c in range(len(lower)):
                if lower[c] > i.p[c]:
                    lower[c] = i.p[c]
                if upper[c] < i.p[c]:
                    upper[c] = i.p[c]
        return lower, upper
    
    
    
    def distancefield(self, R=(100, 100, 100)):      
        
        #generate a meshgrid of the appropriate size and resolution to surround the network
        lower, upper = self.aabb()    #get the space occupied by the network
        x = np.linspace(lower[0], upper[0], R[0])   #get the grid points for uniform sampling of this space
        y = np.linspace(lower[1], upper[1], R[1])
        z = np.linspace(lower[2], upper[2], R[2])
        X, Y, Z = np.meshgrid(x, y, z)
        #Z = 150 * numpy.ones(X.shape)
               
        Q = np.stack((X, Y, Z), 3)
        d_x = abs(x[1]-x[0]);
        d_y = abs(y[1]-y[0]);
        d_z = abs(z[1]-z[0]);
        dis1 = math.sqrt(pow(d_x,2)+pow(d_y,2)+pow(d_z ,2))
        #dx = abs(x[1]-x[0])
        
        #dy = abs(y[1]-y[0])
        #dz = abs(z[1]-z[0])
         #get a list of all node positions in the network
        P = []
      
        for e in self.F[12:13]:    #12-17
             for p in e:
                P.append(p)
               
             for j in range(len(e)-1):
                d_t = e[j+1]-e[j]
                dis2 = math.sqrt(pow(d_t[0],2)+pow(d_t[1],2)+pow(d_t[2],2))
                ins = max(int(d_t[0]/d_x), int(d_t[1]/d_y), int(d_t[2]/d_z))
                if( ins>0 ):  
                    ins = ins+1;
                    for k in range(ins):
                        p_ins =e[j]+(k+1)*(e[j+1]-e[j])/ins;
                        P.append(p_ins);
        #turn that list into a Numpy array so that we can create a KD tree
        P = np.array(P)
      
        #generate a KD-Tree out of the network point array
        tree = sp.spatial.cKDTree(P)
        
        #specify the resolution of the ouput grid
        # R = (200, 200, 200)

        D, I = tree.query(Q)
        
    
        
        return D, Q, dis1   
        
        
'''              
##read NWT file
f= fibernet("full_seg.nwt")
#P = tuple(f.P)
plist = f.F

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

xs = []
ys = []
zs = []
for i in range(10):
    for j in range(len(F[i])):       
        xs.append(P[F[i][j]][0])
        ys.append(P[F[i][j]][1])
        zs.append(P[F[i][j]][2])

   
#ax.scatter(xs, ys, zs)
#ax = fig.gca(projection='3d')
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')

ax.plot(xs, ys, zs, label='center lines')
ax.legend()
plt.savefig('p.png', dpi=100)
plt.show()
'''