import java.util.Scanner; public class Main { public static void main(String[] args) { Scanner in = new Scanner(System.in); int n = in.nextInt(); int k = in.nextInt(); long count = 0L; if (k == 0) { count = (long)n*n; System.out.println(count); } else { for (int y = k + 1; y <= n; y++) { count += (long)(n / y) * (y - k); if (n % y >= k) { count +=(long) n % y - k + 1; } } System.out.println(count); } } }nk题 100%