mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2024-11-23 10:14:13 +08:00
teflon: Use correct convolution params struct
Use TfLiteDepthwiseConvParams for kTfLiteBuiltinDepthwiseConv2d. The layout of stride_width, stride_height, and padding struct members happens to be the same, but we shouldn't depend on that. This prepares for using the activation, dilation_width_factor, and dilation_height_factor members, which are at different offsets. Reviewed-by: Tomeu Vizoso <tomeu@tomeuvizoso.net> Signed-off-by: Philipp Zabel <p.zabel@pengutronix.de> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31981>
This commit is contained in:
parent
e3257f7461
commit
319c56b10a
@ -109,14 +109,24 @@ fill_operation(struct teflon_delegate *delegate, TfLiteContext *tf_context, TfLi
|
||||
switch(node_registration->builtin_code) {
|
||||
case kTfLiteBuiltinConv2d:
|
||||
case kTfLiteBuiltinDepthwiseConv2d: {
|
||||
TfLiteConvParams* params = (TfLiteConvParams*)node->builtin_data;
|
||||
operation->type = PIPE_ML_OPERATION_TYPE_CONVOLUTION;
|
||||
operation->conv.weight_tensor = &tensors[node->inputs->data[1]];
|
||||
operation->conv.bias_tensor = &tensors[node->inputs->data[2]];
|
||||
operation->conv.stride_x = params->stride_width;
|
||||
operation->conv.stride_y = params->stride_height;
|
||||
operation->conv.padding_same = params->padding == kTfLitePaddingSame;
|
||||
operation->conv.depthwise = node_registration->builtin_code == kTfLiteBuiltinDepthwiseConv2d;
|
||||
if (node_registration->builtin_code == kTfLiteBuiltinConv2d) {
|
||||
TfLiteConvParams* params = (TfLiteConvParams*)node->builtin_data;
|
||||
|
||||
operation->conv.stride_x = params->stride_width;
|
||||
operation->conv.stride_y = params->stride_height;
|
||||
operation->conv.padding_same = params->padding == kTfLitePaddingSame;
|
||||
operation->conv.depthwise = false;
|
||||
} else {
|
||||
TfLiteDepthwiseConvParams* params = (TfLiteDepthwiseConvParams*)node->builtin_data;
|
||||
|
||||
operation->conv.stride_x = params->stride_width;
|
||||
operation->conv.stride_y = params->stride_height;
|
||||
operation->conv.padding_same = params->padding == kTfLitePaddingSame;
|
||||
operation->conv.depthwise = true;
|
||||
}
|
||||
operation->conv.pointwise = operation->conv.weight_tensor->dims[1] == 1 && \
|
||||
operation->conv.weight_tensor->dims[2] == 1;
|
||||
break;
|
||||
|
Loading…
Reference in New Issue
Block a user