Skip to content

Commit

Permalink
prompt reworking
Browse files Browse the repository at this point in the history
  • Loading branch information
markpollack committed Jul 29, 2023
1 parent 1b3d6a5 commit 4159758
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 146 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.springframework.ai.core.prompts;

import java.util.Map;

import org.springframework.ai.core.prompts.messages.AiMessage;

public class AiPromptTemplate extends PromptTemplate {

private boolean example = false;

public AiPromptTemplate(String template) {
super(template);
}

public AiPromptTemplate(String template, boolean example) {
super(template);
this.example = example;
}

@Override
public Prompt create() {
return new Prompt(new AiMessage(render()));
}

@Override
public Prompt create(Map<String, Object> model) {
return new Prompt(new AiMessage(render(model), this.example));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.springframework.ai.core.prompts;

import java.util.Map;

import org.springframework.ai.core.prompts.messages.ChatMessage;

public class ChatPromptTemplate extends PromptTemplate {

private String role;

public ChatPromptTemplate(String template) {
super(template);
}

public ChatPromptTemplate(String template, String role) {
super(template);
this.role = role;
}

@Override
public Prompt create() {
return new Prompt(new ChatMessage(render(), this.role));
}

@Override
public Prompt create(Map<String, Object> model) {
return new Prompt(new ChatMessage(render(model), this.role));
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.springframework.ai.core.prompts;

public class FunctionPromptTemplate extends PromptTemplate {

private String name;

public FunctionPromptTemplate(String template) {
super(template);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ public Prompt(String contents) {
this.messages = Collections.singletonList(new HumanMessage(contents));
}

public Prompt(Message message) {
this.messages = Collections.singletonList(message);
}

public Prompt(List<Message> messages) {
this.messages = messages;
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.Map;

import org.springframework.ai.core.prompts.messages.MessageType;

public interface PromptOperations {

String getTemplate();
Expand All @@ -30,6 +32,8 @@ public interface PromptOperations {

String render(Map<String, Object> model);

PromptBuilder prompt();
Prompt create();

Prompt create(Map<String, Object> model);

}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ public String render(Map<String, Object> model) {
return st.render().trim();
}

@Override
public Prompt create() {
return new Prompt(render(new HashMap<>()));
}

@Override
public Prompt create(Map<String, Object> model) {
return new Prompt(render(model));
}

protected Set<String> getInputVariables() {
TokenStream tokens = this.st.impl.tokens;
return IntStream.range(0, tokens.range())
Expand All @@ -97,94 +107,4 @@ protected void validate(Map<String, Object> model) {
"All template variables were not replaced. Missing variable names are " + missingEntries);
}
}

@Override
public PromptBuilder prompt() {
return new PromptTemplatePromptBuilder();
}

public class PromptTemplatePromptBuilder implements PromptBuilder {

private MessageType messageType = MessageType.HUMAN;

private Map<String, Object> model = new HashMap<>();

private Map<String, Object> properties = new HashMap<>();

private boolean containsExample;

private String chatRole;

private String functionName;

@Override
public PromptBuilder system() {
this.messageType = MessageType.SYSTEM;
return this;
}

@Override
public PromptBuilder human() {
this.messageType = MessageType.HUMAN;
return this;
}

@Override
public PromptBuilder ai(boolean containsExample) {
this.messageType = MessageType.AI;
this.containsExample = containsExample;
return this;
}

@Override
public PromptBuilder chat(String chatRole) {
this.messageType = MessageType.CHAT;
this.chatRole = chatRole;
return this;
}

@Override
public PromptBuilder function(String functionName) {
this.messageType = MessageType.FUNCTION;
this.functionName = functionName;
return this;
}

@Override
public PromptBuilder usingModel(Map<String, Object> model) {
this.model = model;
return this;
}

@Override
public PromptBuilder withProperties(Map<String, Object> properties) {
this.properties = properties;
return this;
}

@Override
public Prompt create() {

switch (messageType) {
case HUMAN:
return newPrompt(new HumanMessage(render(model), properties));
case AI:
return newPrompt(new AiMessage(render(model), containsExample, properties));
case CHAT:
return newPrompt(new ChatMessage(render(model), chatRole, properties));
case SYSTEM:
return newPrompt(new SystemMessage(render(model), properties));
case FUNCTION:
return newPrompt(new FunctionMessage(render(model), functionName, properties));
default:
throw new IllegalArgumentException("Invalid MessageType: " + messageType);
}
}

private Prompt newPrompt(Message message) {
return new Prompt(Collections.singletonList(message));
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.springframework.ai.core.prompts;

import java.util.Map;

import org.springframework.ai.core.prompts.messages.SystemMessage;

public class SystemPromptTemplate extends PromptTemplate {

public SystemPromptTemplate(String template) {
super(template);
}

@Override
public Prompt create() {
return new Prompt(new SystemMessage(render()));
}

@Override
public Prompt create(Map<String, Object> model) {
return new Prompt(new SystemMessage(render(model)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public interface Message {

String getContent();

//TODO investigate use of "function_name" and "name" - maybe cna be first class representation vs. map.
Map<String, Object> getProperties();

MessageType getMessageType();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.springframework.ai.core.prompts;

import org.junit.jupiter.api.Test;

public class ChatTests {

// @Test
// void testChat() {
//
// String customerStyle = "American English in a calm and respectful tone";
// String customerEmail = "Arrr, I be fuming that me blender lid "
// + "flew off and splattered me kitchen walls "
// + "with smoothie! And to make matters worse, "
// + "the warranty don't cover the cost of "
// + "cleaning up me kitchen. I need yer help "
// + "right now, matey!";
// ChatOpenAi chatOpenAi = new ChatOpenAi();
// chatOpenAi
//
// }
}
Loading

0 comments on commit 4159758

Please sign in to comment.