spharmonics.py 3.4 KB
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 18 16:31:36 2017

@author: david
"""

import numpy
import scipy
import matplotlib.pyplot as plt
from matplotlib import cm, colors
from mpl_toolkits.mplot3d import Axes3D
import math


def sph2cart(theta, phi, r):
    x = r * numpy.cos(theta) * numpy.sin(phi)
    y = r * numpy.sin(theta) * numpy.sin(phi)
    z = r * numpy.cos(phi)
    
    return x, y, z

def cart2sph(x,y,z):
    r = numpy.sqrt(x**2+y**2+z**2)
    theta = numpy.arctan2(y,x)
    phi = numpy.arccos(z/r)
    #if(x == 0):
    #    phi = 0
    #else:
    #    phi = numpy.arccos(z/r)
    #phi = numpy.arctan2(numpy.sqrt(x**2 + y**2), z)
    return r, theta, phi


def sh(theta, phi, l, m):
    
    if m < 0:
        return numpy.sqrt(2) * (-1)**m * scipy.special.sph_harm(abs(m), l, theta, phi).imag
    elif m > 0:
        return numpy.sqrt(2) * (-1)**m * scipy.special.sph_harm(m, l, theta, phi).real
    else:
        return scipy.special.sph_harm(0, l, theta, phi).real
    
#calculate a spherical harmonic value from a set of coefficients and a spherical coordinate        
def sh_coeff(theta, phi, C):
    
    s = numpy.zeros(theta.shape)
    c = range(len(C))
    
    for c in range(len(C)):
        l, m = i2lm(c)                  #get the 2D indices
        s = s + C[c] * sh(theta, phi, l, m)
        
    return s

#plot a spherical harmonic function on a sphere using N points
def sh_plot(C, N):       
    phi = numpy.linspace(0, numpy.pi, N)
    theta = numpy.linspace(0, 2*numpy.pi, N)
    phi, theta = numpy.meshgrid(phi, theta)
    
    # The Cartesian coordinates of the unit sphere
    x = numpy.sin(phi) * numpy.cos(theta)
    y = numpy.sin(phi) * numpy.sin(theta)
    z = numpy.cos(phi)
    
    # Calculate the spherical harmonic Y(l,m) and normalize to [0,1]
    fcolors = sh_coeff(theta, phi, C)
    fmax, fmin = fcolors.max(), fcolors.min()
    fcolors = (fcolors - fmin)/(fmax - fmin)
    
    # Set the aspect ratio to 1 so our sphere looks spherical
    fig = plt.figure(figsize=plt.figaspect(1.))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(x, y, z,  rstride=1, cstride=1, facecolors=cm.seismic(fcolors))
    # Turn off the axis planes
    ax.set_axis_off()
    plt.show()
    
def i2lm(i):
    l = numpy.floor(numpy.sqrt(i))
    m = i - l *(l + 1)
    return l, m

def lm2i(l, m):
    return l * (l+1) + m

#generates a set of spherical harmonic coefficients from samples using linear least squares
def linfit(theta, phi, s, nc):
    #allocate space for the matrix
    A = numpy.zeros((nc, nc))
    
    #calculate each of the matrix coefficients
        #(see SH technical report in the vascular_viz repository)
    for i in range(nc):
        li, mi = i2lm(i)
        yi = sh(theta, phi, li, mi)
        for j in range(nc):        
            lj, mj = i2lm(j)
            yj = sh(theta, phi, lj, mj)
            A[i, j] = numpy.sum(yi * yj)
    
    #calculate the RHS values
    b = numpy.zeros(nc)
    for j in range(nc):
        lj, mj = i2lm(j)
        yj = sh(theta, phi, lj, mj)
        b[j] = numpy.sum(yj * s)
    
    #solve the system of linear equations
    return numpy.linalg.solve(A, b)

#generate a scatter plot in 3D using spherical coordinates
def scatterplot3d(theta, phi, r):
    #convert all of the samples to cartesian coordinates
    X, Y, Z = sph2cart(theta, phi, r)
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(X, Y, Z)
    plt.show()