Skip to content
This repository was archived by the owner on Feb 11, 2025. It is now read-only.

Commit 5669e6a

Browse files
Add support for audio chat completion (#54)
closes #53
1 parent 8263f57 commit 5669e6a

File tree

6 files changed

+165
-8
lines changed

6 files changed

+165
-8
lines changed

src/main/java/dev/ai4j/openai4j/chat/Content.java

+17-1
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@ public final class Content {
2121
private final String text;
2222
@JsonProperty
2323
private final ImageUrl imageUrl;
24+
@JsonProperty
25+
private final InputAudio inputAudio;
2426

2527
public Content(Builder builder) {
2628
this.type = builder.type;
2729
this.text = builder.text;
2830
this.imageUrl = builder.imageUrl;
31+
this.inputAudio = builder.inputAudio;
2932
}
3033

3134
public ContentType type() {
@@ -40,6 +43,10 @@ public ImageUrl imageUrl() {
4043
return imageUrl;
4144
}
4245

46+
public InputAudio inputAudio() {
47+
return inputAudio;
48+
}
49+
4350
@Override
4451
public boolean equals(Object another) {
4552
if (this == another) return true;
@@ -50,7 +57,8 @@ public boolean equals(Object another) {
5057
private boolean equalTo(Content another) {
5158
return Objects.equals(type, another.type)
5259
&& Objects.equals(text, another.text)
53-
&& Objects.equals(imageUrl, another.imageUrl);
60+
&& Objects.equals(imageUrl, another.imageUrl)
61+
&& Objects.equals(inputAudio, another.inputAudio);
5462
}
5563

5664
@Override
@@ -59,6 +67,7 @@ public int hashCode() {
5967
h += (h << 5) + Objects.hashCode(type);
6068
h += (h << 5) + Objects.hashCode(text);
6169
h += (h << 5) + Objects.hashCode(imageUrl);
70+
h += (h << 5) + Objects.hashCode(inputAudio);
6271
return h;
6372
}
6473

@@ -68,6 +77,7 @@ public String toString() {
6877
"type=" + type +
6978
", text=" + text +
7079
", imageUrl=" + imageUrl +
80+
", inputAudio=" + inputAudio +
7181
"}";
7282
}
7383

@@ -83,6 +93,7 @@ public static final class Builder {
8393
private ContentType type;
8494
private String text;
8595
private ImageUrl imageUrl;
96+
private InputAudio inputAudio;
8697

8798
public Builder type(ContentType type) {
8899
this.type = type;
@@ -99,6 +110,11 @@ public Builder imageUrl(ImageUrl imageUrl) {
99110
return this;
100111
}
101112

113+
public Builder inputAudio(InputAudio inputAudio) {
114+
this.inputAudio = inputAudio;
115+
return this;
116+
}
117+
102118
public Content build() {
103119
return new Content(this);
104120
}

src/main/java/dev/ai4j/openai4j/chat/ContentType.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ public enum ContentType {
77
@JsonProperty("text")
88
TEXT,
99
@JsonProperty("image_url")
10-
IMAGE_URL
10+
IMAGE_URL,
11+
@JsonProperty("input_audio")
12+
AUDIO
1113
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package dev.ai4j.openai4j.chat;
2+
3+
import com.fasterxml.jackson.annotation.JsonInclude;
4+
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
5+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
6+
import com.fasterxml.jackson.databind.annotation.JsonNaming;
7+
8+
import java.util.Objects;
9+
10+
@JsonDeserialize(builder = InputAudio.Builder.class)
11+
@JsonInclude(JsonInclude.Include.NON_NULL)
12+
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
13+
public class InputAudio {
14+
15+
private final String data;
16+
private final String format;
17+
18+
private InputAudio(Builder builder) {
19+
data = builder.data;
20+
format = builder.format;
21+
}
22+
23+
public String getData() {
24+
return data;
25+
}
26+
27+
public String getFormat() {
28+
return format;
29+
}
30+
31+
@Override
32+
public boolean equals(Object another) {
33+
if (this == another) return true;
34+
return another instanceof InputAudio
35+
&& equalTo((InputAudio) another);
36+
}
37+
38+
private boolean equalTo(InputAudio another) {
39+
return Objects.equals(data, another.data)
40+
&& Objects.equals(format, another.format);
41+
}
42+
43+
@Override
44+
public int hashCode() {
45+
int h = 5381;
46+
h += (h << 5) + Objects.hashCode(data);
47+
h += (h << 5) + Objects.hashCode(format);
48+
return h;
49+
}
50+
51+
@Override
52+
public String toString() {
53+
return "InputAudio{" +
54+
"data=" + data +
55+
", format=" + format +
56+
"}";
57+
}
58+
59+
public static Builder builder() {
60+
return new Builder();
61+
}
62+
63+
public static final class Builder {
64+
65+
private String data;
66+
private String format;
67+
68+
public Builder data(String data) {
69+
this.data = data;
70+
return this;
71+
}
72+
73+
public Builder format(String format) {
74+
this.format = format;
75+
return this;
76+
}
77+
78+
public static Builder builder() {
79+
return new Builder();
80+
}
81+
82+
public InputAudio build() {
83+
return new InputAudio(this);
84+
}
85+
86+
}
87+
}

src/main/java/dev/ai4j/openai4j/chat/UserMessage.java

+20-6
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ private Builder() {
107107
}
108108

109109
public Builder addText(String text) {
110-
if (this.content == null) {
111-
this.content = new ArrayList<>();
112-
}
110+
initializeContent();
113111
Content content = Content.builder()
114112
.type(TEXT)
115113
.text(text)
@@ -123,9 +121,7 @@ public Builder addImageUrl(String imageUrl) {
123121
}
124122

125123
public Builder addImageUrl(String imageUrl, ImageDetail imageDetail) {
126-
if (this.content == null) {
127-
this.content = new ArrayList<>();
128-
}
124+
initializeContent();
129125
Content content = Content.builder()
130126
.type(IMAGE_URL)
131127
.imageUrl(ImageUrl.builder()
@@ -143,6 +139,18 @@ public Builder addImageUrls(String... imageUrls) {
143139
}
144140
return this;
145141
}
142+
143+
public Builder addInputAudio(InputAudio inputAudio) {
144+
initializeContent();
145+
this.content.add(
146+
Content.builder()
147+
.type(ContentType.AUDIO)
148+
.inputAudio(inputAudio)
149+
.build()
150+
);
151+
152+
return this;
153+
}
146154

147155
public Builder content(List<Content> content) {
148156
if (content != null) {
@@ -164,5 +172,11 @@ public Builder name(String name) {
164172
public UserMessage build() {
165173
return new UserMessage(this);
166174
}
175+
176+
private void initializeContent() {
177+
if (this.content == null) {
178+
this.content = new ArrayList<>();
179+
}
180+
}
167181
}
168182
}

src/test/java/dev/ai4j/openai4j/chat/ChatCompletionAsyncTest.java

+37
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import org.junit.jupiter.params.ParameterizedTest;
77
import org.junit.jupiter.params.provider.EnumSource;
88

9+
import java.net.URL;
10+
import java.nio.file.Files;
11+
import java.nio.file.Paths;
912
import java.util.Map;
1013
import java.util.concurrent.CompletableFuture;
1114

@@ -479,4 +482,38 @@ void testGpt4Vision() throws Exception {
479482
// then
480483
assertThat(response.content()).containsIgnoringCase("cat");
481484
}
485+
486+
@Test
487+
void testGpt4Audio() throws Exception {
488+
489+
// given
490+
URL resource = getClass().getClassLoader().getResource("sample.b64");
491+
final byte[] bytes = Files.readAllBytes(Paths.get(resource.toURI()));
492+
493+
ChatCompletionRequest request = ChatCompletionRequest.builder()
494+
.model("gpt-4o-audio-preview")
495+
.messages(UserMessage.builder()
496+
.addText("What is on the audio?")
497+
.addInputAudio(InputAudio.builder()
498+
.format("wav")
499+
.data(new String(bytes))
500+
.build())
501+
.build())
502+
.maxCompletionTokens(100)
503+
.temperature(0.0)
504+
.build();
505+
506+
CompletableFuture<ChatCompletionResponse> future = new CompletableFuture<>();
507+
508+
// when
509+
client.chatCompletion(request)
510+
.onResponse(future::complete)
511+
.onError(future::completeExceptionally)
512+
.execute();
513+
514+
ChatCompletionResponse response = future.get(30, SECONDS);
515+
516+
// then
517+
assertThat(response.content()).containsIgnoringCase("hello");
518+
}
482519
}

src/test/resources/sample.b64

+1
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)