LeetCode – Kth Smallest Element in a BST (Java)

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it. (1 ≤ k ≤ BST's total elements)

Java Solution 1 - Inorder Traversal

We can inorder traverse the tree and get the kth smallest element. Time is O(n).

public int kthSmallest(TreeNode root, int k) {
    Stack<TreeNode> stack = new Stack<TreeNode>();
 
    TreeNode p = root;
    int result = 0;
 
    while(!stack.isEmpty() || p!=null){
        if(p!=null){
            stack.push(p);
            p = p.left;
        }else{
            TreeNode t = stack.pop();
            k--;
            if(k==0)
                result = t.val;
            p = t.right;
        }
    }
 
    return result;
}

Similarly, we can also write the inorder traversal as the following:

public int kthSmallest(TreeNode root, int k) {
    Stack<TreeNode> stack = new Stack<TreeNode>();
    TreeNode p = root;
    while(p!=null){
        stack.push(p);
        p=p.left;
    }
    int i=0;
    while(!stack.isEmpty()){
        TreeNode t = stack.pop();
        i++;
 
        if(i==k)
            return t.val;
 
        TreeNode r = t.right;
        while(r!=null){
            stack.push(r);
            r=r.left;
        }
 
    }
 
    return -1;
}

Java Solution 2 - Extra Data Structure

We can let each node track the order, i.e., the number of elements that are less than itself. Time is O(log(n)).

Category >> Algorithms >> Interview  
If you want someone to read your code, please put the code inside <pre><code> and </code></pre> tags. For example:
<pre><code> 
String foo = "bar";
</code></pre>
  • Mark Naguib

    public int KthSmallest(TreeNode root, int k) {

    int c=0;

    int result=0;

    KthSmallest(root,k,ref c,ref result);

    return result;

    }

    public void KthSmallest(TreeNode root,int k, ref int c,ref int result)

    {

    if(root==null) return;

    if(c>=k) return;

    KthSmallest(root.left,k,ref c,ref result);

    c++;

    if(k==c){result= root.val; return;}

    KthSmallest(root.right,k,ref c,ref result);

    return;

    }

  • svGuy

    there is a easier recursive implementation

    TreeNode* kth(TreeNode *root, int & num, int k) {
    if(!root) return NULL;

    TreeNode *left = kth(root->left, num, k);
    if(left) return left;

    ++num; //increment num everytime you visit a node in inorder traversal.
    if(num == k) return root;

    TreeNode *right = kth(root->right, num, k);
    return right;
    }

  • Arun Prakash

    I think below code will give the required result after in order traversal . Why do we need stack here ?

    void findK(TreeNode root, int k) {

    Treenode p = root ;

    while(!p.isEmpty() || k >= 0)
    {
    findK(p.left, k);

    –k;

    if(k == 0) {
    return p.val;
    }
    findK(p.right, k);
    }
    }

    Correct me if I am wrong .

  • Megha Tiwari

    I think the program should return the result immediately after k equals 0, its unnecessary to keep on filling stack even after we have obtained our desired result. Please correct me if I’m wrong

  • Tank Ibdah