項目代碼連結:https://github.com/weiyx15/SparseMatrix/tree/master/SparseMatrix
第一次用C++寫類模闆,過程十分艱辛,代碼十分冗雜。寫幾個注意點吧:
1. 重載雙目運算符用友元函數
2. 父類的友元不是子類的友元
3. 子類不能繼承父類重載的=運算符
4. 有動态記憶體配置設定的類要自己實作拷貝構造函數(深拷貝)、析構函數和重載=運算符
----------------------------------------
代碼
OrthoList.h
/*
* 稀疏矩陣類
* 基于十字連結清單的實作
* 用成員函數實作+, 用友元函數重載*
* 20180401 first edition
*/
#ifndef _ORTHOLIST_H
#define _ORTHOLIST_H
#include<iostream>
#include<vector>
#include "matrix.h"
using namespace std;
template <class NodeValueType> // NodeValueType: 節點元素的資料類型
class Node { // 節點類模闆
public:
int row; // 節點行号
int col; // 節點列号
NodeValueType val; // 節點數值
Node<NodeValueType>* right; // 指向同行右方節點的指針
Node<NodeValueType>* down; // 指向同列下方節點的指針
};
template <class ListValueType> // ListValueType: 節點元素的資料類型
class OrthoList { // 十字連結清單類
public:
OrthoList(vector< vector<ListValueType> > mat);// 構造函數:從鄰接矩陣mat計算十字連結清單
OrthoList(int row, int col);// 構造函數:從空行列頭指針開始建構
OrthoList(void); // 構造函數:構造空十字連結清單
OrthoList(const OrthoList<ListValueType> &b); // 拷貝構造函數:有動态記憶體配置設定要深拷貝
virtual ~OrthoList(); // 析構函數(為滿足動态綁定的要求,設為虛函數)
protected:
Node<ListValueType> **rHead, **cHead; // 行頭指針和列頭指針
int nRow, nCol, nElement; // 行數,列數,非零元素個數
private:
bool elementWiseAddition(Node<ListValueType> *aNode);// 逐個元素加法, 被Addition調用
public:
Node<ListValueType> **getRHead() const; // 傳回行頭指針
Node<ListValueType> **getCHead() const; // 傳回列頭指針
int getRowNumber() const; // 傳回行數
int getColumnNumber() const; // 傳回列數
int getElementNumber() const; // 傳回非零元素個數
void printFull() const; // 按鄰接矩陣形式列印
void printSparse() const; // 按三元組(row, col, val)形式列印
bool add(int r, int c, ListValueType v); // 添加節點,false: (r,c)超過範圍
bool del(int r, int c); // 删除節點,true:成功删除; false:沒找到
bool Addition(const OrthoList<ListValueType> &b); // 矩陣加法,true:能加; false:不能
Matrix<ListValueType> toFull() const; // 傳回Matrix對象
OrthoList<ListValueType>& operator= (const OrthoList<ListValueType> &b);
// 重載=運算符
friend OrthoList<ListValueType> operator*
(const OrthoList<ListValueType> &a, const OrthoList<ListValueType> &b)
// 友元:重載雙目運算符*(矩陣乘法)
{
int ra = a.getRowNumber();
int ca = a.getColumnNumber();
int rb = b.getRowNumber();
int cb = b.getColumnNumber();
if (ca != rb) // 不符合矩陣乘法的尺寸要求
{
cout << "Error: matrices sizes do not match!" << endl;
OrthoList<ListValueType> nullItem; // 調用void構造函數就是對象名後面不加()
return nullItem; // 傳回空對象
}
OrthoList<ListValueType> ans(ra, cb); // 構造計算結果
Node<ListValueType> ** bRHead = b.getRHead();
Node<ListValueType> ** bCHead = b.getCHead();
Node<ListValueType> ** aRHead = a.getRHead();
Node<ListValueType> ** aCHead = a.getCHead();
Node<ListValueType> ** cAns = ans.getCHead();
Node<ListValueType> ** rAns = ans.getRHead();
Node<ListValueType> *p, *q; // 連結清單周遊指針
ListValueType compute = 0; // ans[i,j]的計算結果
bool has_value = false; // 該項是否有值
int i = 0, j = 0;
for (i=0; i<ra; i++)
{
for (j=0; j<cb; j++)
{
has_value = false; // 有值flag重置為false
p = aRHead[i];
q = bCHead[j];
compute = 0; // 乘法計算結果清零
if (p && q)
{
p = p->right;
q = q->down;
while (p && q)
{
if (p->col < q->row) // q在p後面
{
p = p->right; // p往後趕
}
else if (p->col > q->row) // p在q後面
{
q = q->down; // q往後趕
}
else // p->col == q->row
{
has_value = true;
compute += p->val * q->val; // ans[i,j] += A[i,p] * B[p,j];
p = p->right; // p,q一起往後趕
q = q->down; // p,q一起往後趕
}
}
}
if (has_value) // 如果有非零值
{
ans.add(i, j, compute);
}
}
}
return ans;
}
};
template <class ListValueType>
OrthoList <ListValueType>::OrthoList(vector< vector<ListValueType> > mat)
:nRow(mat.size()), nCol(mat.at(0).size()), nElement(0),
rHead(NULL), cHead(NULL)
{
rHead = new Node<ListValueType>*[nRow];
cHead = new Node<ListValueType>*[nCol];
int i = 0, j = 0;
for (i=0; i<nRow; i++)
{
rHead[i] = NULL;
}
for (i=0; i<nCol; i++)
{
cHead[i] = NULL;
}
Node<ListValueType> *p, *q; // 連結清單周遊指針
for (i=0; i<nRow; i++)
{
for (j=0; j<nCol; j++)
{
if (mat.at(i).at(j) != 0) // 如果是矩陣非零元
{
nElement++;
Node<ListValueType>* aNode = new Node<ListValueType>();
aNode->row = i; // 建立新節點
aNode->col = j;
aNode->val = mat.at(i).at(j);
aNode->right = NULL;
aNode->down = NULL;
if (rHead[i] && cHead[j])
{
p = rHead[i];
q = cHead[j];
while (p->right != NULL)
{
p = p->right;
}
p->right = aNode;
while (q->down != NULL)
{
q = q->down;
}
q->down = aNode;
}
else if (rHead[i] == NULL && cHead[j] != NULL)
{
rHead[i] = new Node<ListValueType>();
rHead[i]->right = aNode;
q = cHead[j];
while (q->down != NULL)
{
q = q->down;
}
q->down = aNode;
}
else if (cHead[j] == NULL && rHead[i] != NULL)
{
cHead[j] = new Node<ListValueType>();
cHead[j]->down = aNode;
p = rHead[i];
while (p->right != NULL)
{
p = p->right;
}
p->right = aNode;
}
else
{
rHead[i] = new Node<ListValueType>();
rHead[i]->right = aNode;
cHead[j] = new Node<ListValueType>();
cHead[j]->down = aNode;
}
}
}
}
}
template <class ListValueType>
OrthoList <ListValueType>::OrthoList(int row, int col): nRow(row), nCol(col), nElement(0)
{
rHead = new Node<ListValueType>*[nRow];
cHead = new Node<ListValueType>*[nCol];
int i = 0;
for (i=0; i<nRow; i++)
{
rHead[i] = NULL;
}
for (i=0; i<nCol; i++)
{
cHead[i] = NULL;
}
}
template <class ListValueType>
OrthoList <ListValueType>::OrthoList(void): nRow(0), nCol(0), nElement(0),
rHead(NULL), cHead(NULL){}
template <class ListValueType>
OrthoList<ListValueType>::OrthoList(const OrthoList<ListValueType> &b):
nRow(b.getRowNumber()), nCol(b.getColumnNumber()), nElement(b.getElementNumber())
{
int i = 0, j = 0;
rHead = new Node<ListValueType>*[nRow];
cHead = new Node<ListValueType>*[nCol];
Node<ListValueType> **brHead = b.getRHead();
Node<ListValueType> **bcHead = b.getCHead();
for (i=0; i<nRow; i++)
{
rHead[i] = NULL;
}
for (i=0; i<nCol; i++)
{
cHead[i] = NULL;
}
Node<ListValueType> *bp, *p, *q;
for (i=0; i<nRow; i++)
{
bp = brHead[i];
if (bp)
{
bp = bp->right;
while (bp)
{
Node<ListValueType>* aNode = new Node<ListValueType>(); // 建立新節點
aNode->row = bp->row;
aNode->col = bp->col;
aNode->val = bp->val;
aNode->right = NULL;
aNode->down = NULL;
if (rHead[bp->row] && cHead[bp->col])
{
p = rHead[bp->row];
q = cHead[bp->col];
while (p->right != NULL)
{
p = p->right;
}
p->right = aNode;
while (q->down != NULL)
{
q = q->down;
}
q->down = aNode;
}
else if (rHead[bp->row] == NULL && cHead[bp->col] != NULL)
{
rHead[bp->row] = new Node<ListValueType>();
rHead[bp->row]->right = aNode;
q = cHead[bp->col];
while (q->down != NULL)
{
q = q->down;
}
q->down = aNode;
}
else if (cHead[bp->col] == NULL && rHead[bp->row] != NULL)
{
cHead[bp->col] = new Node<ListValueType>();
cHead[bp->col]->down = aNode;
p = rHead[bp->row];
while (p->right != NULL)
{
p = p->right;
}
p->right = aNode;
}
else
{
rHead[bp->row] = new Node<ListValueType>();
rHead[bp->row]->right = aNode;
cHead[bp->col] = new Node<ListValueType>();
cHead[bp->col]->down = aNode;
}
bp = bp->right;
}
}
}
}
template <class ListValueType>
OrthoList <ListValueType>::~OrthoList()
{
int i = 0;
Node<ListValueType> *p, *q;
for (i=0; i<nRow; i++)
{
p = rHead[i];
while(p != NULL)
{
q = p;
p = p->right;
delete q;
q = NULL;
}
}
delete rHead;
delete cHead;
rHead = NULL;
cHead = NULL;
}
template <class ListValueType>
OrthoList<ListValueType>& OrthoList<ListValueType>::operator= (const OrthoList<ListValueType> &b)
{
nRow = b.getRowNumber();
nCol = b.getColumnNumber();
nElement = b.getElementNumber();
rHead = new Node<ListValueType>*[nRow];
cHead = new Node<ListValueType>*[nCol];
Node<ListValueType> **brHead = b.getRHead();
Node<ListValueType> **bcHead = b.getCHead();
int i = 0, j = 0;
for (i=0; i<nRow; i++)
{
rHead[i] = NULL;
}
for (i=0; i<nCol; i++)
{
cHead[i] = NULL;
}
Node<ListValueType> *bp, *p, *q;
for (i=0; i<nRow; i++)
{
bp = brHead[i];
if (bp)
{
bp = bp->right;
while (bp)
{
Node<ListValueType>* aNode = new Node<ListValueType>(); // 建立新節點
aNode->row = bp->row;
aNode->col = bp->col;
aNode->val = bp->val;
aNode->right = NULL;
aNode->down = NULL;
if (rHead[bp->row] && cHead[bp->col])
{
p = rHead[bp->row];
q = cHead[bp->col];
while (p->right != NULL)
{
p = p->right;
}
p->right = aNode;
while (q->down != NULL)
{
q = q->down;
}
q->down = aNode;
}
else if (rHead[i] == NULL && cHead[j] != NULL)
{
rHead[bp->row] = new Node<ListValueType>();
rHead[bp->row]->right = aNode;
q = cHead[bp->col];
while (q->down != NULL)
{
q = q->down;
}
q->down = aNode;
}
else if (cHead[bp->col] == NULL && rHead[bp->row] != NULL)
{
cHead[bp->col] = new Node<ListValueType>();
cHead[bp->col]->down = aNode;
p = rHead[bp->row];
while (p->right != NULL)
{
p = p->right;
}
p->right = aNode;
}
else
{
rHead[bp->row] = new Node<ListValueType>();
rHead[bp->row]->right = aNode;
cHead[bp->col] = new Node<ListValueType>();
cHead[bp->col]->down = aNode;
}
bp = bp->right;
}
}
}
return *this;
}
template <class ListValueType>
Node<ListValueType> **OrthoList <ListValueType>::getRHead() const
{
return rHead;
}
template <class ListValueType>
Node<ListValueType> **OrthoList <ListValueType>::getCHead() const
{
return cHead;
}
template <class ListValueType>
int OrthoList <ListValueType>::getRowNumber() const
{
return nRow;
}
template <class ListValueType>
int OrthoList <ListValueType>::getColumnNumber() const
{
return nCol;
}
template <class ListValueType>
int OrthoList <ListValueType>::getElementNumber() const
{
return nElement;
}
template <class ListValueType>
void OrthoList <ListValueType>::printFull() const
{
int i = 0, j = 0;
Node<ListValueType> *p;
for (i=0; i<nRow; i++)
{
vector<int> line(nCol,0); // 存儲矩陣的一行, 初始化為全0
if (rHead[i]!=NULL)
{
p = rHead[i]->right;
while (p!=NULL)
{
line.at(p->col) = p->val;
p = p->right;
}
}
for (j=0; j<nCol; j++)
{
cout << line.at(j) << " ";
}
cout << endl;
}
}
template <class ListValueType>
void OrthoList <ListValueType>::printSparse() const
{
int i = 0;
Node<ListValueType> *p;
for (i=0; i<nRow; i++)
{
if (rHead[i]!=NULL)
{
p = rHead[i]->right;
while (p!=NULL)
{
cout << "( " << p->row << ", " << p->col << ", " << p->val << " )" << endl;
// 輸出(row, col, val)三元組
p = p->right;
}
}
}
}
// 在(r,c)位置加入新節點,若(r,c)位置已有節點,則替換為新值
template <class ListValueType>
bool OrthoList <ListValueType>::add(int r, int c, ListValueType v)
{
if (r>=nRow || c>=nCol || r<0 || c<0)
{
return false;
}
Node<ListValueType> *aNode = new Node<ListValueType>();
aNode -> row = r;
aNode -> col = c;
aNode -> val = v;
aNode -> right = NULL;
aNode -> down = NULL;
bool r_flag = false; // 是否在行中已經添加過aNode的标志
bool c_flag = false; // 是否在列中已經修改連接配接關系
Node<ListValueType> *p, *q, *tmpNode; // 周遊連結清單指針
// 向行中插入新元素或修改舊元素
if (rHead[r] && cHead[c])
{
p = rHead[r]->right;
q = cHead[c]->down;
if (p->col > c) // c的列值小于rHead[r]行首的列值
{
rHead[r]->right = aNode;
aNode->right = p;
nElement++;
r_flag = true;
// 修改列的連接配接關系
if (q->row > r)
{
cHead[c]->down = aNode;
aNode->down = q;
c_flag = true;
}
else
{
while (q->down != NULL)
{
if (q->row < r && q->down->row > r)
{
tmpNode = q->down;
q->down = aNode;
aNode->down = tmpNode;
c_flag = true;
break;
}
else
{
q = q->down;
}
}
if (c_flag == false)
{
q->down = aNode;
}
}
}
else if (p->col == c) // 相等則替換
// 因為while循環條件: p->right != NULL, 是以每行第0個節點需要單獨讨論
{
p->val = v;
delete aNode;
r_flag = true;
}
else
{
while (p->right != NULL)
{
if (p->right->col == c) // 相等則替換
// 考慮下一個節點, 這樣寫是為了相容while循環條件: p->right != NULL
{
p -> right -> val = v;
delete aNode; // 此時無須添加新節點
r_flag = true;
break;
}
else if (p->col < c && p->right->col >c) // 不等則插入
{
tmpNode = p->right;
p->right = aNode;
aNode->right = tmpNode;
nElement++;
r_flag = true;
// 修改列的連接配接關系
if (q->row > r)
{
cHead[c]->down = aNode;
aNode->down = q;
c_flag = true;
}
else
{
while (q->down != NULL)
{
if (q->row < r && q->down->row >r)
{
tmpNode = q->down;
q->down = aNode;
aNode->down = tmpNode;
c_flag = true;
break;
}
else
{
q = q->down;
}
}
if (c_flag == false)
{
q->down = aNode;
}
}
break;
}
else
{
p = p->right;
}
}
if (r_flag == false)
{
p->right = aNode; // 到達while循環結尾說明aNode->c大于所有p->c
nElement++;
// 修改列的連接配接關系
if (q->row > r)
{
cHead[c]->down = aNode;
aNode->down = q;
c_flag = true;
}
else
{
while (q->down != NULL)
{
if (q->row < r && q->down->row >r)
{
tmpNode = q->down;
q->down = aNode;
aNode->down = tmpNode;
c_flag = true;
break;
}
else
{
q = q->down;
}
}
if (c_flag == false)
{
q->down = aNode;
}
}
}
}
}
else if (rHead[r] == NULL && cHead[c] != NULL)
{
rHead[r] = new Node<ListValueType>();
rHead[r]->right = aNode;
q = cHead[c];
int c_flag = false; // 是否已經在列中修改連接配接關系
// 修改列的連接配接關系
if (q->row > r)
{
cHead[c]->down = aNode;
aNode->down = q;
c_flag = true;
}
else
{
while (q->down != NULL)
{
if (q->row < r && q->down->row >r)
{
tmpNode = q->down;
q->down = aNode;
aNode->down = tmpNode;
c_flag = true;
break;
}
else
{
q = q->down;
}
}
if (c_flag == false)
{
q->down = aNode;
}
}
}
else if (cHead[c] == NULL && rHead[r] != NULL)
{
cHead[c] = new Node<ListValueType>();
cHead[c]->down = aNode;
p = rHead[r];
int r_flag = false; // 是否已經在行中修改連接配接關系
// 修改行的連接配接關系
if (p->col > c)
{
rHead[r]->right = aNode;
aNode->right = p;
r_flag = true;
}
else
{
while (p->right != NULL)
{
if (p->col < c && p->right->col >c)
{
tmpNode = p->right;
p->right = aNode;
aNode->right = tmpNode;
r_flag = true;
break;
}
else
{
p = p->right;
}
}
if (r_flag == false)
{
p->right = aNode;
}
}
}
else
{
rHead[r] = new Node<ListValueType>();
rHead[r]->right = aNode;
cHead[c] = new Node<ListValueType>();
cHead[c]->down = aNode;
}
return true;
}
// 稀疏矩陣this與單個元素aNode(r,c,v)相加, this(r,c)不存在則建新, 存在則相加
// 複用bool OrthoList <ListValueType>::add(int r, int c, ListValueType v)的代碼進行修改
template <class ListValueType>
bool OrthoList <ListValueType>::elementWiseAddition(Node<ListValueType> *aNode)
{
int r = aNode->row;
int c = aNode->col;
ListValueType v = aNode->val;
if (r>=nRow || c>=nCol || r<0 || c<0)
{
return false; // 尺寸溢出,傳回false
}
bool r_flag = false; // 是否在行中已經添加過aNode的标志
bool c_flag = false; // 是否在列中已經修改連接配接關系
Node<ListValueType> *p, *q, *tmpNode; // 周遊連結清單指針
// 向行中插入新元素或修改舊元素
if (rHead[r] && cHead[c])
{
p = rHead[r]->right;
q = cHead[c]->down;
if (p->col > c) // c的列值小于rHead[r]行首的列值
{
rHead[r]->right = aNode;
aNode->right = p;
r_flag = true;
// 修改列的連接配接關系
if (q->row > r)
{
cHead[c]->down = aNode;
aNode->down = q;
c_flag = true;
}
else
{
while (q->down != NULL)
{
if (q->row < r && q->down->row > r)
{
tmpNode = q->down;
q->down = aNode;
aNode->down = tmpNode;
c_flag = true;
break;
}
else
{
q = q->down;
}
}
if (c_flag == false)
{
q->down = aNode;
}
}
}
else if (p->col == c) // 相等則替換
// 因為while循環條件: p->right != NULL, 是以每行第0個節點需要單獨讨論
{
p->val += v; // 将替換改為相加
delete aNode;
r_flag = true;
}
else
{
while (p->right != NULL)
{
if (p->right->col == c) // 相等則替換
// 考慮下一個節點, 這樣寫是為了相容while循環條件: p->right != NULL
{
p -> right -> val += v;
delete aNode; // 此時無須添加新節點
aNode = NULL;
r_flag = true;
break;
}
else if (p->col < c && p->right->col >c) // 不等則插入
{
tmpNode = p->right;
p->right = aNode;
aNode->right = tmpNode;
r_flag = true;
// 修改列的連接配接關系
if (q->row > r)
{
cHead[c]->down = aNode;
aNode->down = q;
c_flag = true;
}
else
{
while (q->down != NULL)
{
if (q->row < r && q->down->row >r)
{
tmpNode = q->down;
q->down = aNode;
aNode->down = tmpNode;
c_flag = true;
break;
}
else
{
q = q->down;
}
}
if (c_flag == false)
{
q->down = aNode;
}
}
break;
}
else
{
p = p->right;
}
}
if (r_flag == false)
{
p->right = aNode; // 到達while循環結尾說明aNode->c大于所有p->c
// 修改列的連接配接關系
if (q->row > r)
{
cHead[c]->down = aNode;
aNode->down = q;
c_flag = true;
}
else
{
while (q->down != NULL)
{
if (q->row < r && q->down->row >r)
{
tmpNode = q->down;
q->down = aNode;
aNode->down = tmpNode;
c_flag = true;
break;
}
else
{
q = q->down;
}
}
if (c_flag == false)
{
q->down = aNode;
}
}
}
}
}
else if (rHead[r] == NULL && cHead[c] != NULL)
{
rHead[r] = new Node<ListValueType>();
rHead[r]->right = aNode;
q = cHead[c];
int c_flag = false; // 是否已經在列中修改連接配接關系
// 修改列的連接配接關系
if (q->row > r)
{
cHead[c]->down = aNode;
aNode->down = q;
c_flag = true;
}
else
{
while (q->down != NULL)
{
if (q->row < r && q->down->row >r)
{
tmpNode = q->down;
q->down = aNode;
aNode->down = tmpNode;
c_flag = true;
break;
}
else
{
q = q->down;
}
}
if (c_flag == false)
{
q->down = aNode;
}
}
}
else if (cHead[c] == NULL && rHead[r] != NULL)
{
cHead[c] = new Node<ListValueType>();
cHead[c]->down = aNode;
p = rHead[r];
int r_flag = false; // 是否已經在行中修改連接配接關系
// 修改行的連接配接關系
if (p->col > c)
{
rHead[r]->right = aNode;
aNode->right = p;
r_flag = true;
}
else
{
while (p->right != NULL)
{
if (p->col < c && p->right->col >c)
{
tmpNode = p->right;
p->right = aNode;
aNode->right = tmpNode;
r_flag = true;
break;
}
else
{
p = p->right;
}
}
if (r_flag == false)
{
p->right = aNode;
}
}
}
else
{
rHead[r] = new Node<ListValueType>();
rHead[r]->right = aNode;
cHead[c] = new Node<ListValueType>();
cHead[c]->down = aNode;
}
return true;
}
template <class ListValueType>
bool OrthoList <ListValueType>::del(int r, int c)
{
if (r>=nRow || c>= nCol || r<0 || c<0) // 下标超出索引範圍
{
return false;
}
Node<ListValueType> *p, *q; // 連結清單周遊指針
Node<ListValueType> *tmpNode; // 用于删除節點的中間變量
p = rHead[r];
q = cHead[c];
if (p == NULL) // 該行為空
{
return false;
}
else
{
while (p->right != NULL)
{
if (p->right->col == c) // 找到欲删除元素
{
// 更改列的連接配接關系
while (q->down->row != r) // 周遊cHead[c]找row==r
{
q = q->down;
}
if (q == cHead[c] && q->down->down == NULL) // 删掉的是該列唯一的元素
{
delete cHead[c];
cHead[c] = NULL;
}
else
{
q->down = q->down->down; // 修改列的連接配接關系
}
if (p == rHead[r] && p->right->right == NULL) // 删掉的是該行的唯一的元素
{
delete rHead[r]->right;
rHead[r]->right = NULL;
delete rHead[r];
rHead[r] = NULL;
}
else
{
tmpNode = p->right;
p->right = p->right->right; // 修改行的連接配接關系
delete tmpNode;
tmpNode = NULL;
}
nElement--;
return true;
}
p = p->right;
}
return false; // 找不到欲删除的元素
}
}
template <class ListValueType>
bool OrthoList<ListValueType>::Addition(const OrthoList<ListValueType> &b)
{
int rb = b.getRowNumber();
int cb = b.getColumnNumber();
Node<ListValueType> ** bRHead = b.getRHead();
Node<ListValueType> ** bCHead = b.getCHead();
if (rb != nRow || cb != nCol)
{
return false; // 大小不比對: 傳回false
}
Node<ListValueType> *pb; // 周遊b的指針
int i = 0;
for (i=0; i<nRow; i++)
{
pb = bRHead[i];
if (pb != NULL)
{
pb = pb->right;
while (pb != NULL)
{
Node<ListValueType> *aNode = new Node<ListValueType>();
aNode -> row = pb -> row;
aNode -> col = pb -> col;
aNode -> val = pb -> val;
if (!elementWiseAddition(aNode)) // 調用逐個加法私有方法
{
return false;
}
pb = pb->right;
}
}
}
return true;
}
template <class ListValueType>
Matrix<ListValueType> OrthoList <ListValueType>::toFull() const
{
Matrix<ListValueType> mat(nRow, nCol);
int i = 0;
Node<ListValueType> *p;
for (i=0; i<nRow; i++)
{
if (rHead[i]!=NULL)
{
p = rHead[i]->right;
while (p!=NULL)
{
mat(p->row, p->col) = p->val;
p = p->right;
}
}
}
return mat;
}
#endif
matrix.h
/*
* 稠密矩陣類
* 基于一維數組的實作
* 用友元函數重載+/-/*
* 20180401 first edition
*/
#ifndef _MATRIX_H
#define _MATRIX_H
#include <iostream>
#include <cassert>
using namespace std;
template <class T>
class Matrix {
public:
int m, n;
private:
T* data;
public:
//constructor
Matrix(int r=0, int c=0): m(r), n(c) {
if (r>0 && c>0) {
data = new T[m*n];
if (data==NULL) {
cerr<<"Allocate memory failed!"<<endl;
exit(1);
}
memset(data, 0, m*n*sizeof(T));
} else
data = NULL;
}
// copy constructor
Matrix(const Matrix<T> &a): m(a.m), n(a.n) // 深拷貝
{
int i = 0, j = 0;
if (m>0 && n>0) {
data = new T[m*n];
if (data==NULL) {
cerr<<"Allocate memory failed!"<<endl;
exit(1);
}
else
{
T* aData = a.getData();
for (i=0; i<m; i++)
{
for (j=0; j<n; j++)
{
data[i+j*m] = aData[i+j*m];
}
}
}
} else
data = NULL;
}
//deconstructor
~Matrix() {
m = n = 0;
delete[] data;
data = NULL;
}
// get pointer: data
T* getData() const
{
return data;
}
//entry access
T& operator() (int i, int j) const{
assert(i>=0 && i<m && j>=0 && j<n);
return data[i+j*m];
}
//assign
Matrix<T>& operator= (const Matrix<T> &a) {
m = a.m; n = a.n;
if (data != NULL)
{
delete[] data;
data = NULL;
}
data = new T[m*n*sizeof(T)];
if (data==NULL) {
cerr<<"Allocate memory failed!"<<endl;
exit(1);
}
for (int i=0; i<m; i++)
for (int j=0; j<n; j++)
data[i+j*m] = a(i, j);
return *this;
}
//display
void display() {
cout<<"("<<m<<", "<<n<<"):"<<endl;
for (int i=0; i<m; i++) {
for (int j=0; j<n; j++)
cout<<data[i+j*m]<<"\t";
cout<<endl;
}
}
// 友元:雙目操作符+,非成員函數
friend Matrix<T> operator+ (const Matrix<T>& a, const Matrix<T>& b)
{
assert(a.m==b.m && a.n==b.n);
Matrix<T> r(a.m, a.n);
for (int i=0; i<a.m; i++)
for (int j=0; j<a.n; j++)
r(i, j) = a(i, j) + b(i, j);
return r;
}
// 友元:雙目操作符-,非成員函數
friend Matrix<T> operator- (const Matrix<T>& a, const Matrix<T>& b)
{
assert(a.m==b.m && a.n==b.n);
Matrix<T> r(a.m, a.n);
for (int i=0; i<a.m; i++)
for (int j=0; j<a.n; j++)
r(i, j) = a(i, j) - b(i, j);
return r;
}
// 友元:雙目操作符*,非成員函數
friend Matrix<T> operator* (const Matrix<T>& a, const Matrix<T>& b)
{
assert(a.n==b.m);
Matrix<T> r(a.m, b.n);
for (int i=0; i<a.m; i++)
for (int j=0; j<b.n; j++)
for (int k=0; k<a.n; k++)
r(i, j) += a(i, k) * b(k, j);
return r;
}
};
#endif
complex.h
/*
* 複數類
* 用友元函數重載複數的+/-/*運算
* 20180401 first edition
*/
#ifndef _COMPLEX_H
#define _COMPLEX_H
#include <iostream>
using namespace std;
class Complex {
public:
Complex(double r = 0.0, double i = 0.0) : real(r), imag(i) { }
friend Complex operator+(const Complex &c1, const Complex &c2);
friend Complex operator-(const Complex &c1, const Complex &c2);
friend Complex operator*(const Complex &c1, const Complex &c2);
Complex & operator+=(const Complex &c);
friend ostream & operator<<(ostream &out, const Complex &c);
private:
double real; //複數實部
double imag; //複數虛部
};
Complex operator+(const Complex &c1, const Complex &c2){
return Complex(c1.real+c2.real, c1.imag+c2.imag);
}
Complex operator-(const Complex &c1, const Complex &c2){
return Complex(c1.real-c2.real, c1.imag-c2.imag);
}
Complex operator*(const Complex &c1, const Complex &c2){
return Complex(c1.real*c2.real-c1.imag*c2.imag, c1.imag*c2.real+c1.real*c2.imag);
}
Complex & Complex::operator+=(const Complex &c)
{
*this = *this + c;
return *this;
}
ostream & operator<<(ostream &out, const Complex &c){
out << "(" << c.real << ", " << c.imag << ")";
return out;
}
#endif
ComplexOrthoList.h
/*
* 複數稀疏矩陣類
* OrthoList的派生類
* 增加了從檔案讀入的構造函數
*/
#ifndef _COMPLEXORTHOLIST_H
#define _COMPLEXORTHOLIST_H
#include "OrthoList.h"
#include "matrix.h"
#include<fstream>
#include "complex.h"
class ComplexOrthoList : public OrthoList<Complex>
{
public:
ComplexOrthoList(ifstream &fin); // 構造函數:從檔案讀入
ComplexOrthoList(void); // 預設構造函數
ComplexOrthoList(const ComplexOrthoList &b); // 拷貝構造函數
ComplexOrthoList & operator =(const ComplexOrthoList &b);
// 重載=運算符, 繼承類不能繼承基類的=運算符
friend ComplexOrthoList operator *(const ComplexOrthoList &a, const ComplexOrthoList &b)
// 重載*運算符:基類的友元函數不是繼承類的友元函數
{
int ra = a.getRowNumber();
int ca = a.getColumnNumber();
int rb = b.getRowNumber();
int cb = b.getColumnNumber();
if (ca != rb) // 不符合矩陣乘法的尺寸要求
{
cout << "Error: matrices sizes do not match!" << endl;
ComplexOrthoList nullItem; // 調用void構造函數就是對象名後面不加()
return nullItem; // 傳回空對象
}
ComplexOrthoList ans; // 構造計算結果
Node<Complex> ** bRHead = b.getRHead();
Node<Complex> ** bCHead = b.getCHead();
Node<Complex> ** aRHead = a.getRHead();
Node<Complex> ** aCHead = a.getCHead();
Node<Complex> ** cAns = ans.getCHead();
Node<Complex> ** rAns = ans.getRHead();
Node<Complex> *p, *q; // 連結清單周遊指針
Complex compute = 0; // ans[i,j]的計算結果
bool has_value = false; // 該項是否有值
int i = 0, j = 0;
for (i=0; i<ra; i++)
{
for (j=0; j<cb; j++)
{
has_value = false; // 有值flag重置為false
p = aRHead[i];
q = bCHead[j];
compute = 0; // 乘法計算結果清零
if (p && q)
{
p = p->right;
q = q->down;
while (p && q)
{
if (p->col < q->row) // q在p後面
{
p = p->right; // p往後趕
}
else if (p->col > q->row) // p在q後面
{
q = q->down; // q往後趕
}
else // p->col == q->row
{
has_value = true;
compute += p->val * q->val; // ans[i,j] += A[i,p] * B[p,j];
p = p->right; // p,q一起往後趕
q = q->down; // p,q一起往後趕
}
}
}
if (has_value) // 如果有非零值
{
ans.add(i, j, compute);
}
}
}
return ans;
}
};
ComplexOrthoList::ComplexOrthoList(ifstream &fin)
{
fin >> nRow >> nCol; // 首行2個整數是矩陣尺寸
nElement = 0; // 初始化nElement
rHead = new Node<Complex>*[nRow]; // 初始化行頭指針
cHead = new Node<Complex>*[nCol]; // 初始化列頭指針
int i = 0, j = 0; // 接下來每行有4個數,頭兩個是坐标,注意從1開始不是從0開始
for (i=0; i<nRow; i++)
{
rHead[i] = NULL;
}
for (i=0; i<nCol; i++)
{
cHead[i] = NULL;
}
double Real = 0, Imag = 0; // 後2個是實部/虛部
Node<Complex> *p, *q; // 連結清單周遊指針
while (fin >> i >> j >> Real >> Imag)
{
i--; j--; // 下标化為從0開始的形式
Complex com(Real,Imag);
Node<Complex>* aNode = new Node<Complex>();
aNode->row = i; // 建立新節點
aNode->col = j;
aNode->val = com;
aNode->right = NULL;
aNode->down = NULL;
if (rHead[i] && cHead[j])
{
p = rHead[i];
q = cHead[j];
while (p->right != NULL)
{
p = p->right;
}
p->right = aNode;
while (q->down != NULL)
{
q = q->down;
}
q->down = aNode;
}
else if (rHead[i] == NULL && cHead[j] != NULL)
{
rHead[i] = new Node<Complex>();
rHead[i]->right = aNode;
q = cHead[j];
while (q->down != NULL)
{
q = q->down;
}
q->down = aNode;
}
else if (cHead[j] == NULL && rHead[i] != NULL)
{
cHead[j] = new Node<Complex>();
cHead[j]->down = aNode;
p = rHead[i];
while (p->right != NULL)
{
p = p->right;
}
p->right = aNode;
}
else
{
rHead[i] = new Node<Complex>();
rHead[i]->right = aNode;
cHead[j] = new Node<Complex>();
cHead[j]->down = aNode;
}
}
}
ComplexOrthoList::ComplexOrthoList(void):OrthoList<Complex>(){}
ComplexOrthoList::ComplexOrthoList(const ComplexOrthoList &b):OrthoList<Complex>(b){}
ComplexOrthoList &ComplexOrthoList::operator =(const ComplexOrthoList &b)
{
nRow = b.getRowNumber();
nCol = b.getColumnNumber();
nElement = b.getElementNumber();
rHead = new Node<Complex>*[nRow];
cHead = new Node<Complex>*[nCol];
Node<Complex> **brHead = b.getRHead();
Node<Complex> **bcHead = b.getCHead();
int i = 0, j = 0;
for (i=0; i<nRow; i++)
{
rHead[i] = NULL;
}
for (i=0; i<nCol; i++)
{
cHead[i] = NULL;
}
Node<Complex> *bp, *p, *q;
for (i=0; i<nRow; i++)
{
bp = brHead[i];
if (bp)
{
bp = bp->right;
while (bp)
{
Node<Complex>* aNode = new Node<Complex>(); // 建立新節點
aNode->row = bp->row;
aNode->col = bp->col;
aNode->val = bp->val;
aNode->right = NULL;
aNode->down = NULL;
if (rHead[bp->row] && cHead[bp->col])
{
p = rHead[bp->row];
q = cHead[bp->col];
while (p->right != NULL)
{
p = p->right;
}
p->right = aNode;
while (q->down != NULL)
{
q = q->down;
}
q->down = aNode;
}
else if (rHead[i] == NULL && cHead[j] != NULL)
{
rHead[bp->row] = new Node<Complex>();
rHead[bp->row]->right = aNode;
q = cHead[bp->col];
while (q->down != NULL)
{
q = q->down;
}
q->down = aNode;
}
else if (cHead[bp->col] == NULL && rHead[bp->row] != NULL)
{
cHead[bp->col] = new Node<Complex>();
cHead[bp->col]->down = aNode;
p = rHead[bp->row];
while (p->right != NULL)
{
p = p->right;
}
p->right = aNode;
}
else
{
rHead[bp->row] = new Node<Complex>();
rHead[bp->row]->right = aNode;
cHead[bp->col] = new Node<Complex>();
cHead[bp->col]->down = aNode;
}
bp = bp->right;
}
}
}
return *this;
}
#endif
test.cpp
#include "OrthoList.h"
#include "matrix.h"
#include "complex.h"
#include "ComplexOrthoList.h"
#include<ctime>
const int TEST_TIME = 1; // 重複測試取平均的次數
int main()
{
vector< vector<int> > mat, mat1;
vector<int> aLine;
// 第一行:0,1,0
aLine.push_back(0);
aLine.push_back(1);
aLine.push_back(0);
mat.push_back(aLine);
aLine.clear();
// 第二行:2,0,0
aLine.push_back(2);
aLine.push_back(0);
aLine.push_back(0);
mat.push_back(aLine);
aLine.clear();
// 第三行:0,0,3
aLine.push_back(0);
aLine.push_back(0);
aLine.push_back(3);
mat.push_back(aLine);
cout << "**********************" << endl;
cout << "****OrthoList test****" << endl;
cout << "**********************" << endl;
OrthoList<int> testOList(mat);
testOList.printSparse();
testOList.printFull();
cout << "Add elements" << endl;
testOList.add(0,0,-1);
testOList.add(0,1,-2);
testOList.add(0,2,-3);
testOList.add(1,0,-4);
testOList.add(1,1,-5);
testOList.add(1,2,-6);
testOList.add(2,0,-7);
testOList.add(2,1,-8);
testOList.add(2,2,-9);
testOList.printSparse();
testOList.printFull();
cout << "Delete elements" << endl;
testOList.del(0,0);
testOList.del(1,1);
testOList.del(1,0);
testOList.printSparse();
testOList.printFull();
OrthoList<int> oList1(testOList);
cout << " New sparse matrix: " << endl;
oList1.printFull();
cout << " matrix1 + matrix2: " << endl;
testOList.Addition(oList1);
testOList.printSparse();
testOList.printFull();
cout << "matrix1 * martrix2: " << endl;
OrthoList<int> oList2(oList1);
oList2 = testOList * oList1;
oList2.printFull();
cout << "**********************" << endl;
cout << "*OrthoList test done**" << endl;
cout << "**********************" << endl;
cout << endl;
cout << "----------------------" << endl;
cout << "-----matrix test------" << endl;
cout << "----------------------" << endl;
int m = 2, n = 2;
Matrix<int> a = oList1.toFull();
Matrix<int> b = testOList.toFull();
cout<<"Matrix a"; a.display(); cout<<endl;
cout<<"Matrix b"; b.display(); cout<<endl;
Matrix<int> c;
c = (a+b);
cout<<"a+b: "; c.display(); cout<<endl;
Matrix<int> d;
d = (a*b);
cout<<"a*b:"; d.display();
cout << "----------------------" << endl;
cout << "---matrix test done---" << endl;
cout << "----------------------" << endl << endl;
cout << "@@@@@@@@@@@@@@@@@@@@@@" << endl;
cout << "@@@ run time test @@@@" << endl;
cout << "@@@@@@@@@@@@@@@@@@@@@@" << endl;
ifstream fin("Real.txt"); // 資料檔案
assert(fin); // 檔案是否正常打開
ComplexOrthoList sY(fin); // sY: sparse Y matrix
fin.close();
fin.open("Real_2.txt");
assert(fin);
ComplexOrthoList sY1(fin);
fin.close();
ComplexOrthoList sY2(sY);
Matrix<Complex> dY = sY.toFull(); // dY: dense Y matrix
Matrix<Complex> dY1 = sY1.toFull();
Matrix<Complex> dY2(dY);
int i = 0;
clock_t t0;
// 測試矩陣加法
t0 = clock();
for (i=0; i<TEST_TIME; i++)
{
sY.Addition(sY1);
}
clock_t run_sparse_add = clock()-t0;
run_sparse_add /= TEST_TIME;
cout << TEST_TIME << " times SPARSE MATRIX addition: "
<< run_sparse_add << " ms" << endl;
t0 = clock();
for (i=0; i<TEST_TIME; i++)
{
dY2 = dY + dY1;
}
clock_t run_dense_add = clock()-t0;
run_dense_add /= TEST_TIME;
cout << TEST_TIME << " times DENSE MATRIX addition: "
<< run_dense_add << " ms" << endl;
// 測試矩陣乘法
t0 = clock();
for (i=0; i<TEST_TIME; i++)
{
sY2 = sY * sY1;
}
clock_t run_sparse_mul = clock()-t0;
run_sparse_mul /= TEST_TIME;
cout << TEST_TIME << " times SPARSE MATRIX multiplication: "
<< run_sparse_mul << " ms" << endl;
t0 = clock();
for (i=0; i<TEST_TIME; i++)
{
dY2 = dY * dY1;
}
clock_t run_dense_mul = clock()-t0;
run_dense_mul /= TEST_TIME;
cout << TEST_TIME << " times DENSE MATRIX multipication: "
<< run_dense_mul << " ms" << endl;
return 0;
}