import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int n = in.nextInt();
int k = in.nextInt();
System.out.println(core(n, k));
}
public static int core(int n, int k) {
int count = 0;
if(k == 0){
return n*n;
}
for (int y = k + 1; y <= n; y++) {
int i = n/y;
count += i*(y-k);
if(n%y >=k){
count += n%y -k +1;
}
}
return count;
}
}