matrix_sym.h 2.54 KB
``````#ifndef STIM_MATRIX_SYM_H
#define STIM_MATRIX_SYM_H

#include <stim/cuda/cudatools/callable.h>
#include <stim/math/matrix.h>

/* This class represents a rank 2, 3-dimensional tensor viable
for representing tensor fields such as structure and diffusion tensors
*/
namespace stim{

template <typename T, int D>
class matrix_sym{

protected:
//values are stored in column-major order as a lower-triangular matrix
T M[D*(D + 1)/2];

static size_t idx(size_t r, size_t c) {
//if the index is in the upper-triangular portion, swap the indices
if(r < c){
size_t t = r;
r = c;
c = t;
}

size_t ci = (c + 1) * (D + (D - c))/2 - 1;		//index to the end of column c
size_t i = ci - (D - r - 1);
return i;
}

//calculate the row and column given an index
//static void indices(size_t& r, size_t& c, size_t idx) {
//	size_t col = 0;
//	for ( ; col < D; col++)
//		if(idx <= ((D - col + D) * (col + 1)/2 - 1))
//			break;

//	c = col;
//	size_t ci = (D - (col - 1) + D) * col / 2 - 1;   //index to the end of last column col -1
//	r = idx - ci + c - 1;
//}
static void indices(size_t& r, size_t& c, size_t idx) {
size_t cf = -1/2 * sqrt(4 * D * D + 4 * D - (7 + 8 * idx)) + D - 1/2;
c = ceil(cf);
r = idx - D * c + c * (c + 1) / 2;
}

public:
//return the symmetric matrix associated with this tensor
stim::matrix<T> mat() {
stim::matrix<T> r;
r.setsym(M);
return r;
}

CUDA_CALLABLE T& operator()(int r, int c) {
return M[idx(r, c)];
}

CUDA_CALLABLE matrix_sym<T, D> operator=(T rhs) {
int Nsq = D*(D+1)/2;
for(int i=0; i<Nsq; i++)
M[i] = rhs;

return *this;
}

CUDA_CALLABLE matrix_sym<T, D> operator=(matrix_sym<T, D> rhs) {
size_t N = D * (D + 1) / 2;
for (size_t i = 0; i < N; i++) M[i] = rhs.M[i];
return *this;
}

CUDA_CALLABLE T trace() {
T tr = 0;
for (size_t i = 0; i < D; i++)		//for each diagonal value
tr += M[idx(i, i)];				//add the value on the diagonal
return tr;
}
CUDA_CALLABLE void operator_product(matrix_sym<T, D> &B, T rhs) {
int Nsq = D*(D+1)/2;
for(int i=0; i<Nsq; i++)
B.M[i] *= rhs;
}

//return the tensor as a string
std::string str() {
std::stringstream ss;
for(int r = 0; r < D; r++){
ss << "| ";
for(int c=0; c<D; c++)
{
ss << (*this)(r, c) << " ";
}
ss << "|" << std::endl;
}

return ss.str();
}

//returns an identity matrix
static matrix_sym<T, D> identity() {
matrix_sym<T, D> I;
I = 0;
for (size_t i = 0; i < D; i++)
I.M[matrix_sym<T, D>::idx(i, i)] = 1;
return I;
}
};

}	//end namespace stim

#endif
``````