第一次,样例全过代码,比较乱,勿喷
import java.util.*; public class Exam1 { public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        Map<Character,Integer> characterMap=new HashMap<>(); while (in.hasNextLine()) {
            String s = in.nextLine(); //字符串s     1<=length<=50  int k = in.nextInt(); //允许移除的字符个数 0<=k<=length  for(int i=0;i<s.length();i++)
            { char c=s.charAt(i); //Integer integer = characterMap.putIfAbsent(c, 1);  if(characterMap.get(c)==null)
                {
                    characterMap.put(c, 1);
                } else  {
                    characterMap.put(c,characterMap.get(c)+1);
                }
            }
        List<Integer> values=new ArrayList<>(); for(Object key:characterMap.keySet())
        {
            values.add(characterMap.get(key));
        }
        Collections.sort(values);
        Integer [] valuesArr=new Integer[values.size()]; for(int n=0;n<values.size();n++)
        {
            valuesArr[n]=values.get(n);
        } for(int m=0;m<k;m++)
       { for(int j=valuesArr.length-1;j>0;j--) { if(valuesArr[j]>valuesArr[j-1])
               {
                   valuesArr[j]--; break;
               } else  { if(j==1)
                   {
                       valuesArr[0]--;
                   }else  continue;
               }
           }
       } long count=0; for(int i=0;i<valuesArr.length;i++)
       {
           count+=valuesArr[i]*valuesArr[i];
       }
            System.out.println(count);
        }
    }
}