LeetCode – Kth Smallest Element in a BST (Java)

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it. (1 ≤ k ≤ BST's total elements)

Java Solution 1 - Inorder Traversal

We can inorder traverse the tree and get the kth smallest element. Time is O(n).

public int kthSmallest(TreeNode root, int k) {
    Stack<TreeNode> stack = new Stack<TreeNode>();
    TreeNode p = root;
    int result = 0;
    while(!stack.isEmpty() || p!=null){
            p = p.left;
            TreeNode t = stack.pop();
                result = t.val;
            p = t.right;
    return result;

Similarly, we can also write the inorder traversal as the following:

public int kthSmallest(TreeNode root, int k) {
    Stack<TreeNode> stack = new Stack<TreeNode>();
    TreeNode p = root;
    int i=0;
        TreeNode t = stack.pop();
            return t.val;
        TreeNode r = t.right;
    return -1;

Java Solution 2 - Extra Data Structure

We can let each node track the order, i.e., the number of elements that are less than itself. Time is O(log(n)).

Category >> Algorithms >> Interview  
If you want someone to read your code, please put the code inside <pre><code> and </code></pre> tags. For example:
String foo = "bar";

  1. Tank Ibdah on 2015-7-18
  2. Megha Tiwari on 2015-7-21

    I think the program should return the result immediately after k equals 0, its unnecessary to keep on filling stack even after we have obtained our desired result. Please correct me if I’m wrong

  3. Arun Prakash on 2015-10-5

    I think below code will give the required result after in order traversal . Why do we need stack here ?

    void findK(TreeNode root, int k) {

    Treenode p = root ;

    while(!p.isEmpty() || k >= 0)
    findK(p.left, k);


    if(k == 0) {
    return p.val;
    findK(p.right, k);

    Correct me if I am wrong .

  4. Jayesh on 2015-11-22
  5. svGuy on 2016-5-2

    there is a easier recursive implementation

    TreeNode* kth(TreeNode *root, int & num, int k) {
    if(!root) return NULL;

    TreeNode *left = kth(root->left, num, k);
    if(left) return left;

    ++num; //increment num everytime you visit a node in inorder traversal.
    if(num == k) return root;

    TreeNode *right = kth(root->right, num, k);
    return right;

  6. Mark Naguib on 2016-6-16

    public int KthSmallest(TreeNode root, int k) {

    int c=0;

    int result=0;

    KthSmallest(root,k,ref c,ref result);

    return result;


    public void KthSmallest(TreeNode root,int k, ref int c,ref int result)


    if(root==null) return;

    if(c>=k) return;

    KthSmallest(root.left,k,ref c,ref result);


    if(k==c){result= root.val; return;}

    KthSmallest(root.right,k,ref c,ref result);



Leave a comment