Blame view

python/spharmonics.py 3.4 KB
e105516a   David Mayerich   added spherical h...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
  # -*- coding: utf-8 -*-
  """
  Created on Mon Dec 18 16:31:36 2017
  
  @author: david
  """
  
  import numpy
  import scipy
  import matplotlib.pyplot as plt
  from matplotlib import cm, colors
  from mpl_toolkits.mplot3d import Axes3D
  import math
  
  
  def sph2cart(theta, phi, r):
      x = r * numpy.cos(theta) * numpy.sin(phi)
      y = r * numpy.sin(theta) * numpy.sin(phi)
      z = r * numpy.cos(phi)
      
      return x, y, z
  
  def cart2sph(x,y,z):
      r = numpy.sqrt(x**2+y**2+z**2)
      theta = numpy.arctan2(y,x)
      phi = numpy.arccos(z/r)
      #if(x == 0):
      #    phi = 0
      #else:
      #    phi = numpy.arccos(z/r)
      #phi = numpy.arctan2(numpy.sqrt(x**2 + y**2), z)
      return r, theta, phi
  
  
  def sh(theta, phi, l, m):
      
      if m < 0:
          return numpy.sqrt(2) * (-1)**m * scipy.special.sph_harm(abs(m), l, theta, phi).imag
      elif m > 0:
          return numpy.sqrt(2) * (-1)**m * scipy.special.sph_harm(m, l, theta, phi).real
      else:
          return scipy.special.sph_harm(0, l, theta, phi).real
      
  #calculate a spherical harmonic value from a set of coefficients and a spherical coordinate        
  def sh_coeff(theta, phi, C):
      
      s = numpy.zeros(theta.shape)
      c = range(len(C))
      
      for c in range(len(C)):
          l, m = i2lm(c)                  #get the 2D indices
          s = s + C[c] * sh(theta, phi, l, m)
          
      return s
  
  #plot a spherical harmonic function on a sphere using N points
  def sh_plot(C, N):       
      phi = numpy.linspace(0, numpy.pi, N)
      theta = numpy.linspace(0, 2*numpy.pi, N)
      phi, theta = numpy.meshgrid(phi, theta)
      
      # The Cartesian coordinates of the unit sphere
      x = numpy.sin(phi) * numpy.cos(theta)
      y = numpy.sin(phi) * numpy.sin(theta)
      z = numpy.cos(phi)
      
      # Calculate the spherical harmonic Y(l,m) and normalize to [0,1]
      fcolors = sh_coeff(theta, phi, C)
      fmax, fmin = fcolors.max(), fcolors.min()
      fcolors = (fcolors - fmin)/(fmax - fmin)
      
      # Set the aspect ratio to 1 so our sphere looks spherical
      fig = plt.figure(figsize=plt.figaspect(1.))
      ax = fig.add_subplot(111, projection='3d')
      ax.plot_surface(x, y, z,  rstride=1, cstride=1, facecolors=cm.seismic(fcolors))
      # Turn off the axis planes
      ax.set_axis_off()
      plt.show()
      
  def i2lm(i):
      l = numpy.floor(numpy.sqrt(i))
      m = i - l *(l + 1)
      return l, m
  
  def lm2i(l, m):
      return l * (l+1) + m
  
  #generates a set of spherical harmonic coefficients from samples using linear least squares
  def linfit(theta, phi, s, nc):
      #allocate space for the matrix
      A = numpy.zeros((nc, nc))
      
      #calculate each of the matrix coefficients
          #(see SH technical report in the vascular_viz repository)
      for i in range(nc):
          li, mi = i2lm(i)
          yi = sh(theta, phi, li, mi)
          for j in range(nc):        
              lj, mj = i2lm(j)
              yj = sh(theta, phi, lj, mj)
              A[i, j] = numpy.sum(yi * yj)
      
      #calculate the RHS values
      b = numpy.zeros(nc)
      for j in range(nc):
          lj, mj = i2lm(j)
          yj = sh(theta, phi, lj, mj)
          b[j] = numpy.sum(yj * s)
      
      #solve the system of linear equations
      return numpy.linalg.solve(A, b)
  
  #generate a scatter plot in 3D using spherical coordinates
  def scatterplot3d(theta, phi, r):
      #convert all of the samples to cartesian coordinates
      X, Y, Z = sph2cart(theta, phi, r)
      
      fig = plt.figure()
      ax = fig.add_subplot(111, projection='3d')
      ax.scatter(X, Y, Z)
      plt.show()