ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 1269 - 대칭 차집합
    카테고리 없음 2025. 5. 9. 21:22

    이 전에 풀었던 1764번 문제를 참고해서 우선 구현했다.

    #include <iostream>
    #include <algorithm>
    #include <vector>
    using namespace std;
    
    int main()
    {
    
        ios_base::sync_with_stdio(false);
        cin.tie(NULL);
        cout.tie(NULL);
    
        int n, m;
        cin >> n >> m;
    
        vector<int> l;
        vector<int> s;
        vector<int> ls;
        int t;
    
        for (int i = 0; i < n; i++)
            cin >> t, l.push_back(t);
        sort(l.begin(), l.end());
    
        for (int i = 0; i < m; i++)
            cin >> t, s.push_back(t);
        sort(s.begin(), s.end());
    
        for (int i = 0; i < n; i++)
            if (!binary_search(s.begin(), s.end(), l[i]))
                ls.push_back(l[i]);
    
        for (int i = 0; i < m; i++)
            if (!binary_search(l.begin(), l.end(), s[i]))
                ls.push_back(s[i]);
    
        cout << ls.size() << "\n";
    }

     

    근데 살짝 느리다. 아쉽쓰

     

    이때 s랑 l은 정렬이 되어있으니 전부 뒤져볼게 아니라 값이 작은지 큰지만 비교하면 되겠지 싶어서 아래와 같이 짰다.

    #include <iostream>
    #include <algorithm>
    #include <vector>
    using namespace std;
    
    void next(vector<int>::iterator &iter, vector<int>::iterator &end)
    {
        if (iter != end)
            iter++, cout << (iter == end) << "\n";
    }
    
    int main()
    {
    
        ios_base::sync_with_stdio(false);
        cin.tie(NULL);
        cout.tie(NULL);
    
        int n, m;
        cin >> n >> m;
    
        vector<int> l;
        vector<int> s;
        vector<int> ls;
        int t;
    
        for (int i = 0; i < n; i++)
            cin >> t, l.push_back(t);
        sort(l.begin(), l.end());
    
        for (int i = 0; i < m; i++)
            cin >> t, s.push_back(t);
        sort(s.begin(), s.end());
    
        auto li = l.begin();
        auto si = s.begin();
        int cnt = 0;
    
        while (li != l.end() || si != s.end())
        {
            if (*li > *si)
                cnt++, si++;
            else if (*li < *si)
                cnt++, li++;
            else
                si++, li++;
        }
    
        cout << cnt;
    }

    근데 안 됐다. 터진다.

    생각해보니 else에서 si랑 li를 둘 다 올리면 한쪽이 iter의 끝에 도달해도 다른쪽 iter가 안 끝났다면 이미 끝난 iter도 무한히 증가하게 될 것이다. 그래서 아래와 같이 수정해봤다.

     

    #include <iostream>
    #include <algorithm>
    #include <vector>
    using namespace std;
    
    void next(vector<int>::iterator &iter, vector<int>::iterator &end)
    {
        if (iter != end)
            iter++;
    }
    
    int main()
    {
    
        ...
        
        auto le = l.end();
        auto se = s.end();
    
        while (li != l.end() || si != s.end())
        {
            cout << *li << " " << *si << "\n";
            if (*li > *si)
                cnt++, next(si, se);
            else if (*li < *si)
                cnt++, next(li, le);
            else
                next(si, se), next(li, le);
        }
    
        cout << cnt;
    }

    이래도 안된다! next는 iter가 끝에 도달하면 더 증가하지 않도록 짜서 제대로 동작하는데 end에 도달한 iter는 가장 마지막 원소를 참조하는 줄 알았는데 그게 아니라 가장 마지막 원소의 iter 그 다음 iter, 즉 끝을 가리키는 iter가 존재한다는 것을 알게 됐다. 그리고 그 iter가 가리키는 값이 0이라서 si만 계속 next 되는 구조인 것이다! 그래서 아래와 같이 수정했다.

     

    #include <iostream>
    #include <algorithm>
    #include <vector>
    using namespace std;
    
    void next(vector<int>::iterator &iter, vector<int>::iterator &end)
    {
        if (iter != end)
            iter++;
    }
    
    int value(vector<int>::iterator iter, vector<int>::iterator end)
    {
        if (iter == end)
            return 100'000'001;
        else
            return *iter;
    }
    
    int main()
    {
    
        ios_base::sync_with_stdio(false);
        cin.tie(NULL);
        cout.tie(NULL);
    
        int n, m;
        cin >> n >> m;
    
        vector<int> l;
        vector<int> s;
        vector<int> ls;
        int t;
    
        for (int i = 0; i < n; i++)
            cin >> t, l.push_back(t);
        sort(l.begin(), l.end());
    
        for (int i = 0; i < m; i++)
            cin >> t, s.push_back(t);
        sort(s.begin(), s.end());
    
        auto li = l.begin();
        auto si = s.begin();
        int cnt = 0;
    
        auto le = l.end();
        auto se = s.end();
    
        while (li < le || si < se)
        {
            int sv = value(si, se);
            int lv = value(li, le);
            if (lv > sv)
                cnt++, next(si, se);
            else if (lv < sv)
                cnt++, next(li, le);
            else
                next(si, se), next(li, le);
        }
    
        cout << cnt;
    }

    iter가 가리키는 값을 반환하는 value 함수를 만들어 iter가 end에 도달하면 주어진 입력값 범위를 초과하는, 즉 항상 가장 큰 값이 되도록 했다. 이러면 while문 하나로 모든 과정을 끝낼 수 있다. 이렇게 하던가 아니면 while 조건을 li < le && si < se로 넣고 한쪽 iter가 다 돌 때까지 겹치는 갯수를 cnt로 놓고 n+m-2*cnt를 해도 될 것 같다. 사실 후자가 더 편한 것 같다. 그래서 아래와 같이  다시 수정했다.

     

    #include <iostream>
    #include <algorithm>
    #include <vector>
    using namespace std;
    
    int main()
    {
        ios_base::sync_with_stdio(false);
        cin.tie(NULL);
        cout.tie(NULL);
        int n, m;
        cin >> n >> m;
        vector<int> l;
        vector<int> s;
        vector<int> ls;
        int t;
        for (int i = 0; i < n; i++)
            cin >> t, l.push_back(t);
        sort(l.begin(), l.end());
        for (int i = 0; i < m; i++)
            cin >> t, s.push_back(t);
        sort(s.begin(), s.end());
        auto li = l.begin();
        auto si = s.begin();
        int cnt = 0;
        auto le = l.end();
        auto se = s.end();
        while (li < le && si < se)
        {
            if (*li > *si)
                si++;
            else if (*li < *si)
                li++;
            else
                si++, li++, cnt++;
        }
        cout << n + m - 2 * cnt;
    }

     

    사실 소모되는 시간의 차이는 거의 없는데 그냥 코드가 짧고 간결해진다.

    댓글

Designed by Tistory.