Espruino/libs/tensorflow/patches/02-quantized.diff
2019-10-18 14:37:45 +01:00

70 lines
3.4 KiB
Diff

--- a/libs/tensorflow/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc
+++ b/libs/tensorflow/tensorflow/lite/experimental/micro/kernels/all_ops_resolver.cc
@@ -62,11 +62,11 @@ AllOpsResolver::AllOpsResolver() {
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
/* min_version */ 1,
/* max_version */ 4);
- AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
+ AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D(), 1, 2);
AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX());
AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC());
AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
- AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
+ AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), 1, 3);
AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D());
AddBuiltin(BuiltinOperator_ABS, Register_ABS());
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
--- a/libs/tensorflow/tensorflow/lite/experimental/micro/kernels/pooling.cc
+++ b/libs/tensorflow/tensorflow/lite/experimental/micro/kernels/pooling.cc
@@ -158,6 +158,27 @@ void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node,
GetTensorData<uint8_t>(output));
}
+void MaxEvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
+ TfLitePoolParams* params, OpData* data,
+ const TfLiteTensor* input, TfLiteTensor* output) {
+ int32_t activation_min, activation_max;
+ CalculateActivationRangeInt8(params->activation, output, &activation_min,
+ &activation_max);
+
+ tflite::PoolParams op_params;
+ op_params.stride_height = params->stride_height;
+ op_params.stride_width = params->stride_width;
+ op_params.filter_height = params->filter_height;
+ op_params.filter_width = params->filter_width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.padding_values.width = data->padding.width;
+ op_params.quantized_activation_min = activation_min;
+ op_params.quantized_activation_max = activation_max;
+ reference_integer_ops::MaxPool(op_params, GetTensorShape(input),
+ GetTensorData<int8_t>(input), GetTensorShape(output),
+ GetTensorData<int8_t>(output));
+}
+
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -214,6 +235,9 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8:
MaxEvalQuantizedUInt8(context, node, params, &data, input, output);
break;
+ case kTfLiteInt8:
+ MaxEvalQuantizedInt8(context, node, params, &data, input, output);
+ break;
default:
context->ReportError(context, "Type %s not currently supported.",
TfLiteTypeGetName(input->type));
--- a/libs/tensorflow/tensorflow/lite/experimental/micro/kernels/softmax.cc
+++ b/libs/tensorflow/tensorflow/lite/experimental/micro/kernels/softmax.cc
@@ -194,8 +194,8 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
}
default:
context->ReportError(
- context, "Only float32 and uint8_t supported currently, got %d.",
- input->type);
+ context, "Only float32, int8_t and uint8_t supported currently, got %s.",
+ TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}