/*
 * Copyright 2016 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.cloud.stream.app.s3.sink;

import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;

import java.io.File;
import java.io.InputStream;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import org.springframework.beans.DirectFieldAccessor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.IntegrationTest;
import org.springframework.boot.test.SpringApplicationConfiguration;
import org.springframework.cloud.stream.messaging.Sink;
import org.springframework.context.annotation.Bean;
import org.springframework.http.MediaType;
import org.springframework.integration.aws.outbound.S3MessageHandler;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.integration.test.util.TestUtils;
import org.springframework.messaging.Message;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.TestPropertySource;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;

import com.amazonaws.event.ProgressEvent;
import com.amazonaws.event.ProgressEventType;
import com.amazonaws.event.ProgressListener;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Client;
import com.amazonaws.services.s3.model.CannedAccessControlList;
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.PutObjectRequest;
import com.amazonaws.services.s3.model.PutObjectResult;
import com.amazonaws.services.s3.model.SetObjectAclRequest;
import com.amazonaws.services.s3.transfer.PersistableTransfer;
import com.amazonaws.services.s3.transfer.internal.S3ProgressListener;
import com.amazonaws.services.s3.transfer.internal.S3ProgressPublisher;
import com.amazonaws.util.Md5Utils;
import com.amazonaws.util.StringInputStream;

/**
 * @author Artem Bilan
 */
@RunWith(SpringJUnit4ClassRunner.class)
@SpringApplicationConfiguration
@DirtiesContext
@TestPropertySource(properties = {
		"cloud.aws.stack.auto=false",
		"cloud.aws.credentials.accessKey=" + AmazonS3SinkMockTests.AWS_ACCESS_KEY,
		"cloud.aws.credentials.secretKey=" + AmazonS3SinkMockTests.AWS_SECRET_KEY,
		"cloud.aws.region.static=" + AmazonS3SinkMockTests.AWS_REGION,
		"s3.bucket=" + AmazonS3SinkMockTests.S3_BUCKET })
public abstract class AmazonS3SinkMockTests {

	protected static final String AWS_ACCESS_KEY = "test.accessKey";

	protected static final String AWS_SECRET_KEY = "test.secretKey";

	protected static final String AWS_REGION = "us-gov-west-1";

	protected static final String S3_BUCKET = "S3_BUCKET";

	@Rule
	public TemporaryFolder temporaryFolder = new TemporaryFolder();

	@Autowired
	private AmazonS3Client amazonS3;

	@Autowired
	protected S3MessageHandler s3MessageHandler;

	@Autowired
	protected Sink channels;

	@Autowired
	protected CountDownLatch aclLatch;

	@Autowired
	protected CountDownLatch transferCompletedLatch;

	@Before
	public void setupTest() {
		Object transferManager = TestUtils.getPropertyValue(this.s3MessageHandler, "transferManager");

		AmazonS3 amazonS3 = spy(this.amazonS3);

		willAnswer(new Answer<PutObjectResult>() {

			@Override
			public PutObjectResult answer(InvocationOnMock invocation) throws Throwable {
				return new PutObjectResult();
			}

		}).given(amazonS3)
				.putObject(any(PutObjectRequest.class));

		willAnswer(new Answer<Object>() {

			@Override
			public Object answer(InvocationOnMock invocation) throws Throwable {
				aclLatch.countDown();
				return null;
			}

		}).given(amazonS3)
				.setObjectAcl(any(SetObjectAclRequest.class));

		new DirectFieldAccessor(transferManager).setPropertyValue("s3", amazonS3);
	}

	public abstract void test() throws Exception;

	@IntegrationTest({ "s3.acl=PublicReadWrite" })
	public static class AmazonS3UploadFileTests extends AmazonS3SinkMockTests {

		@Test
		@Override
		public void test() throws Exception {
			AmazonS3 amazonS3Client = TestUtils.getPropertyValue(this.s3MessageHandler, "transferManager.s3",
					AmazonS3.class);

			File file = this.temporaryFolder.newFile("foo.mp3");
			Message<?> message = MessageBuilder.withPayload(file)
					.build();

			this.channels.input().send(message);

			ArgumentCaptor<PutObjectRequest> putObjectRequestArgumentCaptor =
					ArgumentCaptor.forClass(PutObjectRequest.class);
			verify(amazonS3Client, atLeastOnce()).putObject(putObjectRequestArgumentCaptor.capture());

			PutObjectRequest putObjectRequest = putObjectRequestArgumentCaptor.getValue();
			assertThat(putObjectRequest.getBucketName(), equalTo(S3_BUCKET));
			assertThat(putObjectRequest.getKey(), equalTo("foo.mp3"));
			assertNotNull(putObjectRequest.getFile());
			assertNull(putObjectRequest.getInputStream());

			ObjectMetadata metadata = putObjectRequest.getMetadata();
			assertThat(metadata.getContentMD5(), equalTo(Md5Utils.md5AsBase64(file)));
			assertThat(metadata.getContentLength(), equalTo(0L));
			assertThat(metadata.getContentType(), equalTo("audio/mpeg"));

			ProgressListener listener = putObjectRequest.getGeneralProgressListener();
			S3ProgressPublisher.publishProgress(listener, ProgressEventType.TRANSFER_COMPLETED_EVENT);

			assertTrue(this.transferCompletedLatch.await(10, TimeUnit.SECONDS));
			assertTrue(this.aclLatch.await(10, TimeUnit.SECONDS));

			ArgumentCaptor<SetObjectAclRequest> setObjectAclRequestArgumentCaptor =
					ArgumentCaptor.forClass(SetObjectAclRequest.class);
			verify(amazonS3Client).setObjectAcl(setObjectAclRequestArgumentCaptor.capture());

			SetObjectAclRequest setObjectAclRequest = setObjectAclRequestArgumentCaptor.getValue();

			assertThat(setObjectAclRequest.getBucketName(), equalTo(S3_BUCKET));
			assertThat(setObjectAclRequest.getKey(), equalTo("foo.mp3"));
			assertNull(setObjectAclRequest.getAcl());
			assertThat(setObjectAclRequest.getCannedAcl(), equalTo(CannedAccessControlList.PublicReadWrite));
		}

	}

	@IntegrationTest({ "s3.key-expression=headers.key" })
	public static class AmazonS3UploadInputStreamTests extends AmazonS3SinkMockTests {


		@Test
		@Override
		public void test() throws Exception {
			AmazonS3 amazonS3Client = TestUtils.getPropertyValue(this.s3MessageHandler, "transferManager.s3",
					AmazonS3.class);

			InputStream payload = new StringInputStream("a");
			Message<?> message = MessageBuilder.withPayload(payload)
					.setHeader("key", "myInputStream")
					.build();

			this.channels.input().send(message);

			ArgumentCaptor<PutObjectRequest> putObjectRequestArgumentCaptor =
					ArgumentCaptor.forClass(PutObjectRequest.class);
			verify(amazonS3Client, atLeastOnce()).putObject(putObjectRequestArgumentCaptor.capture());

			PutObjectRequest putObjectRequest = putObjectRequestArgumentCaptor.getValue();
			assertThat(putObjectRequest.getBucketName(), equalTo(S3_BUCKET));
			assertThat(putObjectRequest.getKey(), equalTo("myInputStream"));
			assertNull(putObjectRequest.getFile());
			assertNotNull(putObjectRequest.getInputStream());

			ObjectMetadata metadata = putObjectRequest.getMetadata();
			assertThat(metadata.getContentMD5(), equalTo(Md5Utils.md5AsBase64(payload)));
			assertThat(metadata.getContentLength(), equalTo(1L));
			assertThat(metadata.getContentType(), equalTo(MediaType.APPLICATION_JSON_VALUE));
			assertThat(metadata.getContentDisposition(), equalTo("test.json"));
		}

	}

	@SpringBootApplication
	public static class S3SinkApplication {

		@Bean
		public CountDownLatch aclLatch() {
			return new CountDownLatch(1);
		}

		@Bean
		public CountDownLatch transferCompletedLatch() {
			return new CountDownLatch(1);
		}

		@Bean
		public S3ProgressListener s3ProgressListener() {
			return new S3ProgressListener() {

				@Override
				public void onPersistableTransfer(PersistableTransfer persistableTransfer) {

				}

				@Override
				public void progressChanged(ProgressEvent progressEvent) {
					if (ProgressEventType.TRANSFER_COMPLETED_EVENT.equals(progressEvent.getEventType())) {
						transferCompletedLatch().countDown();
					}
				}

			};
		}

		@Bean
		public S3MessageHandler.UploadMetadataProvider uploadMetadataProvider() {
			return new S3MessageHandler.UploadMetadataProvider() {

				@Override
				public void populateMetadata(ObjectMetadata metadata, Message<?> message) {
					if (message.getPayload() instanceof InputStream) {
						metadata.setContentLength(1);
						metadata.setContentType(MediaType.APPLICATION_JSON_VALUE);
						metadata.setContentDisposition("test.json");
					}
				}

			};
		}
	}

}