1 | package org.atea.nlptools.koreromaoriinterface.services;
|
---|
2 |
|
---|
3 | import java.io.InputStream;
|
---|
4 | import java.util.ArrayList;
|
---|
5 | import java.util.LinkedList;
|
---|
6 | import java.util.List;
|
---|
7 | import java.util.Queue;
|
---|
8 | import java.util.concurrent.Callable;
|
---|
9 | import java.util.concurrent.ExecutorService;
|
---|
10 | import java.util.concurrent.Future;
|
---|
11 |
|
---|
12 | import com.google.gson.Gson;
|
---|
13 | import com.google.gson.JsonSyntaxException;
|
---|
14 |
|
---|
15 | import org.apache.logging.log4j.LogManager;
|
---|
16 | import org.apache.logging.log4j.Logger;
|
---|
17 | import org.atea.nlptools.koreromaoriinterface.exceptions.ReoTuhituhiException;
|
---|
18 | import org.atea.nlptools.koreromaoriinterface.models.AudioFilePart;
|
---|
19 | import org.atea.nlptools.koreromaoriinterface.models.TranscriptionResult;
|
---|
20 | import org.atea.nlptools.koreromaoriinterface.services.HttpRequestService.HttpRequestException;
|
---|
21 |
|
---|
22 | /**
|
---|
23 | * Functions to interact with the Reo Tuhituhi API.
|
---|
24 | */
|
---|
25 | public class ReoTuhituhiApiService
|
---|
26 | {
|
---|
27 | private static final Logger logger = LogManager.getLogger(ReoTuhituhiApiService.class);
|
---|
28 |
|
---|
29 | private final ExecutorService threadPool;
|
---|
30 | private final Gson jsonSerialiser;
|
---|
31 | private final String apiEndpoint;
|
---|
32 | private final String apiKey;
|
---|
33 |
|
---|
34 | public ReoTuhituhiApiService(Gson jsonSerialiser, String apiEndpoint, String apiKey)
|
---|
35 | {
|
---|
36 | this.jsonSerialiser = jsonSerialiser;
|
---|
37 | this.apiEndpoint = apiEndpoint;
|
---|
38 | this.apiKey = apiKey;
|
---|
39 |
|
---|
40 | threadPool = java.util.concurrent.Executors.newFixedThreadPool(3);
|
---|
41 | }
|
---|
42 |
|
---|
43 | /**
|
---|
44 | * Queries the Reo Tuhituhi API to transcribe the given audio files.
|
---|
45 | *
|
---|
46 | * @param audioFileParts The audio files to retrieve a transcription for.
|
---|
47 | * @return A list of {@link TranscriptionResult} objects.
|
---|
48 | * @throws HttpRequestException When the API call fails.
|
---|
49 | * @throws JsonSyntaxException When the result cannot be parsed.
|
---|
50 | */
|
---|
51 | public List<AudioFilePart> getTranscriptions(Iterable<AudioFilePart> audioFileParts)
|
---|
52 | throws HttpRequestException, JsonSyntaxException, Exception
|
---|
53 | {
|
---|
54 | Queue<Future<AudioFilePart>> apiCalls = new LinkedList<Future<AudioFilePart>>();
|
---|
55 |
|
---|
56 | // Queue each transcription request up asynchronously
|
---|
57 | for (AudioFilePart part : audioFileParts)
|
---|
58 | {
|
---|
59 | //Callable<TranscriptionResult> transcriptionTask = getTranscriptionCallable(audioStream);
|
---|
60 | //apiCalls.add(threadPool.submit(transcriptionTask));
|
---|
61 | //logger.debug("Adding transcription task to thread pool.");
|
---|
62 | }
|
---|
63 |
|
---|
64 | List<AudioFilePart> apiResults = new ArrayList<AudioFilePart>(apiCalls.size());
|
---|
65 |
|
---|
66 | for (AudioFilePart part : audioFileParts)
|
---|
67 | {
|
---|
68 | TranscriptionResult res = getTranscription(part.dataStream);
|
---|
69 | part.setTranscriptionResult(res);
|
---|
70 |
|
---|
71 | apiResults.add(part);
|
---|
72 | }
|
---|
73 |
|
---|
74 | // Wait on the result of each call
|
---|
75 | // TODO: Implement proper timeout here
|
---|
76 | // while (!apiCalls.isEmpty())
|
---|
77 | // {
|
---|
78 | // apiResults.add(apiCalls.remove().get());
|
---|
79 | // logger.debug("API call has completed.");
|
---|
80 | // }
|
---|
81 |
|
---|
82 | return apiResults;
|
---|
83 | }
|
---|
84 |
|
---|
85 | /**
|
---|
86 | * Calls the Reo Tuhituhi API to transcribe a wave audio file.
|
---|
87 | *
|
---|
88 | * @param audioStream The wave audio stream.
|
---|
89 | * @return A {@link TranscriptionResult} object.
|
---|
90 | * @throws HttpRequestException Thrown when the API call fails.
|
---|
91 | * @throws JsonSyntaxException Thrown when the result cannot be parsed.
|
---|
92 | * @throws ReoTuhituhiException Thrown when the Reo Tuhituhi API returns an invalid response.
|
---|
93 | */
|
---|
94 | public TranscriptionResult getTranscription(InputStream audioStream)
|
---|
95 | throws HttpRequestException, JsonSyntaxException, ReoTuhituhiException
|
---|
96 | {
|
---|
97 | HttpRequestService request = HttpRequestService
|
---|
98 | .post(apiEndpoint)
|
---|
99 | .authorization("Basic " + apiKey)
|
---|
100 | .send(audioStream);
|
---|
101 |
|
---|
102 | // Check that the request returned an OK status
|
---|
103 | if (!request.ok())
|
---|
104 | {
|
---|
105 | logger.error
|
---|
106 | (
|
---|
107 | "The Reo Tuhituhi API returned a non-OK status code {} with message {}",
|
---|
108 | request.code(),
|
---|
109 | request.message()
|
---|
110 | );
|
---|
111 |
|
---|
112 | throw new ReoTuhituhiException(request.code(), request.message(), "Non-OK status code");
|
---|
113 | }
|
---|
114 |
|
---|
115 | // Check that the content type is valid
|
---|
116 | if (request.contentType() != "application/json")
|
---|
117 | {
|
---|
118 | logger.error
|
---|
119 | (
|
---|
120 | "The Reo Tuhituhi API returned an invalid content type {}. Provided content was {}",
|
---|
121 | request.contentType(),
|
---|
122 | request.body()
|
---|
123 | );
|
---|
124 |
|
---|
125 | throw new ReoTuhituhiException(request.code(), request.message(), "Invalid content type: " + request.contentType());
|
---|
126 | }
|
---|
127 |
|
---|
128 | TranscriptionResult res = jsonSerialiser.fromJson(request.body(), TranscriptionResult.class);
|
---|
129 |
|
---|
130 | return res;
|
---|
131 | }
|
---|
132 |
|
---|
133 | private Callable<AudioFilePart> getTranscriptionCallable(final AudioFilePart part)
|
---|
134 | {
|
---|
135 | return new Callable<AudioFilePart>()
|
---|
136 | {
|
---|
137 | @Override
|
---|
138 | public AudioFilePart call()
|
---|
139 | throws Exception
|
---|
140 | {
|
---|
141 | TranscriptionResult res = getTranscription(part.dataStream);
|
---|
142 | part.setTranscriptionResult(res);
|
---|
143 | return part;
|
---|
144 | }
|
---|
145 | };
|
---|
146 | }
|
---|
147 | }
|
---|