Blame view

python/spharmonics.py 3.74 KB
e105516a   David Mayerich   added spherical h...
1
2
3
4
5
6
7
8
9
10
11
12
13
  # -*- coding: utf-8 -*-
  """
  Created on Mon Dec 18 16:31:36 2017
  
  @author: david
  """
  
  import numpy
  import scipy
  import matplotlib.pyplot as plt
  from matplotlib import cm, colors
  from mpl_toolkits.mplot3d import Axes3D
  import math
db331d8a   David Mayerich   added adjacency w...
14
  import time
e105516a   David Mayerich   added spherical h...
15
16
  
  
db331d8a   David Mayerich   added adjacency w...
17
  def sph2cart(r, theta, phi):
e105516a   David Mayerich   added spherical h...
18
19
20
21
22
23
      x = r * numpy.cos(theta) * numpy.sin(phi)
      y = r * numpy.sin(theta) * numpy.sin(phi)
      z = r * numpy.cos(phi)
      
      return x, y, z
  
db331d8a   David Mayerich   added adjacency w...
24
  def cart2sph(x, y, z):
e105516a   David Mayerich   added spherical h...
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
      r = numpy.sqrt(x**2+y**2+z**2)
      theta = numpy.arctan2(y,x)
      phi = numpy.arccos(z/r)
      #if(x == 0):
      #    phi = 0
      #else:
      #    phi = numpy.arccos(z/r)
      #phi = numpy.arctan2(numpy.sqrt(x**2 + y**2), z)
      return r, theta, phi
  
  
  def sh(theta, phi, l, m):
      
      if m < 0:
          return numpy.sqrt(2) * (-1)**m * scipy.special.sph_harm(abs(m), l, theta, phi).imag
      elif m > 0:
          return numpy.sqrt(2) * (-1)**m * scipy.special.sph_harm(m, l, theta, phi).real
      else:
          return scipy.special.sph_harm(0, l, theta, phi).real
      
db331d8a   David Mayerich   added adjacency w...
45
46
  #calculate a spherical harmonic value from a set of coefficients and coordinates P = (theta, phi)    
  def sh_coeff(P, C):
e105516a   David Mayerich   added spherical h...
47
      
db331d8a   David Mayerich   added adjacency w...
48
      s = numpy.zeros(P[0].shape)
e105516a   David Mayerich   added spherical h...
49
50
51
52
      c = range(len(C))
      
      for c in range(len(C)):
          l, m = i2lm(c)                  #get the 2D indices
db331d8a   David Mayerich   added adjacency w...
53
          s = s + C[c] * sh(P[0], P[1], l, m)
e105516a   David Mayerich   added spherical h...
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
          
      return s
  
  #plot a spherical harmonic function on a sphere using N points
  def sh_plot(C, N):       
      phi = numpy.linspace(0, numpy.pi, N)
      theta = numpy.linspace(0, 2*numpy.pi, N)
      phi, theta = numpy.meshgrid(phi, theta)
      
      # The Cartesian coordinates of the unit sphere
      x = numpy.sin(phi) * numpy.cos(theta)
      y = numpy.sin(phi) * numpy.sin(theta)
      z = numpy.cos(phi)
      
      # Calculate the spherical harmonic Y(l,m) and normalize to [0,1]
      fcolors = sh_coeff(theta, phi, C)
      fmax, fmin = fcolors.max(), fcolors.min()
      fcolors = (fcolors - fmin)/(fmax - fmin)
      
      # Set the aspect ratio to 1 so our sphere looks spherical
      fig = plt.figure(figsize=plt.figaspect(1.))
      ax = fig.add_subplot(111, projection='3d')
      ax.plot_surface(x, y, z,  rstride=1, cstride=1, facecolors=cm.seismic(fcolors))
      # Turn off the axis planes
      ax.set_axis_off()
      plt.show()
      
  def i2lm(i):
      l = numpy.floor(numpy.sqrt(i))
      m = i - l *(l + 1)
      return l, m
  
  def lm2i(l, m):
      return l * (l+1) + m
  
  #generates a set of spherical harmonic coefficients from samples using linear least squares
db331d8a   David Mayerich   added adjacency w...
90
91
92
93
94
  def linfit(P, s, nc, clock=False):
      if clock:
          start_time = time.time()
          
      #allocate space for the matrix and RHS values
e105516a   David Mayerich   added spherical h...
95
      A = numpy.zeros((nc, nc))
db331d8a   David Mayerich   added adjacency w...
96
      b = numpy.zeros(nc)
e105516a   David Mayerich   added spherical h...
97
98
99
100
101
      
      #calculate each of the matrix coefficients
          #(see SH technical report in the vascular_viz repository)
      for i in range(nc):
          li, mi = i2lm(i)
db331d8a   David Mayerich   added adjacency w...
102
          yi = sh(P[0], P[1], li, mi)
e105516a   David Mayerich   added spherical h...
103
104
          for j in range(nc):        
              lj, mj = i2lm(j)
db331d8a   David Mayerich   added adjacency w...
105
              yj = sh(P[0], P[1], lj, mj)
e105516a   David Mayerich   added spherical h...
106
              A[i, j] = numpy.sum(yi * yj)
db331d8a   David Mayerich   added adjacency w...
107
          b[i] = numpy.sum(yi * s)            #calculate the RHS value
e105516a   David Mayerich   added spherical h...
108
109
      
      #calculate the RHS values
db331d8a   David Mayerich   added adjacency w...
110
111
112
113
      #for j in range(nc):
      #    lj, mj = i2lm(j)
      #    yj = sh(theta, phi, lj, mj)
      #    b[j] = numpy.sum(yj * s)
e105516a   David Mayerich   added spherical h...
114
      
db331d8a   David Mayerich   added adjacency w...
115
116
      if clock:
          print("SH::linfit:matrix "+str(time.time() - start_time)+"s")
e105516a   David Mayerich   added spherical h...
117
      #solve the system of linear equations
db331d8a   David Mayerich   added adjacency w...
118
119
120
121
122
      R = numpy.linalg.solve(A, b)
      
      if clock:
          print("SH::linfit:solution "+str(time.time() - start_time)+"s")
      return R
e105516a   David Mayerich   added spherical h...
123
124
  
  #generate a scatter plot in 3D using spherical coordinates
db331d8a   David Mayerich   added adjacency w...
125
126
  def scatterplot3d(P):
      r, theta, phi = P
e105516a   David Mayerich   added spherical h...
127
      #convert all of the samples to cartesian coordinates
db331d8a   David Mayerich   added adjacency w...
128
      X, Y, Z = sph2cart(r, theta, phi)
e105516a   David Mayerich   added spherical h...
129
130
131
132
133
134
      
      fig = plt.figure()
      ax = fig.add_subplot(111, projection='3d')
      ax.scatter(X, Y, Z)
      plt.show()