#include <iostream>
using namespace std;
template<typename T>
class SharedPtr
{
public:
//禁止分配内存时抛出异常
explicit SharedPtr(T *p):ptr(p), use(new (nothrow) size_t(1))
{
//分配内存失败
if (use == nullptr)
{
delete ptr;
ptr = nullptr;
cout << "分配内存失败";
exit(-1);
}
}
explicit SharedPtr(const SharedPtr &rhs) :ptr(rhs.ptr), use(rhs.use)
{
++*use;
}
SharedPtr& operator=(const SharedPtr &rhs)
{
++*rhs.use;
if (--*use == 0)
{
delete ptr;
ptr = nullptr;
delete use;
use = nullptr;
}
ptr = rhs.ptr;
use = rhs.use;
return *this;
}
T* get() const
{
return ptr;
}
size_t use_count() const
{
return *use;
}
bool unique() const
{
return use_count() == 1;
}
//交换两个指针的指向
void swap(SharedPtr &q)
{
SharedPtr temp(q);
q = *this;
*this = temp;
}
T& operator*() const
{
return *ptr;
}
T* operator->() const
{
return & operator*();
}
//虚析构函数
virtual ~SharedPtr()
{
if (--*use == 0)
{
delete ptr;
ptr = nullptr;
delete use;
use = nullptr;
}
}
private:
T *ptr;
size_t *use;
};
int main()
{
//测试
SharedPtr<int> s1(new int(666));
{
SharedPtr<int> s2(s1);
SharedPtr<int> s3(s2);
cout << "*s1 = " << *s1 << endl;
cout << "s1 use_count " << s1.use_count() << endl;
cout << "s1 is unique? " << s1.unique() << endl;
}
SharedPtr<int> s4(new int(888));
SharedPtr<int> s5(s4);
cout << "Before swap:\n";
cout << "*s5 " << *s5 << endl;
cout << "*s1 " << *s1 << endl;
s5.swap(s1);
cout << "After swap:\n";
cout << "*s5 " << *s5 << endl;
cout << "*s1 " << *s1 << endl;
cout << *(s1.get()) << endl;
s1 = s4;
SharedPtr<double> pd1(new double(2.3333));
SharedPtr<double> pd2(new double(0.33333));
cout << "pd1 is unique? " << pd1.unique() << endl;
auto q = &pd1;
cout << "Before swap:\n";
cout << "*pd1 " << *pd1 << endl;
cout << "*pd2 " << *pd2 << endl;
q->swap(pd2);
cout << "After swap:\n";
cout << "*pd1 " << *pd1 << endl;
cout << "*pd2 " << *pd2 << endl;
cout << "test end!\n";
return 0;
}