这是我第一题A了的代码

static ListNode partition(ListNode head, int m) {
        if (head == null || head.next == null) { return head; }
        ListNode head1 = new ListNode(m);
        ListNode head2 = new ListNode(m);
        ListNode pre = head1;
        ListNode after = head2;
        for (ListNode cur = head; cur != null; cur = cur.next) {
            if (cur.val > m) {
                after.next = cur;
                after = after.next;
            } else {
                pre.next = cur;
                pre = pre.next;
            }
        }
        if (head1.next == null) {
            return head;
        } else {
            pre.next = head2.next;
            after.next = null;
            return head1.next;
        }
    }