matrix.h 8.35 KB
#ifndef RTS_MATRIX_H
#define RTS_MATRIX_H

//#include "rts/vector.h"
#include <string.h>
#include <iostream>
#include <stim/math/vector.h>
#include <stim/math/vec3.h>
//#include <stim/cuda/cudatools/callable.h>

namespace stim{

template <class T>
class matrix {
	//the matrix will be stored in column-major order (compatible with OpenGL)
	T* M;								//pointer to the matrix data
	size_t R;							//number of rows
	size_t C;							//number of colums

	/*void init(size_t rows, size_t cols){
		R = rows;
		C = cols;
		if (R == 0 || C == 0) M = NULL;
		else
			M = (T*)malloc(R * C * sizeof(T));	//allocate space for the matrix
	}*/

	T get(const size_t row, const size_t col) const {
		if (row >= R || col >= C) {
			std::cout << "ERROR: row or column out of range." << std::endl;
			exit(1);
		}
		return M[col * R + row];
	}

	T& at(size_t row, size_t col){
		if (row >= R || col >= C) {
			std::cout << "ERROR: row or column out of range." << std::endl;
			exit(1);
		}
		return M[col * R + row];
	}

public:
	matrix() {
		R = 0;
		C = 0;
		M = NULL;
	}

	matrix(size_t rows, size_t cols) {
		R = rows;
		C = cols;
		if (R * C == 0) 
			M = NULL;
		else
			M = new T[R * C];
	}

	matrix(size_t rows, size_t cols, T* data) {
		R = rows;
		C = cols;
		if (R * C == 0)
			M = NULL;
		else
			M = new T[R * C];
		memcpy(M, data, R * C * sizeof(T));
	}

	matrix(const matrix<T>& cpy){
		
		if (cpy.R * cpy.C == 0)
			M = NULL;
		else
			M = new T[cpy.R * cpy.C];
		memcpy(M, cpy.M, cpy.R * cpy.C * sizeof(T));

		R = cpy.R;
		C = cpy.C;
	}

	~matrix() {
		if(M) free(M);
		M = NULL;
		R = C = 0;
	}

	size_t rows() const {
		return R;
	}

	size_t cols() const {
		return C;
	}

	T& operator()(size_t row, size_t col) {
		return at(row, col);
	}

	matrix<T>& operator=(const T rhs) {
		//init(R, C);
		size_t N = R * C;
		for(size_t n=0; n<N; n++)
			M[n] = rhs;

		return *this;
	}

	matrix<T>& operator=(const matrix<T>& rhs){
		if (this != &rhs) {											//if the matrix isn't self-assigned
			T* new_matrix = new T[rhs.R * rhs.C];					//allocate new resources
			memcpy(new_matrix, rhs.M, rhs.R * rhs.C * sizeof(T));	//copy the matrix

			delete[] M;												//delete the previous array
			M = new_matrix;
			R = rhs.R;
			C = rhs.C;
		}
		return *this;
	}
	
	//element-wise operations
	matrix<T> operator+(const T rhs) const {
		matrix<T> result(R, C);					//create a result matrix
		size_t N = R * C;

		for(int i=0; i<N; i++)
			result.M[i] = M[i] + rhs;			//calculate the operation and assign to result

		return result;
	}

	matrix<T> operator+(const matrix<T> rhs) const {
		if (R != rhs.R || C != rhs.C) {
			std::cout << "ERROR: addition is only defined for matrices that are the same size." << std::endl;
			exit(1);
		}
		matrix<T> result(R, C);					//create a result matrix
		size_t N = R * C;

		for (int i = 0; i < N; i++)
			result.M[i] = M[i] + rhs.M[i];			//calculate the operation and assign to result

		return result;
	}

	matrix<T> operator-(const T rhs) const {
		return operator+(-rhs);					//add the negative of rhs
	}

	matrix<T> operator-(const matrix<T> rhs) const {
		return operator+(-rhs);
	}

	matrix<T> operator-() const {
		matrix<T> result(R, C);					//create a result matrix
		size_t N = R * C;

		for (int i = 0; i < N; i++)
			result.M[i] = -M[i];			//calculate the operation and assign to result

		return result;
	}

	matrix<T> operator*(const T rhs) const {
		matrix<T> result(R, C);					//create a result matrix
		size_t N = R * C;

		for(int i=0; i<N; i++)
			result.M[i] = M[i] * rhs;			//calculate the operation and assign to result

		return result;
	}

	matrix<T> operator/(const T rhs) const {
		matrix<T> result(R, C);					//create a result matrix
		size_t N = R * C;

		for(int i=0; i<N; i++)
			result.M[i] = M[i] / rhs;			//calculate the operation and assign to result

		return result;
	}

	//matrix multiplication
	matrix<T> operator*(const matrix<T> rhs) const {
		if(C != rhs.R){
			std::cout<<"ERROR: matrix multiplication is undefined for matrices of size ";
			std::cout<<"[ "<<R<<" x "<<C<<" ] and [ "<<rhs.R<<" x "<<rhs.C<<"]"<<std::endl;
			exit(1);
		}

		matrix<T> result(R, rhs.C);				//create the output matrix
		T inner;								//stores the running inner product
		size_t c, r, i;
		for(c = 0; c < rhs.C; c++){
			for(r = 0; r < R; r++){
				inner = (T)0;
				for(i = 0; i < C; i++){
					inner += get(r, i) * rhs.get(i, c);
				}
				result.M[c * R + r] = inner;
			}
		}
		return result;
	}

	//returns a pointer to the raw matrix data (in column major format)
	T* data(){
		return M;
	}

	//return a transposed matrix
	matrix<T> transpose() const {
		matrix<T> result(C, R);
		size_t c, r;
		for(c = 0; c < C; c++){
			for(r = 0; r < R; r++){
				result.M[r * C + c] = M[c * R + r];
			}
		}
		return result;
	}

	///Calculate and return the determinant of the matrix
	T det() const {
		if (R != C) {
			std::cout << "ERROR: a determinant can only be calculated for a square matrix." << std::endl;
			exit(1);
		}
		if (R == 1) return M[0];			//if the matrix only contains one value, return it

		int r, c, ri, cia, cib;
		T a = 0;
		T b = 0;
		for (c = 0; c < (int)C; c++) {
			for (r = 0; r < R; r++) {
				ri = r;
				cia = (r + c) % (int)C;
				cib = ((int)C - 1 - r) % (int)C;
				a += get(ri, cia);
				b += get(ri, cib);
			}
		}
		return a - b;
	}

	/// Sum all elements in the matrix
	T sum() const {
		size_t N = R * C;								//calculate the number of elements in the matrix
		T s = (T)0;										//allocate a register to store the sum
		for (size_t n = 0; n < N; n++) s += M[n];		//perform the summation
		return s;
	}

	/// Sort rows of the matrix by the specified indices
	matrix<T> sort_rows(size_t* idx) const {
		matrix<T> result(C, R);					//create the output matrix
		size_t r, c;
		for (c = 0; c < C; c++) {								//for each column
			for (r = 0; r < R; r++) {							//for each row element
				result.M[c * R + r] = M[c * R + idx[r]];		//copy each element of the row into its new position
			}
		}
		return result;
	}

	/// Sort columns of the matrix by the specified indices
	matrix<T> sort_cols(size_t* idx) const {
		matrix<T> result(C, R);
		size_t c;
		for (c = 0; c < C; c++) {											//for each column
			memcpy(&result.M[c * R], &M[idx[c] * R], sizeof(T) * R);		//copy the entire column from this matrix to the appropriate location
		}
		return result;
	}

	/// Return the column specified by index i
	matrix<T> col(size_t i) {
		matrix<T> c(R, 1);										//create a single column matrix
		memcpy(c.data(), &data()[R*i], C * sizeof(T));				//copy the column
		return c;
	}

	/// Return the row specified by index i
	matrix<T> row(size_t i) {
		matrix<T> r(1, C);										//create a single row matrix
		for (size_t c = 0; c < C; c++)
			r(0, c) = at(i, c);
		return r;
	}

	std::string toStr() const {
		std::stringstream ss;

		for(int r = 0; r < R; r++) {
			ss << "| ";
			for(int c=0; c<C; c++) {
				ss << M[c * R + r] << " ";
			}
			ss << "|" << std::endl;
		}
		return ss.str();
	}

	void csv(std::ostream& out) const {
		//std::stringstream csvss;
		for (size_t i = 0; i < R; i++) {
			out << std::fixed << M[i];
			for (size_t j = 1; j < C; j++)
				out << ", " << std::fixed << M[j * R + i];
			out << std::endl;
		}
		//return csvss.str();
	}

	std::string csv() const {
		std::stringstream csvss;
		int digits = std::numeric_limits<double>::max_digits10;
		csvss.precision(digits);
		csv(csvss);
		return csvss.str();
	}



	//save the data as a CSV file
	void csv(std::string filename) const {
		ofstream basisfile(filename.c_str());
		basisfile << csv();
		basisfile.close();
	}

	static matrix<T> I(size_t N) {
		matrix<T> result(N, N);							//create the identity matrix
		memset(result.M, 0, N * N * sizeof(T));			//set the entire matrix to zero
		for (size_t n = 0; n < N; n++) {
			result(n, n) = (T)1;						//set the diagonal component to 1
		}
		return result;
	}

	//loads a matrix from a stream in CSV format
	void csv(std::istream& in) {
		size_t c, r;
		T v;
		for (r = 0; r < R; r++) {
			for (c = 0; c < C; c++) {
				in >> v;
				if (in.peek() == ',') in.seekg(1, std::ios::cur);
				at(r, c) = v;;
			}
		}
	}

};

}	//end namespace rts


#endif