Skip to content

Commit

Permalink
ml wip pt 2
Browse files Browse the repository at this point in the history
  • Loading branch information
sava41 committed Jan 4, 2025
1 parent d86766f commit 44c082d
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 37 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Cross platform 2d infinite canvas app inspired by [dingboard.com](https://dingbo
**Add and Edit Images**
![Image Feature Gif](./resources/images/image.gif)

**Auto Image Segmentation (Desktop Only)**
![Segmentation Feature Gif](./resources/images/image.gif)

## Requirements:
- CMake 3.28 or later
- Python 3.8 or newer
Expand Down Expand Up @@ -48,7 +51,6 @@ https://emscripten.org/docs/getting_started/downloads.html#installation-instruct
```

## Planned Features
- layer AI masking
- canvas history (undo/redo)
- copy/paste
- keyboard shortcuts
Expand Down
20 changes: 10 additions & 10 deletions source/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ namespace mc
addImageToLayer( app, imageData, width, height, 4 );
}
},
app, nullptr, filters, 3, nullptr, SDL_FALSE );
app, nullptr, filters, 3, nullptr, false );

#endif
}
Expand All @@ -129,16 +129,16 @@ namespace mc
{
AppContext* app = reinterpret_cast<mc::AppContext*>( userdata );
#endif
int width = app->textureManager.get( *app->copyTextureHandle.get() ).texture.GetWidth();
int height = app->textureManager.get( *app->copyTextureHandle.get() ).texture.GetHeight();
app->copyTextureHandle.reset();
int width = app->textureManager.get( *app->copyTextureHandle.get() ).texture.GetWidth();
int height = app->textureManager.get( *app->copyTextureHandle.get() ).texture.GetHeight();
app->copyTextureHandle.reset();

const uint8_t* imageData = reinterpret_cast<const uint8_t*>( app->textureMapBuffer.GetConstMappedRange( 0, app->textureMapBuffer.GetSize() ) );
const uint8_t* imageData = reinterpret_cast<const uint8_t*>( app->textureMapBuffer.GetConstMappedRange( 0, app->textureMapBuffer.GetSize() ) );

if( imageData == nullptr )
{
app->textureMapBuffer.Unmap();
return;
if( imageData == nullptr )
{
app->textureMapBuffer.Unmap();
return;
}

// we need to find the stride since the buffer is padded to be a multiple of 256
Expand All @@ -165,7 +165,7 @@ namespace mc
},
app, nullptr, filters, 1, nullptr );
#else
emscripten_browser_file::download( "miskeen.png", "image/png", data.get(), length );
emscripten_browser_file::download( "miskeen.png", "image/png", data.get(), length );
#endif
}
} // namespace mc
21 changes: 14 additions & 7 deletions source/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ SDL_AppResult SDL_AppInit( void** appstate, int argc, char* argv[] )
mc::initPipelines( app );
mc::initImageProcessingPipelines( app );

// print some information about the window
SDL_ShowWindow( app->window );
if( SDL_ShowWindow( app->window ) )
{
SDL_Log( "Window size: %ix%i", app->width, app->height );
SDL_Log( "Backbuffer size: %ix%i", app->bbwidth, app->bbheight );
Expand All @@ -112,8 +111,9 @@ SDL_AppResult SDL_AppInit( void** appstate, int argc, char* argv[] )

app->fontManager.init( app->textureManager, app->device, app->meshManager.getMeshInfo( mc::UnitSquareMeshIndex ) );

app->mlInference = std::make_unique<mc::MlInference>( "C:/Users/sava4/OneDrive/Desktop/miskeenity-canvas/build/Release/sam_preprocess.onnx",
"sam_vit_h_4b8939.onnx", std::thread::hardware_concurrency() );
#if !defined( SDL_PLATFORM_EMSCRIPTEN )
app->mlInference = std::make_unique<mc::MlInference>( "sam_preprocess.onnx", "sam_vit_h_4b8939.onnx", std::thread::hardware_concurrency() );
#endif

SDL_Log( "Application started successfully!" );

Expand Down Expand Up @@ -348,6 +348,13 @@ void proccessUserEvent( const SDL_Event* sdlEvent, mc::AppContext* app )
app->layerEditStart += 1;
}
}
else if( app->mode == mc::Mode::SegmentCut )
{
app->editMaskTextureHandle.reset();
// app->editMaskTextureHandle = std::make_unique<mc::ResourceHandle>(
// app->textureManager.add( nullptr, app->mlInference->getMaxWidth(), app->mlInference->getMaxHeight(), 4, app->device,
// wgpu::TextureUsage::CopyDst | wgpu::TextureUsage::TextureBinding ) );
}
else
{
app->layers.clearSelection();
Expand Down Expand Up @@ -456,7 +463,7 @@ void proccessUserEvent( const SDL_Event* sdlEvent, mc::AppContext* app )
}
}

SDL_AppResult SDL_AppEvent( void* appstate, const SDL_Event* event )
SDL_AppResult SDL_AppEvent( void* appstate, SDL_Event* event )
{
mc::AppContext* app = reinterpret_cast<mc::AppContext*>( appstate );

Expand Down Expand Up @@ -1095,7 +1102,7 @@ SDL_AppResult SDL_AppIterate( void* appstate )
app->rasterizeSelection = false;
}

if( app->mode == mc::Mode::Cut && app->layers.length() > 0 )
if( ( app->mode == mc::Mode::Cut ) && app->layers.length() > 0 )
{
int index = app->layers.getSingleSelectedImage();
mc::Layer layer = app->layers.data()[index];
Expand Down Expand Up @@ -1227,7 +1234,7 @@ SDL_AppResult SDL_AppIterate( void* appstate )
return app->appQuit ? SDL_APP_SUCCESS : SDL_APP_CONTINUE;
}

void SDL_AppQuit( void* appstate )
void SDL_AppQuit( void* appstate, SDL_AppResult result )
{
mc::shutdownUI();

Expand Down
41 changes: 33 additions & 8 deletions source/ml_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "image.h"

#include <SDL3/SDL.h>
#include <SDL3/SDL_filesystem.h>
#define ORT_API_MANUAL_INIT
#include <codecvt>
#include <filesystem>
Expand Down Expand Up @@ -34,8 +35,13 @@ namespace mc
MlInference::MlInference( const std::string& preModelPath, const std::string& samModelPath, int threadsNumber )
: m_valid( true )
{
if( !std::filesystem::exists( preModelPath ) || std::filesystem::exists( samModelPath ) )
const std::string fullModelPath = std::string( SDL_GetBasePath() ) + preModelPath;
const std::string fullSamModelPath = std::string( SDL_GetBasePath() ) + samModelPath;


if( !SDL_GetPathInfo( fullModelPath.c_str(), nullptr ) || !SDL_GetPathInfo( fullSamModelPath.c_str(), nullptr ) )
{
SDL_Log( "lel %s\n %s\n %s", fullModelPath.c_str(), fullSamModelPath.c_str(), SDL_GetError() );
m_valid = false;
return;
}
Expand Down Expand Up @@ -84,14 +90,13 @@ namespace mc
}
}

bool MlInference::pipelineValid() const
MlInference::~MlInference()
{
return m_valid;
}

bool MlInference::loadInput( const uint8_t* buffer, int len, int& width, int& height )
{
if( !m_valid || width > m_onnxData->inputShapePre[3] || height > m_onnxData->inputShapePre[2] )
if( !m_valid || width > getMaxWidth() || height > getMaxHeight() )
{
return false;
}
Expand Down Expand Up @@ -141,10 +146,6 @@ namespace mc
return true;
}

MlInference::~MlInference()
{
}

bool MlInference::getMask( void* imageData, int width, int height, const std::vector<glm::vec2>& points )
{
if( points.size() == 0 )
Expand Down Expand Up @@ -200,4 +201,28 @@ namespace mc
return true;
}

bool MlInference::pipelineValid() const
{
return m_valid;
}

int MlInference::getMaxWidth() const
{
if( m_valid )
{
return m_onnxData->inputShapePre[3];
}

return 0;
}
int MlInference::getMaxHeight() const
{
if( m_valid )
{
return m_onnxData->inputShapePre[2];
}

return 0;
}

} // namespace mc
2 changes: 2 additions & 0 deletions source/ml_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace mc
bool getMask( void* imageData, int width, int height, const std::vector<glm::vec2>& points );

bool pipelineValid() const;
int getMaxWidth() const;
int getMaxHeight() const;

private:
std::unique_ptr<OnnxData> m_onnxData;
Expand Down
89 changes: 80 additions & 9 deletions source/ui.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ namespace mc
ImGuiStyle& style = ImGui::GetStyle();

style.Alpha = 1.0f;
style.DisabledAlpha = 1.0f;
style.DisabledAlpha = 0.5f;
style.WindowPadding = glm::vec2( 12.0f, 12.0f );
style.WindowRounding = 3.0f;
style.WindowBorderSize = 0.0f;
Expand Down Expand Up @@ -750,9 +750,13 @@ namespace mc

if( app->layers.getSingleSelectedImage() >= 0 )
{
std::array<std::string, 3> imageTools = { ICON_LC_CROP, ICON_LC_SQUARE_BOTTOM_DASHED_SCISSORS, ICON_LC_SQUARE_DASHED_MOUSE_POINTER };
std::array<std::string, 3> imageTooltips = { "Crop", "Cut", "TODO: Segment Cut" };
std::array<Mode, 4> imageToolModes = { Mode::Crop, Mode::Cut, Mode::SegmentCut };
std::array<std::string, 3> imageTools = { ICON_LC_CROP, ICON_LC_SQUARE_BOTTOM_DASHED_SCISSORS, ICON_LC_SQUARE_DASHED_MOUSE_POINTER };
#if defined( SDL_PLATFORM_EMSCRIPTEN )
std::array<std::string, 3> imageTooltips = { "Crop", "Cut", "Desktop Only" };
#else
std::array<std::string, 3> imageTooltips = { "Crop", "Cut", "Segment Cut" };
#endif
std::array<Mode, 3> imageToolModes = { Mode::Crop, Mode::Cut, Mode::SegmentCut };

for( size_t i = 0; i < imageTools.size(); i++ )
{
Expand All @@ -762,13 +766,29 @@ namespace mc
color = ImGui::ColorConvertU32ToFloat4( Spectrum::PURPLE400 );
}

#if defined( SDL_PLATFORM_EMSCRIPTEN )
if( i == 2 )
{
ImGui::BeginDisabled();
}
#endif

ImGui::PushStyleColor( ImGuiCol_Button, color );
if( ImGui::Button( imageTools[i].c_str(), buttonSize ) )
{
submitEvent( Events::ChangeMode, { .mode = imageToolModes[i] } );
}
if( ImGui::IsItemHovered( ImGuiHoveredFlags_DelayNormal | ImGuiHoveredFlags_NoSharedDelay | ImGuiHoveredFlags_Stationary ) )
ImGui::SetItemTooltip( imageTooltips[i].c_str() );

#if defined( SDL_PLATFORM_EMSCRIPTEN )
if( i == 2 )
{
ImGui::SetItemTooltip( imageTooltips[i].c_str() );
ImGui::EndDisabled();
}
#endif

ImGui::SameLine( 0.0, buttonSpacing );
ImGui::PopStyleColor( 1 );
}
Expand Down Expand Up @@ -1016,6 +1036,54 @@ namespace mc
}
ImGui::End();
}
else if( app->mode == Mode::SegmentCut )
{
ImGui::SetNextWindowPos( glm::vec2( buttonSpacing, app->height - 180.0 * g_uiScale - buttonSpacing ), ImGuiCond_Appearing );
ImGui::SetNextWindowSize( glm::vec2( 350.0, 180.0 ) * g_uiScale, ImGuiCond_FirstUseEver );

ImGui::Begin( "Cut Image Via Auto Segmentation", nullptr,
ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoScrollbar | ImGuiWindowFlags_NoScrollWithMouse );
{

ImGui::PushItemWidth( ImGui::GetContentRegionAvail().x );
float width = ( ImGui::GetContentRegionAvail().x - 8 ) * 0.5;

if( app->mlInference.get() && app->mlInference->pipelineValid() )
{
ImGui::Text( "Click image to select segmentation regions" );
}
else
{
ImGui::TextColored( ImGui::ColorConvertU32ToFloat4( Spectrum::RED700 ), "Model files not found or are invalid" );
ImGui::BeginDisabled();
}

if( ImGui::Button( "Reset Points", glm::vec2( width, 0.0 ) ) )
{
// todo
}

ImGui::SeparatorText( "" );

if( ImGui::Button( "Apply", glm::vec2( width, 0.0 ) ) )
{
// submitEvent( Events::Cut );
submitEvent( Events::ChangeMode, { .mode = Mode::Cursor } );
}

ImGui::EndDisabled();
if( !app->mlInference.get() && !app->mlInference->pipelineValid() )
{
}

ImGui::SameLine( 0.0, 8.0 );
if( ImGui::Button( "Cancel", glm::vec2( width, 0.0 ) ) )
{
submitEvent( Events::ChangeMode, { .mode = Mode::Cursor } );
}
}
ImGui::End();
}


ImDrawList* drawList = ImGui::GetBackgroundDrawList();
Expand Down Expand Up @@ -1101,7 +1169,7 @@ namespace mc
drawList->AddCircleFilled( g_transformBox.cornerHandleTL, HandleHalfSize * g_uiScale - ceilf( g_uiScale ), color );
}

if( app->mode == Mode::Cut )
if( app->mode == Mode::Cut || app->mode == Mode::SegmentCut )
{
drawShadedRectangleMask(
app->width, app->height, app->selectionAabb * app->viewParams.scale + glm::vec4( app->viewParams.canvasPos, app->viewParams.canvasPos ),
Expand All @@ -1112,10 +1180,13 @@ namespace mc
glm::vec2 uvTop = glm::vec2( app->layers.data()[index].uvTop ) / float( UV_MAX_VALUE );
glm::vec2 uvBottom = glm::vec2( app->layers.data()[index].uvBottom ) / float( UV_MAX_VALUE );

drawList->AddImageQuad( (ImTextureID)(intptr_t)app->textureManager.get( *app->editMaskTextureHandle.get() ).textureView.Get(),
g_transformBox.cornerHandleTL, g_transformBox.cornerHandleTR, g_transformBox.cornerHandleBR, g_transformBox.cornerHandleBL,
uvTop, glm::vec2( uvBottom.x, uvTop.y ), uvBottom, glm::vec2( uvTop.x, uvBottom.y ),
Spectrum::ORANGE600 & 0x00FFFFFF | 0x55000000 );
if( app->editMaskTextureHandle.get() )
{
drawList->AddImageQuad( (ImTextureID)(intptr_t)app->textureManager.get( *app->editMaskTextureHandle.get() ).textureView.Get(),
g_transformBox.cornerHandleTL, g_transformBox.cornerHandleTR, g_transformBox.cornerHandleBR,
g_transformBox.cornerHandleBL, uvTop, glm::vec2( uvBottom.x, uvTop.y ), uvBottom, glm::vec2( uvTop.x, uvBottom.y ),
Spectrum::ORANGE600 & 0x00FFFFFF | 0x55000000 );
}

drawList->AddLine( g_transformBox.cornerHandleTL, g_transformBox.cornerHandleTR, Spectrum::PURPLE400, ceilf( g_uiScale ) );
drawList->AddLine( g_transformBox.cornerHandleTR, g_transformBox.cornerHandleBR, Spectrum::PURPLE400, ceilf( g_uiScale ) );
Expand Down
5 changes: 4 additions & 1 deletion third_party/imgui/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ include(FetchContent)
FetchContent_Declare(
imgui
GIT_REPOSITORY https://github.com/sava41/imgui.git
GIT_TAG 9907cf0f9cf377381b16200d1266467a81a5fbf4
GIT_TAG 2850877545261f070d445085e39721e66de1c732
)
FetchContent_GetProperties(imgui)

Expand All @@ -29,6 +29,9 @@ if(NOT imgui_POPULATED)
target_link_libraries(imgui PUBLIC SDL3::SDL3)

if (NOT CMAKE_SYSTEM_NAME STREQUAL Emscripten)
target_compile_definitions(imgui PRIVATE IMGUI_IMPL_WEBGPU_BACKEND_WGPU)
target_link_libraries(imgui PUBLIC webgpu_dawn)
else()
target_compile_definitions(imgui PRIVATE IMGUI_IMPL_WEBGPU_BACKEND_DAWN)
endif()
endif()
2 changes: 1 addition & 1 deletion third_party/sdl3/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ include(FetchContent)
FetchContent_Declare(
SDL3
GIT_REPOSITORY https://github.com/libsdl-org/SDL.git
GIT_TAG e75175129f83c3d1e85572d03ec070177de8abc4
GIT_TAG preview-3.1.6
)
set(BUILD_SHARED_LIBS FALSE)
FetchContent_MakeAvailable(SDL3)
Expand Down

0 comments on commit 44c082d

Please sign in to comment.