天天看點

十字連結清單類模闆的實作

項目代碼連結:https://github.com/weiyx15/SparseMatrix/tree/master/SparseMatrix

第一次用C++寫類模闆,過程十分艱辛,代碼十分冗雜。寫幾個注意點吧:

1. 重載雙目運算符用友元函數

2. 父類的友元不是子類的友元

3. 子類不能繼承父類重載的=運算符

4. 有動态記憶體配置設定的類要自己實作拷貝構造函數(深拷貝)、析構函數和重載=運算符

----------------------------------------

代碼

OrthoList.h

/*
* 稀疏矩陣類
* 基于十字連結清單的實作
* 用成員函數實作+, 用友元函數重載*
* 20180401 first edition
*/
#ifndef _ORTHOLIST_H
#define _ORTHOLIST_H
#include<iostream>
#include<vector>
#include "matrix.h"
using namespace std;

template <class NodeValueType>		// NodeValueType: 節點元素的資料類型
class Node {						// 節點類模闆
public:
	int row;						// 節點行号
	int col;						// 節點列号
	NodeValueType val;				// 節點數值
	Node<NodeValueType>* right;		// 指向同行右方節點的指針
	Node<NodeValueType>* down;		// 指向同列下方節點的指針
};

template <class ListValueType>		// ListValueType: 節點元素的資料類型
class OrthoList {					// 十字連結清單類
public:								
		OrthoList(vector< vector<ListValueType> > mat);// 構造函數:從鄰接矩陣mat計算十字連結清單
		OrthoList(int row, int col);// 構造函數:從空行列頭指針開始建構
		OrthoList(void);			// 構造函數:構造空十字連結清單
		OrthoList(const OrthoList<ListValueType> &b);	// 拷貝構造函數:有動态記憶體配置設定要深拷貝
		virtual ~OrthoList();				// 析構函數(為滿足動态綁定的要求,設為虛函數)

protected:
	Node<ListValueType> **rHead, **cHead;		// 行頭指針和列頭指針
	int nRow, nCol, nElement;					// 行數,列數,非零元素個數

private:
	bool elementWiseAddition(Node<ListValueType> *aNode);// 逐個元素加法, 被Addition調用

public:
	Node<ListValueType> **getRHead() const;	// 傳回行頭指針
	Node<ListValueType> **getCHead() const;	// 傳回列頭指針
	int getRowNumber() const;				// 傳回行數
	int getColumnNumber() const;			// 傳回列數
	int getElementNumber() const;			// 傳回非零元素個數
	void printFull() const;					// 按鄰接矩陣形式列印
	void printSparse() const;				// 按三元組(row, col, val)形式列印
	bool add(int r, int c, ListValueType v);			// 添加節點,false: (r,c)超過範圍
	bool del(int r, int c);								// 删除節點,true:成功删除; false:沒找到
	bool Addition(const OrthoList<ListValueType> &b);	// 矩陣加法,true:能加; false:不能
	Matrix<ListValueType> toFull() const;				// 傳回Matrix對象
	OrthoList<ListValueType>& operator= (const OrthoList<ListValueType> &b);
	// 重載=運算符
	friend OrthoList<ListValueType> operator*
	(const OrthoList<ListValueType> &a, const OrthoList<ListValueType> &b)
	// 友元:重載雙目運算符*(矩陣乘法)
	{
		int ra = a.getRowNumber();
		int ca = a.getColumnNumber();
		int rb = b.getRowNumber();
		int cb = b.getColumnNumber();
		if (ca != rb)									// 不符合矩陣乘法的尺寸要求
		{
			cout << "Error: matrices sizes do not match!" << endl;
			OrthoList<ListValueType> nullItem;			// 調用void構造函數就是對象名後面不加()
			return nullItem;							// 傳回空對象						
		}
		OrthoList<ListValueType> ans(ra, cb);			// 構造計算結果
		Node<ListValueType> ** bRHead = b.getRHead();
		Node<ListValueType> ** bCHead = b.getCHead();
		Node<ListValueType> ** aRHead = a.getRHead();
		Node<ListValueType> ** aCHead = a.getCHead();
		Node<ListValueType> ** cAns = ans.getCHead();
		Node<ListValueType> ** rAns = ans.getRHead();
		Node<ListValueType> *p, *q;						// 連結清單周遊指針
		ListValueType compute = 0;						// ans[i,j]的計算結果
		bool has_value = false;							// 該項是否有值
		int i = 0, j = 0;
		for (i=0; i<ra; i++)
		{
			for (j=0; j<cb; j++)
			{
				has_value = false;						// 有值flag重置為false
				p = aRHead[i];
				q = bCHead[j];
				compute = 0;							// 乘法計算結果清零
				if (p && q)
				{
					p = p->right;
					q = q->down;
					while (p && q)
					{
						if (p->col < q->row)			// q在p後面
						{
							p = p->right;				// p往後趕
						}
						else if (p->col > q->row)		// p在q後面
						{
							q = q->down;				// q往後趕
						}
						else							// p->col == q->row
						{
							has_value = true;
							compute += p->val * q->val;	// ans[i,j] += A[i,p] * B[p,j];
							p = p->right;				// p,q一起往後趕
							q = q->down;				// p,q一起往後趕
						}
					}
				}
				if (has_value)							// 如果有非零值
				{
					ans.add(i, j, compute);
				}
			}
		}
		return ans;
	}
};

template <class ListValueType>
OrthoList <ListValueType>::OrthoList(vector< vector<ListValueType> > mat)
	:nRow(mat.size()), nCol(mat.at(0).size()), nElement(0),
	rHead(NULL), cHead(NULL)
{
	rHead = new Node<ListValueType>*[nRow];
	cHead = new Node<ListValueType>*[nCol]; 
	int i = 0, j = 0;
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}

	Node<ListValueType> *p, *q;						// 連結清單周遊指針
	for (i=0; i<nRow; i++)
	{
		for (j=0; j<nCol; j++)
		{
			if (mat.at(i).at(j) != 0)				// 如果是矩陣非零元
			{
				nElement++;
				Node<ListValueType>* aNode = new Node<ListValueType>();
				aNode->row = i;						// 建立新節點
				aNode->col = j;
				aNode->val = mat.at(i).at(j);
				aNode->right = NULL;
				aNode->down = NULL;
				if (rHead[i] && cHead[j])
				{
					p = rHead[i];
					q = cHead[j];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (rHead[i] == NULL && cHead[j] != NULL)
				{
					rHead[i] = new Node<ListValueType>();
					rHead[i]->right = aNode;
					q = cHead[j];
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (cHead[j] == NULL && rHead[i] != NULL)
				{
					cHead[j] = new Node<ListValueType>();
					cHead[j]->down = aNode;
					p = rHead[i];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
				}
				else
				{
					rHead[i] = new Node<ListValueType>();
					rHead[i]->right = aNode;
					cHead[j] = new Node<ListValueType>();
					cHead[j]->down = aNode;
				}
			}
		}
	}
}

template <class ListValueType>
OrthoList <ListValueType>::OrthoList(int row, int col): nRow(row), nCol(col), nElement(0)
{
	rHead = new Node<ListValueType>*[nRow];
	cHead = new Node<ListValueType>*[nCol]; 
	int i = 0;
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}
}

template <class ListValueType>
OrthoList <ListValueType>::OrthoList(void): nRow(0), nCol(0), nElement(0),
	rHead(NULL), cHead(NULL){}

template <class ListValueType>
OrthoList<ListValueType>::OrthoList(const OrthoList<ListValueType> &b): 
	nRow(b.getRowNumber()), nCol(b.getColumnNumber()), nElement(b.getElementNumber())
{
	int i = 0, j = 0;
	rHead = new Node<ListValueType>*[nRow];
	cHead = new Node<ListValueType>*[nCol]; 
	Node<ListValueType> **brHead = b.getRHead();
	Node<ListValueType> **bcHead = b.getCHead();
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}
	Node<ListValueType> *bp, *p, *q;
	for (i=0; i<nRow; i++)
	{
		bp = brHead[i];
		if (bp)
		{
			bp = bp->right;
			while (bp)
			{
				Node<ListValueType>* aNode = new Node<ListValueType>();	// 建立新節點
				aNode->row = bp->row;						
				aNode->col = bp->col;
				aNode->val = bp->val;
				aNode->right = NULL;
				aNode->down = NULL;
				if (rHead[bp->row] && cHead[bp->col])
				{
					p = rHead[bp->row];
					q = cHead[bp->col];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (rHead[bp->row] == NULL && cHead[bp->col] != NULL)
				{
					rHead[bp->row] = new Node<ListValueType>();
					rHead[bp->row]->right = aNode;
					q = cHead[bp->col];
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (cHead[bp->col] == NULL && rHead[bp->row] != NULL)
				{
					cHead[bp->col] = new Node<ListValueType>();
					cHead[bp->col]->down = aNode;
					p = rHead[bp->row];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
				}
				else
				{
					rHead[bp->row] = new Node<ListValueType>();
					rHead[bp->row]->right = aNode;
					cHead[bp->col] = new Node<ListValueType>();
					cHead[bp->col]->down = aNode;
				}
				bp = bp->right;
			}
		}
	}
}

template <class ListValueType>
OrthoList <ListValueType>::~OrthoList()
{
	int i = 0;
	Node<ListValueType> *p, *q;
	for (i=0; i<nRow; i++)
	{
		p = rHead[i];
		while(p != NULL)
		{
			q = p;
			p = p->right;
			delete q;
			q = NULL;
		}
	}
	delete rHead;
	delete cHead;
	rHead = NULL;
	cHead = NULL;
}

template <class ListValueType>
OrthoList<ListValueType>& OrthoList<ListValueType>::operator= (const OrthoList<ListValueType> &b)
{
	nRow = b.getRowNumber();
	nCol = b.getColumnNumber();
	nElement = b.getElementNumber();
	rHead = new Node<ListValueType>*[nRow];
	cHead = new Node<ListValueType>*[nCol]; 
	Node<ListValueType> **brHead = b.getRHead();
	Node<ListValueType> **bcHead = b.getCHead();
	int i = 0, j = 0;
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}
	Node<ListValueType> *bp, *p, *q;
	for (i=0; i<nRow; i++)
	{
		bp = brHead[i];
		if (bp)
		{
			bp = bp->right;
			while (bp)
			{
				Node<ListValueType>* aNode = new Node<ListValueType>();	// 建立新節點
				aNode->row = bp->row;						
				aNode->col = bp->col;
				aNode->val = bp->val;
				aNode->right = NULL;
				aNode->down = NULL;
				if (rHead[bp->row] && cHead[bp->col])
				{
					p = rHead[bp->row];
					q = cHead[bp->col];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (rHead[i] == NULL && cHead[j] != NULL)
				{
					rHead[bp->row] = new Node<ListValueType>();
					rHead[bp->row]->right = aNode;
					q = cHead[bp->col];
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (cHead[bp->col] == NULL && rHead[bp->row] != NULL)
				{
					cHead[bp->col] = new Node<ListValueType>();
					cHead[bp->col]->down = aNode;
					p = rHead[bp->row];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
				}
				else
				{
					rHead[bp->row] = new Node<ListValueType>();
					rHead[bp->row]->right = aNode;
					cHead[bp->col] = new Node<ListValueType>();
					cHead[bp->col]->down = aNode;
				}
				bp = bp->right;
			}
		}
	}
	return *this;
}

template <class ListValueType>
Node<ListValueType> **OrthoList <ListValueType>::getRHead() const
{
	return rHead;
}

template <class ListValueType>
Node<ListValueType> **OrthoList <ListValueType>::getCHead() const
{
	return cHead;
}

template <class ListValueType>
int OrthoList <ListValueType>::getRowNumber() const
{
	return nRow;
}

template <class ListValueType>
int OrthoList <ListValueType>::getColumnNumber() const
{
	return nCol;
}

template <class ListValueType>
int OrthoList <ListValueType>::getElementNumber() const
{
	return nElement;
}

template <class ListValueType>
void OrthoList <ListValueType>::printFull() const
{
	int i = 0, j = 0;
	Node<ListValueType> *p;
	for (i=0; i<nRow; i++)
	{
		vector<int> line(nCol,0);						// 存儲矩陣的一行, 初始化為全0
		if (rHead[i]!=NULL)
		{
			p = rHead[i]->right;
			while (p!=NULL)
			{
				line.at(p->col) = p->val;
				p = p->right;
			}
		}
		for (j=0; j<nCol; j++)
		{
			cout << line.at(j) << " ";
		}
		cout << endl;
	}
}

template <class ListValueType>
void OrthoList <ListValueType>::printSparse() const
{
	int i = 0;
	Node<ListValueType> *p;
	for (i=0; i<nRow; i++)
	{
		if (rHead[i]!=NULL)
		{
			p = rHead[i]->right;
			while (p!=NULL)
			{
				cout << "( " << p->row << ", " << p->col << ", " << p->val << " )" << endl;
				// 輸出(row, col, val)三元組
				p = p->right;
			}
		}
	}
}

// 在(r,c)位置加入新節點,若(r,c)位置已有節點,則替換為新值
template <class ListValueType>
bool OrthoList <ListValueType>::add(int r, int c, ListValueType v)
{
	if (r>=nRow || c>=nCol || r<0 || c<0)
	{
		return false;
	}
	Node<ListValueType> *aNode = new Node<ListValueType>();
	aNode -> row = r;
	aNode -> col = c;
	aNode -> val = v;
	aNode -> right = NULL;
	aNode -> down = NULL;
	bool r_flag = false;								// 是否在行中已經添加過aNode的标志
	bool c_flag = false;								// 是否在列中已經修改連接配接關系
	Node<ListValueType> *p, *q, *tmpNode;				// 周遊連結清單指針
	// 向行中插入新元素或修改舊元素
	if (rHead[r] && cHead[c])
	{
		p = rHead[r]->right;
		q = cHead[c]->down;
		if (p->col > c)									// c的列值小于rHead[r]行首的列值
		{
			rHead[r]->right = aNode;
			aNode->right = p;
			nElement++;
			r_flag = true;
			// 修改列的連接配接關系
			if (q->row > r)
			{
				cHead[c]->down = aNode;
				aNode->down = q;
				c_flag = true;
			}
			else
			{
				while (q->down != NULL)
				{
					if (q->row < r && q->down->row > r)
					{
						tmpNode = q->down;
						q->down = aNode;
						aNode->down = tmpNode;
						c_flag = true;
						break;
					}
					else
					{
						q = q->down;
					}
				}
				if (c_flag == false)
				{
					q->down = aNode;
				}
			}
		}
		else if (p->col == c)					// 相等則替換
		// 因為while循環條件: p->right != NULL, 是以每行第0個節點需要單獨讨論
		{
			p->val = v;
			delete aNode;
			r_flag = true;
		}
		else
		{
			while (p->right != NULL)
			{
				if (p->right->col == c)				// 相等則替換
				// 考慮下一個節點, 這樣寫是為了相容while循環條件: p->right != NULL
				{
					p -> right -> val = v;
					delete aNode;							// 此時無須添加新節點
					r_flag = true;
					break;
				}
				else if (p->col < c && p->right->col >c)	// 不等則插入
				{
					tmpNode = p->right;
					p->right = aNode;
					aNode->right = tmpNode;
					nElement++;
					r_flag = true;
					// 修改列的連接配接關系
					if (q->row > r)
					{
						cHead[c]->down = aNode;
						aNode->down = q;
						c_flag = true;
					}
					else
					{
						while (q->down != NULL)
						{
							if (q->row < r && q->down->row >r)
							{
								tmpNode = q->down;
								q->down = aNode;
								aNode->down = tmpNode;
								c_flag = true;
								break;
							}
							else
							{
								q = q->down;
							}
						}
						if (c_flag == false)
						{
							q->down = aNode;
						}
					}
					break;
				}
				else
				{
					p = p->right;
				}
			}
			if (r_flag == false)
			{
				p->right = aNode;	// 到達while循環結尾說明aNode->c大于所有p->c
				nElement++;
				// 修改列的連接配接關系
				if (q->row > r)
				{
					cHead[c]->down = aNode;
					aNode->down = q;
					c_flag = true;
				}
				else
				{
					while (q->down != NULL)
					{
						if (q->row < r && q->down->row >r)
						{
							tmpNode = q->down;
							q->down = aNode;
							aNode->down = tmpNode;
							c_flag = true;
							break;
						}
						else
						{
							q = q->down;
						}
					}
					if (c_flag == false)
					{
						q->down = aNode;
					}
				}
			}
		}
	}
	else if (rHead[r] == NULL && cHead[c] != NULL)
	{
		rHead[r] = new Node<ListValueType>();
		rHead[r]->right = aNode;
		q = cHead[c];
		int c_flag = false;								// 是否已經在列中修改連接配接關系
		// 修改列的連接配接關系
		if (q->row > r)
		{
			cHead[c]->down = aNode;
			aNode->down = q;
			c_flag = true;
		}
		else
		{
			while (q->down != NULL)
			{
				if (q->row < r && q->down->row >r)
				{
					tmpNode = q->down;
					q->down = aNode;
					aNode->down = tmpNode;
					c_flag = true;
					break;
				}
				else
				{
					q = q->down;
				}
			}
			if (c_flag == false)
			{
				q->down = aNode;
			}
		}
	}
	else if (cHead[c] == NULL && rHead[r] != NULL)
	{
		cHead[c] = new Node<ListValueType>();
		cHead[c]->down = aNode;
		p = rHead[r];
		int r_flag = false;								// 是否已經在行中修改連接配接關系
		// 修改行的連接配接關系
		if (p->col > c)
		{
			rHead[r]->right = aNode;
			aNode->right = p;
			r_flag = true;
		}
		else
		{
			while (p->right != NULL)
			{
				if (p->col < c && p->right->col >c)
				{
					tmpNode = p->right;
					p->right = aNode;
					aNode->right = tmpNode;
					r_flag = true;
					break;
				}
				else
				{
					p = p->right;
				}
			}
			if (r_flag == false)
			{
				p->right = aNode;
			}
		}
	}
	else
	{
		rHead[r] = new Node<ListValueType>();
		rHead[r]->right = aNode;
		cHead[c] = new Node<ListValueType>();
		cHead[c]->down = aNode;
	}
	return true;
}

// 稀疏矩陣this與單個元素aNode(r,c,v)相加, this(r,c)不存在則建新, 存在則相加
// 複用bool OrthoList <ListValueType>::add(int r, int c, ListValueType v)的代碼進行修改
template <class ListValueType>
bool OrthoList <ListValueType>::elementWiseAddition(Node<ListValueType> *aNode)
{
	int r = aNode->row;
	int c = aNode->col;
	ListValueType v = aNode->val;
	if (r>=nRow || c>=nCol || r<0 || c<0)
	{
		return false;									// 尺寸溢出,傳回false
	}
	bool r_flag = false;								// 是否在行中已經添加過aNode的标志
	bool c_flag = false;								// 是否在列中已經修改連接配接關系
	Node<ListValueType> *p, *q, *tmpNode;				// 周遊連結清單指針
	// 向行中插入新元素或修改舊元素
	if (rHead[r] && cHead[c])
	{
		p = rHead[r]->right;
		q = cHead[c]->down;
		if (p->col > c)									// c的列值小于rHead[r]行首的列值
		{
			rHead[r]->right = aNode;
			aNode->right = p;
			r_flag = true;
			// 修改列的連接配接關系
			if (q->row > r)
			{
				cHead[c]->down = aNode;
				aNode->down = q;
				c_flag = true;
			}
			else
			{
				while (q->down != NULL)
				{
					if (q->row < r && q->down->row > r)
					{
						tmpNode = q->down;
						q->down = aNode;
						aNode->down = tmpNode;
						c_flag = true;
						break;
					}
					else
					{
						q = q->down;
					}
				}
				if (c_flag == false)
				{
					q->down = aNode;
				}
			}
		}
		else if (p->col == c)					// 相等則替換
		// 因為while循環條件: p->right != NULL, 是以每行第0個節點需要單獨讨論
		{
			p->val += v;						// 将替換改為相加
			delete aNode;
			r_flag = true;
		}
		else
		{
			while (p->right != NULL)
			{
				if (p->right->col == c)				// 相等則替換
				// 考慮下一個節點, 這樣寫是為了相容while循環條件: p->right != NULL
				{
					p -> right -> val += v;
					delete aNode;							// 此時無須添加新節點
					aNode = NULL;
					r_flag = true;
					break;
				}
				else if (p->col < c && p->right->col >c)	// 不等則插入
				{
					tmpNode = p->right;
					p->right = aNode;
					aNode->right = tmpNode;
					r_flag = true;
					// 修改列的連接配接關系
					if (q->row > r)
					{
						cHead[c]->down = aNode;
						aNode->down = q;
						c_flag = true;
					}
					else
					{
						while (q->down != NULL)
						{
							if (q->row < r && q->down->row >r)
							{
								tmpNode = q->down;
								q->down = aNode;
								aNode->down = tmpNode;
								c_flag = true;
								break;
							}
							else
							{
								q = q->down;
							}
						}
						if (c_flag == false)
						{
							q->down = aNode;
						}
					}
					break;
				}
				else
				{
					p = p->right;
				}
			}
			if (r_flag == false)
			{
				p->right = aNode;	// 到達while循環結尾說明aNode->c大于所有p->c
				// 修改列的連接配接關系
				if (q->row > r)
				{
					cHead[c]->down = aNode;
					aNode->down = q;
					c_flag = true;
				}
				else
				{
					while (q->down != NULL)
					{
						if (q->row < r && q->down->row >r)
						{
							tmpNode = q->down;
							q->down = aNode;
							aNode->down = tmpNode;
							c_flag = true;
							break;
						}
						else
						{
							q = q->down;
						}
					}
					if (c_flag == false)
					{
						q->down = aNode;
					}
				}
			}
		}
	}
	else if (rHead[r] == NULL && cHead[c] != NULL)
	{
		rHead[r] = new Node<ListValueType>();
		rHead[r]->right = aNode;
		q = cHead[c];
		int c_flag = false;								// 是否已經在列中修改連接配接關系
		// 修改列的連接配接關系
		if (q->row > r)
		{
			cHead[c]->down = aNode;
			aNode->down = q;
			c_flag = true;
		}
		else
		{
			while (q->down != NULL)
			{
				if (q->row < r && q->down->row >r)
				{
					tmpNode = q->down;
					q->down = aNode;
					aNode->down = tmpNode;
					c_flag = true;
					break;
				}
				else
				{
					q = q->down;
				}
			}
			if (c_flag == false)
			{
				q->down = aNode;
			}
		}
	}
	else if (cHead[c] == NULL && rHead[r] != NULL)
	{
		cHead[c] = new Node<ListValueType>();
		cHead[c]->down = aNode;
		p = rHead[r];
		int r_flag = false;								// 是否已經在行中修改連接配接關系
		// 修改行的連接配接關系
		if (p->col > c)
		{
			rHead[r]->right = aNode;
			aNode->right = p;
			r_flag = true;
		}
		else
		{
			while (p->right != NULL)
			{
				if (p->col < c && p->right->col >c)
				{
					tmpNode = p->right;
					p->right = aNode;
					aNode->right = tmpNode;
					r_flag = true;
					break;
				}
				else
				{
					p = p->right;
				}
			}
			if (r_flag == false)
			{
				p->right = aNode;
			}
		}
	}
	else
	{
		rHead[r] = new Node<ListValueType>();
		rHead[r]->right = aNode;
		cHead[c] = new Node<ListValueType>();
		cHead[c]->down = aNode;
	}
	return true;
}

template <class ListValueType>
bool OrthoList <ListValueType>::del(int r, int c)
{
	if (r>=nRow || c>= nCol || r<0 || c<0)						// 下标超出索引範圍
	{
		return false;
	}
	Node<ListValueType> *p, *q;									// 連結清單周遊指針
	Node<ListValueType> *tmpNode;								// 用于删除節點的中間變量
	p = rHead[r];
	q = cHead[c];
	if (p == NULL)												// 該行為空
	{
		return false;
	}
	else
	{
		while (p->right != NULL)
		{
			if (p->right->col == c)								// 找到欲删除元素
			{
				// 更改列的連接配接關系
				while (q->down->row != r)						// 周遊cHead[c]找row==r				
				{
					q = q->down;
				}
				if (q == cHead[c] && q->down->down == NULL)		// 删掉的是該列唯一的元素
				{
					delete cHead[c];
					cHead[c] = NULL;
				}
				else
				{
					q->down = q->down->down;					// 修改列的連接配接關系
				}
				if (p == rHead[r] && p->right->right == NULL)	// 删掉的是該行的唯一的元素
				{
					delete rHead[r]->right;
					rHead[r]->right = NULL;
					delete rHead[r];
					rHead[r] = NULL;
				}
				else
				{
					tmpNode = p->right;
					p->right = p->right->right;					// 修改行的連接配接關系
					delete tmpNode;
					tmpNode = NULL;
				}
				nElement--;
				return true;
			}
			p  = p->right;
		}
		return false;											// 找不到欲删除的元素
	}
}

template <class ListValueType>
bool OrthoList<ListValueType>::Addition(const OrthoList<ListValueType> &b)
{
	int rb = b.getRowNumber();
	int cb = b.getColumnNumber();
	Node<ListValueType> ** bRHead = b.getRHead();
	Node<ListValueType> ** bCHead = b.getCHead();
	if (rb != nRow || cb != nCol)
	{
		return false;								// 大小不比對: 傳回false
	}
	Node<ListValueType> *pb;						// 周遊b的指針
	int i = 0;
	for (i=0; i<nRow; i++)
	{
		pb = bRHead[i];
		if (pb != NULL)
		{
			pb = pb->right;
			while (pb != NULL)
			{
				Node<ListValueType> *aNode = new Node<ListValueType>();
				aNode -> row = pb -> row;
				aNode -> col = pb -> col;
				aNode -> val = pb -> val;
				if (!elementWiseAddition(aNode))	// 調用逐個加法私有方法
				{
					return false;
				}
				pb = pb->right;
			}
		}
	}
	return true;
}

template <class ListValueType>
Matrix<ListValueType> OrthoList <ListValueType>::toFull() const
{
	Matrix<ListValueType> mat(nRow, nCol);
	int i = 0;
	Node<ListValueType> *p;
	for (i=0; i<nRow; i++)
	{
		if (rHead[i]!=NULL)
		{
			p = rHead[i]->right;
			while (p!=NULL)
			{
				mat(p->row, p->col) = p->val;
				p = p->right;
			}
		}
	}
	return mat;
}


#endif
           

matrix.h

/*
* 稠密矩陣類
* 基于一維數組的實作
* 用友元函數重載+/-/*
* 20180401 first edition
*/
#ifndef _MATRIX_H
#define _MATRIX_H
#include <iostream>
#include <cassert>
using namespace std;
template <class T>
class Matrix {
    public:
        int m, n;
    private:
        T* data;
    public:
        //constructor
        Matrix(int r=0, int c=0): m(r), n(c) {
            if (r>0 && c>0) {
                data = new T[m*n];
                if (data==NULL) {
                    cerr<<"Allocate memory failed!"<<endl;
                    exit(1);
                }
                memset(data, 0, m*n*sizeof(T));
            } else
                data = NULL;
        }
		// copy constructor
		Matrix(const Matrix<T> &a): m(a.m), n(a.n)	// 深拷貝
		{
			int i = 0, j = 0;
			if (m>0 && n>0) {
                data = new T[m*n];
                if (data==NULL) {
                    cerr<<"Allocate memory failed!"<<endl;
                    exit(1);
                }
				else
				{
					T* aData = a.getData();
					for (i=0; i<m; i++)
					{
						for (j=0; j<n; j++)
						{
							data[i+j*m] = aData[i+j*m];
						}
					}
				}
            } else
                data = NULL;
        }
        //deconstructor
        ~Matrix() {
            m = n = 0;
            delete[] data;
            data = NULL;
        }
		// get pointer: data
		T* getData() const
		{
			return data;
		}
        //entry access
        T& operator() (int i, int j) const{
            assert(i>=0 && i<m && j>=0 && j<n);
            return data[i+j*m];
        }
        //assign
        Matrix<T>& operator= (const Matrix<T> &a) {
            m = a.m;    n = a.n;
			if (data != NULL)
			{
				delete[] data;
				data = NULL;
			}
            data = new T[m*n*sizeof(T)];
            if (data==NULL) {
                cerr<<"Allocate memory failed!"<<endl;
                exit(1);
            }
            for (int i=0; i<m; i++)
                for (int j=0; j<n; j++)
                    data[i+j*m] = a(i, j);
            return *this;
        }
        //display
        void display() {
            cout<<"("<<m<<", "<<n<<"):"<<endl;
            for (int i=0; i<m; i++) {
                for (int j=0; j<n; j++)
                    cout<<data[i+j*m]<<"\t";
                cout<<endl;
            }
        }
		// 友元:雙目操作符+,非成員函數
		friend Matrix<T> operator+ (const Matrix<T>& a, const Matrix<T>& b)
		{
			assert(a.m==b.m && a.n==b.n);
			Matrix<T> r(a.m, a.n);
			for (int i=0; i<a.m; i++)
				for (int j=0; j<a.n; j++)
					r(i, j) = a(i, j) + b(i, j);
			return r;
		}
		// 友元:雙目操作符-,非成員函數
		friend Matrix<T> operator- (const Matrix<T>& a, const Matrix<T>& b)
		{
			assert(a.m==b.m && a.n==b.n);
			Matrix<T> r(a.m, a.n);
			for (int i=0; i<a.m; i++)
				for (int j=0; j<a.n; j++)
					r(i, j) = a(i, j) - b(i, j);
			return r;
		}
		// 友元:雙目操作符*,非成員函數
		friend Matrix<T> operator* (const Matrix<T>& a, const Matrix<T>& b)
		{
			assert(a.n==b.m);
			Matrix<T> r(a.m, b.n);
			for (int i=0; i<a.m; i++)
				for (int j=0; j<b.n; j++)
					for (int k=0; k<a.n; k++)
						r(i, j) += a(i, k) * b(k, j);
			return r;
		}

};

#endif
           

complex.h

/*
* 複數類
* 用友元函數重載複數的+/-/*運算
* 20180401 first edition
*/
#ifndef _COMPLEX_H
#define _COMPLEX_H
#include <iostream>
using namespace std;
class Complex {
public:
	Complex(double r = 0.0, double i = 0.0) : real(r), imag(i) { }
	friend Complex operator+(const Complex &c1, const Complex &c2);
	friend Complex operator-(const Complex &c1, const Complex &c2);
	friend Complex operator*(const Complex &c1, const Complex &c2);
	Complex & operator+=(const Complex &c);
	friend ostream & operator<<(ostream &out, const Complex &c);
private:
	double real; //複數實部
	double imag; //複數虛部
};

Complex operator+(const Complex &c1, const Complex &c2){
	return Complex(c1.real+c2.real, c1.imag+c2.imag);
}
Complex operator-(const Complex &c1, const Complex &c2){
	return Complex(c1.real-c2.real, c1.imag-c2.imag);
}
Complex operator*(const Complex &c1, const Complex &c2){
	return Complex(c1.real*c2.real-c1.imag*c2.imag, c1.imag*c2.real+c1.real*c2.imag);
}
Complex & Complex::operator+=(const Complex &c)
{
	*this = *this + c;
	return *this;
}
ostream & operator<<(ostream &out, const Complex &c){
	out << "(" << c.real << ", " << c.imag << ")";
return out;
}

#endif
           

ComplexOrthoList.h

/*
* 複數稀疏矩陣類
* OrthoList的派生類
* 增加了從檔案讀入的構造函數
*/
#ifndef _COMPLEXORTHOLIST_H
#define _COMPLEXORTHOLIST_H
#include "OrthoList.h"
#include "matrix.h"
#include<fstream>
#include "complex.h"

class ComplexOrthoList : public OrthoList<Complex>
{
public:
	ComplexOrthoList(ifstream &fin);				// 構造函數:從檔案讀入
	ComplexOrthoList(void);							// 預設構造函數
	ComplexOrthoList(const ComplexOrthoList &b);	// 拷貝構造函數
	ComplexOrthoList & operator =(const ComplexOrthoList &b);	
	// 重載=運算符, 繼承類不能繼承基類的=運算符
	friend ComplexOrthoList operator *(const ComplexOrthoList &a, const ComplexOrthoList &b)
	// 重載*運算符:基類的友元函數不是繼承類的友元函數
	{
		int ra = a.getRowNumber();
		int ca = a.getColumnNumber();
		int rb = b.getRowNumber();
		int cb = b.getColumnNumber();
		if (ca != rb)									// 不符合矩陣乘法的尺寸要求
		{
			cout << "Error: matrices sizes do not match!" << endl;
			ComplexOrthoList nullItem;					// 調用void構造函數就是對象名後面不加()
			return nullItem;							// 傳回空對象						
		}
		ComplexOrthoList ans;							// 構造計算結果
		Node<Complex> ** bRHead = b.getRHead();
		Node<Complex> ** bCHead = b.getCHead();
		Node<Complex> ** aRHead = a.getRHead();
		Node<Complex> ** aCHead = a.getCHead();
		Node<Complex> ** cAns = ans.getCHead();
		Node<Complex> ** rAns = ans.getRHead();
		Node<Complex> *p, *q;						// 連結清單周遊指針
		Complex compute = 0;						// ans[i,j]的計算結果
		bool has_value = false;						// 該項是否有值
		int i = 0, j = 0;
		for (i=0; i<ra; i++)
		{
			for (j=0; j<cb; j++)
			{
				has_value = false;						// 有值flag重置為false
				p = aRHead[i];
				q = bCHead[j];
				compute = 0;							// 乘法計算結果清零
				if (p && q)
				{
					p = p->right;
					q = q->down;
					while (p && q)
					{
						if (p->col < q->row)			// q在p後面
						{
							p = p->right;				// p往後趕
						}
						else if (p->col > q->row)		// p在q後面
						{
							q = q->down;				// q往後趕
						}
						else							// p->col == q->row
						{
							has_value = true;
							compute += p->val * q->val;	// ans[i,j] += A[i,p] * B[p,j];
							p = p->right;				// p,q一起往後趕
							q = q->down;				// p,q一起往後趕
						}
					}
				}
				if (has_value)							// 如果有非零值
				{
					ans.add(i, j, compute);
				}
			}
		}
		return ans;
	}
};

ComplexOrthoList::ComplexOrthoList(ifstream &fin)
{
	fin >> nRow >> nCol;				// 首行2個整數是矩陣尺寸
	nElement = 0;						// 初始化nElement
	rHead = new Node<Complex>*[nRow];	// 初始化行頭指針
	cHead = new Node<Complex>*[nCol];	// 初始化列頭指針
	int i = 0, j = 0;					// 接下來每行有4個數,頭兩個是坐标,注意從1開始不是從0開始
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}
	double Real = 0, Imag = 0;			// 後2個是實部/虛部
	Node<Complex> *p, *q;				// 連結清單周遊指針
	while (fin >> i >> j >> Real >> Imag)
	{
		i--; j--;						// 下标化為從0開始的形式
		Complex com(Real,Imag);
		Node<Complex>* aNode = new Node<Complex>();
		aNode->row = i;					// 建立新節點
		aNode->col = j;
		aNode->val = com;
		aNode->right = NULL;
		aNode->down = NULL;
		if (rHead[i] && cHead[j])
		{
			p = rHead[i];
			q = cHead[j];
			while (p->right != NULL)
			{
				p = p->right;
			}
			p->right = aNode;
			while (q->down != NULL)
			{
				q = q->down;
			}
			q->down = aNode;
		}
		else if (rHead[i] == NULL && cHead[j] != NULL)
		{
			rHead[i] = new Node<Complex>();
			rHead[i]->right = aNode;
			q = cHead[j];
			while (q->down != NULL)
			{
				q = q->down;
			}
			q->down = aNode;
		}
		else if (cHead[j] == NULL && rHead[i] != NULL)
		{
			cHead[j] = new Node<Complex>();
			cHead[j]->down = aNode;
			p = rHead[i];
			while (p->right != NULL)
			{
				p = p->right;
			}
			p->right = aNode;
		}
		else
		{
			rHead[i] = new Node<Complex>();
			rHead[i]->right = aNode;
			cHead[j] = new Node<Complex>();
			cHead[j]->down = aNode;
		}
	}
}

ComplexOrthoList::ComplexOrthoList(void):OrthoList<Complex>(){}

ComplexOrthoList::ComplexOrthoList(const ComplexOrthoList &b):OrthoList<Complex>(b){}

ComplexOrthoList &ComplexOrthoList::operator =(const ComplexOrthoList &b)
{
	nRow = b.getRowNumber();
	nCol = b.getColumnNumber();
	nElement = b.getElementNumber();
	rHead = new Node<Complex>*[nRow];
	cHead = new Node<Complex>*[nCol]; 
	Node<Complex> **brHead = b.getRHead();
	Node<Complex> **bcHead = b.getCHead();
	int i = 0, j = 0;
	for (i=0; i<nRow; i++)
	{
		rHead[i] = NULL;
	}
	for (i=0; i<nCol; i++)
	{
		cHead[i] = NULL;
	}
	Node<Complex> *bp, *p, *q;
	for (i=0; i<nRow; i++)
	{
		bp = brHead[i];
		if (bp)
		{
			bp = bp->right;
			while (bp)
			{
				Node<Complex>* aNode = new Node<Complex>();	// 建立新節點
				aNode->row = bp->row;						
				aNode->col = bp->col;
				aNode->val = bp->val;
				aNode->right = NULL;
				aNode->down = NULL;
				if (rHead[bp->row] && cHead[bp->col])
				{
					p = rHead[bp->row];
					q = cHead[bp->col];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (rHead[i] == NULL && cHead[j] != NULL)
				{
					rHead[bp->row] = new Node<Complex>();
					rHead[bp->row]->right = aNode;
					q = cHead[bp->col];
					while (q->down != NULL)
					{
						q = q->down;
					}
					q->down = aNode;
				}
				else if (cHead[bp->col] == NULL && rHead[bp->row] != NULL)
				{
					cHead[bp->col] = new Node<Complex>();
					cHead[bp->col]->down = aNode;
					p = rHead[bp->row];
					while (p->right != NULL)
					{
						p = p->right;
					}
					p->right = aNode;
				}
				else
				{
					rHead[bp->row] = new Node<Complex>();
					rHead[bp->row]->right = aNode;
					cHead[bp->col] = new Node<Complex>();
					cHead[bp->col]->down = aNode;
				}
				bp = bp->right;
			}
		}
	}
	return *this;
}

#endif
           

test.cpp

#include "OrthoList.h"
#include "matrix.h"
#include "complex.h"
#include "ComplexOrthoList.h"
#include<ctime>

const int TEST_TIME = 1;							// 重複測試取平均的次數

int main()
{
	vector< vector<int> > mat, mat1;
	vector<int> aLine;
	// 第一行:0,1,0
	aLine.push_back(0);
	aLine.push_back(1);
	aLine.push_back(0);
	mat.push_back(aLine);
	aLine.clear();
	// 第二行:2,0,0
	aLine.push_back(2);
	aLine.push_back(0);
	aLine.push_back(0);
	mat.push_back(aLine);
	aLine.clear();
	// 第三行:0,0,3
	aLine.push_back(0);
	aLine.push_back(0);
	aLine.push_back(3);
	mat.push_back(aLine);
	
	cout << "**********************" << endl;
	cout << "****OrthoList test****" << endl;
	cout << "**********************" << endl;
	OrthoList<int> testOList(mat);
	testOList.printSparse();
	testOList.printFull();
	cout << "Add elements" << endl;
	testOList.add(0,0,-1);
	testOList.add(0,1,-2);
	testOList.add(0,2,-3);
	testOList.add(1,0,-4);
	testOList.add(1,1,-5);
	testOList.add(1,2,-6);
	testOList.add(2,0,-7);
	testOList.add(2,1,-8);
	testOList.add(2,2,-9);
	testOList.printSparse();
	testOList.printFull();
	cout << "Delete elements" << endl;
	testOList.del(0,0);
	testOList.del(1,1);
	testOList.del(1,0);
	testOList.printSparse();
	testOList.printFull();
	OrthoList<int> oList1(testOList);
	cout << " New sparse matrix: " << endl;
	oList1.printFull();
	cout << " matrix1 + matrix2: " << endl;
	testOList.Addition(oList1);
	testOList.printSparse();
	testOList.printFull();
	cout << "matrix1 * martrix2: " << endl;
	OrthoList<int> oList2(oList1);
	oList2 = testOList * oList1;
	oList2.printFull();

	cout << "**********************" << endl;
	cout << "*OrthoList test done**" << endl;
	cout << "**********************" << endl;
	cout << endl;
	cout << "----------------------" << endl;
	cout << "-----matrix test------" << endl;
	cout << "----------------------" << endl;
	int m = 2, n = 2;
    Matrix<int> a = oList1.toFull();
    Matrix<int> b = testOList.toFull();
    cout<<"Matrix a"; a.display(); cout<<endl;
    cout<<"Matrix b"; b.display(); cout<<endl;
    Matrix<int> c;
	c = (a+b);
    cout<<"a+b: "; c.display(); cout<<endl;
    Matrix<int> d;
	d = (a*b);
    cout<<"a*b:"; d.display(); 
	cout << "----------------------" << endl;
	cout << "---matrix test done---" << endl;
	cout << "----------------------" << endl << endl;

	cout << "@@@@@@@@@@@@@@@@@@@@@@" << endl;
	cout << "@@@ run time test @@@@" << endl;
	cout << "@@@@@@@@@@@@@@@@@@@@@@" << endl;
	ifstream fin("Real.txt");					// 資料檔案
	assert(fin);								// 檔案是否正常打開
	ComplexOrthoList sY(fin);					// sY: sparse Y matrix
	fin.close();
	fin.open("Real_2.txt");
	assert(fin);
	ComplexOrthoList sY1(fin);
	fin.close();
	ComplexOrthoList sY2(sY);
	Matrix<Complex> dY = sY.toFull();			// dY: dense Y matrix
	Matrix<Complex> dY1 = sY1.toFull();
	Matrix<Complex> dY2(dY);
	int i = 0;
	clock_t t0;
	// 測試矩陣加法
	t0 = clock();
	for (i=0; i<TEST_TIME; i++)
	{
		sY.Addition(sY1);
	}
	clock_t run_sparse_add = clock()-t0;
	run_sparse_add /= TEST_TIME;
	cout << TEST_TIME << " times SPARSE MATRIX addition: " 
		<< run_sparse_add << " ms" << endl;
	t0 = clock();
	for (i=0; i<TEST_TIME; i++)
	{
		dY2 = dY + dY1;
	}
	clock_t run_dense_add = clock()-t0;
	run_dense_add /= TEST_TIME;
	cout << TEST_TIME << " times DENSE MATRIX addition: " 
		<< run_dense_add << " ms" << endl;
	// 測試矩陣乘法
	t0 = clock();
	for (i=0; i<TEST_TIME; i++)
	{
		sY2 = sY * sY1;
	}
	clock_t run_sparse_mul = clock()-t0;
	run_sparse_mul /= TEST_TIME;
	cout << TEST_TIME << " times SPARSE MATRIX multiplication: " 
		<< run_sparse_mul << " ms" << endl;
	t0 = clock();
	for (i=0; i<TEST_TIME; i++)
	{
		dY2 = dY * dY1;
	}
	clock_t run_dense_mul = clock()-t0;
	run_dense_mul /= TEST_TIME;
	cout << TEST_TIME << " times DENSE MATRIX multipication: " 
		<< run_dense_mul << " ms" << endl;
	return 0;
}
           

繼續閱讀