Add tests for A1ClientFactory

Change-Id: I751fdaf474de3c2fbe5a7ee591284a9af740d4a2
Signed-off-by: elinuxhenrik <henrik.b.andersson@est.tech>
diff --git a/policy-agent/src/main/java/org/oransc/policyagent/clients/A1ClientFactory.java b/policy-agent/src/main/java/org/oransc/policyagent/clients/A1ClientFactory.java
index c150c08..e340e60 100644
--- a/policy-agent/src/main/java/org/oransc/policyagent/clients/A1ClientFactory.java
+++ b/policy-agent/src/main/java/org/oransc/policyagent/clients/A1ClientFactory.java
@@ -22,13 +22,15 @@
 
 import org.oransc.policyagent.clients.A1Client.A1ProtocolType;
 import org.oransc.policyagent.configuration.ApplicationConfig;
-import org.oransc.policyagent.exceptions.ServiceException;
 import org.oransc.policyagent.repository.Ric;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.beans.factory.annotation.Autowired;
 import reactor.core.publisher.Mono;
 
+/**
+ * Factory for A1 clients that supports four different protocol versions of the A1 api.
+ */
 public class A1ClientFactory {
 
     private static final Logger logger = LoggerFactory.getLogger(A1ClientFactory.class);
@@ -40,6 +42,19 @@
         this.appConfig = appConfig;
     }
 
+    /**
+     * Creates an A1 client with the correct A1 protocol for the provided Ric.
+     *
+     * <p>It detects the protocol version by trial and error, since there is no getVersion method specified in the A1
+     * api yet.
+     *
+     * <p>As a side effect it also sets the protocol version in the provided Ric. This means that after the first
+     * successful creation it won't have to try which protocol to use, but can create the client directly.
+     *
+     * @param ric The Ric to get a client for.
+     * @return a client with the correct protocol, or a ServiceException if none of the protocols are supported by the
+     *         Ric.
+     */
     public Mono<A1Client> createA1Client(Ric ric) {
         return getProtocolVersion(ric) //
             .flatMap(version -> createA1Client(ric, version));
@@ -49,21 +64,20 @@
         if (version == A1ProtocolType.STD_V1) {
             return Mono.just(createStdA1ClientImpl(ric));
         } else if (version == A1ProtocolType.OSC_V1) {
-            return Mono.just(new OscA1Client(ric.getConfig()));
+            return Mono.just(createOscA1Client(ric));
         } else if (version == A1ProtocolType.SDNC_OSC) {
             return Mono.just(createSdncOscA1Client(ric));
-        } else if (version == A1ProtocolType.SDNR_ONAP) {
+        } else { // A1ProtocolType.SDNR_ONAP
             return Mono.just(createSdnrOnapA1Client(ric));
         }
-        return Mono.error(new ServiceException("Not supported protocoltype: " + version));
     }
 
     private Mono<A1Client.A1ProtocolType> getProtocolVersion(Ric ric) {
         if (ric.getProtocolVersion() == A1ProtocolType.UNKNOWN) {
-            return fetchVersion(ric, createSdnrOnapA1Client(ric)) //
-                .onErrorResume(err -> fetchVersion(ric, createSdncOscA1Client(ric)))
-                .onErrorResume(err -> fetchVersion(ric, new OscA1Client(ric.getConfig())))
-                .onErrorResume(err -> fetchVersion(ric, createStdA1ClientImpl(ric)))
+            return fetchVersion(createSdnrOnapA1Client(ric)) //
+                .onErrorResume(err -> fetchVersion(createSdncOscA1Client(ric))) //
+                .onErrorResume(err -> fetchVersion(createOscA1Client(ric))) //
+                .onErrorResume(err -> fetchVersion(createStdA1ClientImpl(ric))) //
                 .doOnNext(version -> ric.setProtocolVersion(version))
                 .doOnNext(version -> logger.debug("Recover ric: {}, protocol version:{}", ric.name(), version)) //
                 .doOnError(t -> logger.warn("Could not get protocol version from RIC: {}", ric.name())); //
@@ -72,6 +86,10 @@
         }
     }
 
+    protected A1Client createOscA1Client(Ric ric) {
+        return new OscA1Client(ric.getConfig());
+    }
+
     protected A1Client createStdA1ClientImpl(Ric ric) {
         return new StdA1Client(ric.getConfig());
     }
@@ -86,9 +104,8 @@
             appConfig.getA1ControllerUsername(), appConfig.getA1ControllerPassword());
     }
 
-    private Mono<A1Client.A1ProtocolType> fetchVersion(Ric ric, A1Client a1Client) {
+    private Mono<A1ProtocolType> fetchVersion(A1Client a1Client) {
         return Mono.just(a1Client) //
             .flatMap(client -> a1Client.getProtocolVersion());
     }
-
 }
diff --git a/policy-agent/src/test/java/org/oransc/policyagent/clients/A1ClientFactoryTest.java b/policy-agent/src/test/java/org/oransc/policyagent/clients/A1ClientFactoryTest.java
new file mode 100644
index 0000000..926485d
--- /dev/null
+++ b/policy-agent/src/test/java/org/oransc/policyagent/clients/A1ClientFactoryTest.java
@@ -0,0 +1,212 @@
+/*-
+ * ========================LICENSE_START=================================
+ * O-RAN-SC
+ * %%
+ * Copyright (C) 2019 Nordix Foundation
+ * %%
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * ========================LICENSE_END===================================
+ */
+
+package org.oransc.policyagent.clients;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+import ch.qos.logback.classic.Level;
+import ch.qos.logback.classic.spi.ILoggingEvent;
+import ch.qos.logback.core.read.ListAppender;
+import java.util.Vector;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.oransc.policyagent.clients.A1Client.A1ProtocolType;
+import org.oransc.policyagent.configuration.ApplicationConfig;
+import org.oransc.policyagent.configuration.ImmutableRicConfig;
+import org.oransc.policyagent.repository.Ric;
+import org.oransc.policyagent.utils.LoggingUtils;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+@ExtendWith(MockitoExtension.class)
+public class A1ClientFactoryTest {
+    private static final String RIC_NAME = "Name";
+    private static final String EXCEPTION_MESSAGE = "Error";
+
+    @Mock
+    private ApplicationConfig applicationConfigMock;
+
+    @Mock
+    A1Client stdA1ClientMock;
+
+    @Mock
+    A1Client oscA1ClientMock;
+
+    @Mock
+    A1Client sdncOscA1ClientMock;
+
+    @Mock
+    A1Client sdnrOnapA1ClientMock;
+
+    private ImmutableRicConfig ricConfig =
+        ImmutableRicConfig.builder().name(RIC_NAME).baseUrl("baseUrl").managedElementIds(new Vector<>()).build();
+    private Ric ric = new Ric(ricConfig);
+
+    private A1ClientFactory factoryUnderTest;
+
+    @BeforeEach
+    public void createFactoryUnderTest() {
+        factoryUnderTest = spy(new A1ClientFactory(applicationConfigMock));
+    }
+
+    @Test
+    public void createStd_ok() {
+        whenGetProtocolVersionSdnrOnapA1ClientThrowException();
+        whenGetProtocolVersionSdncOscA1ClientThrowException();
+        whenGetProtocolVersionOscA1ClientThrowException();
+        whenGetProtocolVersionStdA1ClientReturnCorrectProtocol();
+
+        StepVerifier.create(factoryUnderTest.createA1Client(ric)) //
+            .expectSubscription() //
+            .expectNext(stdA1ClientMock) //
+            .verifyComplete();
+
+        assertEquals(A1ProtocolType.STD_V1, ric.getProtocolVersion(), "Not correct protocol");
+    }
+
+    @Test
+    public void createOsc_ok() {
+        whenGetProtocolVersionSdnrOnapA1ClientThrowException();
+        whenGetProtocolVersionSdncOscA1ClientThrowException();
+        whenGetProtocolVersionOscA1ClientReturnCorrectProtocol();
+
+        StepVerifier.create(factoryUnderTest.createA1Client(ric)) //
+            .expectSubscription() //
+            .expectNext(oscA1ClientMock) //
+            .verifyComplete();
+
+        assertEquals(A1ProtocolType.OSC_V1, ric.getProtocolVersion(), "Not correct protocol");
+    }
+
+    @Test
+    public void createSdncOsc_ok() {
+        whenGetProtocolVersionSdnrOnapA1ClientThrowException();
+        whenGetProtocolVersionSdncOscA1ClientReturnCorrectProtocol();
+
+        StepVerifier.create(factoryUnderTest.createA1Client(ric)) //
+            .expectSubscription() //
+            .expectNext(sdncOscA1ClientMock) //
+            .verifyComplete();
+
+        assertEquals(A1ProtocolType.SDNC_OSC, ric.getProtocolVersion(), "Not correct protocol");
+    }
+
+    @Test
+    public void createSdnrOnap_ok() {
+        whenGetProtocolVersionSdnrOnapA1ClientReturnCorrectProtocol();
+
+        StepVerifier.create(factoryUnderTest.createA1Client(ric)) //
+            .expectSubscription() //
+            .expectNext(sdnrOnapA1ClientMock) //
+            .verifyComplete();
+
+        assertEquals(A1ProtocolType.SDNR_ONAP, ric.getProtocolVersion(), "Not correct protocol");
+    }
+
+    @Test
+    public void createWithNoProtocol_error() {
+        whenGetProtocolVersionSdnrOnapA1ClientThrowException();
+        whenGetProtocolVersionSdncOscA1ClientThrowException();
+        whenGetProtocolVersionOscA1ClientThrowException();
+        whenGetProtocolVersionStdA1ClientThrowException();
+
+        final ListAppender<ILoggingEvent> logAppender = LoggingUtils.getLogListAppender(A1ClientFactory.class);
+        StepVerifier.create(factoryUnderTest.createA1Client(ric)) //
+            .expectSubscription() //
+            .expectErrorMatches(
+                throwable -> throwable instanceof Exception && throwable.getMessage().equals(EXCEPTION_MESSAGE))
+            .verify();
+
+        assertEquals(Level.WARN, logAppender.list.get(0).getLevel(), "Warning not logged");
+        assertTrue(logAppender.list.toString().contains("Could not get protocol version from RIC: " + RIC_NAME),
+            "Correct message not logged");
+
+        assertEquals(A1ProtocolType.UNKNOWN, ric.getProtocolVersion(), "Not correct protocol");
+    }
+
+    @Test
+    public void createWithProtocolInRic_noTrialAndError() {
+        doReturn(stdA1ClientMock).when(factoryUnderTest).createStdA1ClientImpl(any(Ric.class));
+
+        ric.setProtocolVersion(A1ProtocolType.STD_V1);
+
+        StepVerifier.create(factoryUnderTest.createA1Client(ric)) //
+            .expectSubscription() //
+            .expectNext(stdA1ClientMock) //
+            .verifyComplete();
+
+        assertEquals(A1ProtocolType.STD_V1, ric.getProtocolVersion(), "Not correct protocol");
+
+        verifyNoMoreInteractions(sdnrOnapA1ClientMock);
+        verifyNoMoreInteractions(sdncOscA1ClientMock);
+        verifyNoMoreInteractions(oscA1ClientMock);
+        verifyNoMoreInteractions(stdA1ClientMock);
+    }
+
+    private void whenGetProtocolVersionSdnrOnapA1ClientThrowException() {
+        doReturn(sdnrOnapA1ClientMock).when(factoryUnderTest).createSdnrOnapA1Client(ric);
+        when(sdnrOnapA1ClientMock.getProtocolVersion()).thenReturn(Mono.error(new Exception(EXCEPTION_MESSAGE)));
+    }
+
+    private void whenGetProtocolVersionSdnrOnapA1ClientReturnCorrectProtocol() {
+        doReturn(sdnrOnapA1ClientMock).when(factoryUnderTest).createSdnrOnapA1Client(any(Ric.class));
+        when(sdnrOnapA1ClientMock.getProtocolVersion()).thenReturn(Mono.just(A1ProtocolType.SDNR_ONAP));
+    }
+
+    private void whenGetProtocolVersionSdncOscA1ClientThrowException() {
+        doReturn(sdncOscA1ClientMock).when(factoryUnderTest).createSdncOscA1Client(any(Ric.class));
+        when(sdncOscA1ClientMock.getProtocolVersion()).thenReturn(Mono.error(new Exception(EXCEPTION_MESSAGE)));
+    }
+
+    private void whenGetProtocolVersionSdncOscA1ClientReturnCorrectProtocol() {
+        doReturn(sdncOscA1ClientMock).when(factoryUnderTest).createSdncOscA1Client(any(Ric.class));
+        when(sdncOscA1ClientMock.getProtocolVersion()).thenReturn(Mono.just(A1ProtocolType.SDNC_OSC));
+    }
+
+    private void whenGetProtocolVersionOscA1ClientThrowException() {
+        doReturn(oscA1ClientMock).when(factoryUnderTest).createOscA1Client(any(Ric.class));
+        when(oscA1ClientMock.getProtocolVersion()).thenReturn(Mono.error(new Exception(EXCEPTION_MESSAGE)));
+    }
+
+    private void whenGetProtocolVersionOscA1ClientReturnCorrectProtocol() {
+        doReturn(oscA1ClientMock).when(factoryUnderTest).createOscA1Client(any(Ric.class));
+        when(oscA1ClientMock.getProtocolVersion()).thenReturn(Mono.just(A1ProtocolType.OSC_V1));
+    }
+
+    private void whenGetProtocolVersionStdA1ClientThrowException() {
+        doReturn(stdA1ClientMock).when(factoryUnderTest).createStdA1ClientImpl(any(Ric.class));
+        when(stdA1ClientMock.getProtocolVersion()).thenReturn(Mono.error(new Exception(EXCEPTION_MESSAGE)));
+    }
+
+    private void whenGetProtocolVersionStdA1ClientReturnCorrectProtocol() {
+        doReturn(stdA1ClientMock).when(factoryUnderTest).createStdA1ClientImpl(any(Ric.class));
+        when(stdA1ClientMock.getProtocolVersion()).thenReturn(Mono.just(A1ProtocolType.STD_V1));
+    }
+}