老铁,我第一题思路跟你基本一样,但为啥是0AC啊,给看一眼
package aiqiyi;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Scanner;
import java.util.Set;

public class _1 {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int res = 0;
        String s = sc.nextLine();
        int k = Integer.valueOf(sc.nextLine().split("")[0]);
        Map<Character, Integer> map = new HashMap<>();
        int length = s.length();
        for (int i = 0; i < length; i++) {
            if (map.containsKey(s.charAt(i))) {
                int m = map.get(s.charAt(i));
                map.put(s.charAt(i), m + 1);
            } else
                map.put(s.charAt(i), 1);
        }
        Set<Map.Entry<Character, Integer>> set = map.entrySet();
        int[] num = new int[map.size()];
        Iterator it = set.iterator();
        int i = 0;
        while (it.hasNext()) {
            Map.Entry m = (Map.Entry) it.next();
            num[i] = (int) m.getValue();
            i++;
        }
        System.out.println(map.toString());
        Arrays.sort(num);
        for (int m = 0; m < k; m++) {
            num[map.size() - 1]--;
            Arrays.sort(num);
        }
        for (int m = 0; m < map.size(); m++) {
            res += (num[m] * num[m]);
        }
        System.out.println(res);
    }
}