kissinference 1.4.x
train.h
1/* * libkissinference - an inference libary for kiss networks
2 * Copyright (C) 2024 Carl Philipp Klemm <carl@uvos.xyz>
3 *
4 * This file is part of libkissinference.
5 *
6 * libkissinference is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Lesser General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * libkissinference is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public License
17 * along with libkissinference. If not, see <http://www.gnu.org/licenses/>.
18 */
19
20#pragma once
21#include <stdbool.h>
22#include <stddef.h>
23
24#ifdef __cplusplus
25extern "C" {
26#endif
27
37typedef enum {
38 KISS_LOSS_NNL,
39} kiss_loss_t;
40
46
47struct kiss_train_priv;
48struct kiss_network;
49
54 bool ready;
55 void *user_ptr;
56 struct kiss_network *net;
58 struct kiss_train_priv *priv;
59};
60
62 float *in;
63 size_t in_size;
64 float *out;
65 size_t out_size;
66};
67
75typedef struct kiss_training_data *(*kiss_data_feed_cb_t)(struct kiss_training_session *session, void *user_ptr);
76
83typedef void (*kiss_data_free_cb_t)(struct kiss_training_data *data, void *usr_ptr);
84
91typedef void (*kiss_train_progress_cb_t)(struct kiss_training_session *session, kiss_train_state_t state, float loss, int step, void *usr_ptr);
92
107struct kiss_training_session *kiss_create_training_session(const char* path, int batch_size);
108
122bool kiss_create_training_session_prealloc(struct kiss_training_session* session, const char *path, int batch_size);
123
124
131
132
139
150bool kiss_start_training(struct kiss_training_session *session, kiss_loss_t loss, kiss_data_feed_cb_t datafeeder,
151 kiss_data_free_cb_t datafreeer, kiss_train_progress_cb_t progresscb, void *user_ptr);
152
161
162
170
178
184#ifdef __cplusplus
185}
186#endif
bool kiss_create_training_session_prealloc(struct kiss_training_session *session, const char *path, int batch_size)
Loads a training session archive from disk while using a session allocated by the caller.
bool kiss_start_training(struct kiss_training_session *session, kiss_loss_t loss, kiss_data_feed_cb_t datafeeder, kiss_data_free_cb_t datafreeer, kiss_train_progress_cb_t progresscb, void *user_ptr)
Starts training the network.
struct kiss_training_session * kiss_create_training_session(const char *path, int batch_size)
Loads a training session archive from disk.
kiss_train_state_t
Definition train.h:41
const char * kiss_state_to_str(kiss_train_state_t state)
Gets a description string descrebing the state.
void(* kiss_data_free_cb_t)(struct kiss_training_data *data, void *usr_ptr)
A function pointer of this type is used by libkissinference to free data requested by the user of thi...
Definition train.h:83
const char * kiss_train_get_strerror(struct kiss_training_session *session)
Gets a error description string descrebing the lass orccured error.
void kiss_finish_training(struct kiss_training_session *session)
Finalizes training of the network.
void kiss_free_training_session_prealloc(struct kiss_training_session *session)
Frees the resources associated with a kiss_training_session but not the struct itself.
void kiss_free_training_session(struct kiss_training_session *session)
Frees a kiss_training_session.
void(* kiss_train_progress_cb_t)(struct kiss_training_session *session, kiss_train_state_t state, float loss, int step, void *usr_ptr)
A function pointer of this type is used by libkissinference to inform the user of this libary.
Definition train.h:91
@ KISS_STEP_COMPLETED
Emmited when a training optimizer step compleates, not emmited after the final step.
Definition train.h:42
@ KISS_COMPLETED
Emmited when a training completes the final optimizer step.
Definition train.h:43
@ KISS_ERROR_STATE
Emmited when an error is encountered.
Definition train.h:44
Struct describing a kiss neural network.
Definition kissinference.h:53
A function pointer of this type is used by libkissinference-train to request data from the user.
Definition train.h:61
size_t in_size
Length of the input array.
Definition train.h:63
float * in
Network input.
Definition train.h:62
float * out
Target output for loss computation.
Definition train.h:64
size_t out_size
Length of the target output array.
Definition train.h:65
Struct describing a training session.
Definition train.h:53
bool ready
Set to true of the session is ready to train.
Definition train.h:54
struct kiss_network * net
A network struct containing network metadata, note this struct is not complete and can not be used fo...
Definition train.h:56
void * user_ptr
Contains user pointer that is passed to the callbacks.
Definition train.h:55