#include <iostream>

using namespace std;

template<typename T>
class SharedPtr
{
public:
    //禁止分配内存时抛出异常
    explicit SharedPtr(T *p):ptr(p), use(new (nothrow) size_t(1))
    {
        //分配内存失败
        if (use == nullptr)
        {
            delete ptr;
            ptr = nullptr;
            cout << "分配内存失败";
            exit(-1);
        }            
    }

    explicit SharedPtr(const SharedPtr &rhs) :ptr(rhs.ptr), use(rhs.use)
    {
        ++*use;
    }

    SharedPtr& operator=(const SharedPtr &rhs)
    {
        ++*rhs.use;
        if (--*use == 0)
        {
            delete ptr;
            ptr = nullptr;
            delete use;
            use = nullptr;
        }
        ptr = rhs.ptr;
        use = rhs.use;
        return *this;
    }

    T* get() const
    {
        return ptr;
    }

    size_t use_count() const
    {
        return *use;
    }

    bool unique() const
    {
        return use_count() == 1;
    }

    //交换两个指针的指向
    void swap(SharedPtr &q)
    {
        SharedPtr temp(q);
        q = *this;
        *this = temp;
    }

    T& operator*() const
    {
        return *ptr;
    }

    T* operator->() const
    {
        return & operator*();
    }

    //虚析构函数
    virtual ~SharedPtr()
    {
        if (--*use == 0)
        {
            delete ptr;
            ptr = nullptr;
            delete use;            
            use = nullptr;        
        }    
    }

private:
    T       *ptr;
    size_t    *use;
};



int main()
{
    //测试
    SharedPtr<int> s1(new int(666));
    {
        SharedPtr<int> s2(s1);
        SharedPtr<int> s3(s2);
        cout << "*s1 = " << *s1 << endl;
        cout << "s1 use_count " << s1.use_count() << endl;
        cout << "s1 is unique? " << s1.unique() << endl;
    }

    SharedPtr<int> s4(new int(888));
    SharedPtr<int> s5(s4);

    cout << "Before swap:\n";
    cout << "*s5 " << *s5 << endl;
    cout << "*s1 " << *s1 << endl;
    s5.swap(s1);
    cout << "After swap:\n";
    cout << "*s5 " << *s5 << endl;
    cout << "*s1 " << *s1 << endl;
    cout << *(s1.get()) << endl;

    s1 = s4;

    SharedPtr<double> pd1(new double(2.3333));
    SharedPtr<double> pd2(new double(0.33333));
    cout << "pd1 is unique? " << pd1.unique() << endl;

    auto q = &pd1;

    cout << "Before swap:\n";
    cout << "*pd1 " << *pd1 << endl;
    cout << "*pd2 " << *pd2 << endl;
    q->swap(pd2);
    cout << "After swap:\n";
    cout << "*pd1 " << *pd1 << endl;
    cout << "*pd2 " << *pd2 << endl;
    
    
    cout << "test end!\n";
    return 0;
}